Refactor project structure

This commit is contained in:
Zhang Junyang
2024-02-27 16:40:16 +08:00
committed by Tate, Hongliang Tian
parent bd878dd1c9
commit e3c227ae06
474 changed files with 77 additions and 77 deletions

View File

@ -0,0 +1,54 @@
// SPDX-License-Identifier: MPL-2.0
use core::ops::{Deref, DerefMut};
use crate::prelude::*;
/// AlwaysSome is a wrapper for Option.
///
/// AlwaysSome should always be Some(T), so we can treat it as a smart pointer.
/// If it becomes None, the AlwaysSome should be viewed invalid and cannot be used anymore.
pub struct AlwaysSome<T>(Option<T>);
impl<T> AlwaysSome<T> {
pub fn new(value: T) -> Self {
AlwaysSome(Some(value))
}
pub fn try_take_with<R, E: Into<Error>, F: FnOnce(T) -> core::result::Result<R, (E, T)>>(
&mut self,
f: F,
) -> Result<R> {
let value = if let Some(value) = self.0.take() {
value
} else {
return_errno_with_message!(Errno::EINVAL, "the take cell is none");
};
match f(value) {
Ok(res) => Ok(res),
Err((err, t)) => {
self.0 = Some(t);
Err(err.into())
}
}
}
/// Takes inner value
pub fn take(&mut self) -> T {
debug_assert!(self.0.is_some());
self.0.take().unwrap()
}
}
impl<T> Deref for AlwaysSome<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.0.as_ref().unwrap()
}
}
impl<T> DerefMut for AlwaysSome<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut().unwrap()
}
}

View File

@ -0,0 +1,68 @@
// SPDX-License-Identifier: MPL-2.0
use crate::{
net::{
iface::{AnyBoundSocket, AnyUnboundSocket, BindPortConfig, Iface, IpAddress, IpEndpoint},
IFACES,
},
prelude::*,
};
pub fn get_iface_to_bind(ip_addr: &IpAddress) -> Option<Arc<dyn Iface>> {
let ifaces = IFACES.get().unwrap();
let IpAddress::Ipv4(ipv4_addr) = ip_addr;
ifaces
.iter()
.find(|iface| {
if let Some(iface_ipv4_addr) = iface.ipv4_addr() {
iface_ipv4_addr == *ipv4_addr
} else {
false
}
})
.map(Clone::clone)
}
/// Get a suitable iface to deal with sendto/connect request if the socket is not bound to an iface.
/// If the remote address is the same as that of some iface, we will use the iface.
/// Otherwise, we will use a default interface.
fn get_ephemeral_iface(remote_ip_addr: &IpAddress) -> Arc<dyn Iface> {
let ifaces = IFACES.get().unwrap();
let IpAddress::Ipv4(remote_ipv4_addr) = remote_ip_addr;
if let Some(iface) = ifaces.iter().find(|iface| {
if let Some(iface_ipv4_addr) = iface.ipv4_addr() {
iface_ipv4_addr == *remote_ipv4_addr
} else {
false
}
}) {
return iface.clone();
}
// FIXME: use the virtio-net as the default interface
ifaces[0].clone()
}
pub(super) fn bind_socket(
unbound_socket: Box<AnyUnboundSocket>,
endpoint: IpEndpoint,
can_reuse: bool,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Box<AnyUnboundSocket>)> {
let iface = match get_iface_to_bind(&endpoint.addr) {
Some(iface) => iface,
None => {
let err = Error::with_message(Errno::EADDRNOTAVAIL, "Request iface is not available");
return Err((err, unbound_socket));
}
};
let bind_port_config = match BindPortConfig::new(endpoint.port, can_reuse) {
Ok(config) => config,
Err(e) => return Err((e, unbound_socket)),
};
iface.bind_socket(unbound_socket, bind_port_config)
}
pub fn get_ephemeral_endpoint(remote_endpoint: &IpEndpoint) -> IpEndpoint {
let iface = get_ephemeral_iface(&remote_endpoint.addr);
let ip_addr = iface.ipv4_addr().unwrap();
IpEndpoint::new(IpAddress::Ipv4(ip_addr), 0)
}

View File

@ -0,0 +1,108 @@
// SPDX-License-Identifier: MPL-2.0
use crate::{
events::{IoEvents, Observer},
net::{
iface::{AnyBoundSocket, IpEndpoint, RawUdpSocket},
poll_ifaces,
socket::util::send_recv_flags::SendRecvFlags,
},
prelude::*,
process::signal::{Pollee, Poller},
};
pub struct BoundDatagram {
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: RwLock<Option<IpEndpoint>>,
pollee: Pollee,
}
impl BoundDatagram {
pub fn new(bound_socket: Arc<AnyBoundSocket>, pollee: Pollee) -> Arc<Self> {
let bound = Arc::new(Self {
bound_socket,
remote_endpoint: RwLock::new(None),
pollee,
});
bound.bound_socket.set_observer(Arc::downgrade(&bound) as _);
bound
}
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
self.remote_endpoint
.read()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "remote endpoint is not specified"))
}
pub fn set_remote_endpoint(&self, endpoint: IpEndpoint) {
*self.remote_endpoint.write() = Some(endpoint);
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.bound_socket.local_endpoint().ok_or_else(|| {
Error::with_message(Errno::EINVAL, "socket does not bind to local endpoint")
})
}
pub fn try_recvfrom(
&self,
buf: &mut [u8],
flags: &SendRecvFlags,
) -> Result<(usize, IpEndpoint)> {
poll_ifaces();
let recv_slice = |socket: &mut RawUdpSocket| {
socket
.recv_slice(buf)
.map_err(|_| Error::with_message(Errno::EAGAIN, "recv buf is empty"))
};
self.bound_socket.raw_with(recv_slice)
}
pub fn try_sendto(
&self,
buf: &[u8],
remote: Option<IpEndpoint>,
flags: SendRecvFlags,
) -> Result<usize> {
let remote_endpoint = remote
.or_else(|| self.remote_endpoint().ok())
.ok_or_else(|| Error::with_message(Errno::EINVAL, "udp should provide remote addr"))?;
let send_slice = |socket: &mut RawUdpSocket| {
socket
.send_slice(buf, remote_endpoint)
.map(|_| buf.len())
.map_err(|_| Error::with_message(Errno::EAGAIN, "send udp packet fails"))
};
let len = self.bound_socket.raw_with(send_slice)?;
poll_ifaces();
Ok(len)
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
fn update_io_events(&self) {
self.bound_socket.raw_with(|socket: &mut RawUdpSocket| {
let pollee = &self.pollee;
if socket.can_recv() {
pollee.add_events(IoEvents::IN);
} else {
pollee.del_events(IoEvents::IN);
}
if socket.can_send() {
pollee.add_events(IoEvents::OUT);
} else {
pollee.del_events(IoEvents::OUT);
}
});
}
}
impl Observer<()> for BoundDatagram {
fn on_events(&self, _: &()) {
self.update_io_events();
}
}

View File

@ -0,0 +1,211 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use self::{bound::BoundDatagram, unbound::UnboundDatagram};
use super::{always_some::AlwaysSome, common::get_ephemeral_endpoint};
use crate::{
events::IoEvents,
fs::{file_handle::FileLike, utils::StatusFlags},
net::{
iface::IpEndpoint,
socket::{
util::{send_recv_flags::SendRecvFlags, socket_addr::SocketAddr},
Socket,
},
},
prelude::*,
process::signal::Poller,
};
mod bound;
mod unbound;
pub struct DatagramSocket {
nonblocking: AtomicBool,
inner: RwLock<Inner>,
}
enum Inner {
Unbound(AlwaysSome<UnboundDatagram>),
Bound(Arc<BoundDatagram>),
}
impl Inner {
fn is_bound(&self) -> bool {
matches!(self, Inner::Bound { .. })
}
fn bind(&mut self, endpoint: IpEndpoint) -> Result<Arc<BoundDatagram>> {
let unbound = match self {
Inner::Unbound(unbound) => unbound,
Inner::Bound(..) => return_errno_with_message!(
Errno::EINVAL,
"the socket is already bound to an address"
),
};
let bound = unbound.try_take_with(|unbound| unbound.bind(endpoint))?;
*self = Inner::Bound(bound.clone());
Ok(bound)
}
fn bind_to_ephemeral_endpoint(
&mut self,
remote_endpoint: &IpEndpoint,
) -> Result<Arc<BoundDatagram>> {
let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(endpoint)
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
match self {
Inner::Unbound(unbound) => unbound.poll(mask, poller),
Inner::Bound(bound) => bound.poll(mask, poller),
}
}
}
impl DatagramSocket {
pub fn new(nonblocking: bool) -> Self {
let unbound = UnboundDatagram::new();
Self {
inner: RwLock::new(Inner::Unbound(AlwaysSome::new(unbound))),
nonblocking: AtomicBool::new(nonblocking),
}
}
pub fn is_bound(&self) -> bool {
self.inner.read().is_bound()
}
pub fn is_nonblocking(&self) -> bool {
self.nonblocking.load(Ordering::SeqCst)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.nonblocking.store(nonblocking, Ordering::SeqCst);
}
fn bound(&self) -> Result<Arc<BoundDatagram>> {
if let Inner::Bound(bound) = &*self.inner.read() {
Ok(bound.clone())
} else {
return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint")
}
}
fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<BoundDatagram>> {
// Fast path
if let Inner::Bound(bound) = &*self.inner.read() {
return Ok(bound.clone());
}
// Slow path
let mut inner = self.inner.write();
if let Inner::Bound(bound) = &*inner {
return Ok(bound.clone());
}
inner.bind_to_ephemeral_endpoint(remote_endpoint)
}
}
impl FileLike for DatagramSocket {
fn read(&self, buf: &mut [u8]) -> Result<usize> {
// FIXME: respect flags
let flags = SendRecvFlags::empty();
let (recv_len, _) = self.recvfrom(buf, flags)?;
Ok(recv_len)
}
fn write(&self, buf: &[u8]) -> Result<usize> {
// FIXME: set correct flags
let flags = SendRecvFlags::empty();
self.sendto(buf, None, flags)
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.inner.read().poll(mask, poller)
}
fn as_socket(self: Arc<Self>) -> Option<Arc<dyn Socket>> {
Some(self)
}
fn status_flags(&self) -> StatusFlags {
if self.is_nonblocking() {
StatusFlags::O_NONBLOCK
} else {
StatusFlags::empty()
}
}
fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
if new_flags.contains(StatusFlags::O_NONBLOCK) {
self.set_nonblocking(true);
} else {
self.set_nonblocking(false);
}
Ok(())
}
}
impl Socket for DatagramSocket {
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?;
self.inner.write().bind(endpoint)?;
Ok(())
}
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?;
let bound = self.try_bind_empheral(&endpoint)?;
bound.set_remote_endpoint(endpoint);
Ok(())
}
fn addr(&self) -> Result<SocketAddr> {
self.bound()?.local_endpoint()?.try_into()
}
fn peer_addr(&self) -> Result<SocketAddr> {
self.bound()?.remote_endpoint()?.try_into()
}
// FIXME: respect RecvFromFlags
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
debug_assert!(flags.is_all_supported());
let bound = self.bound()?;
let poller = Poller::new();
loop {
if let Ok((recv_len, remote_endpoint)) = bound.try_recvfrom(buf, &flags) {
let remote_addr = remote_endpoint.try_into()?;
return Ok((recv_len, remote_addr));
}
let events = bound.poll(IoEvents::IN, Some(&poller));
if !events.contains(IoEvents::IN) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try to receive again");
}
// FIXME: deal with recvfrom timeout
poller.wait()?;
}
}
}
fn sendto(
&self,
buf: &[u8],
remote: Option<SocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
debug_assert!(flags.is_all_supported());
let (bound, remote_endpoint) = if let Some(addr) = remote {
let endpoint = addr.try_into()?;
(self.try_bind_empheral(&endpoint)?, Some(endpoint))
} else {
let bound = self.bound()?;
(bound, None)
};
bound.try_sendto(buf, remote_endpoint, flags)
}
}

