Add unix stream socket

This commit is contained in:
Jianfeng Jiang
2023-07-31 19:23:40 +08:00
committed by Tate, Hongliang Tian
parent af04ef874c
commit 4aeedd16d9
13 changed files with 1175 additions and 1 deletions

View File

@ -0,0 +1,55 @@
use crate::{
net::socket::{unix::addr::UnixSocketAddr, SockShutdownCmd},
prelude::*,
};
use super::endpoint::Endpoint;
pub struct Connected {
local_endpoint: Arc<Endpoint>,
// The peer addr is None if peer is unnamed.
// FIXME: can a socket be bound after the socket is connected?
peer_addr: Option<UnixSocketAddr>,
}
impl Connected {
pub fn new(local_endpoint: Arc<Endpoint>) -> Self {
let peer_addr = local_endpoint.peer_addr();
Connected {
local_endpoint,
peer_addr,
}
}
pub fn addr(&self) -> Option<UnixSocketAddr> {
self.local_endpoint.addr()
}
pub fn peer_addr(&self) -> Option<&UnixSocketAddr> {
self.peer_addr.as_ref()
}
pub fn is_bound(&self) -> bool {
self.addr().is_some()
}
pub fn write(&self, buf: &[u8]) -> Result<usize> {
self.local_endpoint.write(buf)
}
pub fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.local_endpoint.read(buf)
}
pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
self.local_endpoint.shutdown(cmd)
}
pub fn is_nonblocking(&self) -> bool {
self.local_endpoint.is_nonblocking()
}
pub fn set_nonblocking(&self, is_nonblocking: bool) {
self.local_endpoint.set_nonblocking(is_nonblocking).unwrap();
}
}

View File

@ -0,0 +1,123 @@
use crate::{
fs::utils::{Channel, Consumer, IoEvents, Poller, Producer, StatusFlags},
net::socket::{unix::addr::UnixSocketAddr, SockShutdownCmd},
prelude::*,
};
pub struct Endpoint(Inner);
struct Inner {
addr: RwLock<Option<UnixSocketAddr>>,
reader: Consumer<u8>,
writer: Producer<u8>,
peer: Weak<Endpoint>,
}
impl Endpoint {
pub fn end_pair(is_nonblocking: bool) -> Result<(Arc<Endpoint>, Arc<Endpoint>)> {
let flags = if is_nonblocking {
StatusFlags::O_NONBLOCK
} else {
StatusFlags::empty()
};
let (writer_a, reader_b) =
Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split();
let (writer_b, reader_a) =
Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split();
let mut endpoint_b = None;
let endpoint_a = Arc::new_cyclic(|endpoint_a_ref| {
let peer = Arc::new(Endpoint::new(reader_b, writer_b, endpoint_a_ref.clone()));
let endpoint_a = Endpoint::new(reader_a, writer_a, Arc::downgrade(&peer));
endpoint_b = Some(peer);
endpoint_a
});
Ok((endpoint_a, endpoint_b.unwrap()))
}
fn new(reader: Consumer<u8>, writer: Producer<u8>, peer: Weak<Endpoint>) -> Self {
Self(Inner {
addr: RwLock::new(None),
reader,
writer,
peer,
})
}
pub fn addr(&self) -> Option<UnixSocketAddr> {
self.0.addr.read().clone()
}
pub fn set_addr(&self, addr: UnixSocketAddr) {
*self.0.addr.write() = Some(addr);
}
pub fn peer_addr(&self) -> Option<UnixSocketAddr> {
self.0.peer.upgrade().map(|peer| peer.addr()).flatten()
}
pub fn is_nonblocking(&self) -> bool {
let reader_status = self.0.reader.is_nonblocking();
let writer_status = self.0.writer.is_nonblocking();
debug_assert!(reader_status == writer_status);
reader_status
}
pub fn set_nonblocking(&self, is_nonblocking: bool) -> Result<()> {
let reader_flags = self.0.reader.status_flags();
self.0
.reader
.set_status_flags(reader_flags | StatusFlags::O_NONBLOCK)?;
let writer_flags = self.0.writer.status_flags();
self.0
.writer
.set_status_flags(writer_flags | StatusFlags::O_NONBLOCK)?;
Ok(())
}
pub fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.0.reader.read(buf)
}
pub fn write(&self, buf: &[u8]) -> Result<usize> {
self.0.writer.write(buf)
}
pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
if !self.is_connected() {
return_errno_with_message!(Errno::ENOTCONN, "The socket is not connected.");
}
if cmd.shut_read() {
self.0.reader.shutdown();
}
if cmd.shut_write() {
self.0.writer.shutdown();
}
Ok(())
}
pub fn is_connected(&self) -> bool {
self.0.peer.upgrade().is_some()
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let mut events = IoEvents::empty();
// FIXME: should reader and writer use the same mask?
let reader_events = self.0.reader.poll(mask, poller);
let writer_events = self.0.writer.poll(mask, poller);
if reader_events.contains(IoEvents::HUP) || self.0.reader.is_shutdown() {
events |= IoEvents::RDHUP | IoEvents::IN;
if writer_events.contains(IoEvents::ERR) || self.0.writer.is_shutdown() {
events |= IoEvents::HUP | IoEvents::OUT;
}
}
events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT);
events
}
}
const DAFAULT_BUF_SIZE: usize = 4096;

