From ec2c6ab7a3f2e611181d6763ea6f9f61ea21b25a Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Sun, 17 Nov 2024 17:43:34 +0800 Subject: [PATCH] Make `Pollee` semi-stateless --- kernel/src/device/pty/pty.rs | 5 +- kernel/src/device/tty/line_discipline.rs | 2 + kernel/src/fs/epoll/entry.rs | 5 + kernel/src/fs/utils/channel.rs | 10 +- kernel/src/net/socket/ip/datagram/mod.rs | 8 +- kernel/src/net/socket/ip/stream/mod.rs | 7 + kernel/src/net/socket/unix/stream/listener.rs | 3 + .../src/net/socket/vsock/stream/connected.rs | 1 + kernel/src/net/socket/vsock/stream/listen.rs | 1 + kernel/src/process/signal/poll.rs | 251 +++++++++++++++++- ostd/src/mm/mod.rs | 3 +- 11 files changed, 282 insertions(+), 14 deletions(-) diff --git a/kernel/src/device/pty/pty.rs b/kernel/src/device/pty/pty.rs index 51d782c5f..d500b647f 100644 --- a/kernel/src/device/pty/pty.rs +++ b/kernel/src/device/pty/pty.rs @@ -102,7 +102,10 @@ impl PtyMaster { return_errno_with_message!(Errno::EAGAIN, "the buffer is empty"); } - input.read_fallible(writer) + let read_len = input.read_fallible(writer)?; + self.pollee.invalidate(); + + Ok(read_len) } fn check_io_events(&self) -> IoEvents { diff --git a/kernel/src/device/tty/line_discipline.rs b/kernel/src/device/tty/line_discipline.rs index cf6d3b95d..8f213a6c1 100644 --- a/kernel/src/device/tty/line_discipline.rs +++ b/kernel/src/device/tty/line_discipline.rs @@ -265,6 +265,7 @@ impl LineDiscipline { unreachable!() } }; + self.pollee.invalidate(); Ok(read_len) } @@ -344,6 +345,7 @@ impl LineDiscipline { pub fn drain_input(&self) { self.current_line.lock().drain(); self.read_buffer.lock().clear(); + self.pollee.invalidate(); } pub fn buffer_len(&self) -> usize { diff --git a/kernel/src/fs/epoll/entry.rs b/kernel/src/fs/epoll/entry.rs index 96f329c65..cacdd02f7 100644 --- a/kernel/src/fs/epoll/entry.rs +++ b/kernel/src/fs/epoll/entry.rs @@ -367,6 +367,11 @@ impl Iterator for ReadySetPopIter<'_> { // must exist, so we can just unwrap it. let weak_entry = entries.pop_front().unwrap(); + // Clear the epoll file's events if there are no ready entries. + if entries.len() == 0 { + self.ready_set.pollee.invalidate(); + } + let Some(entry) = Weak::upgrade(&weak_entry) else { // The entry has been deleted. continue; diff --git a/kernel/src/fs/utils/channel.rs b/kernel/src/fs/utils/channel.rs index d9d49329b..de5ec624a 100644 --- a/kernel/src/fs/utils/channel.rs +++ b/kernel/src/fs/utils/channel.rs @@ -225,6 +225,7 @@ impl Consumer { let read_len = self.0.read(writer)?; self.peer_end().pollee.notify(IoEvents::OUT); + self.this_end().pollee.invalidate(); if read_len > 0 { Ok(read_len) @@ -248,6 +249,7 @@ impl Consumer { let item = self.0.pop(); self.peer_end().pollee.notify(IoEvents::OUT); + self.this_end().pollee.invalidate(); if let Some(item) = item { Ok(Some(item)) @@ -331,12 +333,16 @@ impl Common { let (rb_producer, rb_consumer) = rb.split(); let producer = { - let pollee = producer_pollee.unwrap_or_default(); + let pollee = producer_pollee + .inspect(|pollee| pollee.invalidate()) + .unwrap_or_default(); FifoInner::new(rb_producer, pollee) }; let consumer = { - let pollee = consumer_pollee.unwrap_or_default(); + let pollee = consumer_pollee + .inspect(|pollee| pollee.invalidate()) + .unwrap_or_default(); FifoInner::new(rb_consumer, pollee) }; diff --git a/kernel/src/net/socket/ip/datagram/mod.rs b/kernel/src/net/socket/ip/datagram/mod.rs index 7618fcdc8..8f136eddd 100644 --- a/kernel/src/net/socket/ip/datagram/mod.rs +++ b/kernel/src/net/socket/ip/datagram/mod.rs @@ -157,9 +157,12 @@ impl DatagramSocket { return_errno_with_message!(Errno::EAGAIN, "the socket is not bound"); }; - bound_datagram + let recv_bytes = bound_datagram .try_recv(writer, flags) - .map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into())) + .map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()))?; + self.pollee.invalidate(); + + Ok(recv_bytes) } fn recv( @@ -190,6 +193,7 @@ impl DatagramSocket { let iface_to_poll = bound_datagram.iface().clone(); drop(inner); + self.pollee.invalidate(); iface_to_poll.poll(); Ok(sent_bytes) diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 3a3177673..5eb89d9fd 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -248,6 +248,7 @@ impl StreamSocket { }); drop(state); + self.pollee.invalidate(); if let Some(iface) = iface_to_poll { iface.poll(); } @@ -286,6 +287,7 @@ impl StreamSocket { let iface_to_poll = listen_stream.iface().clone(); drop(state); + self.pollee.invalidate(); iface_to_poll.poll(); accepted @@ -313,6 +315,7 @@ impl StreamSocket { let remote_endpoint = connected_stream.remote_endpoint(); drop(state); + self.pollee.invalidate(); if let Some(iface) = iface_to_poll { iface.poll(); } @@ -352,6 +355,7 @@ impl StreamSocket { let iface_to_poll = need_poll.then(|| connected_stream.iface().clone()); drop(state); + self.pollee.invalidate(); if let Some(iface) = iface_to_poll { iface.poll(); } @@ -498,6 +502,7 @@ impl Socket for StreamSocket { } }; + self.pollee.invalidate(); (State::Listen(listen_stream), Ok(())) }) } @@ -523,6 +528,8 @@ impl Socket for StreamSocket { }; drop(state); + // No need to call `Pollee::invalidate` because `ConnectedStream::shutdown` will call + // `Pollee::notify`. iface_to_poll.poll(); result diff --git a/kernel/src/net/socket/unix/stream/listener.rs b/kernel/src/net/socket/unix/stream/listener.rs index 4a85565f4..a10602d17 100644 --- a/kernel/src/net/socket/unix/stream/listener.rs +++ b/kernel/src/net/socket/unix/stream/listener.rs @@ -38,6 +38,7 @@ impl Listener { let backlog = BACKLOG_TABLE .add_backlog(addr, reader_pollee, backlog, is_read_shutdown) .unwrap(); + writer_pollee.invalidate(); Self { backlog, @@ -130,6 +131,8 @@ impl BacklogTable { return None; } + // Note that the cached events can be correctly inherited from `Init`, so there is no need + // to explicitly call `Pollee::invalidate`. let new_backlog = Arc::new(Backlog::new(addr, pollee, backlog, is_shutdown)); backlog_sockets.insert(addr_key, new_backlog.clone()); diff --git a/kernel/src/net/socket/vsock/stream/connected.rs b/kernel/src/net/socket/vsock/stream/connected.rs index 3dd53e578..a8ca3cf89 100644 --- a/kernel/src/net/socket/vsock/stream/connected.rs +++ b/kernel/src/net/socket/vsock/stream/connected.rs @@ -56,6 +56,7 @@ impl Connected { let mut connection = self.connection.disable_irq().lock(); let bytes_read = connection.buffer.read_fallible(writer)?; connection.info.done_forwarding(bytes_read); + self.pollee.invalidate(); match bytes_read { 0 => { diff --git a/kernel/src/net/socket/vsock/stream/listen.rs b/kernel/src/net/socket/vsock/stream/listen.rs index cb19a4bad..1eeb254dd 100644 --- a/kernel/src/net/socket/vsock/stream/listen.rs +++ b/kernel/src/net/socket/vsock/stream/listen.rs @@ -51,6 +51,7 @@ impl Listen { .ok_or_else(|| { Error::with_message(Errno::EAGAIN, "no pending connection is available") })?; + self.pollee.invalidate(); Ok(connection) } diff --git a/kernel/src/process/signal/poll.rs b/kernel/src/process/signal/poll.rs index bbae69d04..1f97a6f93 100644 --- a/kernel/src/process/signal/poll.rs +++ b/kernel/src/process/signal/poll.rs @@ -1,11 +1,14 @@ // SPDX-License-Identifier: MPL-2.0 use core::{ - sync::atomic::{AtomicUsize, Ordering}, + sync::atomic::{AtomicIsize, AtomicUsize, Ordering}, time::Duration, }; -use ostd::sync::{Waiter, Waker}; +use ostd::{ + sync::{Waiter, Waker}, + task::Task, +}; use crate::{ events::{IoEvents, Observer, Subject}, @@ -18,7 +21,10 @@ use crate::{ /// 1. An I/O object to maintain its I/O readiness; and /// 2. An interested part to poll the object's I/O readiness. /// -/// To correctly use the pollee, you need to call [`Pollee::notify`] whenever a new event arrives. +/// To use the pollee correctly, you must follow the rules below carefully: +/// * [`Pollee::notify`] needs to be called whenever a new event arrives. +/// * [`Pollee::invalidate`] needs to be called whenever an old event disappears and no new event +/// arrives. /// /// Then, [`Pollee::poll_with`] can allow you to register a [`Poller`] to wait for certain events, /// or register a [`PollAdaptor`] to be notified when certain events occur. @@ -26,9 +32,28 @@ pub struct Pollee { inner: Arc, } +const INV_STATE: isize = -1; + struct PolleeInner { - // A subject which is monitored with pollers. + /// A subject which is monitored with pollers. subject: Subject, + /// A state that describes how events are cached in the pollee. + /// + /// The meaning of this field depends on its value: + /// + /// * A non-negative value represents cached events. The events are guaranteed to be + /// up-to-date, i.e., no one has called [`Pollee::notify`] or [`Pollee::invalidate`] since we + /// started checking the events. + /// + /// * A value of [`INV_STATE`] means no cached events. We may have previously cached some + /// events, but they are no longer valid due to calls of [`Pollee::notify`] or + /// [`Pollee::invalidate`]. + /// + /// * A negative value other than [`INV_STATE`] represents a [`Task`] that is currently + /// checking events. When the task has finished checking and the state is neither invalidated + /// nor overwritten by another task checking events, the state can be used to cache the + /// checked events. + state: AtomicIsize, } impl Default for Pollee { @@ -42,6 +67,7 @@ impl Pollee { pub fn new() -> Self { let inner = PolleeInner { subject: Subject::new(), + state: AtomicIsize::new(INV_STATE), }; Self { inner: Arc::new(inner), @@ -75,8 +101,52 @@ impl Pollee { self.register_poller(poller, mask); } + // Return the cached events, if any. + let events = self.inner.state.load(Ordering::Acquire); + if events >= 0 { + return IoEvents::from_bits_truncate(events as _) & mask; + } + + // If we know some task is checking the events, let it finish. + if events != INV_STATE { + return check() & mask; + } + + // We will store `task_ptr` in `state` to indicate that we're checking the events. But we + // need to make sure it's a negative value. + const { + use ostd::mm::KERNEL_VADDR_RANGE; + assert!((KERNEL_VADDR_RANGE.start as isize) < 0); + } + let task_ptr = Task::current().unwrap().as_ref() as *const _ as isize; + + // Store `task_ptr` in `state` to indicate we're checking the events. + // + // Note that: + // * If there are race conditions, `state` may contain something other than `INV_STATE` (as + // checked above), but that's okay. + // * Given the first point, we only need to do a store here. However, we need the `Acquire` + // order, which forces us to do a `swap` operation. We ignore the returned value to allow + // the compiler to produce better assembly code. + let _ = self.inner.state.swap(task_ptr, Ordering::Acquire); + // Check events after the registration to prevent race conditions. - check() & mask + let new_events = check(); + + // If this `compare_exchange_weak` succeeds, we can guarantee that we are the only task + // trying to cache the checked events, and that the events are not invalidated in the + // middle, so we can cache them with confidence. + // + // Otherwise, we cache nothing, but returning the obsolete events is still okay. + let _ = self.inner.state.compare_exchange_weak( + task_ptr, + new_events.bits() as _, + Ordering::Release, + Ordering::Relaxed, + ); + + // Return the events filtered by the mask. + new_events & mask } fn register_poller(&self, poller: &mut PollHandle, mask: IoEvents) { @@ -89,13 +159,27 @@ impl Pollee { /// Notifies pollers of some events. /// - /// This method wakes up all registered pollers that are interested in the events. + /// This method invalidates the (internal) cached events and wakes up all registered pollers + /// that are interested in the events. /// - /// The events can be spurious. This way, the caller can avoid expensive calculations and - /// simply add all possible ones. + /// This method should be called whenever new events arrive. The events can be spurious. This + /// way, the caller can avoid expensive calculations and simply add all possible ones. pub fn notify(&self, events: IoEvents) { + self.invalidate(); + self.inner.subject.notify_observers(&events); } + + /// Invalidates the (internal) cached events. + /// + /// This method should be called whenever old events disappear but no new events arrive. The + /// invalidation can be spurious, so the caller can avoid complex calculations and simply + /// invalidate even if no events disappear. + pub fn invalidate(&self) { + // The memory order must be `Release`, so that the reader is guaranteed to see the changes + // that trigger the invalidation. + self.inner.state.store(INV_STATE, Ordering::Release); + } } /// An opaque handle that can be used as an argument of the [`Pollable::poll`] method. @@ -325,3 +409,154 @@ pub trait Pollable { } } } + +#[cfg(ktest)] +mod test { + use ostd::prelude::*; + + use super::*; + + #[ktest] + fn test_notify_before() { + let pollee = Pollee::new(); + + pollee.notify(IoEvents::OUT); + + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || IoEvents::IN), + // This is allowed, as we invoke the checking closure. + IoEvents::IN + ); + + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || IoEvents::OUT), + // This is allowed, as the cached state is still valid. + IoEvents::IN + ); + } + + #[ktest] + fn test_notify_middle() { + let pollee = Pollee::new(); + + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || { + pollee.notify(IoEvents::OUT); + IoEvents::IN + }), + // This is allowed, as we invoke the checking closure. + IoEvents::IN + ); + + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || IoEvents::OUT), + // This is allowed, as we invoke the checking closure. + // + // Reusing the cached state is NOT allowed as we've been notified above. + IoEvents::OUT + ); + } + + #[ktest] + fn test_notify_after() { + let pollee = Pollee::new(); + + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || IoEvents::IN), + // This is allowed, as we invoke the checking closure. + IoEvents::IN + ); + + pollee.notify(IoEvents::OUT); + + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || IoEvents::OUT), + // This is allowed, as we invoke the checking closure. + // + // Reusing the cached state is NOT allowed as we've been notified above. + IoEvents::OUT + ); + } + + #[ktest] + fn test_nested_notify_before() { + let pollee = Pollee::new(); + + pollee.notify(IoEvents::OUT); + + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || { + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || IoEvents::OUT), + // This is allowed, as we invoke the checking closure. + IoEvents::OUT + ); + IoEvents::IN + }), + // This is allowed, as we invoke the checking closure. + IoEvents::IN + ); + + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || IoEvents::OUT), + // This is allowed, as the cached state is still valid. + IoEvents::IN + ); + } + + #[ktest] + fn test_nested_notify_between() { + let pollee = Pollee::new(); + + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || { + pollee.notify(IoEvents::OUT); + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || IoEvents::OUT), + // This is allowed, as we invoke the checking closure. + IoEvents::OUT + ); + IoEvents::IN + }), + // This is allowed, as we invoke the checking closure. + IoEvents::IN + ); + + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || IoEvents::OUT), + // This is allowed, as we invoke the checking closure. + // + // Reusing the cached state is NOT allowed as we've been notified above. + IoEvents::OUT + ); + } + + #[ktest] + fn test_nested_notify_inside() { + let pollee = Pollee::new(); + + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || { + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || { + pollee.notify(IoEvents::OUT); + IoEvents::OUT + }), + // This is allowed, as we invoke the checking closure. + IoEvents::OUT + ); + IoEvents::IN + }), + // This is allowed, as we invoke the checking closure. + IoEvents::IN + ); + + assert_eq!( + pollee.poll_with(IoEvents::all(), None, || IoEvents::OUT), + // This is allowed, as we invoke the checking closure. + // + // Reusing the cached state is NOT allowed as we've been notified above. + IoEvents::OUT, + ); + } +} diff --git a/ostd/src/mm/mod.rs b/ostd/src/mm/mod.rs index bfe2cafec..ea493a195 100644 --- a/ostd/src/mm/mod.rs +++ b/ostd/src/mm/mod.rs @@ -100,9 +100,10 @@ pub(crate) const fn nr_base_per_page(level: PagingLevel) - pub const MAX_USERSPACE_VADDR: Vaddr = 0x0000_8000_0000_0000 - PAGE_SIZE; /// The kernel address space. +/// /// There are the high canonical addresses defined in most 48-bit width /// architectures. -pub(crate) const KERNEL_VADDR_RANGE: Range = 0xffff_8000_0000_0000..0xffff_ffff_ffff_0000; +pub const KERNEL_VADDR_RANGE: Range = 0xffff_8000_0000_0000..0xffff_ffff_ffff_0000; /// Gets physical address trait pub trait HasPaddr {