Move status from channels to pipes

This commit is contained in:
Ruihan Li
2024-07-15 17:35:34 +08:00
committed by Tate, Hongliang Tian
parent d24c7f8b9c
commit 8df51ab001
8 changed files with 294 additions and 322 deletions

View File

@ -1,5 +1,9 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::AtomicU32;
use atomic::Ordering;
use super::{ use super::{
file_handle::FileLike, file_handle::FileLike,
utils::{AccessMode, Consumer, InodeMode, InodeType, Metadata, Producer, StatusFlags}, utils::{AccessMode, Consumer, InodeMode, InodeType, Metadata, Producer, StatusFlags},
@ -16,11 +20,17 @@ use crate::{
pub struct PipeReader { pub struct PipeReader {
consumer: Consumer<u8>, consumer: Consumer<u8>,
status_flags: AtomicU32,
} }
impl PipeReader { impl PipeReader {
pub fn new(consumer: Consumer<u8>) -> Self { pub fn new(consumer: Consumer<u8>, status_flags: StatusFlags) -> Result<Arc<Self>> {
Self { consumer } check_status_flags(status_flags)?;
Ok(Arc::new(Self {
consumer,
status_flags: AtomicU32::new(status_flags.bits()),
}))
} }
} }
@ -32,15 +42,22 @@ impl Pollable for PipeReader {
impl FileLike for PipeReader { impl FileLike for PipeReader {
fn read(&self, buf: &mut [u8]) -> Result<usize> { fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.consumer.read(buf) if self.status_flags().contains(StatusFlags::O_NONBLOCK) {
self.consumer.try_read(buf)
} else {
self.wait_events(IoEvents::IN, || self.consumer.try_read(buf))
}
} }
fn status_flags(&self) -> StatusFlags { fn status_flags(&self) -> StatusFlags {
self.consumer.status_flags() StatusFlags::from_bits_truncate(self.status_flags.load(Ordering::Relaxed))
} }
fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
self.consumer.set_status_flags(new_flags) check_status_flags(new_flags)?;
self.status_flags.store(new_flags.bits(), Ordering::Relaxed);
Ok(())
} }
fn access_mode(&self) -> AccessMode { fn access_mode(&self) -> AccessMode {
@ -85,11 +102,17 @@ impl FileLike for PipeReader {
pub struct PipeWriter { pub struct PipeWriter {
producer: Producer<u8>, producer: Producer<u8>,
status_flags: AtomicU32,
} }
impl PipeWriter { impl PipeWriter {
pub fn new(producer: Producer<u8>) -> Self { pub fn new(producer: Producer<u8>, status_flags: StatusFlags) -> Result<Arc<Self>> {
Self { producer } check_status_flags(status_flags)?;
Ok(Arc::new(Self {
producer,
status_flags: AtomicU32::new(status_flags.bits()),
}))
} }
} }
@ -101,15 +124,22 @@ impl Pollable for PipeWriter {
impl FileLike for PipeWriter { impl FileLike for PipeWriter {
fn write(&self, buf: &[u8]) -> Result<usize> { fn write(&self, buf: &[u8]) -> Result<usize> {
self.producer.write(buf) if self.status_flags().contains(StatusFlags::O_NONBLOCK) {
self.producer.try_write(buf)
} else {
self.wait_events(IoEvents::OUT, || self.producer.try_write(buf))
}
} }
fn status_flags(&self) -> StatusFlags { fn status_flags(&self) -> StatusFlags {
self.producer.status_flags() StatusFlags::from_bits_truncate(self.status_flags.load(Ordering::Relaxed))
} }
fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
self.producer.set_status_flags(new_flags) check_status_flags(new_flags)?;
self.status_flags.store(new_flags.bits(), Ordering::Relaxed);
Ok(())
} }
fn access_mode(&self) -> AccessMode { fn access_mode(&self) -> AccessMode {
@ -151,3 +181,150 @@ impl FileLike for PipeWriter {
self.producer.unregister_observer(observer) self.producer.unregister_observer(observer)
} }
} }
fn check_status_flags(status_flags: StatusFlags) -> Result<()> {
if status_flags.contains(StatusFlags::O_DIRECT) {
// "O_DIRECT .. Older kernels that do not support this flag will indicate this via an
// EINVAL error."
//
// See <https://man7.org/linux/man-pages/man2/pipe.2.html>.
return_errno_with_message!(Errno::EINVAL, "the `O_DIRECT` flag is not supported");
}
// TODO: Setting most of the other flags will succeed on Linux, but their effects need to be
// validated.
Ok(())
}
#[cfg(ktest)]
mod test {
use alloc::sync::Arc;
use core::sync::atomic::{self, AtomicBool};
use ostd::prelude::*;
use super::*;
use crate::{
fs::utils::Channel,
thread::{
kernel_thread::{KernelThreadExt, ThreadOptions},
Thread,
},
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum Ordering {
WriteThenRead,
ReadThenWrite,
}
fn test_blocking<W, R>(write: W, read: R, ordering: Ordering)
where
W: Fn(Arc<PipeWriter>) + Sync + Send + 'static,
R: Fn(Arc<PipeReader>) + Sync + Send + 'static,
{
let channel = Channel::new(1);
let (writer, readr) = channel.split();
let writer = PipeWriter::new(writer, StatusFlags::empty()).unwrap();
let reader = PipeReader::new(readr, StatusFlags::empty()).unwrap();
// FIXME: `ThreadOptions::new` currently accepts `Fn`, forcing us to use `SpinLock` to gain
// internal mutability. We should avoid this `SpinLock` by making `ThreadOptions::new`
// accept `FnOnce`.
let writer_with_lock = SpinLock::new(Some(writer));
let reader_with_lock = SpinLock::new(Some(reader));
let signal_writer = Arc::new(AtomicBool::new(false));
let signal_reader = signal_writer.clone();
let writer = Thread::spawn_kernel_thread(ThreadOptions::new(move || {
let writer = writer_with_lock.lock().take().unwrap();
if ordering == Ordering::ReadThenWrite {
while !signal_writer.load(atomic::Ordering::Relaxed) {
Thread::yield_now();
}
} else {
signal_writer.store(true, atomic::Ordering::Relaxed);
}
write(writer);
}));
let reader = Thread::spawn_kernel_thread(ThreadOptions::new(move || {
let reader = reader_with_lock.lock().take().unwrap();
if ordering == Ordering::WriteThenRead {
while !signal_reader.load(atomic::Ordering::Relaxed) {
Thread::yield_now();
}
} else {
signal_reader.store(true, atomic::Ordering::Relaxed);
}
read(reader);
}));
writer.join();
reader.join();
}
#[ktest]
fn test_read_empty() {
test_blocking(
|writer| {
assert_eq!(writer.write(&[1]).unwrap(), 1);
},
|reader| {
let mut buf = [0; 1];
assert_eq!(reader.read(&mut buf).unwrap(), 1);
assert_eq!(&buf, &[1]);
},
Ordering::ReadThenWrite,
);
}
#[ktest]
fn test_write_full() {
test_blocking(
|writer| {
assert_eq!(writer.write(&[1, 2]).unwrap(), 1);
assert_eq!(writer.write(&[2]).unwrap(), 1);
},
|reader| {
let mut buf = [0; 2];
assert_eq!(reader.read(&mut buf).unwrap(), 1);
assert_eq!(&buf[..1], &[1]);
assert_eq!(reader.read(&mut buf).unwrap(), 1);
assert_eq!(&buf[..1], &[2]);
},
Ordering::WriteThenRead,
);
}
#[ktest]
fn test_read_closed() {
test_blocking(
|writer| drop(writer),
|reader| {
let mut buf = [0; 1];
assert_eq!(reader.read(&mut buf).unwrap(), 0);
},
Ordering::ReadThenWrite,
);
}
#[ktest]
fn test_write_closed() {
test_blocking(
|writer| {
assert_eq!(writer.write(&[1, 2]).unwrap(), 1);
assert_eq!(writer.write(&[2]).unwrap_err().error(), Errno::EPIPE);
},
|reader| drop(reader),
Ordering::WriteThenRead,
);
}
}

View File

@ -1,16 +1,15 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use core::sync::atomic::{AtomicBool, Ordering};
use aster_rights::{Read, ReadOp, TRights, Write, WriteOp}; use aster_rights::{Read, ReadOp, TRights, Write, WriteOp};
use aster_rights_proc::require; use aster_rights_proc::require;
use ringbuf::{HeapConsumer as HeapRbConsumer, HeapProducer as HeapRbProducer, HeapRb}; use ringbuf::{HeapConsumer as HeapRbConsumer, HeapProducer as HeapRbProducer, HeapRb};
use super::StatusFlags;
use crate::{ use crate::{
events::{IoEvents, Observer}, events::{IoEvents, Observer},
prelude::*, prelude::*,
process::signal::{Pollable, Pollee, Poller}, process::signal::{Pollee, Poller},
}; };
/// A unidirectional communication channel, intended to implement IPC, e.g., pipe, /// A unidirectional communication channel, intended to implement IPC, e.g., pipe,
@ -21,15 +20,18 @@ pub struct Channel<T> {
} }
impl<T> Channel<T> { impl<T> Channel<T> {
pub fn with_capacity(capacity: usize) -> Result<Self> { /// Creates a new channel with the given capacity.
Self::with_capacity_and_flags(capacity, StatusFlags::empty()) ///
} /// # Panics
///
/// This method will panic if the given capacity is zero.
pub fn new(capacity: usize) -> Self {
let common = Arc::new(Common::new(capacity));
pub fn with_capacity_and_flags(capacity: usize, flags: StatusFlags) -> Result<Self> {
let common = Arc::new(Common::with_capacity_and_flags(capacity, flags)?);
let producer = Producer(Fifo::new(common.clone())); let producer = Producer(Fifo::new(common.clone()));
let consumer = Consumer(Fifo::new(common)); let consumer = Consumer(Fifo::new(common));
Ok(Self { producer, consumer })
Self { producer, consumer }
} }
pub fn split(self) -> (Producer<T>, Consumer<T>) { pub fn split(self) -> (Producer<T>, Consumer<T>) {
@ -68,20 +70,6 @@ macro_rules! impl_common_methods_for_channel {
self.peer_end().is_shutdown() self.peer_end().is_shutdown()
} }
pub fn status_flags(&self) -> StatusFlags {
self.this_end().status_flags()
}
pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
self.this_end().set_status_flags(new_flags)
}
pub fn is_nonblocking(&self) -> bool {
self.this_end()
.status_flags()
.contains(StatusFlags::O_NONBLOCK)
}
pub fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { pub fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents {
self.this_end().pollee.poll(mask, poller) self.this_end().pollee.poll(mask, poller)
} }
@ -139,23 +127,13 @@ impl<T> Producer<T> {
impl_common_methods_for_channel!(); impl_common_methods_for_channel!();
} }
impl<T> Pollable for Producer<T> {
fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents {
self.poll(mask, poller)
}
}
impl<T: Copy> Producer<T> { impl<T: Copy> Producer<T> {
pub fn write(&self, buf: &[T]) -> Result<usize> { /// Tries to write `buf` to the channel.
if self.is_nonblocking() { ///
self.try_write(buf) /// - Returns `Ok(_)` with the number of bytes written if successful.
} else { /// - Returns `Err(EPIPE)` if the channel is shut down.
// The POLLOUT event is set after shutdown, so waiting for the single event is enough. /// - Returns `Err(EAGAIN)` if the channel is full.
self.wait_events(IoEvents::OUT, || self.try_write(buf)) pub fn try_write(&self, buf: &[T]) -> Result<usize> {
}
}
fn try_write(&self, buf: &[T]) -> Result<usize> {
if buf.is_empty() { if buf.is_empty() {
// Even after shutdown, writing an empty buffer is still fine. // Even after shutdown, writing an empty buffer is still fine.
return Ok(0); return Ok(0);
@ -177,35 +155,12 @@ impl<T: Copy> Producer<T> {
} }
impl<T> Producer<T> { impl<T> Producer<T> {
/// Pushes an item into the producer. /// Tries to push `item` to the channel.
/// ///
/// On failure, this method returns `Err` containing /// - Returns `Ok(())` if successful.
/// the item fails to push. /// - Returns `Err(EPIPE)` if the channel is shut down.
pub fn push(&self, item: T) -> core::result::Result<(), (Error, T)> { /// - Returns `Err(EAGAIN)` if the channel is full.
if self.is_nonblocking() { pub fn try_push(&self, item: T) -> core::result::Result<(), (Error, T)> {
return self.try_push(item);
}
let mut stored_item = Some(item);
// The POLLOUT event is set after shutdown, so waiting for the single event is enough.
let result = self.wait_events(IoEvents::OUT, || {
match self.try_push(stored_item.take().unwrap()) {
Ok(()) => Ok(()),
Err((err, item)) => {
stored_item = Some(item);
Err(err)
}
}
});
match result {
Ok(()) => Ok(()),
Err(err) => Err((err, stored_item.unwrap())),
}
}
fn try_push(&self, item: T) -> core::result::Result<(), (Error, T)> {
if self.is_shutdown() || self.is_peer_shutdown() { if self.is_shutdown() || self.is_peer_shutdown() {
let err = Error::with_message(Errno::EPIPE, "the channel is shut down"); let err = Error::with_message(Errno::EPIPE, "the channel is shut down");
return Err((err, item)); return Err((err, item));
@ -265,23 +220,13 @@ impl<T> Consumer<T> {
impl_common_methods_for_channel!(); impl_common_methods_for_channel!();
} }
impl<T> Pollable for Consumer<T> {
fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents {
self.poll(mask, poller)
}
}
impl<T: Copy> Consumer<T> { impl<T: Copy> Consumer<T> {
pub fn read(&self, buf: &mut [T]) -> Result<usize> { /// Tries to read `buf` from the channel.
if self.is_nonblocking() { ///
self.try_read(buf) /// - Returns `Ok(_)` with the number of bytes read if successful.
} else { /// - Returns `Ok(0)` if the channel is shut down and there is no data left.
// The POLLHUP event is in `IoEvents::ALWAYS_POLL`, which is not specified again. /// - Returns `Err(EAGAIN)` if the channel is empty.
self.wait_events(IoEvents::IN, || self.try_read(buf)) pub fn try_read(&self, buf: &mut [T]) -> Result<usize> {
}
}
fn try_read(&self, buf: &mut [T]) -> Result<usize> {
if buf.is_empty() { if buf.is_empty() {
return Ok(0); return Ok(0);
} }
@ -303,17 +248,12 @@ impl<T: Copy> Consumer<T> {
} }
impl<T> Consumer<T> { impl<T> Consumer<T> {
/// Pops an item from the consumer. /// Tries to read an item from the channel.
pub fn pop(&self) -> Result<Option<T>> { ///
if self.is_nonblocking() { /// - Returns `Ok(Some(_))` with the popped item if successful.
self.try_pop() /// - Returns `Ok(None)` if the channel is shut down and there is no data left.
} else { /// - Returns `Err(EAGAIN)` if the channel is empty.
// The POLLHUP event is in `IoEvents::ALWAYS_POLL`, which is not specified again. pub fn try_pop(&self) -> Result<Option<T>> {
self.wait_events(IoEvents::IN, || self.try_pop())
}
}
fn try_pop(&self) -> Result<Option<T>> {
// This must be recorded before the actual operation to avoid race conditions. // This must be recorded before the actual operation to avoid race conditions.
let is_shutdown = self.is_shutdown() || self.is_peer_shutdown(); let is_shutdown = self.is_shutdown() || self.is_peer_shutdown();
@ -397,20 +337,14 @@ struct Common<T> {
} }
impl<T> Common<T> { impl<T> Common<T> {
fn with_capacity_and_flags(capacity: usize, flags: StatusFlags) -> Result<Self> { fn new(capacity: usize) -> Self {
check_status_flags(flags)?;
if capacity == 0 {
return_errno_with_message!(Errno::EINVAL, "the channel capacity cannot be zero");
}
let rb: HeapRb<T> = HeapRb::new(capacity); let rb: HeapRb<T> = HeapRb::new(capacity);
let (rb_producer, rb_consumer) = rb.split(); let (rb_producer, rb_consumer) = rb.split();
let producer = FifoInner::new(rb_producer, IoEvents::OUT, flags); let producer = FifoInner::new(rb_producer, IoEvents::OUT);
let consumer = FifoInner::new(rb_consumer, IoEvents::empty(), flags); let consumer = FifoInner::new(rb_consumer, IoEvents::empty());
Ok(Self { producer, consumer }) Self { producer, consumer }
} }
pub fn capacity(&self) -> usize { pub fn capacity(&self) -> usize {
@ -422,16 +356,14 @@ struct FifoInner<T> {
rb: Mutex<T>, rb: Mutex<T>,
pollee: Pollee, pollee: Pollee,
is_shutdown: AtomicBool, is_shutdown: AtomicBool,
status_flags: AtomicU32,
} }
impl<T> FifoInner<T> { impl<T> FifoInner<T> {
pub fn new(rb: T, init_events: IoEvents, status_flags: StatusFlags) -> Self { pub fn new(rb: T, init_events: IoEvents) -> Self {
Self { Self {
rb: Mutex::new(rb), rb: Mutex::new(rb),
pollee: Pollee::new(init_events), pollee: Pollee::new(init_events),
is_shutdown: AtomicBool::new(false), is_shutdown: AtomicBool::new(false),
status_flags: AtomicU32::new(status_flags.bits()),
} }
} }
@ -446,178 +378,32 @@ impl<T> FifoInner<T> {
pub fn shutdown(&self) { pub fn shutdown(&self) {
self.is_shutdown.store(true, Ordering::Release) self.is_shutdown.store(true, Ordering::Release)
} }
pub fn status_flags(&self) -> StatusFlags {
let bits = self.status_flags.load(Ordering::Relaxed);
StatusFlags::from_bits(bits).unwrap()
}
pub fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> {
check_status_flags(new_flags)?;
self.status_flags.store(new_flags.bits(), Ordering::Relaxed);
Ok(())
}
}
fn check_status_flags(flags: StatusFlags) -> Result<()> {
let valid_flags: StatusFlags = StatusFlags::O_NONBLOCK | StatusFlags::O_DIRECT;
if !valid_flags.contains(flags) {
// FIXME: Linux seems to silently ignore invalid flags. See
// <https://man7.org/linux/man-pages/man2/fcntl.2.html>.
return_errno_with_message!(Errno::EINVAL, "the flags are invalid");
}
if flags.contains(StatusFlags::O_DIRECT) {
return_errno_with_message!(Errno::EINVAL, "the `O_DIRECT` flag is not supported");
}
Ok(())
} }
#[cfg(ktest)] #[cfg(ktest)]
mod test { mod test {
use alloc::sync::Arc; use ostd::prelude::*;
use core::sync::atomic;
use ostd::{prelude::*, sync::AtomicBits};
use super::*; use super::*;
use crate::thread::{
kernel_thread::{KernelThreadExt, ThreadOptions},
Thread,
};
#[ktest] #[ktest]
fn test_non_copy() { fn test_non_copy() {
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
struct NonCopy(Arc<usize>); struct NonCopy(Arc<usize>);
let channel = Channel::with_capacity(16).unwrap(); let channel = Channel::new(16);
let (producer, consumer) = channel.split(); let (producer, consumer) = channel.split();
let data = NonCopy(Arc::new(99)); let data = NonCopy(Arc::new(99));
let expected_data = data.clone(); let expected_data = data.clone();
for _ in 0..3 { for _ in 0..3 {
producer.push(data.clone()).unwrap(); producer.try_push(data.clone()).unwrap();
} }
for _ in 0..3 { for _ in 0..3 {
let data = consumer.pop().unwrap().unwrap(); let data = consumer.try_pop().unwrap().unwrap();
assert_eq!(data, expected_data); assert_eq!(data, expected_data);
} }
} }
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum Ordering {
ProduceThenConsume,
ConsumeThenProduce,
}
fn test_blocking<P, C>(produce: P, consume: C, ordering: Ordering)
where
P: Fn(Producer<u8>) + Sync + Send + 'static,
C: Fn(Consumer<u8>) + Sync + Send + 'static,
{
let channel = Channel::with_capacity(1).unwrap();
let (producer, consumer) = channel.split();
// FIXME: `ThreadOptions::new` currently accepts `Fn`, forcing us to use `SpinLock` to gain
// internal mutability. We should avoid this `SpinLock` by making `ThreadOptions::new`
// accept `FnOnce`.
let producer_with_lock = SpinLock::new(Some(producer));
let consumer_with_lock = SpinLock::new(Some(consumer));
let signal_producer = Arc::new(AtomicBool::new(false));
let signal_consumer = signal_producer.clone();
let producer = Thread::spawn_kernel_thread(ThreadOptions::new(move || {
let producer = producer_with_lock.lock().take().unwrap();
if ordering == Ordering::ConsumeThenProduce {
while !signal_producer.load(atomic::Ordering::Relaxed) {
Thread::yield_now();
}
} else {
signal_producer.store(true, atomic::Ordering::Relaxed);
}
produce(producer);
}));
let consumer = Thread::spawn_kernel_thread(ThreadOptions::new(move || {
let consumer = consumer_with_lock.lock().take().unwrap();
if ordering == Ordering::ProduceThenConsume {
while !signal_consumer.load(atomic::Ordering::Relaxed) {
Thread::yield_now();
}
} else {
signal_consumer.store(true, atomic::Ordering::Relaxed);
}
consume(consumer);
}));
producer.join();
consumer.join();
}
#[ktest]
fn test_read_empty() {
test_blocking(
|producer| {
assert_eq!(producer.write(&[1]).unwrap(), 1);
},
|consumer| {
let mut buf = [0; 1];
assert_eq!(consumer.read(&mut buf).unwrap(), 1);
assert_eq!(&buf, &[1]);
},
Ordering::ConsumeThenProduce,
);
}
#[ktest]
fn test_write_full() {
test_blocking(
|producer| {
assert_eq!(producer.write(&[1, 2]).unwrap(), 1);
assert_eq!(producer.write(&[2]).unwrap(), 1);
},
|consumer| {
let mut buf = [0; 2];
assert_eq!(consumer.read(&mut buf).unwrap(), 1);
assert_eq!(&buf[..1], &[1]);
assert_eq!(consumer.read(&mut buf).unwrap(), 1);
assert_eq!(&buf[..1], &[2]);
},
Ordering::ProduceThenConsume,
);
}
#[ktest]
fn test_read_closed() {
test_blocking(
|producer| drop(producer),
|consumer| {
let mut buf = [0; 1];
assert_eq!(consumer.read(&mut buf).unwrap(), 0);
},
Ordering::ConsumeThenProduce,
);
}
#[ktest]
fn test_write_closed() {
test_blocking(
|producer| {
assert_eq!(producer.write(&[1, 2]).unwrap(), 1);
assert_eq!(producer.write(&[2]).unwrap_err().error(), Errno::EPIPE);
},
|consumer| drop(consumer),
Ordering::ProduceThenConsume,
);
}
} }

View File

@ -1,11 +1,15 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::AtomicBool;
use atomic::Ordering;
use crate::{ use crate::{
events::IoEvents, events::IoEvents,
fs::utils::{Channel, Consumer, Producer, StatusFlags}, fs::utils::{Channel, Consumer, Producer},
net::socket::{unix::addr::UnixSocketAddrBound, SockShutdownCmd}, net::socket::{unix::addr::UnixSocketAddrBound, SockShutdownCmd},
prelude::*, prelude::*,
process::signal::Poller, process::signal::{Pollable, Poller},
}; };
pub(super) struct Endpoint { pub(super) struct Endpoint {
@ -13,6 +17,7 @@ pub(super) struct Endpoint {
peer_addr: Option<UnixSocketAddrBound>, peer_addr: Option<UnixSocketAddrBound>,
reader: Consumer<u8>, reader: Consumer<u8>,
writer: Producer<u8>, writer: Producer<u8>,
is_nonblocking: AtomicBool,
} }
impl Endpoint { impl Endpoint {
@ -20,32 +25,26 @@ impl Endpoint {
addr: Option<UnixSocketAddrBound>, addr: Option<UnixSocketAddrBound>,
peer_addr: Option<UnixSocketAddrBound>, peer_addr: Option<UnixSocketAddrBound>,
is_nonblocking: bool, is_nonblocking: bool,
) -> Result<(Endpoint, Endpoint)> { ) -> (Endpoint, Endpoint) {
let flags = if is_nonblocking { let (writer_this, reader_peer) = Channel::new(DAFAULT_BUF_SIZE).split();
StatusFlags::O_NONBLOCK let (writer_peer, reader_this) = Channel::new(DAFAULT_BUF_SIZE).split();
} else {
StatusFlags::empty()
};
let (writer_this, reader_peer) =
Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split();
let (writer_peer, reader_this) =
Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split();
let this = Endpoint { let this = Endpoint {
addr: addr.clone(), addr: addr.clone(),
peer_addr: peer_addr.clone(), peer_addr: peer_addr.clone(),
reader: reader_this, reader: reader_this,
writer: writer_this, writer: writer_this,
is_nonblocking: AtomicBool::new(is_nonblocking),
}; };
let peer = Endpoint { let peer = Endpoint {
addr: peer_addr, addr: peer_addr,
peer_addr: addr, peer_addr: addr,
reader: reader_peer, reader: reader_peer,
writer: writer_peer, writer: writer_peer,
is_nonblocking: AtomicBool::new(is_nonblocking),
}; };
Ok((this, peer)) (this, peer)
} }
pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> { pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> {
@ -57,32 +56,28 @@ impl Endpoint {
} }
pub(super) fn is_nonblocking(&self) -> bool { pub(super) fn is_nonblocking(&self) -> bool {
let reader_status = self.reader.is_nonblocking(); self.is_nonblocking.load(Ordering::Relaxed)
let writer_status = self.writer.is_nonblocking();
debug_assert!(reader_status == writer_status);
reader_status
} }
pub(super) fn set_nonblocking(&self, is_nonblocking: bool) -> Result<()> { pub(super) fn set_nonblocking(&self, is_nonblocking: bool) -> Result<()> {
let mut reader_flags = self.reader.status_flags(); self.is_nonblocking.store(is_nonblocking, Ordering::Relaxed);
reader_flags.set(StatusFlags::O_NONBLOCK, is_nonblocking);
self.reader.set_status_flags(reader_flags)?;
let mut writer_flags = self.writer.status_flags();
writer_flags.set(StatusFlags::O_NONBLOCK, is_nonblocking);
self.writer.set_status_flags(writer_flags)?;
Ok(()) Ok(())
} }
pub(super) fn read(&self, buf: &mut [u8]) -> Result<usize> { pub(super) fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.reader.read(buf) if self.is_nonblocking() {
self.reader.try_read(buf)
} else {
self.wait_events(IoEvents::IN, || self.reader.try_read(buf))
}
} }
pub(super) fn write(&self, buf: &[u8]) -> Result<usize> { pub(super) fn write(&self, buf: &[u8]) -> Result<usize> {
self.writer.write(buf) if self.is_nonblocking() {
self.writer.try_write(buf)
} else {
self.wait_events(IoEvents::OUT, || self.writer.try_write(buf))
}
} }
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
@ -120,4 +115,10 @@ impl Endpoint {
} }
} }
impl Pollable for Endpoint {
fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents {
self.poll(mask, poller)
}
}
const DAFAULT_BUF_SIZE: usize = 4096; const DAFAULT_BUF_SIZE: usize = 4096;

View File

@ -16,17 +16,17 @@ use crate::{
}; };
pub(super) struct Init { pub(super) struct Init {
is_nonblocking: AtomicBool,
addr: Mutex<Option<UnixSocketAddrBound>>, addr: Mutex<Option<UnixSocketAddrBound>>,
pollee: Pollee, pollee: Pollee,
is_nonblocking: AtomicBool,
} }
impl Init { impl Init {
pub(super) fn new(is_nonblocking: bool) -> Self { pub(super) fn new(is_nonblocking: bool) -> Self {
Self { Self {
is_nonblocking: AtomicBool::new(is_nonblocking),
addr: Mutex::new(None), addr: Mutex::new(None),
pollee: Pollee::new(IoEvents::empty()), pollee: Pollee::new(IoEvents::empty()),
is_nonblocking: AtomicBool::new(is_nonblocking),
} }
} }
@ -58,7 +58,7 @@ impl Init {
} }
let (this_end, remote_end) = let (this_end, remote_end) =
Endpoint::new_pair(addr, Some(remote_addr.clone()), self.is_nonblocking())?; Endpoint::new_pair(addr, Some(remote_addr.clone()), self.is_nonblocking());
push_incoming(remote_addr, remote_end)?; push_incoming(remote_addr, remote_end)?;
Ok(Connected::new(this_end)) Ok(Connected::new(this_end))
@ -69,11 +69,11 @@ impl Init {
} }
pub(super) fn is_nonblocking(&self) -> bool { pub(super) fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Acquire) self.is_nonblocking.load(Ordering::Relaxed)
} }
pub(super) fn set_nonblocking(&self, is_nonblocking: bool) { pub(super) fn set_nonblocking(&self, is_nonblocking: bool) {
self.is_nonblocking.store(is_nonblocking, Ordering::Release); self.is_nonblocking.store(is_nonblocking, Ordering::Relaxed);
} }
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents {

View File

@ -39,11 +39,11 @@ impl Listener {
} }
pub(super) fn is_nonblocking(&self) -> bool { pub(super) fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Acquire) self.is_nonblocking.load(Ordering::Relaxed)
} }
pub(super) fn set_nonblocking(&self, is_nonblocking: bool) { pub(super) fn set_nonblocking(&self, is_nonblocking: bool) {
self.is_nonblocking.store(is_nonblocking, Ordering::Release); self.is_nonblocking.store(is_nonblocking, Ordering::Relaxed);
} }
pub(super) fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> { pub(super) fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {

View File

@ -51,8 +51,8 @@ impl UnixStreamSocket {
Self::new_init(init) Self::new_init(init)
} }
pub fn new_pair(nonblocking: bool) -> Result<(Arc<Self>, Arc<Self>)> { pub fn new_pair(nonblocking: bool) -> (Arc<Self>, Arc<Self>) {
let (end_a, end_b) = Endpoint::new_pair(None, None, nonblocking)?; let (end_a, end_b) = Endpoint::new_pair(None, None, nonblocking);
let connected_a = { let connected_a = {
let connected = Connected::new(end_a); let connected = Connected::new(end_a);
@ -63,7 +63,7 @@ impl UnixStreamSocket {
Self::new_connected(connected) Self::new_connected(connected)
}; };
Ok((Arc::new(connected_a), Arc::new(connected_b))) (Arc::new(connected_a), Arc::new(connected_b))
} }
fn bound_addr(&self) -> Option<UnixSocketAddrBound> { fn bound_addr(&self) -> Option<UnixSocketAddrBound> {

View File

@ -8,23 +8,23 @@ use crate::{
utils::{Channel, CreationFlags, StatusFlags}, utils::{Channel, CreationFlags, StatusFlags},
}, },
prelude::*, prelude::*,
util::{read_val_from_user, write_val_to_user}, util::write_val_to_user,
}; };
pub fn sys_pipe2(fds: Vaddr, flags: u32) -> Result<SyscallReturn> { pub fn sys_pipe2(fds: Vaddr, flags: u32) -> Result<SyscallReturn> {
debug!("flags: {:?}", flags); debug!("flags: {:?}", flags);
let mut pipe_fds = read_val_from_user::<PipeFds>(fds)?; let (pipe_reader, pipe_writer) = {
let (reader, writer) = { let (producer, consumer) = Channel::new(PIPE_BUF_SIZE).split();
let (producer, consumer) = Channel::with_capacity_and_flags(
PIPE_BUF_SIZE, let status_flags = StatusFlags::from_bits_truncate(flags);
StatusFlags::from_bits_truncate(flags),
)? (
.split(); PipeReader::new(consumer, status_flags)?,
(PipeReader::new(consumer), PipeWriter::new(producer)) PipeWriter::new(producer, status_flags)?,
)
}; };
let pipe_reader = Arc::new(reader);
let pipe_writer = Arc::new(writer);
let fd_flags = if CreationFlags::from_bits_truncate(flags).contains(CreationFlags::O_CLOEXEC) { let fd_flags = if CreationFlags::from_bits_truncate(flags).contains(CreationFlags::O_CLOEXEC) {
FdFlags::CLOEXEC FdFlags::CLOEXEC
} else { } else {
@ -33,10 +33,18 @@ pub fn sys_pipe2(fds: Vaddr, flags: u32) -> Result<SyscallReturn> {
let current = current!(); let current = current!();
let mut file_table = current.file_table().lock(); let mut file_table = current.file_table().lock();
pipe_fds.reader_fd = file_table.insert(pipe_reader, fd_flags);
pipe_fds.writer_fd = file_table.insert(pipe_writer, fd_flags); let pipe_fds = PipeFds {
reader_fd: file_table.insert(pipe_reader, fd_flags),
writer_fd: file_table.insert(pipe_writer, fd_flags),
};
debug!("pipe_fds: {:?}", pipe_fds); debug!("pipe_fds: {:?}", pipe_fds);
write_val_to_user(fds, &pipe_fds)?;
if let Err(err) = write_val_to_user(fds, &pipe_fds) {
file_table.close_file(pipe_fds.reader_fd).unwrap();
file_table.close_file(pipe_fds.writer_fd).unwrap();
return Err(err);
}
Ok(SyscallReturn::Return(0)) Ok(SyscallReturn::Return(0))
} }

View File

@ -25,7 +25,7 @@ pub fn sys_socketpair(domain: i32, type_: i32, protocol: i32, sv: Vaddr) -> Resu
let nonblocking = sock_flags.contains(SockFlags::SOCK_NONBLOCK); let nonblocking = sock_flags.contains(SockFlags::SOCK_NONBLOCK);
let (socket_a, socket_b) = match (domain, sock_type) { let (socket_a, socket_b) = match (domain, sock_type) {
(CSocketAddrFamily::AF_UNIX, SockType::SOCK_STREAM) => { (CSocketAddrFamily::AF_UNIX, SockType::SOCK_STREAM) => {
UnixStreamSocket::new_pair(nonblocking)? UnixStreamSocket::new_pair(nonblocking)
} }
_ => return_errno_with_message!( _ => return_errno_with_message!(
Errno::EAFNOSUPPORT, Errno::EAFNOSUPPORT,