diff --git a/kernel/aster-nix/src/error.rs b/kernel/aster-nix/src/error.rs index d351233e6..ec03c5fd1 100644 --- a/kernel/aster-nix/src/error.rs +++ b/kernel/aster-nix/src/error.rs @@ -180,6 +180,12 @@ impl From for Error { } } +impl AsRef for Error { + fn as_ref(&self) -> &Error { + self + } +} + impl From for Error { fn from(frame_error: aster_frame::Error) -> Self { match frame_error { diff --git a/kernel/aster-nix/src/fs/utils/channel.rs b/kernel/aster-nix/src/fs/utils/channel.rs index 2cfe00599..fe724c961 100644 --- a/kernel/aster-nix/src/fs/utils/channel.rs +++ b/kernel/aster-nix/src/fs/utils/channel.rs @@ -181,6 +181,56 @@ impl Producer { } } +impl Producer { + /// Pushes an item into the producer. + /// + /// On failure, this method returns `Err` containing + /// the item fails to push. + pub fn push(&self, item: T) -> core::result::Result<(), (Error, T)> { + let is_nonblocking = self.is_nonblocking(); + + // Fast path + let mut res = self.try_push(item); + if should_io_return(&res, is_nonblocking) { + return res; + } + + // Slow path + let mask = IoEvents::OUT; + let poller = Poller::new(); + loop { + let (_, item) = res.unwrap_err(); + + res = self.try_push(item); + if should_io_return(&res, is_nonblocking) { + return res; + } + let events = self.poll(mask, Some(&poller)); + if events.is_empty() { + // FIXME: should channel deal with timeout? + if let Err(err) = poller.wait() { + return Err((err, res.unwrap_err().1)); + } + } + } + } + + fn try_push(&self, item: T) -> core::result::Result<(), (Error, T)> { + if self.is_shutdown() || self.is_peer_shutdown() { + let err = Error::with_message(Errno::EPIPE, "the pipe is shutdown"); + return Err((err, item)); + } + + self.0.push(item).map_err(|item| { + let err = Error::with_message(Errno::EAGAIN, "try push again"); + (err, item) + })?; + + self.update_pollee(); + Ok(()) + } +} + impl Drop for Producer { fn drop(&mut self) { self.shutdown(); @@ -273,6 +323,54 @@ impl Consumer { } } +impl Consumer { + /// Pops an item from the consumer + pub fn pop(&self) -> Result { + let is_nonblocking = self.is_nonblocking(); + + // Fast path + let res = self.try_pop(); + if should_io_return(&res, is_nonblocking) { + return res; + } + + // Slow path + let mask = IoEvents::IN; + let poller = Poller::new(); + loop { + let res = self.try_pop(); + if should_io_return(&res, is_nonblocking) { + return res; + } + let events = self.poll(mask, Some(&poller)); + if events.is_empty() { + // FIXME: should channel have timeout? + poller.wait()?; + } + } + } + + fn try_pop(&self) -> Result { + if self.is_shutdown() { + return_errno_with_message!(Errno::EPIPE, "this end is shut down"); + } + + let item = self.0.pop(); + + self.update_pollee(); + + if let Some(item) = item { + return Ok(item); + } + + if self.is_peer_shutdown() { + return_errno_with_message!(Errno::EPIPE, "remote end is shut down"); + } + + return_errno_with_message!(Errno::EAGAIN, "try pop again") + } +} + impl Drop for Consumer { fn drop(&mut self) { self.shutdown(); @@ -310,6 +408,24 @@ impl EndPoint { } } +impl EndPoint { + /// Pushes an item into the endpoint. + /// If the `push` method failes, this method will return + /// `Err` containing the item that hasn't been pushed + #[require(R > Write)] + pub fn push(&self, item: T) -> core::result::Result<(), T> { + let mut rb = self.common.producer.rb(); + rb.push(item) + } + + /// Pops an item from the endpoint. + #[require(R > Read)] + pub fn pop(&self) -> Option { + let mut rb = self.common.consumer.rb(); + rb.pop() + } +} + struct Common { producer: EndPointInner>, consumer: EndPointInner>, @@ -399,13 +515,50 @@ fn check_status_flags(flags: StatusFlags) -> Result<()> { Ok(()) } -fn should_io_return(res: &Result, is_nonblocking: bool) -> bool { +fn should_io_return>( + res: &core::result::Result, + is_nonblocking: bool, +) -> bool { if is_nonblocking { return true; } match res { Ok(_) => true, - Err(e) if e.error() == Errno::EAGAIN => false, + Err(e) if e.as_ref().error() == Errno::EAGAIN => false, Err(_) => true, } } + +impl AsRef for (Error, T) { + fn as_ref(&self) -> &Error { + &self.0 + } +} + +#[cfg(ktest)] +mod test { + use alloc::sync::Arc; + + use crate::fs::utils::Channel; + + #[ktest] + fn test_non_copy() { + #[derive(Clone, Debug, PartialEq, Eq)] + struct NonCopy(Arc); + + let channel = Channel::with_capacity(16).unwrap(); + let (producer, consumer) = channel.split(); + + let data = NonCopy(Arc::new(99)); + let expected_data = data.clone(); + + for _ in 0..3 { + producer.push(data.clone()).unwrap(); + } + + for _ in 0..3 { + let data = consumer.pop().unwrap(); + assert_eq!(data, expected_data); + } + } +}