Remove Arcs in TCP and UDP states

This commit is contained in:
Ruihan Li 2024-01-07 23:55:23 +08:00 committed by Tate, Hongliang Tian
parent 07e8cfe2e7
commit a10d04c5f9
14 changed files with 672 additions and 715 deletions

View File

@ -9,6 +9,7 @@ pub type RawSocketHandle = smoltcp::iface::SocketHandle;
pub struct AnyUnboundSocket { pub struct AnyUnboundSocket {
socket_family: AnyRawSocket, socket_family: AnyRawSocket,
observer: Weak<dyn Observer<()>>,
} }
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
@ -23,7 +24,7 @@ pub(super) enum SocketFamily {
} }
impl AnyUnboundSocket { impl AnyUnboundSocket {
pub fn new_tcp() -> Self { pub fn new_tcp(observer: Weak<dyn Observer<()>>) -> Self {
let raw_tcp_socket = { let raw_tcp_socket = {
let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; RECV_BUF_LEN]); let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; RECV_BUF_LEN]);
let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; SEND_BUF_LEN]); let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; SEND_BUF_LEN]);
@ -31,10 +32,11 @@ impl AnyUnboundSocket {
}; };
AnyUnboundSocket { AnyUnboundSocket {
socket_family: AnyRawSocket::Tcp(raw_tcp_socket), socket_family: AnyRawSocket::Tcp(raw_tcp_socket),
observer,
} }
} }
pub fn new_udp() -> Self { pub fn new_udp(observer: Weak<dyn Observer<()>>) -> Self {
let raw_udp_socket = { let raw_udp_socket = {
let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY; let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY;
let rx_buffer = smoltcp::socket::udp::PacketBuffer::new( let rx_buffer = smoltcp::socket::udp::PacketBuffer::new(
@ -49,18 +51,12 @@ impl AnyUnboundSocket {
}; };
AnyUnboundSocket { AnyUnboundSocket {
socket_family: AnyRawSocket::Udp(raw_udp_socket), socket_family: AnyRawSocket::Udp(raw_udp_socket),
observer,
} }
} }
pub(super) fn raw_socket_family(self) -> AnyRawSocket { pub(super) fn into_raw(self) -> (AnyRawSocket, Weak<dyn Observer<()>>) {
self.socket_family (self.socket_family, self.observer)
}
pub(super) fn socket_family(&self) -> SocketFamily {
match &self.socket_family {
AnyRawSocket::Tcp(_) => SocketFamily::Tcp,
AnyRawSocket::Udp(_) => SocketFamily::Udp,
}
} }
} }
@ -79,13 +75,14 @@ impl AnyBoundSocket {
handle: smoltcp::iface::SocketHandle, handle: smoltcp::iface::SocketHandle,
port: u16, port: u16,
socket_family: SocketFamily, socket_family: SocketFamily,
observer: Weak<dyn Observer<()>>,
) -> Arc<Self> { ) -> Arc<Self> {
Arc::new_cyclic(|weak_self| Self { Arc::new_cyclic(|weak_self| Self {
iface, iface,
handle, handle,
port, port,
socket_family, socket_family,
observer: RwLock::new(Weak::<()>::new()), observer: RwLock::new(observer),
weak_self: weak_self.clone(), weak_self: weak_self.clone(),
}) })
} }

View File

@ -11,7 +11,7 @@ use smoltcp::{
}; };
use super::{ use super::{
any_socket::{AnyBoundSocket, AnyRawSocket, AnyUnboundSocket}, any_socket::{AnyBoundSocket, AnyRawSocket, AnyUnboundSocket, SocketFamily},
time::get_network_timestamp, time::get_network_timestamp,
util::BindPortConfig, util::BindPortConfig,
Iface, Ipv4Address, Iface, Ipv4Address,
@ -107,20 +107,28 @@ impl IfaceCommon {
} else { } else {
match self.alloc_ephemeral_port() { match self.alloc_ephemeral_port() {
Ok(port) => port, Ok(port) => port,
Err(e) => return Err((e, socket)), Err(err) => return Err((err, socket)),
} }
}; };
if let Some(e) = self.bind_port(port, config.can_reuse()).err() { if let Some(err) = self.bind_port(port, config.can_reuse()).err() {
return Err((e, socket)); return Err((err, socket));
} }
let socket_family = socket.socket_family();
let mut sockets = self.sockets.lock_irq_disabled(); let (handle, socket_family, observer) = match socket.into_raw() {
let handle = match socket.raw_socket_family() { (AnyRawSocket::Tcp(tcp_socket), observer) => (
AnyRawSocket::Tcp(tcp_socket) => sockets.add(tcp_socket), self.sockets.lock_irq_disabled().add(tcp_socket),
AnyRawSocket::Udp(udp_socket) => sockets.add(udp_socket), SocketFamily::Tcp,
observer,
),
(AnyRawSocket::Udp(udp_socket), observer) => (
self.sockets.lock_irq_disabled().add(udp_socket),
SocketFamily::Udp,
observer,
),
}; };
let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family); let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family, observer);
self.insert_bound_socket(&bound_socket).unwrap(); self.insert_bound_socket(&bound_socket).unwrap();
Ok(bound_socket) Ok(bound_socket)
} }

View File

@ -46,7 +46,6 @@ pub trait Iface: internal::IfaceInternal + Send + Sync {
config: BindPortConfig, config: BindPortConfig,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Box<AnyUnboundSocket>)> { ) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Box<AnyUnboundSocket>)> {
let common = self.common(); let common = self.common();
let socket_type_inner = socket.socket_family();
common.bind_socket(self.arc_self(), socket, config) common.bind_socket(self.arc_self(), socket, config)
} }

View File

