Implement iface event observers and move Pollee to them

Finally, this commit implements an iface event observer trait for the
`ConnectingStream`, `ListenStream`, and `ConnectedStream` states in
`StreamSocket`, as well as the `BoundDatagram` state in
`DatagramSocket`. It also moves the `Pollee` from `AnyBoundSocket` to
these observer implementors.

What I have tried to do is minimize the semantic changes. Ideally, this
commit should be a pure refactor commit, meaning that even if the
sematics of the previous code is wrong, the sematics after this commit
should be wrong in the same way. Fixing the wrong sematics should be
done in a separate commit afterwards.

However, keeping exactly the same sematics for `ListenStream` is hard.
We used to maintain a `Pollee` for each `BacklogSocket`, but now we can
just maintain one `Pollee` for the whole `ListenStream`. However,
implementing the correct semantics looks much easier, so we just do it.

For `ConnectingStream`, it used to share the same `Pollee` logic with
`ConnectedStream` (because the `Pollee` was maintained by
`AnyBoundSocket`, which is used by both). Now we write the `Pollee`
logic separately for `ConnectingStream`, so we can just write a correct
one or try to reuse the logic in `ConnectingStream`. This commit does
the former.

There should be no semantic changes for `ConnectedStream` in
`StreamSocket` and `BoundDatagram` in `DatagramSocket`.
This commit is contained in:
Ruihan Li
2023-11-19 23:54:58 +08:00
committed by Tate, Hongliang Tian
parent 6b903d0c10
commit 0cc9c5fb3a
10 changed files with 349 additions and 179 deletions

View File

@ -1,6 +1,14 @@
/// A trait to represent any events.
///
/// # The unit event
///
/// The unit type `()` can serve as a unit event.
/// It can be used if there is only one kind of event
/// and the event carries no additional information.
pub trait Events: Copy + Clone + Send + Sync + 'static {}
impl Events for () {}
/// A trait to filter events.
///
/// # The no-op event filter

View File

@ -5,7 +5,31 @@ use super::Events;
/// In a sense, event observers are just a fancy form of callback functions.
/// An observer's `on_events` methods are supposed to be called when
/// some events that are interesting to the observer happen.
///
/// # The no-op observer
///
/// The unit type `()` can serve as a no-op observer.
/// It implements `Observer<E>` for any events type `E`,
/// with an `on_events` method that simply does nothing.
///
/// It can be used to create an empty `Weak`, as shown in the example below.
/// Using the unit type is necessary, as creating an empty `Weak` needs to
/// have a sized type (e.g. the unit type).
///
/// # Examples
///
/// ```
/// use alloc::sync::Weak;
/// use crate::events::Observer;
///
/// let empty: Weak<dyn Observer<()>> = Weak::<()>::new();
/// assert!(empty.upgrade().is_empty());
/// ```
pub trait Observer<E: Events>: Send + Sync {
/// Notify the observer that some interesting events happen.
fn on_events(&self, events: &E);
}
impl<E: Events> Observer<E> for () {
fn on_events(&self, events: &E) {}
}

View File

