Reuse the wait_events method in pipes

This commit is contained in:
Ruihan Li
2024-07-01 00:19:58 +08:00
committed by Tate, Hongliang Tian
parent 1c0d865373
commit 8e72451448

View File

@ -10,7 +10,7 @@ use super::StatusFlags;
use crate::{
events::{IoEvents, Observer},
prelude::*,
process::signal::{Pollee, Poller},
process::signal::{Pollable, Pollee, Poller},
};
/// A unidirectional communication channel, intended to implement IPC, e.g., pipe,
@ -137,29 +137,18 @@ impl<T> Producer<T> {
impl_common_methods_for_channel!();
}
impl<T> Pollable for Producer<T> {
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.poll(mask, poller)
}
}
impl<T: Copy> Producer<T> {
pub fn write(&self, buf: &[T]) -> Result<usize> {
let is_nonblocking = self.is_nonblocking();
// Fast path
let res = self.try_write(buf);
if should_io_return(&res, is_nonblocking) {
return res;
}
// Slow path
let mask = IoEvents::OUT;
let poller = Poller::new();
loop {
let res = self.try_write(buf);
if should_io_return(&res, is_nonblocking) {
return res;
}
let events = self.poll(mask, Some(&poller));
if events.is_empty() {
// FIXME: should channel deal with timeout?
poller.wait()?;
}
if self.is_nonblocking() {
self.try_write(buf)
} else {
self.wait_events(IoEvents::OUT, || self.try_write(buf))
}
}
@ -190,31 +179,25 @@ impl<T> Producer<T> {
/// On failure, this method returns `Err` containing
/// the item fails to push.
pub fn push(&self, item: T) -> core::result::Result<(), (Error, T)> {
let is_nonblocking = self.is_nonblocking();
// Fast path
let mut res = self.try_push(item);
if should_io_return(&res, is_nonblocking) {
return res;
if self.is_nonblocking() {
return self.try_push(item);
}
// Slow path
let mask = IoEvents::OUT;
let poller = Poller::new();
loop {
let (_, item) = res.unwrap_err();
let mut stored_item = Some(item);
res = self.try_push(item);
if should_io_return(&res, is_nonblocking) {
return res;
}
let events = self.poll(mask, Some(&poller));
if events.is_empty() {
// FIXME: should channel deal with timeout?
if let Err(err) = poller.wait() {
return Err((err, res.unwrap_err().1));
let result = self.wait_events(IoEvents::OUT, || {
match self.try_push(stored_item.take().unwrap()) {
Ok(()) => Ok(()),
Err((err, item)) => {
stored_item = Some(item);
Err(err)
}
}
});
match result {
Ok(()) => Ok(()),
Err(err) => Err((err, stored_item.unwrap())),
}
}
@ -277,29 +260,18 @@ impl<T> Consumer<T> {
impl_common_methods_for_channel!();
}
impl<T> Pollable for Consumer<T> {
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.poll(mask, poller)
}
}
impl<T: Copy> Consumer<T> {
pub fn read(&self, buf: &mut [T]) -> Result<usize> {
let is_nonblocking = self.is_nonblocking();
// Fast path
let res = self.try_read(buf);
if should_io_return(&res, is_nonblocking) {
return res;
}
// Slow path
let mask = IoEvents::IN;
let poller = Poller::new();
loop {
let res = self.try_read(buf);
if should_io_return(&res, is_nonblocking) {
return res;
}
let events = self.poll(mask, Some(&poller));
if events.is_empty() {
// FIXME: should channel have timeout?
poller.wait()?;
}
if self.is_nonblocking() {
self.try_read(buf)
} else {
self.wait_events(IoEvents::IN, || self.try_read(buf))
}
}
@ -330,27 +302,10 @@ impl<T: Copy> Consumer<T> {
impl<T> Consumer<T> {
/// Pops an item from the consumer
pub fn pop(&self) -> Result<T> {
let is_nonblocking = self.is_nonblocking();
// Fast path
let res = self.try_pop();
if should_io_return(&res, is_nonblocking) {
return res;
}
// Slow path
let mask = IoEvents::IN;
let poller = Poller::new();
loop {
let res = self.try_pop();
if should_io_return(&res, is_nonblocking) {
return res;
}
let events = self.poll(mask, Some(&poller));
if events.is_empty() {
// FIXME: should channel have timeout?
poller.wait()?;
}
if self.is_nonblocking() {
self.try_pop()
} else {
self.wait_events(IoEvents::IN, || self.try_pop())
}
}
@ -510,26 +465,6 @@ fn check_status_flags(flags: StatusFlags) -> Result<()> {
Ok(())
}
fn should_io_return<T, E: AsRef<Error>>(
res: &core::result::Result<T, E>,
is_nonblocking: bool,
) -> bool {
if is_nonblocking {
return true;
}
match res {
Ok(_) => true,
Err(e) if e.as_ref().error() == Errno::EAGAIN => false,
Err(_) => true,
}
}
impl<T> AsRef<Error> for (Error, T) {
fn as_ref(&self) -> &Error {
&self.0
}
}
#[cfg(ktest)]
mod test {
use alloc::sync::Arc;