diff --git a/kernel/aster-nix/src/fs/pipe.rs b/kernel/aster-nix/src/fs/pipe.rs index 1ff70554b..adcaefeed 100644 --- a/kernel/aster-nix/src/fs/pipe.rs +++ b/kernel/aster-nix/src/fs/pipe.rs @@ -1,5 +1,9 @@ // SPDX-License-Identifier: MPL-2.0 +use core::sync::atomic::AtomicU32; + +use atomic::Ordering; + use super::{ file_handle::FileLike, utils::{AccessMode, Consumer, InodeMode, InodeType, Metadata, Producer, StatusFlags}, @@ -16,11 +20,17 @@ use crate::{ pub struct PipeReader { consumer: Consumer, + status_flags: AtomicU32, } impl PipeReader { - pub fn new(consumer: Consumer) -> Self { - Self { consumer } + pub fn new(consumer: Consumer, status_flags: StatusFlags) -> Result> { + check_status_flags(status_flags)?; + + Ok(Arc::new(Self { + consumer, + status_flags: AtomicU32::new(status_flags.bits()), + })) } } @@ -32,15 +42,22 @@ impl Pollable for PipeReader { impl FileLike for PipeReader { fn read(&self, buf: &mut [u8]) -> Result { - self.consumer.read(buf) + if self.status_flags().contains(StatusFlags::O_NONBLOCK) { + self.consumer.try_read(buf) + } else { + self.wait_events(IoEvents::IN, || self.consumer.try_read(buf)) + } } fn status_flags(&self) -> StatusFlags { - self.consumer.status_flags() + StatusFlags::from_bits_truncate(self.status_flags.load(Ordering::Relaxed)) } fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { - self.consumer.set_status_flags(new_flags) + check_status_flags(new_flags)?; + + self.status_flags.store(new_flags.bits(), Ordering::Relaxed); + Ok(()) } fn access_mode(&self) -> AccessMode { @@ -85,11 +102,17 @@ impl FileLike for PipeReader { pub struct PipeWriter { producer: Producer, + status_flags: AtomicU32, } impl PipeWriter { - pub fn new(producer: Producer) -> Self { - Self { producer } + pub fn new(producer: Producer, status_flags: StatusFlags) -> Result> { + check_status_flags(status_flags)?; + + Ok(Arc::new(Self { + producer, + status_flags: AtomicU32::new(status_flags.bits()), + })) } } @@ -101,15 +124,22 @@ impl Pollable for PipeWriter { impl FileLike for PipeWriter { fn write(&self, buf: &[u8]) -> Result { - self.producer.write(buf) + if self.status_flags().contains(StatusFlags::O_NONBLOCK) { + self.producer.try_write(buf) + } else { + self.wait_events(IoEvents::OUT, || self.producer.try_write(buf)) + } } fn status_flags(&self) -> StatusFlags { - self.producer.status_flags() + StatusFlags::from_bits_truncate(self.status_flags.load(Ordering::Relaxed)) } fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { - self.producer.set_status_flags(new_flags) + check_status_flags(new_flags)?; + + self.status_flags.store(new_flags.bits(), Ordering::Relaxed); + Ok(()) } fn access_mode(&self) -> AccessMode { @@ -151,3 +181,150 @@ impl FileLike for PipeWriter { self.producer.unregister_observer(observer) } } + +fn check_status_flags(status_flags: StatusFlags) -> Result<()> { + if status_flags.contains(StatusFlags::O_DIRECT) { + // "O_DIRECT .. Older kernels that do not support this flag will indicate this via an + // EINVAL error." + // + // See . + return_errno_with_message!(Errno::EINVAL, "the `O_DIRECT` flag is not supported"); + } + + // TODO: Setting most of the other flags will succeed on Linux, but their effects need to be + // validated. + + Ok(()) +} + +#[cfg(ktest)] +mod test { + use alloc::sync::Arc; + use core::sync::atomic::{self, AtomicBool}; + + use ostd::prelude::*; + + use super::*; + use crate::{ + fs::utils::Channel, + thread::{ + kernel_thread::{KernelThreadExt, ThreadOptions}, + Thread, + }, + }; + + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + enum Ordering { + WriteThenRead, + ReadThenWrite, + } + + fn test_blocking(write: W, read: R, ordering: Ordering) + where + W: Fn(Arc) + Sync + Send + 'static, + R: Fn(Arc) + Sync + Send + 'static, + { + let channel = Channel::new(1); + let (writer, readr) = channel.split(); + + let writer = PipeWriter::new(writer, StatusFlags::empty()).unwrap(); + let reader = PipeReader::new(readr, StatusFlags::empty()).unwrap(); + + // FIXME: `ThreadOptions::new` currently accepts `Fn`, forcing us to use `SpinLock` to gain + // internal mutability. We should avoid this `SpinLock` by making `ThreadOptions::new` + // accept `FnOnce`. + let writer_with_lock = SpinLock::new(Some(writer)); + let reader_with_lock = SpinLock::new(Some(reader)); + + let signal_writer = Arc::new(AtomicBool::new(false)); + let signal_reader = signal_writer.clone(); + + let writer = Thread::spawn_kernel_thread(ThreadOptions::new(move || { + let writer = writer_with_lock.lock().take().unwrap(); + + if ordering == Ordering::ReadThenWrite { + while !signal_writer.load(atomic::Ordering::Relaxed) { + Thread::yield_now(); + } + } else { + signal_writer.store(true, atomic::Ordering::Relaxed); + } + + write(writer); + })); + + let reader = Thread::spawn_kernel_thread(ThreadOptions::new(move || { + let reader = reader_with_lock.lock().take().unwrap(); + + if ordering == Ordering::WriteThenRead { + while !signal_reader.load(atomic::Ordering::Relaxed) { + Thread::yield_now(); + } + } else { + signal_reader.store(true, atomic::Ordering::Relaxed); + } + + read(reader); + })); + + writer.join(); + reader.join(); + } + + #[ktest] + fn test_read_empty() { + test_blocking( + |writer| { + assert_eq!(writer.write(&[1]).unwrap(), 1); + }, + |reader| { + let mut buf = [0; 1]; + assert_eq!(reader.read(&mut buf).unwrap(), 1); + assert_eq!(&buf, &[1]); + }, + Ordering::ReadThenWrite, + ); + } + + #[ktest] + fn test_write_full() { + test_blocking( + |writer| { + assert_eq!(writer.write(&[1, 2]).unwrap(), 1); + assert_eq!(writer.write(&[2]).unwrap(), 1); + }, + |reader| { + let mut buf = [0; 2]; + assert_eq!(reader.read(&mut buf).unwrap(), 1); + assert_eq!(&buf[..1], &[1]); + assert_eq!(reader.read(&mut buf).unwrap(), 1); + assert_eq!(&buf[..1], &[2]); + }, + Ordering::WriteThenRead, + ); + } + + #[ktest] + fn test_read_closed() { + test_blocking( + |writer| drop(writer), + |reader| { + let mut buf = [0; 1]; + assert_eq!(reader.read(&mut buf).unwrap(), 0); + }, + Ordering::ReadThenWrite, + ); + } + + #[ktest] + fn test_write_closed() { + test_blocking( + |writer| { + assert_eq!(writer.write(&[1, 2]).unwrap(), 1); + assert_eq!(writer.write(&[2]).unwrap_err().error(), Errno::EPIPE); + }, + |reader| drop(reader), + Ordering::WriteThenRead, + ); + } +} diff --git a/kernel/aster-nix/src/fs/utils/channel.rs b/kernel/aster-nix/src/fs/utils/channel.rs index ebf09465f..9fcabca1f 100644 --- a/kernel/aster-nix/src/fs/utils/channel.rs +++ b/kernel/aster-nix/src/fs/utils/channel.rs @@ -1,16 +1,15 @@ // SPDX-License-Identifier: MPL-2.0 -use core::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use core::sync::atomic::{AtomicBool, Ordering}; use aster_rights::{Read, ReadOp, TRights, Write, WriteOp}; use aster_rights_proc::require; use ringbuf::{HeapConsumer as HeapRbConsumer, HeapProducer as HeapRbProducer, HeapRb}; -use super::StatusFlags; use crate::{ events::{IoEvents, Observer}, prelude::*, - process::signal::{Pollable, Pollee, Poller}, + process::signal::{Pollee, Poller}, }; /// A unidirectional communication channel, intended to implement IPC, e.g., pipe, @@ -21,15 +20,18 @@ pub struct Channel { } impl Channel { - pub fn with_capacity(capacity: usize) -> Result { - Self::with_capacity_and_flags(capacity, StatusFlags::empty()) - } + /// Creates a new channel with the given capacity. + /// + /// # Panics + /// + /// This method will panic if the given capacity is zero. + pub fn new(capacity: usize) -> Self { + let common = Arc::new(Common::new(capacity)); - pub fn with_capacity_and_flags(capacity: usize, flags: StatusFlags) -> Result { - let common = Arc::new(Common::with_capacity_and_flags(capacity, flags)?); let producer = Producer(Fifo::new(common.clone())); let consumer = Consumer(Fifo::new(common)); - Ok(Self { producer, consumer }) + + Self { producer, consumer } } pub fn split(self) -> (Producer, Consumer) { @@ -68,20 +70,6 @@ macro_rules! impl_common_methods_for_channel { self.peer_end().is_shutdown() } - pub fn status_flags(&self) -> StatusFlags { - self.this_end().status_flags() - } - - pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { - self.this_end().set_status_flags(new_flags) - } - - pub fn is_nonblocking(&self) -> bool { - self.this_end() - .status_flags() - .contains(StatusFlags::O_NONBLOCK) - } - pub fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { self.this_end().pollee.poll(mask, poller) } @@ -139,23 +127,13 @@ impl Producer { impl_common_methods_for_channel!(); } -impl Pollable for Producer { - fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { - self.poll(mask, poller) - } -} - impl Producer { - pub fn write(&self, buf: &[T]) -> Result { - if self.is_nonblocking() { - self.try_write(buf) - } else { - // The POLLOUT event is set after shutdown, so waiting for the single event is enough. - self.wait_events(IoEvents::OUT, || self.try_write(buf)) - } - } - - fn try_write(&self, buf: &[T]) -> Result { + /// Tries to write `buf` to the channel. + /// + /// - Returns `Ok(_)` with the number of bytes written if successful. + /// - Returns `Err(EPIPE)` if the channel is shut down. + /// - Returns `Err(EAGAIN)` if the channel is full. + pub fn try_write(&self, buf: &[T]) -> Result { if buf.is_empty() { // Even after shutdown, writing an empty buffer is still fine. return Ok(0); @@ -177,35 +155,12 @@ impl Producer { } impl Producer { - /// Pushes an item into the producer. + /// Tries to push `item` to the channel. /// - /// On failure, this method returns `Err` containing - /// the item fails to push. - pub fn push(&self, item: T) -> core::result::Result<(), (Error, T)> { - if self.is_nonblocking() { - return self.try_push(item); - } - - let mut stored_item = Some(item); - - // The POLLOUT event is set after shutdown, so waiting for the single event is enough. - let result = self.wait_events(IoEvents::OUT, || { - match self.try_push(stored_item.take().unwrap()) { - Ok(()) => Ok(()), - Err((err, item)) => { - stored_item = Some(item); - Err(err) - } - } - }); - - match result { - Ok(()) => Ok(()), - Err(err) => Err((err, stored_item.unwrap())), - } - } - - fn try_push(&self, item: T) -> core::result::Result<(), (Error, T)> { + /// - Returns `Ok(())` if successful. + /// - Returns `Err(EPIPE)` if the channel is shut down. + /// - Returns `Err(EAGAIN)` if the channel is full. + pub fn try_push(&self, item: T) -> core::result::Result<(), (Error, T)> { if self.is_shutdown() || self.is_peer_shutdown() { let err = Error::with_message(Errno::EPIPE, "the channel is shut down"); return Err((err, item)); @@ -265,23 +220,13 @@ impl Consumer { impl_common_methods_for_channel!(); } -impl Pollable for Consumer { - fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { - self.poll(mask, poller) - } -} - impl Consumer { - pub fn read(&self, buf: &mut [T]) -> Result { - if self.is_nonblocking() { - self.try_read(buf) - } else { - // The POLLHUP event is in `IoEvents::ALWAYS_POLL`, which is not specified again. - self.wait_events(IoEvents::IN, || self.try_read(buf)) - } - } - - fn try_read(&self, buf: &mut [T]) -> Result { + /// Tries to read `buf` from the channel. + /// + /// - Returns `Ok(_)` with the number of bytes read if successful. + /// - Returns `Ok(0)` if the channel is shut down and there is no data left. + /// - Returns `Err(EAGAIN)` if the channel is empty. + pub fn try_read(&self, buf: &mut [T]) -> Result { if buf.is_empty() { return Ok(0); } @@ -303,17 +248,12 @@ impl Consumer { } impl Consumer { - /// Pops an item from the consumer. - pub fn pop(&self) -> Result> { - if self.is_nonblocking() { - self.try_pop() - } else { - // The POLLHUP event is in `IoEvents::ALWAYS_POLL`, which is not specified again. - self.wait_events(IoEvents::IN, || self.try_pop()) - } - } - - fn try_pop(&self) -> Result> { + /// Tries to read an item from the channel. + /// + /// - Returns `Ok(Some(_))` with the popped item if successful. + /// - Returns `Ok(None)` if the channel is shut down and there is no data left. + /// - Returns `Err(EAGAIN)` if the channel is empty. + pub fn try_pop(&self) -> Result> { // This must be recorded before the actual operation to avoid race conditions. let is_shutdown = self.is_shutdown() || self.is_peer_shutdown(); @@ -397,20 +337,14 @@ struct Common { } impl Common { - fn with_capacity_and_flags(capacity: usize, flags: StatusFlags) -> Result { - check_status_flags(flags)?; - - if capacity == 0 { - return_errno_with_message!(Errno::EINVAL, "the channel capacity cannot be zero"); - } - + fn new(capacity: usize) -> Self { let rb: HeapRb = HeapRb::new(capacity); let (rb_producer, rb_consumer) = rb.split(); - let producer = FifoInner::new(rb_producer, IoEvents::OUT, flags); - let consumer = FifoInner::new(rb_consumer, IoEvents::empty(), flags); + let producer = FifoInner::new(rb_producer, IoEvents::OUT); + let consumer = FifoInner::new(rb_consumer, IoEvents::empty()); - Ok(Self { producer, consumer }) + Self { producer, consumer } } pub fn capacity(&self) -> usize { @@ -422,16 +356,14 @@ struct FifoInner { rb: Mutex, pollee: Pollee, is_shutdown: AtomicBool, - status_flags: AtomicU32, } impl FifoInner { - pub fn new(rb: T, init_events: IoEvents, status_flags: StatusFlags) -> Self { + pub fn new(rb: T, init_events: IoEvents) -> Self { Self { rb: Mutex::new(rb), pollee: Pollee::new(init_events), is_shutdown: AtomicBool::new(false), - status_flags: AtomicU32::new(status_flags.bits()), } } @@ -446,178 +378,32 @@ impl FifoInner { pub fn shutdown(&self) { self.is_shutdown.store(true, Ordering::Release) } - - pub fn status_flags(&self) -> StatusFlags { - let bits = self.status_flags.load(Ordering::Relaxed); - StatusFlags::from_bits(bits).unwrap() - } - - pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { - check_status_flags(new_flags)?; - self.status_flags.store(new_flags.bits(), Ordering::Relaxed); - Ok(()) - } -} - -fn check_status_flags(flags: StatusFlags) -> Result<()> { - let valid_flags: StatusFlags = StatusFlags::O_NONBLOCK | StatusFlags::O_DIRECT; - - if !valid_flags.contains(flags) { - // FIXME: Linux seems to silently ignore invalid flags. See - // . - return_errno_with_message!(Errno::EINVAL, "the flags are invalid"); - } - - if flags.contains(StatusFlags::O_DIRECT) { - return_errno_with_message!(Errno::EINVAL, "the `O_DIRECT` flag is not supported"); - } - - Ok(()) } #[cfg(ktest)] mod test { - use alloc::sync::Arc; - use core::sync::atomic; - - use ostd::{prelude::*, sync::AtomicBits}; + use ostd::prelude::*; use super::*; - use crate::thread::{ - kernel_thread::{KernelThreadExt, ThreadOptions}, - Thread, - }; #[ktest] fn test_non_copy() { #[derive(Clone, Debug, PartialEq, Eq)] struct NonCopy(Arc); - let channel = Channel::with_capacity(16).unwrap(); + let channel = Channel::new(16); let (producer, consumer) = channel.split(); let data = NonCopy(Arc::new(99)); let expected_data = data.clone(); for _ in 0..3 { - producer.push(data.clone()).unwrap(); + producer.try_push(data.clone()).unwrap(); } for _ in 0..3 { - let data = consumer.pop().unwrap().unwrap(); + let data = consumer.try_pop().unwrap().unwrap(); assert_eq!(data, expected_data); } } - - #[derive(Clone, Copy, Debug, PartialEq, Eq)] - enum Ordering { - ProduceThenConsume, - ConsumeThenProduce, - } - - fn test_blocking(produce: P, consume: C, ordering: Ordering) - where - P: Fn(Producer) + Sync + Send + 'static, - C: Fn(Consumer) + Sync + Send + 'static, - { - let channel = Channel::with_capacity(1).unwrap(); - let (producer, consumer) = channel.split(); - - // FIXME: `ThreadOptions::new` currently accepts `Fn`, forcing us to use `SpinLock` to gain - // internal mutability. We should avoid this `SpinLock` by making `ThreadOptions::new` - // accept `FnOnce`. - let producer_with_lock = SpinLock::new(Some(producer)); - let consumer_with_lock = SpinLock::new(Some(consumer)); - - let signal_producer = Arc::new(AtomicBool::new(false)); - let signal_consumer = signal_producer.clone(); - - let producer = Thread::spawn_kernel_thread(ThreadOptions::new(move || { - let producer = producer_with_lock.lock().take().unwrap(); - - if ordering == Ordering::ConsumeThenProduce { - while !signal_producer.load(atomic::Ordering::Relaxed) { - Thread::yield_now(); - } - } else { - signal_producer.store(true, atomic::Ordering::Relaxed); - } - - produce(producer); - })); - - let consumer = Thread::spawn_kernel_thread(ThreadOptions::new(move || { - let consumer = consumer_with_lock.lock().take().unwrap(); - - if ordering == Ordering::ProduceThenConsume { - while !signal_consumer.load(atomic::Ordering::Relaxed) { - Thread::yield_now(); - } - } else { - signal_consumer.store(true, atomic::Ordering::Relaxed); - } - - consume(consumer); - })); - - producer.join(); - consumer.join(); - } - - #[ktest] - fn test_read_empty() { - test_blocking( - |producer| { - assert_eq!(producer.write(&[1]).unwrap(), 1); - }, - |consumer| { - let mut buf = [0; 1]; - assert_eq!(consumer.read(&mut buf).unwrap(), 1); - assert_eq!(&buf, &[1]); - }, - Ordering::ConsumeThenProduce, - ); - } - - #[ktest] - fn test_write_full() { - test_blocking( - |producer| { - assert_eq!(producer.write(&[1, 2]).unwrap(), 1); - assert_eq!(producer.write(&[2]).unwrap(), 1); - }, - |consumer| { - let mut buf = [0; 2]; - assert_eq!(consumer.read(&mut buf).unwrap(), 1); - assert_eq!(&buf[..1], &[1]); - assert_eq!(consumer.read(&mut buf).unwrap(), 1); - assert_eq!(&buf[..1], &[2]); - }, - Ordering::ProduceThenConsume, - ); - } - - #[ktest] - fn test_read_closed() { - test_blocking( - |producer| drop(producer), - |consumer| { - let mut buf = [0; 1]; - assert_eq!(consumer.read(&mut buf).unwrap(), 0); - }, - Ordering::ConsumeThenProduce, - ); - } - - #[ktest] - fn test_write_closed() { - test_blocking( - |producer| { - assert_eq!(producer.write(&[1, 2]).unwrap(), 1); - assert_eq!(producer.write(&[2]).unwrap_err().error(), Errno::EPIPE); - }, - |consumer| drop(consumer), - Ordering::ProduceThenConsume, - ); - } } diff --git a/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs b/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs index 7486a3c65..769d0932a 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/endpoint.rs @@ -1,11 +1,15 @@ // SPDX-License-Identifier: MPL-2.0 +use core::sync::atomic::AtomicBool; + +use atomic::Ordering; + use crate::{ events::IoEvents, - fs::utils::{Channel, Consumer, Producer, StatusFlags}, + fs::utils::{Channel, Consumer, Producer}, net::socket::{unix::addr::UnixSocketAddrBound, SockShutdownCmd}, prelude::*, - process::signal::Poller, + process::signal::{Pollable, Poller}, }; pub(super) struct Endpoint { @@ -13,6 +17,7 @@ pub(super) struct Endpoint { peer_addr: Option, reader: Consumer, writer: Producer, + is_nonblocking: AtomicBool, } impl Endpoint { @@ -20,32 +25,26 @@ impl Endpoint { addr: Option, peer_addr: Option, is_nonblocking: bool, - ) -> Result<(Endpoint, Endpoint)> { - let flags = if is_nonblocking { - StatusFlags::O_NONBLOCK - } else { - StatusFlags::empty() - }; - - let (writer_this, reader_peer) = - Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split(); - let (writer_peer, reader_this) = - Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split(); + ) -> (Endpoint, Endpoint) { + let (writer_this, reader_peer) = Channel::new(DAFAULT_BUF_SIZE).split(); + let (writer_peer, reader_this) = Channel::new(DAFAULT_BUF_SIZE).split(); let this = Endpoint { addr: addr.clone(), peer_addr: peer_addr.clone(), reader: reader_this, writer: writer_this, + is_nonblocking: AtomicBool::new(is_nonblocking), }; let peer = Endpoint { addr: peer_addr, peer_addr: addr, reader: reader_peer, writer: writer_peer, + is_nonblocking: AtomicBool::new(is_nonblocking), }; - Ok((this, peer)) + (this, peer) } pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> { @@ -57,32 +56,28 @@ impl Endpoint { } pub(super) fn is_nonblocking(&self) -> bool { - let reader_status = self.reader.is_nonblocking(); - let writer_status = self.writer.is_nonblocking(); - - debug_assert!(reader_status == writer_status); - - reader_status + self.is_nonblocking.load(Ordering::Relaxed) } pub(super) fn set_nonblocking(&self, is_nonblocking: bool) -> Result<()> { - let mut reader_flags = self.reader.status_flags(); - reader_flags.set(StatusFlags::O_NONBLOCK, is_nonblocking); - self.reader.set_status_flags(reader_flags)?; - - let mut writer_flags = self.writer.status_flags(); - writer_flags.set(StatusFlags::O_NONBLOCK, is_nonblocking); - self.writer.set_status_flags(writer_flags)?; - + self.is_nonblocking.store(is_nonblocking, Ordering::Relaxed); Ok(()) } pub(super) fn read(&self, buf: &mut [u8]) -> Result { - self.reader.read(buf) + if self.is_nonblocking() { + self.reader.try_read(buf) + } else { + self.wait_events(IoEvents::IN, || self.reader.try_read(buf)) + } } pub(super) fn write(&self, buf: &[u8]) -> Result { - self.writer.write(buf) + if self.is_nonblocking() { + self.writer.try_write(buf) + } else { + self.wait_events(IoEvents::OUT, || self.writer.try_write(buf)) + } } pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { @@ -120,4 +115,10 @@ impl Endpoint { } } +impl Pollable for Endpoint { + fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { + self.poll(mask, poller) + } +} + const DAFAULT_BUF_SIZE: usize = 4096; diff --git a/kernel/aster-nix/src/net/socket/unix/stream/init.rs b/kernel/aster-nix/src/net/socket/unix/stream/init.rs index 39245de11..d53331a28 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/init.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/init.rs @@ -16,17 +16,17 @@ use crate::{ }; pub(super) struct Init { - is_nonblocking: AtomicBool, addr: Mutex>, pollee: Pollee, + is_nonblocking: AtomicBool, } impl Init { pub(super) fn new(is_nonblocking: bool) -> Self { Self { - is_nonblocking: AtomicBool::new(is_nonblocking), addr: Mutex::new(None), pollee: Pollee::new(IoEvents::empty()), + is_nonblocking: AtomicBool::new(is_nonblocking), } } @@ -58,7 +58,7 @@ impl Init { } let (this_end, remote_end) = - Endpoint::new_pair(addr, Some(remote_addr.clone()), self.is_nonblocking())?; + Endpoint::new_pair(addr, Some(remote_addr.clone()), self.is_nonblocking()); push_incoming(remote_addr, remote_end)?; Ok(Connected::new(this_end)) @@ -69,11 +69,11 @@ impl Init { } pub(super) fn is_nonblocking(&self) -> bool { - self.is_nonblocking.load(Ordering::Acquire) + self.is_nonblocking.load(Ordering::Relaxed) } pub(super) fn set_nonblocking(&self, is_nonblocking: bool) { - self.is_nonblocking.store(is_nonblocking, Ordering::Release); + self.is_nonblocking.store(is_nonblocking, Ordering::Relaxed); } pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { diff --git a/kernel/aster-nix/src/net/socket/unix/stream/listener.rs b/kernel/aster-nix/src/net/socket/unix/stream/listener.rs index 9f75c9926..3d540312b 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/listener.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/listener.rs @@ -39,11 +39,11 @@ impl Listener { } pub(super) fn is_nonblocking(&self) -> bool { - self.is_nonblocking.load(Ordering::Acquire) + self.is_nonblocking.load(Ordering::Relaxed) } pub(super) fn set_nonblocking(&self, is_nonblocking: bool) { - self.is_nonblocking.store(is_nonblocking, Ordering::Release); + self.is_nonblocking.store(is_nonblocking, Ordering::Relaxed); } pub(super) fn accept(&self) -> Result<(Arc, SocketAddr)> { diff --git a/kernel/aster-nix/src/net/socket/unix/stream/socket.rs b/kernel/aster-nix/src/net/socket/unix/stream/socket.rs index ac82bf396..41e41529c 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/socket.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/socket.rs @@ -51,8 +51,8 @@ impl UnixStreamSocket { Self::new_init(init) } - pub fn new_pair(nonblocking: bool) -> Result<(Arc, Arc)> { - let (end_a, end_b) = Endpoint::new_pair(None, None, nonblocking)?; + pub fn new_pair(nonblocking: bool) -> (Arc, Arc) { + let (end_a, end_b) = Endpoint::new_pair(None, None, nonblocking); let connected_a = { let connected = Connected::new(end_a); @@ -63,7 +63,7 @@ impl UnixStreamSocket { Self::new_connected(connected) }; - Ok((Arc::new(connected_a), Arc::new(connected_b))) + (Arc::new(connected_a), Arc::new(connected_b)) } fn bound_addr(&self) -> Option { diff --git a/kernel/aster-nix/src/syscall/pipe.rs b/kernel/aster-nix/src/syscall/pipe.rs index ba27a0cfe..56a60533a 100644 --- a/kernel/aster-nix/src/syscall/pipe.rs +++ b/kernel/aster-nix/src/syscall/pipe.rs @@ -8,23 +8,23 @@ use crate::{ utils::{Channel, CreationFlags, StatusFlags}, }, prelude::*, - util::{read_val_from_user, write_val_to_user}, + util::write_val_to_user, }; pub fn sys_pipe2(fds: Vaddr, flags: u32) -> Result { debug!("flags: {:?}", flags); - let mut pipe_fds = read_val_from_user::(fds)?; - let (reader, writer) = { - let (producer, consumer) = Channel::with_capacity_and_flags( - PIPE_BUF_SIZE, - StatusFlags::from_bits_truncate(flags), - )? - .split(); - (PipeReader::new(consumer), PipeWriter::new(producer)) + let (pipe_reader, pipe_writer) = { + let (producer, consumer) = Channel::new(PIPE_BUF_SIZE).split(); + + let status_flags = StatusFlags::from_bits_truncate(flags); + + ( + PipeReader::new(consumer, status_flags)?, + PipeWriter::new(producer, status_flags)?, + ) }; - let pipe_reader = Arc::new(reader); - let pipe_writer = Arc::new(writer); + let fd_flags = if CreationFlags::from_bits_truncate(flags).contains(CreationFlags::O_CLOEXEC) { FdFlags::CLOEXEC } else { @@ -33,10 +33,18 @@ pub fn sys_pipe2(fds: Vaddr, flags: u32) -> Result { let current = current!(); let mut file_table = current.file_table().lock(); - pipe_fds.reader_fd = file_table.insert(pipe_reader, fd_flags); - pipe_fds.writer_fd = file_table.insert(pipe_writer, fd_flags); + + let pipe_fds = PipeFds { + reader_fd: file_table.insert(pipe_reader, fd_flags), + writer_fd: file_table.insert(pipe_writer, fd_flags), + }; debug!("pipe_fds: {:?}", pipe_fds); - write_val_to_user(fds, &pipe_fds)?; + + if let Err(err) = write_val_to_user(fds, &pipe_fds) { + file_table.close_file(pipe_fds.reader_fd).unwrap(); + file_table.close_file(pipe_fds.writer_fd).unwrap(); + return Err(err); + } Ok(SyscallReturn::Return(0)) } diff --git a/kernel/aster-nix/src/syscall/socketpair.rs b/kernel/aster-nix/src/syscall/socketpair.rs index 7a55ef5c1..8bbb8dba2 100644 --- a/kernel/aster-nix/src/syscall/socketpair.rs +++ b/kernel/aster-nix/src/syscall/socketpair.rs @@ -25,7 +25,7 @@ pub fn sys_socketpair(domain: i32, type_: i32, protocol: i32, sv: Vaddr) -> Resu let nonblocking = sock_flags.contains(SockFlags::SOCK_NONBLOCK); let (socket_a, socket_b) = match (domain, sock_type) { (CSocketAddrFamily::AF_UNIX, SockType::SOCK_STREAM) => { - UnixStreamSocket::new_pair(nonblocking)? + UnixStreamSocket::new_pair(nonblocking) } _ => return_errno_with_message!( Errno::EAFNOSUPPORT,