Refactor unix stream socket implementation

This commit is contained in:
Jianfeng Jiang
2023-08-09 15:10:33 +08:00
committed by Tate, Hongliang Tian
parent 93429ae2c9
commit 3f15bcaf5d
17 changed files with 495 additions and 421 deletions

View File

@ -1,6 +1,8 @@
use crate::fs::file_handle::FileLike;
use crate::fs::utils::{IoEvents, Poller, StatusFlags};
use crate::net::socket::unix::addr::UnixSocketAddr;
use crate::fs::fs_resolver::FsPath;
use crate::fs::utils::{Dentry, InodeType, IoEvents, Poller, StatusFlags};
use crate::net::socket::unix::addr::UnixSocketAddrBound;
use crate::net::socket::unix::UnixSocketAddr;
use crate::net::socket::util::send_recv_flags::SendRecvFlags;
use crate::net::socket::util::sockaddr::SocketAddr;
use crate::net::socket::{SockShutdownCmd, Socket};
@ -9,40 +11,59 @@ use crate::prelude::*;
use super::connected::Connected;
use super::endpoint::Endpoint;
use super::init::Init;
use super::listen::Listen;
use super::ACTIVE_LISTENERS;
use super::listener::{unregister_backlog, Listener};
pub struct UnixStreamSocket(RwLock<Status>);
pub struct UnixStreamSocket(RwLock<State>);
enum Status {
Init(Init),
Listen(Listen),
Connected(Connected),
impl UnixStreamSocket {
pub(super) fn new_init(init: Init) -> Self {
Self(RwLock::new(State::Init(Arc::new(init))))
}
pub(super) fn new_listen(listen: Listener) -> Self {
Self(RwLock::new(State::Listen(Arc::new(listen))))
}
pub(super) fn new_connected(connected: Connected) -> Self {
Self(RwLock::new(State::Connected(Arc::new(connected))))
}
}
enum State {
Init(Arc<Init>),
Listen(Arc<Listener>),
Connected(Arc<Connected>),
}
impl UnixStreamSocket {
pub fn new(nonblocking: bool) -> Self {
let status = Status::Init(Init::new(nonblocking));
Self(RwLock::new(status))
let init = Init::new(nonblocking);
Self::new_init(init)
}
pub fn new_pair(nonblocking: bool) -> Result<(Arc<Self>, Arc<Self>)> {
let (end_a, end_b) = Endpoint::end_pair(nonblocking)?;
let connected_a = UnixStreamSocket(RwLock::new(Status::Connected(Connected::new(end_a))));
let connected_b = UnixStreamSocket(RwLock::new(Status::Connected(Connected::new(end_b))));
let (end_a, end_b) = Endpoint::new_pair(nonblocking)?;
let connected_a = {
let connected = Connected::new(end_a);
Self::new_connected(connected)
};
let connected_b = {
let connected = Connected::new(end_b);
Self::new_connected(connected)
};
Ok((Arc::new(connected_a), Arc::new(connected_b)))
}
fn bound_addr(&self) -> Option<UnixSocketAddr> {
fn bound_addr(&self) -> Option<UnixSocketAddrBound> {
let status = self.0.read();
match &*status {
Status::Init(init) => init.bound_addr().map(Clone::clone),
Status::Listen(listen) => Some(listen.addr().clone()),
Status::Connected(connected) => connected.addr(),
State::Init(init) => init.addr(),
State::Listen(listen) => Some(listen.addr().clone()),
State::Connected(connected) => connected.addr(),
}
}
fn supported_flags(status_flags: &StatusFlags) -> StatusFlags {
fn mask_flags(status_flags: &StatusFlags) -> StatusFlags {
const SUPPORTED_FLAGS: StatusFlags = StatusFlags::O_NONBLOCK;
const UNSUPPORTED_FLAGS: StatusFlags = SUPPORTED_FLAGS.complement();
@ -71,22 +92,18 @@ impl FileLike for UnixStreamSocket {
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let inner = self.0.read();
match &*inner {
Status::Init(init) => init.poll(mask, poller),
Status::Listen(listen) => {
let addr = listen.addr();
let listener = ACTIVE_LISTENERS.get_listener(addr).unwrap();
listener.poll(mask, poller)
}
Status::Connected(connet) => todo!(),
State::Init(init) => init.poll(mask, poller),
State::Listen(listen) => listen.poll(mask, poller),
State::Connected(connected) => connected.poll(mask, poller),
}
}
fn status_flags(&self) -> StatusFlags {
let inner = self.0.read();
let is_nonblocking = match &*inner {
Status::Init(init) => init.is_nonblocking(),
Status::Listen(listen) => listen.is_nonblocking(),
Status::Connected(connected) => connected.is_nonblocking(),
State::Init(init) => init.is_nonblocking(),
State::Listen(listen) => listen.is_nonblocking(),
State::Connected(connected) => connected.is_nonblocking(),
};
if is_nonblocking {
@ -98,15 +115,15 @@ impl FileLike for UnixStreamSocket {
fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
let is_nonblocking = {
let supported_flags = Self::supported_flags(&new_flags);
let supported_flags = Self::mask_flags(&new_flags);
supported_flags.contains(StatusFlags::O_NONBLOCK)
};
let mut inner = self.0.write();
match &mut *inner {
Status::Init(init) => init.set_nonblocking(is_nonblocking),
Status::Listen(listen) => listen.set_nonblocking(is_nonblocking),
Status::Connected(connected) => connected.set_nonblocking(is_nonblocking),
State::Init(init) => init.set_nonblocking(is_nonblocking),
State::Listen(listen) => listen.set_nonblocking(is_nonblocking),
State::Connected(connected) => connected.set_nonblocking(is_nonblocking),
}
Ok(())
}
@ -115,114 +132,93 @@ impl FileLike for UnixStreamSocket {
impl Socket for UnixStreamSocket {
fn bind(&self, sockaddr: SocketAddr) -> Result<()> {
let addr = UnixSocketAddr::try_from(sockaddr)?;
let mut inner = self.0.write();
match &mut *inner {
Status::Init(init) => init.bind(addr),
Status::Listen(_) | Status::Connected(_) => {
return_errno_with_message!(
Errno::EINVAL,
"cannot bind a listening or connected socket"
);
} // FIXME: Maybe binding a connected sockted should also be allowed?
}
let init = match &*self.0.read() {
State::Init(init) => init.clone(),
_ => return_errno_with_message!(
Errno::EINVAL,
"cannot bind a listening or connected socket"
),
// FIXME: Maybe binding a connected socket should also be allowed?
};
init.bind(&addr)
}
fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
let mut inner = self.0.write();
match &*inner {
Status::Init(init) => {
let remote_addr = UnixSocketAddr::try_from(sockaddr)?;
let addr = init.bound_addr();
if let Some(addr) = addr {
if addr.path() == remote_addr.path() {
return_errno_with_message!(
Errno::EINVAL,
"try to connect to self is invalid"
);
}
let remote_addr = {
let unix_socket_addr = UnixSocketAddr::try_from(sockaddr)?;
match unix_socket_addr {
UnixSocketAddr::Abstract(abstract_name) => {
UnixSocketAddrBound::Abstract(abstract_name)
}
UnixSocketAddr::Path(path) => {
let dentry = lookup_socket_file(&path)?;
UnixSocketAddrBound::Path(dentry)
}
let (this_end, remote_end) = Endpoint::end_pair(init.is_nonblocking())?;
remote_end.set_addr(remote_addr.clone());
if let Some(addr) = addr {
this_end.set_addr(addr.clone());
};
ACTIVE_LISTENERS.push_incoming(&remote_addr, remote_end)?;
*inner = Status::Connected(Connected::new(this_end));
Ok(())
}
Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is listened")
}
Status::Connected(_) => {
};
let init = match &*self.0.read() {
State::Init(init) => init.clone(),
State::Listen(_) => return_errno_with_message!(Errno::EINVAL, "the socket is listened"),
State::Connected(_) => {
return_errno_with_message!(Errno::EISCONN, "the socket is connected")
}
}
};
let connected = init.connect(&remote_addr)?;
*self.0.write() = State::Connected(Arc::new(connected));
Ok(())
}
fn listen(&self, backlog: usize) -> Result<()> {
let mut inner = self.0.write();
match &*inner {
Status::Init(init) => {
let addr = init.bound_addr().ok_or(Error::with_message(
Errno::EINVAL,
"the socket is not bound",
))?;
ACTIVE_LISTENERS.add_listener(addr, backlog)?;
*inner = Status::Listen(Listen::new(addr.clone(), init.is_nonblocking()));
return Ok(());
let init = match &*self.0.read() {
State::Init(init) => init.clone(),
State::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is already listening")
}
Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is already listened")
}
Status::Connected(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is already connected")
State::Connected(_) => {
return_errno_with_message!(Errno::EISCONN, "the socket is already connected")
}
};
let addr = init.addr().ok_or(Error::with_message(
Errno::EINVAL,
"the socket is not bound",
))?;
let listener = Listener::new(addr.clone(), backlog, init.is_nonblocking())?;
*self.0.write() = State::Listen(Arc::new(listener));
Ok(())
}
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let inner = self.0.read();
match &*inner {
Status::Listen(listen) => {
let is_nonblocking = listen.is_nonblocking();
let addr = listen.addr().clone();
drop(inner);
// Avoid lock when waiting
let connected = {
let local_endpoint = ACTIVE_LISTENERS.pop_incoming(is_nonblocking, &addr)?;
Connected::new(local_endpoint)
};
let listen = match &*self.0.read() {
State::Listen(listen) => listen.clone(),
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not listening"),
};
let peer_addr = match connected.peer_addr() {
None => SocketAddr::Unix(String::new()),
Some(addr) => SocketAddr::from(addr.clone()),
};
let socket = UnixStreamSocket(RwLock::new(Status::Connected(connected)));
return Ok((Arc::new(socket), peer_addr));
}
Status::Connected(_) | Status::Init(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is not listened")
}
}
listen.accept()
}
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
let inner = self.0.read();
if let Status::Connected(connected) = &*inner {
connected.shutdown(cmd)
} else {
return_errno_with_message!(Errno::ENOTCONN, "the socked is not connected");
}
let connected = match &*self.0.read() {
State::Connected(connected) => connected.clone(),
_ => return_errno_with_message!(Errno::ENOTCONN, "the socked is not connected"),
};
connected.shutdown(cmd)
}
fn addr(&self) -> Result<SocketAddr> {
let inner = self.0.read();
let addr = match &*inner {
Status::Init(init) => init.bound_addr().map(Clone::clone),
Status::Listen(listen) => Some(listen.addr().clone()),
Status::Connected(connected) => connected.addr(),
let addr = match &*self.0.read() {
State::Init(init) => init.addr(),
State::Listen(listen) => Some(listen.addr().clone()),
State::Connected(connected) => connected.addr(),
};
addr.map(Into::<SocketAddr>::into)
.ok_or(Error::with_message(
Errno::EINVAL,
@ -231,31 +227,28 @@ impl Socket for UnixStreamSocket {
}
fn peer_addr(&self) -> Result<SocketAddr> {
let inner = self.0.read();
if let Status::Connected(connected) = &*inner {
match connected.peer_addr() {
None => return Ok(SocketAddr::Unix(String::new())),
Some(peer_addr) => {
return Ok(SocketAddr::from(peer_addr.clone()));
}
let connected = match &*self.0.read() {
State::Connected(connected) => connected.clone(),
_ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"),
};
match connected.peer_addr() {
None => return Ok(SocketAddr::Unix(UnixSocketAddr::Path(String::new()))),
Some(peer_addr) => {
return Ok(SocketAddr::from(peer_addr.clone()));
}
}
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
}
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let inner = self.0.read();
// TODO: deal with flags
match &*inner {
Status::Connected(connected) => {
let read_size = connected.read(buf)?;
let peer_addr = self.peer_addr()?;
Ok((read_size, peer_addr))
}
Status::Init(_) | Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected")
}
}
let connected = match &*self.0.read() {
State::Connected(connected) => connected.clone(),
_ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"),
};
let peer_addr = self.peer_addr()?;
let read_size = connected.read(buf)?;
Ok((read_size, peer_addr))
}
fn sendto(
@ -266,13 +259,12 @@ impl Socket for UnixStreamSocket {
) -> Result<usize> {
debug_assert!(remote.is_none());
// TODO: deal with flags
let inner = self.0.read();
match &*inner {
Status::Connected(connected) => connected.write(buf),
Status::Init(_) | Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected")
}
}
let connected = match &*self.0.read() {
State::Connected(connected) => connected.clone(),
_ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"),
};
connected.write(buf)
}
}
@ -282,6 +274,26 @@ impl Drop for UnixStreamSocket {
return;
};
ACTIVE_LISTENERS.remove_listener(&bound_addr);
if let State::Listen(_) = &*self.0.read() {
unregister_backlog(&bound_addr);
}
}
}
fn lookup_socket_file(path: &str) -> Result<Arc<Dentry>> {
let dentry = {
let current = current!();
let fs = current.fs().read();
let fs_path = FsPath::try_from(path)?;
fs.lookup(&fs_path)?
};
if dentry.inode_type() != InodeType::Socket {
return_errno_with_message!(Errno::ENOTSOCK, "not a socket file")
}
if !dentry.inode_mode().is_readable() || !dentry.inode_mode().is_writable() {
return_errno_with_message!(Errno::EACCES, "the socket cannot be read or written")
}
return Ok(dentry);
}