View File

@ -0,0 +1,50 @@
use core::sync::atomic::{AtomicBool, Ordering};
use crate::fs::utils::{IoEvents, Pollee, Poller};
use crate::net::socket::unix::addr::UnixSocketAddr;
use crate::prelude::*;
pub struct Init {
is_nonblocking: AtomicBool,
bind_addr: Option<UnixSocketAddr>,
pollee: Pollee,
}
impl Init {
pub fn new(is_nonblocking: bool) -> Self {
Self {
is_nonblocking: AtomicBool::new(is_nonblocking),
bind_addr: None,
pollee: Pollee::new(IoEvents::empty()),
}
}
pub fn bind(&mut self, mut addr: UnixSocketAddr) -> Result<()> {
if self.bind_addr.is_some() {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound");
}
addr.create_file_and_bind()?;
self.bind_addr = Some(addr);
Ok(())
}
pub fn is_bound(&self) -> bool {
self.bind_addr.is_none()
}
pub fn bound_addr(&self) -> Option<&UnixSocketAddr> {
self.bind_addr.as_ref()
}
pub fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Acquire)
}
pub fn set_nonblocking(&self, is_nonblocking: bool) {
self.is_nonblocking.store(is_nonblocking, Ordering::Release);
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
}

View File

@ -0,0 +1,29 @@
use core::sync::atomic::{AtomicBool, Ordering};
use crate::net::socket::unix::addr::UnixSocketAddr;
pub struct Listen {
addr: UnixSocketAddr,
is_nonblocking: AtomicBool,
}
impl Listen {
pub fn new(addr: UnixSocketAddr, nonblocking: bool) -> Self {
Self {
addr,
is_nonblocking: AtomicBool::new(nonblocking),
}
}
pub fn addr(&self) -> &UnixSocketAddr {
&self.addr
}
pub fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Acquire)
}
pub fn set_nonblocking(&self, is_nonblocking: bool) {
self.is_nonblocking.store(is_nonblocking, Ordering::Release);
}
}

View File

