mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-23 01:13:23 +00:00
Refactor unix stream socket implementation
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
93429ae2c9
commit
3f15bcaf5d
@ -1,6 +1,8 @@
|
||||
use crate::fs::file_handle::FileLike;
|
||||
use crate::fs::utils::{IoEvents, Poller, StatusFlags};
|
||||
use crate::net::socket::unix::addr::UnixSocketAddr;
|
||||
use crate::fs::fs_resolver::FsPath;
|
||||
use crate::fs::utils::{Dentry, InodeType, IoEvents, Poller, StatusFlags};
|
||||
use crate::net::socket::unix::addr::UnixSocketAddrBound;
|
||||
use crate::net::socket::unix::UnixSocketAddr;
|
||||
use crate::net::socket::util::send_recv_flags::SendRecvFlags;
|
||||
use crate::net::socket::util::sockaddr::SocketAddr;
|
||||
use crate::net::socket::{SockShutdownCmd, Socket};
|
||||
@ -9,40 +11,59 @@ use crate::prelude::*;
|
||||
use super::connected::Connected;
|
||||
use super::endpoint::Endpoint;
|
||||
use super::init::Init;
|
||||
use super::listen::Listen;
|
||||
use super::ACTIVE_LISTENERS;
|
||||
use super::listener::{unregister_backlog, Listener};
|
||||
|
||||
pub struct UnixStreamSocket(RwLock<Status>);
|
||||
pub struct UnixStreamSocket(RwLock<State>);
|
||||
|
||||
enum Status {
|
||||
Init(Init),
|
||||
Listen(Listen),
|
||||
Connected(Connected),
|
||||
impl UnixStreamSocket {
|
||||
pub(super) fn new_init(init: Init) -> Self {
|
||||
Self(RwLock::new(State::Init(Arc::new(init))))
|
||||
}
|
||||
|
||||
pub(super) fn new_listen(listen: Listener) -> Self {
|
||||
Self(RwLock::new(State::Listen(Arc::new(listen))))
|
||||
}
|
||||
|
||||
pub(super) fn new_connected(connected: Connected) -> Self {
|
||||
Self(RwLock::new(State::Connected(Arc::new(connected))))
|
||||
}
|
||||
}
|
||||
|
||||
enum State {
|
||||
Init(Arc<Init>),
|
||||
Listen(Arc<Listener>),
|
||||
Connected(Arc<Connected>),
|
||||
}
|
||||
|
||||
impl UnixStreamSocket {
|
||||
pub fn new(nonblocking: bool) -> Self {
|
||||
let status = Status::Init(Init::new(nonblocking));
|
||||
Self(RwLock::new(status))
|
||||
let init = Init::new(nonblocking);
|
||||
Self::new_init(init)
|
||||
}
|
||||
|
||||
pub fn new_pair(nonblocking: bool) -> Result<(Arc<Self>, Arc<Self>)> {
|
||||
let (end_a, end_b) = Endpoint::end_pair(nonblocking)?;
|
||||
let connected_a = UnixStreamSocket(RwLock::new(Status::Connected(Connected::new(end_a))));
|
||||
let connected_b = UnixStreamSocket(RwLock::new(Status::Connected(Connected::new(end_b))));
|
||||
let (end_a, end_b) = Endpoint::new_pair(nonblocking)?;
|
||||
let connected_a = {
|
||||
let connected = Connected::new(end_a);
|
||||
Self::new_connected(connected)
|
||||
};
|
||||
let connected_b = {
|
||||
let connected = Connected::new(end_b);
|
||||
Self::new_connected(connected)
|
||||
};
|
||||
Ok((Arc::new(connected_a), Arc::new(connected_b)))
|
||||
}
|
||||
|
||||
fn bound_addr(&self) -> Option<UnixSocketAddr> {
|
||||
fn bound_addr(&self) -> Option<UnixSocketAddrBound> {
|
||||
let status = self.0.read();
|
||||
match &*status {
|
||||
Status::Init(init) => init.bound_addr().map(Clone::clone),
|
||||
Status::Listen(listen) => Some(listen.addr().clone()),
|
||||
Status::Connected(connected) => connected.addr(),
|
||||
State::Init(init) => init.addr(),
|
||||
State::Listen(listen) => Some(listen.addr().clone()),
|
||||
State::Connected(connected) => connected.addr(),
|
||||
}
|
||||
}
|
||||
|
||||
fn supported_flags(status_flags: &StatusFlags) -> StatusFlags {
|
||||
fn mask_flags(status_flags: &StatusFlags) -> StatusFlags {
|
||||
const SUPPORTED_FLAGS: StatusFlags = StatusFlags::O_NONBLOCK;
|
||||
const UNSUPPORTED_FLAGS: StatusFlags = SUPPORTED_FLAGS.complement();
|
||||
|
||||
@ -71,22 +92,18 @@ impl FileLike for UnixStreamSocket {
|
||||
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
let inner = self.0.read();
|
||||
match &*inner {
|
||||
Status::Init(init) => init.poll(mask, poller),
|
||||
Status::Listen(listen) => {
|
||||
let addr = listen.addr();
|
||||
let listener = ACTIVE_LISTENERS.get_listener(addr).unwrap();
|
||||
listener.poll(mask, poller)
|
||||
}
|
||||
Status::Connected(connet) => todo!(),
|
||||
State::Init(init) => init.poll(mask, poller),
|
||||
State::Listen(listen) => listen.poll(mask, poller),
|
||||
State::Connected(connected) => connected.poll(mask, poller),
|
||||
}
|
||||
}
|
||||
|
||||
fn status_flags(&self) -> StatusFlags {
|
||||
let inner = self.0.read();
|
||||
let is_nonblocking = match &*inner {
|
||||
Status::Init(init) => init.is_nonblocking(),
|
||||
Status::Listen(listen) => listen.is_nonblocking(),
|
||||
Status::Connected(connected) => connected.is_nonblocking(),
|
||||
State::Init(init) => init.is_nonblocking(),
|
||||
State::Listen(listen) => listen.is_nonblocking(),
|
||||
State::Connected(connected) => connected.is_nonblocking(),
|
||||
};
|
||||
|
||||
if is_nonblocking {
|
||||
@ -98,15 +115,15 @@ impl FileLike for UnixStreamSocket {
|
||||
|
||||
fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
|
||||
let is_nonblocking = {
|
||||
let supported_flags = Self::supported_flags(&new_flags);
|
||||
let supported_flags = Self::mask_flags(&new_flags);
|
||||
supported_flags.contains(StatusFlags::O_NONBLOCK)
|
||||
};
|
||||
|
||||
let mut inner = self.0.write();
|
||||
match &mut *inner {
|
||||
Status::Init(init) => init.set_nonblocking(is_nonblocking),
|
||||
Status::Listen(listen) => listen.set_nonblocking(is_nonblocking),
|
||||
Status::Connected(connected) => connected.set_nonblocking(is_nonblocking),
|
||||
State::Init(init) => init.set_nonblocking(is_nonblocking),
|
||||
State::Listen(listen) => listen.set_nonblocking(is_nonblocking),
|
||||
State::Connected(connected) => connected.set_nonblocking(is_nonblocking),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@ -115,114 +132,93 @@ impl FileLike for UnixStreamSocket {
|
||||
impl Socket for UnixStreamSocket {
|
||||
fn bind(&self, sockaddr: SocketAddr) -> Result<()> {
|
||||
let addr = UnixSocketAddr::try_from(sockaddr)?;
|
||||
let mut inner = self.0.write();
|
||||
match &mut *inner {
|
||||
Status::Init(init) => init.bind(addr),
|
||||
Status::Listen(_) | Status::Connected(_) => {
|
||||
return_errno_with_message!(
|
||||
Errno::EINVAL,
|
||||
"cannot bind a listening or connected socket"
|
||||
);
|
||||
} // FIXME: Maybe binding a connected sockted should also be allowed?
|
||||
}
|
||||
|
||||
let init = match &*self.0.read() {
|
||||
State::Init(init) => init.clone(),
|
||||
_ => return_errno_with_message!(
|
||||
Errno::EINVAL,
|
||||
"cannot bind a listening or connected socket"
|
||||
),
|
||||
// FIXME: Maybe binding a connected socket should also be allowed?
|
||||
};
|
||||
|
||||
init.bind(&addr)
|
||||
}
|
||||
|
||||
fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
|
||||
let mut inner = self.0.write();
|
||||
match &*inner {
|
||||
Status::Init(init) => {
|
||||
let remote_addr = UnixSocketAddr::try_from(sockaddr)?;
|
||||
let addr = init.bound_addr();
|
||||
if let Some(addr) = addr {
|
||||
if addr.path() == remote_addr.path() {
|
||||
return_errno_with_message!(
|
||||
Errno::EINVAL,
|
||||
"try to connect to self is invalid"
|
||||
);
|
||||
}
|
||||
let remote_addr = {
|
||||
let unix_socket_addr = UnixSocketAddr::try_from(sockaddr)?;
|
||||
match unix_socket_addr {
|
||||
UnixSocketAddr::Abstract(abstract_name) => {
|
||||
UnixSocketAddrBound::Abstract(abstract_name)
|
||||
}
|
||||
UnixSocketAddr::Path(path) => {
|
||||
let dentry = lookup_socket_file(&path)?;
|
||||
UnixSocketAddrBound::Path(dentry)
|
||||
}
|
||||
let (this_end, remote_end) = Endpoint::end_pair(init.is_nonblocking())?;
|
||||
remote_end.set_addr(remote_addr.clone());
|
||||
if let Some(addr) = addr {
|
||||
this_end.set_addr(addr.clone());
|
||||
};
|
||||
ACTIVE_LISTENERS.push_incoming(&remote_addr, remote_end)?;
|
||||
*inner = Status::Connected(Connected::new(this_end));
|
||||
Ok(())
|
||||
}
|
||||
Status::Listen(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is listened")
|
||||
}
|
||||
Status::Connected(_) => {
|
||||
};
|
||||
|
||||
let init = match &*self.0.read() {
|
||||
State::Init(init) => init.clone(),
|
||||
State::Listen(_) => return_errno_with_message!(Errno::EINVAL, "the socket is listened"),
|
||||
State::Connected(_) => {
|
||||
return_errno_with_message!(Errno::EISCONN, "the socket is connected")
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let connected = init.connect(&remote_addr)?;
|
||||
|
||||
*self.0.write() = State::Connected(Arc::new(connected));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn listen(&self, backlog: usize) -> Result<()> {
|
||||
let mut inner = self.0.write();
|
||||
match &*inner {
|
||||
Status::Init(init) => {
|
||||
let addr = init.bound_addr().ok_or(Error::with_message(
|
||||
Errno::EINVAL,
|
||||
"the socket is not bound",
|
||||
))?;
|
||||
ACTIVE_LISTENERS.add_listener(addr, backlog)?;
|
||||
*inner = Status::Listen(Listen::new(addr.clone(), init.is_nonblocking()));
|
||||
return Ok(());
|
||||
let init = match &*self.0.read() {
|
||||
State::Init(init) => init.clone(),
|
||||
State::Listen(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is already listening")
|
||||
}
|
||||
Status::Listen(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is already listened")
|
||||
}
|
||||
Status::Connected(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is already connected")
|
||||
State::Connected(_) => {
|
||||
return_errno_with_message!(Errno::EISCONN, "the socket is already connected")
|
||||
}
|
||||
};
|
||||
|
||||
let addr = init.addr().ok_or(Error::with_message(
|
||||
Errno::EINVAL,
|
||||
"the socket is not bound",
|
||||
))?;
|
||||
|
||||
let listener = Listener::new(addr.clone(), backlog, init.is_nonblocking())?;
|
||||
*self.0.write() = State::Listen(Arc::new(listener));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
||||
let inner = self.0.read();
|
||||
match &*inner {
|
||||
Status::Listen(listen) => {
|
||||
let is_nonblocking = listen.is_nonblocking();
|
||||
let addr = listen.addr().clone();
|
||||
drop(inner);
|
||||
// Avoid lock when waiting
|
||||
let connected = {
|
||||
let local_endpoint = ACTIVE_LISTENERS.pop_incoming(is_nonblocking, &addr)?;
|
||||
Connected::new(local_endpoint)
|
||||
};
|
||||
let listen = match &*self.0.read() {
|
||||
State::Listen(listen) => listen.clone(),
|
||||
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not listening"),
|
||||
};
|
||||
|
||||
let peer_addr = match connected.peer_addr() {
|
||||
None => SocketAddr::Unix(String::new()),
|
||||
Some(addr) => SocketAddr::from(addr.clone()),
|
||||
};
|
||||
|
||||
let socket = UnixStreamSocket(RwLock::new(Status::Connected(connected)));
|
||||
return Ok((Arc::new(socket), peer_addr));
|
||||
}
|
||||
Status::Connected(_) | Status::Init(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not listened")
|
||||
}
|
||||
}
|
||||
listen.accept()
|
||||
}
|
||||
|
||||
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
||||
let inner = self.0.read();
|
||||
if let Status::Connected(connected) = &*inner {
|
||||
connected.shutdown(cmd)
|
||||
} else {
|
||||
return_errno_with_message!(Errno::ENOTCONN, "the socked is not connected");
|
||||
}
|
||||
let connected = match &*self.0.read() {
|
||||
State::Connected(connected) => connected.clone(),
|
||||
_ => return_errno_with_message!(Errno::ENOTCONN, "the socked is not connected"),
|
||||
};
|
||||
|
||||
connected.shutdown(cmd)
|
||||
}
|
||||
|
||||
fn addr(&self) -> Result<SocketAddr> {
|
||||
let inner = self.0.read();
|
||||
let addr = match &*inner {
|
||||
Status::Init(init) => init.bound_addr().map(Clone::clone),
|
||||
Status::Listen(listen) => Some(listen.addr().clone()),
|
||||
Status::Connected(connected) => connected.addr(),
|
||||
let addr = match &*self.0.read() {
|
||||
State::Init(init) => init.addr(),
|
||||
State::Listen(listen) => Some(listen.addr().clone()),
|
||||
State::Connected(connected) => connected.addr(),
|
||||
};
|
||||
|
||||
addr.map(Into::<SocketAddr>::into)
|
||||
.ok_or(Error::with_message(
|
||||
Errno::EINVAL,
|
||||
@ -231,31 +227,28 @@ impl Socket for UnixStreamSocket {
|
||||
}
|
||||
|
||||
fn peer_addr(&self) -> Result<SocketAddr> {
|
||||
let inner = self.0.read();
|
||||
if let Status::Connected(connected) = &*inner {
|
||||
match connected.peer_addr() {
|
||||
None => return Ok(SocketAddr::Unix(String::new())),
|
||||
Some(peer_addr) => {
|
||||
return Ok(SocketAddr::from(peer_addr.clone()));
|
||||
}
|
||||
let connected = match &*self.0.read() {
|
||||
State::Connected(connected) => connected.clone(),
|
||||
_ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"),
|
||||
};
|
||||
|
||||
match connected.peer_addr() {
|
||||
None => return Ok(SocketAddr::Unix(UnixSocketAddr::Path(String::new()))),
|
||||
Some(peer_addr) => {
|
||||
return Ok(SocketAddr::from(peer_addr.clone()));
|
||||
}
|
||||
}
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
||||
}
|
||||
|
||||
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
|
||||
let inner = self.0.read();
|
||||
// TODO: deal with flags
|
||||
match &*inner {
|
||||
Status::Connected(connected) => {
|
||||
let read_size = connected.read(buf)?;
|
||||
let peer_addr = self.peer_addr()?;
|
||||
Ok((read_size, peer_addr))
|
||||
}
|
||||
Status::Init(_) | Status::Listen(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not connected")
|
||||
}
|
||||
}
|
||||
let connected = match &*self.0.read() {
|
||||
State::Connected(connected) => connected.clone(),
|
||||
_ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"),
|
||||
};
|
||||
|
||||
let peer_addr = self.peer_addr()?;
|
||||
let read_size = connected.read(buf)?;
|
||||
Ok((read_size, peer_addr))
|
||||
}
|
||||
|
||||
fn sendto(
|
||||
@ -266,13 +259,12 @@ impl Socket for UnixStreamSocket {
|
||||
) -> Result<usize> {
|
||||
debug_assert!(remote.is_none());
|
||||
// TODO: deal with flags
|
||||
let inner = self.0.read();
|
||||
match &*inner {
|
||||
Status::Connected(connected) => connected.write(buf),
|
||||
Status::Init(_) | Status::Listen(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not connected")
|
||||
}
|
||||
}
|
||||
let connected = match &*self.0.read() {
|
||||
State::Connected(connected) => connected.clone(),
|
||||
_ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"),
|
||||
};
|
||||
|
||||
connected.write(buf)
|
||||
}
|
||||
}
|
||||
|
||||
@ -282,6 +274,26 @@ impl Drop for UnixStreamSocket {
|
||||
return;
|
||||
};
|
||||
|
||||
ACTIVE_LISTENERS.remove_listener(&bound_addr);
|
||||
if let State::Listen(_) = &*self.0.read() {
|
||||
unregister_backlog(&bound_addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn lookup_socket_file(path: &str) -> Result<Arc<Dentry>> {
|
||||
let dentry = {
|
||||
let current = current!();
|
||||
let fs = current.fs().read();
|
||||
let fs_path = FsPath::try_from(path)?;
|
||||
fs.lookup(&fs_path)?
|
||||
};
|
||||
|
||||
if dentry.inode_type() != InodeType::Socket {
|
||||
return_errno_with_message!(Errno::ENOTSOCK, "not a socket file")
|
||||
}
|
||||
|
||||
if !dentry.inode_mode().is_readable() || !dentry.inode_mode().is_writable() {
|
||||
return_errno_with_message!(Errno::EACCES, "the socket cannot be read or written")
|
||||
}
|
||||
return Ok(dentry);
|
||||
}
|
||||
|
Reference in New Issue
Block a user