Make Pollee stateless

This commit is contained in:
Ruihan Li
2024-11-13 23:39:55 +08:00
committed by Tate, Hongliang Tian
parent 5450d0bd71
commit fab61f5f66
30 changed files with 514 additions and 430 deletions

1
Cargo.lock generated
View File

@ -70,6 +70,7 @@ checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16"
name = "aster-bigtcp" name = "aster-bigtcp"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"bitflags 1.3.2",
"keyable-arc", "keyable-arc",
"ostd", "ostd",
"smoltcp", "smoltcp",

View File

@ -6,6 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
bitflags = "1.3"
keyable-arc = { path = "../keyable-arc" } keyable-arc = { path = "../keyable-arc" }
ostd = { path = "../../../ostd" } ostd = { path = "../../../ostd" }
smoltcp = { git = "https://github.com/asterinas/smoltcp", tag = "r_2024-11-08_f07e5b5", default-features = false, features = [ smoltcp = { git = "https://github.com/asterinas/smoltcp", tag = "r_2024-11-08_f07e5b5", default-features = false, features = [

View File

@ -220,13 +220,13 @@ impl<E> IfaceCommon<E> {
context.poll_egress(device, dispatch_phy); context.poll_egress(device, dispatch_phy);
tcp_sockets.iter().for_each(|socket| { tcp_sockets.iter().for_each(|socket| {
if socket.has_new_events() { if socket.has_events() {
socket.on_iface_events(); socket.on_events();
} }
}); });
udp_sockets.iter().for_each(|socket| { udp_sockets.iter().for_each(|socket| {
if socket.has_new_events() { if socket.has_events() {
socket.on_iface_events(); socket.on_events();
} }
}); });

View File

@ -6,7 +6,7 @@ use alloc::{
}; };
use core::{ use core::{
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
sync::atomic::{AtomicBool, AtomicU64, Ordering}, sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering},
}; };
use ostd::sync::{LocalIrqDisabled, RwLock, SpinLock, SpinLockGuard}; use ostd::sync::{LocalIrqDisabled, RwLock, SpinLock, SpinLockGuard};
@ -17,7 +17,10 @@ use smoltcp::{
wire::{IpAddress, IpEndpoint, IpRepr, TcpControl, TcpRepr, UdpRepr}, wire::{IpAddress, IpEndpoint, IpRepr, TcpControl, TcpRepr, UdpRepr},
}; };
use super::{event::SocketEventObserver, RawTcpSocket, RawUdpSocket}; use super::{
event::{SocketEventObserver, SocketEvents},
RawTcpSocket, RawUdpSocket, TcpStateCheck,
};
use crate::iface::Iface; use crate::iface::Iface;
pub struct BoundSocket<T: AnySocket, E>(Arc<BoundSocketInner<T, E>>); pub struct BoundSocket<T: AnySocket, E>(Arc<BoundSocketInner<T, E>>);
@ -44,8 +47,8 @@ pub struct BoundSocketInner<T, E> {
port: u16, port: u16,
socket: T, socket: T,
observer: RwLock<Weak<dyn SocketEventObserver>>, observer: RwLock<Weak<dyn SocketEventObserver>>,
events: AtomicU8,
next_poll_at_ms: AtomicU64, next_poll_at_ms: AtomicU64,
has_new_events: AtomicBool,
} }
/// States needed by [`BoundTcpSocketInner`] but not [`BoundUdpSocketInner`]. /// States needed by [`BoundTcpSocketInner`] but not [`BoundUdpSocketInner`].
@ -56,6 +59,7 @@ pub struct TcpSocket {
struct RawTcpSocketExt { struct RawTcpSocketExt {
socket: Box<RawTcpSocket>, socket: Box<RawTcpSocket>,
has_connected: bool,
/// Whether the socket is in the background. /// Whether the socket is in the background.
/// ///
/// A background socket is a socket with its corresponding [`BoundSocket`] dropped. This means /// A background socket is a socket with its corresponding [`BoundSocket`] dropped. This means
@ -79,6 +83,22 @@ impl DerefMut for RawTcpSocketExt {
} }
} }
impl RawTcpSocketExt {
fn on_new_state(&mut self) -> SocketEvents {
if self.may_send() {
self.has_connected = true;
}
if self.is_peer_closed() {
SocketEvents::PEER_CLOSED
} else if self.is_closed() {
SocketEvents::CLOSED
} else {
SocketEvents::empty()
}
}
}
impl TcpSocket { impl TcpSocket {
fn lock(&self) -> SpinLockGuard<RawTcpSocketExt, LocalIrqDisabled> { fn lock(&self) -> SpinLockGuard<RawTcpSocketExt, LocalIrqDisabled> {
self.socket.lock() self.socket.lock()
@ -123,6 +143,7 @@ impl AnySocket for TcpSocket {
fn new(socket: Box<Self::RawSocket>) -> Self { fn new(socket: Box<Self::RawSocket>) -> Self {
let socket_ext = RawTcpSocketExt { let socket_ext = RawTcpSocketExt {
socket, socket,
has_connected: false,
in_background: false, in_background: false,
}; };
@ -184,8 +205,8 @@ impl<T: AnySocket, E> BoundSocket<T, E> {
port, port,
socket: T::new(socket), socket: T::new(socket),
observer: RwLock::new(observer), observer: RwLock::new(observer),
events: AtomicU8::new(0),
next_poll_at_ms: AtomicU64::new(u64::MAX), next_poll_at_ms: AtomicU64::new(u64::MAX),
has_new_events: AtomicBool::new(false),
})) }))
} }
@ -204,7 +225,7 @@ impl<T: AnySocket, E> BoundSocket<T, E> {
pub fn set_observer(&self, new_observer: Weak<dyn SocketEventObserver>) { pub fn set_observer(&self, new_observer: Weak<dyn SocketEventObserver>) {
*self.0.observer.write_irq_disabled() = new_observer; *self.0.observer.write_irq_disabled() = new_observer;
self.0.on_iface_events(); self.0.on_events();
} }
/// Returns the observer. /// Returns the observer.
@ -229,6 +250,12 @@ impl<T: AnySocket, E> BoundSocket<T, E> {
} }
} }
pub enum ConnectState {
Connecting,
Connected,
Refused,
}
impl<E> BoundTcpSocket<E> { impl<E> BoundTcpSocket<E> {
/// Connects to a remote endpoint. /// Connects to a remote endpoint.
pub fn connect( pub fn connect(
@ -240,11 +267,26 @@ impl<E> BoundTcpSocket<E> {
let mut socket = self.0.socket.lock(); let mut socket = self.0.socket.lock();
let result = socket.connect(iface.context(), remote_endpoint, self.0.port); socket.connect(iface.context(), remote_endpoint, self.0.port)?;
socket.has_connected = false;
self.0 self.0
.update_next_poll_at_ms(socket.poll_at(iface.context())); .update_next_poll_at_ms(socket.poll_at(iface.context()));
result Ok(())
}
/// Returns the state of the connecting procedure.
pub fn connect_state(&self) -> ConnectState {
let socket = self.0.socket.lock();
if socket.state() == State::SynSent || socket.state() == State::SynReceived {
ConnectState::Connecting
} else if socket.has_connected {
ConnectState::Connected
} else {
ConnectState::Refused
}
} }
/// Listens at a specified endpoint. /// Listens at a specified endpoint.
@ -366,22 +408,33 @@ impl<E> BoundUdpSocket<E> {
} }
impl<T, E> BoundSocketInner<T, E> { impl<T, E> BoundSocketInner<T, E> {
pub(crate) fn has_new_events(&self) -> bool { pub(crate) fn has_events(&self) -> bool {
self.has_new_events.load(Ordering::Relaxed) self.events.load(Ordering::Relaxed) != 0
} }
pub(crate) fn on_iface_events(&self) { pub(crate) fn on_events(&self) {
self.has_new_events.store(false, Ordering::Relaxed); // This method can only be called to process network events, so we assume we are holding the
// poll lock and no race conditions can occur.
let events = self.events.load(Ordering::Relaxed);
self.events.store(0, Ordering::Relaxed);
// We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we // We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we
// get the read lock. // get the read lock.
let observer = Weak::upgrade(&*self.observer.read()); let observer = Weak::upgrade(&*self.observer.read());
if let Some(inner) = observer { if let Some(inner) = observer {
inner.on_events(); inner.on_events(SocketEvents::from_bits_truncate(events));
} }
} }
fn add_events(&self, new_events: SocketEvents) {
// This method can only be called to add network events, so we assume we are holding the
// poll lock and no race conditions can occur.
let events = self.events.load(Ordering::Relaxed);
self.events
.store(events | new_events.bits(), Ordering::Relaxed);
}
/// Returns the next polling time. /// Returns the next polling time.
/// ///
/// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required /// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required
@ -396,8 +449,6 @@ impl<T, E> BoundSocketInner<T, E> {
/// method also marks that there may be new events, so that the event observer provided by /// method also marks that there may be new events, so that the event observer provided by
/// [`BoundSocket::set_observer`] can be notified later. /// [`BoundSocket::set_observer`] can be notified later.
fn update_next_poll_at_ms(&self, poll_at: PollAt) { fn update_next_poll_at_ms(&self, poll_at: PollAt) {
self.has_new_events.store(true, Ordering::Relaxed);
match poll_at { match poll_at {
PollAt::Now => self.next_poll_at_ms.store(0, Ordering::Relaxed), PollAt::Now => self.next_poll_at_ms.store(0, Ordering::Relaxed),
PollAt::Time(instant) => self PollAt::Time(instant) => self
@ -484,11 +535,21 @@ impl<E> BoundTcpSocketInner<E> {
return TcpProcessResult::NotProcessed; return TcpProcessResult::NotProcessed;
} }
let old_state = socket.state();
// For TCP, receiving an ACK packet can free up space in the queue, allowing more packets
// to be queued.
let mut events = SocketEvents::CAN_RECV | SocketEvents::CAN_SEND;
let result = match socket.process(cx, ip_repr, tcp_repr) { let result = match socket.process(cx, ip_repr, tcp_repr) {
None => TcpProcessResult::Processed, None => TcpProcessResult::Processed,
Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr), Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr),
}; };
if socket.state() != old_state {
events |= socket.on_new_state();
}
self.add_events(events);
self.update_next_poll_at_ms(socket.poll_at(cx)); self.update_next_poll_at_ms(socket.poll_at(cx));
self.socket.update_dead(&socket); self.socket.update_dead(&socket);
@ -506,6 +567,9 @@ impl<E> BoundTcpSocketInner<E> {
{ {
let mut socket = self.socket.lock(); let mut socket = self.socket.lock();
let old_state = socket.state();
let mut events = SocketEvents::empty();
let mut reply = None; let mut reply = None;
socket socket
.dispatch(cx, |cx, (ip_repr, tcp_repr)| { .dispatch(cx, |cx, (ip_repr, tcp_repr)| {
@ -521,8 +585,14 @@ impl<E> BoundTcpSocketInner<E> {
break; break;
} }
reply = socket.process(cx, ip_repr, tcp_repr); reply = socket.process(cx, ip_repr, tcp_repr);
events |= SocketEvents::CAN_RECV | SocketEvents::CAN_SEND;
} }
if socket.state() != old_state {
events |= socket.on_new_state();
}
self.add_events(events);
self.update_next_poll_at_ms(socket.poll_at(cx)); self.update_next_poll_at_ms(socket.poll_at(cx));
self.socket.update_dead(&socket); self.socket.update_dead(&socket);
@ -552,6 +622,8 @@ impl<E> BoundUdpSocketInner<E> {
udp_repr, udp_repr,
udp_payload, udp_payload,
); );
self.add_events(SocketEvents::CAN_RECV);
self.update_next_poll_at_ms(socket.poll_at(cx)); self.update_next_poll_at_ms(socket.poll_at(cx));
true true
@ -570,6 +642,9 @@ impl<E> BoundUdpSocketInner<E> {
Ok::<(), ()>(()) Ok::<(), ()>(())
}) })
.unwrap(); .unwrap();
// For UDP, dequeuing a packet means that we can queue more packets.
self.add_events(SocketEvents::CAN_SEND);
self.update_next_poll_at_ms(socket.poll_at(cx)); self.update_next_poll_at_ms(socket.poll_at(cx));
} }
} }

View File

@ -3,9 +3,19 @@
/// A observer that will be invoked whenever events occur on the socket. /// A observer that will be invoked whenever events occur on the socket.
pub trait SocketEventObserver: Send + Sync { pub trait SocketEventObserver: Send + Sync {
/// Notifies that events occurred on the socket. /// Notifies that events occurred on the socket.
fn on_events(&self); fn on_events(&self, events: SocketEvents);
} }
impl SocketEventObserver for () { impl SocketEventObserver for () {
fn on_events(&self) {} fn on_events(&self, _events: SocketEvents) {}
}
bitflags::bitflags! {
/// Socket events caused by the _network_.
pub struct SocketEvents: u8 {
const CAN_RECV = 1;
const CAN_SEND = 2;
const PEER_CLOSED = 4;
const CLOSED = 8;
}
} }

View File

@ -2,12 +2,13 @@
mod bound; mod bound;
mod event; mod event;
mod state;
mod unbound; mod unbound;
pub use bound::{BoundTcpSocket, BoundUdpSocket}; pub use bound::{BoundTcpSocket, BoundUdpSocket, ConnectState};
pub(crate) use bound::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult}; pub(crate) use bound::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult};
pub use event::SocketEventObserver; pub use event::{SocketEventObserver, SocketEvents};
pub use smoltcp::socket::tcp::State as TcpState; pub use state::TcpStateCheck;
pub use unbound::{ pub use unbound::{
UnboundTcpSocket, UnboundUdpSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN, UnboundTcpSocket, UnboundUdpSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN,
UDP_SEND_PAYLOAD_LEN, UDP_SEND_PAYLOAD_LEN,

View File

@ -0,0 +1,31 @@
// SPDX-License-Identifier: MPL-2.0
use smoltcp::socket::tcp::State as TcpState;
use super::RawTcpSocket;
pub trait TcpStateCheck {
/// Checks if the peer socket has closed its sending side.
///
/// If the sending side of this socket is also closed, this method will return `false`.
/// In such cases, you should verify using [`is_closed`].
fn is_peer_closed(&self) -> bool;
/// Checks if the socket is fully closed.
///
/// This function returns `true` if both this socket and the peer have closed their sending sides.
///
/// This TCP state corresponds to the `Normal Close Sequence` and `Simultaneous Close Sequence`
/// as outlined in RFC793 (https://datatracker.ietf.org/doc/html/rfc793#page-39).
fn is_closed(&self) -> bool;
}
impl TcpStateCheck for RawTcpSocket {
fn is_peer_closed(&self) -> bool {
self.state() == TcpState::CloseWait
}
fn is_closed(&self) -> bool {
!self.is_open() || self.state() == TcpState::Closing || self.state() == TcpState::LastAck
}
}

View File

@ -48,7 +48,7 @@ impl PtyMaster {
output: ldisc, output: ldisc,
input: SpinLock::new(RingBuffer::new(BUFFER_CAPACITY)), input: SpinLock::new(RingBuffer::new(BUFFER_CAPACITY)),
job_control, job_control,
pollee: Pollee::new(IoEvents::OUT), pollee: Pollee::new(),
weak_self: weak_ref.clone(), weak_self: weak_ref.clone(),
}) })
} }
@ -64,7 +64,7 @@ impl PtyMaster {
pub(super) fn slave_push_char(&self, ch: u8) { pub(super) fn slave_push_char(&self, ch: u8) {
let mut input = self.input.disable_irq().lock(); let mut input = self.input.disable_irq().lock();
input.push_overwrite(ch); input.push_overwrite(ch);
self.update_state(&input); self.pollee.notify(IoEvents::IN);
} }
pub(super) fn slave_poll( pub(super) fn slave_poll(
@ -82,7 +82,9 @@ impl PtyMaster {
let poll_out_mask = mask & IoEvents::OUT; let poll_out_mask = mask & IoEvents::OUT;
if !poll_out_mask.is_empty() { if !poll_out_mask.is_empty() {
let poll_out_status = self.pollee.poll(poll_out_mask, poller); let poll_out_status = self
.pollee
.poll_with(poll_out_mask, poller, || self.check_io_events());
poll_status |= poll_out_status; poll_status |= poll_out_status;
} }
@ -100,17 +102,16 @@ impl PtyMaster {
return_errno_with_message!(Errno::EAGAIN, "the buffer is empty"); return_errno_with_message!(Errno::EAGAIN, "the buffer is empty");
} }
let read_len = input.read_fallible(writer)?; input.read_fallible(writer)
self.update_state(&input);
Ok(read_len)
} }
fn update_state(&self, buf: &RingBuffer<u8>) { fn check_io_events(&self) -> IoEvents {
if buf.is_empty() { let input = self.input.disable_irq().lock();
self.pollee.del_events(IoEvents::IN)
if !input.is_empty() {
IoEvents::IN | IoEvents::OUT
} else { } else {
self.pollee.add_events(IoEvents::IN); IoEvents::OUT
} }
} }
} }
@ -121,7 +122,11 @@ impl Pollable for PtyMaster {
let poll_in_mask = mask & IoEvents::IN; let poll_in_mask = mask & IoEvents::IN;
if !poll_in_mask.is_empty() { if !poll_in_mask.is_empty() {
let poll_in_status = self.pollee.poll(poll_in_mask, poller.as_deref_mut()); let poll_in_status = self
.pollee
.poll_with(poll_in_mask, poller.as_deref_mut(), || {
self.check_io_events()
});
poll_status |= poll_in_status; poll_status |= poll_in_status;
} }
@ -157,7 +162,7 @@ impl FileIo for PtyMaster {
}); });
} }
self.update_state(&input); self.pollee.notify(IoEvents::IN);
Ok(write_len) Ok(write_len)
} }