@ -44,7 +44,7 @@ fn get_ephemeral_iface(remote_ip_addr: &IpAddress) -> Arc<dyn Iface> {
pub(super) fn bind_socket( pub(super) fn bind_socket(
unbound_socket: Box<AnyUnboundSocket>, unbound_socket: Box<AnyUnboundSocket>,
endpoint: IpEndpoint, endpoint: &IpEndpoint,
can_reuse: bool, can_reuse: bool,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Box<AnyUnboundSocket>)> { ) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Box<AnyUnboundSocket>)> {
let iface = match get_iface_to_bind(&endpoint.addr) { let iface = match get_iface_to_bind(&endpoint.addr) {

View File

@ -1,61 +1,49 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use crate::{ use crate::{
events::{IoEvents, Observer}, events::IoEvents,
net::{ net::{
iface::{AnyBoundSocket, IpEndpoint, RawUdpSocket}, iface::{AnyBoundSocket, IpEndpoint, RawUdpSocket},
poll_ifaces,
socket::util::send_recv_flags::SendRecvFlags, socket::util::send_recv_flags::SendRecvFlags,
}, },
prelude::*, prelude::*,
process::signal::{Pollee, Poller}, process::signal::Pollee,
}; };
pub struct BoundDatagram { pub struct BoundDatagram {
bound_socket: Arc<AnyBoundSocket>, bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: RwLock<Option<IpEndpoint>>, remote_endpoint: Option<IpEndpoint>,
pollee: Pollee,
} }
impl BoundDatagram { impl BoundDatagram {
pub fn new(bound_socket: Arc<AnyBoundSocket>, pollee: Pollee) -> Arc<Self> { pub fn new(bound_socket: Arc<AnyBoundSocket>) -> Self {
let bound = Arc::new(Self { Self {
bound_socket, bound_socket,
remote_endpoint: RwLock::new(None), remote_endpoint: None,
pollee, }
}); }
bound.bound_socket.set_observer(Arc::downgrade(&bound) as _);
bound pub fn local_endpoint(&self) -> IpEndpoint {
self.bound_socket.local_endpoint().unwrap()
} }
pub fn remote_endpoint(&self) -> Result<IpEndpoint> { pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
self.remote_endpoint self.remote_endpoint
.read()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "remote endpoint is not specified")) .ok_or_else(|| Error::with_message(Errno::EINVAL, "remote endpoint is not specified"))
} }
pub fn set_remote_endpoint(&self, endpoint: IpEndpoint) { pub fn set_remote_endpoint(&mut self, endpoint: &IpEndpoint) {
*self.remote_endpoint.write() = Some(endpoint); self.remote_endpoint = Some(*endpoint)
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.bound_socket.local_endpoint().ok_or_else(|| {
Error::with_message(Errno::EINVAL, "socket does not bind to local endpoint")
})
} }
pub fn try_recvfrom( pub fn try_recvfrom(
&self, &self,
buf: &mut [u8], buf: &mut [u8],
flags: &SendRecvFlags, flags: SendRecvFlags,
) -> Result<(usize, IpEndpoint)> { ) -> Result<(usize, IpEndpoint)> {
poll_ifaces(); self.bound_socket
let recv_slice = |socket: &mut RawUdpSocket| { .raw_with(|socket: &mut RawUdpSocket| socket.recv_slice(buf))
socket
.recv_slice(buf)
.map_err(|_| Error::with_message(Errno::EAGAIN, "recv buf is empty")) .map_err(|_| Error::with_message(Errno::EAGAIN, "recv buf is empty"))
};
self.bound_socket.raw_with(recv_slice)
} }
pub fn try_sendto( pub fn try_sendto(
@ -65,27 +53,21 @@ impl BoundDatagram {
flags: SendRecvFlags, flags: SendRecvFlags,
) -> Result<usize> { ) -> Result<usize> {
let remote_endpoint = remote let remote_endpoint = remote
.or_else(|| self.remote_endpoint().ok()) .or(self.remote_endpoint)
.ok_or_else(|| Error::with_message(Errno::EINVAL, "udp should provide remote addr"))?; .ok_or_else(|| Error::with_message(Errno::EINVAL, "udp should provide remote addr"))?;
let send_slice = |socket: &mut RawUdpSocket| { self.bound_socket
socket .raw_with(|socket: &mut RawUdpSocket| socket.send_slice(buf, remote_endpoint))
.send_slice(buf, remote_endpoint)
.map(|_| buf.len()) .map(|_| buf.len())
.map_err(|_| Error::with_message(Errno::EAGAIN, "send udp packet fails")) .map_err(|_| Error::with_message(Errno::EAGAIN, "send udp packet fails"))
};
let len = self.bound_socket.raw_with(send_slice)?;
poll_ifaces();
Ok(len)
} }
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { pub(super) fn init_pollee(&self, pollee: &Pollee) {
self.pollee.poll(mask, poller) pollee.reset_events();
self.update_io_events(pollee)
} }
fn update_io_events(&self) { pub(super) fn update_io_events(&self, pollee: &Pollee) {
self.bound_socket.raw_with(|socket: &mut RawUdpSocket| { self.bound_socket.raw_with(|socket: &mut RawUdpSocket| {
let pollee = &self.pollee;
if socket.can_recv() { if socket.can_recv() {
pollee.add_events(IoEvents::IN); pollee.add_events(IoEvents::IN);
} else { } else {
@ -100,9 +82,3 @@ impl BoundDatagram {
}); });
} }
} }
impl Observer<()> for BoundDatagram {
fn on_events(&self, _: &()) {
self.update_io_events();
}
}

View File

@ -1,81 +1,91 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering}; use core::{
mem,
sync::atomic::{AtomicBool, Ordering},
};
use self::{bound::BoundDatagram, unbound::UnboundDatagram}; use self::{bound::BoundDatagram, unbound::UnboundDatagram};
use super::{always_some::AlwaysSome, common::get_ephemeral_endpoint}; use super::common::get_ephemeral_endpoint;
use crate::{ use crate::{
events::IoEvents, events::{IoEvents, Observer},
fs::{file_handle::FileLike, utils::StatusFlags}, fs::{file_handle::FileLike, utils::StatusFlags},
net::{ net::{
iface::IpEndpoint, iface::IpEndpoint,
poll_ifaces,
socket::{ socket::{
util::{send_recv_flags::SendRecvFlags, socket_addr::SocketAddr}, util::{send_recv_flags::SendRecvFlags, socket_addr::SocketAddr},
Socket, Socket,
}, },
}, },
prelude::*, prelude::*,
process::signal::Poller, process::signal::{Pollee, Poller},
}; };
mod bound; mod bound;
mod unbound; mod unbound;
pub struct DatagramSocket { pub struct DatagramSocket {
nonblocking: AtomicBool,
inner: RwLock<Inner>, inner: RwLock<Inner>,
nonblocking: AtomicBool,
pollee: Pollee,
} }
enum Inner { enum Inner {
Unbound(AlwaysSome<UnboundDatagram>), Unbound(UnboundDatagram),
Bound(Arc<BoundDatagram>), Bound(BoundDatagram),
Poisoned,
} }
impl Inner { impl Inner {
fn is_bound(&self) -> bool { fn bind(self, endpoint: &IpEndpoint) -> core::result::Result<BoundDatagram, (Error, Self)> {
matches!(self, Inner::Bound { .. }) let unbound_datagram = match self {
Inner::Unbound(unbound_datagram) => unbound_datagram,
Inner::Bound(bound_datagram) => {
return Err((
Error::with_message(Errno::EINVAL, "the socket is already bound to an address"),
Inner::Bound(bound_datagram),
));
}
Inner::Poisoned => {
return Err((
Error::with_message(Errno::EINVAL, "the socket is poisoned"),
Inner::Poisoned,
));
} }
fn bind(&mut self, endpoint: IpEndpoint) -> Result<Arc<BoundDatagram>> {
let unbound = match self {
Inner::Unbound(unbound) => unbound,
Inner::Bound(..) => return_errno_with_message!(
Errno::EINVAL,
"the socket is already bound to an address"
),
}; };
let bound = unbound.try_take_with(|unbound| unbound.bind(endpoint))?;
*self = Inner::Bound(bound.clone()); let bound_datagram = match unbound_datagram.bind(endpoint) {
Ok(bound) Ok(bound_datagram) => bound_datagram,
Err((err, unbound_datagram)) => return Err((err, Inner::Unbound(unbound_datagram))),
};
Ok(bound_datagram)
} }
fn bind_to_ephemeral_endpoint( fn bind_to_ephemeral_endpoint(
&mut self, self,
remote_endpoint: &IpEndpoint, remote_endpoint: &IpEndpoint,
) -> Result<Arc<BoundDatagram>> { ) -> core::result::Result<BoundDatagram, (Error, Self)> {
let endpoint = get_ephemeral_endpoint(remote_endpoint); if let Inner::Bound(bound_datagram) = self {
self.bind(endpoint) return Ok(bound_datagram);
} }
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { let endpoint = get_ephemeral_endpoint(remote_endpoint);
match self { self.bind(&endpoint)
Inner::Unbound(unbound) => unbound.poll(mask, poller),
Inner::Bound(bound) => bound.poll(mask, poller),
}
} }
} }
impl DatagramSocket { impl DatagramSocket {
pub fn new(nonblocking: bool) -> Self { pub fn new(nonblocking: bool) -> Arc<Self> {
let unbound = UnboundDatagram::new(); Arc::new_cyclic(|me| {
let unbound_datagram = UnboundDatagram::new(me.clone() as _);
let pollee = Pollee::new(IoEvents::empty());
Self { Self {
inner: RwLock::new(Inner::Unbound(AlwaysSome::new(unbound))), inner: RwLock::new(Inner::Unbound(unbound_datagram)),
nonblocking: AtomicBool::new(nonblocking), nonblocking: AtomicBool::new(nonblocking),
pollee,
} }
} })
pub fn is_bound(&self) -> bool {
self.inner.read().is_bound()
} }
pub fn is_nonblocking(&self) -> bool { pub fn is_nonblocking(&self) -> bool {
@ -86,26 +96,81 @@ impl DatagramSocket {
self.nonblocking.store(nonblocking, Ordering::SeqCst); self.nonblocking.store(nonblocking, Ordering::SeqCst);
} }
fn bound(&self) -> Result<Arc<BoundDatagram>> { fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
if let Inner::Bound(bound) = &*self.inner.read() {
Ok(bound.clone())
} else {
return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint")
}
}
fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<BoundDatagram>> {
// Fast path // Fast path
if let Inner::Bound(bound) = &*self.inner.read() { if let Inner::Bound(_) = &*self.inner.read() {
return Ok(bound.clone()); return Ok(());
} }
// Slow path // Slow path
let mut inner = self.inner.write(); let mut inner = self.inner.write();
if let Inner::Bound(bound) = &*inner { let owned_inner = mem::replace(&mut *inner, Inner::Poisoned);
return Ok(bound.clone());
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) {
Ok(bound_datagram) => bound_datagram,
Err((err, err_inner)) => {
*inner = err_inner;
return Err(err);
} }
inner.bind_to_ephemeral_endpoint(remote_endpoint) };
bound_datagram.init_pollee(&self.pollee);
*inner = Inner::Bound(bound_datagram);
Ok(())
}
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
};
let (recv_bytes, remote_endpoint) = bound_datagram.try_recvfrom(buf, flags)?;
bound_datagram.update_io_events(&self.pollee);
Ok((recv_bytes, remote_endpoint.try_into()?))
}
fn try_sendto(
&self,
buf: &[u8],
remote: Option<IpEndpoint>,
flags: SendRecvFlags,
) -> Result<usize> {
let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
};
let sent_bytes = bound_datagram.try_sendto(buf, remote, flags)?;
bound_datagram.update_io_events(&self.pollee);
Ok(sent_bytes)
}
// TODO: Support timeout
fn wait_events<F, R>(&self, mask: IoEvents, mut cond: F) -> Result<R>
where
F: FnMut() -> Result<R>,
{
let poller = Poller::new();
loop {
match cond() {
Err(err) if err.error() == Errno::EAGAIN => (),
result => return result,
};
let events = self.poll(mask, Some(&poller));
if !events.is_empty() {
continue;
}
poller.wait()?;
}
}
fn update_io_events(&self) {
let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
return;
};
bound_datagram.update_io_events(&self.pollee);
} }
} }
@ -124,7 +189,7 @@ impl FileLike for DatagramSocket {
} }
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.inner.read().poll(mask, poller) self.pollee.poll(mask, poller)
} }
fn as_socket(self: Arc<Self>) -> Option<Arc<dyn Socket>> { fn as_socket(self: Arc<Self>) -> Option<Arc<dyn Socket>> {
@ -152,43 +217,61 @@ impl FileLike for DatagramSocket {
impl Socket for DatagramSocket { impl Socket for DatagramSocket {
fn bind(&self, socket_addr: SocketAddr) -> Result<()> { fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?; let endpoint = socket_addr.try_into()?;
self.inner.write().bind(endpoint)?;
let mut inner = self.inner.write();
let owned_inner = mem::replace(&mut *inner, Inner::Poisoned);
let bound_datagram = match owned_inner.bind(&endpoint) {
Ok(bound_datagram) => bound_datagram,
Err((err, err_inner)) => {
*inner = err_inner;
return Err(err);
}
};
bound_datagram.init_pollee(&self.pollee);
*inner = Inner::Bound(bound_datagram);
Ok(()) Ok(())
} }
fn connect(&self, socket_addr: SocketAddr) -> Result<()> { fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?; let endpoint = socket_addr.try_into()?;
let bound = self.try_bind_empheral(&endpoint)?;
bound.set_remote_endpoint(endpoint); self.try_bind_empheral(&endpoint)?;
let mut inner = self.inner.write();
let Inner::Bound(bound_datagram) = &mut *inner else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound")
};
bound_datagram.set_remote_endpoint(&endpoint);
Ok(()) Ok(())
} }
fn addr(&self) -> Result<SocketAddr> { fn addr(&self) -> Result<SocketAddr> {
self.bound()?.local_endpoint()?.try_into() let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
};
bound_datagram.local_endpoint().try_into()
} }
fn peer_addr(&self) -> Result<SocketAddr> { fn peer_addr(&self) -> Result<SocketAddr> {
self.bound()?.remote_endpoint()?.try_into() let inner = self.inner.read();
let Inner::Bound(bound_datagram) = &*inner else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
};
bound_datagram.remote_endpoint()?.try_into()
} }
// FIXME: respect RecvFromFlags // FIXME: respect RecvFromFlags
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
debug_assert!(flags.is_all_supported()); debug_assert!(flags.is_all_supported());
let bound = self.bound()?;
let poller = Poller::new(); poll_ifaces();
loop {
if let Ok((recv_len, remote_endpoint)) = bound.try_recvfrom(buf, &flags) {
let remote_addr = remote_endpoint.try_into()?;
return Ok((recv_len, remote_addr));
}
let events = bound.poll(IoEvents::IN, Some(&poller));
if !events.contains(IoEvents::IN) {
if self.is_nonblocking() { if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try to receive again"); self.try_recvfrom(buf, flags)
} } else {
// FIXME: deal with recvfrom timeout self.wait_events(IoEvents::IN, || self.try_recvfrom(buf, flags))
poller.wait()?;
}
} }
} }
@ -199,13 +282,24 @@ impl Socket for DatagramSocket {
flags: SendRecvFlags, flags: SendRecvFlags,
) -> Result<usize> { ) -> Result<usize> {
debug_assert!(flags.is_all_supported()); debug_assert!(flags.is_all_supported());
let (bound, remote_endpoint) = if let Some(addr) = remote {
let endpoint = addr.try_into()?; let remote_endpoint = match remote {
(self.try_bind_empheral(&endpoint)?, Some(endpoint)) Some(remote_addr) => Some(remote_addr.try_into()?),
} else { None => None,
let bound = self.bound()?;
(bound, None)
}; };
bound.try_sendto(buf, remote_endpoint, flags) if let Some(endpoint) = remote_endpoint {
self.try_bind_empheral(&endpoint)?;
}
// TODO: Block if the send buffer is full
let sent_bytes = self.try_sendto(buf, remote_endpoint, flags)?;
poll_ifaces();
Ok(sent_bytes)
}
}
impl Observer<()> for DatagramSocket {
fn on_events(&self, events: &()) {
self.update_io_events();
} }
} }