View File

@ -0,0 +1,53 @@
// SPDX-License-Identifier: MPL-2.0
use super::bound::BoundDatagram;
use crate::{
events::IoEvents,
net::{
iface::{AnyUnboundSocket, IpEndpoint, RawUdpSocket},
socket::ip::common::bind_socket,
},
prelude::*,
process::signal::{Pollee, Poller},
};
pub struct UnboundDatagram {
unbound_socket: Box<AnyUnboundSocket>,
pollee: Pollee,
}
impl UnboundDatagram {
pub fn new() -> Self {
Self {
unbound_socket: Box::new(AnyUnboundSocket::new_udp()),
pollee: Pollee::new(IoEvents::empty()),
}
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
pub fn bind(
self,
endpoint: IpEndpoint,
) -> core::result::Result<Arc<BoundDatagram>, (Error, Self)> {
let bound_socket = match bind_socket(self.unbound_socket, endpoint, false) {
Ok(bound_socket) => bound_socket,
Err((err, unbound_socket)) => {
return Err((
err,
Self {
unbound_socket,
pollee: self.pollee,
},
))
}
};
let bound_endpoint = bound_socket.local_endpoint().unwrap();
bound_socket.raw_with(|socket: &mut RawUdpSocket| {
socket.bind(bound_endpoint).unwrap();
});
Ok(BoundDatagram::new(bound_socket, self.pollee))
}
}

View File

@ -0,0 +1,9 @@
// SPDX-License-Identifier: MPL-2.0
mod always_some;
mod common;
mod datagram;
pub mod stream;
pub use datagram::DatagramSocket;
pub use stream::StreamSocket;

View File

@ -0,0 +1,170 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use crate::{
events::{IoEvents, Observer},
net::{
iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket},
poll_ifaces,
socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd},
},
prelude::*,
process::signal::{Pollee, Poller},
};
pub struct ConnectedStream {
nonblocking: AtomicBool,
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
pollee: Pollee,
}
impl ConnectedStream {
pub fn new(
is_nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
pollee: Pollee,
) -> Arc<Self> {
let connected = Arc::new(Self {
nonblocking: AtomicBool::new(is_nonblocking),
bound_socket,
remote_endpoint,
pollee,
});
connected
.bound_socket
.set_observer(Arc::downgrade(&connected) as _);
connected
}
pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
// TODO: deal with cmd
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
socket.close();
});
poll_ifaces();
Ok(())
}
pub fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, IpEndpoint)> {
debug_assert!(flags.is_all_supported());
let poller = Poller::new();
loop {
let recv_len = self.try_recvfrom(buf, flags)?;
if recv_len > 0 {
let remote_endpoint = self.remote_endpoint()?;
return Ok((recv_len, remote_endpoint));
}
let events = self.poll(IoEvents::IN, Some(&poller));
if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) {
return_errno_with_message!(Errno::ENOTCONN, "recv packet fails");
}
if !events.contains(IoEvents::IN) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try to recv again");
}
// FIXME: deal with receive timeout
poller.wait()?;
}
}
}
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<usize> {
poll_ifaces();
let res = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
socket
.recv_slice(buf)
.map_err(|_| Error::with_message(Errno::ENOTCONN, "fail to recv packet"))
});
self.update_io_events();
res
}
pub fn sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
debug_assert!(flags.is_all_supported());
let poller = Poller::new();
loop {
let sent_len = self.try_sendto(buf, flags)?;
if sent_len > 0 {
return Ok(sent_len);
}
let events = self.poll(IoEvents::OUT, Some(&poller));
if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) {
return_errno_with_message!(Errno::ENOBUFS, "fail to send packets");
}
if !events.contains(IoEvents::OUT) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try to send again");
}
// FIXME: deal with send timeout
poller.wait()?;
}
}
}
fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
let res = self
.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.send_slice(buf))
.map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet"));
match res {
// We have to explicitly invoke `update_io_events` when the send buffer becomes
// full. Note that smoltcp does not think it is an interface event, so calling
// `poll_ifaces` alone is not enough.
Ok(0) => self.update_io_events(),
Ok(_) => poll_ifaces(),
_ => (),
};
res
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.bound_socket
.local_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint"))
}
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
Ok(self.remote_endpoint)
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
fn update_io_events(&self) {
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
let pollee = &self.pollee;
if socket.can_recv() {
pollee.add_events(IoEvents::IN);
} else {
pollee.del_events(IoEvents::IN);
}
if socket.can_send() {
pollee.add_events(IoEvents::OUT);
} else {
pollee.del_events(IoEvents::OUT);
}
});
}
pub fn is_nonblocking(&self) -> bool {
self.nonblocking.load(Ordering::Relaxed)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.nonblocking.store(nonblocking, Ordering::Relaxed);
}
}
impl Observer<()> for ConnectedStream {
fn on_events(&self, _: &()) {
self.update_io_events();
}
}

View File

@ -0,0 +1,155 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use super::{connected::ConnectedStream, init::InitStream};
use crate::{
events::{IoEvents, Observer},
net::{
iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket},
poll_ifaces,
},
prelude::*,
process::signal::{Pollee, Poller},
};
pub struct ConnectingStream {
nonblocking: AtomicBool,
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
conn_result: RwLock<Option<ConnResult>>,
pollee: Pollee,
}
enum ConnResult {
Connected,
Refused,
}
impl ConnectingStream {
pub fn new(
nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
pollee: Pollee,
) -> Result<Arc<Self>> {
bound_socket.do_connect(remote_endpoint)?;
let connecting = Arc::new(Self {
nonblocking: AtomicBool::new(nonblocking),
bound_socket,
remote_endpoint,
conn_result: RwLock::new(None),
pollee,
});
connecting.pollee.reset_events();
connecting
.bound_socket
.set_observer(Arc::downgrade(&connecting) as _);
Ok(connecting)
}
pub fn wait_conn(
&self,
) -> core::result::Result<Arc<ConnectedStream>, (Error, Arc<InitStream>)> {
debug_assert!(!self.is_nonblocking());
let poller = Poller::new();
loop {
poll_ifaces();
match *self.conn_result.read() {
Some(ConnResult::Connected) => {
return Ok(ConnectedStream::new(
self.is_nonblocking(),
self.bound_socket.clone(),
self.remote_endpoint,
self.pollee.clone(),
));
}
Some(ConnResult::Refused) => {
return Err((
Error::with_message(Errno::ECONNREFUSED, "connection refused"),
InitStream::new_bound(
self.is_nonblocking(),
self.bound_socket.clone(),
self.pollee.clone(),
),
));
}
None => (),
};
let events = self.poll(IoEvents::OUT, Some(&poller));
if !events.contains(IoEvents::OUT) {
// FIXME: deal with nonblocking mode & connecting timeout
poller.wait().expect("async connect() not implemented");
}
}
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.bound_socket
.local_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "no local endpoint"))
}
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
Ok(self.remote_endpoint)
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
pub fn is_nonblocking(&self) -> bool {
self.nonblocking.load(Ordering::Relaxed)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.nonblocking.store(nonblocking, Ordering::Relaxed);
}
fn update_io_events(&self) {
if self.conn_result.read().is_some() {
return;
}
let became_writable = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
let mut result = self.conn_result.write();
if result.is_some() {
return false;
}
// Connected
if socket.can_send() {
*result = Some(ConnResult::Connected);
return true;
}
// Connecting
if socket.is_open() {
return false;
}
// Refused
*result = Some(ConnResult::Refused);
true
});
// Either when the connection is established, or when the connection fails, the socket
// shall indicate that it is writable.
//
// TODO: Find a way to turn `ConnectingStream` into `ConnectedStream` or `InitStream`
// here, so non-blocking `connect()` can work correctly. Meanwhile, the latter should
// be responsible to initialize all the I/O events including `IoEvents::OUT`, so the
// following hard-coded event addition can be removed.
if became_writable {
self.pollee.add_events(IoEvents::OUT);
}
}
}
impl Observer<()> for ConnectingStream {
fn on_events(&self, _: &()) {
self.update_io_events();
}
}