View File

@ -89,7 +89,8 @@ impl CurrentLine {
impl Pollable for LineDiscipline { impl Pollable for LineDiscipline {
fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee.poll(mask, poller) self.pollee
.poll_with(mask, poller, || self.check_io_events())
} }
} }
@ -108,7 +109,7 @@ impl LineDiscipline {
read_buffer: SpinLock::new(RingBuffer::new(BUFFER_CAPACITY)), read_buffer: SpinLock::new(RingBuffer::new(BUFFER_CAPACITY)),
termios: SpinLock::new(KernelTermios::default()), termios: SpinLock::new(KernelTermios::default()),
winsize: SpinLock::new(WinSize::default()), winsize: SpinLock::new(WinSize::default()),
pollee: Pollee::new(IoEvents::empty()), pollee: Pollee::new(),
send_signal, send_signal,
work_item, work_item,
work_item_para: Arc::new(SpinLock::new(LineDisciplineWorkPara::new())), work_item_para: Arc::new(SpinLock::new(LineDisciplineWorkPara::new())),
@ -140,7 +141,7 @@ impl LineDiscipline {
// Raw mode // Raw mode
if !termios.is_canonical_mode() { if !termios.is_canonical_mode() {
self.read_buffer.lock().push_overwrite(ch); self.read_buffer.lock().push_overwrite(ch);
self.update_readable_state(); self.pollee.notify(IoEvents::IN);
return; return;
} }
@ -166,6 +167,7 @@ impl LineDiscipline {
let current_line_chars = current_line.drain(); let current_line_chars = current_line.drain();
for char in current_line_chars { for char in current_line_chars {
self.read_buffer.lock().push_overwrite(char); self.read_buffer.lock().push_overwrite(char);
self.pollee.notify(IoEvents::IN);
} }
} }
@ -173,8 +175,6 @@ impl LineDiscipline {
// Printable character // Printable character
self.current_line.lock().push_char(ch); self.current_line.lock().push_char(ch);
} }
self.update_readable_state();
} }
fn may_send_signal(&self, termios: &KernelTermios, ch: u8) -> bool { fn may_send_signal(&self, termios: &KernelTermios, ch: u8) -> bool {
@ -198,13 +198,13 @@ impl LineDiscipline {
true true
} }
pub fn update_readable_state(&self) { fn check_io_events(&self) -> IoEvents {
let buffer = self.read_buffer.lock(); let buffer = self.read_buffer.lock();
if !buffer.is_empty() { if !buffer.is_empty() {
self.pollee.add_events(IoEvents::IN); IoEvents::IN
} else { } else {
self.pollee.del_events(IoEvents::IN); IoEvents::empty()
} }
} }
@ -265,7 +265,6 @@ impl LineDiscipline {
unreachable!() unreachable!()
} }
}; };
self.update_readable_state();
Ok(read_len) Ok(read_len)
} }