View File

@ -1,53 +1,39 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use alloc::sync::Weak;
use super::bound::BoundDatagram; use super::bound::BoundDatagram;
use crate::{ use crate::{
events::IoEvents, events::Observer,
net::{ net::{
iface::{AnyUnboundSocket, IpEndpoint, RawUdpSocket}, iface::{AnyUnboundSocket, IpEndpoint, RawUdpSocket},
socket::ip::common::bind_socket, socket::ip::common::bind_socket,
}, },
prelude::*, prelude::*,
process::signal::{Pollee, Poller},
}; };
pub struct UnboundDatagram { pub struct UnboundDatagram {
unbound_socket: Box<AnyUnboundSocket>, unbound_socket: Box<AnyUnboundSocket>,
pollee: Pollee,
} }
impl UnboundDatagram { impl UnboundDatagram {
pub fn new() -> Self { pub fn new(observer: Weak<dyn Observer<()>>) -> Self {
Self { Self {
unbound_socket: Box::new(AnyUnboundSocket::new_udp()), unbound_socket: Box::new(AnyUnboundSocket::new_udp(observer)),
pollee: Pollee::new(IoEvents::empty()),
} }
} }
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { pub fn bind(self, endpoint: &IpEndpoint) -> core::result::Result<BoundDatagram, (Error, Self)> {
self.pollee.poll(mask, poller)
}
pub fn bind(
self,
endpoint: IpEndpoint,
) -> core::result::Result<Arc<BoundDatagram>, (Error, Self)> {
let bound_socket = match bind_socket(self.unbound_socket, endpoint, false) { let bound_socket = match bind_socket(self.unbound_socket, endpoint, false) {
Ok(bound_socket) => bound_socket, Ok(bound_socket) => bound_socket,
Err((err, unbound_socket)) => { Err((err, unbound_socket)) => return Err((err, Self { unbound_socket })),
return Err((
err,
Self {
unbound_socket,
pollee: self.pollee,
},
))
}
}; };
let bound_endpoint = bound_socket.local_endpoint().unwrap(); let bound_endpoint = bound_socket.local_endpoint().unwrap();
bound_socket.raw_with(|socket: &mut RawUdpSocket| { bound_socket.raw_with(|socket: &mut RawUdpSocket| {
socket.bind(bound_endpoint).unwrap(); socket.bind(bound_endpoint).unwrap();
}); });
Ok(BoundDatagram::new(bound_socket, self.pollee))
Ok(BoundDatagram::new(bound_socket))
} }
} }

View File

@ -1,6 +1,5 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
mod always_some;
mod common; mod common;
mod datagram; mod datagram;
pub mod stream; pub mod stream;

View File