View File

@ -0,0 +1,156 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use super::{connecting::ConnectingStream, listen::ListenStream};
use crate::{
events::IoEvents,
net::{
iface::{AnyBoundSocket, AnyUnboundSocket, Iface, IpEndpoint},
socket::ip::{
always_some::AlwaysSome,
common::{bind_socket, get_ephemeral_endpoint},
},
},
prelude::*,
process::signal::{Pollee, Poller},
};
pub struct InitStream {
inner: RwLock<Inner>,
is_nonblocking: AtomicBool,
pollee: Pollee,
}
enum Inner {
Unbound(AlwaysSome<Box<AnyUnboundSocket>>),
Bound(AlwaysSome<Arc<AnyBoundSocket>>),
}
impl Inner {
fn new() -> Inner {
let unbound_socket = Box::new(AnyUnboundSocket::new_tcp());
Inner::Unbound(AlwaysSome::new(unbound_socket))
}
fn is_bound(&self) -> bool {
match self {
Self::Unbound(_) => false,
Self::Bound(_) => true,
}
}
fn bind(&mut self, endpoint: IpEndpoint) -> Result<()> {
let unbound_socket = if let Inner::Unbound(unbound_socket) = self {
unbound_socket
} else {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address");
};
let bound_socket =
unbound_socket.try_take_with(|raw_socket| bind_socket(raw_socket, endpoint, false))?;
*self = Inner::Bound(AlwaysSome::new(bound_socket));
Ok(())
}
fn bind_to_ephemeral_endpoint(&mut self, remote_endpoint: &IpEndpoint) -> Result<()> {
let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(endpoint)
}
fn bound_socket(&self) -> Option<&Arc<AnyBoundSocket>> {
match self {
Inner::Bound(bound_socket) => Some(bound_socket),
Inner::Unbound(_) => None,
}
}
fn iface(&self) -> Option<Arc<dyn Iface>> {
match self {
Inner::Bound(bound_socket) => Some(bound_socket.iface().clone()),
Inner::Unbound(_) => None,
}
}
fn local_endpoint(&self) -> Option<IpEndpoint> {
self.bound_socket()
.and_then(|socket| socket.local_endpoint())
}
}
impl InitStream {
// FIXME: In Linux we have the `POLLOUT` event for a newly created socket, while calling
// `write()` on it triggers `SIGPIPE`/`EPIPE`. No documentation found yet, but confirmed by
// experimentation and Linux source code.
pub fn new(nonblocking: bool) -> Arc<Self> {
Arc::new(Self {
inner: RwLock::new(Inner::new()),
is_nonblocking: AtomicBool::new(nonblocking),
pollee: Pollee::new(IoEvents::empty()),
})
}
pub fn new_bound(
nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>,
pollee: Pollee,
) -> Arc<Self> {
bound_socket.set_observer(Weak::<()>::new());
let inner = Inner::Bound(AlwaysSome::new(bound_socket));
Arc::new(Self {
is_nonblocking: AtomicBool::new(nonblocking),
inner: RwLock::new(inner),
pollee,
})
}
pub fn bind(&self, endpoint: IpEndpoint) -> Result<()> {
self.inner.write().bind(endpoint)
}
pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<ConnectingStream>> {
if !self.inner.read().is_bound() {
self.inner
.write()
.bind_to_ephemeral_endpoint(remote_endpoint)?
}
ConnectingStream::new(
self.is_nonblocking(),
self.inner.read().bound_socket().unwrap().clone(),
*remote_endpoint,
self.pollee.clone(),
)
}
pub fn listen(&self, backlog: usize) -> Result<Arc<ListenStream>> {
let bound_socket = if let Some(bound_socket) = self.inner.read().bound_socket() {
bound_socket.clone()
} else {
return_errno_with_message!(Errno::EINVAL, "cannot listen without bound")
};
ListenStream::new(
self.is_nonblocking(),
bound_socket,
backlog,
self.pollee.clone(),
)
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.inner
.read()
.local_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has local endpoint"))
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
pub fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Relaxed)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
}
}

View File

@ -0,0 +1,183 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use super::connected::ConnectedStream;
use crate::{
events::{IoEvents, Observer},
net::{
iface::{AnyBoundSocket, AnyUnboundSocket, BindPortConfig, IpEndpoint, RawTcpSocket},
poll_ifaces,
},
prelude::*,
process::signal::{Pollee, Poller},
};
pub struct ListenStream {
is_nonblocking: AtomicBool,
backlog: usize,
/// A bound socket held to ensure the TCP port cannot be released
bound_socket: Arc<AnyBoundSocket>,
/// Backlog sockets listening at the local endpoint
backlog_sockets: RwLock<Vec<BacklogSocket>>,
pollee: Pollee,
}
impl ListenStream {
pub fn new(
nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>,
backlog: usize,
pollee: Pollee,
) -> Result<Arc<Self>> {
let listen_stream = Arc::new(Self {
is_nonblocking: AtomicBool::new(nonblocking),
backlog,
bound_socket,
backlog_sockets: RwLock::new(Vec::new()),
pollee,
});
listen_stream.fill_backlog_sockets()?;
listen_stream.pollee.reset_events();
listen_stream
.bound_socket
.set_observer(Arc::downgrade(&listen_stream) as _);
Ok(listen_stream)
}
pub fn accept(&self) -> Result<(Arc<ConnectedStream>, IpEndpoint)> {
// wait to accept
let poller = Poller::new();
loop {
poll_ifaces();
let accepted_socket = if let Some(accepted_socket) = self.try_accept() {
accepted_socket
} else {
let events = self.poll(IoEvents::IN, Some(&poller));
if !events.contains(IoEvents::IN) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try accept again");
}
// FIXME: deal with accept timeout
poller.wait()?;
}
continue;
};
let remote_endpoint = accepted_socket.remote_endpoint().unwrap();
let connected_stream = {
let BacklogSocket {
bound_socket: backlog_socket,
} = accepted_socket;
ConnectedStream::new(
false,
backlog_socket,
remote_endpoint,
Pollee::new(IoEvents::empty()),
)
};
return Ok((connected_stream, remote_endpoint));
}
}
/// Append sockets listening at LocalEndPoint to support backlog
fn fill_backlog_sockets(&self) -> Result<()> {
let backlog = self.backlog;
let mut backlog_sockets = self.backlog_sockets.write();
let current_backlog_len = backlog_sockets.len();
debug_assert!(backlog >= current_backlog_len);
if backlog == current_backlog_len {
return Ok(());
}
for _ in current_backlog_len..backlog {
let backlog_socket = BacklogSocket::new(&self.bound_socket)?;
backlog_sockets.push(backlog_socket);
}
Ok(())
}
fn try_accept(&self) -> Option<BacklogSocket> {
let backlog_socket = {
let mut backlog_sockets = self.backlog_sockets.write();
let index = backlog_sockets
.iter()
.position(|backlog_socket| backlog_socket.is_active())?;
backlog_sockets.remove(index)
};
self.fill_backlog_sockets().unwrap();
self.update_io_events();
Some(backlog_socket)
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.bound_socket
.local_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint"))
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
fn update_io_events(&self) {
// The lock should be held to avoid data races
let backlog_sockets = self.backlog_sockets.read();
let can_accept = backlog_sockets.iter().any(|socket| socket.is_active());
if can_accept {
self.pollee.add_events(IoEvents::IN);
} else {
self.pollee.del_events(IoEvents::IN);
}
}
pub fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Relaxed)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
}
}
impl Observer<()> for ListenStream {
fn on_events(&self, _: &()) {
self.update_io_events();
}
}
struct BacklogSocket {
bound_socket: Arc<AnyBoundSocket>,
}
impl BacklogSocket {
fn new(bound_socket: &Arc<AnyBoundSocket>) -> Result<Self> {
let local_endpoint = bound_socket.local_endpoint().ok_or(Error::with_message(
Errno::EINVAL,
"the socket is not bound",
))?;
let unbound_socket = Box::new(AnyUnboundSocket::new_tcp());
let bound_socket = {
let iface = bound_socket.iface();
let bind_port_config = BindPortConfig::new(local_endpoint.port, true)?;
iface
.bind_socket(unbound_socket, bind_port_config)
.map_err(|(e, _)| e)?
};
bound_socket.raw_with(|raw_tcp_socket: &mut RawTcpSocket| {
raw_tcp_socket
.listen(local_endpoint)
.map_err(|_| Error::with_message(Errno::EINVAL, "fail to listen"))
})?;
Ok(Self { bound_socket })
}
fn is_active(&self) -> bool {
self.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.is_active())
}
fn remote_endpoint(&self) -> Option<IpEndpoint> {
self.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint())
}
}

View File