View File

@ -126,9 +126,6 @@ impl FileIo for Tty {
}; };
self.set_foreground(&pgid)?; self.set_foreground(&pgid)?;
// Some background processes may be waiting on the wait queue,
// when set_fg, the background processes may be able to read.
self.ldisc.update_readable_state();
Ok(0) Ok(0)
} }
IoctlCmd::TCSETS => { IoctlCmd::TCSETS => {

View File

@ -288,7 +288,7 @@ impl ReadySet {
Self { Self {
entries: SpinLock::new(VecDeque::new()), entries: SpinLock::new(VecDeque::new()),
pop_guard: Mutex::new(PopGuard), pop_guard: Mutex::new(PopGuard),
pollee: Pollee::new(IoEvents::empty()), pollee: Pollee::new(),
} }
} }
@ -315,7 +315,7 @@ impl ReadySet {
// Even if the entry is already set to ready, // Even if the entry is already set to ready,
// there might be new events that we are interested in. // there might be new events that we are interested in.
// Wake the poller anyway. // Wake the poller anyway.
self.pollee.add_events(IoEvents::IN); self.pollee.notify(IoEvents::IN);
} }
pub(super) fn lock_pop(&self) -> ReadySetPopIter { pub(super) fn lock_pop(&self) -> ReadySetPopIter {
@ -327,7 +327,18 @@ impl ReadySet {
} }
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee.poll(mask, poller) self.pollee
.poll_with(mask, poller, || self.check_io_events())
}
fn check_io_events(&self) -> IoEvents {
let entries = self.entries.lock();
if !entries.is_empty() {
IoEvents::IN
} else {
IoEvents::empty()
}
} }
} }
@ -356,11 +367,6 @@ impl Iterator for ReadySetPopIter<'_> {
// must exist, so we can just unwrap it. // must exist, so we can just unwrap it.
let weak_entry = entries.pop_front().unwrap(); let weak_entry = entries.pop_front().unwrap();
// Clear the epoll file's events if there are no ready entries.
if entries.len() == 0 {
self.ready_set.pollee.del_events(IoEvents::IN);
}
let Some(entry) = Weak::upgrade(&weak_entry) else { let Some(entry) = Weak::upgrade(&weak_entry) else {
// The entry has been deleted. // The entry has been deleted.
continue; continue;

View File

@ -96,7 +96,9 @@ macro_rules! impl_common_methods_for_channel {
} }
pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.this_end().pollee.poll(mask, poller) self.this_end()
.pollee
.poll_with(mask, poller, || self.check_io_events())
} }
}; };
} }
@ -110,27 +112,17 @@ impl<T> Producer<T> {
&self.0.common.consumer &self.0.common.consumer
} }
fn update_pollee(&self) { fn check_io_events(&self) -> IoEvents {
// In theory, `rb.free_len()`/`rb.is_empty()`, where the `rb` is taken from either
// `this_end` or `peer_end`, should reflect the same state. However, we need to take the
// correct lock when updating the events to avoid races between the state check and the
// event update.
let this_end = self.this_end(); let this_end = self.this_end();
let rb = this_end.rb(); let rb = this_end.rb();
if self.is_shutdown() {
// The POLLOUT event is always set in this case. Don't try to remove it.
} else if rb.free_len() < PIPE_BUF {
this_end.pollee.del_events(IoEvents::OUT);
}
drop(rb);
let peer_end = self.peer_end(); if self.is_shutdown() {
let rb = peer_end.rb(); IoEvents::ERR | IoEvents::OUT
if !rb.is_empty() { } else if rb.free_len() > PIPE_BUF {
peer_end.pollee.add_events(IoEvents::IN); IoEvents::OUT
} else {
IoEvents::empty()
} }
drop(rb);
} }
impl_common_methods_for_channel!(); impl_common_methods_for_channel!();
@ -153,7 +145,7 @@ impl Producer<u8> {
} }
let written_len = self.0.write(reader)?; let written_len = self.0.write(reader)?;
self.update_pollee(); self.peer_end().pollee.notify(IoEvents::IN);
if written_len > 0 { if written_len > 0 {
Ok(written_len) Ok(written_len)
@ -179,7 +171,7 @@ impl<T: Pod> Producer<T> {
let err = Error::with_message(Errno::EAGAIN, "the channel is full"); let err = Error::with_message(Errno::EAGAIN, "the channel is full");
(err, item) (err, item)
})?; })?;
self.update_pollee(); self.peer_end().pollee.notify(IoEvents::IN);
Ok(()) Ok(())
} }
@ -200,25 +192,18 @@ impl<T> Consumer<T> {
&self.0.common.producer &self.0.common.producer
} }
fn update_pollee(&self) { fn check_io_events(&self) -> IoEvents {
// In theory, `rb.free_len()`/`rb.is_empty()`, where the `rb` is taken from either
// `this_end` or `peer_end`, should reflect the same state. However, we need to take the
// correct lock when updating the events to avoid races between the state check and the
// event update.
let this_end = self.this_end(); let this_end = self.this_end();
let rb = this_end.rb(); let rb = this_end.rb();
if rb.is_empty() {
this_end.pollee.del_events(IoEvents::IN);
}
drop(rb);
let peer_end = self.peer_end(); let mut events = IoEvents::empty();
let rb = peer_end.rb(); if self.is_shutdown() {
if rb.free_len() >= PIPE_BUF { events |= IoEvents::HUP;
peer_end.pollee.add_events(IoEvents::OUT);
} }
drop(rb); if !rb.is_empty() {
events |= IoEvents::IN;
}
events
} }
impl_common_methods_for_channel!(); impl_common_methods_for_channel!();
@ -239,7 +224,7 @@ impl Consumer<u8> {
let is_shutdown = self.is_shutdown(); let is_shutdown = self.is_shutdown();
let read_len = self.0.read(writer)?; let read_len = self.0.read(writer)?;
self.update_pollee(); self.peer_end().pollee.notify(IoEvents::OUT);
if read_len > 0 { if read_len > 0 {
Ok(read_len) Ok(read_len)
@ -262,7 +247,7 @@ impl<T: Pod> Consumer<T> {
let is_shutdown = self.is_shutdown(); let is_shutdown = self.is_shutdown();
let item = self.0.pop(); let item = self.0.pop();
self.update_pollee(); self.peer_end().pollee.notify(IoEvents::OUT);
if let Some(item) = item { if let Some(item) = item {
Ok(Some(item)) Ok(Some(item))
@ -346,25 +331,12 @@ impl<T> Common<T> {
let (rb_producer, rb_consumer) = rb.split(); let (rb_producer, rb_consumer) = rb.split();
let producer = { let producer = {
let polee = if let Some(pollee) = producer_pollee { let pollee = producer_pollee.unwrap_or_default();
pollee.reset_events(); FifoInner::new(rb_producer, pollee)
pollee.add_events(IoEvents::OUT);
pollee
} else {
Pollee::new(IoEvents::OUT)
};
FifoInner::new(rb_producer, polee)
}; };
let consumer = { let consumer = {
let pollee = if let Some(pollee) = consumer_pollee { let pollee = consumer_pollee.unwrap_or_default();
pollee.reset_events();
pollee
} else {
Pollee::new(IoEvents::empty())
};
FifoInner::new(rb_consumer, pollee) FifoInner::new(rb_consumer, pollee)
}; };
@ -389,19 +361,11 @@ impl<T> Common<T> {
} }
// The POLLHUP event indicates that the write end is shut down. // The POLLHUP event indicates that the write end is shut down.
// self.consumer.pollee.notify(IoEvents::HUP);
// No need to take a lock. There is no race because no one is modifying this particular event.
self.consumer.pollee.add_events(IoEvents::HUP);
// The POLLERR event indicates that the read end is shut down (so any subsequent writes // The POLLERR event indicates that the read end is shut down (so any subsequent writes
// will fail with an `EPIPE` error). // will fail with an `EPIPE` error).
// self.producer.pollee.notify(IoEvents::ERR | IoEvents::OUT);
// The lock is taken because we are also adding the POLLOUT event, which may have races
// with the event updates triggered by the writer.
let _rb = self.producer.rb();
self.producer
.pollee
.add_events(IoEvents::ERR | IoEvents::OUT);
} }
} }

View File

@ -9,7 +9,6 @@ use crate::{
events::IoEvents, events::IoEvents,
net::{iface::BoundUdpSocket, socket::util::send_recv_flags::SendRecvFlags}, net::{iface::BoundUdpSocket, socket::util::send_recv_flags::SendRecvFlags},
prelude::*, prelude::*,
process::signal::Pollee,
util::{MultiRead, MultiWrite}, util::{MultiRead, MultiWrite},
}; };
@ -93,24 +92,19 @@ impl BoundDatagram {
} }
} }
pub(super) fn init_pollee(&self, pollee: &Pollee) { pub(super) fn check_io_events(&self) -> IoEvents {
pollee.reset_events();
self.update_io_events(pollee)
}
pub(super) fn update_io_events(&self, pollee: &Pollee) {
self.bound_socket.raw_with(|socket| { self.bound_socket.raw_with(|socket| {
let mut events = IoEvents::empty();
if socket.can_recv() { if socket.can_recv() {
pollee.add_events(IoEvents::IN); events |= IoEvents::IN;
} else {
pollee.del_events(IoEvents::IN);
} }
if socket.can_send() { if socket.can_send() {
pollee.add_events(IoEvents::OUT); events |= IoEvents::OUT;
} else {
pollee.del_events(IoEvents::OUT);
} }
});
events
})
} }
} }

View File