@ -1,42 +1,28 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering}; use alloc::sync::Weak;
use crate::{ use crate::{
events::{IoEvents, Observer}, events::{IoEvents, Observer},
net::{ net::{
iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket}, iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket},
poll_ifaces,
socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd}, socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd},
}, },
prelude::*, prelude::*,
process::signal::{Pollee, Poller}, process::signal::Pollee,
}; };
pub struct ConnectedStream { pub struct ConnectedStream {
nonblocking: AtomicBool,
bound_socket: Arc<AnyBoundSocket>, bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint, remote_endpoint: IpEndpoint,
pollee: Pollee,
} }
impl ConnectedStream { impl ConnectedStream {
pub fn new( pub fn new(bound_socket: Arc<AnyBoundSocket>, remote_endpoint: IpEndpoint) -> Self {
is_nonblocking: bool, Self {
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
pollee: Pollee,
) -> Arc<Self> {
let connected = Arc::new(Self {
nonblocking: AtomicBool::new(is_nonblocking),
bound_socket, bound_socket,
remote_endpoint, remote_endpoint,
pollee, }
});
connected
.bound_socket
.set_observer(Arc::downgrade(&connected) as _);
connected
} }
pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
@ -44,102 +30,46 @@ impl ConnectedStream {
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
socket.close(); socket.close();
}); });
poll_ifaces();
Ok(()) Ok(())
} }
pub fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, IpEndpoint)> { pub fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<usize> {
debug_assert!(flags.is_all_supported()); let recv_bytes = self
.bound_socket
let poller = Poller::new(); .raw_with(|socket: &mut RawTcpSocket| socket.recv_slice(buf))
loop { .map_err(|_| Error::with_message(Errno::ENOTCONN, "fail to recv packet"))?;
let recv_len = self.try_recvfrom(buf, flags)?; if recv_bytes == 0 {
if recv_len > 0 {
let remote_endpoint = self.remote_endpoint()?;
return Ok((recv_len, remote_endpoint));
}
let events = self.poll(IoEvents::IN, Some(&poller));
if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) {
return_errno_with_message!(Errno::ENOTCONN, "recv packet fails");
}
if !events.contains(IoEvents::IN) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try to recv again"); return_errno_with_message!(Errno::EAGAIN, "try to recv again");
} }
// FIXME: deal with receive timeout Ok(recv_bytes)
poller.wait()?;
}
}
} }
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<usize> { pub fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
poll_ifaces(); let sent_bytes = self
let res = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
socket
.recv_slice(buf)
.map_err(|_| Error::with_message(Errno::ENOTCONN, "fail to recv packet"))
});
self.update_io_events();
res
}
pub fn sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
debug_assert!(flags.is_all_supported());
let poller = Poller::new();
loop {
let sent_len = self.try_sendto(buf, flags)?;
if sent_len > 0 {
return Ok(sent_len);
}
let events = self.poll(IoEvents::OUT, Some(&poller));
if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) {
return_errno_with_message!(Errno::ENOBUFS, "fail to send packets");
}
if !events.contains(IoEvents::OUT) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try to send again");
}
// FIXME: deal with send timeout
poller.wait()?;
}
}
}
fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
let res = self
.bound_socket .bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.send_slice(buf)) .raw_with(|socket: &mut RawTcpSocket| socket.send_slice(buf))
.map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet")); .map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet"))?;
match res { if sent_bytes == 0 {
// We have to explicitly invoke `update_io_events` when the send buffer becomes return_errno_with_message!(Errno::EAGAIN, "try to send again");
// full. Note that smoltcp does not think it is an interface event, so calling }
// `poll_ifaces` alone is not enough. Ok(sent_bytes)
Ok(0) => self.update_io_events(),
Ok(_) => poll_ifaces(),
_ => (),
};
res
} }
pub fn local_endpoint(&self) -> Result<IpEndpoint> { pub fn local_endpoint(&self) -> IpEndpoint {
self.bound_socket self.bound_socket.local_endpoint().unwrap()
.local_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint"))
} }
pub fn remote_endpoint(&self) -> Result<IpEndpoint> { pub fn remote_endpoint(&self) -> IpEndpoint {
Ok(self.remote_endpoint) self.remote_endpoint
} }
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { pub(super) fn init_pollee(&self, pollee: &Pollee) {
self.pollee.poll(mask, poller) pollee.reset_events();
self.update_io_events(pollee);
} }
fn update_io_events(&self) { pub(super) fn update_io_events(&self, pollee: &Pollee) {
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
let pollee = &self.pollee;
if socket.can_recv() { if socket.can_recv() {
pollee.add_events(IoEvents::IN); pollee.add_events(IoEvents::IN);
} else { } else {
@ -154,17 +84,7 @@ impl ConnectedStream {
}); });
} }
pub fn is_nonblocking(&self) -> bool { pub(super) fn set_observer(&self, observer: Weak<dyn Observer<()>>) {
self.nonblocking.load(Ordering::Relaxed) self.bound_socket.set_observer(observer)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.nonblocking.store(nonblocking, Ordering::Relaxed);
}
}
impl Observer<()> for ConnectedStream {
fn on_events(&self, _: &()) {
self.update_io_events();
} }
} }

View File

@ -1,116 +1,77 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use super::{connected::ConnectedStream, init::InitStream}; use super::{connected::ConnectedStream, init::InitStream};
use crate::{ use crate::{
events::{IoEvents, Observer}, events::IoEvents,
net::{ net::iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket},
iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket},
poll_ifaces,
},
prelude::*, prelude::*,
process::signal::{Pollee, Poller}, process::signal::Pollee,
}; };
pub struct ConnectingStream { pub struct ConnectingStream {
nonblocking: AtomicBool,
bound_socket: Arc<AnyBoundSocket>, bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint, remote_endpoint: IpEndpoint,
conn_result: RwLock<Option<ConnResult>>, conn_result: RwLock<Option<ConnResult>>,
pollee: Pollee,
} }
#[derive(Clone, Copy)]
enum ConnResult { enum ConnResult {
Connected, Connected,
Refused, Refused,
} }
pub enum NonConnectedStream {
Init(InitStream),
Connecting(ConnectingStream),
}
impl ConnectingStream { impl ConnectingStream {
pub fn new( pub fn new(
nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>, bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint, remote_endpoint: IpEndpoint,
pollee: Pollee, ) -> core::result::Result<Self, (Error, Arc<AnyBoundSocket>)> {
) -> Result<Arc<Self>> { if let Err(err) = bound_socket.do_connect(remote_endpoint) {
bound_socket.do_connect(remote_endpoint)?; return Err((err, bound_socket));
}
let connecting = Arc::new(Self { Ok(Self {
nonblocking: AtomicBool::new(nonblocking),
bound_socket, bound_socket,
remote_endpoint, remote_endpoint,
conn_result: RwLock::new(None), conn_result: RwLock::new(None),
pollee, })
});
connecting.pollee.reset_events();
connecting
.bound_socket
.set_observer(Arc::downgrade(&connecting) as _);
Ok(connecting)
} }
pub fn wait_conn( pub fn into_result(self) -> core::result::Result<ConnectedStream, (Error, NonConnectedStream)> {
&self, let conn_result = *self.conn_result.read();
) -> core::result::Result<Arc<ConnectedStream>, (Error, Arc<InitStream>)> { match conn_result {
debug_assert!(!self.is_nonblocking()); Some(ConnResult::Connected) => Ok(ConnectedStream::new(
self.bound_socket,
let poller = Poller::new();
loop {
poll_ifaces();
match *self.conn_result.read() {
Some(ConnResult::Connected) => {
return Ok(ConnectedStream::new(
self.is_nonblocking(),
self.bound_socket.clone(),
self.remote_endpoint, self.remote_endpoint,
self.pollee.clone(), )),
)); Some(ConnResult::Refused) => Err((
} Error::with_message(Errno::ECONNREFUSED, "the connection is refused"),
Some(ConnResult::Refused) => { NonConnectedStream::Init(InitStream::new_bound(self.bound_socket)),
return Err(( )),
Error::with_message(Errno::ECONNREFUSED, "connection refused"), None => Err((
InitStream::new_bound( Error::with_message(Errno::EAGAIN, "the connection is pending"),
self.is_nonblocking(), NonConnectedStream::Connecting(self),
self.bound_socket.clone(), )),
self.pollee.clone(),
),
));
}
None => (),
};
let events = self.poll(IoEvents::OUT, Some(&poller));
if !events.contains(IoEvents::OUT) {
// FIXME: deal with nonblocking mode & connecting timeout
poller.wait().expect("async connect() not implemented");
}
} }
} }
pub fn local_endpoint(&self) -> Result<IpEndpoint> { pub fn local_endpoint(&self) -> IpEndpoint {
self.bound_socket self.bound_socket.local_endpoint().unwrap()
.local_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "no local endpoint"))
} }
pub fn remote_endpoint(&self) -> Result<IpEndpoint> { pub fn remote_endpoint(&self) -> IpEndpoint {
Ok(self.remote_endpoint) self.remote_endpoint
} }
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { pub(super) fn init_pollee(&self, pollee: &Pollee) {
self.pollee.poll(mask, poller) pollee.reset_events();
self.update_io_events(pollee);
} }
pub fn is_nonblocking(&self) -> bool { pub(super) fn update_io_events(&self, pollee: &Pollee) {
self.nonblocking.load(Ordering::Relaxed)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.nonblocking.store(nonblocking, Ordering::Relaxed);
}
fn update_io_events(&self) {
if self.conn_result.read().is_some() { if self.conn_result.read().is_some() {
return; return;
} }
@ -143,13 +104,7 @@ impl ConnectingStream {
// be responsible to initialize all the I/O events including `IoEvents::OUT`, so the // be responsible to initialize all the I/O events including `IoEvents::OUT`, so the
// following hard-coded event addition can be removed. // following hard-coded event addition can be removed.
if became_writable { if became_writable {
self.pollee.add_events(IoEvents::OUT); pollee.add_events(IoEvents::OUT);
} }
} }
} }
impl Observer<()> for ConnectingStream {
fn on_events(&self, _: &()) {
self.update_io_events();
}
}