@ -0,0 +1,408 @@
// SPDX-License-Identifier: MPL-2.0
use connected::ConnectedStream;
use connecting::ConnectingStream;
use init::InitStream;
use listen::ListenStream;
use options::{Congestion, MaxSegment, NoDelay, WindowClamp};
use smoltcp::wire::IpEndpoint;
use util::{TcpOptionSet, DEFAULT_MAXSEG};
use crate::{
events::IoEvents,
fs::{file_handle::FileLike, utils::StatusFlags},
match_sock_option_mut, match_sock_option_ref,
net::socket::{
options::{Error, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption},
util::{
options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF},
send_recv_flags::SendRecvFlags,
shutdown_cmd::SockShutdownCmd,
socket_addr::SocketAddr,
},
Socket,
},
prelude::*,
process::signal::Poller,
};
mod connected;
mod connecting;
mod init;
mod listen;
pub mod options;
mod util;
pub use self::util::CongestionControl;
pub struct StreamSocket {
options: RwLock<OptionSet>,
state: RwLock<State>,
}
enum State {
// Start state
Init(Arc<InitStream>),
// Intermediate state
Connecting(Arc<ConnectingStream>),
// Final State 1
Connected(Arc<ConnectedStream>),
// Final State 2
Listen(Arc<ListenStream>),
}
#[derive(Debug, Clone)]
struct OptionSet {
socket: SocketOptionSet,
tcp: TcpOptionSet,
}
impl OptionSet {
fn new() -> Self {
let socket = SocketOptionSet::new_tcp();
let tcp = TcpOptionSet::new();
OptionSet { socket, tcp }
}
}
impl StreamSocket {
pub fn new(nonblocking: bool) -> Self {
let options = OptionSet::new();
let state = State::Init(InitStream::new(nonblocking));
Self {
options: RwLock::new(options),
state: RwLock::new(state),
}
}
fn is_nonblocking(&self) -> bool {
match &*self.state.read() {
State::Init(init) => init.is_nonblocking(),
State::Connecting(connecting) => connecting.is_nonblocking(),
State::Connected(connected) => connected.is_nonblocking(),
State::Listen(listen) => listen.is_nonblocking(),
}
}
fn set_nonblocking(&self, nonblocking: bool) {
match &*self.state.read() {
State::Init(init) => init.set_nonblocking(nonblocking),
State::Connecting(connecting) => connecting.set_nonblocking(nonblocking),
State::Connected(connected) => connected.set_nonblocking(nonblocking),
State::Listen(listen) => listen.set_nonblocking(nonblocking),
}
}
fn do_connect(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<ConnectingStream>> {
let mut state = self.state.write();
let init_stream = match &*state {
State::Init(init_stream) => init_stream,
State::Listen(_) | State::Connecting(_) | State::Connected(_) => {
return_errno_with_message!(Errno::EINVAL, "cannot connect")
}
};
let connecting = init_stream.connect(remote_endpoint)?;
*state = State::Connecting(connecting.clone());
Ok(connecting)
}
}
impl FileLike for StreamSocket {
fn read(&self, buf: &mut [u8]) -> Result<usize> {
// FIXME: set correct flags
let flags = SendRecvFlags::empty();
let (recv_len, _) = self.recvfrom(buf, flags)?;
Ok(recv_len)
}
fn write(&self, buf: &[u8]) -> Result<usize> {
// FIXME: set correct flags
let flags = SendRecvFlags::empty();
self.sendto(buf, None, flags)
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let state = self.state.read();
match &*state {
State::Init(init) => init.poll(mask, poller),
State::Connecting(connecting) => connecting.poll(mask, poller),
State::Connected(connected) => connected.poll(mask, poller),
State::Listen(listen) => listen.poll(mask, poller),
}
}
fn status_flags(&self) -> StatusFlags {
if self.is_nonblocking() {
StatusFlags::O_NONBLOCK
} else {
StatusFlags::empty()
}
}
fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
if new_flags.contains(StatusFlags::O_NONBLOCK) {
self.set_nonblocking(true);
} else {
self.set_nonblocking(false);
}
Ok(())
}
fn as_socket(self: Arc<Self>) -> Option<Arc<dyn Socket>> {
Some(self)
}
}
impl Socket for StreamSocket {
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?;
let state = self.state.read();
match &*state {
State::Init(init_stream) => init_stream.bind(endpoint),
_ => return_errno_with_message!(Errno::EINVAL, "cannot bind"),
}
}
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
let remote_endpoint = socket_addr.try_into()?;
let connecting_stream = self.do_connect(&remote_endpoint)?;
match connecting_stream.wait_conn() {
Ok(connected_stream) => {
*self.state.write() = State::Connected(connected_stream);
Ok(())
}
Err((err, init_stream)) => {
*self.state.write() = State::Init(init_stream);
Err(err)
}
}
}
fn listen(&self, backlog: usize) -> Result<()> {
let mut state = self.state.write();
let init_stream = match &*state {
State::Init(init_stream) => init_stream,
State::Connecting(connecting_stream) => {
return_errno_with_message!(Errno::EINVAL, "cannot listen for a connecting stream")
}
State::Listen(listen_stream) => {
return_errno_with_message!(Errno::EINVAL, "cannot listen for a listening stream")
}
State::Connected(_) => return_errno_with_message!(Errno::EINVAL, "cannot listen"),
};
let listener = init_stream.listen(backlog)?;
*state = State::Listen(listener);
Ok(())
}
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let listen_stream = match &*self.state.read() {
State::Listen(listen_stream) => listen_stream.clone(),
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not listening"),
};
let (connected_stream, remote_endpoint) = {
let listen_stream = listen_stream.clone();
listen_stream.accept()?
};
let accepted_socket = {
let state = RwLock::new(State::Connected(connected_stream));
Arc::new(StreamSocket {
options: RwLock::new(OptionSet::new()),
state,
})
};
let socket_addr = remote_endpoint.try_into()?;
Ok((accepted_socket, socket_addr))
}
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
let state = self.state.read();
match &*state {
State::Connected(connected_stream) => connected_stream.shutdown(cmd),
// TDOD: shutdown listening stream
_ => return_errno_with_message!(Errno::EINVAL, "cannot shutdown"),
}
}
fn addr(&self) -> Result<SocketAddr> {
let state = self.state.read();
let local_endpoint = match &*state {
State::Init(init_stream) => init_stream.local_endpoint(),
State::Connecting(connecting_stream) => connecting_stream.local_endpoint(),
State::Listen(listen_stream) => listen_stream.local_endpoint(),
State::Connected(connected_stream) => connected_stream.local_endpoint(),
}?;
local_endpoint.try_into()
}
fn peer_addr(&self) -> Result<SocketAddr> {
let state = self.state.read();
let remote_endpoint = match &*state {
State::Init(init_stream) => {
return_errno_with_message!(Errno::EINVAL, "init socket does not have peer")
}
State::Connecting(connecting_stream) => connecting_stream.remote_endpoint(),
State::Listen(listen_stream) => {
return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer")
}
State::Connected(connected_stream) => connected_stream.remote_endpoint(),
}?;
remote_endpoint.try_into()
}
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let connected_stream = match &*self.state.read() {
State::Connected(connected_stream) => connected_stream.clone(),
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not connected"),
};
let (recv_size, remote_endpoint) = connected_stream.recvfrom(buf, flags)?;
let socket_addr = remote_endpoint.try_into()?;
Ok((recv_size, socket_addr))
}
fn sendto(
&self,
buf: &[u8],
remote: Option<SocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
debug_assert!(remote.is_none());
if remote.is_some() {
return_errno_with_message!(Errno::EINVAL, "tcp socked should not provide remote addr");
}
let connected_stream = match &*self.state.read() {
State::Connected(connected_stream) => connected_stream.clone(),
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not connected"),
};
connected_stream.sendto(buf, flags)
}
fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> {
let options = self.options.read();
match_sock_option_mut!(option, {
// Socket Options
socket_errors: Error => {
let sock_errors = options.socket.sock_errors();
socket_errors.set(sock_errors);
},
socket_reuse_addr: ReuseAddr => {
let reuse_addr = options.socket.reuse_addr();
socket_reuse_addr.set(reuse_addr);
},
socket_send_buf: SendBuf => {
let send_buf = options.socket.send_buf();
socket_send_buf.set(send_buf);
},
socket_recv_buf: RecvBuf => {
let recv_buf = options.socket.recv_buf();
socket_recv_buf.set(recv_buf);
},
socket_reuse_port: ReusePort => {
let reuse_port = options.socket.reuse_port();
socket_reuse_port.set(reuse_port);
},
// Tcp Options
tcp_no_delay: NoDelay => {
let no_delay = options.tcp.no_delay();
tcp_no_delay.set(no_delay);
},
tcp_congestion: Congestion => {
let congestion = options.tcp.congestion();
tcp_congestion.set(congestion);
},
tcp_maxseg: MaxSegment => {
// It will always return the default MSS value defined above for an unconnected socket
// and always return the actual current MSS for a connected one.
// FIXME: how to get the current MSS?
let maxseg = match &*self.state.read() {
State::Init(_) | State::Listen(_) | State::Connecting(_) => DEFAULT_MAXSEG,
State::Connected(_) => options.tcp.maxseg(),
};
tcp_maxseg.set(maxseg);
},
tcp_window_clamp: WindowClamp => {
let window_clamp = options.tcp.window_clamp();
tcp_window_clamp.set(window_clamp);
},
_ => return_errno_with_message!(Errno::ENOPROTOOPT, "get unknown option")
});
Ok(())
}
fn set_option(&self, option: &dyn SocketOption) -> Result<()> {
let mut options = self.options.write();
// FIXME: here we have only set the value of the option, without actually
// making any real modifications.
match_sock_option_ref!(option, {
// Socket options
socket_recv_buf: RecvBuf => {
let recv_buf = socket_recv_buf.get().unwrap();
if *recv_buf <= MIN_RECVBUF {
options.socket.set_recv_buf(MIN_RECVBUF);
} else{
options.socket.set_recv_buf(*recv_buf);
}
},
socket_send_buf: SendBuf => {
let send_buf = socket_send_buf.get().unwrap();
if *send_buf <= MIN_SENDBUF {
options.socket.set_send_buf(MIN_SENDBUF);
} else {
options.socket.set_send_buf(*send_buf);
}
},
socket_reuse_addr: ReuseAddr => {
let reuse_addr = socket_reuse_addr.get().unwrap();
options.socket.set_reuse_addr(*reuse_addr);
},
socket_reuse_port: ReusePort => {
let reuse_port = socket_reuse_port.get().unwrap();
options.socket.set_reuse_port(*reuse_port);
},
socket_linger: Linger => {
let linger = socket_linger.get().unwrap();
options.socket.set_linger(*linger);
},
// Tcp options
tcp_no_delay: NoDelay => {
let no_delay = tcp_no_delay.get().unwrap();
options.tcp.set_no_delay(*no_delay);
},
tcp_congestion: Congestion => {
let congestion = tcp_congestion.get().unwrap();
options.tcp.set_congestion(*congestion);
},
tcp_maxseg: MaxSegment => {
const MIN_MAXSEG: u32 = 536;
const MAX_MAXSEG: u32 = 65535;
let maxseg = tcp_maxseg.get().unwrap();
if *maxseg < MIN_MAXSEG || *maxseg > MAX_MAXSEG {
return_errno_with_message!(Errno::EINVAL, "New maxseg should be in allowed range.");
}
options.tcp.set_maxseg(*maxseg);
},
tcp_window_clamp: WindowClamp => {
let window_clamp = tcp_window_clamp.get().unwrap();
let half_recv_buf = (options.socket.recv_buf()) / 2;
if *window_clamp <= half_recv_buf {
options.tcp.set_window_clamp(half_recv_buf);
} else {
options.tcp.set_window_clamp(*window_clamp);
}
},
_ => return_errno_with_message!(Errno::ENOPROTOOPT, "set unknown option")
});
Ok(())
}
}