@ -2,7 +2,10 @@
use core::sync::atomic::{AtomicBool, Ordering}; use core::sync::atomic::{AtomicBool, Ordering};
use aster_bigtcp::{socket::SocketEventObserver, wire::IpEndpoint}; use aster_bigtcp::{
socket::{SocketEventObserver, SocketEvents},
wire::IpEndpoint,
};
use takeable::Takeable; use takeable::Takeable;
use self::{bound::BoundDatagram, unbound::UnboundDatagram}; use self::{bound::BoundDatagram, unbound::UnboundDatagram};
@ -98,12 +101,10 @@ impl DatagramSocket {
pub fn new(nonblocking: bool) -> Arc<Self> { pub fn new(nonblocking: bool) -> Arc<Self> {
Arc::new_cyclic(|me| { Arc::new_cyclic(|me| {
let unbound_datagram = UnboundDatagram::new(me.clone() as _); let unbound_datagram = UnboundDatagram::new(me.clone() as _);
let pollee = Pollee::new(IoEvents::empty());
unbound_datagram.init_pollee(&pollee);
Self { Self {
inner: RwLock::new(Takeable::new(Inner::Unbound(unbound_datagram))), inner: RwLock::new(Takeable::new(Inner::Unbound(unbound_datagram))),
nonblocking: AtomicBool::new(nonblocking), nonblocking: AtomicBool::new(nonblocking),
pollee, pollee: Pollee::new(),
options: RwLock::new(OptionSet::new()), options: RwLock::new(OptionSet::new()),
} }
}) })
@ -141,7 +142,6 @@ impl DatagramSocket {
return (err_inner, Err(err)); return (err_inner, Err(err));
} }
}; };
bound_datagram.init_pollee(&self.pollee);
(Inner::Bound(bound_datagram), Ok(())) (Inner::Bound(bound_datagram), Ok(()))
}) })
} }
@ -199,18 +199,20 @@ impl DatagramSocket {
sent_bytes sent_bytes
} }
fn update_io_events(&self) { fn check_io_events(&self) -> IoEvents {
let inner = self.inner.read(); let inner = self.inner.read();
let Inner::Bound(bound_datagram) = inner.as_ref() else {
return; match inner.as_ref() {
}; Inner::Unbound(unbound_datagram) => unbound_datagram.check_io_events(),
bound_datagram.update_io_events(&self.pollee); Inner::Bound(bound_socket) => bound_socket.check_io_events(),
}
} }
} }
impl Pollable for DatagramSocket { impl Pollable for DatagramSocket {
fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee.poll(mask, poller) self.pollee
.poll_with(mask, poller, || self.check_io_events())
} }
} }
@ -283,7 +285,6 @@ impl Socket for DatagramSocket {
return (err_inner, Err(err)); return (err_inner, Err(err));
} }
}; };
bound_datagram.init_pollee(&self.pollee);
(Inner::Bound(bound_datagram), Ok(())) (Inner::Bound(bound_datagram), Ok(()))
}) })
} }
@ -388,7 +389,17 @@ impl Socket for DatagramSocket {
} }
impl SocketEventObserver for DatagramSocket { impl SocketEventObserver for DatagramSocket {
fn on_events(&self) { fn on_events(&self, events: SocketEvents) {
self.update_io_events(); let mut io_events = IoEvents::empty();
if events.contains(SocketEvents::CAN_RECV) {
io_events |= IoEvents::IN;
}
if events.contains(SocketEvents::CAN_SEND) {
io_events |= IoEvents::OUT;
}
self.pollee.notify(io_events);
} }
} }

View File

@ -8,9 +8,7 @@ use aster_bigtcp::{
}; };
use super::bound::BoundDatagram; use super::bound::BoundDatagram;
use crate::{ use crate::{events::IoEvents, net::socket::ip::common::bind_socket, prelude::*};
events::IoEvents, net::socket::ip::common::bind_socket, prelude::*, process::signal::Pollee,
};
pub struct UnboundDatagram { pub struct UnboundDatagram {
unbound_socket: Box<UnboundUdpSocket>, unbound_socket: Box<UnboundUdpSocket>,
@ -44,8 +42,7 @@ impl UnboundDatagram {
Ok(BoundDatagram::new(bound_socket)) Ok(BoundDatagram::new(bound_socket))
} }
pub(super) fn init_pollee(&self, pollee: &Pollee) { pub(super) fn check_io_events(&self) -> IoEvents {
pollee.reset_events(); IoEvents::OUT
pollee.add_events(IoEvents::OUT);
} }
} }

View File

@ -5,7 +5,7 @@ use core::sync::atomic::{AtomicBool, Ordering};
use aster_bigtcp::{ use aster_bigtcp::{
errors::tcp::{RecvError, SendError}, errors::tcp::{RecvError, SendError},
socket::{RawTcpSocket, SocketEventObserver, TcpState}, socket::{SocketEventObserver, TcpStateCheck},
wire::IpEndpoint, wire::IpEndpoint,
}; };
@ -61,16 +61,21 @@ impl ConnectedStream {
} }
pub fn shutdown(&self, cmd: SockShutdownCmd, pollee: &Pollee) -> Result<()> { pub fn shutdown(&self, cmd: SockShutdownCmd, pollee: &Pollee) -> Result<()> {
let mut events = IoEvents::empty();
if cmd.shut_read() { if cmd.shut_read() {
self.is_receiving_closed.store(true, Ordering::Relaxed); self.is_receiving_closed.store(true, Ordering::Relaxed);
self.update_io_events(pollee); events |= IoEvents::IN | IoEvents::RDHUP;
} }
if cmd.shut_write() { if cmd.shut_write() {
self.is_sending_closed.store(true, Ordering::Relaxed); self.is_sending_closed.store(true, Ordering::Relaxed);
self.bound_socket.close(); self.bound_socket.close();
events |= IoEvents::OUT | IoEvents::HUP;
} }
pollee.notify(events);
Ok(()) Ok(())
} }
@ -132,17 +137,12 @@ impl ConnectedStream {
Ok(()) Ok(())
} }
pub(super) fn init_pollee(&self, pollee: &Pollee) { pub(super) fn check_io_events(&self) -> IoEvents {
pollee.reset_events();
self.update_io_events(pollee);
}
pub(super) fn update_io_events(&self, pollee: &Pollee) {
self.bound_socket.raw_with(|socket| { self.bound_socket.raw_with(|socket| {
if is_peer_closed(socket) { if socket.is_peer_closed() {
// Only the sending side of peer socket is closed // Only the sending side of peer socket is closed
self.is_receiving_closed.store(true, Ordering::Relaxed); self.is_receiving_closed.store(true, Ordering::Relaxed);
} else if is_closed(socket) { } else if socket.is_closed() {
// The sending side of both peer socket and this socket are closed // The sending side of both peer socket and this socket are closed
self.is_receiving_closed.store(true, Ordering::Relaxed); self.is_receiving_closed.store(true, Ordering::Relaxed);
self.is_sending_closed.store(true, Ordering::Relaxed); self.is_sending_closed.store(true, Ordering::Relaxed);
@ -151,50 +151,32 @@ impl ConnectedStream {
let is_receiving_closed = self.is_receiving_closed.load(Ordering::Relaxed); let is_receiving_closed = self.is_receiving_closed.load(Ordering::Relaxed);
let is_sending_closed = self.is_sending_closed.load(Ordering::Relaxed); let is_sending_closed = self.is_sending_closed.load(Ordering::Relaxed);
let mut events = IoEvents::empty();
// If the receiving side is closed, always add events IN and RDHUP; // If the receiving side is closed, always add events IN and RDHUP;
// otherwise, check if the socket can receive. // otherwise, check if the socket can receive.
if is_receiving_closed { if is_receiving_closed {
pollee.add_events(IoEvents::IN | IoEvents::RDHUP); events |= IoEvents::IN | IoEvents::RDHUP;
} else if socket.can_recv() { } else if socket.can_recv() {
pollee.add_events(IoEvents::IN); events |= IoEvents::IN;
} else {
pollee.del_events(IoEvents::IN);
} }
// If the sending side is closed, always add an OUT event; // If the sending side is closed, always add an OUT event;
// otherwise, check if the socket can send. // otherwise, check if the socket can send.
if is_sending_closed || socket.can_send() { if is_sending_closed || socket.can_send() {
pollee.add_events(IoEvents::OUT); events |= IoEvents::OUT;
} else {
pollee.del_events(IoEvents::OUT);
} }
// If both sending and receiving sides are closed, add a HUP event. // If both sending and receiving sides are closed, add a HUP event.
if is_receiving_closed && is_sending_closed { if is_receiving_closed && is_sending_closed {
pollee.add_events(IoEvents::HUP); events |= IoEvents::HUP;
} }
});
events
})
} }
pub(super) fn set_observer(&self, observer: Weak<dyn SocketEventObserver>) { pub(super) fn set_observer(&self, observer: Weak<dyn SocketEventObserver>) {
self.bound_socket.set_observer(observer) self.bound_socket.set_observer(observer)
} }
} }
/// Checks if the peer socket has closed its sending side.
///
/// If the sending side of this socket is also closed, this method will return `false`.
/// In such cases, you should verify using [`is_closed`].
fn is_peer_closed(socket: &RawTcpSocket) -> bool {
socket.state() == TcpState::CloseWait
}
/// Checks if the socket is fully closed.
///
/// This function returns `true` if both this socket and the peer have closed their sending sides.
///
/// This TCP state corresponds to the `Normal Close Sequence` and `Simultaneous Close Sequence`
/// as outlined in RFC793 (https://datatracker.ietf.org/doc/html/rfc793#page-39).
fn is_closed(socket: &RawTcpSocket) -> bool {
!socket.is_open() || socket.state() == TcpState::Closing || socket.state() == TcpState::LastAck
}

View File

@ -1,21 +1,19 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use aster_bigtcp::wire::IpEndpoint; use aster_bigtcp::{socket::ConnectState, wire::IpEndpoint};
use ostd::sync::LocalIrqDisabled;
use super::{connected::ConnectedStream, init::InitStream}; use super::{connected::ConnectedStream, init::InitStream};
use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee}; use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*};
pub struct ConnectingStream { pub struct ConnectingStream {
bound_socket: BoundTcpSocket, bound_socket: BoundTcpSocket,
remote_endpoint: IpEndpoint, remote_endpoint: IpEndpoint,
conn_result: SpinLock<Option<ConnResult>, LocalIrqDisabled>,
} }
#[derive(Clone, Copy)] pub enum ConnResult {
enum ConnResult { Connecting(ConnectingStream),
Connected, Connected(ConnectedStream),
Refused, Refused(InitStream),
} }
impl ConnectingStream { impl ConnectingStream {
@ -41,27 +39,28 @@ impl ConnectingStream {
Ok(Self { Ok(Self {
bound_socket, bound_socket,
remote_endpoint, remote_endpoint,
conn_result: SpinLock::new(None),
}) })
} }
pub fn has_result(&self) -> bool { pub fn has_result(&self) -> bool {
self.conn_result.lock().is_some() match self.bound_socket.connect_state() {
ConnectState::Connecting => false,
ConnectState::Connected => true,
ConnectState::Refused => true,
}
} }
pub fn into_result(self) -> core::result::Result<ConnectedStream, (Error, InitStream)> { pub fn into_result(self) -> ConnResult {
let conn_result = *self.conn_result.lock(); let next_state = self.bound_socket.connect_state();
match conn_result {
Some(ConnResult::Connected) => Ok(ConnectedStream::new( match next_state {
ConnectState::Connecting => ConnResult::Connecting(self),
ConnectState::Connected => ConnResult::Connected(ConnectedStream::new(
self.bound_socket, self.bound_socket,
self.remote_endpoint, self.remote_endpoint,
true, true,
)), )),
Some(ConnResult::Refused) => Err(( ConnectState::Refused => ConnResult::Refused(InitStream::new_bound(self.bound_socket)),
Error::with_message(Errno::ECONNREFUSED, "the connection is refused"),
InitStream::new_bound(self.bound_socket),
)),
None => unreachable!("`has_result` must be true before calling `into_result`"),
} }
} }
@ -73,43 +72,7 @@ impl ConnectingStream {
self.remote_endpoint self.remote_endpoint
} }
pub(super) fn init_pollee(&self, pollee: &Pollee) { pub(super) fn check_io_events(&self) -> IoEvents {
pollee.reset_events(); IoEvents::empty()
}
pub(super) fn update_io_events(&self, pollee: &Pollee) {
if self.conn_result.lock().is_some() {
return;
}
self.bound_socket.raw_with(|socket| {
let mut result = self.conn_result.lock();
if result.is_some() {
return;
}
// Connected
if socket.can_send() {
*result = Some(ConnResult::Connected);
pollee.add_events(IoEvents::OUT);
return;
}
// Connecting
if socket.is_open() {
return;
}
// Refused
*result = Some(ConnResult::Refused);
pollee.add_events(IoEvents::OUT);
// Add `IoEvents::OUT` because the man pages say "EINPROGRESS [..] It is possible to
// select(2) or poll(2) for completion by selecting the socket for writing". For
// details, see <https://man7.org/linux/man-pages/man2/connect.2.html>.
//
// TODO: It is better to do the state transition and let `ConnectedStream` or
// `InitStream` set the correct I/O events. However, the state transition is delayed
// because we're probably in IRQ handlers. Maybe mark the `pollee` as obsolete and
// re-calculate the I/O events in `poll`.
})
} }
} }