@ -1,6 +1,5 @@
use crate::events::IoEvents;
use crate::events::Observer;
use crate::prelude::*;
use crate::process::signal::{Pollee, Poller};
use super::Iface;
use super::{IpAddress, IpEndpoint};
@ -11,7 +10,6 @@ pub type RawSocketHandle = smoltcp::iface::SocketHandle;
pub struct AnyUnboundSocket {
socket_family: AnyRawSocket,
pollee: Pollee,
}
#[allow(clippy::large_enum_variant)]
@ -32,10 +30,8 @@ impl AnyUnboundSocket {
let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; SEND_BUF_LEN]);
RawTcpSocket::new(rx_buffer, tx_buffer)
};
let pollee = Pollee::new(IoEvents::empty());
AnyUnboundSocket {
socket_family: AnyRawSocket::Tcp(raw_tcp_socket),
pollee,
}
}
@ -54,7 +50,6 @@ impl AnyUnboundSocket {
};
AnyUnboundSocket {
socket_family: AnyRawSocket::Udp(raw_udp_socket),
pollee: Pollee::new(IoEvents::empty()),
}
}
@ -68,22 +63,14 @@ impl AnyUnboundSocket {
AnyRawSocket::Udp(_) => SocketFamily::Udp,
}
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
pub(super) fn pollee(&self) -> Pollee {
self.pollee.clone()
}
}
pub struct AnyBoundSocket {
iface: Arc<dyn Iface>,
handle: smoltcp::iface::SocketHandle,
port: u16,
pollee: Pollee,
socket_family: SocketFamily,
observer: RwLock<Weak<dyn Observer<()>>>,
weak_self: Weak<Self>,
}
@ -92,19 +79,36 @@ impl AnyBoundSocket {
iface: Arc<dyn Iface>,
handle: smoltcp::iface::SocketHandle,
port: u16,
pollee: Pollee,
socket_family: SocketFamily,
) -> Arc<Self> {
Arc::new_cyclic(|weak_self| Self {
iface,
handle,
port,
pollee,
socket_family,
observer: RwLock::new(Weak::<()>::new()),
weak_self: weak_self.clone(),
})
}
pub(super) fn on_iface_events(&self) {
if let Some(observer) = Weak::upgrade(&*self.observer.read()) {
observer.on_events(&())
}
}
/// Set the observer whose `on_events` will be called when certain iface events happen. After
/// setting, the new observer will fire once immediately to avoid missing any events.
///
/// If there is an existing observer, due to race conditions, this function does not guarentee
/// that the old observer will never be called after the setting. Users should be aware of this
/// and proactively handle the race conditions if necessary.
pub fn set_observer(&self, handler: Weak<dyn Observer<()>>) {
*self.observer.write() = handler;
self.on_iface_events();
}
pub fn local_endpoint(&self) -> Option<IpEndpoint> {
let ip_addr = {
let ipv4_addr = self.iface.ipv4_addr()?;
@ -135,30 +139,10 @@ impl AnyBoundSocket {
Ok(())
}
pub fn update_socket_state(&self) {
let handle = &self.handle;
let pollee = &self.pollee;
let sockets = self.iface().sockets();
match self.socket_family {
SocketFamily::Tcp => {
let socket = sockets.get::<RawTcpSocket>(*handle);
update_tcp_socket_state(socket, pollee);
}
SocketFamily::Udp => {
let udp_socket = sockets.get::<RawUdpSocket>(*handle);
update_udp_socket_state(udp_socket, pollee);
}
}
}
pub fn iface(&self) -> &Arc<dyn Iface> {
&self.iface
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
pub(super) fn weak_ref(&self) -> Weak<Self> {
self.weak_self.clone()
}
@ -181,34 +165,6 @@ impl Drop for AnyBoundSocket {
}
}
fn update_tcp_socket_state(socket: &RawTcpSocket, pollee: &Pollee) {
if socket.can_recv() {
pollee.add_events(IoEvents::IN);
} else {
pollee.del_events(IoEvents::IN);
}
if socket.can_send() {
pollee.add_events(IoEvents::OUT);
} else {
pollee.del_events(IoEvents::OUT);
}
}
fn update_udp_socket_state(socket: &RawUdpSocket, pollee: &Pollee) {
if socket.can_recv() {
pollee.add_events(IoEvents::IN);
} else {
pollee.del_events(IoEvents::IN);
}
if socket.can_send() {
pollee.add_events(IoEvents::OUT);
} else {
pollee.del_events(IoEvents::OUT);
}
}
// For TCP
const RECV_BUF_LEN: usize = 65536;
const SEND_BUF_LEN: usize = 65536;

View File

@ -113,13 +113,12 @@ impl IfaceCommon {
return Err((e, socket));
}
let socket_family = socket.socket_family();
let pollee = socket.pollee();
let mut sockets = self.sockets.lock_irq_disabled();
let handle = match socket.raw_socket_family() {
AnyRawSocket::Tcp(tcp_socket) => sockets.add(tcp_socket),
AnyRawSocket::Udp(udp_socket) => sockets.add(udp_socket),
};
let bound_socket = AnyBoundSocket::new(iface, handle, port, pollee, socket_family);
let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family);
self.insert_bound_socket(&bound_socket).unwrap();
Ok(bound_socket)
}
@ -140,7 +139,7 @@ impl IfaceCommon {
if has_events {
self.bound_sockets.read().iter().for_each(|bound_socket| {
if let Some(bound_socket) = bound_socket.upgrade() {
bound_socket.update_socket_state();
bound_socket.on_iface_events();
}
});
}

View File

@ -1,10 +1,10 @@
use core::sync::atomic::{AtomicBool, Ordering};
use crate::events::IoEvents;
use crate::events::{IoEvents, Observer};
use crate::fs::utils::StatusFlags;
use crate::net::iface::IpEndpoint;
use crate::process::signal::Poller;
use crate::process::signal::{Pollee, Poller};
use crate::{
fs::file_handle::FileLike,
net::{
@ -27,21 +27,63 @@ pub struct DatagramSocket {
}
enum Inner {
Unbound(AlwaysSome<Box<AnyUnboundSocket>>),
Unbound(AlwaysSome<UnboundDatagram>),
Bound(Arc<BoundDatagram>),
}
struct UnboundDatagram {
unbound_socket: Box<AnyUnboundSocket>,
pollee: Pollee,
}
impl UnboundDatagram {
fn new() -> Self {
Self {
unbound_socket: Box::new(AnyUnboundSocket::new_udp()),
pollee: Pollee::new(IoEvents::empty()),
}
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.pollee.poll(mask, poller)
}
fn bind(self, endpoint: IpEndpoint) -> core::result::Result<Arc<BoundDatagram>, (Error, Self)> {
let bound_socket = match bind_socket(self.unbound_socket, endpoint, false) {
Ok(bound_socket) => bound_socket,
Err((err, unbound_socket)) => {
return Err((
err,
Self {
unbound_socket,
pollee: self.pollee,
},
))
}
};
let bound_endpoint = bound_socket.local_endpoint().unwrap();
bound_socket.raw_with(|socket: &mut RawUdpSocket| {
socket.bind(bound_endpoint).unwrap();
});
Ok(BoundDatagram::new(bound_socket, self.pollee))
}
}
struct BoundDatagram {
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: RwLock<Option<IpEndpoint>>,
pollee: Pollee,
}
impl BoundDatagram {
fn new(bound_socket: Arc<AnyBoundSocket>) -> Arc<Self> {
Arc::new(Self {
fn new(bound_socket: Arc<AnyBoundSocket>, pollee: Pollee) -> Arc<Self> {
let bound = Arc::new(Self {
bound_socket,
remote_endpoint: RwLock::new(None),
})
pollee,
});
bound.bound_socket.set_observer(Arc::downgrade(&bound) as _);
bound
}
fn remote_endpoint(&self) -> Result<IpEndpoint> {
@ -94,11 +136,31 @@ impl BoundDatagram {
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.bound_socket.poll(mask, poller)
self.pollee.poll(mask, poller)
}
fn update_socket_state(&self) {
self.bound_socket.update_socket_state();
fn update_io_events(&self) {
self.bound_socket.raw_with(|socket: &mut RawUdpSocket| {
let pollee = &self.pollee;
if socket.can_recv() {
pollee.add_events(IoEvents::IN);
} else {
pollee.del_events(IoEvents::IN);
}
if socket.can_send() {
pollee.add_events(IoEvents::OUT);
} else {
pollee.del_events(IoEvents::OUT);
}
});
}
}
impl Observer<()> for BoundDatagram {
fn on_events(&self, _: &()) {
self.update_io_events();
}
}
@ -108,25 +170,15 @@ impl Inner {
}
fn bind(&mut self, endpoint: IpEndpoint) -> Result<Arc<BoundDatagram>> {
if self.is_bound() {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address");
}
let unbound_socket = match self {
Inner::Unbound(unbound_socket) => unbound_socket,
_ => unreachable!(),
let unbound = match self {
Inner::Unbound(unbound) => unbound,
Inner::Bound(..) => return_errno_with_message!(
Errno::EINVAL,
"the socket is already bound to an address"
),
};
let bound_socket =
unbound_socket.try_take_with(|socket| bind_socket(socket, endpoint, false))?;
let bound_endpoint = bound_socket.local_endpoint().unwrap();
bound_socket.raw_with(|socket: &mut RawUdpSocket| {
socket
.bind(bound_endpoint)
.map_err(|_| Error::with_message(Errno::EINVAL, "cannot bind socket"))
})?;
let bound = BoundDatagram::new(bound_socket);
let bound = unbound.try_take_with(|unbound| unbound.bind(endpoint))?;
*self = Inner::Bound(bound.clone());
// Once the socket is bound, we should update the socket state at once.
bound.update_socket_state();
Ok(bound)
}
@ -140,7 +192,7 @@ impl Inner {
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
match self {
Inner::Unbound(unbound_socket) => unbound_socket.poll(mask, poller),
Inner::Unbound(unbound) => unbound.poll(mask, poller),
Inner::Bound(bound) => bound.poll(mask, poller),
}
}
@ -148,9 +200,9 @@ impl Inner {
impl DatagramSocket {
pub fn new(nonblocking: bool) -> Self {
let udp_socket = Box::new(AnyUnboundSocket::new_udp());
let unbound = UnboundDatagram::new();
Self {
inner: RwLock::new(Inner::Unbound(AlwaysSome::new(udp_socket))),
inner: RwLock::new(Inner::Unbound(AlwaysSome::new(unbound))),
nonblocking: AtomicBool::new(nonblocking),
}
}

View File

@ -1,8 +1,8 @@
use core::sync::atomic::{AtomicBool, Ordering};
use crate::events::IoEvents;
use crate::events::{IoEvents, Observer};
use crate::net::iface::IpEndpoint;
use crate::process::signal::Poller;
use crate::process::signal::{Pollee, Poller};
use crate::{
net::{
iface::{AnyBoundSocket, RawTcpSocket},
@ -16,6 +16,7 @@ pub struct ConnectedStream {
nonblocking: AtomicBool,
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
pollee: Pollee,
}
impl ConnectedStream {
@ -23,12 +24,18 @@ impl ConnectedStream {
is_nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
) -> Self {
Self {
pollee: Pollee,
) -> Arc<Self> {
let connected = Arc::new(Self {
nonblocking: AtomicBool::new(is_nonblocking),
bound_socket,
remote_endpoint,
}
pollee,
});
connected
.bound_socket
.set_observer(Arc::downgrade(&connected) as _);
connected
}
pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
@ -50,7 +57,7 @@ impl ConnectedStream {
let remote_endpoint = self.remote_endpoint()?;
return Ok((recv_len, remote_endpoint));
}
let events = self.bound_socket.poll(IoEvents::IN, Some(&poller));
let events = self.poll(IoEvents::IN, Some(&poller));
if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) {
return_errno_with_message!(Errno::ENOTCONN, "recv packet fails");
}
@ -71,7 +78,7 @@ impl ConnectedStream {
.recv_slice(buf)
.map_err(|_| Error::with_message(Errno::ENOTCONN, "fail to recv packet"))
});
self.bound_socket.update_socket_state();
self.update_io_events();
res
}
@ -84,7 +91,7 @@ impl ConnectedStream {
if sent_len > 0 {
return Ok(sent_len);
}
let events = self.bound_socket.poll(IoEvents::OUT, Some(&poller));
let events = self.poll(IoEvents::OUT, Some(&poller));
if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) {
return_errno_with_message!(Errno::ENOBUFS, "fail to send packets");
}
@ -104,10 +111,10 @@ impl ConnectedStream {
.raw_with(|socket: &mut RawTcpSocket| socket.send_slice(buf))
.map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet"));
match res {
// We have to explicitly invoke `update_socket_state` when the send buffer becomes
// We have to explicitly invoke `update_io_events` when the send buffer becomes
// full. Note that smoltcp does not think it is an interface event, so calling
// `poll_ifaces` alone is not enough.
Ok(0) => self.bound_socket.update_socket_state(),
Ok(0) => self.update_io_events(),
Ok(_) => poll_ifaces(),
_ => (),
};
@ -125,7 +132,25 @@ impl ConnectedStream {
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.bound_socket.poll(mask, poller)
self.pollee.poll(mask, poller)
}
fn update_io_events(&self) {
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
let pollee = &self.pollee;
if socket.can_recv() {
pollee.add_events(IoEvents::IN);
} else {
pollee.del_events(IoEvents::IN);
}
if socket.can_send() {
pollee.add_events(IoEvents::OUT);
} else {
pollee.del_events(IoEvents::OUT);
}
});
}
pub fn is_nonblocking(&self) -> bool {
@ -136,3 +161,9 @@ impl ConnectedStream {
self.nonblocking.store(nonblocking, Ordering::Relaxed);
}
}
impl Observer<()> for ConnectedStream {
fn on_events(&self, _: &()) {
self.update_io_events();
}
}

View File

@ -2,12 +2,13 @@ use core::sync::atomic::{AtomicBool, Ordering};
use alloc::sync::Arc;
use crate::events::IoEvents;
use crate::events::{IoEvents, Observer};
use crate::net::iface::RawTcpSocket;
use crate::net::poll_ifaces;
use crate::prelude::*;
use crate::net::iface::{AnyBoundSocket, IpEndpoint};
use crate::process::signal::Poller;
use crate::process::signal::{Pollee, Poller};
use super::connected::ConnectedStream;
use super::init::InitStream;
@ -16,6 +17,13 @@ pub struct ConnectingStream {
nonblocking: AtomicBool,
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
conn_result: RwLock<Option<ConnResult>>,
pollee: Pollee,
}
enum ConnResult {
Connected,
Refused,
}
impl ConnectingStream {
@ -23,36 +31,57 @@ impl ConnectingStream {
nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
) -> Result<Self> {
pollee: Pollee,
) -> Result<Arc<Self>> {
bound_socket.do_connect(remote_endpoint)?;
Ok(Self {
let connecting = Arc::new(Self {
nonblocking: AtomicBool::new(nonblocking),
bound_socket,
remote_endpoint,
})
conn_result: RwLock::new(None),
pollee,
});
connecting.pollee.reset_events();
connecting
.bound_socket
.set_observer(Arc::downgrade(&connecting) as _);
Ok(connecting)
}
pub fn wait_conn(&self) -> core::result::Result<ConnectedStream, (Error, InitStream)> {
pub fn wait_conn(
&self,
) -> core::result::Result<Arc<ConnectedStream>, (Error, Arc<InitStream>)> {
debug_assert!(!self.is_nonblocking());
let poller = Poller::new();
loop {
poll_ifaces();
let events = self.poll(IoEvents::OUT | IoEvents::IN, Some(&poller));
if events.contains(IoEvents::IN) || events.contains(IoEvents::OUT) {
return Ok(ConnectedStream::new(
self.is_nonblocking(),
self.bound_socket.clone(),
self.remote_endpoint,
));
} else if !events.is_empty() {
return Err((
Error::with_message(Errno::ECONNREFUSED, "connection refused"),
InitStream::new_bound(self.is_nonblocking(), self.bound_socket.clone()),
));
} else {
match *self.conn_result.read() {
Some(ConnResult::Connected) => {
return Ok(ConnectedStream::new(
self.is_nonblocking(),
self.bound_socket.clone(),
self.remote_endpoint,
self.pollee.clone(),
));
}
Some(ConnResult::Refused) => {
return Err((
Error::with_message(Errno::ECONNREFUSED, "connection refused"),
InitStream::new_bound(
self.is_nonblocking(),
self.bound_socket.clone(),
self.pollee.clone(),
),
));
}
None => (),
};
let events = self.poll(IoEvents::OUT, Some(&poller));
if !events.contains(IoEvents::OUT) {
// FIXME: deal with nonblocking mode & connecting timeout
poller.wait().expect("async connect() not implemented");
}
@ -70,7 +99,7 @@ impl ConnectingStream {
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.bound_socket.poll(mask, poller)
self.pollee.poll(mask, poller)
}
pub fn is_nonblocking(&self) -> bool {
@ -80,4 +109,47 @@ impl ConnectingStream {
pub fn set_nonblocking(&self, nonblocking: bool) {
self.nonblocking.store(nonblocking, Ordering::Relaxed);
}
fn update_io_events(&self) {
if self.conn_result.read().is_some() {
return;
}
let became_writable = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
let mut result = self.conn_result.write();
if result.is_some() {
return false;
}
// Connected
if socket.can_send() {
*result = Some(ConnResult::Connected);
return true;
}
// Connecting
if socket.is_open() {
return false;
}
// Refused
*result = Some(ConnResult::Refused);
true
});
// Either when the connection is established, or when the connection fails, the socket
// shall indicate that it is writable.
//
// TODO: Find a way to turn `ConnectingStream` into `ConnectedStream` or `InitStream`
// here, so non-blocking `connect()` can work correctly. Meanwhile, the latter should
// be responsible to initialize all the I/O events including `IoEvents::OUT`, so the
// following hard-coded event addition can be removed.
if became_writable {
self.pollee.add_events(IoEvents::OUT);
}
}
}
impl Observer<()> for ConnectingStream {
fn on_events(&self, _: &()) {
self.update_io_events();
}
}

View File

@ -7,6 +7,7 @@ use crate::net::iface::{AnyBoundSocket, AnyUnboundSocket};
use crate::net::socket::ip::always_some::AlwaysSome;
use crate::net::socket::ip::common::{bind_socket, get_ephemeral_endpoint};
use crate::prelude::*;
use crate::process::signal::Pollee;
use crate::process::signal::Poller;
use super::connecting::ConnectingStream;
@ -15,6 +16,7 @@ use super::listen::ListenStream;
pub struct InitStream {
inner: RwLock<Inner>,
is_nonblocking: AtomicBool,
pollee: Pollee,
}
enum Inner {
@ -23,6 +25,11 @@ enum Inner {
}
impl Inner {
fn new() -> Inner {
let unbound_socket = Box::new(AnyUnboundSocket::new_tcp());
Inner::Unbound(AlwaysSome::new(unbound_socket))
}
fn is_bound(&self) -> bool {
match self {
Self::Unbound(_) => false,
@ -38,7 +45,6 @@ impl Inner {
};
let bound_socket =
unbound_socket.try_take_with(|raw_socket| bind_socket(raw_socket, endpoint, false))?;
bound_socket.update_socket_state();
*self = Inner::Bound(AlwaysSome::new(bound_socket));
Ok(())
}
@ -55,13 +61,6 @@ impl Inner {
}
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
match self {
Inner::Bound(bound_socket) => bound_socket.poll(mask, poller),
Inner::Unbound(unbound_socket) => unbound_socket.poll(mask, poller),
}
}
fn iface(&self) -> Option<Arc<dyn Iface>> {
match self {
Inner::Bound(bound_socket) => Some(bound_socket.iface().clone()),
@ -76,28 +75,36 @@ impl Inner {
}
impl InitStream {
pub fn new(nonblocking: bool) -> Self {
let socket = Box::new(AnyUnboundSocket::new_tcp());
let inner = Inner::Unbound(AlwaysSome::new(socket));
Self {
// FIXME: In Linux we have the `POLLOUT` event for a newly created socket, while calling
// `write()` on it triggers `SIGPIPE`/`EPIPE`. No documentation found yet, but confirmed by
// experimentation and Linux source code.
pub fn new(nonblocking: bool) -> Arc<Self> {
Arc::new(Self {
inner: RwLock::new(Inner::new()),
is_nonblocking: AtomicBool::new(nonblocking),
inner: RwLock::new(inner),
}
pollee: Pollee::new(IoEvents::empty()),
})
}
pub fn new_bound(nonblocking: bool, bound_socket: Arc<AnyBoundSocket>) -> Self {
pub fn new_bound(
nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>,
pollee: Pollee,
) -> Arc<Self> {
bound_socket.set_observer(Weak::<()>::new());
let inner = Inner::Bound(AlwaysSome::new(bound_socket));
Self {
Arc::new(Self {
is_nonblocking: AtomicBool::new(nonblocking),
inner: RwLock::new(inner),
}
pollee,
})
}
pub fn bind(&self, endpoint: IpEndpoint) -> Result<()> {
self.inner.write().bind(endpoint)
}
pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result<ConnectingStream> {
pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<ConnectingStream>> {
if !self.inner.read().is_bound() {
self.inner
.write()
@ -107,16 +114,22 @@ impl InitStream {
self.is_nonblocking(),
self.inner.read().bound_socket().unwrap().clone(),
*remote_endpoint,
self.pollee.clone(),
)
}
pub fn listen(&self, backlog: usize) -> Result<ListenStream> {
pub fn listen(&self, backlog: usize) -> Result<Arc<ListenStream>> {
let bound_socket = if let Some(bound_socket) = self.inner.read().bound_socket() {
bound_socket.clone()
} else {
return_errno_with_message!(Errno::EINVAL, "cannot listen without bound")
};
ListenStream::new(self.is_nonblocking(), bound_socket, backlog)
ListenStream::new(
self.is_nonblocking(),
bound_socket,
backlog,
self.pollee.clone(),
)
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
@ -127,7 +140,7 @@ impl InitStream {
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.inner.read().poll(mask, poller)
self.pollee.poll(mask, poller)
}
pub fn is_nonblocking(&self) -> bool {

View File

@ -1,10 +1,10 @@
use core::sync::atomic::{AtomicBool, Ordering};
use crate::events::IoEvents;
use crate::events::{IoEvents, Observer};
use crate::net::iface::{AnyUnboundSocket, BindPortConfig, IpEndpoint};
use crate::net::iface::{AnyBoundSocket, RawTcpSocket};
use crate::process::signal::Poller;
use crate::process::signal::{Pollee, Poller};
use crate::{net::poll_ifaces, prelude::*};
use super::connected::ConnectedStream;
@ -16,6 +16,7 @@ pub struct ListenStream {
bound_socket: Arc<AnyBoundSocket>,
/// Backlog sockets listening at the local endpoint
backlog_sockets: RwLock<Vec<BacklogSocket>>,
pollee: Pollee,
}
impl ListenStream {
@ -23,18 +24,24 @@ impl ListenStream {
nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>,
backlog: usize,
) -> Result<Self> {
let listen_stream = Self {
pollee: Pollee,
) -> Result<Arc<Self>> {
let listen_stream = Arc::new(Self {
is_nonblocking: AtomicBool::new(nonblocking),
backlog,
bound_socket,
backlog_sockets: RwLock::new(Vec::new()),
};
pollee,
});
listen_stream.fill_backlog_sockets()?;
listen_stream.pollee.reset_events();
listen_stream
.bound_socket
.set_observer(Arc::downgrade(&listen_stream) as _);
Ok(listen_stream)
}
pub fn accept(&self) -> Result<(ConnectedStream, IpEndpoint)> {
pub fn accept(&self) -> Result<(Arc<ConnectedStream>, IpEndpoint)> {
// wait to accept
let poller = Poller::new();
loop {
@ -42,8 +49,8 @@ impl ListenStream {
let accepted_socket = if let Some(accepted_socket) = self.try_accept() {
accepted_socket
} else {
let events = self.poll(IoEvents::IN | IoEvents::OUT, Some(&poller));
if !events.contains(IoEvents::IN) && !events.contains(IoEvents::OUT) {
let events = self.poll(IoEvents::IN, Some(&poller));
if !events.contains(IoEvents::IN) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try accept again");
}
@ -57,7 +64,12 @@ impl ListenStream {
let BacklogSocket {
bound_socket: backlog_socket,
} = accepted_socket;
ConnectedStream::new(false, backlog_socket, remote_endpoint)
ConnectedStream::new(
false,
backlog_socket,
remote_endpoint,
Pollee::new(IoEvents::empty()),
)
};
return Ok((connected_stream, remote_endpoint));
}
@ -88,6 +100,7 @@ impl ListenStream {
backlog_sockets.remove(index)
};
self.fill_backlog_sockets().unwrap();
self.update_io_events();
Some(backlog_socket)
}
@ -98,22 +111,25 @@ impl ListenStream {
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let backlog_sockets = self.backlog_sockets.read();
for backlog_socket in backlog_sockets.iter() {
if backlog_socket.is_active() {
return IoEvents::IN;
} else {
// regiser poller to the backlog socket
backlog_socket.poll(mask, poller);
}
}
IoEvents::empty()
self.pollee.poll(mask, poller)
}
fn bound_socket(&self) -> Arc<AnyBoundSocket> {
self.backlog_sockets.read()[0].bound_socket.clone()
}
fn update_io_events(&self) {
// The lock should be held to avoid data races
let backlog_sockets = self.backlog_sockets.read();
let can_accept = backlog_sockets.iter().any(|socket| socket.is_active());
if can_accept {
self.pollee.add_events(IoEvents::IN);
} else {
self.pollee.del_events(IoEvents::IN);
}
}
pub fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Relaxed)
}
@ -123,6 +139,12 @@ impl ListenStream {
}
}
impl Observer<()> for ListenStream {
fn on_events(&self, _: &()) {
self.update_io_events();
}
}
struct BacklogSocket {
bound_socket: Arc<AnyBoundSocket>,
}
@ -146,7 +168,6 @@ impl BacklogSocket {
.listen(local_endpoint)
.map_err(|_| Error::with_message(Errno::EINVAL, "fail to listen"))
})?;
bound_socket.update_socket_state();
Ok(Self { bound_socket })
}
@ -159,8 +180,4 @@ impl BacklogSocket {
self.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint())
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.bound_socket.poll(mask, poller)
}
}

View File

@ -38,7 +38,7 @@ enum State {
impl StreamSocket {
pub fn new(nonblocking: bool) -> Self {
let state = State::Init(Arc::new(InitStream::new(nonblocking)));
let state = State::Init(InitStream::new(nonblocking));
Self {
state: RwLock::new(state),
}
@ -71,7 +71,7 @@ impl StreamSocket {
}
};
let connecting = Arc::new(init_stream.connect(remote_endpoint)?);
let connecting = init_stream.connect(remote_endpoint)?;
*state = State::Connecting(connecting.clone());
Ok(connecting)
}
@ -139,12 +139,10 @@ impl Socket for StreamSocket {
let connecting_stream = self.do_connect(&remote_endpoint)?;
match connecting_stream.wait_conn() {
Ok(connected_stream) => {
let connected_stream = Arc::new(connected_stream);
*self.state.write() = State::Connected(connected_stream);
Ok(())
}
Err((err, init_stream)) => {
let init_stream = Arc::new(init_stream);
*self.state.write() = State::Init(init_stream);
Err(err)
}
@ -164,7 +162,7 @@ impl Socket for StreamSocket {
State::Connected(_) => return_errno_with_message!(Errno::EINVAL, "cannot listen"),
};
let listener = Arc::new(init_stream.listen(backlog)?);
let listener = init_stream.listen(backlog)?;
*state = State::Listen(listener);
Ok(())
}
@ -181,7 +179,7 @@ impl Socket for StreamSocket {
};
let accepted_socket = {
let state = RwLock::new(State::Connected(Arc::new(connected_stream)));
let state = RwLock::new(State::Connected(connected_stream));
Arc::new(StreamSocket { state })
};