View File

@ -0,0 +1,11 @@
// SPDX-License-Identifier: MPL-2.0
use super::CongestionControl;
use crate::impl_socket_options;
impl_socket_options!(
pub struct NoDelay(bool);
pub struct Congestion(CongestionControl);
pub struct MaxSegment(u32);
pub struct WindowClamp(u32);
);

View File

@ -0,0 +1,61 @@
// SPDX-License-Identifier: MPL-2.0
use crate::prelude::*;
#[derive(Debug, Clone, Copy, CopyGetters, Setters)]
#[get_copy = "pub"]
#[set = "pub"]
pub struct TcpOptionSet {
no_delay: bool,
congestion: CongestionControl,
maxseg: u32,
window_clamp: u32,
}
pub const DEFAULT_MAXSEG: u32 = 536;
pub const DEFAULT_WINDOW_CLAMP: u32 = 0x8000_0000;
impl TcpOptionSet {
pub fn new() -> Self {
Self {
no_delay: false,
congestion: CongestionControl::Reno,
maxseg: DEFAULT_MAXSEG,
window_clamp: DEFAULT_WINDOW_CLAMP,
}
}
}
impl Default for TcpOptionSet {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
pub enum CongestionControl {
Reno,
Cubic,
}
impl CongestionControl {
const RENO: &'static str = "reno";
const CUBIC: &'static str = "cubic";
pub fn new(name: &str) -> Result<Self> {
let congestion = match name {
Self::RENO => Self::Reno,
Self::CUBIC => Self::Cubic,
_ => return_errno_with_message!(Errno::EINVAL, "unsupported congestion name"),
};
Ok(congestion)
}
pub fn name(&self) -> &'static str {
match self {
Self::Reno => Self::RENO,
Self::Cubic => Self::CUBIC,
}
}
}

View File

@ -0,0 +1,77 @@
// SPDX-License-Identifier: MPL-2.0
use self::options::SocketOption;
pub use self::util::{
options::LingerOption, send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd,
socket_addr::SocketAddr,
};
use crate::{fs::file_handle::FileLike, prelude::*};
pub mod ip;
pub mod options;
pub mod unix;
mod util;
/// Operations defined on a socket.
pub trait Socket: FileLike + Send + Sync {
/// Assign the address specified by socket_addr to the socket
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
return_errno_with_message!(Errno::EINVAL, "bind not implemented");
}
/// Build connection for a given address
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
return_errno_with_message!(Errno::EINVAL, "connect not implemented");
}
/// Listen for connections on a socket
fn listen(&self, backlog: usize) -> Result<()> {
return_errno_with_message!(Errno::EINVAL, "connect not implemented");
}
/// Accept a connection on a socket
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
return_errno_with_message!(Errno::EINVAL, "accept not implemented");
}
/// Shut down part of a full-duplex connection
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
return_errno_with_message!(Errno::EINVAL, "shutdown not implemented");
}
/// Get address of this socket.
fn addr(&self) -> Result<SocketAddr> {
return_errno_with_message!(Errno::EINVAL, "getsockname not implemented");
}
/// Get address of peer socket
fn peer_addr(&self) -> Result<SocketAddr> {
return_errno_with_message!(Errno::EINVAL, "getpeername not implemented");
}
/// Get options on the socket. The resulted option will put in the `option` parameter, if
/// this method returns success.
fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> {
return_errno_with_message!(Errno::EINVAL, "getsockopt not implemented");
}
/// Set options on the socket.
fn set_option(&self, option: &dyn SocketOption) -> Result<()> {
return_errno_with_message!(Errno::EINVAL, "setsockopt not implemented");
}
/// Receive a message from a socket
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
return_errno_with_message!(Errno::EINVAL, "recvfrom not implemented");
}
/// Send a message on a socket
fn sendto(
&self,
buf: &[u8],
remote: Option<SocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
return_errno_with_message!(Errno::EINVAL, "recvfrom not implemented");
}
}

View File

@ -0,0 +1,85 @@
// SPDX-License-Identifier: MPL-2.0
#[macro_export]
macro_rules! impl_socket_options {
($(
$(#[$outer:meta])*
pub struct $name: ident ( $value_ty:ty );
)*) => {
$(
$(#[$outer])*
#[derive(Debug)]
pub struct $name (Option<$value_ty>);
impl $name {
pub fn new() -> Self {
Self (None)
}
pub fn get(&self) -> Option<&$value_ty> {
self.0.as_ref()
}
pub fn set(&mut self, value: $value_ty) {
self.0 = Some(value);
}
}
impl $crate::net::socket::SocketOption for $name {
fn as_any(&self) -> &dyn core::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
self
}
}
impl Default for $name {
fn default() -> Self {
Self::new()
}
}
)*
};
}
#[macro_export]
macro_rules! match_sock_option_ref {
(
$option:expr, {
$( $bind: ident : $ty:ty => $arm:expr ),*,
_ => $default:expr
}
) => {{
let __option : &dyn SocketOption = $option;
$(
if let Some($bind) = __option.as_any().downcast_ref::<$ty>() {
$arm
} else
)*
{
$default
}
}};
}
#[macro_export]
macro_rules! match_sock_option_mut {
(
$option:expr, {
$( $bind: ident : $ty:ty => $arm:expr ),*,
_ => $default:expr
}
) => {{
let __option : &mut dyn SocketOption = $option;
$(
if let Some($bind) = __option.as_any_mut().downcast_mut::<$ty>() {
$arm
} else
)*
{
$default
}
}};
}

View File

@ -0,0 +1,22 @@
// SPDX-License-Identifier: MPL-2.0
use crate::{impl_socket_options, prelude::*};
mod macros;
use super::LingerOption;
/// Socket options. This trait represents all options that can be set or got for a socket, including
/// socket level options and options for specific socket type like tcp socket.
pub trait SocketOption: Any + Send + Sync + Debug {
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
}
impl_socket_options!(
pub struct ReuseAddr(bool);
pub struct ReusePort(bool);
pub struct SendBuf(u32);
pub struct RecvBuf(u32);
pub struct Error(Option<crate::error::Error>);
pub struct Linger(LingerOption);
);

View File

@ -0,0 +1,55 @@
// SPDX-License-Identifier: MPL-2.0
use crate::{fs::utils::Dentry, net::socket::util::socket_addr::SocketAddr, prelude::*};
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum UnixSocketAddr {
Path(String),
Abstract(String),
}
#[derive(Clone)]
pub(super) enum UnixSocketAddrBound {
Path(Arc<Dentry>),
Abstract(String),
}
impl PartialEq for UnixSocketAddrBound {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Abstract(l0), Self::Abstract(r0)) => l0 == r0,
(Self::Path(l0), Self::Path(r0)) => Arc::ptr_eq(l0.inode(), r0.inode()),
_ => false,
}
}
}
impl TryFrom<SocketAddr> for UnixSocketAddr {
type Error = Error;
fn try_from(value: SocketAddr) -> Result<Self> {
match value {
SocketAddr::Unix(unix_socket_addr) => Ok(unix_socket_addr),
_ => return_errno_with_message!(Errno::EINVAL, "Invalid unix socket addr"),
}
}
}
impl From<UnixSocketAddrBound> for UnixSocketAddr {
fn from(value: UnixSocketAddrBound) -> Self {
match value {
UnixSocketAddrBound::Path(dentry) => {
let abs_path = dentry.abs_path();
Self::Path(abs_path)
}
UnixSocketAddrBound::Abstract(name) => Self::Abstract(name),
}
}
}
impl From<UnixSocketAddrBound> for SocketAddr {
fn from(value: UnixSocketAddrBound) -> Self {
let unix_socket_addr = UnixSocketAddr::from(value);
SocketAddr::Unix(unix_socket_addr)
}
}

View File

@ -0,0 +1,7 @@
// SPDX-License-Identifier: MPL-2.0
mod addr;
mod stream;
pub use addr::UnixSocketAddr;
pub use stream::UnixStreamSocket;