View File

@ -15,7 +15,6 @@ use crate::{
socket::ip::common::{bind_socket, get_ephemeral_endpoint}, socket::ip::common::{bind_socket, get_ephemeral_endpoint},
}, },
prelude::*, prelude::*,
process::signal::Pollee,
}; };
pub enum InitStream { pub enum InitStream {
@ -101,9 +100,8 @@ impl InitStream {
} }
} }
pub(super) fn init_pollee(&self, pollee: &Pollee) { pub(super) fn check_io_events(&self) -> IoEvents {
pollee.reset_events();
// Linux adds OUT and HUP events for a newly created socket // Linux adds OUT and HUP events for a newly created socket
pollee.add_events(IoEvents::OUT | IoEvents::HUP); IoEvents::OUT | IoEvents::HUP
} }
} }

View File

@ -5,7 +5,7 @@ use aster_bigtcp::{
}; };
use super::connected::ConnectedStream; use super::connected::ConnectedStream;
use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee}; use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*};
pub struct ListenStream { pub struct ListenStream {
backlog: usize, backlog: usize,
@ -80,22 +80,17 @@ impl ListenStream {
self.bound_socket.local_endpoint().unwrap() self.bound_socket.local_endpoint().unwrap()
} }
pub(super) fn init_pollee(&self, pollee: &Pollee) { pub(super) fn check_io_events(&self) -> IoEvents {
pollee.reset_events();
self.update_io_events(pollee);
}
pub(super) fn update_io_events(&self, pollee: &Pollee) {
let backlog_sockets = self.backlog_sockets.read(); let backlog_sockets = self.backlog_sockets.read();
let can_accept = backlog_sockets.iter().any(|socket| socket.can_accept()); let can_accept = backlog_sockets.iter().any(|socket| socket.can_accept());
// FIXME: If network packets come in simultaneously, the socket state may change in the // If network packets come in simultaneously, the socket state may change in the middle.
// middle. This can cause the wrong I/O events to be added or deleted. // However, the current pollee implementation should be able to handle this race condition.
if can_accept { if can_accept {
pollee.add_events(IoEvents::IN); IoEvents::IN
} else { } else {
pollee.del_events(IoEvents::IN); IoEvents::empty()
} }
} }
} }

View File

@ -2,9 +2,12 @@
use core::sync::atomic::{AtomicBool, Ordering}; use core::sync::atomic::{AtomicBool, Ordering};
use aster_bigtcp::{socket::SocketEventObserver, wire::IpEndpoint}; use aster_bigtcp::{
socket::{SocketEventObserver, SocketEvents},
wire::IpEndpoint,
};
use connected::ConnectedStream; use connected::ConnectedStream;
use connecting::ConnectingStream; use connecting::{ConnResult, ConnectingStream};
use init::InitStream; use init::InitStream;
use listen::ListenStream; use listen::ListenStream;
use options::{Congestion, MaxSegment, NoDelay, WindowClamp}; use options::{Congestion, MaxSegment, NoDelay, WindowClamp};
@ -81,27 +84,23 @@ impl StreamSocket {
pub fn new(nonblocking: bool) -> Arc<Self> { pub fn new(nonblocking: bool) -> Arc<Self> {
Arc::new_cyclic(|me| { Arc::new_cyclic(|me| {
let init_stream = InitStream::new(me.clone() as _); let init_stream = InitStream::new(me.clone() as _);
let pollee = Pollee::new(IoEvents::empty());
init_stream.init_pollee(&pollee);
Self { Self {
options: RwLock::new(OptionSet::new()), options: RwLock::new(OptionSet::new()),
state: RwLock::new(Takeable::new(State::Init(init_stream))), state: RwLock::new(Takeable::new(State::Init(init_stream))),
is_nonblocking: AtomicBool::new(nonblocking), is_nonblocking: AtomicBool::new(nonblocking),
pollee, pollee: Pollee::new(),
} }
}) })
} }
fn new_connected(connected_stream: ConnectedStream) -> Arc<Self> { fn new_connected(connected_stream: ConnectedStream) -> Arc<Self> {
Arc::new_cyclic(move |me| { Arc::new_cyclic(move |me| {
let pollee = Pollee::new(IoEvents::empty());
connected_stream.set_observer(me.clone() as _); connected_stream.set_observer(me.clone() as _);
connected_stream.init_pollee(&pollee);
Self { Self {
options: RwLock::new(OptionSet::new()), options: RwLock::new(OptionSet::new()),
state: RwLock::new(Takeable::new(State::Connected(connected_stream))), state: RwLock::new(Takeable::new(State::Connected(connected_stream))),
is_nonblocking: AtomicBool::new(false), is_nonblocking: AtomicBool::new(false),
pollee, pollee: Pollee::new(),
} }
}) })
} }
@ -161,23 +160,26 @@ impl StreamSocket {
_ => return (options, state), _ => return (options, state),
} }
let result = state.borrow_result(|owned_state| { state.borrow(|owned_state| {
let State::Connecting(connecting_stream) = owned_state else { let State::Connecting(connecting_stream) = owned_state else {
unreachable!("`State::Connecting` is checked before calling `borrow_result`"); unreachable!("`State::Connecting` is checked before calling `borrow_result`");
}; };
let connected_stream = match connecting_stream.into_result() { match connecting_stream.into_result() {
Ok(connected_stream) => connected_stream, ConnResult::Connecting(connecting_stream) => State::Connecting(connecting_stream),
Err((err, init_stream)) => { ConnResult::Connected(connected_stream) => {
init_stream.init_pollee(&self.pollee); options.socket.set_sock_errors(None);
return (State::Init(init_stream), Err(err)); State::Connected(connected_stream)
}
ConnResult::Refused(init_stream) => {
options.socket.set_sock_errors(Some(Error::with_message(
Errno::ECONNREFUSED,
"the connection is refused",
)));
State::Init(init_stream)
}
} }
};
connected_stream.init_pollee(&self.pollee);
(State::Connected(connected_stream), Ok(()))
}); });
options.socket.set_sock_errors(result.err());
(options, state) (options, state)
} }
@ -224,7 +226,6 @@ impl StreamSocket {
return (State::Init(init_stream), Some(Err(err))); return (State::Init(init_stream), Some(Err(err)));
} }
}; };
connecting_stream.init_pollee(&self.pollee);
( (
State::Connecting(connecting_stream), State::Connecting(connecting_stream),
@ -269,8 +270,6 @@ impl StreamSocket {
}; };
let accepted = listen_stream.try_accept().map(|connected_stream| { let accepted = listen_stream.try_accept().map(|connected_stream| {
listen_stream.update_io_events(&self.pollee);
let remote_endpoint = connected_stream.remote_endpoint(); let remote_endpoint = connected_stream.remote_endpoint();
let accepted_socket = Self::new_connected(connected_stream); let accepted_socket = Self::new_connected(connected_stream);
(accepted_socket as _, remote_endpoint.into()) (accepted_socket as _, remote_endpoint.into())
@ -354,30 +353,22 @@ impl StreamSocket {
} }
} }
fn update_io_events(&self) { fn check_io_events(&self) -> IoEvents {
let state = self.state.read(); let state = self.read_updated_state();
match state.as_ref() {
State::Init(_) => (),
State::Connecting(connecting_stream) => {
connecting_stream.update_io_events(&self.pollee)
}
State::Listen(listen_stream) => {
listen_stream.update_io_events(&self.pollee);
}
State::Connected(connected_stream) => {
connected_stream.update_io_events(&self.pollee);
}
}
// Note: Network events can cause a state transition from `State::Connecting` to match state.as_ref() {
// `State::Connected`/`State::Init`. The state transition is delayed until State::Init(init_stream) => init_stream.check_io_events(),
// `update_connecting`is triggered by user events, see that method for details. State::Connecting(connecting_stream) => connecting_stream.check_io_events(),
State::Listen(listen_stream) => listen_stream.check_io_events(),
State::Connected(connected_stream) => connected_stream.check_io_events(),
}
} }
} }
impl Pollable for StreamSocket { impl Pollable for StreamSocket {
fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee.poll(mask, poller) self.pollee
.poll_with(mask, poller, || self.check_io_events())
} }
} }
@ -492,7 +483,6 @@ impl Socket for StreamSocket {
return (State::Init(init_stream), Err(err)); return (State::Init(init_stream), Err(err));
} }
}; };
listen_stream.init_pollee(&self.pollee);
(State::Listen(listen_stream), Ok(())) (State::Listen(listen_stream), Ok(()))
}) })
@ -677,8 +667,26 @@ impl Socket for StreamSocket {
} }
impl SocketEventObserver for StreamSocket { impl SocketEventObserver for StreamSocket {
fn on_events(&self) { fn on_events(&self, events: SocketEvents) {
self.update_io_events(); let mut io_events = IoEvents::empty();
if events.contains(SocketEvents::CAN_RECV) {
io_events |= IoEvents::IN;
}
if events.contains(SocketEvents::CAN_SEND) {
io_events |= IoEvents::OUT;
}
if events.contains(SocketEvents::PEER_CLOSED) {
io_events |= IoEvents::IN | IoEvents::RDHUP;
}
if events.contains(SocketEvents::CLOSED) {
io_events |= IoEvents::IN | IoEvents::OUT | IoEvents::RDHUP | IoEvents::HUP;
}
self.pollee.notify(io_events);
} }
} }

