diff --git a/Cargo.lock b/Cargo.lock index a2726570..860b85a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -70,6 +70,7 @@ checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16" name = "aster-bigtcp" version = "0.1.0" dependencies = [ + "bitflags 1.3.2", "keyable-arc", "ostd", "smoltcp", diff --git a/kernel/libs/aster-bigtcp/Cargo.toml b/kernel/libs/aster-bigtcp/Cargo.toml index 0a102ffa..dddbb579 100644 --- a/kernel/libs/aster-bigtcp/Cargo.toml +++ b/kernel/libs/aster-bigtcp/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +bitflags = "1.3" keyable-arc = { path = "../keyable-arc" } ostd = { path = "../../../ostd" } smoltcp = { git = "https://github.com/asterinas/smoltcp", tag = "r_2024-11-08_f07e5b5", default-features = false, features = [ diff --git a/kernel/libs/aster-bigtcp/src/iface/common.rs b/kernel/libs/aster-bigtcp/src/iface/common.rs index 660649a9..c4be367c 100644 --- a/kernel/libs/aster-bigtcp/src/iface/common.rs +++ b/kernel/libs/aster-bigtcp/src/iface/common.rs @@ -220,13 +220,13 @@ impl IfaceCommon { context.poll_egress(device, dispatch_phy); tcp_sockets.iter().for_each(|socket| { - if socket.has_new_events() { - socket.on_iface_events(); + if socket.has_events() { + socket.on_events(); } }); udp_sockets.iter().for_each(|socket| { - if socket.has_new_events() { - socket.on_iface_events(); + if socket.has_events() { + socket.on_events(); } }); diff --git a/kernel/libs/aster-bigtcp/src/socket/bound.rs b/kernel/libs/aster-bigtcp/src/socket/bound.rs index da914532..9620ff02 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound.rs @@ -6,7 +6,7 @@ use alloc::{ }; use core::{ ops::{Deref, DerefMut}, - sync::atomic::{AtomicBool, AtomicU64, Ordering}, + sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering}, }; use ostd::sync::{LocalIrqDisabled, RwLock, SpinLock, SpinLockGuard}; @@ -17,7 +17,10 @@ use smoltcp::{ wire::{IpAddress, IpEndpoint, IpRepr, TcpControl, TcpRepr, UdpRepr}, }; -use super::{event::SocketEventObserver, RawTcpSocket, RawUdpSocket}; +use super::{ + event::{SocketEventObserver, SocketEvents}, + RawTcpSocket, RawUdpSocket, TcpStateCheck, +}; use crate::iface::Iface; pub struct BoundSocket(Arc>); @@ -44,8 +47,8 @@ pub struct BoundSocketInner { port: u16, socket: T, observer: RwLock>, + events: AtomicU8, next_poll_at_ms: AtomicU64, - has_new_events: AtomicBool, } /// States needed by [`BoundTcpSocketInner`] but not [`BoundUdpSocketInner`]. @@ -56,6 +59,7 @@ pub struct TcpSocket { struct RawTcpSocketExt { socket: Box, + has_connected: bool, /// Whether the socket is in the background. /// /// A background socket is a socket with its corresponding [`BoundSocket`] dropped. This means @@ -79,6 +83,22 @@ impl DerefMut for RawTcpSocketExt { } } +impl RawTcpSocketExt { + fn on_new_state(&mut self) -> SocketEvents { + if self.may_send() { + self.has_connected = true; + } + + if self.is_peer_closed() { + SocketEvents::PEER_CLOSED + } else if self.is_closed() { + SocketEvents::CLOSED + } else { + SocketEvents::empty() + } + } +} + impl TcpSocket { fn lock(&self) -> SpinLockGuard { self.socket.lock() @@ -123,6 +143,7 @@ impl AnySocket for TcpSocket { fn new(socket: Box) -> Self { let socket_ext = RawTcpSocketExt { socket, + has_connected: false, in_background: false, }; @@ -184,8 +205,8 @@ impl BoundSocket { port, socket: T::new(socket), observer: RwLock::new(observer), + events: AtomicU8::new(0), next_poll_at_ms: AtomicU64::new(u64::MAX), - has_new_events: AtomicBool::new(false), })) } @@ -204,7 +225,7 @@ impl BoundSocket { pub fn set_observer(&self, new_observer: Weak) { *self.0.observer.write_irq_disabled() = new_observer; - self.0.on_iface_events(); + self.0.on_events(); } /// Returns the observer. @@ -229,6 +250,12 @@ impl BoundSocket { } } +pub enum ConnectState { + Connecting, + Connected, + Refused, +} + impl BoundTcpSocket { /// Connects to a remote endpoint. pub fn connect( @@ -240,11 +267,26 @@ impl BoundTcpSocket { let mut socket = self.0.socket.lock(); - let result = socket.connect(iface.context(), remote_endpoint, self.0.port); + socket.connect(iface.context(), remote_endpoint, self.0.port)?; + + socket.has_connected = false; self.0 .update_next_poll_at_ms(socket.poll_at(iface.context())); - result + Ok(()) + } + + /// Returns the state of the connecting procedure. + pub fn connect_state(&self) -> ConnectState { + let socket = self.0.socket.lock(); + + if socket.state() == State::SynSent || socket.state() == State::SynReceived { + ConnectState::Connecting + } else if socket.has_connected { + ConnectState::Connected + } else { + ConnectState::Refused + } } /// Listens at a specified endpoint. @@ -366,22 +408,33 @@ impl BoundUdpSocket { } impl BoundSocketInner { - pub(crate) fn has_new_events(&self) -> bool { - self.has_new_events.load(Ordering::Relaxed) + pub(crate) fn has_events(&self) -> bool { + self.events.load(Ordering::Relaxed) != 0 } - pub(crate) fn on_iface_events(&self) { - self.has_new_events.store(false, Ordering::Relaxed); + pub(crate) fn on_events(&self) { + // This method can only be called to process network events, so we assume we are holding the + // poll lock and no race conditions can occur. + let events = self.events.load(Ordering::Relaxed); + self.events.store(0, Ordering::Relaxed); // We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we // get the read lock. let observer = Weak::upgrade(&*self.observer.read()); if let Some(inner) = observer { - inner.on_events(); + inner.on_events(SocketEvents::from_bits_truncate(events)); } } + fn add_events(&self, new_events: SocketEvents) { + // This method can only be called to add network events, so we assume we are holding the + // poll lock and no race conditions can occur. + let events = self.events.load(Ordering::Relaxed); + self.events + .store(events | new_events.bits(), Ordering::Relaxed); + } + /// Returns the next polling time. /// /// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required @@ -396,8 +449,6 @@ impl BoundSocketInner { /// method also marks that there may be new events, so that the event observer provided by /// [`BoundSocket::set_observer`] can be notified later. fn update_next_poll_at_ms(&self, poll_at: PollAt) { - self.has_new_events.store(true, Ordering::Relaxed); - match poll_at { PollAt::Now => self.next_poll_at_ms.store(0, Ordering::Relaxed), PollAt::Time(instant) => self @@ -484,11 +535,21 @@ impl BoundTcpSocketInner { return TcpProcessResult::NotProcessed; } + let old_state = socket.state(); + // For TCP, receiving an ACK packet can free up space in the queue, allowing more packets + // to be queued. + let mut events = SocketEvents::CAN_RECV | SocketEvents::CAN_SEND; + let result = match socket.process(cx, ip_repr, tcp_repr) { None => TcpProcessResult::Processed, Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr), }; + if socket.state() != old_state { + events |= socket.on_new_state(); + } + + self.add_events(events); self.update_next_poll_at_ms(socket.poll_at(cx)); self.socket.update_dead(&socket); @@ -506,6 +567,9 @@ impl BoundTcpSocketInner { { let mut socket = self.socket.lock(); + let old_state = socket.state(); + let mut events = SocketEvents::empty(); + let mut reply = None; socket .dispatch(cx, |cx, (ip_repr, tcp_repr)| { @@ -521,8 +585,14 @@ impl BoundTcpSocketInner { break; } reply = socket.process(cx, ip_repr, tcp_repr); + events |= SocketEvents::CAN_RECV | SocketEvents::CAN_SEND; } + if socket.state() != old_state { + events |= socket.on_new_state(); + } + + self.add_events(events); self.update_next_poll_at_ms(socket.poll_at(cx)); self.socket.update_dead(&socket); @@ -552,6 +622,8 @@ impl BoundUdpSocketInner { udp_repr, udp_payload, ); + + self.add_events(SocketEvents::CAN_RECV); self.update_next_poll_at_ms(socket.poll_at(cx)); true @@ -570,6 +642,9 @@ impl BoundUdpSocketInner { Ok::<(), ()>(()) }) .unwrap(); + + // For UDP, dequeuing a packet means that we can queue more packets. + self.add_events(SocketEvents::CAN_SEND); self.update_next_poll_at_ms(socket.poll_at(cx)); } } diff --git a/kernel/libs/aster-bigtcp/src/socket/event.rs b/kernel/libs/aster-bigtcp/src/socket/event.rs index d0770aae..0aba28c2 100644 --- a/kernel/libs/aster-bigtcp/src/socket/event.rs +++ b/kernel/libs/aster-bigtcp/src/socket/event.rs @@ -3,9 +3,19 @@ /// A observer that will be invoked whenever events occur on the socket. pub trait SocketEventObserver: Send + Sync { /// Notifies that events occurred on the socket. - fn on_events(&self); + fn on_events(&self, events: SocketEvents); } impl SocketEventObserver for () { - fn on_events(&self) {} + fn on_events(&self, _events: SocketEvents) {} +} + +bitflags::bitflags! { + /// Socket events caused by the _network_. + pub struct SocketEvents: u8 { + const CAN_RECV = 1; + const CAN_SEND = 2; + const PEER_CLOSED = 4; + const CLOSED = 8; + } } diff --git a/kernel/libs/aster-bigtcp/src/socket/mod.rs b/kernel/libs/aster-bigtcp/src/socket/mod.rs index c8428ac9..36feec33 100644 --- a/kernel/libs/aster-bigtcp/src/socket/mod.rs +++ b/kernel/libs/aster-bigtcp/src/socket/mod.rs @@ -2,12 +2,13 @@ mod bound; mod event; +mod state; mod unbound; -pub use bound::{BoundTcpSocket, BoundUdpSocket}; +pub use bound::{BoundTcpSocket, BoundUdpSocket, ConnectState}; pub(crate) use bound::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult}; -pub use event::SocketEventObserver; -pub use smoltcp::socket::tcp::State as TcpState; +pub use event::{SocketEventObserver, SocketEvents}; +pub use state::TcpStateCheck; pub use unbound::{ UnboundTcpSocket, UnboundUdpSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN, UDP_SEND_PAYLOAD_LEN, diff --git a/kernel/libs/aster-bigtcp/src/socket/state.rs b/kernel/libs/aster-bigtcp/src/socket/state.rs new file mode 100644 index 00000000..df56c16f --- /dev/null +++ b/kernel/libs/aster-bigtcp/src/socket/state.rs @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MPL-2.0 + +use smoltcp::socket::tcp::State as TcpState; + +use super::RawTcpSocket; + +pub trait TcpStateCheck { + /// Checks if the peer socket has closed its sending side. + /// + /// If the sending side of this socket is also closed, this method will return `false`. + /// In such cases, you should verify using [`is_closed`]. + fn is_peer_closed(&self) -> bool; + + /// Checks if the socket is fully closed. + /// + /// This function returns `true` if both this socket and the peer have closed their sending sides. + /// + /// This TCP state corresponds to the `Normal Close Sequence` and `Simultaneous Close Sequence` + /// as outlined in RFC793 (https://datatracker.ietf.org/doc/html/rfc793#page-39). + fn is_closed(&self) -> bool; +} + +impl TcpStateCheck for RawTcpSocket { + fn is_peer_closed(&self) -> bool { + self.state() == TcpState::CloseWait + } + + fn is_closed(&self) -> bool { + !self.is_open() || self.state() == TcpState::Closing || self.state() == TcpState::LastAck + } +} diff --git a/kernel/src/device/pty/pty.rs b/kernel/src/device/pty/pty.rs index 10792be4..51d782c5 100644 --- a/kernel/src/device/pty/pty.rs +++ b/kernel/src/device/pty/pty.rs @@ -48,7 +48,7 @@ impl PtyMaster { output: ldisc, input: SpinLock::new(RingBuffer::new(BUFFER_CAPACITY)), job_control, - pollee: Pollee::new(IoEvents::OUT), + pollee: Pollee::new(), weak_self: weak_ref.clone(), }) } @@ -64,7 +64,7 @@ impl PtyMaster { pub(super) fn slave_push_char(&self, ch: u8) { let mut input = self.input.disable_irq().lock(); input.push_overwrite(ch); - self.update_state(&input); + self.pollee.notify(IoEvents::IN); } pub(super) fn slave_poll( @@ -82,7 +82,9 @@ impl PtyMaster { let poll_out_mask = mask & IoEvents::OUT; if !poll_out_mask.is_empty() { - let poll_out_status = self.pollee.poll(poll_out_mask, poller); + let poll_out_status = self + .pollee + .poll_with(poll_out_mask, poller, || self.check_io_events()); poll_status |= poll_out_status; } @@ -100,17 +102,16 @@ impl PtyMaster { return_errno_with_message!(Errno::EAGAIN, "the buffer is empty"); } - let read_len = input.read_fallible(writer)?; - self.update_state(&input); - - Ok(read_len) + input.read_fallible(writer) } - fn update_state(&self, buf: &RingBuffer) { - if buf.is_empty() { - self.pollee.del_events(IoEvents::IN) + fn check_io_events(&self) -> IoEvents { + let input = self.input.disable_irq().lock(); + + if !input.is_empty() { + IoEvents::IN | IoEvents::OUT } else { - self.pollee.add_events(IoEvents::IN); + IoEvents::OUT } } } @@ -121,7 +122,11 @@ impl Pollable for PtyMaster { let poll_in_mask = mask & IoEvents::IN; if !poll_in_mask.is_empty() { - let poll_in_status = self.pollee.poll(poll_in_mask, poller.as_deref_mut()); + let poll_in_status = self + .pollee + .poll_with(poll_in_mask, poller.as_deref_mut(), || { + self.check_io_events() + }); poll_status |= poll_in_status; } @@ -157,7 +162,7 @@ impl FileIo for PtyMaster { }); } - self.update_state(&input); + self.pollee.notify(IoEvents::IN); Ok(write_len) } diff --git a/kernel/src/device/tty/line_discipline.rs b/kernel/src/device/tty/line_discipline.rs index db9eb30e..cf6d3b95 100644 --- a/kernel/src/device/tty/line_discipline.rs +++ b/kernel/src/device/tty/line_discipline.rs @@ -89,7 +89,8 @@ impl CurrentLine { impl Pollable for LineDiscipline { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.pollee.poll(mask, poller) + self.pollee + .poll_with(mask, poller, || self.check_io_events()) } } @@ -108,7 +109,7 @@ impl LineDiscipline { read_buffer: SpinLock::new(RingBuffer::new(BUFFER_CAPACITY)), termios: SpinLock::new(KernelTermios::default()), winsize: SpinLock::new(WinSize::default()), - pollee: Pollee::new(IoEvents::empty()), + pollee: Pollee::new(), send_signal, work_item, work_item_para: Arc::new(SpinLock::new(LineDisciplineWorkPara::new())), @@ -140,7 +141,7 @@ impl LineDiscipline { // Raw mode if !termios.is_canonical_mode() { self.read_buffer.lock().push_overwrite(ch); - self.update_readable_state(); + self.pollee.notify(IoEvents::IN); return; } @@ -166,6 +167,7 @@ impl LineDiscipline { let current_line_chars = current_line.drain(); for char in current_line_chars { self.read_buffer.lock().push_overwrite(char); + self.pollee.notify(IoEvents::IN); } } @@ -173,8 +175,6 @@ impl LineDiscipline { // Printable character self.current_line.lock().push_char(ch); } - - self.update_readable_state(); } fn may_send_signal(&self, termios: &KernelTermios, ch: u8) -> bool { @@ -198,13 +198,13 @@ impl LineDiscipline { true } - pub fn update_readable_state(&self) { + fn check_io_events(&self) -> IoEvents { let buffer = self.read_buffer.lock(); if !buffer.is_empty() { - self.pollee.add_events(IoEvents::IN); + IoEvents::IN } else { - self.pollee.del_events(IoEvents::IN); + IoEvents::empty() } } @@ -265,7 +265,6 @@ impl LineDiscipline { unreachable!() } }; - self.update_readable_state(); Ok(read_len) } diff --git a/kernel/src/device/tty/mod.rs b/kernel/src/device/tty/mod.rs index f2f2979b..72215c46 100644 --- a/kernel/src/device/tty/mod.rs +++ b/kernel/src/device/tty/mod.rs @@ -126,9 +126,6 @@ impl FileIo for Tty { }; self.set_foreground(&pgid)?; - // Some background processes may be waiting on the wait queue, - // when set_fg, the background processes may be able to read. - self.ldisc.update_readable_state(); Ok(0) } IoctlCmd::TCSETS => { diff --git a/kernel/src/fs/epoll/entry.rs b/kernel/src/fs/epoll/entry.rs index b216dc31..96f329c6 100644 --- a/kernel/src/fs/epoll/entry.rs +++ b/kernel/src/fs/epoll/entry.rs @@ -288,7 +288,7 @@ impl ReadySet { Self { entries: SpinLock::new(VecDeque::new()), pop_guard: Mutex::new(PopGuard), - pollee: Pollee::new(IoEvents::empty()), + pollee: Pollee::new(), } } @@ -315,7 +315,7 @@ impl ReadySet { // Even if the entry is already set to ready, // there might be new events that we are interested in. // Wake the poller anyway. - self.pollee.add_events(IoEvents::IN); + self.pollee.notify(IoEvents::IN); } pub(super) fn lock_pop(&self) -> ReadySetPopIter { @@ -327,7 +327,18 @@ impl ReadySet { } pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.pollee.poll(mask, poller) + self.pollee + .poll_with(mask, poller, || self.check_io_events()) + } + + fn check_io_events(&self) -> IoEvents { + let entries = self.entries.lock(); + + if !entries.is_empty() { + IoEvents::IN + } else { + IoEvents::empty() + } } } @@ -356,11 +367,6 @@ impl Iterator for ReadySetPopIter<'_> { // must exist, so we can just unwrap it. let weak_entry = entries.pop_front().unwrap(); - // Clear the epoll file's events if there are no ready entries. - if entries.len() == 0 { - self.ready_set.pollee.del_events(IoEvents::IN); - } - let Some(entry) = Weak::upgrade(&weak_entry) else { // The entry has been deleted. continue; diff --git a/kernel/src/fs/utils/channel.rs b/kernel/src/fs/utils/channel.rs index f4efca68..d9d49329 100644 --- a/kernel/src/fs/utils/channel.rs +++ b/kernel/src/fs/utils/channel.rs @@ -96,7 +96,9 @@ macro_rules! impl_common_methods_for_channel { } pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.this_end().pollee.poll(mask, poller) + self.this_end() + .pollee + .poll_with(mask, poller, || self.check_io_events()) } }; } @@ -110,27 +112,17 @@ impl Producer { &self.0.common.consumer } - fn update_pollee(&self) { - // In theory, `rb.free_len()`/`rb.is_empty()`, where the `rb` is taken from either - // `this_end` or `peer_end`, should reflect the same state. However, we need to take the - // correct lock when updating the events to avoid races between the state check and the - // event update. - + fn check_io_events(&self) -> IoEvents { let this_end = self.this_end(); let rb = this_end.rb(); - if self.is_shutdown() { - // The POLLOUT event is always set in this case. Don't try to remove it. - } else if rb.free_len() < PIPE_BUF { - this_end.pollee.del_events(IoEvents::OUT); - } - drop(rb); - let peer_end = self.peer_end(); - let rb = peer_end.rb(); - if !rb.is_empty() { - peer_end.pollee.add_events(IoEvents::IN); + if self.is_shutdown() { + IoEvents::ERR | IoEvents::OUT + } else if rb.free_len() > PIPE_BUF { + IoEvents::OUT + } else { + IoEvents::empty() } - drop(rb); } impl_common_methods_for_channel!(); @@ -153,7 +145,7 @@ impl Producer { } let written_len = self.0.write(reader)?; - self.update_pollee(); + self.peer_end().pollee.notify(IoEvents::IN); if written_len > 0 { Ok(written_len) @@ -179,7 +171,7 @@ impl Producer { let err = Error::with_message(Errno::EAGAIN, "the channel is full"); (err, item) })?; - self.update_pollee(); + self.peer_end().pollee.notify(IoEvents::IN); Ok(()) } @@ -200,25 +192,18 @@ impl Consumer { &self.0.common.producer } - fn update_pollee(&self) { - // In theory, `rb.free_len()`/`rb.is_empty()`, where the `rb` is taken from either - // `this_end` or `peer_end`, should reflect the same state. However, we need to take the - // correct lock when updating the events to avoid races between the state check and the - // event update. - + fn check_io_events(&self) -> IoEvents { let this_end = self.this_end(); let rb = this_end.rb(); - if rb.is_empty() { - this_end.pollee.del_events(IoEvents::IN); - } - drop(rb); - let peer_end = self.peer_end(); - let rb = peer_end.rb(); - if rb.free_len() >= PIPE_BUF { - peer_end.pollee.add_events(IoEvents::OUT); + let mut events = IoEvents::empty(); + if self.is_shutdown() { + events |= IoEvents::HUP; } - drop(rb); + if !rb.is_empty() { + events |= IoEvents::IN; + } + events } impl_common_methods_for_channel!(); @@ -239,7 +224,7 @@ impl Consumer { let is_shutdown = self.is_shutdown(); let read_len = self.0.read(writer)?; - self.update_pollee(); + self.peer_end().pollee.notify(IoEvents::OUT); if read_len > 0 { Ok(read_len) @@ -262,7 +247,7 @@ impl Consumer { let is_shutdown = self.is_shutdown(); let item = self.0.pop(); - self.update_pollee(); + self.peer_end().pollee.notify(IoEvents::OUT); if let Some(item) = item { Ok(Some(item)) @@ -346,25 +331,12 @@ impl Common { let (rb_producer, rb_consumer) = rb.split(); let producer = { - let polee = if let Some(pollee) = producer_pollee { - pollee.reset_events(); - pollee.add_events(IoEvents::OUT); - pollee - } else { - Pollee::new(IoEvents::OUT) - }; - - FifoInner::new(rb_producer, polee) + let pollee = producer_pollee.unwrap_or_default(); + FifoInner::new(rb_producer, pollee) }; let consumer = { - let pollee = if let Some(pollee) = consumer_pollee { - pollee.reset_events(); - pollee - } else { - Pollee::new(IoEvents::empty()) - }; - + let pollee = consumer_pollee.unwrap_or_default(); FifoInner::new(rb_consumer, pollee) }; @@ -389,19 +361,11 @@ impl Common { } // The POLLHUP event indicates that the write end is shut down. - // - // No need to take a lock. There is no race because no one is modifying this particular event. - self.consumer.pollee.add_events(IoEvents::HUP); + self.consumer.pollee.notify(IoEvents::HUP); // The POLLERR event indicates that the read end is shut down (so any subsequent writes // will fail with an `EPIPE` error). - // - // The lock is taken because we are also adding the POLLOUT event, which may have races - // with the event updates triggered by the writer. - let _rb = self.producer.rb(); - self.producer - .pollee - .add_events(IoEvents::ERR | IoEvents::OUT); + self.producer.pollee.notify(IoEvents::ERR | IoEvents::OUT); } } diff --git a/kernel/src/net/socket/ip/datagram/bound.rs b/kernel/src/net/socket/ip/datagram/bound.rs index fc2ed7b3..03b3d64d 100644 --- a/kernel/src/net/socket/ip/datagram/bound.rs +++ b/kernel/src/net/socket/ip/datagram/bound.rs @@ -9,7 +9,6 @@ use crate::{ events::IoEvents, net::{iface::BoundUdpSocket, socket::util::send_recv_flags::SendRecvFlags}, prelude::*, - process::signal::Pollee, util::{MultiRead, MultiWrite}, }; @@ -93,24 +92,19 @@ impl BoundDatagram { } } - pub(super) fn init_pollee(&self, pollee: &Pollee) { - pollee.reset_events(); - self.update_io_events(pollee) - } - - pub(super) fn update_io_events(&self, pollee: &Pollee) { + pub(super) fn check_io_events(&self) -> IoEvents { self.bound_socket.raw_with(|socket| { + let mut events = IoEvents::empty(); + if socket.can_recv() { - pollee.add_events(IoEvents::IN); - } else { - pollee.del_events(IoEvents::IN); + events |= IoEvents::IN; } if socket.can_send() { - pollee.add_events(IoEvents::OUT); - } else { - pollee.del_events(IoEvents::OUT); + events |= IoEvents::OUT; } - }); + + events + }) } } diff --git a/kernel/src/net/socket/ip/datagram/mod.rs b/kernel/src/net/socket/ip/datagram/mod.rs index b2d8b76d..86538d3d 100644 --- a/kernel/src/net/socket/ip/datagram/mod.rs +++ b/kernel/src/net/socket/ip/datagram/mod.rs @@ -2,7 +2,10 @@ use core::sync::atomic::{AtomicBool, Ordering}; -use aster_bigtcp::{socket::SocketEventObserver, wire::IpEndpoint}; +use aster_bigtcp::{ + socket::{SocketEventObserver, SocketEvents}, + wire::IpEndpoint, +}; use takeable::Takeable; use self::{bound::BoundDatagram, unbound::UnboundDatagram}; @@ -98,12 +101,10 @@ impl DatagramSocket { pub fn new(nonblocking: bool) -> Arc { Arc::new_cyclic(|me| { let unbound_datagram = UnboundDatagram::new(me.clone() as _); - let pollee = Pollee::new(IoEvents::empty()); - unbound_datagram.init_pollee(&pollee); Self { inner: RwLock::new(Takeable::new(Inner::Unbound(unbound_datagram))), nonblocking: AtomicBool::new(nonblocking), - pollee, + pollee: Pollee::new(), options: RwLock::new(OptionSet::new()), } }) @@ -141,7 +142,6 @@ impl DatagramSocket { return (err_inner, Err(err)); } }; - bound_datagram.init_pollee(&self.pollee); (Inner::Bound(bound_datagram), Ok(())) }) } @@ -199,18 +199,20 @@ impl DatagramSocket { sent_bytes } - fn update_io_events(&self) { + fn check_io_events(&self) -> IoEvents { let inner = self.inner.read(); - let Inner::Bound(bound_datagram) = inner.as_ref() else { - return; - }; - bound_datagram.update_io_events(&self.pollee); + + match inner.as_ref() { + Inner::Unbound(unbound_datagram) => unbound_datagram.check_io_events(), + Inner::Bound(bound_socket) => bound_socket.check_io_events(), + } } } impl Pollable for DatagramSocket { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.pollee.poll(mask, poller) + self.pollee + .poll_with(mask, poller, || self.check_io_events()) } } @@ -283,7 +285,6 @@ impl Socket for DatagramSocket { return (err_inner, Err(err)); } }; - bound_datagram.init_pollee(&self.pollee); (Inner::Bound(bound_datagram), Ok(())) }) } @@ -388,7 +389,17 @@ impl Socket for DatagramSocket { } impl SocketEventObserver for DatagramSocket { - fn on_events(&self) { - self.update_io_events(); + fn on_events(&self, events: SocketEvents) { + let mut io_events = IoEvents::empty(); + + if events.contains(SocketEvents::CAN_RECV) { + io_events |= IoEvents::IN; + } + + if events.contains(SocketEvents::CAN_SEND) { + io_events |= IoEvents::OUT; + } + + self.pollee.notify(io_events); } } diff --git a/kernel/src/net/socket/ip/datagram/unbound.rs b/kernel/src/net/socket/ip/datagram/unbound.rs index 7415a2b2..5eb51f94 100644 --- a/kernel/src/net/socket/ip/datagram/unbound.rs +++ b/kernel/src/net/socket/ip/datagram/unbound.rs @@ -8,9 +8,7 @@ use aster_bigtcp::{ }; use super::bound::BoundDatagram; -use crate::{ - events::IoEvents, net::socket::ip::common::bind_socket, prelude::*, process::signal::Pollee, -}; +use crate::{events::IoEvents, net::socket::ip::common::bind_socket, prelude::*}; pub struct UnboundDatagram { unbound_socket: Box, @@ -44,8 +42,7 @@ impl UnboundDatagram { Ok(BoundDatagram::new(bound_socket)) } - pub(super) fn init_pollee(&self, pollee: &Pollee) { - pollee.reset_events(); - pollee.add_events(IoEvents::OUT); + pub(super) fn check_io_events(&self) -> IoEvents { + IoEvents::OUT } } diff --git a/kernel/src/net/socket/ip/stream/connected.rs b/kernel/src/net/socket/ip/stream/connected.rs index 07e8ceca..ca8a13c6 100644 --- a/kernel/src/net/socket/ip/stream/connected.rs +++ b/kernel/src/net/socket/ip/stream/connected.rs @@ -5,7 +5,7 @@ use core::sync::atomic::{AtomicBool, Ordering}; use aster_bigtcp::{ errors::tcp::{RecvError, SendError}, - socket::{RawTcpSocket, SocketEventObserver, TcpState}, + socket::{SocketEventObserver, TcpStateCheck}, wire::IpEndpoint, }; @@ -61,16 +61,21 @@ impl ConnectedStream { } pub fn shutdown(&self, cmd: SockShutdownCmd, pollee: &Pollee) -> Result<()> { + let mut events = IoEvents::empty(); + if cmd.shut_read() { self.is_receiving_closed.store(true, Ordering::Relaxed); - self.update_io_events(pollee); + events |= IoEvents::IN | IoEvents::RDHUP; } if cmd.shut_write() { self.is_sending_closed.store(true, Ordering::Relaxed); self.bound_socket.close(); + events |= IoEvents::OUT | IoEvents::HUP; } + pollee.notify(events); + Ok(()) } @@ -132,17 +137,12 @@ impl ConnectedStream { Ok(()) } - pub(super) fn init_pollee(&self, pollee: &Pollee) { - pollee.reset_events(); - self.update_io_events(pollee); - } - - pub(super) fn update_io_events(&self, pollee: &Pollee) { + pub(super) fn check_io_events(&self) -> IoEvents { self.bound_socket.raw_with(|socket| { - if is_peer_closed(socket) { + if socket.is_peer_closed() { // Only the sending side of peer socket is closed self.is_receiving_closed.store(true, Ordering::Relaxed); - } else if is_closed(socket) { + } else if socket.is_closed() { // The sending side of both peer socket and this socket are closed self.is_receiving_closed.store(true, Ordering::Relaxed); self.is_sending_closed.store(true, Ordering::Relaxed); @@ -151,50 +151,32 @@ impl ConnectedStream { let is_receiving_closed = self.is_receiving_closed.load(Ordering::Relaxed); let is_sending_closed = self.is_sending_closed.load(Ordering::Relaxed); + let mut events = IoEvents::empty(); + // If the receiving side is closed, always add events IN and RDHUP; // otherwise, check if the socket can receive. if is_receiving_closed { - pollee.add_events(IoEvents::IN | IoEvents::RDHUP); + events |= IoEvents::IN | IoEvents::RDHUP; } else if socket.can_recv() { - pollee.add_events(IoEvents::IN); - } else { - pollee.del_events(IoEvents::IN); + events |= IoEvents::IN; } // If the sending side is closed, always add an OUT event; // otherwise, check if the socket can send. if is_sending_closed || socket.can_send() { - pollee.add_events(IoEvents::OUT); - } else { - pollee.del_events(IoEvents::OUT); + events |= IoEvents::OUT; } // If both sending and receiving sides are closed, add a HUP event. if is_receiving_closed && is_sending_closed { - pollee.add_events(IoEvents::HUP); + events |= IoEvents::HUP; } - }); + + events + }) } pub(super) fn set_observer(&self, observer: Weak) { self.bound_socket.set_observer(observer) } } - -/// Checks if the peer socket has closed its sending side. -/// -/// If the sending side of this socket is also closed, this method will return `false`. -/// In such cases, you should verify using [`is_closed`]. -fn is_peer_closed(socket: &RawTcpSocket) -> bool { - socket.state() == TcpState::CloseWait -} - -/// Checks if the socket is fully closed. -/// -/// This function returns `true` if both this socket and the peer have closed their sending sides. -/// -/// This TCP state corresponds to the `Normal Close Sequence` and `Simultaneous Close Sequence` -/// as outlined in RFC793 (https://datatracker.ietf.org/doc/html/rfc793#page-39). -fn is_closed(socket: &RawTcpSocket) -> bool { - !socket.is_open() || socket.state() == TcpState::Closing || socket.state() == TcpState::LastAck -} diff --git a/kernel/src/net/socket/ip/stream/connecting.rs b/kernel/src/net/socket/ip/stream/connecting.rs index d2ca8d7f..4448dabf 100644 --- a/kernel/src/net/socket/ip/stream/connecting.rs +++ b/kernel/src/net/socket/ip/stream/connecting.rs @@ -1,21 +1,19 @@ // SPDX-License-Identifier: MPL-2.0 -use aster_bigtcp::wire::IpEndpoint; -use ostd::sync::LocalIrqDisabled; +use aster_bigtcp::{socket::ConnectState, wire::IpEndpoint}; use super::{connected::ConnectedStream, init::InitStream}; -use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee}; +use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*}; pub struct ConnectingStream { bound_socket: BoundTcpSocket, remote_endpoint: IpEndpoint, - conn_result: SpinLock, LocalIrqDisabled>, } -#[derive(Clone, Copy)] -enum ConnResult { - Connected, - Refused, +pub enum ConnResult { + Connecting(ConnectingStream), + Connected(ConnectedStream), + Refused(InitStream), } impl ConnectingStream { @@ -41,27 +39,28 @@ impl ConnectingStream { Ok(Self { bound_socket, remote_endpoint, - conn_result: SpinLock::new(None), }) } pub fn has_result(&self) -> bool { - self.conn_result.lock().is_some() + match self.bound_socket.connect_state() { + ConnectState::Connecting => false, + ConnectState::Connected => true, + ConnectState::Refused => true, + } } - pub fn into_result(self) -> core::result::Result { - let conn_result = *self.conn_result.lock(); - match conn_result { - Some(ConnResult::Connected) => Ok(ConnectedStream::new( + pub fn into_result(self) -> ConnResult { + let next_state = self.bound_socket.connect_state(); + + match next_state { + ConnectState::Connecting => ConnResult::Connecting(self), + ConnectState::Connected => ConnResult::Connected(ConnectedStream::new( self.bound_socket, self.remote_endpoint, true, )), - Some(ConnResult::Refused) => Err(( - Error::with_message(Errno::ECONNREFUSED, "the connection is refused"), - InitStream::new_bound(self.bound_socket), - )), - None => unreachable!("`has_result` must be true before calling `into_result`"), + ConnectState::Refused => ConnResult::Refused(InitStream::new_bound(self.bound_socket)), } } @@ -73,43 +72,7 @@ impl ConnectingStream { self.remote_endpoint } - pub(super) fn init_pollee(&self, pollee: &Pollee) { - pollee.reset_events(); - } - - pub(super) fn update_io_events(&self, pollee: &Pollee) { - if self.conn_result.lock().is_some() { - return; - } - - self.bound_socket.raw_with(|socket| { - let mut result = self.conn_result.lock(); - if result.is_some() { - return; - } - - // Connected - if socket.can_send() { - *result = Some(ConnResult::Connected); - pollee.add_events(IoEvents::OUT); - return; - } - // Connecting - if socket.is_open() { - return; - } - // Refused - *result = Some(ConnResult::Refused); - pollee.add_events(IoEvents::OUT); - - // Add `IoEvents::OUT` because the man pages say "EINPROGRESS [..] It is possible to - // select(2) or poll(2) for completion by selecting the socket for writing". For - // details, see . - // - // TODO: It is better to do the state transition and let `ConnectedStream` or - // `InitStream` set the correct I/O events. However, the state transition is delayed - // because we're probably in IRQ handlers. Maybe mark the `pollee` as obsolete and - // re-calculate the I/O events in `poll`. - }) + pub(super) fn check_io_events(&self) -> IoEvents { + IoEvents::empty() } } diff --git a/kernel/src/net/socket/ip/stream/init.rs b/kernel/src/net/socket/ip/stream/init.rs index d188099c..5c28824f 100644 --- a/kernel/src/net/socket/ip/stream/init.rs +++ b/kernel/src/net/socket/ip/stream/init.rs @@ -15,7 +15,6 @@ use crate::{ socket::ip::common::{bind_socket, get_ephemeral_endpoint}, }, prelude::*, - process::signal::Pollee, }; pub enum InitStream { @@ -101,9 +100,8 @@ impl InitStream { } } - pub(super) fn init_pollee(&self, pollee: &Pollee) { - pollee.reset_events(); + pub(super) fn check_io_events(&self) -> IoEvents { // Linux adds OUT and HUP events for a newly created socket - pollee.add_events(IoEvents::OUT | IoEvents::HUP); + IoEvents::OUT | IoEvents::HUP } } diff --git a/kernel/src/net/socket/ip/stream/listen.rs b/kernel/src/net/socket/ip/stream/listen.rs index 8f2e91ca..a26f0e92 100644 --- a/kernel/src/net/socket/ip/stream/listen.rs +++ b/kernel/src/net/socket/ip/stream/listen.rs @@ -5,7 +5,7 @@ use aster_bigtcp::{ }; use super::connected::ConnectedStream; -use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee}; +use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*}; pub struct ListenStream { backlog: usize, @@ -80,22 +80,17 @@ impl ListenStream { self.bound_socket.local_endpoint().unwrap() } - pub(super) fn init_pollee(&self, pollee: &Pollee) { - pollee.reset_events(); - self.update_io_events(pollee); - } - - pub(super) fn update_io_events(&self, pollee: &Pollee) { + pub(super) fn check_io_events(&self) -> IoEvents { let backlog_sockets = self.backlog_sockets.read(); let can_accept = backlog_sockets.iter().any(|socket| socket.can_accept()); - // FIXME: If network packets come in simultaneously, the socket state may change in the - // middle. This can cause the wrong I/O events to be added or deleted. + // If network packets come in simultaneously, the socket state may change in the middle. + // However, the current pollee implementation should be able to handle this race condition. if can_accept { - pollee.add_events(IoEvents::IN); + IoEvents::IN } else { - pollee.del_events(IoEvents::IN); + IoEvents::empty() } } } diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index e1910a12..9c14d3c9 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -2,9 +2,12 @@ use core::sync::atomic::{AtomicBool, Ordering}; -use aster_bigtcp::{socket::SocketEventObserver, wire::IpEndpoint}; +use aster_bigtcp::{ + socket::{SocketEventObserver, SocketEvents}, + wire::IpEndpoint, +}; use connected::ConnectedStream; -use connecting::ConnectingStream; +use connecting::{ConnResult, ConnectingStream}; use init::InitStream; use listen::ListenStream; use options::{Congestion, MaxSegment, NoDelay, WindowClamp}; @@ -81,27 +84,23 @@ impl StreamSocket { pub fn new(nonblocking: bool) -> Arc { Arc::new_cyclic(|me| { let init_stream = InitStream::new(me.clone() as _); - let pollee = Pollee::new(IoEvents::empty()); - init_stream.init_pollee(&pollee); Self { options: RwLock::new(OptionSet::new()), state: RwLock::new(Takeable::new(State::Init(init_stream))), is_nonblocking: AtomicBool::new(nonblocking), - pollee, + pollee: Pollee::new(), } }) } fn new_connected(connected_stream: ConnectedStream) -> Arc { Arc::new_cyclic(move |me| { - let pollee = Pollee::new(IoEvents::empty()); connected_stream.set_observer(me.clone() as _); - connected_stream.init_pollee(&pollee); Self { options: RwLock::new(OptionSet::new()), state: RwLock::new(Takeable::new(State::Connected(connected_stream))), is_nonblocking: AtomicBool::new(false), - pollee, + pollee: Pollee::new(), } }) } @@ -161,23 +160,26 @@ impl StreamSocket { _ => return (options, state), } - let result = state.borrow_result(|owned_state| { + state.borrow(|owned_state| { let State::Connecting(connecting_stream) = owned_state else { unreachable!("`State::Connecting` is checked before calling `borrow_result`"); }; - let connected_stream = match connecting_stream.into_result() { - Ok(connected_stream) => connected_stream, - Err((err, init_stream)) => { - init_stream.init_pollee(&self.pollee); - return (State::Init(init_stream), Err(err)); + match connecting_stream.into_result() { + ConnResult::Connecting(connecting_stream) => State::Connecting(connecting_stream), + ConnResult::Connected(connected_stream) => { + options.socket.set_sock_errors(None); + State::Connected(connected_stream) } - }; - connected_stream.init_pollee(&self.pollee); - - (State::Connected(connected_stream), Ok(())) + ConnResult::Refused(init_stream) => { + options.socket.set_sock_errors(Some(Error::with_message( + Errno::ECONNREFUSED, + "the connection is refused", + ))); + State::Init(init_stream) + } + } }); - options.socket.set_sock_errors(result.err()); (options, state) } @@ -224,7 +226,6 @@ impl StreamSocket { return (State::Init(init_stream), Some(Err(err))); } }; - connecting_stream.init_pollee(&self.pollee); ( State::Connecting(connecting_stream), @@ -269,8 +270,6 @@ impl StreamSocket { }; let accepted = listen_stream.try_accept().map(|connected_stream| { - listen_stream.update_io_events(&self.pollee); - let remote_endpoint = connected_stream.remote_endpoint(); let accepted_socket = Self::new_connected(connected_stream); (accepted_socket as _, remote_endpoint.into()) @@ -354,30 +353,22 @@ impl StreamSocket { } } - fn update_io_events(&self) { - let state = self.state.read(); - match state.as_ref() { - State::Init(_) => (), - State::Connecting(connecting_stream) => { - connecting_stream.update_io_events(&self.pollee) - } - State::Listen(listen_stream) => { - listen_stream.update_io_events(&self.pollee); - } - State::Connected(connected_stream) => { - connected_stream.update_io_events(&self.pollee); - } - } + fn check_io_events(&self) -> IoEvents { + let state = self.read_updated_state(); - // Note: Network events can cause a state transition from `State::Connecting` to - // `State::Connected`/`State::Init`. The state transition is delayed until - // `update_connecting`is triggered by user events, see that method for details. + match state.as_ref() { + State::Init(init_stream) => init_stream.check_io_events(), + State::Connecting(connecting_stream) => connecting_stream.check_io_events(), + State::Listen(listen_stream) => listen_stream.check_io_events(), + State::Connected(connected_stream) => connected_stream.check_io_events(), + } } } impl Pollable for StreamSocket { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.pollee.poll(mask, poller) + self.pollee + .poll_with(mask, poller, || self.check_io_events()) } } @@ -492,7 +483,6 @@ impl Socket for StreamSocket { return (State::Init(init_stream), Err(err)); } }; - listen_stream.init_pollee(&self.pollee); (State::Listen(listen_stream), Ok(())) }) @@ -677,8 +667,26 @@ impl Socket for StreamSocket { } impl SocketEventObserver for StreamSocket { - fn on_events(&self) { - self.update_io_events(); + fn on_events(&self, events: SocketEvents) { + let mut io_events = IoEvents::empty(); + + if events.contains(SocketEvents::CAN_RECV) { + io_events |= IoEvents::IN; + } + + if events.contains(SocketEvents::CAN_SEND) { + io_events |= IoEvents::OUT; + } + + if events.contains(SocketEvents::PEER_CLOSED) { + io_events |= IoEvents::IN | IoEvents::RDHUP; + } + + if events.contains(SocketEvents::CLOSED) { + io_events |= IoEvents::IN | IoEvents::OUT | IoEvents::RDHUP | IoEvents::HUP; + } + + self.pollee.notify(io_events); } } diff --git a/kernel/src/net/socket/unix/stream/init.rs b/kernel/src/net/socket/unix/stream/init.rs index 2766ba1d..eedf2012 100644 --- a/kernel/src/net/socket/unix/stream/init.rs +++ b/kernel/src/net/socket/unix/stream/init.rs @@ -28,8 +28,8 @@ impl Init { pub(super) fn new() -> Self { Self { addr: None, - reader_pollee: Pollee::new(IoEvents::empty()), - writer_pollee: Pollee::new(IoEvents::OUT), + reader_pollee: Pollee::new(), + writer_pollee: Pollee::new(), is_read_shutdown: AtomicBool::new(false), is_write_shutdown: AtomicBool::new(false), } @@ -87,6 +87,7 @@ impl Init { self.writer_pollee, backlog, self.is_read_shutdown.into_inner(), + self.is_write_shutdown.into_inner(), )) } @@ -94,7 +95,7 @@ impl Init { match cmd { SockShutdownCmd::SHUT_WR | SockShutdownCmd::SHUT_RDWR => { self.is_write_shutdown.store(true, Ordering::Relaxed); - self.writer_pollee.add_events(IoEvents::ERR); + self.writer_pollee.notify(IoEvents::ERR); } SockShutdownCmd::SHUT_RD => (), } @@ -102,7 +103,7 @@ impl Init { match cmd { SockShutdownCmd::SHUT_RD | SockShutdownCmd::SHUT_RDWR => { self.is_read_shutdown.store(true, Ordering::Relaxed); - self.reader_pollee.add_events(IoEvents::HUP); + self.reader_pollee.notify(IoEvents::HUP); } SockShutdownCmd::SHUT_WR => (), } @@ -115,8 +116,22 @@ impl Init { pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut PollHandle>) -> IoEvents { // To avoid loss of events, this must be compatible with // `Connected::poll`/`Listener::poll`. - let reader_events = self.reader_pollee.poll(mask, poller.as_deref_mut()); - let writer_events = self.writer_pollee.poll(mask, poller); + let reader_events = self + .reader_pollee + .poll_with(mask, poller.as_deref_mut(), || { + if self.is_read_shutdown.load(Ordering::Relaxed) { + IoEvents::HUP + } else { + IoEvents::empty() + } + }); + let writer_events = self.writer_pollee.poll_with(mask, poller, || { + if self.is_write_shutdown.load(Ordering::Relaxed) { + IoEvents::OUT | IoEvents::ERR + } else { + IoEvents::OUT + } + }); // According to the Linux implementation, we always have `IoEvents::HUP` in this state. // Meanwhile, it is in `IoEvents::ALWAYS_POLL`, so we always return it. diff --git a/kernel/src/net/socket/unix/stream/listener.rs b/kernel/src/net/socket/unix/stream/listener.rs index e9c3139e..4a85565f 100644 --- a/kernel/src/net/socket/unix/stream/listener.rs +++ b/kernel/src/net/socket/unix/stream/listener.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 -use core::sync::atomic::{AtomicUsize, Ordering}; +use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use ostd::sync::WaitQueue; @@ -22,6 +22,7 @@ use crate::{ pub(super) struct Listener { backlog: Arc, + is_write_shutdown: AtomicBool, writer_pollee: Pollee, } @@ -31,17 +32,16 @@ impl Listener { reader_pollee: Pollee, writer_pollee: Pollee, backlog: usize, - is_shutdown: bool, + is_read_shutdown: bool, + is_write_shutdown: bool, ) -> Self { - // Note that the I/O events can be correctly inherited from `Init`. There is no need to - // explicitly call `Pollee::reset_io_events`. let backlog = BACKLOG_TABLE - .add_backlog(addr, reader_pollee, backlog, is_shutdown) + .add_backlog(addr, reader_pollee, backlog, is_read_shutdown) .unwrap(); - writer_pollee.del_events(IoEvents::OUT); Self { backlog, + is_write_shutdown: AtomicBool::new(is_write_shutdown), writer_pollee, } } @@ -65,7 +65,8 @@ impl Listener { pub(super) fn shutdown(&self, cmd: SockShutdownCmd) { match cmd { SockShutdownCmd::SHUT_WR | SockShutdownCmd::SHUT_RDWR => { - self.writer_pollee.add_events(IoEvents::ERR); + self.is_write_shutdown.store(true, Ordering::Relaxed); + self.writer_pollee.notify(IoEvents::ERR); } SockShutdownCmd::SHUT_RD => (), } @@ -80,7 +81,14 @@ impl Listener { pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut PollHandle>) -> IoEvents { let reader_events = self.backlog.poll(mask, poller.as_deref_mut()); - let writer_events = self.writer_pollee.poll(mask, poller); + + let writer_events = self.writer_pollee.poll_with(mask, poller, || { + if self.is_write_shutdown.load(Ordering::Relaxed) { + IoEvents::ERR + } else { + IoEvents::empty() + } + }); combine_io_events(mask, reader_events, writer_events) } @@ -172,11 +180,7 @@ impl Backlog { let Some(incoming_conns) = &mut *locked_incoming_conns else { return_errno_with_message!(Errno::EINVAL, "the socket is shut down for reading"); }; - let conn = incoming_conns.pop_front(); - if incoming_conns.is_empty() { - self.pollee.del_events(IoEvents::IN); - } drop(locked_incoming_conns); @@ -199,8 +203,7 @@ impl Backlog { let mut incoming_conns = self.incoming_conns.lock(); *incoming_conns = None; - self.pollee.add_events(IoEvents::HUP); - self.pollee.del_events(IoEvents::IN); + self.pollee.notify(IoEvents::HUP); drop(incoming_conns); @@ -208,7 +211,22 @@ impl Backlog { } fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.pollee.poll(mask, poller) + self.pollee + .poll_with(mask, poller, || self.check_io_events()) + } + + fn check_io_events(&self) -> IoEvents { + let incoming_conns = self.incoming_conns.lock(); + + if let Some(conns) = &*incoming_conns { + if !conns.is_empty() { + IoEvents::IN + } else { + IoEvents::empty() + } + } else { + IoEvents::HUP + } } } @@ -240,9 +258,9 @@ impl Backlog { } let (client_conn, server_conn) = init.into_connected(self.addr.clone()); - incoming_conns.push_back(server_conn); - self.pollee.add_events(IoEvents::IN); + incoming_conns.push_back(server_conn); + self.pollee.notify(IoEvents::IN); Ok(client_conn) } diff --git a/kernel/src/net/socket/vsock/common.rs b/kernel/src/net/socket/vsock/common.rs index 8617fa5a..eedf8303 100644 --- a/kernel/src/net/socket/vsock/common.rs +++ b/kernel/src/net/socket/vsock/common.rs @@ -16,7 +16,7 @@ use super::{ listen::Listen, }, }; -use crate::{events::IoEvents, prelude::*, return_errno_with_message, util::MultiRead}; +use crate::{prelude::*, return_errno_with_message, util::MultiRead}; /// Manage all active sockets pub struct VsockSpace { @@ -237,7 +237,6 @@ impl VsockSpace { let connected = Arc::new(Connected::new(peer.into(), listen.addr())); connected.update_info(&event); listen.push_incoming(connected).unwrap(); - listen.update_io_events(); } VsockEventType::ConnectionResponse => { let connecting_sockets = self.connecting_sockets.disable_irq().lock(); @@ -253,7 +252,7 @@ impl VsockSpace { connecting.local_addr() ); connecting.update_info(&event); - connecting.add_events(IoEvents::IN); + connecting.set_connected(); } VsockEventType::Disconnected { .. } => { let connected_sockets = self.connected_sockets.read_irq_disabled(); @@ -296,7 +295,6 @@ impl VsockSpace { if !connected.add_connection_buffer(body) { return Err(SocketError::BufferTooShort); } - connected.update_io_events(); } Ok(Some(event)) }) diff --git a/kernel/src/net/socket/vsock/stream/connected.rs b/kernel/src/net/socket/vsock/stream/connected.rs index e1d706d5..3dd53e57 100644 --- a/kernel/src/net/socket/vsock/stream/connected.rs +++ b/kernel/src/net/socket/vsock/stream/connected.rs @@ -27,7 +27,8 @@ impl Connected { Self { connection: SpinLock::new(Connection::new(peer_addr, local_addr.port)), id: ConnectionID::new(local_addr, peer_addr), - pollee: Pollee::new(IoEvents::empty()), + // FIXME: We should reuse `Pollee` from `Init`. + pollee: Pollee::new(), } } @@ -35,7 +36,8 @@ impl Connected { Self { connection: SpinLock::new(Connection::new_from_info(connecting.info())), id: connecting.id(), - pollee: Pollee::new(IoEvents::empty()), + // FIXME: We should reuse `Pollee` from `Init`. + pollee: Pollee::new(), } } pub fn peer_addr(&self) -> VsockSocketAddr { @@ -116,7 +118,11 @@ impl Connected { pub fn add_connection_buffer(&self, bytes: &[u8]) -> bool { let mut connection = self.connection.disable_irq().lock(); - connection.add(bytes) + + let result = connection.add(bytes); + self.pollee.notify(IoEvents::IN); + + result } pub fn set_peer_requested_shutdown(&self) { @@ -127,16 +133,18 @@ impl Connected { } pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.pollee.poll(mask, poller) + self.pollee + .poll_with(mask, poller, || self.check_io_events()) } - pub fn update_io_events(&self) { + fn check_io_events(&self) -> IoEvents { let connection = self.connection.disable_irq().lock(); + // receive if !connection.buffer.is_empty() { - self.pollee.add_events(IoEvents::IN); + IoEvents::IN } else { - self.pollee.del_events(IoEvents::IN); + IoEvents::empty() } } } diff --git a/kernel/src/net/socket/vsock/stream/connecting.rs b/kernel/src/net/socket/vsock/stream/connecting.rs index 80e5fe51..2a23f3cc 100644 --- a/kernel/src/net/socket/vsock/stream/connecting.rs +++ b/kernel/src/net/socket/vsock/stream/connecting.rs @@ -1,5 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 +use core::sync::atomic::{AtomicBool, Ordering}; + use aster_virtio::device::socket::connect::{ConnectionInfo, VsockEvent}; use super::connected::ConnectionID; @@ -13,6 +15,7 @@ use crate::{ pub struct Connecting { id: ConnectionID, info: SpinLock, + is_connected: AtomicBool, pollee: Pollee, } @@ -21,7 +24,8 @@ impl Connecting { Self { info: SpinLock::new(ConnectionInfo::new(peer_addr.into(), local_addr.port)), id: ConnectionID::new(local_addr, peer_addr), - pollee: Pollee::new(IoEvents::empty()), + is_connected: AtomicBool::new(false), + pollee: Pollee::new(), } } @@ -46,11 +50,21 @@ impl Connecting { } pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.pollee.poll(mask, poller) + self.pollee + .poll_with(mask, poller, || self.check_io_events()) } - pub fn add_events(&self, events: IoEvents) { - self.pollee.add_events(events) + fn check_io_events(&self) -> IoEvents { + if self.is_connected.load(Ordering::Relaxed) { + IoEvents::IN + } else { + IoEvents::empty() + } + } + + pub fn set_connected(&self) { + self.is_connected.store(true, Ordering::Relaxed); + self.pollee.notify(IoEvents::IN); } } diff --git a/kernel/src/net/socket/vsock/stream/init.rs b/kernel/src/net/socket/vsock/stream/init.rs index 07c9a930..d1f8bfb8 100644 --- a/kernel/src/net/socket/vsock/stream/init.rs +++ b/kernel/src/net/socket/vsock/stream/init.rs @@ -7,19 +7,17 @@ use crate::{ VSOCK_GLOBAL, }, prelude::*, - process::signal::{PollHandle, Pollee}, + process::signal::PollHandle, }; pub struct Init { bound_addr: Mutex>, - pollee: Pollee, } impl Init { pub fn new() -> Self { Self { bound_addr: Mutex::new(None), - pollee: Pollee::new(IoEvents::empty()), } } @@ -61,8 +59,8 @@ impl Init { *self.bound_addr.lock() } - pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.pollee.poll(mask, poller) + pub fn poll(&self, _mask: IoEvents, _poller: Option<&mut PollHandle>) -> IoEvents { + IoEvents::empty() } } diff --git a/kernel/src/net/socket/vsock/stream/listen.rs b/kernel/src/net/socket/vsock/stream/listen.rs index 83eee925..cb19a4ba 100644 --- a/kernel/src/net/socket/vsock/stream/listen.rs +++ b/kernel/src/net/socket/vsock/stream/listen.rs @@ -18,7 +18,8 @@ impl Listen { pub fn new(addr: VsockSocketAddr, backlog: usize) -> Self { Self { addr, - pollee: Pollee::new(IoEvents::empty()), + // FIXME: We should reuse `Pollee` from `Init`. + pollee: Pollee::new(), backlog, incoming_connection: SpinLock::new(VecDeque::with_capacity(backlog)), } @@ -33,8 +34,11 @@ impl Listen { if incoming_connections.len() >= self.backlog { return_errno_with_message!(Errno::ECONNREFUSED, "queue in listenging socket is full") } + // FIXME: check if the port is already used incoming_connections.push_back(connect); + self.pollee.notify(IoEvents::IN); + Ok(()) } @@ -52,15 +56,17 @@ impl Listen { } pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.pollee.poll(mask, poller) + self.pollee + .poll_with(mask, poller, || self.check_io_events()) } - pub fn update_io_events(&self) { + fn check_io_events(&self) -> IoEvents { let incoming_connection = self.incoming_connection.disable_irq().lock(); + if !incoming_connection.is_empty() { - self.pollee.add_events(IoEvents::IN); + IoEvents::IN } else { - self.pollee.del_events(IoEvents::IN); + IoEvents::empty() } } } diff --git a/kernel/src/net/socket/vsock/stream/socket.rs b/kernel/src/net/socket/vsock/stream/socket.rs index 29a828d1..46b3aecf 100644 --- a/kernel/src/net/socket/vsock/stream/socket.rs +++ b/kernel/src/net/socket/vsock/stream/socket.rs @@ -62,7 +62,6 @@ impl VsockStreamSocket { }; let connected = listen.try_accept()?; - listen.update_io_events(); let peer_addr = connected.peer_addr(); @@ -104,7 +103,6 @@ impl VsockStreamSocket { }; let read_size = connected.try_recv(writer)?; - connected.update_io_events(); let peer_addr = self.peer_addr()?; // If buffer is now empty and the peer requested shutdown, finish shutting down the diff --git a/kernel/src/process/signal/poll.rs b/kernel/src/process/signal/poll.rs index 0c990993..bbae69d0 100644 --- a/kernel/src/process/signal/poll.rs +++ b/kernel/src/process/signal/poll.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 use core::{ - sync::atomic::{AtomicU32, AtomicUsize, Ordering}, + sync::atomic::{AtomicUsize, Ordering}, time::Duration, }; @@ -12,8 +12,16 @@ use crate::{ prelude::*, }; -/// A pollee maintains a set of active events, which can be polled with -/// pollers or be monitored with observers. +/// A pollee represents any I/O object (e.g., a file or socket) that can be polled. +/// +/// `Pollee` provides a standard mechanism to allow +/// 1. An I/O object to maintain its I/O readiness; and +/// 2. An interested part to poll the object's I/O readiness. +/// +/// To correctly use the pollee, you need to call [`Pollee::notify`] whenever a new event arrives. +/// +/// Then, [`Pollee::poll_with`] can allow you to register a [`Poller`] to wait for certain events, +/// or register a [`PollAdaptor`] to be notified when certain events occur. pub struct Pollee { inner: Arc, } @@ -21,30 +29,45 @@ pub struct Pollee { struct PolleeInner { // A subject which is monitored with pollers. subject: Subject, - // For efficient manipulation, we use AtomicU32 instead of RwLock. - events: AtomicU32, +} + +impl Default for Pollee { + fn default() -> Self { + Self::new() + } } impl Pollee { - /// Creates a new instance of pollee. - pub fn new(init_events: IoEvents) -> Self { + /// Creates a new pollee. + pub fn new() -> Self { let inner = PolleeInner { subject: Subject::new(), - events: AtomicU32::new(init_events.bits()), }; Self { inner: Arc::new(inner), } } - /// Returns the current events of the pollee filtered by the given event mask. + /// Returns the current events filtered by the given event mask. /// /// If a poller is provided, the poller will start monitoring the pollee and receive event /// notification when the pollee receives interesting events. /// /// This operation is _atomic_ in the sense that if there are interesting events, either the /// events are returned or the poller is notified. - pub fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { + /// + /// The above statement about atomicity is true even if `check` contains race conditions (and + /// in fact it always will, because even if it holds a lock, the lock will be released when + /// `check` returns). + pub fn poll_with( + &self, + mask: IoEvents, + poller: Option<&mut PollHandle>, + check: F, + ) -> IoEvents + where + F: FnOnce() -> IoEvents, + { let mask = mask | IoEvents::ALWAYS_POLL; // Register the provided poller. @@ -53,7 +76,7 @@ impl Pollee { } // Check events after the registration to prevent race conditions. - self.events() & mask + check() & mask } fn register_poller(&self, poller: &mut PollHandle, mask: IoEvents) { @@ -64,41 +87,18 @@ impl Pollee { poller.pollees.push(Arc::downgrade(&self.inner)); } - /// Add some events to the pollee's state. + /// Notifies pollers of some events. /// - /// This method wakes up all registered pollers that are interested in - /// the added events. - pub fn add_events(&self, events: IoEvents) { - self.inner.events.fetch_or(events.bits(), Ordering::Release); + /// This method wakes up all registered pollers that are interested in the events. + /// + /// The events can be spurious. This way, the caller can avoid expensive calculations and + /// simply add all possible ones. + pub fn notify(&self, events: IoEvents) { self.inner.subject.notify_observers(&events); } - - /// Remove some events from the pollee's state. - /// - /// This method will not wake up registered pollers even when - /// the pollee still has some interesting events to the pollers. - pub fn del_events(&self, events: IoEvents) { - self.inner - .events - .fetch_and(!events.bits(), Ordering::Release); - } - - /// Reset the pollee's state. - /// - /// Reset means removing all events on the pollee. - pub fn reset_events(&self) { - self.inner - .events - .fetch_and(!IoEvents::all().bits(), Ordering::Release); - } - - fn events(&self) -> IoEvents { - let event_bits = self.inner.events.load(Ordering::Acquire); - IoEvents::from_bits(event_bits).unwrap() - } } -/// An opaque handle that can be used as an argument of the [`Pollee::poll`] method. +/// An opaque handle that can be used as an argument of the [`Pollable::poll`] method. /// /// This type can represent an entity of [`PollAdaptor`] or [`Poller`], which is done via the /// [`PollAdaptor::as_handle_mut`] and [`Poller::as_handle_mut`] methods. @@ -146,11 +146,11 @@ impl Drop for PollHandle { } } -/// An adaptor to make an [`Observer`] usable for [`Pollee::poll`]. +/// An adaptor to make an [`Observer`] usable for [`Pollable::poll`]. /// -/// Normally, [`Pollee::poll`] accepts a [`Poller`] which is used to wait for events. By using this -/// adaptor, it is possible to use any [`Observer`] with [`Pollee::poll`]. The observer will be -/// notified whenever there are new events. +/// Normally, [`Pollable::poll`] accepts a [`Poller`] which is used to wait for events. By using +/// this adaptor, it is possible to use any [`Observer`] with [`Pollable::poll`]. The observer will +/// be notified whenever there are new events. pub struct PollAdaptor { // The event observer. observer: Arc, @@ -258,18 +258,18 @@ impl Observer for EventCounter { /// The `Pollable` trait allows for waiting for events and performing event-based operations. /// /// Implementors are required to provide a method, [`Pollable::poll`], which is usually implemented -/// by simply calling [`Pollee::poll`] on the internal [`Pollee`]. This trait provides another +/// by simply calling [`Pollable::poll`] on the internal [`Pollee`]. This trait provides another /// method, [`Pollable::wait_events`], to allow waiting for events and performing operations /// according to the events. /// /// This trait is added instead of creating a new method in [`Pollee`] because sometimes we do not /// have access to the internal [`Pollee`], but there is a method that provides the same semantics -/// as [`Pollee::poll`] and we need to perform event-based operations using that method. +/// as [`Pollable::poll`] and we need to perform event-based operations using that method. pub trait Pollable { /// Returns the interesting events now and monitors their occurrence in the future if the /// poller is provided. /// - /// This method has the same semantics as [`Pollee::poll`]. + /// This method has the same semantics as [`Pollee::poll_with`]. fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents; /// Waits for events and performs event-based operations. diff --git a/kernel/src/syscall/eventfd.rs b/kernel/src/syscall/eventfd.rs index e79df51d..e282e51d 100644 --- a/kernel/src/syscall/eventfd.rs +++ b/kernel/src/syscall/eventfd.rs @@ -85,7 +85,7 @@ impl EventFile { fn new(init_val: u64, flags: Flags) -> Self { let counter = Mutex::new(init_val); - let pollee = Pollee::new(IoEvents::OUT); + let pollee = Pollee::new(); let write_wait_queue = WaitQueue::new(); Self { counter, @@ -99,35 +99,24 @@ impl EventFile { self.flags.lock().contains(Flags::EFD_NONBLOCK) } - fn update_io_state(&self, counter: &MutexGuard) { - let is_readable = **counter != 0; + fn check_io_events(&self) -> IoEvents { + let counter = self.counter.lock(); + + let mut events = IoEvents::empty(); + + let is_readable = *counter != 0; + if is_readable { + events |= IoEvents::IN; + } // if it is possible to write a value of at least "1" // without blocking, the file is writable - let is_writable = **counter < Self::MAX_COUNTER_VALUE; - + let is_writable = *counter < Self::MAX_COUNTER_VALUE; if is_writable { - if is_readable { - self.pollee.add_events(IoEvents::IN | IoEvents::OUT); - } else { - self.pollee.add_events(IoEvents::OUT); - self.pollee.del_events(IoEvents::IN); - } - - self.write_wait_queue.wake_all(); - - return; + events |= IoEvents::OUT; } - if is_readable { - self.pollee.add_events(IoEvents::IN); - self.pollee.del_events(IoEvents::OUT); - return; - } - - self.pollee.del_events(IoEvents::IN | IoEvents::OUT); - - // TODO: deal with overflow logic + events } fn try_read(&self, writer: &mut VmWriter) -> Result<()> { @@ -147,7 +136,8 @@ impl EventFile { *counter = 0; } - self.update_io_state(&counter); + self.pollee.notify(IoEvents::OUT); + self.write_wait_queue.wake_all(); Ok(()) } @@ -165,7 +155,7 @@ impl EventFile { if new_value <= Self::MAX_COUNTER_VALUE { *counter = new_value; - self.update_io_state(&counter); + self.pollee.notify(IoEvents::IN); return Ok(()); } @@ -175,7 +165,8 @@ impl EventFile { impl Pollable for EventFile { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.pollee.poll(mask, poller) + self.pollee + .poll_with(mask, poller, || self.check_io_events()) } }