View File

@ -0,0 +1,55 @@
// SPDX-License-Identifier: MPL-2.0
use super::endpoint::Endpoint;
use crate::{
events::IoEvents,
net::socket::{unix::addr::UnixSocketAddrBound, SockShutdownCmd},
prelude::*,
process::signal::Poller,
};
pub(super) struct Connected {
local_endpoint: Arc<Endpoint>,
}
impl Connected {
pub(super) fn new(local_endpoint: Arc<Endpoint>) -> Self {
Connected { local_endpoint }
}
pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> {
self.local_endpoint.addr()
}
pub(super) fn peer_addr(&self) -> Option<UnixSocketAddrBound> {
self.local_endpoint.peer_addr()
}
pub(super) fn is_bound(&self) -> bool {
self.addr().is_some()
}
pub(super) fn write(&self, buf: &[u8]) -> Result<usize> {
self.local_endpoint.write(buf)
}
pub(super) fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.local_endpoint.read(buf)
}
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
self.local_endpoint.shutdown(cmd)
}
pub(super) fn is_nonblocking(&self) -> bool {
self.local_endpoint.is_nonblocking()
}
pub(super) fn set_nonblocking(&self, is_nonblocking: bool) {
self.local_endpoint.set_nonblocking(is_nonblocking).unwrap();
}
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.local_endpoint.poll(mask, poller)
}
}

View File

@ -0,0 +1,127 @@
// SPDX-License-Identifier: MPL-2.0
use crate::{
events::IoEvents,
fs::utils::{Channel, Consumer, Producer, StatusFlags},
net::socket::{unix::addr::UnixSocketAddrBound, SockShutdownCmd},
prelude::*,
process::signal::Poller,
};
pub(super) struct Endpoint(Inner);
struct Inner {
addr: RwLock<Option<UnixSocketAddrBound>>,
reader: Consumer<u8>,
writer: Producer<u8>,
peer: Weak<Endpoint>,
}
impl Endpoint {
pub(super) fn new_pair(is_nonblocking: bool) -> Result<(Arc<Endpoint>, Arc<Endpoint>)> {
let flags = if is_nonblocking {
StatusFlags::O_NONBLOCK
} else {
StatusFlags::empty()
};
let (writer_a, reader_b) =
Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split();
let (writer_b, reader_a) =
Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split();
let mut endpoint_b = None;
let endpoint_a = Arc::new_cyclic(|endpoint_a_ref| {
let peer = Arc::new(Endpoint::new(reader_b, writer_b, endpoint_a_ref.clone()));
let endpoint_a = Endpoint::new(reader_a, writer_a, Arc::downgrade(&peer));
endpoint_b = Some(peer);
endpoint_a
});
Ok((endpoint_a, endpoint_b.unwrap()))
}
fn new(reader: Consumer<u8>, writer: Producer<u8>, peer: Weak<Endpoint>) -> Self {
Self(Inner {
addr: RwLock::new(None),
reader,
writer,
peer,
})
}
pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> {
self.0.addr.read().clone()
}
pub(super) fn set_addr(&self, addr: UnixSocketAddrBound) {
*self.0.addr.write() = Some(addr);
}
pub(super) fn peer_addr(&self) -> Option<UnixSocketAddrBound> {
self.0.peer.upgrade().and_then(|peer| peer.addr())
}
pub(super) fn is_nonblocking(&self) -> bool {
let reader_status = self.0.reader.is_nonblocking();
let writer_status = self.0.writer.is_nonblocking();
debug_assert!(reader_status == writer_status);
reader_status
}
pub(super) fn set_nonblocking(&self, is_nonblocking: bool) -> Result<()> {
let reader_flags = self.0.reader.status_flags();
self.0
.reader
.set_status_flags(reader_flags | StatusFlags::O_NONBLOCK)?;
let writer_flags = self.0.writer.status_flags();
self.0
.writer
.set_status_flags(writer_flags | StatusFlags::O_NONBLOCK)?;
Ok(())
}
pub(super) fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.0.reader.read(buf)
}
pub(super) fn write(&self, buf: &[u8]) -> Result<usize> {
self.0.writer.write(buf)
}
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
if !self.is_connected() {
return_errno_with_message!(Errno::ENOTCONN, "The socket is not connected.");
}
if cmd.shut_read() {
self.0.reader.shutdown();
}
if cmd.shut_write() {
self.0.writer.shutdown();
}
Ok(())
}
pub(super) fn is_connected(&self) -> bool {
self.0.peer.upgrade().is_some()
}
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let mut events = IoEvents::empty();
// FIXME: should reader and writer use the same mask?
let reader_events = self.0.reader.poll(mask, poller);
let writer_events = self.0.writer.poll(mask, poller);
if reader_events.contains(IoEvents::HUP) || self.0.reader.is_shutdown() {
events |= IoEvents::RDHUP | IoEvents::IN;
if writer_events.contains(IoEvents::ERR) || self.0.writer.is_shutdown() {
events |= IoEvents::HUP | IoEvents::OUT;
}
}
events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT);
events
}
}
const DAFAULT_BUF_SIZE: usize = 4096;

View File

@ -0,0 +1,104 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use super::{connected::Connected, endpoint::Endpoint, listener::push_incoming};
use crate::{
events::IoEvents,
fs::{
fs_resolver::{split_path, FsPath},
utils::{Dentry, InodeMode, InodeType},
},
net::socket::unix::addr::{UnixSocketAddr, UnixSocketAddrBound},
prelude::*,
process::signal::{Pollee, Poller},
};
pub(super) struct Init {
is_nonblocking: AtomicBool,
addr: Mutex<Option<UnixSocketAddrBound>>,
pollee: Pollee,
}
impl Init {
pub(super) fn new(is_nonblocking: bool) -> Self {
Self {
is_nonblocking: AtomicBool::new(is_nonblocking),
addr: Mutex::new(None),
pollee: Pollee::new(IoEvents::empty()),
}
}
pub(super) fn bind(&self, addr_to_bind: &UnixSocketAddr) -> Result<()> {
let mut addr = self.addr.lock();
if addr.is_some() {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound");
}
let bound_addr = match addr_to_bind {
UnixSocketAddr::Abstract(_) => todo!(),
UnixSocketAddr::Path(path) => {
let dentry = create_socket_file(path)?;
UnixSocketAddrBound::Path(dentry)
}
};
*addr = Some(bound_addr);
Ok(())
}
pub(super) fn connect(&self, remote_addr: &UnixSocketAddrBound) -> Result<Connected> {
let addr = self.addr();
if let Some(ref addr) = addr {
if *addr == *remote_addr {
return_errno_with_message!(Errno::EINVAL, "try to connect to self is invalid");
}
}
let (this_end, remote_end) = Endpoint::new_pair(self.is_nonblocking())?;
remote_end.set_addr(remote_addr.clone());
if let Some(addr) = addr {
this_end.set_addr(addr.clone());
};
push_incoming(remote_addr, remote_end)?;
Ok(Connected::new(this_end))
}
pub(super) fn is_bound(&self) -> bool {
self.addr.lock().is_some()
}
pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> {
self.addr.lock().clone()
}
pub(super) fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Acquire)
}
pub(super) fn set_nonblocking(&self, is_nonblocking: bool) {
self.is_nonblocking.store(is_nonblocking, Ordering::Release);
}
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
}
fn create_socket_file(path: &str) -> Result<Arc<Dentry>> {
let (parent_pathname, file_name) = split_path(path);
let parent = {
let current = current!();
let fs = current.fs().read();
let parent_path = FsPath::try_from(parent_pathname)?;
fs.lookup(&parent_path)?
};
let dentry = parent.create(
file_name,
InodeType::Socket,
InodeMode::S_IRUSR | InodeMode::S_IWUSR,
)?;
Ok(dentry)
}

View File