View File

@ -28,8 +28,8 @@ impl Init {
pub(super) fn new() -> Self { pub(super) fn new() -> Self {
Self { Self {
addr: None, addr: None,
reader_pollee: Pollee::new(IoEvents::empty()), reader_pollee: Pollee::new(),
writer_pollee: Pollee::new(IoEvents::OUT), writer_pollee: Pollee::new(),
is_read_shutdown: AtomicBool::new(false), is_read_shutdown: AtomicBool::new(false),
is_write_shutdown: AtomicBool::new(false), is_write_shutdown: AtomicBool::new(false),
} }
@ -87,6 +87,7 @@ impl Init {
self.writer_pollee, self.writer_pollee,
backlog, backlog,
self.is_read_shutdown.into_inner(), self.is_read_shutdown.into_inner(),
self.is_write_shutdown.into_inner(),
)) ))
} }
@ -94,7 +95,7 @@ impl Init {
match cmd { match cmd {
SockShutdownCmd::SHUT_WR | SockShutdownCmd::SHUT_RDWR => { SockShutdownCmd::SHUT_WR | SockShutdownCmd::SHUT_RDWR => {
self.is_write_shutdown.store(true, Ordering::Relaxed); self.is_write_shutdown.store(true, Ordering::Relaxed);
self.writer_pollee.add_events(IoEvents::ERR); self.writer_pollee.notify(IoEvents::ERR);
} }
SockShutdownCmd::SHUT_RD => (), SockShutdownCmd::SHUT_RD => (),
} }
@ -102,7 +103,7 @@ impl Init {
match cmd { match cmd {
SockShutdownCmd::SHUT_RD | SockShutdownCmd::SHUT_RDWR => { SockShutdownCmd::SHUT_RD | SockShutdownCmd::SHUT_RDWR => {
self.is_read_shutdown.store(true, Ordering::Relaxed); self.is_read_shutdown.store(true, Ordering::Relaxed);
self.reader_pollee.add_events(IoEvents::HUP); self.reader_pollee.notify(IoEvents::HUP);
} }
SockShutdownCmd::SHUT_WR => (), SockShutdownCmd::SHUT_WR => (),
} }
@ -115,8 +116,22 @@ impl Init {
pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut PollHandle>) -> IoEvents { pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut PollHandle>) -> IoEvents {
// To avoid loss of events, this must be compatible with // To avoid loss of events, this must be compatible with
// `Connected::poll`/`Listener::poll`. // `Connected::poll`/`Listener::poll`.
let reader_events = self.reader_pollee.poll(mask, poller.as_deref_mut()); let reader_events = self
let writer_events = self.writer_pollee.poll(mask, poller); .reader_pollee
.poll_with(mask, poller.as_deref_mut(), || {
if self.is_read_shutdown.load(Ordering::Relaxed) {
IoEvents::HUP
} else {
IoEvents::empty()
}
});
let writer_events = self.writer_pollee.poll_with(mask, poller, || {
if self.is_write_shutdown.load(Ordering::Relaxed) {
IoEvents::OUT | IoEvents::ERR
} else {
IoEvents::OUT
}
});
// According to the Linux implementation, we always have `IoEvents::HUP` in this state. // According to the Linux implementation, we always have `IoEvents::HUP` in this state.
// Meanwhile, it is in `IoEvents::ALWAYS_POLL`, so we always return it. // Meanwhile, it is in `IoEvents::ALWAYS_POLL`, so we always return it.

View File

@ -1,6 +1,6 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicUsize, Ordering}; use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use ostd::sync::WaitQueue; use ostd::sync::WaitQueue;
@ -22,6 +22,7 @@ use crate::{
pub(super) struct Listener { pub(super) struct Listener {
backlog: Arc<Backlog>, backlog: Arc<Backlog>,
is_write_shutdown: AtomicBool,
writer_pollee: Pollee, writer_pollee: Pollee,
} }
@ -31,17 +32,16 @@ impl Listener {
reader_pollee: Pollee, reader_pollee: Pollee,
writer_pollee: Pollee, writer_pollee: Pollee,
backlog: usize, backlog: usize,
is_shutdown: bool, is_read_shutdown: bool,
is_write_shutdown: bool,
) -> Self { ) -> Self {
// Note that the I/O events can be correctly inherited from `Init`. There is no need to
// explicitly call `Pollee::reset_io_events`.
let backlog = BACKLOG_TABLE let backlog = BACKLOG_TABLE
.add_backlog(addr, reader_pollee, backlog, is_shutdown) .add_backlog(addr, reader_pollee, backlog, is_read_shutdown)
.unwrap(); .unwrap();
writer_pollee.del_events(IoEvents::OUT);
Self { Self {
backlog, backlog,
is_write_shutdown: AtomicBool::new(is_write_shutdown),
writer_pollee, writer_pollee,
} }
} }
@ -65,7 +65,8 @@ impl Listener {
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) { pub(super) fn shutdown(&self, cmd: SockShutdownCmd) {
match cmd { match cmd {
SockShutdownCmd::SHUT_WR | SockShutdownCmd::SHUT_RDWR => { SockShutdownCmd::SHUT_WR | SockShutdownCmd::SHUT_RDWR => {
self.writer_pollee.add_events(IoEvents::ERR); self.is_write_shutdown.store(true, Ordering::Relaxed);
self.writer_pollee.notify(IoEvents::ERR);
} }
SockShutdownCmd::SHUT_RD => (), SockShutdownCmd::SHUT_RD => (),
} }
@ -80,7 +81,14 @@ impl Listener {
pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut PollHandle>) -> IoEvents { pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut PollHandle>) -> IoEvents {
let reader_events = self.backlog.poll(mask, poller.as_deref_mut()); let reader_events = self.backlog.poll(mask, poller.as_deref_mut());
let writer_events = self.writer_pollee.poll(mask, poller);
let writer_events = self.writer_pollee.poll_with(mask, poller, || {
if self.is_write_shutdown.load(Ordering::Relaxed) {
IoEvents::ERR
} else {
IoEvents::empty()
}
});
combine_io_events(mask, reader_events, writer_events) combine_io_events(mask, reader_events, writer_events)
} }
@ -172,11 +180,7 @@ impl Backlog {
let Some(incoming_conns) = &mut *locked_incoming_conns else { let Some(incoming_conns) = &mut *locked_incoming_conns else {
return_errno_with_message!(Errno::EINVAL, "the socket is shut down for reading"); return_errno_with_message!(Errno::EINVAL, "the socket is shut down for reading");
}; };
let conn = incoming_conns.pop_front(); let conn = incoming_conns.pop_front();
if incoming_conns.is_empty() {
self.pollee.del_events(IoEvents::IN);
}
drop(locked_incoming_conns); drop(locked_incoming_conns);
@ -199,8 +203,7 @@ impl Backlog {
let mut incoming_conns = self.incoming_conns.lock(); let mut incoming_conns = self.incoming_conns.lock();
*incoming_conns = None; *incoming_conns = None;
self.pollee.add_events(IoEvents::HUP); self.pollee.notify(IoEvents::HUP);
self.pollee.del_events(IoEvents::IN);
drop(incoming_conns); drop(incoming_conns);
@ -208,7 +211,22 @@ impl Backlog {
} }
fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee.poll(mask, poller) self.pollee
.poll_with(mask, poller, || self.check_io_events())
}
fn check_io_events(&self) -> IoEvents {
let incoming_conns = self.incoming_conns.lock();
if let Some(conns) = &*incoming_conns {
if !conns.is_empty() {
IoEvents::IN
} else {
IoEvents::empty()
}
} else {
IoEvents::HUP
}
} }
} }
@ -240,9 +258,9 @@ impl Backlog {
} }
let (client_conn, server_conn) = init.into_connected(self.addr.clone()); let (client_conn, server_conn) = init.into_connected(self.addr.clone());
incoming_conns.push_back(server_conn);
self.pollee.add_events(IoEvents::IN); incoming_conns.push_back(server_conn);
self.pollee.notify(IoEvents::IN);
Ok(client_conn) Ok(client_conn)
} }

