mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-28 03:43:23 +00:00
Refactor project structure
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
bd878dd1c9
commit
e3c227ae06
55
kernel/aster-nix/src/net/socket/unix/addr.rs
Normal file
55
kernel/aster-nix/src/net/socket/unix/addr.rs
Normal file
@ -0,0 +1,55 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use crate::{fs::utils::Dentry, net::socket::util::socket_addr::SocketAddr, prelude::*};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum UnixSocketAddr {
|
||||
Path(String),
|
||||
Abstract(String),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(super) enum UnixSocketAddrBound {
|
||||
Path(Arc<Dentry>),
|
||||
Abstract(String),
|
||||
}
|
||||
|
||||
impl PartialEq for UnixSocketAddrBound {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(Self::Abstract(l0), Self::Abstract(r0)) => l0 == r0,
|
||||
(Self::Path(l0), Self::Path(r0)) => Arc::ptr_eq(l0.inode(), r0.inode()),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<SocketAddr> for UnixSocketAddr {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: SocketAddr) -> Result<Self> {
|
||||
match value {
|
||||
SocketAddr::Unix(unix_socket_addr) => Ok(unix_socket_addr),
|
||||
_ => return_errno_with_message!(Errno::EINVAL, "Invalid unix socket addr"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UnixSocketAddrBound> for UnixSocketAddr {
|
||||
fn from(value: UnixSocketAddrBound) -> Self {
|
||||
match value {
|
||||
UnixSocketAddrBound::Path(dentry) => {
|
||||
let abs_path = dentry.abs_path();
|
||||
Self::Path(abs_path)
|
||||
}
|
||||
UnixSocketAddrBound::Abstract(name) => Self::Abstract(name),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UnixSocketAddrBound> for SocketAddr {
|
||||
fn from(value: UnixSocketAddrBound) -> Self {
|
||||
let unix_socket_addr = UnixSocketAddr::from(value);
|
||||
SocketAddr::Unix(unix_socket_addr)
|
||||
}
|
||||
}
|
7
kernel/aster-nix/src/net/socket/unix/mod.rs
Normal file
7
kernel/aster-nix/src/net/socket/unix/mod.rs
Normal file
@ -0,0 +1,7 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
mod addr;
|
||||
mod stream;
|
||||
|
||||
pub use addr::UnixSocketAddr;
|
||||
pub use stream::UnixStreamSocket;
|
55
kernel/aster-nix/src/net/socket/unix/stream/connected.rs
Normal file
55
kernel/aster-nix/src/net/socket/unix/stream/connected.rs
Normal file
@ -0,0 +1,55 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use super::endpoint::Endpoint;
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
net::socket::{unix::addr::UnixSocketAddrBound, SockShutdownCmd},
|
||||
prelude::*,
|
||||
process::signal::Poller,
|
||||
};
|
||||
|
||||
pub(super) struct Connected {
|
||||
local_endpoint: Arc<Endpoint>,
|
||||
}
|
||||
|
||||
impl Connected {
|
||||
pub(super) fn new(local_endpoint: Arc<Endpoint>) -> Self {
|
||||
Connected { local_endpoint }
|
||||
}
|
||||
|
||||
pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> {
|
||||
self.local_endpoint.addr()
|
||||
}
|
||||
|
||||
pub(super) fn peer_addr(&self) -> Option<UnixSocketAddrBound> {
|
||||
self.local_endpoint.peer_addr()
|
||||
}
|
||||
|
||||
pub(super) fn is_bound(&self) -> bool {
|
||||
self.addr().is_some()
|
||||
}
|
||||
|
||||
pub(super) fn write(&self, buf: &[u8]) -> Result<usize> {
|
||||
self.local_endpoint.write(buf)
|
||||
}
|
||||
|
||||
pub(super) fn read(&self, buf: &mut [u8]) -> Result<usize> {
|
||||
self.local_endpoint.read(buf)
|
||||
}
|
||||
|
||||
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
||||
self.local_endpoint.shutdown(cmd)
|
||||
}
|
||||
|
||||
pub(super) fn is_nonblocking(&self) -> bool {
|
||||
self.local_endpoint.is_nonblocking()
|
||||
}
|
||||
|
||||
pub(super) fn set_nonblocking(&self, is_nonblocking: bool) {
|
||||
self.local_endpoint.set_nonblocking(is_nonblocking).unwrap();
|
||||
}
|
||||
|
||||
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.local_endpoint.poll(mask, poller)
|
||||
}
|
||||
}
|
127
kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs
Normal file
127
kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs
Normal file
@ -0,0 +1,127 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
fs::utils::{Channel, Consumer, Producer, StatusFlags},
|
||||
net::socket::{unix::addr::UnixSocketAddrBound, SockShutdownCmd},
|
||||
prelude::*,
|
||||
process::signal::Poller,
|
||||
};
|
||||
|
||||
pub(super) struct Endpoint(Inner);
|
||||
|
||||
struct Inner {
|
||||
addr: RwLock<Option<UnixSocketAddrBound>>,
|
||||
reader: Consumer<u8>,
|
||||
writer: Producer<u8>,
|
||||
peer: Weak<Endpoint>,
|
||||
}
|
||||
|
||||
impl Endpoint {
|
||||
pub(super) fn new_pair(is_nonblocking: bool) -> Result<(Arc<Endpoint>, Arc<Endpoint>)> {
|
||||
let flags = if is_nonblocking {
|
||||
StatusFlags::O_NONBLOCK
|
||||
} else {
|
||||
StatusFlags::empty()
|
||||
};
|
||||
let (writer_a, reader_b) =
|
||||
Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split();
|
||||
let (writer_b, reader_a) =
|
||||
Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split();
|
||||
let mut endpoint_b = None;
|
||||
let endpoint_a = Arc::new_cyclic(|endpoint_a_ref| {
|
||||
let peer = Arc::new(Endpoint::new(reader_b, writer_b, endpoint_a_ref.clone()));
|
||||
let endpoint_a = Endpoint::new(reader_a, writer_a, Arc::downgrade(&peer));
|
||||
endpoint_b = Some(peer);
|
||||
endpoint_a
|
||||
});
|
||||
Ok((endpoint_a, endpoint_b.unwrap()))
|
||||
}
|
||||
|
||||
fn new(reader: Consumer<u8>, writer: Producer<u8>, peer: Weak<Endpoint>) -> Self {
|
||||
Self(Inner {
|
||||
addr: RwLock::new(None),
|
||||
reader,
|
||||
writer,
|
||||
peer,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> {
|
||||
self.0.addr.read().clone()
|
||||
}
|
||||
|
||||
pub(super) fn set_addr(&self, addr: UnixSocketAddrBound) {
|
||||
*self.0.addr.write() = Some(addr);
|
||||
}
|
||||
|
||||
pub(super) fn peer_addr(&self) -> Option<UnixSocketAddrBound> {
|
||||
self.0.peer.upgrade().and_then(|peer| peer.addr())
|
||||
}
|
||||
|
||||
pub(super) fn is_nonblocking(&self) -> bool {
|
||||
let reader_status = self.0.reader.is_nonblocking();
|
||||
let writer_status = self.0.writer.is_nonblocking();
|
||||
debug_assert!(reader_status == writer_status);
|
||||
reader_status
|
||||
}
|
||||
|
||||
pub(super) fn set_nonblocking(&self, is_nonblocking: bool) -> Result<()> {
|
||||
let reader_flags = self.0.reader.status_flags();
|
||||
self.0
|
||||
.reader
|
||||
.set_status_flags(reader_flags | StatusFlags::O_NONBLOCK)?;
|
||||
let writer_flags = self.0.writer.status_flags();
|
||||
self.0
|
||||
.writer
|
||||
.set_status_flags(writer_flags | StatusFlags::O_NONBLOCK)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn read(&self, buf: &mut [u8]) -> Result<usize> {
|
||||
self.0.reader.read(buf)
|
||||
}
|
||||
|
||||
pub(super) fn write(&self, buf: &[u8]) -> Result<usize> {
|
||||
self.0.writer.write(buf)
|
||||
}
|
||||
|
||||
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
||||
if !self.is_connected() {
|
||||
return_errno_with_message!(Errno::ENOTCONN, "The socket is not connected.");
|
||||
}
|
||||
|
||||
if cmd.shut_read() {
|
||||
self.0.reader.shutdown();
|
||||
}
|
||||
|
||||
if cmd.shut_write() {
|
||||
self.0.writer.shutdown();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn is_connected(&self) -> bool {
|
||||
self.0.peer.upgrade().is_some()
|
||||
}
|
||||
|
||||
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
let mut events = IoEvents::empty();
|
||||
// FIXME: should reader and writer use the same mask?
|
||||
let reader_events = self.0.reader.poll(mask, poller);
|
||||
let writer_events = self.0.writer.poll(mask, poller);
|
||||
|
||||
if reader_events.contains(IoEvents::HUP) || self.0.reader.is_shutdown() {
|
||||
events |= IoEvents::RDHUP | IoEvents::IN;
|
||||
if writer_events.contains(IoEvents::ERR) || self.0.writer.is_shutdown() {
|
||||
events |= IoEvents::HUP | IoEvents::OUT;
|
||||
}
|
||||
}
|
||||
|
||||
events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT);
|
||||
events
|
||||
}
|
||||
}
|
||||
|
||||
const DAFAULT_BUF_SIZE: usize = 4096;
|
104
kernel/aster-nix/src/net/socket/unix/stream/init.rs
Normal file
104
kernel/aster-nix/src/net/socket/unix/stream/init.rs
Normal file
@ -0,0 +1,104 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use core::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use super::{connected::Connected, endpoint::Endpoint, listener::push_incoming};
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
fs::{
|
||||
fs_resolver::{split_path, FsPath},
|
||||
utils::{Dentry, InodeMode, InodeType},
|
||||
},
|
||||
net::socket::unix::addr::{UnixSocketAddr, UnixSocketAddrBound},
|
||||
prelude::*,
|
||||
process::signal::{Pollee, Poller},
|
||||
};
|
||||
|
||||
pub(super) struct Init {
|
||||
is_nonblocking: AtomicBool,
|
||||
addr: Mutex<Option<UnixSocketAddrBound>>,
|
||||
pollee: Pollee,
|
||||
}
|
||||
|
||||
impl Init {
|
||||
pub(super) fn new(is_nonblocking: bool) -> Self {
|
||||
Self {
|
||||
is_nonblocking: AtomicBool::new(is_nonblocking),
|
||||
addr: Mutex::new(None),
|
||||
pollee: Pollee::new(IoEvents::empty()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn bind(&self, addr_to_bind: &UnixSocketAddr) -> Result<()> {
|
||||
let mut addr = self.addr.lock();
|
||||
if addr.is_some() {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is already bound");
|
||||
}
|
||||
|
||||
let bound_addr = match addr_to_bind {
|
||||
UnixSocketAddr::Abstract(_) => todo!(),
|
||||
UnixSocketAddr::Path(path) => {
|
||||
let dentry = create_socket_file(path)?;
|
||||
UnixSocketAddrBound::Path(dentry)
|
||||
}
|
||||
};
|
||||
|
||||
*addr = Some(bound_addr);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn connect(&self, remote_addr: &UnixSocketAddrBound) -> Result<Connected> {
|
||||
let addr = self.addr();
|
||||
|
||||
if let Some(ref addr) = addr {
|
||||
if *addr == *remote_addr {
|
||||
return_errno_with_message!(Errno::EINVAL, "try to connect to self is invalid");
|
||||
}
|
||||
}
|
||||
|
||||
let (this_end, remote_end) = Endpoint::new_pair(self.is_nonblocking())?;
|
||||
remote_end.set_addr(remote_addr.clone());
|
||||
if let Some(addr) = addr {
|
||||
this_end.set_addr(addr.clone());
|
||||
};
|
||||
|
||||
push_incoming(remote_addr, remote_end)?;
|
||||
Ok(Connected::new(this_end))
|
||||
}
|
||||
|
||||
pub(super) fn is_bound(&self) -> bool {
|
||||
self.addr.lock().is_some()
|
||||
}
|
||||
|
||||
pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> {
|
||||
self.addr.lock().clone()
|
||||
}
|
||||
|
||||
pub(super) fn is_nonblocking(&self) -> bool {
|
||||
self.is_nonblocking.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
pub(super) fn set_nonblocking(&self, is_nonblocking: bool) {
|
||||
self.is_nonblocking.store(is_nonblocking, Ordering::Release);
|
||||
}
|
||||
|
||||
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.pollee.poll(mask, poller)
|
||||
}
|
||||
}
|
||||
|
||||
fn create_socket_file(path: &str) -> Result<Arc<Dentry>> {
|
||||
let (parent_pathname, file_name) = split_path(path);
|
||||
let parent = {
|
||||
let current = current!();
|
||||
let fs = current.fs().read();
|
||||
let parent_path = FsPath::try_from(parent_pathname)?;
|
||||
fs.lookup(&parent_path)?
|
||||
};
|
||||
let dentry = parent.create(
|
||||
file_name,
|
||||
InodeType::Socket,
|
||||
InodeMode::S_IRUSR | InodeMode::S_IWUSR,
|
||||
)?;
|
||||
Ok(dentry)
|
||||
}
|
229
kernel/aster-nix/src/net/socket/unix/stream/listener.rs
Normal file
229
kernel/aster-nix/src/net/socket/unix/stream/listener.rs
Normal file
@ -0,0 +1,229 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use core::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use keyable_arc::KeyableWeak;
|
||||
|
||||
use super::{connected::Connected, endpoint::Endpoint, UnixStreamSocket};
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
fs::{
|
||||
file_handle::FileLike,
|
||||
utils::{Dentry, Inode},
|
||||
},
|
||||
net::socket::{
|
||||
unix::addr::{UnixSocketAddr, UnixSocketAddrBound},
|
||||
SocketAddr,
|
||||
},
|
||||
prelude::*,
|
||||
process::signal::{Pollee, Poller},
|
||||
};
|
||||
|
||||
pub(super) struct Listener {
|
||||
addr: UnixSocketAddrBound,
|
||||
is_nonblocking: AtomicBool,
|
||||
}
|
||||
|
||||
impl Listener {
|
||||
pub(super) fn new(
|
||||
addr: UnixSocketAddrBound,
|
||||
backlog: usize,
|
||||
nonblocking: bool,
|
||||
) -> Result<Self> {
|
||||
BACKLOG_TABLE.add_backlog(&addr, backlog)?;
|
||||
Ok(Self {
|
||||
addr,
|
||||
is_nonblocking: AtomicBool::new(nonblocking),
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn addr(&self) -> &UnixSocketAddrBound {
|
||||
&self.addr
|
||||
}
|
||||
|
||||
pub(super) fn is_nonblocking(&self) -> bool {
|
||||
self.is_nonblocking.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
pub(super) fn set_nonblocking(&self, is_nonblocking: bool) {
|
||||
self.is_nonblocking.store(is_nonblocking, Ordering::Release);
|
||||
}
|
||||
|
||||
pub(super) fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
||||
let addr = self.addr().clone();
|
||||
let is_nonblocking = self.is_nonblocking();
|
||||
|
||||
let connected = {
|
||||
let local_endpoint = BACKLOG_TABLE.pop_incoming(is_nonblocking, &addr)?;
|
||||
Connected::new(local_endpoint)
|
||||
};
|
||||
|
||||
let peer_addr = match connected.peer_addr() {
|
||||
None => SocketAddr::Unix(UnixSocketAddr::Path(String::new())),
|
||||
Some(addr) => SocketAddr::from(addr.clone()),
|
||||
};
|
||||
|
||||
let socket = Arc::new(UnixStreamSocket::new_connected(connected));
|
||||
|
||||
Ok((socket, peer_addr))
|
||||
}
|
||||
|
||||
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
let addr = self.addr();
|
||||
let backlog = BACKLOG_TABLE.get_backlog(addr).unwrap();
|
||||
backlog.poll(mask, poller)
|
||||
}
|
||||
}
|
||||
|
||||
static BACKLOG_TABLE: BacklogTable = BacklogTable::new();
|
||||
|
||||
struct BacklogTable {
|
||||
backlog_sockets: RwLock<BTreeMap<KeyableWeak<dyn Inode>, Arc<Backlog>>>,
|
||||
// TODO: For linux, there is also abstract socket domain that a socket addr is not bound to an inode.
|
||||
}
|
||||
|
||||
impl BacklogTable {
|
||||
const fn new() -> Self {
|
||||
Self {
|
||||
backlog_sockets: RwLock::new(BTreeMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_backlog(&self, addr: &UnixSocketAddrBound, backlog: usize) -> Result<()> {
|
||||
let inode = {
|
||||
let UnixSocketAddrBound::Path(dentry) = addr else {
|
||||
todo!()
|
||||
};
|
||||
create_keyable_inode(dentry)
|
||||
};
|
||||
|
||||
let mut backlog_sockets = self.backlog_sockets.write();
|
||||
if backlog_sockets.contains_key(&inode) {
|
||||
return_errno_with_message!(Errno::EADDRINUSE, "the addr is already used");
|
||||
}
|
||||
let new_backlog = Arc::new(Backlog::new(backlog));
|
||||
backlog_sockets.insert(inode, new_backlog);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_backlog(&self, addr: &UnixSocketAddrBound) -> Result<Arc<Backlog>> {
|
||||
let inode = {
|
||||
let UnixSocketAddrBound::Path(dentry) = addr else {
|
||||
todo!()
|
||||
};
|
||||
create_keyable_inode(dentry)
|
||||
};
|
||||
|
||||
let backlog_sockets = self.backlog_sockets.read();
|
||||
backlog_sockets
|
||||
.get(&inode)
|
||||
.map(Arc::clone)
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "the socket is not listened"))
|
||||
}
|
||||
|
||||
fn pop_incoming(&self, nonblocking: bool, addr: &UnixSocketAddrBound) -> Result<Arc<Endpoint>> {
|
||||
let poller = Poller::new();
|
||||
loop {
|
||||
let backlog = self.get_backlog(addr)?;
|
||||
|
||||
if let Some(endpoint) = backlog.pop_incoming() {
|
||||
return Ok(endpoint);
|
||||
}
|
||||
|
||||
if nonblocking {
|
||||
return_errno_with_message!(Errno::EAGAIN, "no connection comes");
|
||||
}
|
||||
|
||||
let events = {
|
||||
let mask = IoEvents::IN;
|
||||
backlog.poll(mask, Some(&poller))
|
||||
};
|
||||
|
||||
if events.contains(IoEvents::ERR) | events.contains(IoEvents::HUP) {
|
||||
return_errno_with_message!(Errno::ECONNABORTED, "connection is aborted");
|
||||
}
|
||||
|
||||
// FIXME: deal with accept timeout
|
||||
if events.is_empty() {
|
||||
poller.wait()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn push_incoming(&self, addr: &UnixSocketAddrBound, endpoint: Arc<Endpoint>) -> Result<()> {
|
||||
let backlog = self.get_backlog(addr).map_err(|_| {
|
||||
Error::with_message(
|
||||
Errno::ECONNREFUSED,
|
||||
"no socket is listened at the remote address",
|
||||
)
|
||||
})?;
|
||||
|
||||
backlog.push_incoming(endpoint)
|
||||
}
|
||||
|
||||
fn remove_backlog(&self, addr: &UnixSocketAddrBound) {
|
||||
let UnixSocketAddrBound::Path(dentry) = addr else {
|
||||
todo!()
|
||||
};
|
||||
|
||||
let inode = create_keyable_inode(dentry);
|
||||
self.backlog_sockets.write().remove(&inode);
|
||||
}
|
||||
}
|
||||
|
||||
struct Backlog {
|
||||
pollee: Pollee,
|
||||
backlog: usize,
|
||||
incoming_endpoints: Mutex<VecDeque<Arc<Endpoint>>>,
|
||||
}
|
||||
|
||||
impl Backlog {
|
||||
fn new(backlog: usize) -> Self {
|
||||
Self {
|
||||
pollee: Pollee::new(IoEvents::empty()),
|
||||
backlog,
|
||||
incoming_endpoints: Mutex::new(VecDeque::with_capacity(backlog)),
|
||||
}
|
||||
}
|
||||
|
||||
fn push_incoming(&self, endpoint: Arc<Endpoint>) -> Result<()> {
|
||||
let mut endpoints = self.incoming_endpoints.lock();
|
||||
if endpoints.len() >= self.backlog {
|
||||
return_errno_with_message!(Errno::ECONNREFUSED, "incoming_endpoints is full");
|
||||
}
|
||||
endpoints.push_back(endpoint);
|
||||
self.pollee.add_events(IoEvents::IN);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn pop_incoming(&self) -> Option<Arc<Endpoint>> {
|
||||
let mut incoming_endpoints = self.incoming_endpoints.lock();
|
||||
let endpoint = incoming_endpoints.pop_front();
|
||||
if incoming_endpoints.is_empty() {
|
||||
self.pollee.del_events(IoEvents::IN);
|
||||
}
|
||||
endpoint
|
||||
}
|
||||
|
||||
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
// Lock to avoid any events may change pollee state when we poll
|
||||
let _lock = self.incoming_endpoints.lock();
|
||||
self.pollee.poll(mask, poller)
|
||||
}
|
||||
}
|
||||
|
||||
fn create_keyable_inode(dentry: &Arc<Dentry>) -> KeyableWeak<dyn Inode> {
|
||||
let weak_inode = Arc::downgrade(dentry.inode());
|
||||
KeyableWeak::from(weak_inode)
|
||||
}
|
||||
|
||||
pub(super) fn unregister_backlog(addr: &UnixSocketAddrBound) {
|
||||
BACKLOG_TABLE.remove_backlog(addr);
|
||||
}
|
||||
|
||||
pub(super) fn push_incoming(
|
||||
remote_addr: &UnixSocketAddrBound,
|
||||
remote_end: Arc<Endpoint>,
|
||||
) -> Result<()> {
|
||||
BACKLOG_TABLE.push_incoming(remote_addr, remote_end)
|
||||
}
|
9
kernel/aster-nix/src/net/socket/unix/stream/mod.rs
Normal file
9
kernel/aster-nix/src/net/socket/unix/stream/mod.rs
Normal file
@ -0,0 +1,9 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
mod connected;
|
||||
mod endpoint;
|
||||
mod init;
|
||||
mod listener;
|
||||
mod socket;
|
||||
|
||||
pub use socket::UnixStreamSocket;
|
306
kernel/aster-nix/src/net/socket/unix/stream/socket.rs
Normal file
306
kernel/aster-nix/src/net/socket/unix/stream/socket.rs
Normal file
@ -0,0 +1,306 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use super::{
|
||||
connected::Connected,
|
||||
endpoint::Endpoint,
|
||||
init::Init,
|
||||
listener::{unregister_backlog, Listener},
|
||||
};
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
fs::{
|
||||
file_handle::FileLike,
|
||||
fs_resolver::FsPath,
|
||||
utils::{Dentry, InodeType, StatusFlags},
|
||||
},
|
||||
net::socket::{
|
||||
unix::{addr::UnixSocketAddrBound, UnixSocketAddr},
|
||||
util::{send_recv_flags::SendRecvFlags, socket_addr::SocketAddr},
|
||||
SockShutdownCmd, Socket,
|
||||
},
|
||||
prelude::*,
|
||||
process::signal::Poller,
|
||||
};
|
||||
|
||||
pub struct UnixStreamSocket(RwLock<State>);
|
||||
|
||||
impl UnixStreamSocket {
|
||||
pub(super) fn new_init(init: Init) -> Self {
|
||||
Self(RwLock::new(State::Init(Arc::new(init))))
|
||||
}
|
||||
|
||||
pub(super) fn new_listen(listen: Listener) -> Self {
|
||||
Self(RwLock::new(State::Listen(Arc::new(listen))))
|
||||
}
|
||||
|
||||
pub(super) fn new_connected(connected: Connected) -> Self {
|
||||
Self(RwLock::new(State::Connected(Arc::new(connected))))
|
||||
}
|
||||
}
|
||||
|
||||
enum State {
|
||||
Init(Arc<Init>),
|
||||
Listen(Arc<Listener>),
|
||||
Connected(Arc<Connected>),
|
||||
}
|
||||
|
||||
impl UnixStreamSocket {
|
||||
pub fn new(nonblocking: bool) -> Self {
|
||||
let init = Init::new(nonblocking);
|
||||
Self::new_init(init)
|
||||
}
|
||||
|
||||
pub fn new_pair(nonblocking: bool) -> Result<(Arc<Self>, Arc<Self>)> {
|
||||
let (end_a, end_b) = Endpoint::new_pair(nonblocking)?;
|
||||
let connected_a = {
|
||||
let connected = Connected::new(end_a);
|
||||
Self::new_connected(connected)
|
||||
};
|
||||
let connected_b = {
|
||||
let connected = Connected::new(end_b);
|
||||
Self::new_connected(connected)
|
||||
};
|
||||
Ok((Arc::new(connected_a), Arc::new(connected_b)))
|
||||
}
|
||||
|
||||
fn bound_addr(&self) -> Option<UnixSocketAddrBound> {
|
||||
let status = self.0.read();
|
||||
match &*status {
|
||||
State::Init(init) => init.addr(),
|
||||
State::Listen(listen) => Some(listen.addr().clone()),
|
||||
State::Connected(connected) => connected.addr(),
|
||||
}
|
||||
}
|
||||
|
||||
fn mask_flags(status_flags: &StatusFlags) -> StatusFlags {
|
||||
const SUPPORTED_FLAGS: StatusFlags = StatusFlags::O_NONBLOCK;
|
||||
const UNSUPPORTED_FLAGS: StatusFlags = SUPPORTED_FLAGS.complement();
|
||||
|
||||
if status_flags.intersects(UNSUPPORTED_FLAGS) {
|
||||
warn!("ignore unsupported flags");
|
||||
}
|
||||
|
||||
status_flags.intersection(SUPPORTED_FLAGS)
|
||||
}
|
||||
}
|
||||
|
||||
impl FileLike for UnixStreamSocket {
|
||||
fn as_socket(self: Arc<Self>) -> Option<Arc<dyn Socket>> {
|
||||
Some(self)
|
||||
}
|
||||
|
||||
fn read(&self, buf: &mut [u8]) -> Result<usize> {
|
||||
self.recvfrom(buf, SendRecvFlags::empty())
|
||||
.map(|(read_size, _)| read_size)
|
||||
}
|
||||
|
||||
fn write(&self, buf: &[u8]) -> Result<usize> {
|
||||
self.sendto(buf, None, SendRecvFlags::empty())
|
||||
}
|
||||
|
||||
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
let inner = self.0.read();
|
||||
match &*inner {
|
||||
State::Init(init) => init.poll(mask, poller),
|
||||
State::Listen(listen) => listen.poll(mask, poller),
|
||||
State::Connected(connected) => connected.poll(mask, poller),
|
||||
}
|
||||
}
|
||||
|
||||
fn status_flags(&self) -> StatusFlags {
|
||||
let inner = self.0.read();
|
||||
let is_nonblocking = match &*inner {
|
||||
State::Init(init) => init.is_nonblocking(),
|
||||
State::Listen(listen) => listen.is_nonblocking(),
|
||||
State::Connected(connected) => connected.is_nonblocking(),
|
||||
};
|
||||
|
||||
if is_nonblocking {
|
||||
StatusFlags::O_NONBLOCK
|
||||
} else {
|
||||
StatusFlags::empty()
|
||||
}
|
||||
}
|
||||
|
||||
fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
|
||||
let is_nonblocking = {
|
||||
let supported_flags = Self::mask_flags(&new_flags);
|
||||
supported_flags.contains(StatusFlags::O_NONBLOCK)
|
||||
};
|
||||
|
||||
let mut inner = self.0.write();
|
||||
match &mut *inner {
|
||||
State::Init(init) => init.set_nonblocking(is_nonblocking),
|
||||
State::Listen(listen) => listen.set_nonblocking(is_nonblocking),
|
||||
State::Connected(connected) => connected.set_nonblocking(is_nonblocking),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Socket for UnixStreamSocket {
|
||||
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
|
||||
let addr = UnixSocketAddr::try_from(socket_addr)?;
|
||||
|
||||
let init = match &*self.0.read() {
|
||||
State::Init(init) => init.clone(),
|
||||
_ => return_errno_with_message!(
|
||||
Errno::EINVAL,
|
||||
"cannot bind a listening or connected socket"
|
||||
),
|
||||
// FIXME: Maybe binding a connected socket should also be allowed?
|
||||
};
|
||||
|
||||
init.bind(&addr)
|
||||
}
|
||||
|
||||
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
|
||||
let remote_addr = {
|
||||
let unix_socket_addr = UnixSocketAddr::try_from(socket_addr)?;
|
||||
match unix_socket_addr {
|
||||
UnixSocketAddr::Abstract(abstract_name) => {
|
||||
UnixSocketAddrBound::Abstract(abstract_name)
|
||||
}
|
||||
UnixSocketAddr::Path(path) => {
|
||||
let dentry = lookup_socket_file(&path)?;
|
||||
UnixSocketAddrBound::Path(dentry)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let init = match &*self.0.read() {
|
||||
State::Init(init) => init.clone(),
|
||||
State::Listen(_) => return_errno_with_message!(Errno::EINVAL, "the socket is listened"),
|
||||
State::Connected(_) => {
|
||||
return_errno_with_message!(Errno::EISCONN, "the socket is connected")
|
||||
}
|
||||
};
|
||||
|
||||
let connected = init.connect(&remote_addr)?;
|
||||
|
||||
*self.0.write() = State::Connected(Arc::new(connected));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn listen(&self, backlog: usize) -> Result<()> {
|
||||
let init = match &*self.0.read() {
|
||||
State::Init(init) => init.clone(),
|
||||
State::Listen(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is already listening")
|
||||
}
|
||||
State::Connected(_) => {
|
||||
return_errno_with_message!(Errno::EISCONN, "the socket is already connected")
|
||||
}
|
||||
};
|
||||
|
||||
let addr = init.addr().ok_or(Error::with_message(
|
||||
Errno::EINVAL,
|
||||
"the socket is not bound",
|
||||
))?;
|
||||
|
||||
let listener = Listener::new(addr.clone(), backlog, init.is_nonblocking())?;
|
||||
*self.0.write() = State::Listen(Arc::new(listener));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
||||
let listen = match &*self.0.read() {
|
||||
State::Listen(listen) => listen.clone(),
|
||||
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not listening"),
|
||||
};
|
||||
|
||||
listen.accept()
|
||||
}
|
||||
|
||||
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
||||
let connected = match &*self.0.read() {
|
||||
State::Connected(connected) => connected.clone(),
|
||||
_ => return_errno_with_message!(Errno::ENOTCONN, "the socked is not connected"),
|
||||
};
|
||||
|
||||
connected.shutdown(cmd)
|
||||
}
|
||||
|
||||
fn addr(&self) -> Result<SocketAddr> {
|
||||
let addr = match &*self.0.read() {
|
||||
State::Init(init) => init.addr(),
|
||||
State::Listen(listen) => Some(listen.addr().clone()),
|
||||
State::Connected(connected) => connected.addr(),
|
||||
};
|
||||
|
||||
addr.map(Into::<SocketAddr>::into)
|
||||
.ok_or(Error::with_message(
|
||||
Errno::EINVAL,
|
||||
"the socket does not bind to addr",
|
||||
))
|
||||
}
|
||||
|
||||
fn peer_addr(&self) -> Result<SocketAddr> {
|
||||
let connected = match &*self.0.read() {
|
||||
State::Connected(connected) => connected.clone(),
|
||||
_ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"),
|
||||
};
|
||||
|
||||
match connected.peer_addr() {
|
||||
None => Ok(SocketAddr::Unix(UnixSocketAddr::Path(String::new()))),
|
||||
Some(peer_addr) => Ok(SocketAddr::from(peer_addr.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
|
||||
let connected = match &*self.0.read() {
|
||||
State::Connected(connected) => connected.clone(),
|
||||
_ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"),
|
||||
};
|
||||
|
||||
let peer_addr = self.peer_addr()?;
|
||||
let read_size = connected.read(buf)?;
|
||||
Ok((read_size, peer_addr))
|
||||
}
|
||||
|
||||
fn sendto(
|
||||
&self,
|
||||
buf: &[u8],
|
||||
remote: Option<SocketAddr>,
|
||||
flags: SendRecvFlags,
|
||||
) -> Result<usize> {
|
||||
debug_assert!(remote.is_none());
|
||||
// TODO: deal with flags
|
||||
let connected = match &*self.0.read() {
|
||||
State::Connected(connected) => connected.clone(),
|
||||
_ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"),
|
||||
};
|
||||
|
||||
connected.write(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for UnixStreamSocket {
|
||||
fn drop(&mut self) {
|
||||
let Some(bound_addr) = self.bound_addr() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let State::Listen(_) = &*self.0.read() {
|
||||
unregister_backlog(&bound_addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn lookup_socket_file(path: &str) -> Result<Arc<Dentry>> {
|
||||
let dentry = {
|
||||
let current = current!();
|
||||
let fs = current.fs().read();
|
||||
let fs_path = FsPath::try_from(path)?;
|
||||
fs.lookup(&fs_path)?
|
||||
};
|
||||
|
||||
if dentry.type_() != InodeType::Socket {
|
||||
return_errno_with_message!(Errno::ENOTSOCK, "not a socket file")
|
||||
}
|
||||
|
||||
if !dentry.mode()?.is_readable() || !dentry.mode()?.is_writable() {
|
||||
return_errno_with_message!(Errno::EACCES, "the socket cannot be read or written")
|
||||
}
|
||||
Ok(dentry)
|
||||
}
|
Reference in New Issue
Block a user