View File

@ -1,156 +1,93 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering}; use alloc::sync::Weak;
use super::{connecting::ConnectingStream, listen::ListenStream}; use super::{connecting::ConnectingStream, listen::ListenStream};
use crate::{ use crate::{
events::IoEvents, events::Observer,
net::{ net::{
iface::{AnyBoundSocket, AnyUnboundSocket, Iface, IpEndpoint}, iface::{AnyBoundSocket, AnyUnboundSocket, IpEndpoint},
socket::ip::{ socket::ip::common::{bind_socket, get_ephemeral_endpoint},
always_some::AlwaysSome,
common::{bind_socket, get_ephemeral_endpoint},
},
}, },
prelude::*, prelude::*,
process::signal::{Pollee, Poller},
}; };
pub struct InitStream { pub enum InitStream {
inner: RwLock<Inner>, Unbound(Box<AnyUnboundSocket>),
is_nonblocking: AtomicBool, Bound(Arc<AnyBoundSocket>),
pollee: Pollee,
}
enum Inner {
Unbound(AlwaysSome<Box<AnyUnboundSocket>>),
Bound(AlwaysSome<Arc<AnyBoundSocket>>),
}
impl Inner {
fn new() -> Inner {
let unbound_socket = Box::new(AnyUnboundSocket::new_tcp());
Inner::Unbound(AlwaysSome::new(unbound_socket))
}
fn is_bound(&self) -> bool {
match self {
Self::Unbound(_) => false,
Self::Bound(_) => true,
}
}
fn bind(&mut self, endpoint: IpEndpoint) -> Result<()> {
let unbound_socket = if let Inner::Unbound(unbound_socket) = self {
unbound_socket
} else {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address");
};
let bound_socket =
unbound_socket.try_take_with(|raw_socket| bind_socket(raw_socket, endpoint, false))?;
*self = Inner::Bound(AlwaysSome::new(bound_socket));
Ok(())
}
fn bind_to_ephemeral_endpoint(&mut self, remote_endpoint: &IpEndpoint) -> Result<()> {
let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(endpoint)
}
fn bound_socket(&self) -> Option<&Arc<AnyBoundSocket>> {
match self {
Inner::Bound(bound_socket) => Some(bound_socket),
Inner::Unbound(_) => None,
}
}
fn iface(&self) -> Option<Arc<dyn Iface>> {
match self {
Inner::Bound(bound_socket) => Some(bound_socket.iface().clone()),
Inner::Unbound(_) => None,
}
}
fn local_endpoint(&self) -> Option<IpEndpoint> {
self.bound_socket()
.and_then(|socket| socket.local_endpoint())
}
} }
impl InitStream { impl InitStream {
// FIXME: In Linux we have the `POLLOUT` event for a newly created socket, while calling // FIXME: In Linux we have the `POLLOUT` event for a newly created socket, while calling
// `write()` on it triggers `SIGPIPE`/`EPIPE`. No documentation found yet, but confirmed by // `write()` on it triggers `SIGPIPE`/`EPIPE`. No documentation found yet, but confirmed by
// experimentation and Linux source code. // experimentation and Linux source code.
pub fn new(nonblocking: bool) -> Arc<Self> { pub fn new(observer: Weak<dyn Observer<()>>) -> Self {
Arc::new(Self { InitStream::Unbound(Box::new(AnyUnboundSocket::new_tcp(observer)))
inner: RwLock::new(Inner::new()),
is_nonblocking: AtomicBool::new(nonblocking),
pollee: Pollee::new(IoEvents::empty()),
})
} }
pub fn new_bound( pub fn new_bound(bound_socket: Arc<AnyBoundSocket>) -> Self {
nonblocking: bool, InitStream::Bound(bound_socket)
bound_socket: Arc<AnyBoundSocket>,
pollee: Pollee,
) -> Arc<Self> {
bound_socket.set_observer(Weak::<()>::new());
let inner = Inner::Bound(AlwaysSome::new(bound_socket));
Arc::new(Self {
is_nonblocking: AtomicBool::new(nonblocking),
inner: RwLock::new(inner),
pollee,
})
} }
pub fn bind(&self, endpoint: IpEndpoint) -> Result<()> { pub fn bind(
self.inner.write().bind(endpoint) self,
endpoint: &IpEndpoint,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Self)> {
let unbound_socket = match self {
InitStream::Unbound(unbound_socket) => unbound_socket,
InitStream::Bound(bound_socket) => {
return Err((
Error::with_message(Errno::EINVAL, "the socket is already bound to an address"),
InitStream::Bound(bound_socket),
));
} }
pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<ConnectingStream>> {
if !self.inner.read().is_bound() {
self.inner
.write()
.bind_to_ephemeral_endpoint(remote_endpoint)?
}
ConnectingStream::new(
self.is_nonblocking(),
self.inner.read().bound_socket().unwrap().clone(),
*remote_endpoint,
self.pollee.clone(),
)
}
pub fn listen(&self, backlog: usize) -> Result<Arc<ListenStream>> {
let bound_socket = if let Some(bound_socket) = self.inner.read().bound_socket() {
bound_socket.clone()
} else {
return_errno_with_message!(Errno::EINVAL, "cannot listen without bound")
}; };
ListenStream::new( let bound_socket = match bind_socket(unbound_socket, endpoint, false) {
self.is_nonblocking(), Ok(bound_socket) => bound_socket,
bound_socket, Err((err, unbound_socket)) => return Err((err, InitStream::Unbound(unbound_socket))),
backlog, };
self.pollee.clone(), Ok(bound_socket)
) }
fn bind_to_ephemeral_endpoint(
self,
remote_endpoint: &IpEndpoint,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Self)> {
let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(&endpoint)
}
pub fn connect(
self,
remote_endpoint: &IpEndpoint,
) -> core::result::Result<ConnectingStream, (Error, Self)> {
let bound_socket = match self {
InitStream::Bound(bound_socket) => bound_socket,
InitStream::Unbound(_) => self.bind_to_ephemeral_endpoint(remote_endpoint)?,
};
ConnectingStream::new(bound_socket, *remote_endpoint)
.map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket)))
}
pub fn listen(self, backlog: usize) -> core::result::Result<ListenStream, (Error, Self)> {
let InitStream::Bound(bound_socket) = self else {
return Err((
Error::with_message(Errno::EINVAL, "cannot listen without bound"),
self,
));
};
ListenStream::new(bound_socket, backlog)
.map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket)))
} }
pub fn local_endpoint(&self) -> Result<IpEndpoint> { pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.inner match self {
.read() InitStream::Unbound(_) => {
.local_endpoint() return_errno_with_message!(Errno::EINVAL, "does not has local endpoint")
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has local endpoint")) }
InitStream::Bound(bound_socket) => Ok(bound_socket.local_endpoint().unwrap()),
} }
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
pub fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Relaxed)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
} }
} }