View File

@ -16,7 +16,7 @@ use super::{
listen::Listen, listen::Listen,
}, },
}; };
use crate::{events::IoEvents, prelude::*, return_errno_with_message, util::MultiRead}; use crate::{prelude::*, return_errno_with_message, util::MultiRead};
/// Manage all active sockets /// Manage all active sockets
pub struct VsockSpace { pub struct VsockSpace {
@ -237,7 +237,6 @@ impl VsockSpace {
let connected = Arc::new(Connected::new(peer.into(), listen.addr())); let connected = Arc::new(Connected::new(peer.into(), listen.addr()));
connected.update_info(&event); connected.update_info(&event);
listen.push_incoming(connected).unwrap(); listen.push_incoming(connected).unwrap();
listen.update_io_events();
} }
VsockEventType::ConnectionResponse => { VsockEventType::ConnectionResponse => {
let connecting_sockets = self.connecting_sockets.disable_irq().lock(); let connecting_sockets = self.connecting_sockets.disable_irq().lock();
@ -253,7 +252,7 @@ impl VsockSpace {
connecting.local_addr() connecting.local_addr()
); );
connecting.update_info(&event); connecting.update_info(&event);
connecting.add_events(IoEvents::IN); connecting.set_connected();
} }
VsockEventType::Disconnected { .. } => { VsockEventType::Disconnected { .. } => {
let connected_sockets = self.connected_sockets.read_irq_disabled(); let connected_sockets = self.connected_sockets.read_irq_disabled();
@ -296,7 +295,6 @@ impl VsockSpace {
if !connected.add_connection_buffer(body) { if !connected.add_connection_buffer(body) {
return Err(SocketError::BufferTooShort); return Err(SocketError::BufferTooShort);
} }
connected.update_io_events();
} }
Ok(Some(event)) Ok(Some(event))
}) })

View File

@ -27,7 +27,8 @@ impl Connected {
Self { Self {
connection: SpinLock::new(Connection::new(peer_addr, local_addr.port)), connection: SpinLock::new(Connection::new(peer_addr, local_addr.port)),
id: ConnectionID::new(local_addr, peer_addr), id: ConnectionID::new(local_addr, peer_addr),
pollee: Pollee::new(IoEvents::empty()), // FIXME: We should reuse `Pollee` from `Init`.
pollee: Pollee::new(),
} }
} }
@ -35,7 +36,8 @@ impl Connected {
Self { Self {
connection: SpinLock::new(Connection::new_from_info(connecting.info())), connection: SpinLock::new(Connection::new_from_info(connecting.info())),
id: connecting.id(), id: connecting.id(),
pollee: Pollee::new(IoEvents::empty()), // FIXME: We should reuse `Pollee` from `Init`.
pollee: Pollee::new(),
} }
} }
pub fn peer_addr(&self) -> VsockSocketAddr { pub fn peer_addr(&self) -> VsockSocketAddr {
@ -116,7 +118,11 @@ impl Connected {
pub fn add_connection_buffer(&self, bytes: &[u8]) -> bool { pub fn add_connection_buffer(&self, bytes: &[u8]) -> bool {
let mut connection = self.connection.disable_irq().lock(); let mut connection = self.connection.disable_irq().lock();
connection.add(bytes)
let result = connection.add(bytes);
self.pollee.notify(IoEvents::IN);
result
} }
pub fn set_peer_requested_shutdown(&self) { pub fn set_peer_requested_shutdown(&self) {
@ -127,16 +133,18 @@ impl Connected {
} }
pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee.poll(mask, poller) self.pollee
.poll_with(mask, poller, || self.check_io_events())
} }
pub fn update_io_events(&self) { fn check_io_events(&self) -> IoEvents {
let connection = self.connection.disable_irq().lock(); let connection = self.connection.disable_irq().lock();
// receive // receive
if !connection.buffer.is_empty() { if !connection.buffer.is_empty() {
self.pollee.add_events(IoEvents::IN); IoEvents::IN
} else { } else {
self.pollee.del_events(IoEvents::IN); IoEvents::empty()
} }
} }
} }

View File

@ -1,5 +1,7 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use aster_virtio::device::socket::connect::{ConnectionInfo, VsockEvent}; use aster_virtio::device::socket::connect::{ConnectionInfo, VsockEvent};
use super::connected::ConnectionID; use super::connected::ConnectionID;
@ -13,6 +15,7 @@ use crate::{
pub struct Connecting { pub struct Connecting {
id: ConnectionID, id: ConnectionID,
info: SpinLock<ConnectionInfo>, info: SpinLock<ConnectionInfo>,
is_connected: AtomicBool,
pollee: Pollee, pollee: Pollee,
} }
@ -21,7 +24,8 @@ impl Connecting {
Self { Self {
info: SpinLock::new(ConnectionInfo::new(peer_addr.into(), local_addr.port)), info: SpinLock::new(ConnectionInfo::new(peer_addr.into(), local_addr.port)),
id: ConnectionID::new(local_addr, peer_addr), id: ConnectionID::new(local_addr, peer_addr),
pollee: Pollee::new(IoEvents::empty()), is_connected: AtomicBool::new(false),
pollee: Pollee::new(),
} }
} }
@ -46,11 +50,21 @@ impl Connecting {
} }
pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee.poll(mask, poller) self.pollee
.poll_with(mask, poller, || self.check_io_events())
} }
pub fn add_events(&self, events: IoEvents) { fn check_io_events(&self) -> IoEvents {
self.pollee.add_events(events) if self.is_connected.load(Ordering::Relaxed) {
IoEvents::IN
} else {
IoEvents::empty()
}
}
pub fn set_connected(&self) {
self.is_connected.store(true, Ordering::Relaxed);
self.pollee.notify(IoEvents::IN);
} }
} }

View File

@ -7,19 +7,17 @@ use crate::{
VSOCK_GLOBAL, VSOCK_GLOBAL,
}, },
prelude::*, prelude::*,
process::signal::{PollHandle, Pollee}, process::signal::PollHandle,
}; };
pub struct Init { pub struct Init {
bound_addr: Mutex<Option<VsockSocketAddr>>, bound_addr: Mutex<Option<VsockSocketAddr>>,
pollee: Pollee,
} }
impl Init { impl Init {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
bound_addr: Mutex::new(None), bound_addr: Mutex::new(None),
pollee: Pollee::new(IoEvents::empty()),
} }
} }
@ -61,8 +59,8 @@ impl Init {
*self.bound_addr.lock() *self.bound_addr.lock()
} }
pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { pub fn poll(&self, _mask: IoEvents, _poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee.poll(mask, poller) IoEvents::empty()
} }
} }

View File

@ -18,7 +18,8 @@ impl Listen {
pub fn new(addr: VsockSocketAddr, backlog: usize) -> Self { pub fn new(addr: VsockSocketAddr, backlog: usize) -> Self {
Self { Self {
addr, addr,
pollee: Pollee::new(IoEvents::empty()), // FIXME: We should reuse `Pollee` from `Init`.
pollee: Pollee::new(),
backlog, backlog,
incoming_connection: SpinLock::new(VecDeque::with_capacity(backlog)), incoming_connection: SpinLock::new(VecDeque::with_capacity(backlog)),
} }
@ -33,8 +34,11 @@ impl Listen {
if incoming_connections.len() >= self.backlog { if incoming_connections.len() >= self.backlog {
return_errno_with_message!(Errno::ECONNREFUSED, "queue in listenging socket is full") return_errno_with_message!(Errno::ECONNREFUSED, "queue in listenging socket is full")
} }
// FIXME: check if the port is already used // FIXME: check if the port is already used
incoming_connections.push_back(connect); incoming_connections.push_back(connect);
self.pollee.notify(IoEvents::IN);
Ok(()) Ok(())
} }
@ -52,15 +56,17 @@ impl Listen {
} }
pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee.poll(mask, poller) self.pollee
.poll_with(mask, poller, || self.check_io_events())
} }
pub fn update_io_events(&self) { fn check_io_events(&self) -> IoEvents {
let incoming_connection = self.incoming_connection.disable_irq().lock(); let incoming_connection = self.incoming_connection.disable_irq().lock();
if !incoming_connection.is_empty() { if !incoming_connection.is_empty() {
self.pollee.add_events(IoEvents::IN); IoEvents::IN
} else { } else {
self.pollee.del_events(IoEvents::IN); IoEvents::empty()
} }
} }
} }

View File

@ -62,7 +62,6 @@ impl VsockStreamSocket {
}; };
let connected = listen.try_accept()?; let connected = listen.try_accept()?;
listen.update_io_events();
let peer_addr = connected.peer_addr(); let peer_addr = connected.peer_addr();
@ -104,7 +103,6 @@ impl VsockStreamSocket {
}; };
let read_size = connected.try_recv(writer)?; let read_size = connected.try_recv(writer)?;
connected.update_io_events();
let peer_addr = self.peer_addr()?; let peer_addr = self.peer_addr()?;
// If buffer is now empty and the peer requested shutdown, finish shutting down the // If buffer is now empty and the peer requested shutdown, finish shutting down the

View File

