Remove Arcs in TCP and UDP states

This commit is contained in:
Ruihan Li 2024-01-07 23:55:23 +08:00 committed by Tate, Hongliang Tian
parent 07e8cfe2e7
commit a10d04c5f9
14 changed files with 672 additions and 715 deletions

View File

@ -9,6 +9,7 @@ pub type RawSocketHandle = smoltcp::iface::SocketHandle;
pub struct AnyUnboundSocket {
socket_family: AnyRawSocket,
observer: Weak<dyn Observer<()>>,
}
#[allow(clippy::large_enum_variant)]
@ -23,7 +24,7 @@ pub(super) enum SocketFamily {
}
impl AnyUnboundSocket {
pub fn new_tcp() -> Self {
pub fn new_tcp(observer: Weak<dyn Observer<()>>) -> Self {
let raw_tcp_socket = {
let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; RECV_BUF_LEN]);
let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; SEND_BUF_LEN]);
@ -31,10 +32,11 @@ impl AnyUnboundSocket {
};
AnyUnboundSocket {
socket_family: AnyRawSocket::Tcp(raw_tcp_socket),
observer,
}
}
pub fn new_udp() -> Self {
pub fn new_udp(observer: Weak<dyn Observer<()>>) -> Self {
let raw_udp_socket = {
let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY;
let rx_buffer = smoltcp::socket::udp::PacketBuffer::new(
@ -49,18 +51,12 @@ impl AnyUnboundSocket {
};
AnyUnboundSocket {
socket_family: AnyRawSocket::Udp(raw_udp_socket),
observer,
}
}
pub(super) fn raw_socket_family(self) -> AnyRawSocket {
self.socket_family
}
pub(super) fn socket_family(&self) -> SocketFamily {
match &self.socket_family {
AnyRawSocket::Tcp(_) => SocketFamily::Tcp,
AnyRawSocket::Udp(_) => SocketFamily::Udp,
}
pub(super) fn into_raw(self) -> (AnyRawSocket, Weak<dyn Observer<()>>) {
(self.socket_family, self.observer)
}
}
@ -79,13 +75,14 @@ impl AnyBoundSocket {
handle: smoltcp::iface::SocketHandle,
port: u16,
socket_family: SocketFamily,
observer: Weak<dyn Observer<()>>,
) -> Arc<Self> {
Arc::new_cyclic(|weak_self| Self {
iface,
handle,
port,
socket_family,
observer: RwLock::new(Weak::<()>::new()),
observer: RwLock::new(observer),
weak_self: weak_self.clone(),
})
}

View File

@ -11,7 +11,7 @@ use smoltcp::{
};
use super::{
any_socket::{AnyBoundSocket, AnyRawSocket, AnyUnboundSocket},
any_socket::{AnyBoundSocket, AnyRawSocket, AnyUnboundSocket, SocketFamily},
time::get_network_timestamp,
util::BindPortConfig,
Iface, Ipv4Address,
@ -107,20 +107,28 @@ impl IfaceCommon {
} else {
match self.alloc_ephemeral_port() {
Ok(port) => port,
Err(e) => return Err((e, socket)),
Err(err) => return Err((err, socket)),
}
};
if let Some(e) = self.bind_port(port, config.can_reuse()).err() {
return Err((e, socket));
if let Some(err) = self.bind_port(port, config.can_reuse()).err() {
return Err((err, socket));
}
let socket_family = socket.socket_family();
let mut sockets = self.sockets.lock_irq_disabled();
let handle = match socket.raw_socket_family() {
AnyRawSocket::Tcp(tcp_socket) => sockets.add(tcp_socket),
AnyRawSocket::Udp(udp_socket) => sockets.add(udp_socket),
let (handle, socket_family, observer) = match socket.into_raw() {
(AnyRawSocket::Tcp(tcp_socket), observer) => (
self.sockets.lock_irq_disabled().add(tcp_socket),
SocketFamily::Tcp,
observer,
),
(AnyRawSocket::Udp(udp_socket), observer) => (
self.sockets.lock_irq_disabled().add(udp_socket),
SocketFamily::Udp,
observer,
),
};
let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family);
let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family, observer);
self.insert_bound_socket(&bound_socket).unwrap();
Ok(bound_socket)
}

View File

@ -46,7 +46,6 @@ pub trait Iface: internal::IfaceInternal + Send + Sync {
config: BindPortConfig,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Box<AnyUnboundSocket>)> {
let common = self.common();
let socket_type_inner = socket.socket_family();
common.bind_socket(self.arc_self(), socket, config)
}

View File

@ -44,7 +44,7 @@ fn get_ephemeral_iface(remote_ip_addr: &IpAddress) -> Arc<dyn Iface> {
pub(super) fn bind_socket(
unbound_socket: Box<AnyUnboundSocket>,
endpoint: IpEndpoint,
endpoint: &IpEndpoint,
can_reuse: bool,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Box<AnyUnboundSocket>)> {
let iface = match get_iface_to_bind(&endpoint.addr) {

View File

@ -1,61 +1,49 @@
// SPDX-License-Identifier: MPL-2.0
use crate::{
events::{IoEvents, Observer},
events::IoEvents,
net::{
iface::{AnyBoundSocket, IpEndpoint, RawUdpSocket},
poll_ifaces,
socket::util::send_recv_flags::SendRecvFlags,
},
prelude::*,
process::signal::{Pollee, Poller},
process::signal::Pollee,
};
pub struct BoundDatagram {
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: RwLock<Option<IpEndpoint>>,
pollee: Pollee,
remote_endpoint: Option<IpEndpoint>,
}
impl BoundDatagram {
pub fn new(bound_socket: Arc<AnyBoundSocket>, pollee: Pollee) -> Arc<Self> {
let bound = Arc::new(Self {
pub fn new(bound_socket: Arc<AnyBoundSocket>) -> Self {
Self {
bound_socket,
remote_endpoint: RwLock::new(None),
pollee,
});
bound.bound_socket.set_observer(Arc::downgrade(&bound) as _);
bound
remote_endpoint: None,
}
}
pub fn local_endpoint(&self) -> IpEndpoint {
self.bound_socket.local_endpoint().unwrap()
}
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
self.remote_endpoint
.read()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "remote endpoint is not specified"))
}
pub fn set_remote_endpoint(&self, endpoint: IpEndpoint) {
*self.remote_endpoint.write() = Some(endpoint);
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.bound_socket.local_endpoint().ok_or_else(|| {
Error::with_message(Errno::EINVAL, "socket does not bind to local endpoint")
})
pub fn set_remote_endpoint(&mut self, endpoint: &IpEndpoint) {
self.remote_endpoint = Some(*endpoint)
}
pub fn try_recvfrom(
&self,
buf: &mut [u8],
flags: &SendRecvFlags,
flags: SendRecvFlags,
) -> Result<(usize, IpEndpoint)> {
poll_ifaces();
let recv_slice = |socket: &mut RawUdpSocket| {
socket
.recv_slice(buf)
.map_err(|_| Error::with_message(Errno::EAGAIN, "recv buf is empty"))
};
self.bound_socket.raw_with(recv_slice)
self.bound_socket
.raw_with(|socket: &mut RawUdpSocket| socket.recv_slice(buf))
.map_err(|_| Error::with_message(Errno::EAGAIN, "recv buf is empty"))
}
pub fn try_sendto(
@ -65,27 +53,21 @@ impl BoundDatagram {
flags: SendRecvFlags,
) -> Result<usize> {
let remote_endpoint = remote
.or_else(|| self.remote_endpoint().ok())
.or(self.remote_endpoint)
.ok_or_else(|| Error::with_message(Errno::EINVAL, "udp should provide remote addr"))?;
let send_slice = |socket: &mut RawUdpSocket| {
socket
.send_slice(buf, remote_endpoint)
.map(|_| buf.len())
.map_err(|_| Error::with_message(Errno::EAGAIN, "send udp packet fails"))
};
let len = self.bound_socket.raw_with(send_slice)?;
poll_ifaces();
Ok(len)
self.bound_socket
.raw_with(|socket: &mut RawUdpSocket| socket.send_slice(buf, remote_endpoint))
.map(|_| buf.len())
.map_err(|_| Error::with_message(Errno::EAGAIN, "send udp packet fails"))
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
pub(super) fn init_pollee(&self, pollee: &Pollee) {
pollee.reset_events();
self.update_io_events(pollee)
}
fn update_io_events(&self) {
pub(super) fn update_io_events(&self, pollee: &Pollee) {
self.bound_socket.raw_with(|socket: &mut RawUdpSocket| {
let pollee = &self.pollee;
if socket.can_recv() {
pollee.add_events(IoEvents::IN);
} else {
@ -100,9 +82,3 @@ impl BoundDatagram {
});
}
}
impl Observer<()> for BoundDatagram {
fn on_events(&self, _: &()) {
self.update_io_events();
}
}

View File

@ -1,81 +1,91 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use core::{
mem,
sync::atomic::{AtomicBool, Ordering},
};
use self::{bound::BoundDatagram, unbound::UnboundDatagram};
use super::{always_some::AlwaysSome, common::get_ephemeral_endpoint};
use super::common::get_ephemeral_endpoint;
use crate::{
events::IoEvents,
events::{IoEvents, Observer},
fs::{file_handle::FileLike, utils::StatusFlags},
net::{
iface::IpEndpoint,
poll_ifaces,
socket::{
util::{send_recv_flags::SendRecvFlags, socket_addr::SocketAddr},
Socket,
},
},
prelude::*,
process::signal::Poller,
process::signal::{Pollee, Poller},
};
mod bound;
mod unbound;
pub struct DatagramSocket {
nonblocking: AtomicBool,
inner: RwLock<Inner>,
nonblocking: AtomicBool,
pollee: Pollee,
}
enum Inner {
Unbound(AlwaysSome<UnboundDatagram>),
Bound(Arc<BoundDatagram>),
Unbound(UnboundDatagram),
Bound(BoundDatagram),
Poisoned,
}
impl Inner {
fn is_bound(&self) -> bool {
matches!(self, Inner::Bound { .. })
}
fn bind(&mut self, endpoint: IpEndpoint) -> Result<Arc<BoundDatagram>> {
let unbound = match self {
Inner::Unbound(unbound) => unbound,
Inner::Bound(..) => return_errno_with_message!(
Errno::EINVAL,
"the socket is already bound to an address"
),
fn bind(self, endpoint: &IpEndpoint) -> core::result::Result<BoundDatagram, (Error, Self)> {
let unbound_datagram = match self {
Inner::Unbound(unbound_datagram) => unbound_datagram,
Inner::Bound(bound_datagram) => {
return Err((
Error::with_message(Errno::EINVAL, "the socket is already bound to an address"),
Inner::Bound(bound_datagram),
));
}
Inner::Poisoned => {
return Err((
Error::with_message(Errno::EINVAL, "the socket is poisoned"),
Inner::Poisoned,
));
}
};
let bound = unbound.try_take_with(|unbound| unbound.bind(endpoint))?;
*self = Inner::Bound(bound.clone());
Ok(bound)
let bound_datagram = match unbound_datagram.bind(endpoint) {
Ok(bound_datagram) => bound_datagram,
Err((err, unbound_datagram)) => return Err((err, Inner::Unbound(unbound_datagram))),
};
Ok(bound_datagram)
}
fn bind_to_ephemeral_endpoint(
&mut self,
self,
remote_endpoint: &IpEndpoint,
) -> Result<Arc<BoundDatagram>> {
let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(endpoint)
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
match self {
Inner::Unbound(unbound) => unbound.poll(mask, poller),
Inner::Bound(bound) => bound.poll(mask, poller),
) -> core::result::Result<BoundDatagram, (Error, Self)> {
if let Inner::Bound(bound_datagram) = self {
return Ok(bound_datagram);
}
let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(&endpoint)
}
}
impl DatagramSocket {
pub fn new(nonblocking: bool) -> Self {
let unbound = UnboundDatagram::new();
Self {
inner: RwLock::new(Inner::Unbound(AlwaysSome::new(unbound))),
nonblocking: AtomicBool::new(nonblocking),
}
}
pub fn is_bound(&self) -> bool {
self.inner.read().is_bound()
pub fn new(nonblocking: bool) -> Arc<Self> {
Arc::new_cyclic(|me| {
let unbound_datagram = UnboundDatagram::new(me.clone() as _);
let pollee = Pollee::new(IoEvents::empty());
Self {
inner: RwLock::new(Inner::Unbound(unbound_datagram)),
nonblocking: AtomicBool::new(nonblocking),
pollee,
}
})
}
pub fn is_nonblocking(&self) -> bool {
@ -86,26 +96,81 @@ impl DatagramSocket {
self.nonblocking.store(nonblocking, Ordering::SeqCst);
}
fn bound(&self) -> Result<Arc<BoundDatagram>> {
if let Inner::Bound(bound) = &*self.inner.read() {
Ok(bound.clone())
} else {
return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint")
}
}
fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<BoundDatagram>> {
fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
// Fast path
if let Inner::Bound(bound) = &*self.inner.read() {
return Ok(bound.clone());
if let Inner::Bound(_) = &*self.inner.read() {
return Ok(());
}
// Slow path
let mut inner = self.inner.write();
if let Inner::Bound(bound) = &*inner {
return Ok(bound.clone());
let owned_inner = mem::replace(&mut *inner, Inner::Poisoned);
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) {
Ok(bound_datagram) => bound_datagram,
Err((err, err_inner)) => {
*inner = err_inner;
return Err(err);
}
};
bound_datagram.init_pollee(&self.pollee);
*inner = Inner::Bound(bound_datagram);
Ok(())
}
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
};
let (recv_bytes, remote_endpoint) = bound_datagram.try_recvfrom(buf, flags)?;
bound_datagram.update_io_events(&self.pollee);
Ok((recv_bytes, remote_endpoint.try_into()?))
}
fn try_sendto(
&self,
buf: &[u8],
remote: Option<IpEndpoint>,
flags: SendRecvFlags,
) -> Result<usize> {
let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
};
let sent_bytes = bound_datagram.try_sendto(buf, remote, flags)?;
bound_datagram.update_io_events(&self.pollee);
Ok(sent_bytes)
}
// TODO: Support timeout
fn wait_events<F, R>(&self, mask: IoEvents, mut cond: F) -> Result<R>
where
F: FnMut() -> Result<R>,
{
let poller = Poller::new();
loop {
match cond() {
Err(err) if err.error() == Errno::EAGAIN => (),
result => return result,
};
let events = self.poll(mask, Some(&poller));
if !events.is_empty() {
continue;
}
poller.wait()?;
}
inner.bind_to_ephemeral_endpoint(remote_endpoint)
}
fn update_io_events(&self) {
let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
return;
};
bound_datagram.update_io_events(&self.pollee);
}
}
@ -124,7 +189,7 @@ impl FileLike for DatagramSocket {
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.inner.read().poll(mask, poller)
self.pollee.poll(mask, poller)
}
fn as_socket(self: Arc<Self>) -> Option<Arc<dyn Socket>> {
@ -152,43 +217,61 @@ impl FileLike for DatagramSocket {
impl Socket for DatagramSocket {
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?;
self.inner.write().bind(endpoint)?;
let mut inner = self.inner.write();
let owned_inner = mem::replace(&mut *inner, Inner::Poisoned);
let bound_datagram = match owned_inner.bind(&endpoint) {
Ok(bound_datagram) => bound_datagram,
Err((err, err_inner)) => {
*inner = err_inner;
return Err(err);
}
};
bound_datagram.init_pollee(&self.pollee);
*inner = Inner::Bound(bound_datagram);
Ok(())
}
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?;
let bound = self.try_bind_empheral(&endpoint)?;
bound.set_remote_endpoint(endpoint);
self.try_bind_empheral(&endpoint)?;
let mut inner = self.inner.write();
let Inner::Bound(bound_datagram) = &mut *inner else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound")
};
bound_datagram.set_remote_endpoint(&endpoint);
Ok(())
}
fn addr(&self) -> Result<SocketAddr> {
self.bound()?.local_endpoint()?.try_into()
let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
};
bound_datagram.local_endpoint().try_into()
}
fn peer_addr(&self) -> Result<SocketAddr> {
self.bound()?.remote_endpoint()?.try_into()
let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
};
bound_datagram.remote_endpoint()?.try_into()
}
// FIXME: respect RecvFromFlags
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
debug_assert!(flags.is_all_supported());
let bound = self.bound()?;
let poller = Poller::new();
loop {
if let Ok((recv_len, remote_endpoint)) = bound.try_recvfrom(buf, &flags) {
let remote_addr = remote_endpoint.try_into()?;
return Ok((recv_len, remote_addr));
}
let events = bound.poll(IoEvents::IN, Some(&poller));
if !events.contains(IoEvents::IN) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try to receive again");
}
// FIXME: deal with recvfrom timeout
poller.wait()?;
}
poll_ifaces();
if self.is_nonblocking() {
self.try_recvfrom(buf, flags)
} else {
self.wait_events(IoEvents::IN, || self.try_recvfrom(buf, flags))
}
}
@ -199,13 +282,24 @@ impl Socket for DatagramSocket {
flags: SendRecvFlags,
) -> Result<usize> {
debug_assert!(flags.is_all_supported());
let (bound, remote_endpoint) = if let Some(addr) = remote {
let endpoint = addr.try_into()?;
(self.try_bind_empheral(&endpoint)?, Some(endpoint))
} else {
let bound = self.bound()?;
(bound, None)
let remote_endpoint = match remote {
Some(remote_addr) => Some(remote_addr.try_into()?),
None => None,
};
bound.try_sendto(buf, remote_endpoint, flags)
if let Some(endpoint) = remote_endpoint {
self.try_bind_empheral(&endpoint)?;
}
// TODO: Block if the send buffer is full
let sent_bytes = self.try_sendto(buf, remote_endpoint, flags)?;
poll_ifaces();
Ok(sent_bytes)
}
}
impl Observer<()> for DatagramSocket {
fn on_events(&self, events: &()) {
self.update_io_events();
}
}

View File

@ -1,53 +1,39 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::sync::Weak;
use super::bound::BoundDatagram;
use crate::{
events::IoEvents,
events::Observer,
net::{
iface::{AnyUnboundSocket, IpEndpoint, RawUdpSocket},
socket::ip::common::bind_socket,
},
prelude::*,
process::signal::{Pollee, Poller},
};
pub struct UnboundDatagram {
unbound_socket: Box<AnyUnboundSocket>,
pollee: Pollee,
}
impl UnboundDatagram {
pub fn new() -> Self {
pub fn new(observer: Weak<dyn Observer<()>>) -> Self {
Self {
unbound_socket: Box::new(AnyUnboundSocket::new_udp()),
pollee: Pollee::new(IoEvents::empty()),
unbound_socket: Box::new(AnyUnboundSocket::new_udp(observer)),
}
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
pub fn bind(
self,
endpoint: IpEndpoint,
) -> core::result::Result<Arc<BoundDatagram>, (Error, Self)> {
pub fn bind(self, endpoint: &IpEndpoint) -> core::result::Result<BoundDatagram, (Error, Self)> {
let bound_socket = match bind_socket(self.unbound_socket, endpoint, false) {
Ok(bound_socket) => bound_socket,
Err((err, unbound_socket)) => {
return Err((
err,
Self {
unbound_socket,
pollee: self.pollee,
},
))
}
Err((err, unbound_socket)) => return Err((err, Self { unbound_socket })),
};
let bound_endpoint = bound_socket.local_endpoint().unwrap();
bound_socket.raw_with(|socket: &mut RawUdpSocket| {
socket.bind(bound_endpoint).unwrap();
});
Ok(BoundDatagram::new(bound_socket, self.pollee))
Ok(BoundDatagram::new(bound_socket))
}
}

View File

@ -1,6 +1,5 @@
// SPDX-License-Identifier: MPL-2.0
mod always_some;
mod common;
mod datagram;
pub mod stream;

View File

@ -1,42 +1,28 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use alloc::sync::Weak;
use crate::{
events::{IoEvents, Observer},
net::{
iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket},
poll_ifaces,
socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd},
},
prelude::*,
process::signal::{Pollee, Poller},
process::signal::Pollee,
};
pub struct ConnectedStream {
nonblocking: AtomicBool,
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
pollee: Pollee,
}
impl ConnectedStream {
pub fn new(
is_nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
pollee: Pollee,
) -> Arc<Self> {
let connected = Arc::new(Self {
nonblocking: AtomicBool::new(is_nonblocking),
pub fn new(bound_socket: Arc<AnyBoundSocket>, remote_endpoint: IpEndpoint) -> Self {
Self {
bound_socket,
remote_endpoint,
pollee,
});
connected
.bound_socket
.set_observer(Arc::downgrade(&connected) as _);
connected
}
}
pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
@ -44,102 +30,46 @@ impl ConnectedStream {
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
socket.close();
});
poll_ifaces();
Ok(())
}
pub fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, IpEndpoint)> {
debug_assert!(flags.is_all_supported());
let poller = Poller::new();
loop {
let recv_len = self.try_recvfrom(buf, flags)?;
if recv_len > 0 {
let remote_endpoint = self.remote_endpoint()?;
return Ok((recv_len, remote_endpoint));
}
let events = self.poll(IoEvents::IN, Some(&poller));
if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) {
return_errno_with_message!(Errno::ENOTCONN, "recv packet fails");
}
if !events.contains(IoEvents::IN) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try to recv again");
}
// FIXME: deal with receive timeout
poller.wait()?;
}
pub fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<usize> {
let recv_bytes = self
.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.recv_slice(buf))
.map_err(|_| Error::with_message(Errno::ENOTCONN, "fail to recv packet"))?;
if recv_bytes == 0 {
return_errno_with_message!(Errno::EAGAIN, "try to recv again");
}
Ok(recv_bytes)
}
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<usize> {
poll_ifaces();
let res = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
socket
.recv_slice(buf)
.map_err(|_| Error::with_message(Errno::ENOTCONN, "fail to recv packet"))
});
self.update_io_events();
res
}
pub fn sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
debug_assert!(flags.is_all_supported());
let poller = Poller::new();
loop {
let sent_len = self.try_sendto(buf, flags)?;
if sent_len > 0 {
return Ok(sent_len);
}
let events = self.poll(IoEvents::OUT, Some(&poller));
if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) {
return_errno_with_message!(Errno::ENOBUFS, "fail to send packets");
}
if !events.contains(IoEvents::OUT) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try to send again");
}
// FIXME: deal with send timeout
poller.wait()?;
}
}
}
fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
let res = self
pub fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
let sent_bytes = self
.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.send_slice(buf))
.map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet"));
match res {
// We have to explicitly invoke `update_io_events` when the send buffer becomes
// full. Note that smoltcp does not think it is an interface event, so calling
// `poll_ifaces` alone is not enough.
Ok(0) => self.update_io_events(),
Ok(_) => poll_ifaces(),
_ => (),
};
res
.map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet"))?;
if sent_bytes == 0 {
return_errno_with_message!(Errno::EAGAIN, "try to send again");
}
Ok(sent_bytes)
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.bound_socket
.local_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint"))
pub fn local_endpoint(&self) -> IpEndpoint {
self.bound_socket.local_endpoint().unwrap()
}
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
Ok(self.remote_endpoint)
pub fn remote_endpoint(&self) -> IpEndpoint {
self.remote_endpoint
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
pub(super) fn init_pollee(&self, pollee: &Pollee) {
pollee.reset_events();
self.update_io_events(pollee);
}
fn update_io_events(&self) {
pub(super) fn update_io_events(&self, pollee: &Pollee) {
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
let pollee = &self.pollee;
if socket.can_recv() {
pollee.add_events(IoEvents::IN);
} else {
@ -154,17 +84,7 @@ impl ConnectedStream {
});
}
pub fn is_nonblocking(&self) -> bool {
self.nonblocking.load(Ordering::Relaxed)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.nonblocking.store(nonblocking, Ordering::Relaxed);
}
}
impl Observer<()> for ConnectedStream {
fn on_events(&self, _: &()) {
self.update_io_events();
pub(super) fn set_observer(&self, observer: Weak<dyn Observer<()>>) {
self.bound_socket.set_observer(observer)
}
}

View File

@ -1,116 +1,77 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use super::{connected::ConnectedStream, init::InitStream};
use crate::{
events::{IoEvents, Observer},
net::{
iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket},
poll_ifaces,
},
events::IoEvents,
net::iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket},
prelude::*,
process::signal::{Pollee, Poller},
process::signal::Pollee,
};
pub struct ConnectingStream {
nonblocking: AtomicBool,
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
conn_result: RwLock<Option<ConnResult>>,
pollee: Pollee,
}
#[derive(Clone, Copy)]
enum ConnResult {
Connected,
Refused,
}
pub enum NonConnectedStream {
Init(InitStream),
Connecting(ConnectingStream),
}
impl ConnectingStream {
pub fn new(
nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
pollee: Pollee,
) -> Result<Arc<Self>> {
bound_socket.do_connect(remote_endpoint)?;
let connecting = Arc::new(Self {
nonblocking: AtomicBool::new(nonblocking),
) -> core::result::Result<Self, (Error, Arc<AnyBoundSocket>)> {
if let Err(err) = bound_socket.do_connect(remote_endpoint) {
return Err((err, bound_socket));
}
Ok(Self {
bound_socket,
remote_endpoint,
conn_result: RwLock::new(None),
pollee,
});
connecting.pollee.reset_events();
connecting
.bound_socket
.set_observer(Arc::downgrade(&connecting) as _);
Ok(connecting)
})
}
pub fn wait_conn(
&self,
) -> core::result::Result<Arc<ConnectedStream>, (Error, Arc<InitStream>)> {
debug_assert!(!self.is_nonblocking());
let poller = Poller::new();
loop {
poll_ifaces();
match *self.conn_result.read() {
Some(ConnResult::Connected) => {
return Ok(ConnectedStream::new(
self.is_nonblocking(),
self.bound_socket.clone(),
self.remote_endpoint,
self.pollee.clone(),
));
}
Some(ConnResult::Refused) => {
return Err((
Error::with_message(Errno::ECONNREFUSED, "connection refused"),
InitStream::new_bound(
self.is_nonblocking(),
self.bound_socket.clone(),
self.pollee.clone(),
),
));
}
None => (),
};
let events = self.poll(IoEvents::OUT, Some(&poller));
if !events.contains(IoEvents::OUT) {
// FIXME: deal with nonblocking mode & connecting timeout
poller.wait().expect("async connect() not implemented");
}
pub fn into_result(self) -> core::result::Result<ConnectedStream, (Error, NonConnectedStream)> {
let conn_result = *self.conn_result.read();
match conn_result {
Some(ConnResult::Connected) => Ok(ConnectedStream::new(
self.bound_socket,
self.remote_endpoint,
)),
Some(ConnResult::Refused) => Err((
Error::with_message(Errno::ECONNREFUSED, "the connection is refused"),
NonConnectedStream::Init(InitStream::new_bound(self.bound_socket)),
)),
None => Err((
Error::with_message(Errno::EAGAIN, "the connection is pending"),
NonConnectedStream::Connecting(self),
)),
}
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.bound_socket
.local_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "no local endpoint"))
pub fn local_endpoint(&self) -> IpEndpoint {
self.bound_socket.local_endpoint().unwrap()
}
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
Ok(self.remote_endpoint)
pub fn remote_endpoint(&self) -> IpEndpoint {
self.remote_endpoint
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
pub(super) fn init_pollee(&self, pollee: &Pollee) {
pollee.reset_events();
self.update_io_events(pollee);
}
pub fn is_nonblocking(&self) -> bool {
self.nonblocking.load(Ordering::Relaxed)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.nonblocking.store(nonblocking, Ordering::Relaxed);
}
fn update_io_events(&self) {
pub(super) fn update_io_events(&self, pollee: &Pollee) {
if self.conn_result.read().is_some() {
return;
}
@ -143,13 +104,7 @@ impl ConnectingStream {
// be responsible to initialize all the I/O events including `IoEvents::OUT`, so the
// following hard-coded event addition can be removed.
if became_writable {
self.pollee.add_events(IoEvents::OUT);
pollee.add_events(IoEvents::OUT);
}
}
}
impl Observer<()> for ConnectingStream {
fn on_events(&self, _: &()) {
self.update_io_events();
}
}

View File

@ -1,156 +1,93 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use alloc::sync::Weak;
use super::{connecting::ConnectingStream, listen::ListenStream};
use crate::{
events::IoEvents,
events::Observer,
net::{
iface::{AnyBoundSocket, AnyUnboundSocket, Iface, IpEndpoint},
socket::ip::{
always_some::AlwaysSome,
common::{bind_socket, get_ephemeral_endpoint},
},
iface::{AnyBoundSocket, AnyUnboundSocket, IpEndpoint},
socket::ip::common::{bind_socket, get_ephemeral_endpoint},
},
prelude::*,
process::signal::{Pollee, Poller},
};
pub struct InitStream {
inner: RwLock<Inner>,
is_nonblocking: AtomicBool,
pollee: Pollee,
}
enum Inner {
Unbound(AlwaysSome<Box<AnyUnboundSocket>>),
Bound(AlwaysSome<Arc<AnyBoundSocket>>),
}
impl Inner {
fn new() -> Inner {
let unbound_socket = Box::new(AnyUnboundSocket::new_tcp());
Inner::Unbound(AlwaysSome::new(unbound_socket))
}
fn is_bound(&self) -> bool {
match self {
Self::Unbound(_) => false,
Self::Bound(_) => true,
}
}
fn bind(&mut self, endpoint: IpEndpoint) -> Result<()> {
let unbound_socket = if let Inner::Unbound(unbound_socket) = self {
unbound_socket
} else {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address");
};
let bound_socket =
unbound_socket.try_take_with(|raw_socket| bind_socket(raw_socket, endpoint, false))?;
*self = Inner::Bound(AlwaysSome::new(bound_socket));
Ok(())
}
fn bind_to_ephemeral_endpoint(&mut self, remote_endpoint: &IpEndpoint) -> Result<()> {
let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(endpoint)
}
fn bound_socket(&self) -> Option<&Arc<AnyBoundSocket>> {
match self {
Inner::Bound(bound_socket) => Some(bound_socket),
Inner::Unbound(_) => None,
}
}
fn iface(&self) -> Option<Arc<dyn Iface>> {
match self {
Inner::Bound(bound_socket) => Some(bound_socket.iface().clone()),
Inner::Unbound(_) => None,
}
}
fn local_endpoint(&self) -> Option<IpEndpoint> {
self.bound_socket()
.and_then(|socket| socket.local_endpoint())
}
pub enum InitStream {
Unbound(Box<AnyUnboundSocket>),
Bound(Arc<AnyBoundSocket>),
}
impl InitStream {
// FIXME: In Linux we have the `POLLOUT` event for a newly created socket, while calling
// `write()` on it triggers `SIGPIPE`/`EPIPE`. No documentation found yet, but confirmed by
// experimentation and Linux source code.
pub fn new(nonblocking: bool) -> Arc<Self> {
Arc::new(Self {
inner: RwLock::new(Inner::new()),
is_nonblocking: AtomicBool::new(nonblocking),
pollee: Pollee::new(IoEvents::empty()),
})
pub fn new(observer: Weak<dyn Observer<()>>) -> Self {
InitStream::Unbound(Box::new(AnyUnboundSocket::new_tcp(observer)))
}
pub fn new_bound(
nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>,
pollee: Pollee,
) -> Arc<Self> {
bound_socket.set_observer(Weak::<()>::new());
let inner = Inner::Bound(AlwaysSome::new(bound_socket));
Arc::new(Self {
is_nonblocking: AtomicBool::new(nonblocking),
inner: RwLock::new(inner),
pollee,
})
pub fn new_bound(bound_socket: Arc<AnyBoundSocket>) -> Self {
InitStream::Bound(bound_socket)
}
pub fn bind(&self, endpoint: IpEndpoint) -> Result<()> {
self.inner.write().bind(endpoint)
}
pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<ConnectingStream>> {
if !self.inner.read().is_bound() {
self.inner
.write()
.bind_to_ephemeral_endpoint(remote_endpoint)?
}
ConnectingStream::new(
self.is_nonblocking(),
self.inner.read().bound_socket().unwrap().clone(),
*remote_endpoint,
self.pollee.clone(),
)
}
pub fn listen(&self, backlog: usize) -> Result<Arc<ListenStream>> {
let bound_socket = if let Some(bound_socket) = self.inner.read().bound_socket() {
bound_socket.clone()
} else {
return_errno_with_message!(Errno::EINVAL, "cannot listen without bound")
pub fn bind(
self,
endpoint: &IpEndpoint,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Self)> {
let unbound_socket = match self {
InitStream::Unbound(unbound_socket) => unbound_socket,
InitStream::Bound(bound_socket) => {
return Err((
Error::with_message(Errno::EINVAL, "the socket is already bound to an address"),
InitStream::Bound(bound_socket),
));
}
};
ListenStream::new(
self.is_nonblocking(),
bound_socket,
backlog,
self.pollee.clone(),
)
let bound_socket = match bind_socket(unbound_socket, endpoint, false) {
Ok(bound_socket) => bound_socket,
Err((err, unbound_socket)) => return Err((err, InitStream::Unbound(unbound_socket))),
};
Ok(bound_socket)
}
fn bind_to_ephemeral_endpoint(
self,
remote_endpoint: &IpEndpoint,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Self)> {
let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(&endpoint)
}
pub fn connect(
self,
remote_endpoint: &IpEndpoint,
) -> core::result::Result<ConnectingStream, (Error, Self)> {
let bound_socket = match self {
InitStream::Bound(bound_socket) => bound_socket,
InitStream::Unbound(_) => self.bind_to_ephemeral_endpoint(remote_endpoint)?,
};
ConnectingStream::new(bound_socket, *remote_endpoint)
.map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket)))
}
pub fn listen(self, backlog: usize) -> core::result::Result<ListenStream, (Error, Self)> {
let InitStream::Bound(bound_socket) = self else {
return Err((
Error::with_message(Errno::EINVAL, "cannot listen without bound"),
self,
));
};
ListenStream::new(bound_socket, backlog)
.map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket)))
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.inner
.read()
.local_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has local endpoint"))
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
pub fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Relaxed)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
match self {
InitStream::Unbound(_) => {
return_errno_with_message!(Errno::EINVAL, "does not has local endpoint")
}
InitStream::Bound(bound_socket) => Ok(bound_socket.local_endpoint().unwrap()),
}
}
}

View File

@ -1,148 +1,97 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use super::connected::ConnectedStream;
use crate::{
events::{IoEvents, Observer},
net::{
iface::{AnyBoundSocket, AnyUnboundSocket, BindPortConfig, IpEndpoint, RawTcpSocket},
poll_ifaces,
},
events::IoEvents,
net::iface::{AnyBoundSocket, AnyUnboundSocket, BindPortConfig, IpEndpoint, RawTcpSocket},
prelude::*,
process::signal::{Pollee, Poller},
process::signal::Pollee,
};
pub struct ListenStream {
is_nonblocking: AtomicBool,
backlog: usize,
/// A bound socket held to ensure the TCP port cannot be released
bound_socket: Arc<AnyBoundSocket>,
/// Backlog sockets listening at the local endpoint
backlog_sockets: RwLock<Vec<BacklogSocket>>,
pollee: Pollee,
}
impl ListenStream {
pub fn new(
nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>,
backlog: usize,
pollee: Pollee,
) -> Result<Arc<Self>> {
let listen_stream = Arc::new(Self {
is_nonblocking: AtomicBool::new(nonblocking),
) -> core::result::Result<Self, (Error, Arc<AnyBoundSocket>)> {
let listen_stream = Self {
backlog,
bound_socket,
backlog_sockets: RwLock::new(Vec::new()),
pollee,
});
listen_stream.fill_backlog_sockets()?;
listen_stream.pollee.reset_events();
listen_stream
.bound_socket
.set_observer(Arc::downgrade(&listen_stream) as _);
Ok(listen_stream)
}
pub fn accept(&self) -> Result<(Arc<ConnectedStream>, IpEndpoint)> {
// wait to accept
let poller = Poller::new();
loop {
poll_ifaces();
let accepted_socket = if let Some(accepted_socket) = self.try_accept() {
accepted_socket
} else {
let events = self.poll(IoEvents::IN, Some(&poller));
if !events.contains(IoEvents::IN) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try accept again");
}
// FIXME: deal with accept timeout
poller.wait()?;
}
continue;
};
let remote_endpoint = accepted_socket.remote_endpoint().unwrap();
let connected_stream = {
let BacklogSocket {
bound_socket: backlog_socket,
} = accepted_socket;
ConnectedStream::new(
false,
backlog_socket,
remote_endpoint,
Pollee::new(IoEvents::empty()),
)
};
return Ok((connected_stream, remote_endpoint));
};
if let Err(err) = listen_stream.fill_backlog_sockets() {
return Err((err, listen_stream.bound_socket));
}
Ok(listen_stream)
}
/// Append sockets listening at LocalEndPoint to support backlog
fn fill_backlog_sockets(&self) -> Result<()> {
let backlog = self.backlog;
let mut backlog_sockets = self.backlog_sockets.write();
let backlog = self.backlog;
let current_backlog_len = backlog_sockets.len();
debug_assert!(backlog >= current_backlog_len);
if backlog == current_backlog_len {
return Ok(());
}
for _ in current_backlog_len..backlog {
let backlog_socket = BacklogSocket::new(&self.bound_socket)?;
backlog_sockets.push(backlog_socket);
}
Ok(())
}
fn try_accept(&self) -> Option<BacklogSocket> {
let backlog_socket = {
let mut backlog_sockets = self.backlog_sockets.write();
let index = backlog_sockets
.iter()
.position(|backlog_socket| backlog_socket.is_active())?;
backlog_sockets.remove(index)
};
self.fill_backlog_sockets().unwrap();
self.update_io_events();
Some(backlog_socket)
pub fn try_accept(&self) -> Result<ConnectedStream> {
let mut backlog_sockets = self.backlog_sockets.write();
let index = backlog_sockets
.iter()
.position(|backlog_socket| backlog_socket.is_active())
.ok_or_else(|| Error::with_message(Errno::EAGAIN, "try to accept again"))?;
let active_backlog_socket = backlog_sockets.remove(index);
match BacklogSocket::new(&self.bound_socket) {
Ok(backlog_socket) => backlog_sockets.push(backlog_socket),
Err(err) => (),
}
let remote_endpoint = active_backlog_socket.remote_endpoint().unwrap();
Ok(ConnectedStream::new(
active_backlog_socket.into_bound_socket(),
remote_endpoint,
))
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.bound_socket
.local_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint"))
pub fn local_endpoint(&self) -> IpEndpoint {
self.bound_socket.local_endpoint().unwrap()
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
pub(super) fn init_pollee(&self, pollee: &Pollee) {
pollee.reset_events();
self.update_io_events(pollee);
}
fn update_io_events(&self) {
pub(super) fn update_io_events(&self, pollee: &Pollee) {
// The lock should be held to avoid data races
let backlog_sockets = self.backlog_sockets.read();
let can_accept = backlog_sockets.iter().any(|socket| socket.is_active());
if can_accept {
self.pollee.add_events(IoEvents::IN);
pollee.add_events(IoEvents::IN);
} else {
self.pollee.del_events(IoEvents::IN);
pollee.del_events(IoEvents::IN);
}
}
pub fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Relaxed)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
}
}
impl Observer<()> for ListenStream {
fn on_events(&self, _: &()) {
self.update_io_events();
}
}
struct BacklogSocket {
@ -155,19 +104,21 @@ impl BacklogSocket {
Errno::EINVAL,
"the socket is not bound",
))?;
let unbound_socket = Box::new(AnyUnboundSocket::new_tcp());
let unbound_socket = Box::new(AnyUnboundSocket::new_tcp(Weak::<()>::new()));
let bound_socket = {
let iface = bound_socket.iface();
let bind_port_config = BindPortConfig::new(local_endpoint.port, true)?;
iface
.bind_socket(unbound_socket, bind_port_config)
.map_err(|(e, _)| e)?
.map_err(|(err, _)| err)?
};
bound_socket.raw_with(|raw_tcp_socket: &mut RawTcpSocket| {
raw_tcp_socket
.listen(local_endpoint)
.map_err(|_| Error::with_message(Errno::EINVAL, "fail to listen"))
})?;
Ok(Self { bound_socket })
}
@ -180,4 +131,8 @@ impl BacklogSocket {
self.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint())
}
fn into_bound_socket(self) -> Arc<AnyBoundSocket> {
self.bound_socket
}
}