@ -0,0 +1,151 @@
use keyable_arc::KeyableWeak;
use spin::RwLockReadGuard;
use crate::{
fs::utils::{Inode, IoEvents, Pollee, Poller},
net::socket::unix::addr::UnixSocketAddr,
prelude::*,
};
use super::endpoint::Endpoint;
pub static ACTIVE_LISTENERS: ActiveListeners = ActiveListeners::new();
pub struct ActiveListeners {
listeners: RwLock<BTreeMap<KeyableWeak<dyn Inode>, Arc<Listener>>>,
// TODO: For linux, there is also abstract socket domain that a socket addr is not bound to an inode.
}
impl ActiveListeners {
pub const fn new() -> Self {
Self {
listeners: RwLock::new(BTreeMap::new()),
}
}
pub(super) fn add_listener(&self, addr: &UnixSocketAddr, backlog: usize) -> Result<()> {
let inode = create_keyable_inode(addr)?;
let mut listeners = self.listeners.write();
if listeners.contains_key(&inode) {
return_errno_with_message!(Errno::EADDRINUSE, "the addr is already used");
}
let new_listener = Arc::new(Listener::new(backlog));
listeners.insert(inode, new_listener);
Ok(())
}
pub(super) fn get_listener(&self, addr: &UnixSocketAddr) -> Result<Arc<Listener>> {
let listeners = self.listeners.read();
get_listener(&listeners, addr)
}
pub(super) fn pop_incoming(
&self,
nonblocking: bool,
addr: &UnixSocketAddr,
) -> Result<Arc<Endpoint>> {
let poller = Poller::new();
loop {
let listener = {
let listeners = self.listeners.read();
get_listener(&listeners, addr)?
};
if let Some(endpoint) = listener.pop_incoming() {
return Ok(endpoint);
}
if nonblocking {
return_errno_with_message!(Errno::EAGAIN, "no connection comes");
}
let events = {
let mask = IoEvents::IN;
listener.poll(mask, Some(&poller))
};
if events.contains(IoEvents::ERR) | events.contains(IoEvents::HUP) {
return_errno_with_message!(Errno::EINVAL, "connection is refused");
}
if events.is_empty() {
poller.wait();
}
}
}
pub(super) fn push_incoming(
&self,
addr: &UnixSocketAddr,
endpoint: Arc<Endpoint>,
) -> Result<()> {
let listeners = self.listeners.read();
let listener = get_listener(&listeners, addr).map_err(|_| {
Error::with_message(
Errno::ECONNREFUSED,
"no socket is listened at the remote address",
)
})?;
listener.push_incoming(endpoint)
}
pub(super) fn remove_listener(&self, addr: &UnixSocketAddr) {
let Ok(inode) = create_keyable_inode(addr) else {
return;
};
self.listeners.write().remove(&inode);
}
}
fn get_listener(
listeners: &RwLockReadGuard<BTreeMap<KeyableWeak<dyn Inode>, Arc<Listener>>>,
addr: &UnixSocketAddr,
) -> Result<Arc<Listener>> {
let dentry = create_keyable_inode(addr)?;
listeners
.get(&dentry)
.map(Arc::clone)
.ok_or_else(|| Error::with_message(Errno::EINVAL, "the socket is not listened"))
}
pub(super) struct Listener {
pollee: Pollee,
backlog: usize,
incoming_endpoints: Mutex<VecDeque<Arc<Endpoint>>>,
}
impl Listener {
pub fn new(backlog: usize) -> Self {
Self {
pollee: Pollee::new(IoEvents::empty()),
backlog,
incoming_endpoints: Mutex::new(VecDeque::with_capacity(backlog)),
}
}
pub fn push_incoming(&self, endpoint: Arc<Endpoint>) -> Result<()> {
let mut endpoints = self.incoming_endpoints.lock();
if endpoints.len() >= self.backlog {
return_errno_with_message!(Errno::ECONNREFUSED, "incoming_endpoints is full");
}
endpoints.push_back(endpoint);
self.pollee.add_events(IoEvents::IN);
Ok(())
}
pub fn pop_incoming(&self) -> Option<Arc<Endpoint>> {
let mut incoming_endpoints = self.incoming_endpoints.lock();
let endpoint = incoming_endpoints.pop_front();
if endpoint.is_none() {
self.pollee.del_events(IoEvents::IN);
}
endpoint
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
// Lock to avoid any events may change pollee state when we poll
let _lock = self.incoming_endpoints.lock();
self.pollee.poll(mask, poller)
}
}
fn create_keyable_inode(addr: &UnixSocketAddr) -> Result<KeyableWeak<dyn Inode>> {
let dentry = addr.dentry()?;
let inode = dentry.inode();
Ok(KeyableWeak::from(inode))
}

View File

@ -0,0 +1,9 @@
mod connected;
mod endpoint;
mod init;
mod listen;
mod listener;
pub mod stream;
pub use listener::{ActiveListeners, ACTIVE_LISTENERS};
pub use stream::UnixStreamSocket;

View File

