Support Channel sending data which does not implement Copy

This commit is contained in:
Jianfeng Jiang
2024-04-18 08:27:57 +00:00
committed by Tate, Hongliang Tian
parent 83c2aba0b0
commit 5189f889a3
2 changed files with 161 additions and 2 deletions

View File

@ -180,6 +180,12 @@ impl From<Errno> for Error {
}
}
impl AsRef<Error> for Error {
fn as_ref(&self) -> &Error {
self
}
}
impl From<aster_frame::Error> for Error {
fn from(frame_error: aster_frame::Error) -> Self {
match frame_error {

View File

@ -181,6 +181,56 @@ impl<T: Copy> Producer<T> {
}
}
impl<T> Producer<T> {
/// Pushes an item into the producer.
///
/// On failure, this method returns `Err` containing
/// the item fails to push.
pub fn push(&self, item: T) -> core::result::Result<(), (Error, T)> {
let is_nonblocking = self.is_nonblocking();
// Fast path
let mut res = self.try_push(item);
if should_io_return(&res, is_nonblocking) {
return res;
}
// Slow path
let mask = IoEvents::OUT;
let poller = Poller::new();
loop {
let (_, item) = res.unwrap_err();
res = self.try_push(item);
if should_io_return(&res, is_nonblocking) {
return res;
}
let events = self.poll(mask, Some(&poller));
if events.is_empty() {
// FIXME: should channel deal with timeout?
if let Err(err) = poller.wait() {
return Err((err, res.unwrap_err().1));
}
}
}
}
fn try_push(&self, item: T) -> core::result::Result<(), (Error, T)> {
if self.is_shutdown() || self.is_peer_shutdown() {
let err = Error::with_message(Errno::EPIPE, "the pipe is shutdown");
return Err((err, item));
}
self.0.push(item).map_err(|item| {
let err = Error::with_message(Errno::EAGAIN, "try push again");
(err, item)
})?;
self.update_pollee();
Ok(())
}
}
impl<T> Drop for Producer<T> {
fn drop(&mut self) {
self.shutdown();
@ -273,6 +323,54 @@ impl<T: Copy> Consumer<T> {
}
}
impl<T> Consumer<T> {
/// Pops an item from the consumer
pub fn pop(&self) -> Result<T> {
let is_nonblocking = self.is_nonblocking();
// Fast path
let res = self.try_pop();
if should_io_return(&res, is_nonblocking) {
return res;
}
// Slow path
let mask = IoEvents::IN;
let poller = Poller::new();
loop {
let res = self.try_pop();
if should_io_return(&res, is_nonblocking) {
return res;
}
let events = self.poll(mask, Some(&poller));
if events.is_empty() {
// FIXME: should channel have timeout?
poller.wait()?;
}
}
}
fn try_pop(&self) -> Result<T> {
if self.is_shutdown() {
return_errno_with_message!(Errno::EPIPE, "this end is shut down");
}
let item = self.0.pop();
self.update_pollee();
if let Some(item) = item {
return Ok(item);
}
if self.is_peer_shutdown() {
return_errno_with_message!(Errno::EPIPE, "remote end is shut down");
}
return_errno_with_message!(Errno::EAGAIN, "try pop again")
}
}
impl<T> Drop for Consumer<T> {
fn drop(&mut self) {
self.shutdown();
@ -310,6 +408,24 @@ impl<T: Copy, R: TRights> EndPoint<T, R> {
}
}
impl<T, R: TRights> EndPoint<T, R> {
/// Pushes an item into the endpoint.
/// If the `push` method failes, this method will return
/// `Err` containing the item that hasn't been pushed
#[require(R > Write)]
pub fn push(&self, item: T) -> core::result::Result<(), T> {
let mut rb = self.common.producer.rb();
rb.push(item)
}
/// Pops an item from the endpoint.
#[require(R > Read)]
pub fn pop(&self) -> Option<T> {
let mut rb = self.common.consumer.rb();
rb.pop()
}
}
struct Common<T> {
producer: EndPointInner<HeapRbProducer<T>>,
consumer: EndPointInner<HeapRbConsumer<T>>,
@ -399,13 +515,50 @@ fn check_status_flags(flags: StatusFlags) -> Result<()> {
Ok(())
}
fn should_io_return(res: &Result<usize>, is_nonblocking: bool) -> bool {
fn should_io_return<T, E: AsRef<Error>>(
res: &core::result::Result<T, E>,
is_nonblocking: bool,
) -> bool {
if is_nonblocking {
return true;
}
match res {
Ok(_) => true,
Err(e) if e.error() == Errno::EAGAIN => false,
Err(e) if e.as_ref().error() == Errno::EAGAIN => false,
Err(_) => true,
}
}
impl<T> AsRef<Error> for (Error, T) {
fn as_ref(&self) -> &Error {
&self.0
}
}
#[cfg(ktest)]
mod test {
use alloc::sync::Arc;
use crate::fs::utils::Channel;
#[ktest]
fn test_non_copy() {
#[derive(Clone, Debug, PartialEq, Eq)]
struct NonCopy(Arc<usize>);
let channel = Channel::with_capacity(16).unwrap();
let (producer, consumer) = channel.split();
let data = NonCopy(Arc::new(99));
let expected_data = data.clone();
for _ in 0..3 {
producer.push(data.clone()).unwrap();
}
for _ in 0..3 {
let data = consumer.pop().unwrap();
assert_eq!(data, expected_data);
}
}
}