@ -0,0 +1,229 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use keyable_arc::KeyableWeak;
use super::{connected::Connected, endpoint::Endpoint, UnixStreamSocket};
use crate::{
events::IoEvents,
fs::{
file_handle::FileLike,
utils::{Dentry, Inode},
},
net::socket::{
unix::addr::{UnixSocketAddr, UnixSocketAddrBound},
SocketAddr,
},
prelude::*,
process::signal::{Pollee, Poller},
};
pub(super) struct Listener {
addr: UnixSocketAddrBound,
is_nonblocking: AtomicBool,
}
impl Listener {
pub(super) fn new(
addr: UnixSocketAddrBound,
backlog: usize,
nonblocking: bool,
) -> Result<Self> {
BACKLOG_TABLE.add_backlog(&addr, backlog)?;
Ok(Self {
addr,
is_nonblocking: AtomicBool::new(nonblocking),
})
}
pub(super) fn addr(&self) -> &UnixSocketAddrBound {
&self.addr
}
pub(super) fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Acquire)
}
pub(super) fn set_nonblocking(&self, is_nonblocking: bool) {
self.is_nonblocking.store(is_nonblocking, Ordering::Release);
}
pub(super) fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let addr = self.addr().clone();
let is_nonblocking = self.is_nonblocking();
let connected = {
let local_endpoint = BACKLOG_TABLE.pop_incoming(is_nonblocking, &addr)?;
Connected::new(local_endpoint)
};
let peer_addr = match connected.peer_addr() {
None => SocketAddr::Unix(UnixSocketAddr::Path(String::new())),
Some(addr) => SocketAddr::from(addr.clone()),
};
let socket = Arc::new(UnixStreamSocket::new_connected(connected));
Ok((socket, peer_addr))
}
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let addr = self.addr();
let backlog = BACKLOG_TABLE.get_backlog(addr).unwrap();
backlog.poll(mask, poller)
}
}
static BACKLOG_TABLE: BacklogTable = BacklogTable::new();
struct BacklogTable {
backlog_sockets: RwLock<BTreeMap<KeyableWeak<dyn Inode>, Arc<Backlog>>>,
// TODO: For linux, there is also abstract socket domain that a socket addr is not bound to an inode.
}
impl BacklogTable {
const fn new() -> Self {
Self {
backlog_sockets: RwLock::new(BTreeMap::new()),
}
}
fn add_backlog(&self, addr: &UnixSocketAddrBound, backlog: usize) -> Result<()> {
let inode = {
let UnixSocketAddrBound::Path(dentry) = addr else {
todo!()
};
create_keyable_inode(dentry)
};
let mut backlog_sockets = self.backlog_sockets.write();
if backlog_sockets.contains_key(&inode) {
return_errno_with_message!(Errno::EADDRINUSE, "the addr is already used");
}
let new_backlog = Arc::new(Backlog::new(backlog));
backlog_sockets.insert(inode, new_backlog);
Ok(())
}
fn get_backlog(&self, addr: &UnixSocketAddrBound) -> Result<Arc<Backlog>> {
let inode = {
let UnixSocketAddrBound::Path(dentry) = addr else {
todo!()
};
create_keyable_inode(dentry)
};
let backlog_sockets = self.backlog_sockets.read();
backlog_sockets
.get(&inode)
.map(Arc::clone)
.ok_or_else(|| Error::with_message(Errno::EINVAL, "the socket is not listened"))
}
fn pop_incoming(&self, nonblocking: bool, addr: &UnixSocketAddrBound) -> Result<Arc<Endpoint>> {
let poller = Poller::new();
loop {
let backlog = self.get_backlog(addr)?;
if let Some(endpoint) = backlog.pop_incoming() {
return Ok(endpoint);
}
if nonblocking {
return_errno_with_message!(Errno::EAGAIN, "no connection comes");
}
let events = {
let mask = IoEvents::IN;
backlog.poll(mask, Some(&poller))
};
if events.contains(IoEvents::ERR) | events.contains(IoEvents::HUP) {
return_errno_with_message!(Errno::ECONNABORTED, "connection is aborted");
}
// FIXME: deal with accept timeout
if events.is_empty() {
poller.wait()?;
}
}
}
fn push_incoming(&self, addr: &UnixSocketAddrBound, endpoint: Arc<Endpoint>) -> Result<()> {
let backlog = self.get_backlog(addr).map_err(|_| {
Error::with_message(
Errno::ECONNREFUSED,
"no socket is listened at the remote address",
)
})?;
backlog.push_incoming(endpoint)
}
fn remove_backlog(&self, addr: &UnixSocketAddrBound) {
let UnixSocketAddrBound::Path(dentry) = addr else {
todo!()
};
let inode = create_keyable_inode(dentry);
self.backlog_sockets.write().remove(&inode);
}
}
struct Backlog {
pollee: Pollee,
backlog: usize,
incoming_endpoints: Mutex<VecDeque<Arc<Endpoint>>>,
}
impl Backlog {
fn new(backlog: usize) -> Self {
Self {
pollee: Pollee::new(IoEvents::empty()),
backlog,
incoming_endpoints: Mutex::new(VecDeque::with_capacity(backlog)),
}
}
fn push_incoming(&self, endpoint: Arc<Endpoint>) -> Result<()> {
let mut endpoints = self.incoming_endpoints.lock();
if endpoints.len() >= self.backlog {
return_errno_with_message!(Errno::ECONNREFUSED, "incoming_endpoints is full");
}
endpoints.push_back(endpoint);
self.pollee.add_events(IoEvents::IN);
Ok(())
}
fn pop_incoming(&self) -> Option<Arc<Endpoint>> {
let mut incoming_endpoints = self.incoming_endpoints.lock();
let endpoint = incoming_endpoints.pop_front();
if incoming_endpoints.is_empty() {
self.pollee.del_events(IoEvents::IN);
}
endpoint
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
// Lock to avoid any events may change pollee state when we poll
let _lock = self.incoming_endpoints.lock();
self.pollee.poll(mask, poller)
}
}
fn create_keyable_inode(dentry: &Arc<Dentry>) -> KeyableWeak<dyn Inode> {
let weak_inode = Arc::downgrade(dentry.inode());
KeyableWeak::from(weak_inode)
}
pub(super) fn unregister_backlog(addr: &UnixSocketAddrBound) {
BACKLOG_TABLE.remove_backlog(addr);
}
pub(super) fn push_incoming(
remote_addr: &UnixSocketAddrBound,
remote_end: Arc<Endpoint>,
) -> Result<()> {
BACKLOG_TABLE.push_incoming(remote_addr, remote_end)
}

View File

@ -0,0 +1,9 @@
// SPDX-License-Identifier: MPL-2.0
mod connected;
mod endpoint;
mod init;
mod listener;
mod socket;
pub use socket::UnixStreamSocket;

View File

@ -0,0 +1,306 @@
// SPDX-License-Identifier: MPL-2.0
use super::{
connected::Connected,
endpoint::Endpoint,
init::Init,
listener::{unregister_backlog, Listener},
};
use crate::{
events::IoEvents,
fs::{
file_handle::FileLike,
fs_resolver::FsPath,
utils::{Dentry, InodeType, StatusFlags},
},
net::socket::{
unix::{addr::UnixSocketAddrBound, UnixSocketAddr},
util::{send_recv_flags::SendRecvFlags, socket_addr::SocketAddr},
SockShutdownCmd, Socket,
},
prelude::*,
process::signal::Poller,
};
pub struct UnixStreamSocket(RwLock<State>);
impl UnixStreamSocket {
pub(super) fn new_init(init: Init) -> Self {
Self(RwLock::new(State::Init(Arc::new(init))))
}
pub(super) fn new_listen(listen: Listener) -> Self {
Self(RwLock::new(State::Listen(Arc::new(listen))))
}
pub(super) fn new_connected(connected: Connected) -> Self {
Self(RwLock::new(State::Connected(Arc::new(connected))))
}
}
enum State {
Init(Arc<Init>),
Listen(Arc<Listener>),
Connected(Arc<Connected>),
}
impl UnixStreamSocket {
pub fn new(nonblocking: bool) -> Self {
let init = Init::new(nonblocking);
Self::new_init(init)
}
pub fn new_pair(nonblocking: bool) -> Result<(Arc<Self>, Arc<Self>)> {
let (end_a, end_b) = Endpoint::new_pair(nonblocking)?;
let connected_a = {
let connected = Connected::new(end_a);
Self::new_connected(connected)
};
let connected_b = {
let connected = Connected::new(end_b);
Self::new_connected(connected)
};
Ok((Arc::new(connected_a), Arc::new(connected_b)))
}
fn bound_addr(&self) -> Option<UnixSocketAddrBound> {
let status = self.0.read();
match &*status {
State::Init(init) => init.addr(),
State::Listen(listen) => Some(listen.addr().clone()),
State::Connected(connected) => connected.addr(),
}
}
fn mask_flags(status_flags: &StatusFlags) -> StatusFlags {
const SUPPORTED_FLAGS: StatusFlags = StatusFlags::O_NONBLOCK;
const UNSUPPORTED_FLAGS: StatusFlags = SUPPORTED_FLAGS.complement();
if status_flags.intersects(UNSUPPORTED_FLAGS) {
warn!("ignore unsupported flags");
}
status_flags.intersection(SUPPORTED_FLAGS)
}
}
impl FileLike for UnixStreamSocket {
fn as_socket(self: Arc<Self>) -> Option<Arc<dyn Socket>> {
Some(self)
}
fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.recvfrom(buf, SendRecvFlags::empty())
.map(|(read_size, _)| read_size)
}
fn write(&self, buf: &[u8]) -> Result<usize> {
self.sendto(buf, None, SendRecvFlags::empty())
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let inner = self.0.read();
match &*inner {
State::Init(init) => init.poll(mask, poller),
State::Listen(listen) => listen.poll(mask, poller),
State::Connected(connected) => connected.poll(mask, poller),
}
}
fn status_flags(&self) -> StatusFlags {
let inner = self.0.read();
let is_nonblocking = match &*inner {
State::Init(init) => init.is_nonblocking(),
State::Listen(listen) => listen.is_nonblocking(),
State::Connected(connected) => connected.is_nonblocking(),
};
if is_nonblocking {
StatusFlags::O_NONBLOCK
} else {
StatusFlags::empty()
}
}
fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
let is_nonblocking = {
let supported_flags = Self::mask_flags(&new_flags);
supported_flags.contains(StatusFlags::O_NONBLOCK)
};
let mut inner = self.0.write();
match &mut *inner {
State::Init(init) => init.set_nonblocking(is_nonblocking),
State::Listen(listen) => listen.set_nonblocking(is_nonblocking),
State::Connected(connected) => connected.set_nonblocking(is_nonblocking),
}
Ok(())
}
}
impl Socket for UnixStreamSocket {
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
let addr = UnixSocketAddr::try_from(socket_addr)?;
let init = match &*self.0.read() {
State::Init(init) => init.clone(),
_ => return_errno_with_message!(
Errno::EINVAL,
"cannot bind a listening or connected socket"
),
// FIXME: Maybe binding a connected socket should also be allowed?
};
init.bind(&addr)
}
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
let remote_addr = {
let unix_socket_addr = UnixSocketAddr::try_from(socket_addr)?;
match unix_socket_addr {
UnixSocketAddr::Abstract(abstract_name) => {
UnixSocketAddrBound::Abstract(abstract_name)
}
UnixSocketAddr::Path(path) => {
let dentry = lookup_socket_file(&path)?;
UnixSocketAddrBound::Path(dentry)
}
}
};
let init = match &*self.0.read() {
State::Init(init) => init.clone(),
State::Listen(_) => return_errno_with_message!(Errno::EINVAL, "the socket is listened"),
State::Connected(_) => {
return_errno_with_message!(Errno::EISCONN, "the socket is connected")
}
};
let connected = init.connect(&remote_addr)?;
*self.0.write() = State::Connected(Arc::new(connected));
Ok(())
}
fn listen(&self, backlog: usize) -> Result<()> {
let init = match &*self.0.read() {
State::Init(init) => init.clone(),
State::Listen(_) => {
return_errno_with_message!(Errno::EINVAL, "the socket is already listening")
}
State::Connected(_) => {
return_errno_with_message!(Errno::EISCONN, "the socket is already connected")
}
};
let addr = init.addr().ok_or(Error::with_message(
Errno::EINVAL,
"the socket is not bound",
))?;
let listener = Listener::new(addr.clone(), backlog, init.is_nonblocking())?;
*self.0.write() = State::Listen(Arc::new(listener));
Ok(())
}
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let listen = match &*self.0.read() {
State::Listen(listen) => listen.clone(),
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not listening"),
};
listen.accept()
}
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
let connected = match &*self.0.read() {
State::Connected(connected) => connected.clone(),
_ => return_errno_with_message!(Errno::ENOTCONN, "the socked is not connected"),
};
connected.shutdown(cmd)
}
fn addr(&self) -> Result<SocketAddr> {
let addr = match &*self.0.read() {
State::Init(init) => init.addr(),
State::Listen(listen) => Some(listen.addr().clone()),
State::Connected(connected) => connected.addr(),
};
addr.map(Into::<SocketAddr>::into)
.ok_or(Error::with_message(
Errno::EINVAL,
"the socket does not bind to addr",
))
}
fn peer_addr(&self) -> Result<SocketAddr> {
let connected = match &*self.0.read() {
State::Connected(connected) => connected.clone(),
_ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"),
};
match connected.peer_addr() {
None => Ok(SocketAddr::Unix(UnixSocketAddr::Path(String::new()))),
Some(peer_addr) => Ok(SocketAddr::from(peer_addr.clone())),
}
}
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let connected = match &*self.0.read() {
State::Connected(connected) => connected.clone(),
_ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"),
};
let peer_addr = self.peer_addr()?;
let read_size = connected.read(buf)?;
Ok((read_size, peer_addr))
}
fn sendto(
&self,
buf: &[u8],
remote: Option<SocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
debug_assert!(remote.is_none());
// TODO: deal with flags
let connected = match &*self.0.read() {
State::Connected(connected) => connected.clone(),
_ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"),
};
connected.write(buf)
}
}
impl Drop for UnixStreamSocket {
fn drop(&mut self) {
let Some(bound_addr) = self.bound_addr() else {
return;
};
if let State::Listen(_) = &*self.0.read() {
unregister_backlog(&bound_addr);
}
}
}
fn lookup_socket_file(path: &str) -> Result<Arc<Dentry>> {
let dentry = {
let current = current!();
let fs = current.fs().read();
let fs_path = FsPath::try_from(path)?;
fs.lookup(&fs_path)?
};
if dentry.type_() != InodeType::Socket {
return_errno_with_message!(Errno::ENOTSOCK, "not a socket file")
}
if !dentry.mode()?.is_readable() || !dentry.mode()?.is_writable() {
return_errno_with_message!(Errno::EACCES, "the socket cannot be read or written")
}
Ok(dentry)
}

