diff --git a/kernel/src/net/socket/ip/datagram/mod.rs b/kernel/src/net/socket/ip/datagram/mod.rs index cdf426d80..9f10fc3b1 100644 --- a/kernel/src/net/socket/ip/datagram/mod.rs +++ b/kernel/src/net/socket/ip/datagram/mod.rs @@ -10,13 +10,10 @@ use self::{bound::BoundDatagram, unbound::UnboundDatagram}; use super::{common::get_ephemeral_endpoint, UNSPECIFIED_LOCAL_ENDPOINT}; use crate::{ events::IoEvents, - fs::{ - file_handle::FileLike, - utils::{InodeMode, Metadata, StatusFlags}, - }, match_sock_option_mut, net::socket::{ options::{Error as SocketError, SocketOption}, + private::SocketPrivate, util::{ options::{SetSocketLevelOption, SocketOptionSet}, send_recv_flags::SendRecvFlags, @@ -110,14 +107,6 @@ impl DatagramSocket { }) } - pub fn is_nonblocking(&self) -> bool { - self.is_nonblocking.load(Ordering::Relaxed) - } - - pub fn set_nonblocking(&self, is_nonblocking: bool) { - self.is_nonblocking.store(is_nonblocking, Ordering::Relaxed); - } - fn remote_endpoint(&self) -> Option { let inner = self.inner.read(); @@ -168,18 +157,6 @@ impl DatagramSocket { Ok(recv_bytes) } - fn recv( - &self, - writer: &mut dyn MultiWrite, - flags: SendRecvFlags, - ) -> Result<(usize, SocketAddr)> { - if self.is_nonblocking() { - self.try_recv(writer, flags) - } else { - self.wait_events(IoEvents::IN, None, || self.try_recv(writer, flags)) - } - } - fn try_send( &self, reader: &mut dyn MultiRead, @@ -219,59 +196,13 @@ impl Pollable for DatagramSocket { } } -impl FileLike for DatagramSocket { - fn read(&self, writer: &mut VmWriter) -> Result { - // TODO: set correct flags - let flags = SendRecvFlags::empty(); - let read_len = self.recv(writer, flags).map(|(len, _)| len)?; - Ok(read_len) +impl SocketPrivate for DatagramSocket { + fn is_nonblocking(&self) -> bool { + self.is_nonblocking.load(Ordering::Relaxed) } - fn write(&self, reader: &mut VmReader) -> Result { - let remote = self.remote_endpoint().ok_or_else(|| { - Error::with_message( - Errno::EDESTADDRREQ, - "the destination address is not specified", - ) - })?; - - // TODO: Set correct flags - let flags = SendRecvFlags::empty(); - - // TODO: Block if send buffer is full - self.try_send(reader, &remote, flags) - } - - fn as_socket(&self) -> Option<&dyn Socket> { - Some(self) - } - - fn status_flags(&self) -> StatusFlags { - // TODO: when we fully support O_ASYNC, return the flag - if self.is_nonblocking() { - StatusFlags::O_NONBLOCK - } else { - StatusFlags::empty() - } - } - - fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { - if new_flags.contains(StatusFlags::O_NONBLOCK) { - self.set_nonblocking(true); - } else { - self.set_nonblocking(false); - } - Ok(()) - } - - fn metadata(&self) -> Metadata { - // This is a dummy implementation. - // TODO: Add "SockFS" and link `DatagramSocket` to it. - Metadata::new_socket( - 0, - InodeMode::from_bits_truncate(0o140777), - aster_block::BLOCK_SIZE, - ) + fn set_nonblocking(&self, is_nonblocking: bool) { + self.is_nonblocking.store(is_nonblocking, Ordering::Relaxed); } } @@ -373,7 +304,8 @@ impl Socket for DatagramSocket { warn!("unsupported flags: {:?}", flags); } - let (received_bytes, peer_addr) = self.recv(writer, flags)?; + let (received_bytes, peer_addr) = + self.block_on(IoEvents::IN, || self.try_recv(writer, flags))?; // TODO: Receive control message diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 2845899be..956dd03e5 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -18,15 +18,13 @@ use util::TcpOptionSet; use super::UNSPECIFIED_LOCAL_ENDPOINT; use crate::{ events::IoEvents, - fs::{ - file_handle::FileLike, - utils::{InodeMode, Metadata, StatusFlags}, - }, + fs::file_handle::FileLike, match_sock_option_mut, match_sock_option_ref, net::{ iface::Iface, socket::{ options::{Error as SocketError, SocketOption}, + private::SocketPrivate, util::{ options::{SetSocketLevelOption, SocketOptionSet}, send_recv_flags::SendRecvFlags, @@ -131,14 +129,6 @@ impl StreamSocket { }) } - fn is_nonblocking(&self) -> bool { - self.is_nonblocking.load(Ordering::Relaxed) - } - - fn set_nonblocking(&self, nonblocking: bool) { - self.is_nonblocking.store(nonblocking, Ordering::Relaxed); - } - /// Ensures that the socket state is up to date and obtains a read lock on it. /// /// For a description of what "up-to-date" means, see [`Self::update_connecting`]. @@ -355,18 +345,6 @@ impl StreamSocket { Ok((recv_bytes, remote_endpoint.into())) } - fn recv( - &self, - writer: &mut dyn MultiWrite, - flags: SendRecvFlags, - ) -> Result<(usize, SocketAddr)> { - if self.is_nonblocking() { - self.try_recv(writer, flags) - } else { - self.wait_events(IoEvents::IN, None, || self.try_recv(writer, flags)) - } - } - fn try_send(&self, reader: &mut dyn MultiRead, flags: SendRecvFlags) -> Result { let state = self.read_updated_state(); @@ -395,14 +373,6 @@ impl StreamSocket { Ok(sent_bytes) } - fn send(&self, reader: &mut dyn MultiRead, flags: SendRecvFlags) -> Result { - if self.is_nonblocking() { - self.try_send(reader, flags) - } else { - self.wait_events(IoEvents::OUT, None, || self.try_send(reader, flags)) - } - } - fn check_io_events(&self) -> IoEvents { let state = self.read_updated_state(); @@ -422,49 +392,13 @@ impl Pollable for StreamSocket { } } -impl FileLike for StreamSocket { - fn read(&self, writer: &mut VmWriter) -> Result { - // TODO: Set correct flags - let flags = SendRecvFlags::empty(); - self.recv(writer, flags).map(|(len, _)| len) +impl SocketPrivate for StreamSocket { + fn is_nonblocking(&self) -> bool { + self.is_nonblocking.load(Ordering::Relaxed) } - fn write(&self, reader: &mut VmReader) -> Result { - // TODO: Set correct flags - let flags = SendRecvFlags::empty(); - self.send(reader, flags) - } - - fn status_flags(&self) -> StatusFlags { - // TODO: when we fully support O_ASYNC, return the flag - if self.is_nonblocking() { - StatusFlags::O_NONBLOCK - } else { - StatusFlags::empty() - } - } - - fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { - if new_flags.contains(StatusFlags::O_NONBLOCK) { - self.set_nonblocking(true); - } else { - self.set_nonblocking(false); - } - Ok(()) - } - - fn as_socket(&self) -> Option<&dyn Socket> { - Some(self) - } - - fn metadata(&self) -> Metadata { - // This is a dummy implementation. - // TODO: Add "SockFS" and link `StreamSocket` to it. - Metadata::new_socket( - 0, - InodeMode::from_bits_truncate(0o140777), - aster_block::BLOCK_SIZE, - ) + fn set_nonblocking(&self, nonblocking: bool) { + self.is_nonblocking.store(nonblocking, Ordering::Relaxed); } } @@ -546,11 +480,7 @@ impl Socket for StreamSocket { } fn accept(&self) -> Result<(Arc, SocketAddr)> { - if self.is_nonblocking() { - self.try_accept() - } else { - self.wait_events(IoEvents::IN, None, || self.try_accept()) - } + self.block_on(IoEvents::IN, || self.try_accept()) } fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { @@ -622,7 +552,7 @@ impl Socket for StreamSocket { warn!("sending control message is not supported"); } - self.send(reader, flags) + self.block_on(IoEvents::OUT, || self.try_send(reader, flags)) } fn recvmsg( @@ -635,7 +565,7 @@ impl Socket for StreamSocket { warn!("unsupported flags: {:?}", flags); } - let (received_bytes, _) = self.recv(writer, flags)?; + let (received_bytes, _) = self.block_on(IoEvents::IN, || self.try_recv(writer, flags))?; // TODO: Receive control message diff --git a/kernel/src/net/socket/mod.rs b/kernel/src/net/socket/mod.rs index c6ab68fb9..adbd36dd4 100644 --- a/kernel/src/net/socket/mod.rs +++ b/kernel/src/net/socket/mod.rs @@ -6,7 +6,10 @@ pub use self::util::{ socket_addr::SocketAddr, MessageHeader, }; use crate::{ - fs::file_handle::FileLike, + fs::{ + file_handle::FileLike, + utils::{InodeMode, Metadata, StatusFlags}, + }, prelude::*, util::{MultiRead, MultiWrite}, }; @@ -17,8 +20,43 @@ pub mod unix; mod util; pub mod vsock; +mod private { + use crate::{events::IoEvents, prelude::*, process::signal::Pollable}; + + /// Common methods for sockets, but private to the network module. + /// + /// These are implementation details of sockets, so shouldn't be accessed outside the network + /// module. Therefore, the whole trait is sealed. + pub trait SocketPrivate: Pollable { + /// Returns whether the socket is in non-blocking mode. + fn is_nonblocking(&self) -> bool; + + /// Sets whether the socket is in non-blocking mode. + fn set_nonblocking(&self, nonblocking: bool); + + /// Blocks until some events occur to complete I/O operations. + /// + /// If the socket is in non-blocking mode and the I/O operations cannot be completed + /// immediately, this method will fail with [`EAGAIN`] instead of blocking. + /// + /// [`EAGAIN`]: crate::error::Errno::EAGAIN + #[track_caller] + fn block_on(&self, events: IoEvents, mut try_op: F) -> Result + where + Self: Sized, + F: FnMut() -> Result, + { + if self.is_nonblocking() { + try_op() + } else { + self.wait_events(events, None, try_op) + } + } + } +} + /// Operations defined on a socket. -pub trait Socket: FileLike + Send + Sync { +pub trait Socket: private::SocketPrivate + Send + Sync { /// Assigns the specified address to the socket. fn bind(&self, _socket_addr: SocketAddr) -> Result<()> { return_errno_with_message!(Errno::EOPNOTSUPP, "bind() is not supported"); @@ -85,3 +123,56 @@ pub trait Socket: FileLike + Send + Sync { flags: SendRecvFlags, ) -> Result<(usize, MessageHeader)>; } + +impl FileLike for T { + fn read(&self, writer: &mut VmWriter) -> Result { + // TODO: Set correct flags + self.recvmsg(writer, SendRecvFlags::empty()) + .map(|(len, _)| len) + } + + fn write(&self, reader: &mut VmReader) -> Result { + // TODO: Set correct flags + self.sendmsg( + reader, + MessageHeader { + addr: None, + control_message: None, + }, + SendRecvFlags::empty(), + ) + } + + fn status_flags(&self) -> StatusFlags { + // TODO: Support other flags (e.g., `O_ASYNC`) + if self.is_nonblocking() { + StatusFlags::O_NONBLOCK + } else { + StatusFlags::empty() + } + } + + fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { + // TODO: Support other flags (e.g., `O_ASYNC`) + if new_flags.contains(StatusFlags::O_NONBLOCK) { + self.set_nonblocking(true); + } else { + self.set_nonblocking(false); + } + Ok(()) + } + + fn as_socket(&self) -> Option<&dyn Socket> { + Some(self) + } + + fn metadata(&self) -> Metadata { + // This is a dummy implementation. + // TODO: Add "SockFS" and link `Socket` to it. + Metadata::new_socket( + 0, + InodeMode::from_bits_truncate(0o140777), + aster_block::BLOCK_SIZE, + ) + } +} diff --git a/kernel/src/net/socket/unix/stream/socket.rs b/kernel/src/net/socket/unix/stream/socket.rs index 05b19a9f9..b0bb3733f 100644 --- a/kernel/src/net/socket/unix/stream/socket.rs +++ b/kernel/src/net/socket/unix/stream/socket.rs @@ -11,11 +11,9 @@ use super::{ }; use crate::{ events::IoEvents, - fs::{ - file_handle::FileLike, - utils::{InodeMode, Metadata, StatusFlags}, - }, + fs::file_handle::FileLike, net::socket::{ + private::SocketPrivate, unix::UnixSocketAddr, util::{send_recv_flags::SendRecvFlags, socket_addr::SocketAddr, MessageHeader}, SockShutdownCmd, Socket, @@ -65,14 +63,6 @@ impl UnixStreamSocket { ) } - fn send(&self, reader: &mut dyn MultiRead, flags: SendRecvFlags) -> Result { - if self.is_nonblocking() { - self.try_send(reader, flags) - } else { - self.wait_events(IoEvents::OUT, None, || self.try_send(reader, flags)) - } - } - fn try_send(&self, buf: &mut dyn MultiRead, _flags: SendRecvFlags) -> Result { match self.state.read().as_ref() { State::Connected(connected) => connected.try_write(buf), @@ -82,14 +72,6 @@ impl UnixStreamSocket { } } - fn recv(&self, writer: &mut dyn MultiWrite, flags: SendRecvFlags) -> Result { - if self.is_nonblocking() { - self.try_recv(writer, flags) - } else { - self.wait_events(IoEvents::IN, None, || self.try_recv(writer, flags)) - } - } - fn try_recv(&self, buf: &mut dyn MultiWrite, _flags: SendRecvFlags) -> Result { match self.state.read().as_ref() { State::Connected(connected) => connected.try_read(buf), @@ -142,14 +124,6 @@ impl UnixStreamSocket { } } } - - fn is_nonblocking(&self) -> bool { - self.is_nonblocking.load(Ordering::Relaxed) - } - - fn set_nonblocking(&self, nonblocking: bool) { - self.is_nonblocking.store(nonblocking, Ordering::Relaxed); - } } impl Pollable for UnixStreamSocket { @@ -163,45 +137,13 @@ impl Pollable for UnixStreamSocket { } } -impl FileLike for UnixStreamSocket { - fn as_socket(&self) -> Option<&dyn Socket> { - Some(self) +impl SocketPrivate for UnixStreamSocket { + fn is_nonblocking(&self) -> bool { + self.is_nonblocking.load(Ordering::Relaxed) } - fn read(&self, writer: &mut VmWriter) -> Result { - // TODO: Set correct flags - let flags = SendRecvFlags::empty(); - let read_len = self.recv(writer, flags)?; - Ok(read_len) - } - - fn write(&self, reader: &mut VmReader) -> Result { - // TODO: Set correct flags - let flags = SendRecvFlags::empty(); - self.send(reader, flags) - } - - fn status_flags(&self) -> StatusFlags { - if self.is_nonblocking() { - StatusFlags::O_NONBLOCK - } else { - StatusFlags::empty() - } - } - - fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { - self.set_nonblocking(new_flags.contains(StatusFlags::O_NONBLOCK)); - Ok(()) - } - - fn metadata(&self) -> Metadata { - // This is a dummy implementation. - // TODO: Add "SockFS" and link `UnixStreamSocket` to it. - Metadata::new_socket( - 0, - InodeMode::from_bits_truncate(0o140777), - aster_block::BLOCK_SIZE, - ) + fn set_nonblocking(&self, nonblocking: bool) { + self.is_nonblocking.store(nonblocking, Ordering::Relaxed); } } @@ -270,11 +212,7 @@ impl Socket for UnixStreamSocket { } fn accept(&self) -> Result<(Arc, SocketAddr)> { - if self.is_nonblocking() { - self.try_accept() - } else { - self.wait_events(IoEvents::IN, None, || self.try_accept()) - } + self.block_on(IoEvents::IN, || self.try_accept()) } fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { @@ -328,7 +266,7 @@ impl Socket for UnixStreamSocket { warn!("sending control message is not supported"); } - self.send(reader, flags) + self.block_on(IoEvents::OUT, || self.try_send(reader, flags)) } fn recvmsg( @@ -341,7 +279,7 @@ impl Socket for UnixStreamSocket { warn!("unsupported flags: {:?}", flags); } - let received_bytes = self.recv(writer, flags)?; + let received_bytes = self.block_on(IoEvents::IN, || self.try_recv(writer, flags))?; // TODO: Receive control message diff --git a/kernel/src/net/socket/vsock/stream/socket.rs b/kernel/src/net/socket/vsock/stream/socket.rs index c8fdf9c42..831b45fe8 100644 --- a/kernel/src/net/socket/vsock/stream/socket.rs +++ b/kernel/src/net/socket/vsock/stream/socket.rs @@ -5,11 +5,9 @@ use core::sync::atomic::{AtomicBool, Ordering}; use super::{connected::Connected, connecting::Connecting, init::Init, listen::Listen}; use crate::{ events::IoEvents, - fs::{ - file_handle::FileLike, - utils::{InodeMode, Metadata, StatusFlags}, - }, + fs::file_handle::FileLike, net::socket::{ + private::SocketPrivate, vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL}, MessageHeader, SendRecvFlags, SockShutdownCmd, Socket, SocketAddr, }, @@ -45,14 +43,6 @@ impl VsockStreamSocket { } } - fn is_nonblocking(&self) -> bool { - self.is_nonblocking.load(Ordering::Relaxed) - } - - fn set_nonblocking(&self, nonblocking: bool) { - self.is_nonblocking.store(nonblocking, Ordering::Relaxed); - } - fn try_accept(&self) -> Result<(Arc, SocketAddr)> { let listen = match &*self.status.read() { Status::Listen(listen) => listen.clone(), @@ -114,18 +104,6 @@ impl VsockStreamSocket { } Ok((read_size, peer_addr)) } - - fn recv( - &self, - writer: &mut dyn MultiWrite, - flags: SendRecvFlags, - ) -> Result<(usize, SocketAddr)> { - if self.is_nonblocking() { - self.try_recv(writer, flags) - } else { - self.wait_events(IoEvents::IN, None, || self.try_recv(writer, flags)) - } - } } impl Pollable for VsockStreamSocket { @@ -138,49 +116,13 @@ impl Pollable for VsockStreamSocket { } } -impl FileLike for VsockStreamSocket { - fn as_socket(&self) -> Option<&dyn Socket> { - Some(self) +impl SocketPrivate for VsockStreamSocket { + fn is_nonblocking(&self) -> bool { + self.is_nonblocking.load(Ordering::Relaxed) } - fn read(&self, writer: &mut VmWriter) -> Result { - // TODO: Set correct flags - let read_len = self - .recv(writer, SendRecvFlags::empty()) - .map(|(len, _)| len)?; - Ok(read_len) - } - - fn write(&self, reader: &mut VmReader) -> Result { - // TODO: Set correct flags - self.send(reader, SendRecvFlags::empty()) - } - - fn status_flags(&self) -> StatusFlags { - if self.is_nonblocking() { - StatusFlags::O_NONBLOCK - } else { - StatusFlags::empty() - } - } - - fn set_status_flags(&self, new_flags: StatusFlags) -> Result<()> { - if new_flags.contains(StatusFlags::O_NONBLOCK) { - self.set_nonblocking(true); - } else { - self.set_nonblocking(false); - } - Ok(()) - } - - fn metadata(&self) -> Metadata { - // This is a dummy implementation. - // TODO: Add "SockFS" and link `VsockStreamSocket` to it. - Metadata::new_socket( - 0, - InodeMode::from_bits_truncate(0o140777), - aster_block::BLOCK_SIZE, - ) + fn set_nonblocking(&self, nonblocking: bool) { + self.is_nonblocking.store(nonblocking, Ordering::Relaxed); } } @@ -280,11 +222,7 @@ impl Socket for VsockStreamSocket { } fn accept(&self) -> Result<(Arc, SocketAddr)> { - if self.is_nonblocking() { - self.try_accept() - } else { - self.wait_events(IoEvents::IN, None, || self.try_accept()) - } + self.block_on(IoEvents::IN, || self.try_accept()) } fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { @@ -329,7 +267,7 @@ impl Socket for VsockStreamSocket { warn!("unsupported flags: {:?}", flags); } - let (received_bytes, _) = self.recv(writer, flags)?; + let (received_bytes, _) = self.block_on(IoEvents::IN, || self.try_recv(writer, flags))?; // TODO: Receive control message