mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-16 00:36:48 +00:00
Remove Arc
s in TCP and UDP states
This commit is contained in:
parent
07e8cfe2e7
commit
a10d04c5f9
@ -9,6 +9,7 @@ pub type RawSocketHandle = smoltcp::iface::SocketHandle;
|
||||
|
||||
pub struct AnyUnboundSocket {
|
||||
socket_family: AnyRawSocket,
|
||||
observer: Weak<dyn Observer<()>>,
|
||||
}
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
@ -23,7 +24,7 @@ pub(super) enum SocketFamily {
|
||||
}
|
||||
|
||||
impl AnyUnboundSocket {
|
||||
pub fn new_tcp() -> Self {
|
||||
pub fn new_tcp(observer: Weak<dyn Observer<()>>) -> Self {
|
||||
let raw_tcp_socket = {
|
||||
let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; RECV_BUF_LEN]);
|
||||
let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; SEND_BUF_LEN]);
|
||||
@ -31,10 +32,11 @@ impl AnyUnboundSocket {
|
||||
};
|
||||
AnyUnboundSocket {
|
||||
socket_family: AnyRawSocket::Tcp(raw_tcp_socket),
|
||||
observer,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_udp() -> Self {
|
||||
pub fn new_udp(observer: Weak<dyn Observer<()>>) -> Self {
|
||||
let raw_udp_socket = {
|
||||
let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY;
|
||||
let rx_buffer = smoltcp::socket::udp::PacketBuffer::new(
|
||||
@ -49,18 +51,12 @@ impl AnyUnboundSocket {
|
||||
};
|
||||
AnyUnboundSocket {
|
||||
socket_family: AnyRawSocket::Udp(raw_udp_socket),
|
||||
observer,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn raw_socket_family(self) -> AnyRawSocket {
|
||||
self.socket_family
|
||||
}
|
||||
|
||||
pub(super) fn socket_family(&self) -> SocketFamily {
|
||||
match &self.socket_family {
|
||||
AnyRawSocket::Tcp(_) => SocketFamily::Tcp,
|
||||
AnyRawSocket::Udp(_) => SocketFamily::Udp,
|
||||
}
|
||||
pub(super) fn into_raw(self) -> (AnyRawSocket, Weak<dyn Observer<()>>) {
|
||||
(self.socket_family, self.observer)
|
||||
}
|
||||
}
|
||||
|
||||
@ -79,13 +75,14 @@ impl AnyBoundSocket {
|
||||
handle: smoltcp::iface::SocketHandle,
|
||||
port: u16,
|
||||
socket_family: SocketFamily,
|
||||
observer: Weak<dyn Observer<()>>,
|
||||
) -> Arc<Self> {
|
||||
Arc::new_cyclic(|weak_self| Self {
|
||||
iface,
|
||||
handle,
|
||||
port,
|
||||
socket_family,
|
||||
observer: RwLock::new(Weak::<()>::new()),
|
||||
observer: RwLock::new(observer),
|
||||
weak_self: weak_self.clone(),
|
||||
})
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ use smoltcp::{
|
||||
};
|
||||
|
||||
use super::{
|
||||
any_socket::{AnyBoundSocket, AnyRawSocket, AnyUnboundSocket},
|
||||
any_socket::{AnyBoundSocket, AnyRawSocket, AnyUnboundSocket, SocketFamily},
|
||||
time::get_network_timestamp,
|
||||
util::BindPortConfig,
|
||||
Iface, Ipv4Address,
|
||||
@ -107,20 +107,28 @@ impl IfaceCommon {
|
||||
} else {
|
||||
match self.alloc_ephemeral_port() {
|
||||
Ok(port) => port,
|
||||
Err(e) => return Err((e, socket)),
|
||||
Err(err) => return Err((err, socket)),
|
||||
}
|
||||
};
|
||||
if let Some(e) = self.bind_port(port, config.can_reuse()).err() {
|
||||
return Err((e, socket));
|
||||
if let Some(err) = self.bind_port(port, config.can_reuse()).err() {
|
||||
return Err((err, socket));
|
||||
}
|
||||
let socket_family = socket.socket_family();
|
||||
let mut sockets = self.sockets.lock_irq_disabled();
|
||||
let handle = match socket.raw_socket_family() {
|
||||
AnyRawSocket::Tcp(tcp_socket) => sockets.add(tcp_socket),
|
||||
AnyRawSocket::Udp(udp_socket) => sockets.add(udp_socket),
|
||||
|
||||
let (handle, socket_family, observer) = match socket.into_raw() {
|
||||
(AnyRawSocket::Tcp(tcp_socket), observer) => (
|
||||
self.sockets.lock_irq_disabled().add(tcp_socket),
|
||||
SocketFamily::Tcp,
|
||||
observer,
|
||||
),
|
||||
(AnyRawSocket::Udp(udp_socket), observer) => (
|
||||
self.sockets.lock_irq_disabled().add(udp_socket),
|
||||
SocketFamily::Udp,
|
||||
observer,
|
||||
),
|
||||
};
|
||||
let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family);
|
||||
let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family, observer);
|
||||
self.insert_bound_socket(&bound_socket).unwrap();
|
||||
|
||||
Ok(bound_socket)
|
||||
}
|
||||
|
||||
|
@ -46,7 +46,6 @@ pub trait Iface: internal::IfaceInternal + Send + Sync {
|
||||
config: BindPortConfig,
|
||||
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Box<AnyUnboundSocket>)> {
|
||||
let common = self.common();
|
||||
let socket_type_inner = socket.socket_family();
|
||||
common.bind_socket(self.arc_self(), socket, config)
|
||||
}
|
||||
|
||||
|
@ -44,7 +44,7 @@ fn get_ephemeral_iface(remote_ip_addr: &IpAddress) -> Arc<dyn Iface> {
|
||||
|
||||
pub(super) fn bind_socket(
|
||||
unbound_socket: Box<AnyUnboundSocket>,
|
||||
endpoint: IpEndpoint,
|
||||
endpoint: &IpEndpoint,
|
||||
can_reuse: bool,
|
||||
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Box<AnyUnboundSocket>)> {
|
||||
let iface = match get_iface_to_bind(&endpoint.addr) {
|
||||
|
@ -1,61 +1,49 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use crate::{
|
||||
events::{IoEvents, Observer},
|
||||
events::IoEvents,
|
||||
net::{
|
||||
iface::{AnyBoundSocket, IpEndpoint, RawUdpSocket},
|
||||
poll_ifaces,
|
||||
socket::util::send_recv_flags::SendRecvFlags,
|
||||
},
|
||||
prelude::*,
|
||||
process::signal::{Pollee, Poller},
|
||||
process::signal::Pollee,
|
||||
};
|
||||
|
||||
pub struct BoundDatagram {
|
||||
bound_socket: Arc<AnyBoundSocket>,
|
||||
remote_endpoint: RwLock<Option<IpEndpoint>>,
|
||||
pollee: Pollee,
|
||||
remote_endpoint: Option<IpEndpoint>,
|
||||
}
|
||||
|
||||
impl BoundDatagram {
|
||||
pub fn new(bound_socket: Arc<AnyBoundSocket>, pollee: Pollee) -> Arc<Self> {
|
||||
let bound = Arc::new(Self {
|
||||
pub fn new(bound_socket: Arc<AnyBoundSocket>) -> Self {
|
||||
Self {
|
||||
bound_socket,
|
||||
remote_endpoint: RwLock::new(None),
|
||||
pollee,
|
||||
});
|
||||
bound.bound_socket.set_observer(Arc::downgrade(&bound) as _);
|
||||
bound
|
||||
remote_endpoint: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn local_endpoint(&self) -> IpEndpoint {
|
||||
self.bound_socket.local_endpoint().unwrap()
|
||||
}
|
||||
|
||||
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
|
||||
self.remote_endpoint
|
||||
.read()
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "remote endpoint is not specified"))
|
||||
}
|
||||
|
||||
pub fn set_remote_endpoint(&self, endpoint: IpEndpoint) {
|
||||
*self.remote_endpoint.write() = Some(endpoint);
|
||||
}
|
||||
|
||||
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
|
||||
self.bound_socket.local_endpoint().ok_or_else(|| {
|
||||
Error::with_message(Errno::EINVAL, "socket does not bind to local endpoint")
|
||||
})
|
||||
pub fn set_remote_endpoint(&mut self, endpoint: &IpEndpoint) {
|
||||
self.remote_endpoint = Some(*endpoint)
|
||||
}
|
||||
|
||||
pub fn try_recvfrom(
|
||||
&self,
|
||||
buf: &mut [u8],
|
||||
flags: &SendRecvFlags,
|
||||
flags: SendRecvFlags,
|
||||
) -> Result<(usize, IpEndpoint)> {
|
||||
poll_ifaces();
|
||||
let recv_slice = |socket: &mut RawUdpSocket| {
|
||||
socket
|
||||
.recv_slice(buf)
|
||||
.map_err(|_| Error::with_message(Errno::EAGAIN, "recv buf is empty"))
|
||||
};
|
||||
self.bound_socket.raw_with(recv_slice)
|
||||
self.bound_socket
|
||||
.raw_with(|socket: &mut RawUdpSocket| socket.recv_slice(buf))
|
||||
.map_err(|_| Error::with_message(Errno::EAGAIN, "recv buf is empty"))
|
||||
}
|
||||
|
||||
pub fn try_sendto(
|
||||
@ -65,27 +53,21 @@ impl BoundDatagram {
|
||||
flags: SendRecvFlags,
|
||||
) -> Result<usize> {
|
||||
let remote_endpoint = remote
|
||||
.or_else(|| self.remote_endpoint().ok())
|
||||
.or(self.remote_endpoint)
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "udp should provide remote addr"))?;
|
||||
let send_slice = |socket: &mut RawUdpSocket| {
|
||||
socket
|
||||
.send_slice(buf, remote_endpoint)
|
||||
.map(|_| buf.len())
|
||||
.map_err(|_| Error::with_message(Errno::EAGAIN, "send udp packet fails"))
|
||||
};
|
||||
let len = self.bound_socket.raw_with(send_slice)?;
|
||||
poll_ifaces();
|
||||
Ok(len)
|
||||
self.bound_socket
|
||||
.raw_with(|socket: &mut RawUdpSocket| socket.send_slice(buf, remote_endpoint))
|
||||
.map(|_| buf.len())
|
||||
.map_err(|_| Error::with_message(Errno::EAGAIN, "send udp packet fails"))
|
||||
}
|
||||
|
||||
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.pollee.poll(mask, poller)
|
||||
pub(super) fn init_pollee(&self, pollee: &Pollee) {
|
||||
pollee.reset_events();
|
||||
self.update_io_events(pollee)
|
||||
}
|
||||
|
||||
fn update_io_events(&self) {
|
||||
pub(super) fn update_io_events(&self, pollee: &Pollee) {
|
||||
self.bound_socket.raw_with(|socket: &mut RawUdpSocket| {
|
||||
let pollee = &self.pollee;
|
||||
|
||||
if socket.can_recv() {
|
||||
pollee.add_events(IoEvents::IN);
|
||||
} else {
|
||||
@ -100,9 +82,3 @@ impl BoundDatagram {
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Observer<()> for BoundDatagram {
|
||||
fn on_events(&self, _: &()) {
|
||||
self.update_io_events();
|
||||
}
|
||||
}
|
||||
|
@ -1,81 +1,91 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use core::sync::atomic::{AtomicBool, Ordering};
|
||||
use core::{
|
||||
mem,
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
};
|
||||
|
||||
use self::{bound::BoundDatagram, unbound::UnboundDatagram};
|
||||
use super::{always_some::AlwaysSome, common::get_ephemeral_endpoint};
|
||||
use super::common::get_ephemeral_endpoint;
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
events::{IoEvents, Observer},
|
||||
fs::{file_handle::FileLike, utils::StatusFlags},
|
||||
net::{
|
||||
iface::IpEndpoint,
|
||||
poll_ifaces,
|
||||
socket::{
|
||||
util::{send_recv_flags::SendRecvFlags, socket_addr::SocketAddr},
|
||||
Socket,
|
||||
},
|
||||
},
|
||||
prelude::*,
|
||||
process::signal::Poller,
|
||||
process::signal::{Pollee, Poller},
|
||||
};
|
||||
|
||||
mod bound;
|
||||
mod unbound;
|
||||
|
||||
pub struct DatagramSocket {
|
||||
nonblocking: AtomicBool,
|
||||
inner: RwLock<Inner>,
|
||||
nonblocking: AtomicBool,
|
||||
pollee: Pollee,
|
||||
}
|
||||
|
||||
enum Inner {
|
||||
Unbound(AlwaysSome<UnboundDatagram>),
|
||||
Bound(Arc<BoundDatagram>),
|
||||
Unbound(UnboundDatagram),
|
||||
Bound(BoundDatagram),
|
||||
Poisoned,
|
||||
}
|
||||
|
||||
impl Inner {
|
||||
fn is_bound(&self) -> bool {
|
||||
matches!(self, Inner::Bound { .. })
|
||||
}
|
||||
|
||||
fn bind(&mut self, endpoint: IpEndpoint) -> Result<Arc<BoundDatagram>> {
|
||||
let unbound = match self {
|
||||
Inner::Unbound(unbound) => unbound,
|
||||
Inner::Bound(..) => return_errno_with_message!(
|
||||
Errno::EINVAL,
|
||||
"the socket is already bound to an address"
|
||||
),
|
||||
fn bind(self, endpoint: &IpEndpoint) -> core::result::Result<BoundDatagram, (Error, Self)> {
|
||||
let unbound_datagram = match self {
|
||||
Inner::Unbound(unbound_datagram) => unbound_datagram,
|
||||
Inner::Bound(bound_datagram) => {
|
||||
return Err((
|
||||
Error::with_message(Errno::EINVAL, "the socket is already bound to an address"),
|
||||
Inner::Bound(bound_datagram),
|
||||
));
|
||||
}
|
||||
Inner::Poisoned => {
|
||||
return Err((
|
||||
Error::with_message(Errno::EINVAL, "the socket is poisoned"),
|
||||
Inner::Poisoned,
|
||||
));
|
||||
}
|
||||
};
|
||||
let bound = unbound.try_take_with(|unbound| unbound.bind(endpoint))?;
|
||||
*self = Inner::Bound(bound.clone());
|
||||
Ok(bound)
|
||||
|
||||
let bound_datagram = match unbound_datagram.bind(endpoint) {
|
||||
Ok(bound_datagram) => bound_datagram,
|
||||
Err((err, unbound_datagram)) => return Err((err, Inner::Unbound(unbound_datagram))),
|
||||
};
|
||||
Ok(bound_datagram)
|
||||
}
|
||||
|
||||
fn bind_to_ephemeral_endpoint(
|
||||
&mut self,
|
||||
self,
|
||||
remote_endpoint: &IpEndpoint,
|
||||
) -> Result<Arc<BoundDatagram>> {
|
||||
let endpoint = get_ephemeral_endpoint(remote_endpoint);
|
||||
self.bind(endpoint)
|
||||
}
|
||||
|
||||
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
match self {
|
||||
Inner::Unbound(unbound) => unbound.poll(mask, poller),
|
||||
Inner::Bound(bound) => bound.poll(mask, poller),
|
||||
) -> core::result::Result<BoundDatagram, (Error, Self)> {
|
||||
if let Inner::Bound(bound_datagram) = self {
|
||||
return Ok(bound_datagram);
|
||||
}
|
||||
|
||||
let endpoint = get_ephemeral_endpoint(remote_endpoint);
|
||||
self.bind(&endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
impl DatagramSocket {
|
||||
pub fn new(nonblocking: bool) -> Self {
|
||||
let unbound = UnboundDatagram::new();
|
||||
Self {
|
||||
inner: RwLock::new(Inner::Unbound(AlwaysSome::new(unbound))),
|
||||
nonblocking: AtomicBool::new(nonblocking),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_bound(&self) -> bool {
|
||||
self.inner.read().is_bound()
|
||||
pub fn new(nonblocking: bool) -> Arc<Self> {
|
||||
Arc::new_cyclic(|me| {
|
||||
let unbound_datagram = UnboundDatagram::new(me.clone() as _);
|
||||
let pollee = Pollee::new(IoEvents::empty());
|
||||
Self {
|
||||
inner: RwLock::new(Inner::Unbound(unbound_datagram)),
|
||||
nonblocking: AtomicBool::new(nonblocking),
|
||||
pollee,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn is_nonblocking(&self) -> bool {
|
||||
@ -86,26 +96,81 @@ impl DatagramSocket {
|
||||
self.nonblocking.store(nonblocking, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
fn bound(&self) -> Result<Arc<BoundDatagram>> {
|
||||
if let Inner::Bound(bound) = &*self.inner.read() {
|
||||
Ok(bound.clone())
|
||||
} else {
|
||||
return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint")
|
||||
}
|
||||
}
|
||||
|
||||
fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<BoundDatagram>> {
|
||||
fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
|
||||
// Fast path
|
||||
if let Inner::Bound(bound) = &*self.inner.read() {
|
||||
return Ok(bound.clone());
|
||||
if let Inner::Bound(_) = &*self.inner.read() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Slow path
|
||||
let mut inner = self.inner.write();
|
||||
if let Inner::Bound(bound) = &*inner {
|
||||
return Ok(bound.clone());
|
||||
let owned_inner = mem::replace(&mut *inner, Inner::Poisoned);
|
||||
|
||||
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(remote_endpoint) {
|
||||
Ok(bound_datagram) => bound_datagram,
|
||||
Err((err, err_inner)) => {
|
||||
*inner = err_inner;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
bound_datagram.init_pollee(&self.pollee);
|
||||
*inner = Inner::Bound(bound_datagram);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
|
||||
let inner = self.inner.read();
|
||||
let Inner::Bound(bound_datagram) = &*inner else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
|
||||
};
|
||||
let (recv_bytes, remote_endpoint) = bound_datagram.try_recvfrom(buf, flags)?;
|
||||
bound_datagram.update_io_events(&self.pollee);
|
||||
Ok((recv_bytes, remote_endpoint.try_into()?))
|
||||
}
|
||||
|
||||
fn try_sendto(
|
||||
&self,
|
||||
buf: &[u8],
|
||||
remote: Option<IpEndpoint>,
|
||||
flags: SendRecvFlags,
|
||||
) -> Result<usize> {
|
||||
let inner = self.inner.read();
|
||||
let Inner::Bound(bound_datagram) = &*inner else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
|
||||
};
|
||||
let sent_bytes = bound_datagram.try_sendto(buf, remote, flags)?;
|
||||
bound_datagram.update_io_events(&self.pollee);
|
||||
Ok(sent_bytes)
|
||||
}
|
||||
|
||||
// TODO: Support timeout
|
||||
fn wait_events<F, R>(&self, mask: IoEvents, mut cond: F) -> Result<R>
|
||||
where
|
||||
F: FnMut() -> Result<R>,
|
||||
{
|
||||
let poller = Poller::new();
|
||||
|
||||
loop {
|
||||
match cond() {
|
||||
Err(err) if err.error() == Errno::EAGAIN => (),
|
||||
result => return result,
|
||||
};
|
||||
|
||||
let events = self.poll(mask, Some(&poller));
|
||||
if !events.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
poller.wait()?;
|
||||
}
|
||||
inner.bind_to_ephemeral_endpoint(remote_endpoint)
|
||||
}
|
||||
|
||||
fn update_io_events(&self) {
|
||||
let inner = self.inner.read();
|
||||
let Inner::Bound(bound_datagram) = &*inner else {
|
||||
return;
|
||||
};
|
||||
bound_datagram.update_io_events(&self.pollee);
|
||||
}
|
||||
}
|
||||
|
||||
@ -124,7 +189,7 @@ impl FileLike for DatagramSocket {
|
||||
}
|
||||
|
||||
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.inner.read().poll(mask, poller)
|
||||
self.pollee.poll(mask, poller)
|
||||
}
|
||||
|
||||
fn as_socket(self: Arc<Self>) -> Option<Arc<dyn Socket>> {
|
||||
@ -152,43 +217,61 @@ impl FileLike for DatagramSocket {
|
||||
impl Socket for DatagramSocket {
|
||||
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
|
||||
let endpoint = socket_addr.try_into()?;
|
||||
self.inner.write().bind(endpoint)?;
|
||||
|
||||
let mut inner = self.inner.write();
|
||||
let owned_inner = mem::replace(&mut *inner, Inner::Poisoned);
|
||||
|
||||
let bound_datagram = match owned_inner.bind(&endpoint) {
|
||||
Ok(bound_datagram) => bound_datagram,
|
||||
Err((err, err_inner)) => {
|
||||
*inner = err_inner;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
bound_datagram.init_pollee(&self.pollee);
|
||||
*inner = Inner::Bound(bound_datagram);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
|
||||
let endpoint = socket_addr.try_into()?;
|
||||
let bound = self.try_bind_empheral(&endpoint)?;
|
||||
bound.set_remote_endpoint(endpoint);
|
||||
|
||||
self.try_bind_empheral(&endpoint)?;
|
||||
|
||||
let mut inner = self.inner.write();
|
||||
let Inner::Bound(bound_datagram) = &mut *inner else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not bound")
|
||||
};
|
||||
bound_datagram.set_remote_endpoint(&endpoint);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn addr(&self) -> Result<SocketAddr> {
|
||||
self.bound()?.local_endpoint()?.try_into()
|
||||
let inner = self.inner.read();
|
||||
let Inner::Bound(bound_datagram) = &*inner else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
|
||||
};
|
||||
bound_datagram.local_endpoint().try_into()
|
||||
}
|
||||
|
||||
fn peer_addr(&self) -> Result<SocketAddr> {
|
||||
self.bound()?.remote_endpoint()?.try_into()
|
||||
let inner = self.inner.read();
|
||||
let Inner::Bound(bound_datagram) = &*inner else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
|
||||
};
|
||||
bound_datagram.remote_endpoint()?.try_into()
|
||||
}
|
||||
|
||||
// FIXME: respect RecvFromFlags
|
||||
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
|
||||
debug_assert!(flags.is_all_supported());
|
||||
let bound = self.bound()?;
|
||||
let poller = Poller::new();
|
||||
loop {
|
||||
if let Ok((recv_len, remote_endpoint)) = bound.try_recvfrom(buf, &flags) {
|
||||
let remote_addr = remote_endpoint.try_into()?;
|
||||
return Ok((recv_len, remote_addr));
|
||||
}
|
||||
let events = bound.poll(IoEvents::IN, Some(&poller));
|
||||
if !events.contains(IoEvents::IN) {
|
||||
if self.is_nonblocking() {
|
||||
return_errno_with_message!(Errno::EAGAIN, "try to receive again");
|
||||
}
|
||||
// FIXME: deal with recvfrom timeout
|
||||
poller.wait()?;
|
||||
}
|
||||
|
||||
poll_ifaces();
|
||||
if self.is_nonblocking() {
|
||||
self.try_recvfrom(buf, flags)
|
||||
} else {
|
||||
self.wait_events(IoEvents::IN, || self.try_recvfrom(buf, flags))
|
||||
}
|
||||
}
|
||||
|
||||
@ -199,13 +282,24 @@ impl Socket for DatagramSocket {
|
||||
flags: SendRecvFlags,
|
||||
) -> Result<usize> {
|
||||
debug_assert!(flags.is_all_supported());
|
||||
let (bound, remote_endpoint) = if let Some(addr) = remote {
|
||||
let endpoint = addr.try_into()?;
|
||||
(self.try_bind_empheral(&endpoint)?, Some(endpoint))
|
||||
} else {
|
||||
let bound = self.bound()?;
|
||||
(bound, None)
|
||||
|
||||
let remote_endpoint = match remote {
|
||||
Some(remote_addr) => Some(remote_addr.try_into()?),
|
||||
None => None,
|
||||
};
|
||||
bound.try_sendto(buf, remote_endpoint, flags)
|
||||
if let Some(endpoint) = remote_endpoint {
|
||||
self.try_bind_empheral(&endpoint)?;
|
||||
}
|
||||
|
||||
// TODO: Block if the send buffer is full
|
||||
let sent_bytes = self.try_sendto(buf, remote_endpoint, flags)?;
|
||||
poll_ifaces();
|
||||
Ok(sent_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl Observer<()> for DatagramSocket {
|
||||
fn on_events(&self, events: &()) {
|
||||
self.update_io_events();
|
||||
}
|
||||
}
|
||||
|
@ -1,53 +1,39 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use alloc::sync::Weak;
|
||||
|
||||
use super::bound::BoundDatagram;
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
events::Observer,
|
||||
net::{
|
||||
iface::{AnyUnboundSocket, IpEndpoint, RawUdpSocket},
|
||||
socket::ip::common::bind_socket,
|
||||
},
|
||||
prelude::*,
|
||||
process::signal::{Pollee, Poller},
|
||||
};
|
||||
|
||||
pub struct UnboundDatagram {
|
||||
unbound_socket: Box<AnyUnboundSocket>,
|
||||
pollee: Pollee,
|
||||
}
|
||||
|
||||
impl UnboundDatagram {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(observer: Weak<dyn Observer<()>>) -> Self {
|
||||
Self {
|
||||
unbound_socket: Box::new(AnyUnboundSocket::new_udp()),
|
||||
pollee: Pollee::new(IoEvents::empty()),
|
||||
unbound_socket: Box::new(AnyUnboundSocket::new_udp(observer)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.pollee.poll(mask, poller)
|
||||
}
|
||||
|
||||
pub fn bind(
|
||||
self,
|
||||
endpoint: IpEndpoint,
|
||||
) -> core::result::Result<Arc<BoundDatagram>, (Error, Self)> {
|
||||
pub fn bind(self, endpoint: &IpEndpoint) -> core::result::Result<BoundDatagram, (Error, Self)> {
|
||||
let bound_socket = match bind_socket(self.unbound_socket, endpoint, false) {
|
||||
Ok(bound_socket) => bound_socket,
|
||||
Err((err, unbound_socket)) => {
|
||||
return Err((
|
||||
err,
|
||||
Self {
|
||||
unbound_socket,
|
||||
pollee: self.pollee,
|
||||
},
|
||||
))
|
||||
}
|
||||
Err((err, unbound_socket)) => return Err((err, Self { unbound_socket })),
|
||||
};
|
||||
|
||||
let bound_endpoint = bound_socket.local_endpoint().unwrap();
|
||||
bound_socket.raw_with(|socket: &mut RawUdpSocket| {
|
||||
socket.bind(bound_endpoint).unwrap();
|
||||
});
|
||||
Ok(BoundDatagram::new(bound_socket, self.pollee))
|
||||
|
||||
Ok(BoundDatagram::new(bound_socket))
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,5 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
mod always_some;
|
||||
mod common;
|
||||
mod datagram;
|
||||
pub mod stream;
|
||||
|
@ -1,42 +1,28 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use core::sync::atomic::{AtomicBool, Ordering};
|
||||
use alloc::sync::Weak;
|
||||
|
||||
use crate::{
|
||||
events::{IoEvents, Observer},
|
||||
net::{
|
||||
iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket},
|
||||
poll_ifaces,
|
||||
socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd},
|
||||
},
|
||||
prelude::*,
|
||||
process::signal::{Pollee, Poller},
|
||||
process::signal::Pollee,
|
||||
};
|
||||
|
||||
pub struct ConnectedStream {
|
||||
nonblocking: AtomicBool,
|
||||
bound_socket: Arc<AnyBoundSocket>,
|
||||
remote_endpoint: IpEndpoint,
|
||||
pollee: Pollee,
|
||||
}
|
||||
|
||||
impl ConnectedStream {
|
||||
pub fn new(
|
||||
is_nonblocking: bool,
|
||||
bound_socket: Arc<AnyBoundSocket>,
|
||||
remote_endpoint: IpEndpoint,
|
||||
pollee: Pollee,
|
||||
) -> Arc<Self> {
|
||||
let connected = Arc::new(Self {
|
||||
nonblocking: AtomicBool::new(is_nonblocking),
|
||||
pub fn new(bound_socket: Arc<AnyBoundSocket>, remote_endpoint: IpEndpoint) -> Self {
|
||||
Self {
|
||||
bound_socket,
|
||||
remote_endpoint,
|
||||
pollee,
|
||||
});
|
||||
connected
|
||||
.bound_socket
|
||||
.set_observer(Arc::downgrade(&connected) as _);
|
||||
connected
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
||||
@ -44,102 +30,46 @@ impl ConnectedStream {
|
||||
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
|
||||
socket.close();
|
||||
});
|
||||
poll_ifaces();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, IpEndpoint)> {
|
||||
debug_assert!(flags.is_all_supported());
|
||||
|
||||
let poller = Poller::new();
|
||||
loop {
|
||||
let recv_len = self.try_recvfrom(buf, flags)?;
|
||||
if recv_len > 0 {
|
||||
let remote_endpoint = self.remote_endpoint()?;
|
||||
return Ok((recv_len, remote_endpoint));
|
||||
}
|
||||
let events = self.poll(IoEvents::IN, Some(&poller));
|
||||
if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) {
|
||||
return_errno_with_message!(Errno::ENOTCONN, "recv packet fails");
|
||||
}
|
||||
if !events.contains(IoEvents::IN) {
|
||||
if self.is_nonblocking() {
|
||||
return_errno_with_message!(Errno::EAGAIN, "try to recv again");
|
||||
}
|
||||
// FIXME: deal with receive timeout
|
||||
poller.wait()?;
|
||||
}
|
||||
pub fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<usize> {
|
||||
let recv_bytes = self
|
||||
.bound_socket
|
||||
.raw_with(|socket: &mut RawTcpSocket| socket.recv_slice(buf))
|
||||
.map_err(|_| Error::with_message(Errno::ENOTCONN, "fail to recv packet"))?;
|
||||
if recv_bytes == 0 {
|
||||
return_errno_with_message!(Errno::EAGAIN, "try to recv again");
|
||||
}
|
||||
Ok(recv_bytes)
|
||||
}
|
||||
|
||||
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<usize> {
|
||||
poll_ifaces();
|
||||
let res = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
|
||||
socket
|
||||
.recv_slice(buf)
|
||||
.map_err(|_| Error::with_message(Errno::ENOTCONN, "fail to recv packet"))
|
||||
});
|
||||
self.update_io_events();
|
||||
res
|
||||
}
|
||||
|
||||
pub fn sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
|
||||
debug_assert!(flags.is_all_supported());
|
||||
|
||||
let poller = Poller::new();
|
||||
loop {
|
||||
let sent_len = self.try_sendto(buf, flags)?;
|
||||
if sent_len > 0 {
|
||||
return Ok(sent_len);
|
||||
}
|
||||
let events = self.poll(IoEvents::OUT, Some(&poller));
|
||||
if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) {
|
||||
return_errno_with_message!(Errno::ENOBUFS, "fail to send packets");
|
||||
}
|
||||
if !events.contains(IoEvents::OUT) {
|
||||
if self.is_nonblocking() {
|
||||
return_errno_with_message!(Errno::EAGAIN, "try to send again");
|
||||
}
|
||||
// FIXME: deal with send timeout
|
||||
poller.wait()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
|
||||
let res = self
|
||||
pub fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
|
||||
let sent_bytes = self
|
||||
.bound_socket
|
||||
.raw_with(|socket: &mut RawTcpSocket| socket.send_slice(buf))
|
||||
.map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet"));
|
||||
match res {
|
||||
// We have to explicitly invoke `update_io_events` when the send buffer becomes
|
||||
// full. Note that smoltcp does not think it is an interface event, so calling
|
||||
// `poll_ifaces` alone is not enough.
|
||||
Ok(0) => self.update_io_events(),
|
||||
Ok(_) => poll_ifaces(),
|
||||
_ => (),
|
||||
};
|
||||
res
|
||||
.map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet"))?;
|
||||
if sent_bytes == 0 {
|
||||
return_errno_with_message!(Errno::EAGAIN, "try to send again");
|
||||
}
|
||||
Ok(sent_bytes)
|
||||
}
|
||||
|
||||
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
|
||||
self.bound_socket
|
||||
.local_endpoint()
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint"))
|
||||
pub fn local_endpoint(&self) -> IpEndpoint {
|
||||
self.bound_socket.local_endpoint().unwrap()
|
||||
}
|
||||
|
||||
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
|
||||
Ok(self.remote_endpoint)
|
||||
pub fn remote_endpoint(&self) -> IpEndpoint {
|
||||
self.remote_endpoint
|
||||
}
|
||||
|
||||
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.pollee.poll(mask, poller)
|
||||
pub(super) fn init_pollee(&self, pollee: &Pollee) {
|
||||
pollee.reset_events();
|
||||
self.update_io_events(pollee);
|
||||
}
|
||||
|
||||
fn update_io_events(&self) {
|
||||
pub(super) fn update_io_events(&self, pollee: &Pollee) {
|
||||
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
|
||||
let pollee = &self.pollee;
|
||||
|
||||
if socket.can_recv() {
|
||||
pollee.add_events(IoEvents::IN);
|
||||
} else {
|
||||
@ -154,17 +84,7 @@ impl ConnectedStream {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn is_nonblocking(&self) -> bool {
|
||||
self.nonblocking.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn set_nonblocking(&self, nonblocking: bool) {
|
||||
self.nonblocking.store(nonblocking, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
impl Observer<()> for ConnectedStream {
|
||||
fn on_events(&self, _: &()) {
|
||||
self.update_io_events();
|
||||
pub(super) fn set_observer(&self, observer: Weak<dyn Observer<()>>) {
|
||||
self.bound_socket.set_observer(observer)
|
||||
}
|
||||
}
|
||||
|
@ -1,116 +1,77 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use core::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use super::{connected::ConnectedStream, init::InitStream};
|
||||
use crate::{
|
||||
events::{IoEvents, Observer},
|
||||
net::{
|
||||
iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket},
|
||||
poll_ifaces,
|
||||
},
|
||||
events::IoEvents,
|
||||
net::iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket},
|
||||
prelude::*,
|
||||
process::signal::{Pollee, Poller},
|
||||
process::signal::Pollee,
|
||||
};
|
||||
|
||||
pub struct ConnectingStream {
|
||||
nonblocking: AtomicBool,
|
||||
bound_socket: Arc<AnyBoundSocket>,
|
||||
remote_endpoint: IpEndpoint,
|
||||
conn_result: RwLock<Option<ConnResult>>,
|
||||
pollee: Pollee,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum ConnResult {
|
||||
Connected,
|
||||
Refused,
|
||||
}
|
||||
|
||||
pub enum NonConnectedStream {
|
||||
Init(InitStream),
|
||||
Connecting(ConnectingStream),
|
||||
}
|
||||
|
||||
impl ConnectingStream {
|
||||
pub fn new(
|
||||
nonblocking: bool,
|
||||
bound_socket: Arc<AnyBoundSocket>,
|
||||
remote_endpoint: IpEndpoint,
|
||||
pollee: Pollee,
|
||||
) -> Result<Arc<Self>> {
|
||||
bound_socket.do_connect(remote_endpoint)?;
|
||||
|
||||
let connecting = Arc::new(Self {
|
||||
nonblocking: AtomicBool::new(nonblocking),
|
||||
) -> core::result::Result<Self, (Error, Arc<AnyBoundSocket>)> {
|
||||
if let Err(err) = bound_socket.do_connect(remote_endpoint) {
|
||||
return Err((err, bound_socket));
|
||||
}
|
||||
Ok(Self {
|
||||
bound_socket,
|
||||
remote_endpoint,
|
||||
conn_result: RwLock::new(None),
|
||||
pollee,
|
||||
});
|
||||
connecting.pollee.reset_events();
|
||||
connecting
|
||||
.bound_socket
|
||||
.set_observer(Arc::downgrade(&connecting) as _);
|
||||
Ok(connecting)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn wait_conn(
|
||||
&self,
|
||||
) -> core::result::Result<Arc<ConnectedStream>, (Error, Arc<InitStream>)> {
|
||||
debug_assert!(!self.is_nonblocking());
|
||||
|
||||
let poller = Poller::new();
|
||||
loop {
|
||||
poll_ifaces();
|
||||
|
||||
match *self.conn_result.read() {
|
||||
Some(ConnResult::Connected) => {
|
||||
return Ok(ConnectedStream::new(
|
||||
self.is_nonblocking(),
|
||||
self.bound_socket.clone(),
|
||||
self.remote_endpoint,
|
||||
self.pollee.clone(),
|
||||
));
|
||||
}
|
||||
Some(ConnResult::Refused) => {
|
||||
return Err((
|
||||
Error::with_message(Errno::ECONNREFUSED, "connection refused"),
|
||||
InitStream::new_bound(
|
||||
self.is_nonblocking(),
|
||||
self.bound_socket.clone(),
|
||||
self.pollee.clone(),
|
||||
),
|
||||
));
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
|
||||
let events = self.poll(IoEvents::OUT, Some(&poller));
|
||||
if !events.contains(IoEvents::OUT) {
|
||||
// FIXME: deal with nonblocking mode & connecting timeout
|
||||
poller.wait().expect("async connect() not implemented");
|
||||
}
|
||||
pub fn into_result(self) -> core::result::Result<ConnectedStream, (Error, NonConnectedStream)> {
|
||||
let conn_result = *self.conn_result.read();
|
||||
match conn_result {
|
||||
Some(ConnResult::Connected) => Ok(ConnectedStream::new(
|
||||
self.bound_socket,
|
||||
self.remote_endpoint,
|
||||
)),
|
||||
Some(ConnResult::Refused) => Err((
|
||||
Error::with_message(Errno::ECONNREFUSED, "the connection is refused"),
|
||||
NonConnectedStream::Init(InitStream::new_bound(self.bound_socket)),
|
||||
)),
|
||||
None => Err((
|
||||
Error::with_message(Errno::EAGAIN, "the connection is pending"),
|
||||
NonConnectedStream::Connecting(self),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
|
||||
self.bound_socket
|
||||
.local_endpoint()
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "no local endpoint"))
|
||||
pub fn local_endpoint(&self) -> IpEndpoint {
|
||||
self.bound_socket.local_endpoint().unwrap()
|
||||
}
|
||||
|
||||
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
|
||||
Ok(self.remote_endpoint)
|
||||
pub fn remote_endpoint(&self) -> IpEndpoint {
|
||||
self.remote_endpoint
|
||||
}
|
||||
|
||||
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.pollee.poll(mask, poller)
|
||||
pub(super) fn init_pollee(&self, pollee: &Pollee) {
|
||||
pollee.reset_events();
|
||||
self.update_io_events(pollee);
|
||||
}
|
||||
|
||||
pub fn is_nonblocking(&self) -> bool {
|
||||
self.nonblocking.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn set_nonblocking(&self, nonblocking: bool) {
|
||||
self.nonblocking.store(nonblocking, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn update_io_events(&self) {
|
||||
pub(super) fn update_io_events(&self, pollee: &Pollee) {
|
||||
if self.conn_result.read().is_some() {
|
||||
return;
|
||||
}
|
||||
@ -143,13 +104,7 @@ impl ConnectingStream {
|
||||
// be responsible to initialize all the I/O events including `IoEvents::OUT`, so the
|
||||
// following hard-coded event addition can be removed.
|
||||
if became_writable {
|
||||
self.pollee.add_events(IoEvents::OUT);
|
||||
pollee.add_events(IoEvents::OUT);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Observer<()> for ConnectingStream {
|
||||
fn on_events(&self, _: &()) {
|
||||
self.update_io_events();
|
||||
}
|
||||
}
|
||||
|
@ -1,156 +1,93 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use core::sync::atomic::{AtomicBool, Ordering};
|
||||
use alloc::sync::Weak;
|
||||
|
||||
use super::{connecting::ConnectingStream, listen::ListenStream};
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
events::Observer,
|
||||
net::{
|
||||
iface::{AnyBoundSocket, AnyUnboundSocket, Iface, IpEndpoint},
|
||||
socket::ip::{
|
||||
always_some::AlwaysSome,
|
||||
common::{bind_socket, get_ephemeral_endpoint},
|
||||
},
|
||||
iface::{AnyBoundSocket, AnyUnboundSocket, IpEndpoint},
|
||||
socket::ip::common::{bind_socket, get_ephemeral_endpoint},
|
||||
},
|
||||
prelude::*,
|
||||
process::signal::{Pollee, Poller},
|
||||
};
|
||||
|
||||
pub struct InitStream {
|
||||
inner: RwLock<Inner>,
|
||||
is_nonblocking: AtomicBool,
|
||||
pollee: Pollee,
|
||||
}
|
||||
|
||||
enum Inner {
|
||||
Unbound(AlwaysSome<Box<AnyUnboundSocket>>),
|
||||
Bound(AlwaysSome<Arc<AnyBoundSocket>>),
|
||||
}
|
||||
|
||||
impl Inner {
|
||||
fn new() -> Inner {
|
||||
let unbound_socket = Box::new(AnyUnboundSocket::new_tcp());
|
||||
Inner::Unbound(AlwaysSome::new(unbound_socket))
|
||||
}
|
||||
|
||||
fn is_bound(&self) -> bool {
|
||||
match self {
|
||||
Self::Unbound(_) => false,
|
||||
Self::Bound(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn bind(&mut self, endpoint: IpEndpoint) -> Result<()> {
|
||||
let unbound_socket = if let Inner::Unbound(unbound_socket) = self {
|
||||
unbound_socket
|
||||
} else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address");
|
||||
};
|
||||
let bound_socket =
|
||||
unbound_socket.try_take_with(|raw_socket| bind_socket(raw_socket, endpoint, false))?;
|
||||
*self = Inner::Bound(AlwaysSome::new(bound_socket));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn bind_to_ephemeral_endpoint(&mut self, remote_endpoint: &IpEndpoint) -> Result<()> {
|
||||
let endpoint = get_ephemeral_endpoint(remote_endpoint);
|
||||
self.bind(endpoint)
|
||||
}
|
||||
|
||||
fn bound_socket(&self) -> Option<&Arc<AnyBoundSocket>> {
|
||||
match self {
|
||||
Inner::Bound(bound_socket) => Some(bound_socket),
|
||||
Inner::Unbound(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn iface(&self) -> Option<Arc<dyn Iface>> {
|
||||
match self {
|
||||
Inner::Bound(bound_socket) => Some(bound_socket.iface().clone()),
|
||||
Inner::Unbound(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn local_endpoint(&self) -> Option<IpEndpoint> {
|
||||
self.bound_socket()
|
||||
.and_then(|socket| socket.local_endpoint())
|
||||
}
|
||||
pub enum InitStream {
|
||||
Unbound(Box<AnyUnboundSocket>),
|
||||
Bound(Arc<AnyBoundSocket>),
|
||||
}
|
||||
|
||||
impl InitStream {
|
||||
// FIXME: In Linux we have the `POLLOUT` event for a newly created socket, while calling
|
||||
// `write()` on it triggers `SIGPIPE`/`EPIPE`. No documentation found yet, but confirmed by
|
||||
// experimentation and Linux source code.
|
||||
pub fn new(nonblocking: bool) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
inner: RwLock::new(Inner::new()),
|
||||
is_nonblocking: AtomicBool::new(nonblocking),
|
||||
pollee: Pollee::new(IoEvents::empty()),
|
||||
})
|
||||
pub fn new(observer: Weak<dyn Observer<()>>) -> Self {
|
||||
InitStream::Unbound(Box::new(AnyUnboundSocket::new_tcp(observer)))
|
||||
}
|
||||
|
||||
pub fn new_bound(
|
||||
nonblocking: bool,
|
||||
bound_socket: Arc<AnyBoundSocket>,
|
||||
pollee: Pollee,
|
||||
) -> Arc<Self> {
|
||||
bound_socket.set_observer(Weak::<()>::new());
|
||||
let inner = Inner::Bound(AlwaysSome::new(bound_socket));
|
||||
Arc::new(Self {
|
||||
is_nonblocking: AtomicBool::new(nonblocking),
|
||||
inner: RwLock::new(inner),
|
||||
pollee,
|
||||
})
|
||||
pub fn new_bound(bound_socket: Arc<AnyBoundSocket>) -> Self {
|
||||
InitStream::Bound(bound_socket)
|
||||
}
|
||||
|
||||
pub fn bind(&self, endpoint: IpEndpoint) -> Result<()> {
|
||||
self.inner.write().bind(endpoint)
|
||||
}
|
||||
|
||||
pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<ConnectingStream>> {
|
||||
if !self.inner.read().is_bound() {
|
||||
self.inner
|
||||
.write()
|
||||
.bind_to_ephemeral_endpoint(remote_endpoint)?
|
||||
}
|
||||
ConnectingStream::new(
|
||||
self.is_nonblocking(),
|
||||
self.inner.read().bound_socket().unwrap().clone(),
|
||||
*remote_endpoint,
|
||||
self.pollee.clone(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn listen(&self, backlog: usize) -> Result<Arc<ListenStream>> {
|
||||
let bound_socket = if let Some(bound_socket) = self.inner.read().bound_socket() {
|
||||
bound_socket.clone()
|
||||
} else {
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot listen without bound")
|
||||
pub fn bind(
|
||||
self,
|
||||
endpoint: &IpEndpoint,
|
||||
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Self)> {
|
||||
let unbound_socket = match self {
|
||||
InitStream::Unbound(unbound_socket) => unbound_socket,
|
||||
InitStream::Bound(bound_socket) => {
|
||||
return Err((
|
||||
Error::with_message(Errno::EINVAL, "the socket is already bound to an address"),
|
||||
InitStream::Bound(bound_socket),
|
||||
));
|
||||
}
|
||||
};
|
||||
ListenStream::new(
|
||||
self.is_nonblocking(),
|
||||
bound_socket,
|
||||
backlog,
|
||||
self.pollee.clone(),
|
||||
)
|
||||
let bound_socket = match bind_socket(unbound_socket, endpoint, false) {
|
||||
Ok(bound_socket) => bound_socket,
|
||||
Err((err, unbound_socket)) => return Err((err, InitStream::Unbound(unbound_socket))),
|
||||
};
|
||||
Ok(bound_socket)
|
||||
}
|
||||
|
||||
fn bind_to_ephemeral_endpoint(
|
||||
self,
|
||||
remote_endpoint: &IpEndpoint,
|
||||
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, Self)> {
|
||||
let endpoint = get_ephemeral_endpoint(remote_endpoint);
|
||||
self.bind(&endpoint)
|
||||
}
|
||||
|
||||
pub fn connect(
|
||||
self,
|
||||
remote_endpoint: &IpEndpoint,
|
||||
) -> core::result::Result<ConnectingStream, (Error, Self)> {
|
||||
let bound_socket = match self {
|
||||
InitStream::Bound(bound_socket) => bound_socket,
|
||||
InitStream::Unbound(_) => self.bind_to_ephemeral_endpoint(remote_endpoint)?,
|
||||
};
|
||||
|
||||
ConnectingStream::new(bound_socket, *remote_endpoint)
|
||||
.map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket)))
|
||||
}
|
||||
|
||||
pub fn listen(self, backlog: usize) -> core::result::Result<ListenStream, (Error, Self)> {
|
||||
let InitStream::Bound(bound_socket) = self else {
|
||||
return Err((
|
||||
Error::with_message(Errno::EINVAL, "cannot listen without bound"),
|
||||
self,
|
||||
));
|
||||
};
|
||||
|
||||
ListenStream::new(bound_socket, backlog)
|
||||
.map_err(|(err, bound_socket)| (err, InitStream::Bound(bound_socket)))
|
||||
}
|
||||
|
||||
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
|
||||
self.inner
|
||||
.read()
|
||||
.local_endpoint()
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has local endpoint"))
|
||||
}
|
||||
|
||||
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.pollee.poll(mask, poller)
|
||||
}
|
||||
|
||||
pub fn is_nonblocking(&self) -> bool {
|
||||
self.is_nonblocking.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn set_nonblocking(&self, nonblocking: bool) {
|
||||
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
|
||||
match self {
|
||||
InitStream::Unbound(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "does not has local endpoint")
|
||||
}
|
||||
InitStream::Bound(bound_socket) => Ok(bound_socket.local_endpoint().unwrap()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,148 +1,97 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use core::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use super::connected::ConnectedStream;
|
||||
use crate::{
|
||||
events::{IoEvents, Observer},
|
||||
net::{
|
||||
iface::{AnyBoundSocket, AnyUnboundSocket, BindPortConfig, IpEndpoint, RawTcpSocket},
|
||||
poll_ifaces,
|
||||
},
|
||||
events::IoEvents,
|
||||
net::iface::{AnyBoundSocket, AnyUnboundSocket, BindPortConfig, IpEndpoint, RawTcpSocket},
|
||||
prelude::*,
|
||||
process::signal::{Pollee, Poller},
|
||||
process::signal::Pollee,
|
||||
};
|
||||
|
||||
pub struct ListenStream {
|
||||
is_nonblocking: AtomicBool,
|
||||
backlog: usize,
|
||||
/// A bound socket held to ensure the TCP port cannot be released
|
||||
bound_socket: Arc<AnyBoundSocket>,
|
||||
/// Backlog sockets listening at the local endpoint
|
||||
backlog_sockets: RwLock<Vec<BacklogSocket>>,
|
||||
pollee: Pollee,
|
||||
}
|
||||
|
||||
impl ListenStream {
|
||||
pub fn new(
|
||||
nonblocking: bool,
|
||||
bound_socket: Arc<AnyBoundSocket>,
|
||||
backlog: usize,
|
||||
pollee: Pollee,
|
||||
) -> Result<Arc<Self>> {
|
||||
let listen_stream = Arc::new(Self {
|
||||
is_nonblocking: AtomicBool::new(nonblocking),
|
||||
) -> core::result::Result<Self, (Error, Arc<AnyBoundSocket>)> {
|
||||
let listen_stream = Self {
|
||||
backlog,
|
||||
bound_socket,
|
||||
backlog_sockets: RwLock::new(Vec::new()),
|
||||
pollee,
|
||||
});
|
||||
listen_stream.fill_backlog_sockets()?;
|
||||
listen_stream.pollee.reset_events();
|
||||
listen_stream
|
||||
.bound_socket
|
||||
.set_observer(Arc::downgrade(&listen_stream) as _);
|
||||
Ok(listen_stream)
|
||||
}
|
||||
|
||||
pub fn accept(&self) -> Result<(Arc<ConnectedStream>, IpEndpoint)> {
|
||||
// wait to accept
|
||||
let poller = Poller::new();
|
||||
loop {
|
||||
poll_ifaces();
|
||||
let accepted_socket = if let Some(accepted_socket) = self.try_accept() {
|
||||
accepted_socket
|
||||
} else {
|
||||
let events = self.poll(IoEvents::IN, Some(&poller));
|
||||
if !events.contains(IoEvents::IN) {
|
||||
if self.is_nonblocking() {
|
||||
return_errno_with_message!(Errno::EAGAIN, "try accept again");
|
||||
}
|
||||
// FIXME: deal with accept timeout
|
||||
poller.wait()?;
|
||||
}
|
||||
continue;
|
||||
};
|
||||
let remote_endpoint = accepted_socket.remote_endpoint().unwrap();
|
||||
let connected_stream = {
|
||||
let BacklogSocket {
|
||||
bound_socket: backlog_socket,
|
||||
} = accepted_socket;
|
||||
ConnectedStream::new(
|
||||
false,
|
||||
backlog_socket,
|
||||
remote_endpoint,
|
||||
Pollee::new(IoEvents::empty()),
|
||||
)
|
||||
};
|
||||
return Ok((connected_stream, remote_endpoint));
|
||||
};
|
||||
if let Err(err) = listen_stream.fill_backlog_sockets() {
|
||||
return Err((err, listen_stream.bound_socket));
|
||||
}
|
||||
Ok(listen_stream)
|
||||
}
|
||||
|
||||
/// Append sockets listening at LocalEndPoint to support backlog
|
||||
fn fill_backlog_sockets(&self) -> Result<()> {
|
||||
let backlog = self.backlog;
|
||||
let mut backlog_sockets = self.backlog_sockets.write();
|
||||
|
||||
let backlog = self.backlog;
|
||||
let current_backlog_len = backlog_sockets.len();
|
||||
debug_assert!(backlog >= current_backlog_len);
|
||||
if backlog == current_backlog_len {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for _ in current_backlog_len..backlog {
|
||||
let backlog_socket = BacklogSocket::new(&self.bound_socket)?;
|
||||
backlog_sockets.push(backlog_socket);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn try_accept(&self) -> Option<BacklogSocket> {
|
||||
let backlog_socket = {
|
||||
let mut backlog_sockets = self.backlog_sockets.write();
|
||||
let index = backlog_sockets
|
||||
.iter()
|
||||
.position(|backlog_socket| backlog_socket.is_active())?;
|
||||
backlog_sockets.remove(index)
|
||||
};
|
||||
self.fill_backlog_sockets().unwrap();
|
||||
self.update_io_events();
|
||||
Some(backlog_socket)
|
||||
pub fn try_accept(&self) -> Result<ConnectedStream> {
|
||||
let mut backlog_sockets = self.backlog_sockets.write();
|
||||
|
||||
let index = backlog_sockets
|
||||
.iter()
|
||||
.position(|backlog_socket| backlog_socket.is_active())
|
||||
.ok_or_else(|| Error::with_message(Errno::EAGAIN, "try to accept again"))?;
|
||||
let active_backlog_socket = backlog_sockets.remove(index);
|
||||
|
||||
match BacklogSocket::new(&self.bound_socket) {
|
||||
Ok(backlog_socket) => backlog_sockets.push(backlog_socket),
|
||||
Err(err) => (),
|
||||
}
|
||||
|
||||
let remote_endpoint = active_backlog_socket.remote_endpoint().unwrap();
|
||||
Ok(ConnectedStream::new(
|
||||
active_backlog_socket.into_bound_socket(),
|
||||
remote_endpoint,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
|
||||
self.bound_socket
|
||||
.local_endpoint()
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint"))
|
||||
pub fn local_endpoint(&self) -> IpEndpoint {
|
||||
self.bound_socket.local_endpoint().unwrap()
|
||||
}
|
||||
|
||||
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.pollee.poll(mask, poller)
|
||||
pub(super) fn init_pollee(&self, pollee: &Pollee) {
|
||||
pollee.reset_events();
|
||||
self.update_io_events(pollee);
|
||||
}
|
||||
|
||||
fn update_io_events(&self) {
|
||||
pub(super) fn update_io_events(&self, pollee: &Pollee) {
|
||||
// The lock should be held to avoid data races
|
||||
let backlog_sockets = self.backlog_sockets.read();
|
||||
|
||||
let can_accept = backlog_sockets.iter().any(|socket| socket.is_active());
|
||||
if can_accept {
|
||||
self.pollee.add_events(IoEvents::IN);
|
||||
pollee.add_events(IoEvents::IN);
|
||||
} else {
|
||||
self.pollee.del_events(IoEvents::IN);
|
||||
pollee.del_events(IoEvents::IN);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_nonblocking(&self) -> bool {
|
||||
self.is_nonblocking.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn set_nonblocking(&self, nonblocking: bool) {
|
||||
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
impl Observer<()> for ListenStream {
|
||||
fn on_events(&self, _: &()) {
|
||||
self.update_io_events();
|
||||
}
|
||||
}
|
||||
|
||||
struct BacklogSocket {
|
||||
@ -155,19 +104,21 @@ impl BacklogSocket {
|
||||
Errno::EINVAL,
|
||||
"the socket is not bound",
|
||||
))?;
|
||||
let unbound_socket = Box::new(AnyUnboundSocket::new_tcp());
|
||||
|
||||
let unbound_socket = Box::new(AnyUnboundSocket::new_tcp(Weak::<()>::new()));
|
||||
let bound_socket = {
|
||||
let iface = bound_socket.iface();
|
||||
let bind_port_config = BindPortConfig::new(local_endpoint.port, true)?;
|
||||
iface
|
||||
.bind_socket(unbound_socket, bind_port_config)
|
||||
.map_err(|(e, _)| e)?
|
||||
.map_err(|(err, _)| err)?
|
||||
};
|
||||
bound_socket.raw_with(|raw_tcp_socket: &mut RawTcpSocket| {
|
||||
raw_tcp_socket
|
||||
.listen(local_endpoint)
|
||||
.map_err(|_| Error::with_message(Errno::EINVAL, "fail to listen"))
|
||||
})?;
|
||||
|
||||
Ok(Self { bound_socket })
|
||||
}
|
||||
|
||||
@ -180,4 +131,8 @@ impl BacklogSocket {
|
||||
self.bound_socket
|
||||
.raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint())
|
||||
}
|
||||
|
||||
fn into_bound_socket(self) -> Arc<AnyBoundSocket> {
|
||||
self.bound_socket
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,10 @@
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
use core::{
|
||||
mem,
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
};
|
||||
|
||||
use connected::ConnectedStream;
|
||||
use connecting::ConnectingStream;
|
||||
use init::InitStream;
|
||||
@ -9,21 +14,24 @@ use smoltcp::wire::IpEndpoint;
|
||||
use util::{TcpOptionSet, DEFAULT_MAXSEG};
|
||||
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
events::{IoEvents, Observer},
|
||||
fs::{file_handle::FileLike, utils::StatusFlags},
|
||||
match_sock_option_mut, match_sock_option_ref,
|
||||
net::socket::{
|
||||
options::{Error, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption},
|
||||
util::{
|
||||
options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF},
|
||||
send_recv_flags::SendRecvFlags,
|
||||
shutdown_cmd::SockShutdownCmd,
|
||||
socket_addr::SocketAddr,
|
||||
net::{
|
||||
poll_ifaces,
|
||||
socket::{
|
||||
options::{Error, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption},
|
||||
util::{
|
||||
options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF},
|
||||
send_recv_flags::SendRecvFlags,
|
||||
shutdown_cmd::SockShutdownCmd,
|
||||
socket_addr::SocketAddr,
|
||||
},
|
||||
Socket,
|
||||
},
|
||||
Socket,
|
||||
},
|
||||
prelude::*,
|
||||
process::signal::Poller,
|
||||
process::signal::{Pollee, Poller},
|
||||
};
|
||||
|
||||
mod connected;
|
||||
@ -33,22 +41,27 @@ mod listen;
|
||||
pub mod options;
|
||||
mod util;
|
||||
|
||||
use self::connecting::NonConnectedStream;
|
||||
pub use self::util::CongestionControl;
|
||||
|
||||
pub struct StreamSocket {
|
||||
options: RwLock<OptionSet>,
|
||||
state: RwLock<State>,
|
||||
is_nonblocking: AtomicBool,
|
||||
pollee: Pollee,
|
||||
}
|
||||
|
||||
enum State {
|
||||
// Start state
|
||||
Init(Arc<InitStream>),
|
||||
Init(InitStream),
|
||||
// Intermediate state
|
||||
Connecting(Arc<ConnectingStream>),
|
||||
Connecting(ConnectingStream),
|
||||
// Final State 1
|
||||
Connected(Arc<ConnectedStream>),
|
||||
Connected(ConnectedStream),
|
||||
// Final State 2
|
||||
Listen(Arc<ListenStream>),
|
||||
Listen(ListenStream),
|
||||
// Poisoned state
|
||||
Poisoned,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -66,45 +79,159 @@ impl OptionSet {
|
||||
}
|
||||
|
||||
impl StreamSocket {
|
||||
pub fn new(nonblocking: bool) -> Self {
|
||||
let options = OptionSet::new();
|
||||
let state = State::Init(InitStream::new(nonblocking));
|
||||
Self {
|
||||
options: RwLock::new(options),
|
||||
state: RwLock::new(state),
|
||||
}
|
||||
pub fn new(nonblocking: bool) -> Arc<Self> {
|
||||
Arc::new_cyclic(|me| {
|
||||
let init_stream = InitStream::new(me.clone() as _);
|
||||
let pollee = Pollee::new(IoEvents::empty());
|
||||
Self {
|
||||
options: RwLock::new(OptionSet::new()),
|
||||
state: RwLock::new(State::Init(init_stream)),
|
||||
is_nonblocking: AtomicBool::new(nonblocking),
|
||||
pollee,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn new_connected(connected_stream: ConnectedStream) -> Arc<Self> {
|
||||
Arc::new_cyclic(move |me| {
|
||||
let pollee = Pollee::new(IoEvents::empty());
|
||||
connected_stream.set_observer(me.clone() as _);
|
||||
connected_stream.init_pollee(&pollee);
|
||||
Self {
|
||||
options: RwLock::new(OptionSet::new()),
|
||||
state: RwLock::new(State::Connected(connected_stream)),
|
||||
is_nonblocking: AtomicBool::new(false),
|
||||
pollee,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn is_nonblocking(&self) -> bool {
|
||||
match &*self.state.read() {
|
||||
State::Init(init) => init.is_nonblocking(),
|
||||
State::Connecting(connecting) => connecting.is_nonblocking(),
|
||||
State::Connected(connected) => connected.is_nonblocking(),
|
||||
State::Listen(listen) => listen.is_nonblocking(),
|
||||
}
|
||||
self.is_nonblocking.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn set_nonblocking(&self, nonblocking: bool) {
|
||||
match &*self.state.read() {
|
||||
State::Init(init) => init.set_nonblocking(nonblocking),
|
||||
State::Connecting(connecting) => connecting.set_nonblocking(nonblocking),
|
||||
State::Connected(connected) => connected.set_nonblocking(nonblocking),
|
||||
State::Listen(listen) => listen.set_nonblocking(nonblocking),
|
||||
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
|
||||
let mut state = self.state.write();
|
||||
|
||||
let owned_state = mem::replace(&mut *state, State::Poisoned);
|
||||
let State::Init(init_stream) = owned_state else {
|
||||
*state = owned_state;
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot connect")
|
||||
};
|
||||
|
||||
let connecting_stream = match init_stream.connect(remote_endpoint) {
|
||||
Ok(connecting_stream) => connecting_stream,
|
||||
Err((err, init_stream)) => {
|
||||
*state = State::Init(init_stream);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
connecting_stream.init_pollee(&self.pollee);
|
||||
*state = State::Connecting(connecting_stream);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn finish_connect(&self) -> Result<()> {
|
||||
let mut state = self.state.write();
|
||||
|
||||
let owned_state = mem::replace(&mut *state, State::Poisoned);
|
||||
let State::Connecting(connecting_stream) = owned_state else {
|
||||
*state = owned_state;
|
||||
debug_assert!(false, "the socket unexpectedly left the connecting state");
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not connecting");
|
||||
};
|
||||
|
||||
let connected_stream = match connecting_stream.into_result() {
|
||||
Ok(connected_stream) => connected_stream,
|
||||
Err((err, NonConnectedStream::Init(init_stream))) => {
|
||||
*state = State::Init(init_stream);
|
||||
return Err(err);
|
||||
}
|
||||
Err((err, NonConnectedStream::Connecting(connecting_stream))) => {
|
||||
*state = State::Connecting(connecting_stream);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
connected_stream.init_pollee(&self.pollee);
|
||||
*state = State::Connected(connected_stream);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
||||
let state = self.state.read();
|
||||
|
||||
let State::Listen(listen_stream) = &*state else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not listening");
|
||||
};
|
||||
|
||||
let connected_stream = listen_stream.try_accept()?;
|
||||
listen_stream.update_io_events(&self.pollee);
|
||||
|
||||
let remote_endpoint = connected_stream.remote_endpoint();
|
||||
let accepted_socket = Self::new_connected(connected_stream);
|
||||
Ok((accepted_socket, remote_endpoint.try_into()?))
|
||||
}
|
||||
|
||||
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
|
||||
let state = self.state.read();
|
||||
|
||||
let State::Connected(connected_stream) = &*state else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
||||
};
|
||||
let recv_bytes = connected_stream.try_recvfrom(buf, flags)?;
|
||||
connected_stream.update_io_events(&self.pollee);
|
||||
Ok((recv_bytes, connected_stream.remote_endpoint().try_into()?))
|
||||
}
|
||||
|
||||
fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
|
||||
let state = self.state.read();
|
||||
|
||||
let State::Connected(connected_stream) = &*state else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not connected");
|
||||
};
|
||||
let sent_bytes = connected_stream.try_sendto(buf, flags)?;
|
||||
connected_stream.update_io_events(&self.pollee);
|
||||
Ok(sent_bytes)
|
||||
}
|
||||
|
||||
// TODO: Support timeout
|
||||
fn wait_events<F, R>(&self, mask: IoEvents, mut cond: F) -> Result<R>
|
||||
where
|
||||
F: FnMut() -> Result<R>,
|
||||
{
|
||||
let poller = Poller::new();
|
||||
|
||||
loop {
|
||||
match cond() {
|
||||
Err(err) if err.error() == Errno::EAGAIN => (),
|
||||
result => return result,
|
||||
};
|
||||
|
||||
let events = self.poll(mask, Some(&poller));
|
||||
if !events.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
poller.wait()?;
|
||||
}
|
||||
}
|
||||
|
||||
fn do_connect(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<ConnectingStream>> {
|
||||
let mut state = self.state.write();
|
||||
let init_stream = match &*state {
|
||||
State::Init(init_stream) => init_stream,
|
||||
State::Listen(_) | State::Connecting(_) | State::Connected(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot connect")
|
||||
fn update_io_events(&self) {
|
||||
let state = self.state.read();
|
||||
match &*state {
|
||||
State::Init(_) | State::Poisoned => (),
|
||||
State::Connecting(connecting_stream) => {
|
||||
connecting_stream.update_io_events(&self.pollee)
|
||||
}
|
||||
};
|
||||
|
||||
let connecting = init_stream.connect(remote_endpoint)?;
|
||||
*state = State::Connecting(connecting.clone());
|
||||
Ok(connecting)
|
||||
State::Listen(listen_stream) => listen_stream.update_io_events(&self.pollee),
|
||||
State::Connected(connected_stream) => connected_stream.update_io_events(&self.pollee),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -123,13 +250,7 @@ impl FileLike for StreamSocket {
|
||||
}
|
||||
|
||||
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
let state = self.state.read();
|
||||
match &*state {
|
||||
State::Init(init) => init.poll(mask, poller),
|
||||
State::Connecting(connecting) => connecting.poll(mask, poller),
|
||||
State::Connected(connected) => connected.poll(mask, poller),
|
||||
State::Listen(listen) => listen.poll(mask, poller),
|
||||
}
|
||||
self.pollee.poll(mask, poller)
|
||||
}
|
||||
|
||||
fn status_flags(&self) -> StatusFlags {
|
||||
@ -157,68 +278,65 @@ impl FileLike for StreamSocket {
|
||||
impl Socket for StreamSocket {
|
||||
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
|
||||
let endpoint = socket_addr.try_into()?;
|
||||
let state = self.state.read();
|
||||
match &*state {
|
||||
State::Init(init_stream) => init_stream.bind(endpoint),
|
||||
_ => return_errno_with_message!(Errno::EINVAL, "cannot bind"),
|
||||
}
|
||||
|
||||
let mut state = self.state.write();
|
||||
|
||||
let owned_state = mem::replace(&mut *state, State::Poisoned);
|
||||
let State::Init(init_stream) = owned_state else {
|
||||
*state = owned_state;
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot bind");
|
||||
};
|
||||
|
||||
let bound_socket = match init_stream.bind(&endpoint) {
|
||||
Ok(bound_socket) => bound_socket,
|
||||
Err((err, init_stream)) => {
|
||||
*state = State::Init(init_stream);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
*state = State::Init(InitStream::new_bound(bound_socket));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TODO: Support nonblocking mode
|
||||
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
|
||||
let remote_endpoint = socket_addr.try_into()?;
|
||||
self.start_connect(&remote_endpoint)?;
|
||||
|
||||
let connecting_stream = self.do_connect(&remote_endpoint)?;
|
||||
match connecting_stream.wait_conn() {
|
||||
Ok(connected_stream) => {
|
||||
*self.state.write() = State::Connected(connected_stream);
|
||||
Ok(())
|
||||
}
|
||||
Err((err, init_stream)) => {
|
||||
*self.state.write() = State::Init(init_stream);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
poll_ifaces();
|
||||
self.wait_events(IoEvents::OUT, || self.finish_connect())
|
||||
}
|
||||
|
||||
fn listen(&self, backlog: usize) -> Result<()> {
|
||||
let mut state = self.state.write();
|
||||
let init_stream = match &*state {
|
||||
State::Init(init_stream) => init_stream,
|
||||
State::Connecting(connecting_stream) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot listen for a connecting stream")
|
||||
}
|
||||
State::Listen(listen_stream) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot listen for a listening stream")
|
||||
}
|
||||
State::Connected(_) => return_errno_with_message!(Errno::EINVAL, "cannot listen"),
|
||||
|
||||
let owned_state = mem::replace(&mut *state, State::Poisoned);
|
||||
let State::Init(init_stream) = owned_state else {
|
||||
*state = owned_state;
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot listen");
|
||||
};
|
||||
|
||||
let listener = init_stream.listen(backlog)?;
|
||||
*state = State::Listen(listener);
|
||||
let listen_stream = match init_stream.listen(backlog) {
|
||||
Ok(listen_stream) => listen_stream,
|
||||
Err((err, init_stream)) => {
|
||||
*state = State::Init(init_stream);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
listen_stream.init_pollee(&self.pollee);
|
||||
*state = State::Listen(listen_stream);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
||||
let listen_stream = match &*self.state.read() {
|
||||
State::Listen(listen_stream) => listen_stream.clone(),
|
||||
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not listening"),
|
||||
};
|
||||
|
||||
let (connected_stream, remote_endpoint) = {
|
||||
let listen_stream = listen_stream.clone();
|
||||
listen_stream.accept()?
|
||||
};
|
||||
|
||||
let accepted_socket = {
|
||||
let state = RwLock::new(State::Connected(connected_stream));
|
||||
Arc::new(StreamSocket {
|
||||
options: RwLock::new(OptionSet::new()),
|
||||
state,
|
||||
})
|
||||
};
|
||||
|
||||
let socket_addr = remote_endpoint.try_into()?;
|
||||
Ok((accepted_socket, socket_addr))
|
||||
poll_ifaces();
|
||||
if self.is_nonblocking() {
|
||||
self.try_accept()
|
||||
} else {
|
||||
self.wait_events(IoEvents::IN, || self.try_accept())
|
||||
}
|
||||
}
|
||||
|
||||
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
||||
@ -233,11 +351,12 @@ impl Socket for StreamSocket {
|
||||
fn addr(&self) -> Result<SocketAddr> {
|
||||
let state = self.state.read();
|
||||
let local_endpoint = match &*state {
|
||||
State::Init(init_stream) => init_stream.local_endpoint(),
|
||||
State::Init(init_stream) => init_stream.local_endpoint()?,
|
||||
State::Connecting(connecting_stream) => connecting_stream.local_endpoint(),
|
||||
State::Listen(listen_stream) => listen_stream.local_endpoint(),
|
||||
State::Connected(connected_stream) => connected_stream.local_endpoint(),
|
||||
}?;
|
||||
State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"),
|
||||
};
|
||||
local_endpoint.try_into()
|
||||
}
|
||||
|
||||
@ -252,19 +371,20 @@ impl Socket for StreamSocket {
|
||||
return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer")
|
||||
}
|
||||
State::Connected(connected_stream) => connected_stream.remote_endpoint(),
|
||||
}?;
|
||||
State::Poisoned => return_errno_with_message!(Errno::EINVAL, "socket is poisoned"),
|
||||
};
|
||||
remote_endpoint.try_into()
|
||||
}
|
||||
|
||||
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
|
||||
let connected_stream = match &*self.state.read() {
|
||||
State::Connected(connected_stream) => connected_stream.clone(),
|
||||
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not connected"),
|
||||
};
|
||||
debug_assert!(flags.is_all_supported());
|
||||
|
||||
let (recv_size, remote_endpoint) = connected_stream.recvfrom(buf, flags)?;
|
||||
let socket_addr = remote_endpoint.try_into()?;
|
||||
Ok((recv_size, socket_addr))
|
||||
poll_ifaces();
|
||||
if self.is_nonblocking() {
|
||||
self.try_recvfrom(buf, flags)
|
||||
} else {
|
||||
self.wait_events(IoEvents::IN, || self.try_recvfrom(buf, flags))
|
||||
}
|
||||
}
|
||||
|
||||
fn sendto(
|
||||
@ -273,16 +393,19 @@ impl Socket for StreamSocket {
|
||||
remote: Option<SocketAddr>,
|
||||
flags: SendRecvFlags,
|
||||
) -> Result<usize> {
|
||||
debug_assert!(remote.is_none());
|
||||
debug_assert!(flags.is_all_supported());
|
||||
|
||||
if remote.is_some() {
|
||||
return_errno_with_message!(Errno::EINVAL, "tcp socked should not provide remote addr");
|
||||
}
|
||||
|
||||
let connected_stream = match &*self.state.read() {
|
||||
State::Connected(connected_stream) => connected_stream.clone(),
|
||||
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not connected"),
|
||||
let sent_bytes = if self.is_nonblocking() {
|
||||
self.try_sendto(buf, flags)?
|
||||
} else {
|
||||
self.wait_events(IoEvents::OUT, || self.try_sendto(buf, flags))?
|
||||
};
|
||||
connected_stream.sendto(buf, flags)
|
||||
poll_ifaces();
|
||||
Ok(sent_bytes)
|
||||
}
|
||||
|
||||
fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> {
|
||||
@ -324,7 +447,9 @@ impl Socket for StreamSocket {
|
||||
|
||||
// FIXME: how to get the current MSS?
|
||||
let maxseg = match &*self.state.read() {
|
||||
State::Init(_) | State::Listen(_) | State::Connecting(_) => DEFAULT_MAXSEG,
|
||||
State::Init(_) | State::Listen(_) | State::Connecting(_) | State::Poisoned => {
|
||||
DEFAULT_MAXSEG
|
||||
}
|
||||
State::Connected(_) => options.tcp.maxseg(),
|
||||
};
|
||||
tcp_maxseg.set(maxseg);
|
||||
@ -348,7 +473,7 @@ impl Socket for StreamSocket {
|
||||
let recv_buf = socket_recv_buf.get().unwrap();
|
||||
if *recv_buf <= MIN_RECVBUF {
|
||||
options.socket.set_recv_buf(MIN_RECVBUF);
|
||||
} else{
|
||||
} else {
|
||||
options.socket.set_recv_buf(*recv_buf);
|
||||
}
|
||||
},
|
||||
@ -406,3 +531,9 @@ impl Socket for StreamSocket {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Observer<()> for StreamSocket {
|
||||
fn on_events(&self, events: &()) {
|
||||
self.update_io_events();
|
||||
}
|
||||
}
|
||||
|
@ -31,12 +31,12 @@ pub fn sys_socket(domain: i32, type_: i32, protocol: i32) -> Result<SyscallRetur
|
||||
CSocketAddrFamily::AF_INET,
|
||||
SockType::SOCK_STREAM,
|
||||
Protocol::IPPROTO_IP | Protocol::IPPROTO_TCP,
|
||||
) => Arc::new(StreamSocket::new(nonblocking)) as Arc<dyn FileLike>,
|
||||
) => StreamSocket::new(nonblocking) as Arc<dyn FileLike>,
|
||||
(
|
||||
CSocketAddrFamily::AF_INET,
|
||||
SockType::SOCK_DGRAM,
|
||||
Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP,
|
||||
) => Arc::new(DatagramSocket::new(nonblocking)) as Arc<dyn FileLike>,
|
||||
) => DatagramSocket::new(nonblocking) as Arc<dyn FileLike>,
|
||||
_ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported domain"),
|
||||
};
|
||||
let fd = {
|
||||
|
Loading…
x
Reference in New Issue
Block a user