View File

@ -0,0 +1,6 @@
// SPDX-License-Identifier: MPL-2.0
pub mod options;
pub mod send_recv_flags;
pub mod shutdown_cmd;
pub mod socket_addr;

View File

@ -0,0 +1,57 @@
// SPDX-License-Identifier: MPL-2.0
use core::time::Duration;
use crate::{
net::iface::{RECV_BUF_LEN, SEND_BUF_LEN},
prelude::*,
};
#[derive(Debug, Clone, CopyGetters, Setters)]
#[get_copy = "pub"]
#[set = "pub"]
pub struct SocketOptionSet {
sock_errors: Option<Error>,
reuse_addr: bool,
reuse_port: bool,
send_buf: u32,
recv_buf: u32,
linger: LingerOption,
}
impl SocketOptionSet {
/// Return the default socket level options for tcp socket.
pub fn new_tcp() -> Self {
Self {
sock_errors: None,
reuse_addr: false,
reuse_port: false,
send_buf: SEND_BUF_LEN as u32,
recv_buf: RECV_BUF_LEN as u32,
linger: LingerOption::default(),
}
}
}
pub const MIN_SENDBUF: u32 = 2304;
pub const MIN_RECVBUF: u32 = 2304;
#[derive(Debug, Default, Clone, Copy)]
pub struct LingerOption {
is_on: bool,
timeout: Duration,
}
impl LingerOption {
pub fn new(is_on: bool, timeout: Duration) -> Self {
Self { is_on, timeout }
}
pub fn is_on(&self) -> bool {
self.is_on
}
pub fn timeout(&self) -> Duration {
self.timeout
}
}

View File

@ -0,0 +1,47 @@
// SPDX-License-Identifier: MPL-2.0
use crate::prelude::*;
bitflags! {
/// Flags used for send/recv.
/// The definiton is from https://elixir.bootlin.com/linux/v6.0.9/source/include/linux/socket.h
#[repr(C)]
#[derive(Pod)]
pub struct SendRecvFlags: i32 {
const MSG_OOB = 1;
const MSG_PEEK = 2;
const MSG_DONTROUTE = 4;
// const MSG_TRYHARD = 4; /* Synonym for MSG_DONTROUTE for DECnet */
const MSG_CTRUNC = 8;
const MSG_PROBE = 0x10; /* Do not send. Only probe path f.e. for MTU */
const MSG_TRUNC = 0x20;
const MSG_DONTWAIT = 0x40; /* Nonblocking io */
const MSG_EOR = 0x80; /* End of record */
const MSG_WAITALL = 0x100; /* Wait for a full request */
const MSG_FIN = 0x200;
const MSG_SYN = 0x400;
const MSG_CONFIRM = 0x800; /* Confirm path validity */
const MSG_RST = 0x1000;
const MSG_ERRQUEUE = 0x2000; /* Fetch message from error queue */
const MSG_NOSIGNAL = 0x4000; /* Do not generate SIGPIPE */
const MSG_MORE = 0x8000; /* Sender will send more */
const MSG_WAITFORONE = 0x10000; /* recvmmsg(): block until 1+ packets avail */
const MSG_SENDPAGE_NOPOLICY = 0x10000; /* sendpage() internal : do no apply policy */
const MSG_SENDPAGE_NOTLAST = 0x20000; /* sendpage() internal : not the last page */
const MSG_BATCH = 0x40000; /* sendmmsg(): more messages coming */
// const MSG_EOF MSG_FIN
const MSG_NO_SHARED_FRAGS = 0x80000; /* sendpage() internal : page frags are not shared */
const MSG_SENDPAGE_DECRYPTED = 0x100000; /* sendpage() internal : page may carry plain text and require encryption */
}
}
impl SendRecvFlags {
fn supported_flags() -> Self {
SendRecvFlags::empty()
}
pub fn is_all_supported(&self) -> bool {
let supported_flags = Self::supported_flags();
supported_flags.contains(*self)
}
}

View File

@ -0,0 +1,27 @@
// SPDX-License-Identifier: MPL-2.0
use crate::prelude::*;
/// Shutdown types
/// From https://elixir.bootlin.com/linux/v6.0.9/source/include/linux/net.h
#[repr(i32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromInt)]
#[allow(non_camel_case_types)]
pub enum SockShutdownCmd {
/// Shutdown receptions
SHUT_RD = 0,
/// Shutdown transmissions
SHUT_WR = 1,
/// Shutdown receptions and transmissions
SHUT_RDWR = 2,
}
impl SockShutdownCmd {
pub fn shut_read(&self) -> bool {
*self == Self::SHUT_RD || *self == Self::SHUT_RDWR
}
pub fn shut_write(&self) -> bool {
*self == Self::SHUT_WR || *self == Self::SHUT_RDWR
}
}

View File

@ -0,0 +1,57 @@
// SPDX-License-Identifier: MPL-2.0
use crate::{
net::{
iface::{IpAddress, IpEndpoint, IpListenEndpoint, Ipv4Address},
socket::unix::UnixSocketAddr,
},
prelude::*,
};
type PortNum = u16;
#[derive(Debug)]
pub enum SocketAddr {
Unix(UnixSocketAddr),
IPv4(Ipv4Address, PortNum),
IPv6,
}
impl TryFrom<SocketAddr> for IpEndpoint {
type Error = Error;
fn try_from(value: SocketAddr) -> Result<Self> {
match value {
SocketAddr::IPv4(addr, port) => Ok(IpEndpoint::new(addr.into_address(), port)),
_ => return_errno_with_message!(
Errno::EINVAL,
"sock addr cannot be converted as IpEndpoint"
),
}
}
}
impl TryFrom<IpEndpoint> for SocketAddr {
type Error = Error;
fn try_from(endpoint: IpEndpoint) -> Result<Self> {
let port = endpoint.port;
let socket_addr = match endpoint.addr {
IpAddress::Ipv4(addr) => SocketAddr::IPv4(addr, port), // TODO: support IPv6
};
Ok(socket_addr)
}
}
impl TryFrom<IpListenEndpoint> for SocketAddr {
type Error = Error;
fn try_from(value: IpListenEndpoint) -> Result<Self> {
let port = value.port;
let socket_addr = match value.addr {
None => return_errno_with_message!(Errno::EINVAL, "address is unspecified"),
Some(IpAddress::Ipv4(address)) => SocketAddr::IPv4(address, port),
};
Ok(socket_addr)
}
}