@ -0,0 +1,287 @@
use crate::fs::file_handle::FileLike;
use crate::fs::utils::{IoEvents, Poller, StatusFlags};
use crate::net::socket::unix::addr::UnixSocketAddr;
use crate::net::socket::util::send_recv_flags::SendRecvFlags;
use crate::net::socket::util::sockaddr::SocketAddr;
use crate::net::socket::{SockShutdownCmd, Socket};
use crate::prelude::*;
use super::connected::Connected;
use super::endpoint::Endpoint;
use super::init::Init;
use super::listen::Listen;
use super::ACTIVE_LISTENERS;
pub struct UnixStreamSocket(RwLock<Status>);
enum Status {
Init(Init),
Listen(Listen),
Connected(Connected),
}
impl UnixStreamSocket {
pub fn new(nonblocking: bool) -> Self {
let status = Status::Init(Init::new(nonblocking));
Self(RwLock::new(status))
}
pub fn new_pair(nonblocking: bool) -> Result<(Arc<Self>, Arc<Self>)> {
let (end_a, end_b) = Endpoint::end_pair(nonblocking)?;
let connected_a = UnixStreamSocket(RwLock::new(Status::Connected(Connected::new(end_a))));
let connected_b = UnixStreamSocket(RwLock::new(Status::Connected(Connected::new(end_b))));
Ok((Arc::new(connected_a), Arc::new(connected_b)))
}
fn bound_addr(&self) -> Option<UnixSocketAddr> {
let status = self.0.read();
match &*status {
Status::Init(init) => init.bound_addr().map(Clone::clone),
Status::Listen(listen) => Some(listen.addr().clone()),
Status::Connected(connected) => connected.addr(),
}
}
fn supported_flags(status_flags: &StatusFlags) -> StatusFlags {
const SUPPORTED_FLAGS: StatusFlags = StatusFlags::O_NONBLOCK;
const UNSUPPORTED_FLAGS: StatusFlags = SUPPORTED_FLAGS.complement();
if status_flags.intersects(UNSUPPORTED_FLAGS) {
warn!("ignore unsupported flags");
}
status_flags.intersection(SUPPORTED_FLAGS)
}
}
impl FileLike for UnixStreamSocket {
fn as_socket(&self) -> Option<&dyn Socket> {
Some(self)
}
fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.recvfrom(buf, SendRecvFlags::empty())
.map(|(read_size, _)| read_size)
}
fn write(&self, buf: &[u8]) -> Result<usize> {
self.sendto(buf, None, SendRecvFlags::empty())
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let inner = self.0.read();
match &*inner {
Status::Init(init) => init.poll(mask, poller),
Status::Listen(listen) => {
let addr = listen.addr();
let listener = ACTIVE_LISTENERS.get_listener(addr).unwrap();
listener.poll(mask, poller)
}
Status::Connected(connet) => todo!(),
}
}
fn status_flags(&self) -> StatusFlags {
let inner = self.0.read();
let is_nonblocking = match &*inner {
Status::Init(init) => init.is_nonblocking(),
Status::Listen(listen) => listen.is_nonblocking(),
Status::Connected(connected) => connected.is_nonblocking(),
};
if is_nonblocking {
StatusFlags::O_NONBLOCK
} else {
StatusFlags::empty()
}
}
fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
let is_nonblocking = {
let supported_flags = Self::supported_flags(&new_flags);
supported_flags.contains(StatusFlags::O_NONBLOCK)
};
let mut inner = self.0.write();
match &mut *inner {
Status::Init(init) => init.set_nonblocking(is_nonblocking),
Status::Listen(listen) => listen.set_nonblocking(is_nonblocking),
Status::Connected(connected) => connected.set_nonblocking(is_nonblocking),
}
Ok(())
}
}
impl Socket for UnixStreamSocket {
fn bind(&self, sockaddr: SocketAddr) -> Result<()> {
let addr = UnixSocketAddr::try_from(sockaddr)?;
let mut inner = self.0.write();
match &mut *inner {
Status::Init(init) => init.bind(addr),
Status::Listen(_) | Status::Connected(_) => {
return_errno_with_message!(
Errno::EINVAL,
"cannot bind a listening or connected socket"
);
} // FIXME: Maybe binding a connected sockted should also be allowed?
}
}
fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
let mut inner = self.0.write();
match &*inner {
Status::Init(init) => {
let remote_addr = UnixSocketAddr::try_from(sockaddr)?;
let addr = init.bound_addr();
if let Some(addr) = addr {
if addr.path() == remote_addr.path() {
return_errno_with_message!(
Errno::EINVAL,
"try to connect to self is invalid"
);
}
}
let (this_end, remote_end) = Endpoint::end_pair(init.is_nonblocking())?;
remote_end.set_addr(remote_addr.clone());
if let Some(addr) = addr {
this_end.set_addr(addr.clone());
};
ACTIVE_LISTENERS.push_incoming(&remote_addr, remote_end)?;
*inner = Status::Connected(Connected::new(this_end));
Ok(())
}
Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is listened")
}
Status::Connected(_) => {
return_errno_with_message!(Errno::EISCONN, "the socket is connected")
}
}
}
fn listen(&self, backlog: usize) -> Result<()> {
let mut inner = self.0.write();
match &*inner {
Status::Init(init) => {
let addr = init.bound_addr().ok_or(Error::with_message(
Errno::EINVAL,
"the socket is not bound",
))?;
ACTIVE_LISTENERS.add_listener(addr, backlog)?;
*inner = Status::Listen(Listen::new(addr.clone(), init.is_nonblocking()));
return Ok(());
}
Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is already listened")
}
Status::Connected(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is already connected")
}
};
}
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let inner = self.0.read();
match &*inner {
Status::Listen(listen) => {
let is_nonblocking = listen.is_nonblocking();
let addr = listen.addr().clone();
drop(inner);
// Avoid lock when waiting
let connected = {
let local_endpoint = ACTIVE_LISTENERS.pop_incoming(is_nonblocking, &addr)?;
Connected::new(local_endpoint)
};
let peer_addr = match connected.peer_addr() {
None => SocketAddr::Unix(String::new()),
Some(addr) => SocketAddr::from(addr.clone()),
};
let socket = UnixStreamSocket(RwLock::new(Status::Connected(connected)));
return Ok((Arc::new(socket), peer_addr));
}
Status::Connected(_) | Status::Init(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is not listened")
}
}
}
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
let inner = self.0.read();
if let Status::Connected(connected) = &*inner {
connected.shutdown(cmd)
} else {
return_errno_with_message!(Errno::ENOTCONN, "the socked is not connected");
}
}
fn addr(&self) -> Result<SocketAddr> {
let inner = self.0.read();
let addr = match &*inner {
Status::Init(init) => init.bound_addr().map(Clone::clone),
Status::Listen(listen) => Some(listen.addr().clone()),
Status::Connected(connected) => connected.addr(),
};
addr.map(Into::<SocketAddr>::into)
.ok_or(Error::with_message(
Errno::EINVAL,
"the socket does not bind to addr",
))
}
fn peer_addr(&self) -> Result<SocketAddr> {
let inner = self.0.read();
if let Status::Connected(connected) = &*inner {
match connected.peer_addr() {
None => return Ok(SocketAddr::Unix(String::new())),
Some(peer_addr) => {
return Ok(SocketAddr::from(peer_addr.clone()));
}
}
}
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
}
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let inner = self.0.read();
// TODO: deal with flags
match &*inner {
Status::Connected(connected) => {
let read_size = connected.read(buf)?;
let peer_addr = self.peer_addr()?;
Ok((read_size, peer_addr))
}
Status::Init(_) | Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected")
}
}
}
fn sendto(
&self,
buf: &[u8],
remote: Option<SocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
debug_assert!(remote.is_none());
// TODO: deal with flags
let inner = self.0.read();
match &*inner {
Status::Connected(connected) => connected.write(buf),
Status::Init(_) | Status::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected")
}
}
}
}
impl Drop for UnixStreamSocket {
fn drop(&mut self) {
let Some(bound_addr) = self.bound_addr() else {
return;
};
ACTIVE_LISTENERS.remove_listener(&bound_addr);
}
}