View File

@ -1,5 +1,10 @@
// SPDX-License-Identifier: MPL-2.0
use core::{
mem,
sync::atomic::{AtomicBool, Ordering},
};
use connected::ConnectedStream;
use connecting::ConnectingStream;
use init::InitStream;
@ -9,21 +14,24 @@ use smoltcp::wire::IpEndpoint;
use util::{TcpOptionSet, DEFAULT_MAXSEG};
use crate::{
events::IoEvents,
events::{IoEvents, Observer},
fs::{file_handle::FileLike, utils::StatusFlags},
match_sock_option_mut, match_sock_option_ref,
net::socket::{
options::{Error, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption},
util::{
options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF},
send_recv_flags::SendRecvFlags,
shutdown_cmd::SockShutdownCmd,
socket_addr::SocketAddr,
net::{
poll_ifaces,
socket::{
options::{Error, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption},
util::{
options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF},
send_recv_flags::SendRecvFlags,
shutdown_cmd::SockShutdownCmd,
socket_addr::SocketAddr,
},
Socket,
},
Socket,
},
prelude::*,
process::signal::Poller,
process::signal::{Pollee, Poller},
};
mod connected;
@ -33,22 +41,27 @@ mod listen;
pub mod options;
mod util;
use self::connecting::NonConnectedStream;
pub use self::util::CongestionControl;
pub struct StreamSocket {
options: RwLock<OptionSet>,
state: RwLock<State>,
is_nonblocking: AtomicBool,
pollee: Pollee,
}
enum State {
// Start state
Init(Arc<InitStream>),
Init(InitStream),
// Intermediate state
Connecting(Arc<ConnectingStream>),
Connecting(ConnectingStream),
// Final State 1
Connected(Arc<ConnectedStream>),
Connected(ConnectedStream),
// Final State 2
Listen(Arc<ListenStream>),
Listen(ListenStream),
// Poisoned state
Poisoned,
}
#[derive(Debug, Clone)]
@ -66,45 +79,159 @@ impl OptionSet {
}
impl StreamSocket {
pub fn new(nonblocking: bool) -> Self {
let options = OptionSet::new();
let state = State::Init(InitStream::new(nonblocking));
Self {
options: RwLock::new(options),
state: RwLock::new(state),
}
pub fn new(nonblocking: bool) -> Arc<Self> {
Arc::new_cyclic(|me| {
let init_stream = InitStream::new(me.clone() as _);
let pollee = Pollee::new(IoEvents::empty());
Self {
options: RwLock::new(OptionSet::new()),
state: RwLock::new(State::Init(init_stream)),
is_nonblocking: AtomicBool::new(nonblocking),
pollee,
}
})
}
fn new_connected(connected_stream: ConnectedStream) -> Arc<Self> {
Arc::new_cyclic(move |me| {
let pollee = Pollee::new(IoEvents::empty());
connected_stream.set_observer(me.clone() as _);
connected_stream.init_pollee(&pollee);
Self {
options: RwLock::new(OptionSet::new()),
state: RwLock::new(State::Connected(connected_stream)),
is_nonblocking: AtomicBool::new(false),
pollee,
}
})
}
fn is_nonblocking(&self) -> bool {
match &*self.state.read() {
State::Init(init) => init.is_nonblocking(),
State::Connecting(connecting) => connecting.is_nonblocking(),
State::Connected(connected) => connected.is_nonblocking(),
State::Listen(listen) => listen.is_nonblocking(),
}
self.is_nonblocking.load(Ordering::Relaxed)
}
fn set_nonblocking(&self, nonblocking: bool) {
match &*self.state.read() {
State::Init(init) => init.set_nonblocking(nonblocking),
State::Connecting(connecting) => connecting.set_nonblocking(nonblocking),
State::Connected(connected) => connected.set_nonblocking(nonblocking),
State::Listen(listen) => listen.set_nonblocking(nonblocking),
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
}
fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
let mut state = self.state.write();
let owned_state = mem::replace(&mut *state, State::Poisoned);
let State::Init(init_stream) = owned_state else {
*state = owned_state;
return_errno_with_message!(Errno::EINVAL, "cannot connect")
};
let connecting_stream = match init_stream.connect(remote_endpoint) {
Ok(connecting_stream) => connecting_stream,
Err((err, init_stream)) => {
*state = State::Init(init_stream);
return Err(err);
}
};
connecting_stream.init_pollee(&self.pollee);
*state = State::Connecting(connecting_stream);
Ok(())
}
fn finish_connect(&self) -> Result<()> {
let mut state = self.state.write();
let owned_state = mem::replace(&mut *state, State::Poisoned);
let State::Connecting(connecting_stream) = owned_state else {
*state = owned_state;
debug_assert!(false, "the socket unexpectedly left the connecting state");
return_errno_with_message!(Errno::EINVAL, "the socket is not connecting");
};
let connected_stream = match connecting_stream.into_result() {
Ok(connected_stream) => connected_stream,
Err((err, NonConnectedStream::Init(init_stream))) => {
*state = State::Init(init_stream);
return Err(err);
}
Err((err, NonConnectedStream::Connecting(connecting_stream))) => {
*state = State::Connecting(connecting_stream);
return Err(err);
}
};
connected_stream.init_pollee(&self.pollee);
*state = State::Connected(connected_stream);
Ok(())
}
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let state = self.state.read();
let State::Listen(listen_stream) = &*state else {
return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
};
let connected_stream = listen_stream.try_accept()?;
listen_stream.update_io_events(&self.pollee);
let remote_endpoint = connected_stream.remote_endpoint();
let accepted_socket = Self::new_connected(connected_stream);
Ok((accepted_socket, remote_endpoint.try_into()?))
}
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let state = self.state.read();
let State::Connected(connected_stream) = &*state else {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
};
let recv_bytes = connected_stream.try_recvfrom(buf, flags)?;
connected_stream.update_io_events(&self.pollee);
Ok((recv_bytes, connected_stream.remote_endpoint().try_into()?))
}
fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
let state = self.state.read();
let State::Connected(connected_stream) = &*state else {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
};
let sent_bytes = connected_stream.try_sendto(buf, flags)?;
connected_stream.update_io_events(&self.pollee);
Ok(sent_bytes)
}
// TODO: Support timeout
fn wait_events<F, R>(&self, mask: IoEvents, mut cond: F) -> Result<R>
where
F: FnMut() -> Result<R>,
{
let poller = Poller::new();
loop {
match cond() {
Err(err) if err.error() == Errno::EAGAIN => (),
result => return result,
};
let events = self.poll(mask, Some(&poller));
if !events.is_empty() {
continue;
}
poller.wait()?;
}
}
fn do_connect(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<ConnectingStream>> {
let mut state = self.state.write();
let init_stream = match &*state {
State::Init(init_stream) => init_stream,
State::Listen(_) | State::Connecting(_) | State::Connected(_) => {
return_errno_with_message!(Errno::EINVAL, "cannot connect")
fn update_io_events(&self) {
let state = self.state.read();
match &*state {
State::Init(_) | State::Poisoned => (),
State::Connecting(connecting_stream) => {
connecting_stream.update_io_events(&self.pollee)
}
};
let connecting = init_stream.connect(remote_endpoint)?;
*state = State::Connecting(connecting.clone());
Ok(connecting)
State::Listen(listen_stream) => listen_stream.update_io_events(&self.pollee),
State::Connected(connected_stream) => connected_stream.update_io_events(&self.pollee),
}
}
}
@ -123,13 +250,7 @@ impl FileLike for StreamSocket {
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let state = self.state.read();
match &*state {
State::Init(init) => init.poll(mask, poller),
State::Connecting(connecting) => connecting.poll(mask, poller),
State::Connected(connected) => connected.poll(mask, poller),
State::Listen(listen) => listen.poll(mask, poller),
}
self.pollee.poll(mask, poller)
}
fn status_flags(&self) -> StatusFlags {
@ -157,68 +278,65 @@ impl FileLike for StreamSocket {
impl Socket for StreamSocket {
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?;
let state = self.state.read();
match &*state {
State::Init(init_stream) => init_stream.bind(endpoint),
_ => return_errno_with_message!(Errno::EINVAL, "cannot bind"),
}
let mut state = self.state.write();
let owned_state = mem::replace(&mut *state, State::Poisoned);
let State::Init(init_stream) = owned_state else {
*state = owned_state;
return_errno_with_message!(Errno::EINVAL, "cannot bind");
};
let bound_socket = match init_stream.bind(&endpoint) {
Ok(bound_socket) => bound_socket,
Err((err, init_stream)) => {
*state = State::Init(init_stream);
return Err(err);
}
};
*state = State::Init(InitStream::new_bound(bound_socket));
Ok(())
}
// TODO: Support nonblocking mode
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
let remote_endpoint = socket_addr.try_into()?;
self.start_connect(&remote_endpoint)?;
let connecting_stream = self.do_connect(&remote_endpoint)?;
match connecting_stream.wait_conn() {
Ok(connected_stream) => {
*self.state.write() = State::Connected(connected_stream);
Ok(())
}
Err((err, init_stream)) => {
*self.state.write() = State::Init(init_stream);
Err(err)
}
}
poll_ifaces();
self.wait_events(IoEvents::OUT, || self.finish_connect())
}
fn listen(&self, backlog: usize) -> Result<()> {
let mut state = self.state.write();
let init_stream = match &*state {
State::Init(init_stream) => init_stream,
State::Connecting(connecting_stream) => {
return_errno_with_message!(Errno::EINVAL, "cannot listen for a connecting stream")
}
State::Listen(listen_stream) => {
return_errno_with_message!(Errno::EINVAL, "cannot listen for a listening stream")
}
State::Connected(_) => return_errno_with_message!(Errno::EINVAL, "cannot listen"),
let owned_state = mem::replace(&mut *state, State::Poisoned);
let State::Init(init_stream) = owned_state else {
*state = owned_state;
return_errno_with_message!(Errno::EINVAL, "cannot listen");
};
let listener = init_stream.listen(backlog)?;
*state = State::Listen(listener);
let listen_stream = match init_stream.listen(backlog) {
Ok(listen_stream) => listen_stream,
Err((err, init_stream)) => {
*state = State::Init(init_stream);
return Err(err);
}
};
listen_stream.init_pollee(&self.pollee);
*state = State::Listen(listen_stream);
Ok(())
}
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let listen_stream = match &*self.state.read() {
State::Listen(listen_stream) => listen_stream.clone(),
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not listening"),
};
let (connected_stream, remote_endpoint) = {
let listen_stream = listen_stream.clone();
listen_stream.accept()?
};
let accepted_socket = {
let state = RwLock::new(State::Connected(connected_stream));
Arc::new(StreamSocket {
options: RwLock::new(OptionSet::new()),
state,
})
};
let socket_addr = remote_endpoint.try_into()?;
Ok((accepted_socket, socket_addr))
poll_ifaces();
if self.is_nonblocking() {
self.try_accept()
} else {
self.wait_events(IoEvents::IN, || self.try_accept())
}
}
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
@ -233,11 +351,12 @@ impl Socket for StreamSocket {
fn addr(&self) -> Result<SocketAddr> {
let state = self.state.read();
let local_endpoint = match &*state {
State::Init(init_stream) => init_stream.local_endpoint(),
State::Init(init_stream) => init_stream.local_endpoint()?,
State::Connecting(connecting_stream) => connecting_stream.local_endpoint(),
State::Listen(listen_stream) => listen_stream.local_endpoint(),
State::Connected(connected_stream) => connected_stream.local_endpoint(),
}?;
State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"),
};
local_endpoint.try_into()
}
@ -252,19 +371,20 @@ impl Socket for StreamSocket {
return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer")
}
State::Connected(connected_stream) => connected_stream.remote_endpoint(),
}?;
State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"),
};
remote_endpoint.try_into()
}
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let connected_stream = match &*self.state.read() {
State::Connected(connected_stream) => connected_stream.clone(),
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not connected"),
};
debug_assert!(flags.is_all_supported());
let (recv_size, remote_endpoint) = connected_stream.recvfrom(buf, flags)?;
let socket_addr = remote_endpoint.try_into()?;
Ok((recv_size, socket_addr))
poll_ifaces();
if self.is_nonblocking() {
self.try_recvfrom(buf, flags)
} else {
self.wait_events(IoEvents::IN, || self.try_recvfrom(buf, flags))
}
}
fn sendto(
@ -273,16 +393,19 @@ impl Socket for StreamSocket {
remote: Option<SocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
debug_assert!(remote.is_none());
debug_assert!(flags.is_all_supported());
if remote.is_some() {
return_errno_with_message!(Errno::EINVAL, "tcp socked should not provide remote addr");
}
let connected_stream = match &*self.state.read() {
State::Connected(connected_stream) => connected_stream.clone(),
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not connected"),
let sent_bytes = if self.is_nonblocking() {
self.try_sendto(buf, flags)?
} else {
self.wait_events(IoEvents::OUT, || self.try_sendto(buf, flags))?
};
connected_stream.sendto(buf, flags)
poll_ifaces();
Ok(sent_bytes)
}
fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> {
@ -324,7 +447,9 @@ impl Socket for StreamSocket {
// FIXME: how to get the current MSS?
let maxseg = match &*self.state.read() {
State::Init(_) | State::Listen(_) | State::Connecting(_) => DEFAULT_MAXSEG,
State::Init(_) | State::Listen(_) | State::Connecting(_) | State::Poisoned => {
DEFAULT_MAXSEG
}
State::Connected(_) => options.tcp.maxseg(),
};
tcp_maxseg.set(maxseg);
@ -348,7 +473,7 @@ impl Socket for StreamSocket {
let recv_buf = socket_recv_buf.get().unwrap();
if *recv_buf <= MIN_RECVBUF {
options.socket.set_recv_buf(MIN_RECVBUF);
} else{
} else {
options.socket.set_recv_buf(*recv_buf);
}
},
@ -406,3 +531,9 @@ impl Socket for StreamSocket {
Ok(())
}
}
impl Observer<()> for StreamSocket {
fn on_events(&self, events: &()) {
self.update_io_events();
}
}

View File

@ -31,12 +31,12 @@ pub fn sys_socket(domain: i32, type_: i32, protocol: i32) -> Result<SyscallRetur
CSocketAddrFamily::AF_INET,
SockType::SOCK_STREAM,
Protocol::IPPROTO_IP | Protocol::IPPROTO_TCP,
) => Arc::new(StreamSocket::new(nonblocking)) as Arc<dyn FileLike>,
) => StreamSocket::new(nonblocking) as Arc<dyn FileLike>,
(
CSocketAddrFamily::AF_INET,
SockType::SOCK_DGRAM,
Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP,
) => Arc::new(DatagramSocket::new(nonblocking)) as Arc<dyn FileLike>,
) => DatagramSocket::new(nonblocking) as Arc<dyn FileLike>,
_ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported domain"),
};
let fd = {