View File

@ -1,148 +1,97 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use super::connected::ConnectedStream; use super::connected::ConnectedStream;
use crate::{ use crate::{
events::{IoEvents, Observer}, events::IoEvents,
net::{ net::iface::{AnyBoundSocket, AnyUnboundSocket, BindPortConfig, IpEndpoint, RawTcpSocket},
iface::{AnyBoundSocket, AnyUnboundSocket, BindPortConfig, IpEndpoint, RawTcpSocket},
poll_ifaces,
},
prelude::*, prelude::*,
process::signal::{Pollee, Poller}, process::signal::Pollee,
}; };
pub struct ListenStream { pub struct ListenStream {
is_nonblocking: AtomicBool,
backlog: usize, backlog: usize,
/// A bound socket held to ensure the TCP port cannot be released /// A bound socket held to ensure the TCP port cannot be released
bound_socket: Arc<AnyBoundSocket>, bound_socket: Arc<AnyBoundSocket>,
/// Backlog sockets listening at the local endpoint /// Backlog sockets listening at the local endpoint
backlog_sockets: RwLock<Vec<BacklogSocket>>, backlog_sockets: RwLock<Vec<BacklogSocket>>,
pollee: Pollee,
} }
impl ListenStream { impl ListenStream {
pub fn new( pub fn new(
nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>, bound_socket: Arc<AnyBoundSocket>,
backlog: usize, backlog: usize,
pollee: Pollee, ) -> core::result::Result<Self, (Error, Arc<AnyBoundSocket>)> {
) -> Result<Arc<Self>> { let listen_stream = Self {
let listen_stream = Arc::new(Self {
is_nonblocking: AtomicBool::new(nonblocking),
backlog, backlog,
bound_socket, bound_socket,
backlog_sockets: RwLock::new(Vec::new()), backlog_sockets: RwLock::new(Vec::new()),
pollee, };
}); if let Err(err) = listen_stream.fill_backlog_sockets() {
listen_stream.fill_backlog_sockets()?; return Err((err, listen_stream.bound_socket));
listen_stream.pollee.reset_events(); }
listen_stream
.bound_socket
.set_observer(Arc::downgrade(&listen_stream) as _);
Ok(listen_stream) Ok(listen_stream)
} }
pub fn accept(&self) -> Result<(Arc<ConnectedStream>, IpEndpoint)> {
// wait to accept
let poller = Poller::new();
loop {
poll_ifaces();
let accepted_socket = if let Some(accepted_socket) = self.try_accept() {
accepted_socket
} else {
let events = self.poll(IoEvents::IN, Some(&poller));
if !events.contains(IoEvents::IN) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try accept again");
}
// FIXME: deal with accept timeout
poller.wait()?;
}
continue;
};
let remote_endpoint = accepted_socket.remote_endpoint().unwrap();
let connected_stream = {
let BacklogSocket {
bound_socket: backlog_socket,
} = accepted_socket;
ConnectedStream::new(
false,
backlog_socket,
remote_endpoint,
Pollee::new(IoEvents::empty()),
)
};
return Ok((connected_stream, remote_endpoint));
}
}
/// Append sockets listening at LocalEndPoint to support backlog /// Append sockets listening at LocalEndPoint to support backlog
fn fill_backlog_sockets(&self) -> Result<()> { fn fill_backlog_sockets(&self) -> Result<()> {
let backlog = self.backlog;
let mut backlog_sockets = self.backlog_sockets.write(); let mut backlog_sockets = self.backlog_sockets.write();
let backlog = self.backlog;
let current_backlog_len = backlog_sockets.len(); let current_backlog_len = backlog_sockets.len();
debug_assert!(backlog >= current_backlog_len); debug_assert!(backlog >= current_backlog_len);
if backlog == current_backlog_len { if backlog == current_backlog_len {
return Ok(()); return Ok(());
} }
for _ in current_backlog_len..backlog { for _ in current_backlog_len..backlog {
let backlog_socket = BacklogSocket::new(&self.bound_socket)?; let backlog_socket = BacklogSocket::new(&self.bound_socket)?;
backlog_sockets.push(backlog_socket); backlog_sockets.push(backlog_socket);
} }
Ok(()) Ok(())
} }
fn try_accept(&self) -> Option<BacklogSocket> { pub fn try_accept(&self) -> Result<ConnectedStream> {
let backlog_socket = {
let mut backlog_sockets = self.backlog_sockets.write(); let mut backlog_sockets = self.backlog_sockets.write();
let index = backlog_sockets let index = backlog_sockets
.iter() .iter()
.position(|backlog_socket| backlog_socket.is_active())?; .position(|backlog_socket| backlog_socket.is_active())
backlog_sockets.remove(index) .ok_or_else(|| Error::with_message(Errno::EAGAIN, "try to accept again"))?;
}; let active_backlog_socket = backlog_sockets.remove(index);
self.fill_backlog_sockets().unwrap();
self.update_io_events(); match BacklogSocket::new(&self.bound_socket) {
Some(backlog_socket) Ok(backlog_socket) => backlog_sockets.push(backlog_socket),
Err(err) => (),
} }
pub fn local_endpoint(&self) -> Result<IpEndpoint> { let remote_endpoint = active_backlog_socket.remote_endpoint().unwrap();
self.bound_socket Ok(ConnectedStream::new(
.local_endpoint() active_backlog_socket.into_bound_socket(),
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint")) remote_endpoint,
))
} }
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { pub fn local_endpoint(&self) -> IpEndpoint {
self.pollee.poll(mask, poller) self.bound_socket.local_endpoint().unwrap()
} }
fn update_io_events(&self) { pub(super) fn init_pollee(&self, pollee: &Pollee) {
pollee.reset_events();
self.update_io_events(pollee);
}
pub(super) fn update_io_events(&self, pollee: &Pollee) {
// The lock should be held to avoid data races // The lock should be held to avoid data races
let backlog_sockets = self.backlog_sockets.read(); let backlog_sockets = self.backlog_sockets.read();
let can_accept = backlog_sockets.iter().any(|socket| socket.is_active()); let can_accept = backlog_sockets.iter().any(|socket| socket.is_active());
if can_accept { if can_accept {
self.pollee.add_events(IoEvents::IN); pollee.add_events(IoEvents::IN);
} else { } else {
self.pollee.del_events(IoEvents::IN); pollee.del_events(IoEvents::IN);
} }
} }
pub fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Relaxed)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
}
}
impl Observer<()> for ListenStream {
fn on_events(&self, _: &()) {
self.update_io_events();
}
} }
struct BacklogSocket { struct BacklogSocket {
@ -155,19 +104,21 @@ impl BacklogSocket {
Errno::EINVAL, Errno::EINVAL,
"the socket is not bound", "the socket is not bound",
))?; ))?;
let unbound_socket = Box::new(AnyUnboundSocket::new_tcp());
let unbound_socket = Box::new(AnyUnboundSocket::new_tcp(Weak::<()>::new()));
let bound_socket = { let bound_socket = {
let iface = bound_socket.iface(); let iface = bound_socket.iface();
let bind_port_config = BindPortConfig::new(local_endpoint.port, true)?; let bind_port_config = BindPortConfig::new(local_endpoint.port, true)?;
iface iface
.bind_socket(unbound_socket, bind_port_config) .bind_socket(unbound_socket, bind_port_config)
.map_err(|(e, _)| e)? .map_err(|(err, _)| err)?
}; };
bound_socket.raw_with(|raw_tcp_socket: &mut RawTcpSocket| { bound_socket.raw_with(|raw_tcp_socket: &mut RawTcpSocket| {
raw_tcp_socket raw_tcp_socket
.listen(local_endpoint) .listen(local_endpoint)
.map_err(|_| Error::with_message(Errno::EINVAL, "fail to listen")) .map_err(|_| Error::with_message(Errno::EINVAL, "fail to listen"))
})?; })?;
Ok(Self { bound_socket }) Ok(Self { bound_socket })
} }
@ -180,4 +131,8 @@ impl BacklogSocket {
self.bound_socket self.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint()) .raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint())
} }
fn into_bound_socket(self) -> Arc<AnyBoundSocket> {
self.bound_socket
}
} }

