diff --git a/kernel/aster-nix/src/fs/utils/channel.rs b/kernel/aster-nix/src/fs/utils/channel.rs index a35dc23cb..92b429bf1 100644 --- a/kernel/aster-nix/src/fs/utils/channel.rs +++ b/kernel/aster-nix/src/fs/utils/channel.rs @@ -10,7 +10,7 @@ use super::StatusFlags; use crate::{ events::{IoEvents, Observer}, prelude::*, - process::signal::{Pollee, Poller}, + process::signal::{Pollable, Pollee, Poller}, }; /// A unidirectional communication channel, intended to implement IPC, e.g., pipe, @@ -137,29 +137,18 @@ impl Producer { impl_common_methods_for_channel!(); } +impl Pollable for Producer { + fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + self.poll(mask, poller) + } +} + impl Producer { pub fn write(&self, buf: &[T]) -> Result { - let is_nonblocking = self.is_nonblocking(); - - // Fast path - let res = self.try_write(buf); - if should_io_return(&res, is_nonblocking) { - return res; - } - - // Slow path - let mask = IoEvents::OUT; - let poller = Poller::new(); - loop { - let res = self.try_write(buf); - if should_io_return(&res, is_nonblocking) { - return res; - } - let events = self.poll(mask, Some(&poller)); - if events.is_empty() { - // FIXME: should channel deal with timeout? - poller.wait()?; - } + if self.is_nonblocking() { + self.try_write(buf) + } else { + self.wait_events(IoEvents::OUT, || self.try_write(buf)) } } @@ -190,31 +179,25 @@ impl Producer { /// On failure, this method returns `Err` containing /// the item fails to push. pub fn push(&self, item: T) -> core::result::Result<(), (Error, T)> { - let is_nonblocking = self.is_nonblocking(); - - // Fast path - let mut res = self.try_push(item); - if should_io_return(&res, is_nonblocking) { - return res; + if self.is_nonblocking() { + return self.try_push(item); } - // Slow path - let mask = IoEvents::OUT; - let poller = Poller::new(); - loop { - let (_, item) = res.unwrap_err(); + let mut stored_item = Some(item); - res = self.try_push(item); - if should_io_return(&res, is_nonblocking) { - return res; - } - let events = self.poll(mask, Some(&poller)); - if events.is_empty() { - // FIXME: should channel deal with timeout? - if let Err(err) = poller.wait() { - return Err((err, res.unwrap_err().1)); + let result = self.wait_events(IoEvents::OUT, || { + match self.try_push(stored_item.take().unwrap()) { + Ok(()) => Ok(()), + Err((err, item)) => { + stored_item = Some(item); + Err(err) } } + }); + + match result { + Ok(()) => Ok(()), + Err(err) => Err((err, stored_item.unwrap())), } } @@ -277,29 +260,18 @@ impl Consumer { impl_common_methods_for_channel!(); } +impl Pollable for Consumer { + fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + self.poll(mask, poller) + } +} + impl Consumer { pub fn read(&self, buf: &mut [T]) -> Result { - let is_nonblocking = self.is_nonblocking(); - - // Fast path - let res = self.try_read(buf); - if should_io_return(&res, is_nonblocking) { - return res; - } - - // Slow path - let mask = IoEvents::IN; - let poller = Poller::new(); - loop { - let res = self.try_read(buf); - if should_io_return(&res, is_nonblocking) { - return res; - } - let events = self.poll(mask, Some(&poller)); - if events.is_empty() { - // FIXME: should channel have timeout? - poller.wait()?; - } + if self.is_nonblocking() { + self.try_read(buf) + } else { + self.wait_events(IoEvents::IN, || self.try_read(buf)) } } @@ -330,27 +302,10 @@ impl Consumer { impl Consumer { /// Pops an item from the consumer pub fn pop(&self) -> Result { - let is_nonblocking = self.is_nonblocking(); - - // Fast path - let res = self.try_pop(); - if should_io_return(&res, is_nonblocking) { - return res; - } - - // Slow path - let mask = IoEvents::IN; - let poller = Poller::new(); - loop { - let res = self.try_pop(); - if should_io_return(&res, is_nonblocking) { - return res; - } - let events = self.poll(mask, Some(&poller)); - if events.is_empty() { - // FIXME: should channel have timeout? - poller.wait()?; - } + if self.is_nonblocking() { + self.try_pop() + } else { + self.wait_events(IoEvents::IN, || self.try_pop()) } } @@ -510,26 +465,6 @@ fn check_status_flags(flags: StatusFlags) -> Result<()> { Ok(()) } -fn should_io_return>( - res: &core::result::Result, - is_nonblocking: bool, -) -> bool { - if is_nonblocking { - return true; - } - match res { - Ok(_) => true, - Err(e) if e.as_ref().error() == Errno::EAGAIN => false, - Err(_) => true, - } -} - -impl AsRef for (Error, T) { - fn as_ref(&self) -> &Error { - &self.0 - } -} - #[cfg(ktest)] mod test { use alloc::sync::Arc;