@ -1,7 +1,7 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::{ use core::{
sync::atomic::{AtomicU32, AtomicUsize, Ordering}, sync::atomic::{AtomicUsize, Ordering},
time::Duration, time::Duration,
}; };
@ -12,8 +12,16 @@ use crate::{
prelude::*, prelude::*,
}; };
/// A pollee maintains a set of active events, which can be polled with /// A pollee represents any I/O object (e.g., a file or socket) that can be polled.
/// pollers or be monitored with observers. ///
/// `Pollee` provides a standard mechanism to allow
/// 1. An I/O object to maintain its I/O readiness; and
/// 2. An interested part to poll the object's I/O readiness.
///
/// To correctly use the pollee, you need to call [`Pollee::notify`] whenever a new event arrives.
///
/// Then, [`Pollee::poll_with`] can allow you to register a [`Poller`] to wait for certain events,
/// or register a [`PollAdaptor`] to be notified when certain events occur.
pub struct Pollee { pub struct Pollee {
inner: Arc<PolleeInner>, inner: Arc<PolleeInner>,
} }
@ -21,30 +29,45 @@ pub struct Pollee {
struct PolleeInner { struct PolleeInner {
// A subject which is monitored with pollers. // A subject which is monitored with pollers.
subject: Subject<IoEvents, IoEvents>, subject: Subject<IoEvents, IoEvents>,
// For efficient manipulation, we use AtomicU32 instead of RwLock<IoEvents>. }
events: AtomicU32,
impl Default for Pollee {
fn default() -> Self {
Self::new()
}
} }
impl Pollee { impl Pollee {
/// Creates a new instance of pollee. /// Creates a new pollee.
pub fn new(init_events: IoEvents) -> Self { pub fn new() -> Self {
let inner = PolleeInner { let inner = PolleeInner {
subject: Subject::new(), subject: Subject::new(),
events: AtomicU32::new(init_events.bits()),
}; };
Self { Self {
inner: Arc::new(inner), inner: Arc::new(inner),
} }
} }
/// Returns the current events of the pollee filtered by the given event mask. /// Returns the current events filtered by the given event mask.
/// ///
/// If a poller is provided, the poller will start monitoring the pollee and receive event /// If a poller is provided, the poller will start monitoring the pollee and receive event
/// notification when the pollee receives interesting events. /// notification when the pollee receives interesting events.
/// ///
/// This operation is _atomic_ in the sense that if there are interesting events, either the /// This operation is _atomic_ in the sense that if there are interesting events, either the
/// events are returned or the poller is notified. /// events are returned or the poller is notified.
pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { ///
/// The above statement about atomicity is true even if `check` contains race conditions (and
/// in fact it always will, because even if it holds a lock, the lock will be released when
/// `check` returns).
pub fn poll_with<F>(
&self,
mask: IoEvents,
poller: Option<&mut PollHandle>,
check: F,
) -> IoEvents
where
F: FnOnce() -> IoEvents,
{
let mask = mask | IoEvents::ALWAYS_POLL; let mask = mask | IoEvents::ALWAYS_POLL;
// Register the provided poller. // Register the provided poller.
@ -53,7 +76,7 @@ impl Pollee {
} }
// Check events after the registration to prevent race conditions. // Check events after the registration to prevent race conditions.
self.events() & mask check() & mask
} }
fn register_poller(&self, poller: &mut PollHandle, mask: IoEvents) { fn register_poller(&self, poller: &mut PollHandle, mask: IoEvents) {
@ -64,41 +87,18 @@ impl Pollee {
poller.pollees.push(Arc::downgrade(&self.inner)); poller.pollees.push(Arc::downgrade(&self.inner));
} }
/// Add some events to the pollee's state. /// Notifies pollers of some events.
/// ///
/// This method wakes up all registered pollers that are interested in /// This method wakes up all registered pollers that are interested in the events.
/// the added events. ///
pub fn add_events(&self, events: IoEvents) { /// The events can be spurious. This way, the caller can avoid expensive calculations and
self.inner.events.fetch_or(events.bits(), Ordering::Release); /// simply add all possible ones.
pub fn notify(&self, events: IoEvents) {
self.inner.subject.notify_observers(&events); self.inner.subject.notify_observers(&events);
} }
/// Remove some events from the pollee's state.
///
/// This method will not wake up registered pollers even when
/// the pollee still has some interesting events to the pollers.
pub fn del_events(&self, events: IoEvents) {
self.inner
.events
.fetch_and(!events.bits(), Ordering::Release);
} }
/// Reset the pollee's state. /// An opaque handle that can be used as an argument of the [`Pollable::poll`] method.
///
/// Reset means removing all events on the pollee.
pub fn reset_events(&self) {
self.inner
.events
.fetch_and(!IoEvents::all().bits(), Ordering::Release);
}
fn events(&self) -> IoEvents {
let event_bits = self.inner.events.load(Ordering::Acquire);
IoEvents::from_bits(event_bits).unwrap()
}
}
/// An opaque handle that can be used as an argument of the [`Pollee::poll`] method.
/// ///
/// This type can represent an entity of [`PollAdaptor`] or [`Poller`], which is done via the /// This type can represent an entity of [`PollAdaptor`] or [`Poller`], which is done via the
/// [`PollAdaptor::as_handle_mut`] and [`Poller::as_handle_mut`] methods. /// [`PollAdaptor::as_handle_mut`] and [`Poller::as_handle_mut`] methods.
@ -146,11 +146,11 @@ impl Drop for PollHandle {
} }
} }
/// An adaptor to make an [`Observer`] usable for [`Pollee::poll`]. /// An adaptor to make an [`Observer`] usable for [`Pollable::poll`].
/// ///
/// Normally, [`Pollee::poll`] accepts a [`Poller`] which is used to wait for events. By using this /// Normally, [`Pollable::poll`] accepts a [`Poller`] which is used to wait for events. By using
/// adaptor, it is possible to use any [`Observer`] with [`Pollee::poll`]. The observer will be /// this adaptor, it is possible to use any [`Observer`] with [`Pollable::poll`]. The observer will
/// notified whenever there are new events. /// be notified whenever there are new events.
pub struct PollAdaptor<O> { pub struct PollAdaptor<O> {
// The event observer. // The event observer.
observer: Arc<O>, observer: Arc<O>,
@ -258,18 +258,18 @@ impl Observer<IoEvents> for EventCounter {
/// The `Pollable` trait allows for waiting for events and performing event-based operations. /// The `Pollable` trait allows for waiting for events and performing event-based operations.
/// ///
/// Implementors are required to provide a method, [`Pollable::poll`], which is usually implemented /// Implementors are required to provide a method, [`Pollable::poll`], which is usually implemented
/// by simply calling [`Pollee::poll`] on the internal [`Pollee`]. This trait provides another /// by simply calling [`Pollable::poll`] on the internal [`Pollee`]. This trait provides another
/// method, [`Pollable::wait_events`], to allow waiting for events and performing operations /// method, [`Pollable::wait_events`], to allow waiting for events and performing operations
/// according to the events. /// according to the events.
/// ///
/// This trait is added instead of creating a new method in [`Pollee`] because sometimes we do not /// This trait is added instead of creating a new method in [`Pollee`] because sometimes we do not
/// have access to the internal [`Pollee`], but there is a method that provides the same semantics /// have access to the internal [`Pollee`], but there is a method that provides the same semantics
/// as [`Pollee::poll`] and we need to perform event-based operations using that method. /// as [`Pollable::poll`] and we need to perform event-based operations using that method.
pub trait Pollable { pub trait Pollable {
/// Returns the interesting events now and monitors their occurrence in the future if the /// Returns the interesting events now and monitors their occurrence in the future if the
/// poller is provided. /// poller is provided.
/// ///
/// This method has the same semantics as [`Pollee::poll`]. /// This method has the same semantics as [`Pollee::poll_with`].
fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents; fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents;
/// Waits for events and performs event-based operations. /// Waits for events and performs event-based operations.

View File

@ -85,7 +85,7 @@ impl EventFile {
fn new(init_val: u64, flags: Flags) -> Self { fn new(init_val: u64, flags: Flags) -> Self {
let counter = Mutex::new(init_val); let counter = Mutex::new(init_val);
let pollee = Pollee::new(IoEvents::OUT); let pollee = Pollee::new();
let write_wait_queue = WaitQueue::new(); let write_wait_queue = WaitQueue::new();
Self { Self {
counter, counter,
@ -99,35 +99,24 @@ impl EventFile {
self.flags.lock().contains(Flags::EFD_NONBLOCK) self.flags.lock().contains(Flags::EFD_NONBLOCK)
} }
fn update_io_state(&self, counter: &MutexGuard<u64>) { fn check_io_events(&self) -> IoEvents {
let is_readable = **counter != 0; let counter = self.counter.lock();
let mut events = IoEvents::empty();
let is_readable = *counter != 0;
if is_readable {
events |= IoEvents::IN;
}
// if it is possible to write a value of at least "1" // if it is possible to write a value of at least "1"
// without blocking, the file is writable // without blocking, the file is writable
let is_writable = **counter < Self::MAX_COUNTER_VALUE; let is_writable = *counter < Self::MAX_COUNTER_VALUE;
if is_writable { if is_writable {
if is_readable { events |= IoEvents::OUT;
self.pollee.add_events(IoEvents::IN | IoEvents::OUT);
} else {
self.pollee.add_events(IoEvents::OUT);
self.pollee.del_events(IoEvents::IN);
} }
self.write_wait_queue.wake_all(); events
return;
}
if is_readable {
self.pollee.add_events(IoEvents::IN);
self.pollee.del_events(IoEvents::OUT);
return;
}
self.pollee.del_events(IoEvents::IN | IoEvents::OUT);
// TODO: deal with overflow logic
} }
fn try_read(&self, writer: &mut VmWriter) -> Result<()> { fn try_read(&self, writer: &mut VmWriter) -> Result<()> {
@ -147,7 +136,8 @@ impl EventFile {
*counter = 0; *counter = 0;
} }
self.update_io_state(&counter); self.pollee.notify(IoEvents::OUT);
self.write_wait_queue.wake_all();
Ok(()) Ok(())
} }
@ -165,7 +155,7 @@ impl EventFile {
if new_value <= Self::MAX_COUNTER_VALUE { if new_value <= Self::MAX_COUNTER_VALUE {
*counter = new_value; *counter = new_value;
self.update_io_state(&counter); self.pollee.notify(IoEvents::IN);
return Ok(()); return Ok(());
} }
@ -175,7 +165,8 @@ impl EventFile {
impl Pollable for EventFile { impl Pollable for EventFile {
fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee.poll(mask, poller) self.pollee
.poll_with(mask, poller, || self.check_io_events())
} }
} }