View File

@ -1,5 +1,10 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::{
mem,
sync::atomic::{AtomicBool, Ordering},
};
use connected::ConnectedStream; use connected::ConnectedStream;
use connecting::ConnectingStream; use connecting::ConnectingStream;
use init::InitStream; use init::InitStream;
@ -9,10 +14,12 @@ use smoltcp::wire::IpEndpoint;
use util::{TcpOptionSet, DEFAULT_MAXSEG}; use util::{TcpOptionSet, DEFAULT_MAXSEG};
use crate::{ use crate::{
events::IoEvents, events::{IoEvents, Observer},
fs::{file_handle::FileLike, utils::StatusFlags}, fs::{file_handle::FileLike, utils::StatusFlags},
match_sock_option_mut, match_sock_option_ref, match_sock_option_mut, match_sock_option_ref,
net::socket::{ net::{
poll_ifaces,
socket::{
options::{Error, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption}, options::{Error, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption},
util::{ util::{
options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF}, options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF},
@ -22,8 +29,9 @@ use crate::{
}, },
Socket, Socket,
}, },
},
prelude::*, prelude::*,
process::signal::Poller, process::signal::{Pollee, Poller},
}; };
mod connected; mod connected;
@ -33,22 +41,27 @@ mod listen;
pub mod options; pub mod options;
mod util; mod util;
use self::connecting::NonConnectedStream;
pub use self::util::CongestionControl; pub use self::util::CongestionControl;
pub struct StreamSocket { pub struct StreamSocket {
options: RwLock<OptionSet>, options: RwLock<OptionSet>,
state: RwLock<State>, state: RwLock<State>,
is_nonblocking: AtomicBool,
pollee: Pollee,
} }
enum State { enum State {
// Start state // Start state
Init(Arc<InitStream>), Init(InitStream),
// Intermediate state // Intermediate state
Connecting(Arc<ConnectingStream>), Connecting(ConnectingStream),
// Final State 1 // Final State 1
Connected(Arc<ConnectedStream>), Connected(ConnectedStream),
// Final State 2 // Final State 2
Listen(Arc<ListenStream>), Listen(ListenStream),
// Poisoned state
Poisoned,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -66,45 +79,159 @@ impl OptionSet {
} }
impl StreamSocket { impl StreamSocket {
pub fn new(nonblocking: bool) -> Self { pub fn new(nonblocking: bool) -> Arc<Self> {
let options = OptionSet::new(); Arc::new_cyclic(|me| {
let state = State::Init(InitStream::new(nonblocking)); let init_stream = InitStream::new(me.clone() as _);
let pollee = Pollee::new(IoEvents::empty());
Self { Self {
options: RwLock::new(options), options: RwLock::new(OptionSet::new()),
state: RwLock::new(state), state: RwLock::new(State::Init(init_stream)),
is_nonblocking: AtomicBool::new(nonblocking),
pollee,
} }
})
}
fn new_connected(connected_stream: ConnectedStream) -> Arc<Self> {
Arc::new_cyclic(move |me| {
let pollee = Pollee::new(IoEvents::empty());
connected_stream.set_observer(me.clone() as _);
connected_stream.init_pollee(&pollee);
Self {
options: RwLock::new(OptionSet::new()),
state: RwLock::new(State::Connected(connected_stream)),
is_nonblocking: AtomicBool::new(false),
pollee,
}
})
} }
fn is_nonblocking(&self) -> bool { fn is_nonblocking(&self) -> bool {
match &*self.state.read() { self.is_nonblocking.load(Ordering::Relaxed)
State::Init(init) => init.is_nonblocking(),
State::Connecting(connecting) => connecting.is_nonblocking(),
State::Connected(connected) => connected.is_nonblocking(),
State::Listen(listen) => listen.is_nonblocking(),
}
} }
fn set_nonblocking(&self, nonblocking: bool) { fn set_nonblocking(&self, nonblocking: bool) {
match &*self.state.read() { self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
State::Init(init) => init.set_nonblocking(nonblocking),
State::Connecting(connecting) => connecting.set_nonblocking(nonblocking),
State::Connected(connected) => connected.set_nonblocking(nonblocking),
State::Listen(listen) => listen.set_nonblocking(nonblocking),
}
} }
fn do_connect(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<ConnectingStream>> { fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
let mut state = self.state.write(); let mut state = self.state.write();
let init_stream = match &*state {
State::Init(init_stream) => init_stream, let owned_state = mem::replace(&mut *state, State::Poisoned);
State::Listen(_) | State::Connecting(_) | State::Connected(_) => { let State::Init(init_stream) = owned_state else {
*state = owned_state;
return_errno_with_message!(Errno::EINVAL, "cannot connect") return_errno_with_message!(Errno::EINVAL, "cannot connect")
}
}; };
let connecting = init_stream.connect(remote_endpoint)?; let connecting_stream = match init_stream.connect(remote_endpoint) {
*state = State::Connecting(connecting.clone()); Ok(connecting_stream) => connecting_stream,
Ok(connecting) Err((err, init_stream)) => {
*state = State::Init(init_stream);
return Err(err);
}
};
connecting_stream.init_pollee(&self.pollee);
*state = State::Connecting(connecting_stream);
Ok(())
}
fn finish_connect(&self) -> Result<()> {
let mut state = self.state.write();
let owned_state = mem::replace(&mut *state, State::Poisoned);
let State::Connecting(connecting_stream) = owned_state else {
*state = owned_state;
debug_assert!(false, "the socket unexpectedly left the connecting state");
return_errno_with_message!(Errno::EINVAL, "the socket is not connecting");
};
let connected_stream = match connecting_stream.into_result() {
Ok(connected_stream) => connected_stream,
Err((err, NonConnectedStream::Init(init_stream))) => {
*state = State::Init(init_stream);
return Err(err);
}
Err((err, NonConnectedStream::Connecting(connecting_stream))) => {
*state = State::Connecting(connecting_stream);
return Err(err);
}
};
connected_stream.init_pollee(&self.pollee);
*state = State::Connected(connected_stream);
Ok(())
}
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let state = self.state.read();
let State::Listen(listen_stream) = &*state else {
return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
};
let connected_stream = listen_stream.try_accept()?;
listen_stream.update_io_events(&self.pollee);
let remote_endpoint = connected_stream.remote_endpoint();
let accepted_socket = Self::new_connected(connected_stream);
Ok((accepted_socket, remote_endpoint.try_into()?))
}
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let state = self.state.read();
let State::Connected(connected_stream) = &*state else {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
};
let recv_bytes = connected_stream.try_recvfrom(buf, flags)?;
connected_stream.update_io_events(&self.pollee);
Ok((recv_bytes, connected_stream.remote_endpoint().try_into()?))
}
fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
let state = self.state.read();
let State::Connected(connected_stream) = &*state else {
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
};
let sent_bytes = connected_stream.try_sendto(buf, flags)?;
connected_stream.update_io_events(&self.pollee);
Ok(sent_bytes)
}
// TODO: Support timeout
fn wait_events<F, R>(&self, mask: IoEvents, mut cond: F) -> Result<R>
where
F: FnMut() -> Result<R>,
{
let poller = Poller::new();
loop {
match cond() {
Err(err) if err.error() == Errno::EAGAIN => (),
result => return result,
};
let events = self.poll(mask, Some(&poller));
if !events.is_empty() {
continue;
}
poller.wait()?;
}
}
fn update_io_events(&self) {
let state = self.state.read();
match &*state {
State::Init(_) | State::Poisoned => (),
State::Connecting(connecting_stream) => {
connecting_stream.update_io_events(&self.pollee)
}
State::Listen(listen_stream) => listen_stream.update_io_events(&self.pollee),
State::Connected(connected_stream) => connected_stream.update_io_events(&self.pollee),
}
} }
} }
@ -123,13 +250,7 @@ impl FileLike for StreamSocket {
} }
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let state = self.state.read(); self.pollee.poll(mask, poller)
match &*state {
State::Init(init) => init.poll(mask, poller),
State::Connecting(connecting) => connecting.poll(mask, poller),
State::Connected(connected) => connected.poll(mask, poller),
State::Listen(listen) => listen.poll(mask, poller),
}
} }
fn status_flags(&self) -> StatusFlags { fn status_flags(&self) -> StatusFlags {
@ -157,68 +278,65 @@ impl FileLike for StreamSocket {
impl Socket for StreamSocket { impl Socket for StreamSocket {
fn bind(&self, socket_addr: SocketAddr) -> Result<()> { fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?; let endpoint = socket_addr.try_into()?;
let state = self.state.read();
match &*state {
State::Init(init_stream) => init_stream.bind(endpoint),
_ => return_errno_with_message!(Errno::EINVAL, "cannot bind"),
}
}
fn connect(&self, socket_addr: SocketAddr) -> Result<()> { let mut state = self.state.write();
let remote_endpoint = socket_addr.try_into()?;
let owned_state = mem::replace(&mut *state, State::Poisoned);
let State::Init(init_stream) = owned_state else {
*state = owned_state;
return_errno_with_message!(Errno::EINVAL, "cannot bind");
};
let bound_socket = match init_stream.bind(&endpoint) {
Ok(bound_socket) => bound_socket,
Err((err, init_stream)) => {
*state = State::Init(init_stream);
return Err(err);
}
};
*state = State::Init(InitStream::new_bound(bound_socket));
let connecting_stream = self.do_connect(&remote_endpoint)?;
match connecting_stream.wait_conn() {
Ok(connected_stream) => {
*self.state.write() = State::Connected(connected_stream);
Ok(()) Ok(())
} }
Err((err, init_stream)) => {
*self.state.write() = State::Init(init_stream); // TODO: Support nonblocking mode
Err(err) fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
} let remote_endpoint = socket_addr.try_into()?;
} self.start_connect(&remote_endpoint)?;
poll_ifaces();
self.wait_events(IoEvents::OUT, || self.finish_connect())
} }
fn listen(&self, backlog: usize) -> Result<()> { fn listen(&self, backlog: usize) -> Result<()> {
let mut state = self.state.write(); let mut state = self.state.write();
let init_stream = match &*state {
State::Init(init_stream) => init_stream, let owned_state = mem::replace(&mut *state, State::Poisoned);
State::Connecting(connecting_stream) => { let State::Init(init_stream) = owned_state else {
return_errno_with_message!(Errno::EINVAL, "cannot listen for a connecting stream") *state = owned_state;
} return_errno_with_message!(Errno::EINVAL, "cannot listen");
State::Listen(listen_stream) => {
return_errno_with_message!(Errno::EINVAL, "cannot listen for a listening stream")
}
State::Connected(_) => return_errno_with_message!(Errno::EINVAL, "cannot listen"),
}; };
let listener = init_stream.listen(backlog)?; let listen_stream = match init_stream.listen(backlog) {
*state = State::Listen(listener); Ok(listen_stream) => listen_stream,
Err((err, init_stream)) => {
*state = State::Init(init_stream);
return Err(err);
}
};
listen_stream.init_pollee(&self.pollee);
*state = State::Listen(listen_stream);
Ok(()) Ok(())
} }
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> { fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let listen_stream = match &*self.state.read() { poll_ifaces();
State::Listen(listen_stream) => listen_stream.clone(), if self.is_nonblocking() {
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not listening"), self.try_accept()
}; } else {
self.wait_events(IoEvents::IN, || self.try_accept())
let (connected_stream, remote_endpoint) = { }
let listen_stream = listen_stream.clone();
listen_stream.accept()?
};
let accepted_socket = {
let state = RwLock::new(State::Connected(connected_stream));
Arc::new(StreamSocket {
options: RwLock::new(OptionSet::new()),
state,
})
};
let socket_addr = remote_endpoint.try_into()?;
Ok((accepted_socket, socket_addr))
} }
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
@ -233,11 +351,12 @@ impl Socket for StreamSocket {
fn addr(&self) -> Result<SocketAddr> { fn addr(&self) -> Result<SocketAddr> {
let state = self.state.read(); let state = self.state.read();
let local_endpoint = match &*state { let local_endpoint = match &*state {
State::Init(init_stream) => init_stream.local_endpoint(), State::Init(init_stream) => init_stream.local_endpoint()?,
State::Connecting(connecting_stream) => connecting_stream.local_endpoint(), State::Connecting(connecting_stream) => connecting_stream.local_endpoint(),
State::Listen(listen_stream) => listen_stream.local_endpoint(), State::Listen(listen_stream) => listen_stream.local_endpoint(),
State::Connected(connected_stream) => connected_stream.local_endpoint(), State::Connected(connected_stream) => connected_stream.local_endpoint(),
}?; State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"),
};
local_endpoint.try_into() local_endpoint.try_into()
} }
@ -252,19 +371,20 @@ impl Socket for StreamSocket {
return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer") return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer")
} }
State::Connected(connected_stream) => connected_stream.remote_endpoint(), State::Connected(connected_stream) => connected_stream.remote_endpoint(),
}?; State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"),
};
remote_endpoint.try_into() remote_endpoint.try_into()
} }
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let connected_stream = match &*self.state.read() { debug_assert!(flags.is_all_supported());
State::Connected(connected_stream) => connected_stream.clone(),
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not connected"),
};
let (recv_size, remote_endpoint) = connected_stream.recvfrom(buf, flags)?; poll_ifaces();
let socket_addr = remote_endpoint.try_into()?; if self.is_nonblocking() {
Ok((recv_size, socket_addr)) self.try_recvfrom(buf, flags)
} else {
self.wait_events(IoEvents::IN, || self.try_recvfrom(buf, flags))
}
} }
fn sendto( fn sendto(
@ -273,16 +393,19 @@ impl Socket for StreamSocket {
remote: Option<SocketAddr>, remote: Option<SocketAddr>,
flags: SendRecvFlags, flags: SendRecvFlags,
) -> Result<usize> { ) -> Result<usize> {
debug_assert!(remote.is_none()); debug_assert!(flags.is_all_supported());
if remote.is_some() { if remote.is_some() {
return_errno_with_message!(Errno::EINVAL, "tcp socked should not provide remote addr"); return_errno_with_message!(Errno::EINVAL, "tcp socked should not provide remote addr");
} }
let connected_stream = match &*self.state.read() { let sent_bytes = if self.is_nonblocking() {
State::Connected(connected_stream) => connected_stream.clone(), self.try_sendto(buf, flags)?
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not connected"), } else {
self.wait_events(IoEvents::OUT, || self.try_sendto(buf, flags))?
}; };
connected_stream.sendto(buf, flags) poll_ifaces();
Ok(sent_bytes)
} }
fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> { fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> {
@ -324,7 +447,9 @@ impl Socket for StreamSocket {
// FIXME: how to get the current MSS? // FIXME: how to get the current MSS?
let maxseg = match &*self.state.read() { let maxseg = match &*self.state.read() {
State::Init(_) | State::Listen(_) | State::Connecting(_) => DEFAULT_MAXSEG, State::Init(_) | State::Listen(_) | State::Connecting(_) | State::Poisoned => {
DEFAULT_MAXSEG
}
State::Connected(_) => options.tcp.maxseg(), State::Connected(_) => options.tcp.maxseg(),
}; };
tcp_maxseg.set(maxseg); tcp_maxseg.set(maxseg);
@ -406,3 +531,9 @@ impl Socket for StreamSocket {
Ok(()) Ok(())
} }
} }
impl Observer<()> for StreamSocket {
fn on_events(&self, events: &()) {
self.update_io_events();
}
}

View File

@ -31,12 +31,12 @@ pub fn sys_socket(domain: i32, type_: i32, protocol: i32) -> Result<SyscallRetur
CSocketAddrFamily::AF_INET, CSocketAddrFamily::AF_INET,
SockType::SOCK_STREAM, SockType::SOCK_STREAM,
Protocol::IPPROTO_IP | Protocol::IPPROTO_TCP, Protocol::IPPROTO_IP | Protocol::IPPROTO_TCP,
) => Arc::new(StreamSocket::new(nonblocking)) as Arc<dyn FileLike>, ) => StreamSocket::new(nonblocking) as Arc<dyn FileLike>,
( (
CSocketAddrFamily::AF_INET, CSocketAddrFamily::AF_INET,
SockType::SOCK_DGRAM, SockType::SOCK_DGRAM,
Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP, Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP,
) => Arc::new(DatagramSocket::new(nonblocking)) as Arc<dyn FileLike>, ) => DatagramSocket::new(nonblocking) as Arc<dyn FileLike>,
_ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported domain"), _ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported domain"),
}; };
let fd = { let fd = {