From 985813c7f974ac2022b379902c168e785c76464c Mon Sep 17 00:00:00 2001 From: Jianfeng Jiang Date: Wed, 11 Sep 2024 02:36:24 +0000 Subject: [PATCH] Use IoVec-based reader/writer to refactor network APIs --- kernel/src/device/pty/pty.rs | 5 +- kernel/src/fs/utils/channel.rs | 25 +- kernel/src/net/socket/ip/datagram/bound.rs | 65 +++-- kernel/src/net/socket/ip/datagram/mod.rs | 71 ++--- kernel/src/net/socket/ip/stream/connected.rs | 37 ++- kernel/src/net/socket/ip/stream/mod.rs | 62 ++--- kernel/src/net/socket/mod.rs | 14 +- .../src/net/socket/unix/stream/connected.rs | 11 +- kernel/src/net/socket/unix/stream/socket.rs | 52 ++-- kernel/src/net/socket/util/message_header.rs | 65 +---- kernel/src/net/socket/util/mod.rs | 3 - kernel/src/net/socket/vsock/common.rs | 11 +- .../src/net/socket/vsock/stream/connected.rs | 13 +- kernel/src/net/socket/vsock/stream/socket.rs | 54 ++-- kernel/src/syscall/preadv.rs | 43 ++- kernel/src/syscall/pwritev.rs | 38 +-- kernel/src/syscall/recvfrom.rs | 15 +- kernel/src/syscall/recvmsg.rs | 4 +- kernel/src/syscall/sendmsg.rs | 8 +- kernel/src/syscall/sendto.rs | 14 +- kernel/src/util/iovec.rs | 256 ++++++++++++------ kernel/src/util/mod.rs | 2 +- kernel/src/util/net/socket.rs | 10 +- kernel/src/util/ring_buffer.rs | 59 +--- 24 files changed, 484 insertions(+), 453 deletions(-) diff --git a/kernel/src/device/pty/pty.rs b/kernel/src/device/pty/pty.rs index 674a0f683..bb1aad257 100644 --- a/kernel/src/device/pty/pty.rs +++ b/kernel/src/device/pty/pty.rs @@ -124,10 +124,7 @@ impl FileIo for PtyMaster { continue; } - let read_len = match input.read_fallible(writer) { - Ok(len) => len, - Err((_, len)) => len, - }; + let read_len = input.read_fallible(writer)?; self.update_state(&input); return Ok(read_len); } diff --git a/kernel/src/fs/utils/channel.rs b/kernel/src/fs/utils/channel.rs index f2f537fec..7b11d3833 100644 --- a/kernel/src/fs/utils/channel.rs +++ b/kernel/src/fs/utils/channel.rs @@ -9,7 +9,10 @@ use crate::{ events::{IoEvents, Observer}, prelude::*, process::signal::{Pollee, Poller}, - util::ring_buffer::{RbConsumer, RbProducer, RingBuffer}, + util::{ + ring_buffer::{RbConsumer, RbProducer, RingBuffer}, + MultiRead, MultiWrite, + }, }; /// A unidirectional communication channel, intended to implement IPC, e.g., pipe, @@ -142,8 +145,8 @@ impl Producer { /// - Returns `Ok(_)` with the number of bytes written if successful. /// - Returns `Err(EPIPE)` if the channel is shut down. /// - Returns `Err(EAGAIN)` if the channel is full. - pub fn try_write(&self, reader: &mut VmReader) -> Result { - if reader.remain() == 0 { + pub fn try_write(&self, reader: &mut dyn MultiRead) -> Result { + if reader.is_empty() { // Even after shutdown, writing an empty buffer is still fine. return Ok(0); } @@ -230,8 +233,8 @@ impl Consumer { /// - Returns `Ok(_)` with the number of bytes read if successful. /// - Returns `Ok(0)` if the channel is shut down and there is no data left. /// - Returns `Err(EAGAIN)` if the channel is empty. - pub fn try_read(&self, writer: &mut VmWriter) -> Result { - if writer.avail() == 0 { + pub fn try_read(&self, writer: &mut dyn MultiWrite) -> Result { + if writer.is_empty() { return Ok(0); } @@ -296,25 +299,25 @@ impl Fifo { impl Fifo { #[require(R > Read)] - pub fn read(&self, writer: &mut VmWriter) -> usize { + pub fn read(&self, writer: &mut dyn MultiWrite) -> Result { let mut rb = self.common.consumer.rb(); match rb.read_fallible(writer) { Ok(len) => len, - Err((e, len)) => { + Err(e) => { error!("memory read failed on the ring buffer, error: {e:?}"); - len + 0 } } } #[require(R > Write)] - pub fn write(&self, reader: &mut VmReader) -> usize { + pub fn write(&self, reader: &mut dyn MultiRead) -> Result { let mut rb = self.common.producer.rb(); match rb.write_fallible(reader) { Ok(len) => len, - Err((e, len)) => { + Err(e) => { error!("memory write failed on the ring buffer, error: {e:?}"); - len + 0 } } } diff --git a/kernel/src/net/socket/ip/datagram/bound.rs b/kernel/src/net/socket/ip/datagram/bound.rs index 32592a4f6..ddd9da673 100644 --- a/kernel/src/net/socket/ip/datagram/bound.rs +++ b/kernel/src/net/socket/ip/datagram/bound.rs @@ -11,6 +11,7 @@ use crate::{ net::{iface::AnyBoundSocket, socket::util::send_recv_flags::SendRecvFlags}, prelude::*, process::signal::Pollee, + util::{MultiRead, MultiWrite}, }; pub struct BoundDatagram { @@ -38,43 +39,63 @@ impl BoundDatagram { self.remote_endpoint = Some(*endpoint) } - pub fn try_recv(&self, buf: &mut [u8], _flags: SendRecvFlags) -> Result<(usize, IpEndpoint)> { - let result = self - .bound_socket - .raw_with(|socket: &mut RawUdpSocket| socket.recv_slice(buf)); + pub fn try_recv( + &self, + writer: &mut dyn MultiWrite, + _flags: SendRecvFlags, + ) -> Result<(usize, IpEndpoint)> { + let result = self.bound_socket.raw_with(|socket: &mut RawUdpSocket| { + socket.recv().map(|(packet, udp_metadata)| { + let copied_res = writer.write(&mut VmReader::from(packet)); + let endpoint = udp_metadata.endpoint; + (copied_res, endpoint) + }) + }); + match result { - Ok((recv_len, udp_metadata)) => Ok((recv_len, udp_metadata.endpoint)), + Ok((Ok(res), endpoint)) => Ok((res, endpoint)), + Ok((Err(e), _)) => Err(e), Err(RecvError::Exhausted) => { return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty") } Err(RecvError::Truncated) => { - todo!(); + unreachable!("`Socket::recv` should never fail with `RecvError::Truncated`") } } } pub fn try_send( &self, - buf: &[u8], + reader: &mut dyn MultiRead, remote: &IpEndpoint, _flags: SendRecvFlags, ) -> Result { - let result = self.bound_socket.raw_with(|socket: &mut RawUdpSocket| { - if socket.payload_send_capacity() < buf.len() { - return None; + let reader_len = reader.sum_lens(); + + self.bound_socket.raw_with(|socket: &mut RawUdpSocket| { + if socket.payload_send_capacity() < reader_len { + return_errno_with_message!(Errno::EMSGSIZE, "the message is too large"); } - Some(socket.send_slice(buf, *remote)) - }); - match result { - Some(Ok(())) => Ok(buf.len()), - Some(Err(SendError::BufferFull)) => { - return_errno_with_message!(Errno::EAGAIN, "the send buffer is full") - } - Some(Err(SendError::Unaddressable)) => { - return_errno_with_message!(Errno::EINVAL, "the destination address is invalid") - } - None => return_errno_with_message!(Errno::EMSGSIZE, "the message is too large"), - } + + let socket_buffer = match socket.send(reader_len, *remote) { + Ok(socket_buffer) => socket_buffer, + Err(SendError::BufferFull) => { + return_errno_with_message!(Errno::EAGAIN, "the send buffer is full") + } + Err(SendError::Unaddressable) => { + return_errno_with_message!(Errno::EINVAL, "the destination address is invalid") + } + }; + + // FIXME: If copy failed, we should not send any packet. + // But current smoltcp API seems not to support this behavior. + reader + .read(&mut VmWriter::from(socket_buffer)) + .map_err(|e| { + warn!("unexpected UDP packet will be sent"); + e + }) + }) } pub(super) fn init_pollee(&self, pollee: &Pollee) { diff --git a/kernel/src/net/socket/ip/datagram/mod.rs b/kernel/src/net/socket/ip/datagram/mod.rs index 0c97d9381..61bc9dc33 100644 --- a/kernel/src/net/socket/ip/datagram/mod.rs +++ b/kernel/src/net/socket/ip/datagram/mod.rs @@ -16,7 +16,6 @@ use crate::{ socket::{ options::{Error as SocketError, SocketOption}, util::{ - copy_message_from_user, copy_message_to_user, create_message_buffer, options::SocketOptionSet, send_recv_flags::SendRecvFlags, socket_addr::SocketAddr, MessageHeader, }, @@ -25,7 +24,7 @@ use crate::{ }, prelude::*, process::signal::{Pollable, Pollee, Poller}, - util::IoVec, + util::{MultiRead, MultiWrite}, }; mod bound; @@ -144,19 +143,24 @@ impl DatagramSocket { }) } - fn try_recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + fn try_recv( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, SocketAddr)> { let inner = self.inner.read(); let Inner::Bound(bound_datagram) = inner.as_ref() else { return_errno_with_message!(Errno::EAGAIN, "the socket is not bound"); }; - let received = bound_datagram - .try_recv(buf, flags) - .map(|(recv_bytes, remote_endpoint)| { - bound_datagram.update_io_events(&self.pollee); - (recv_bytes, remote_endpoint.into()) - }); + let received = + bound_datagram + .try_recv(writer, flags) + .map(|(recv_bytes, remote_endpoint)| { + bound_datagram.update_io_events(&self.pollee); + (recv_bytes, remote_endpoint.into()) + }); drop(inner); poll_ifaces(); @@ -164,15 +168,24 @@ impl DatagramSocket { received } - fn recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + fn recv( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, SocketAddr)> { if self.is_nonblocking() { - self.try_recv(buf, flags) + self.try_recv(writer, flags) } else { - self.wait_events(IoEvents::IN, || self.try_recv(buf, flags)) + self.wait_events(IoEvents::IN, || self.try_recv(writer, flags)) } } - fn try_send(&self, buf: &[u8], remote: &IpEndpoint, flags: SendRecvFlags) -> Result { + fn try_send( + &self, + reader: &mut dyn MultiRead, + remote: &IpEndpoint, + flags: SendRecvFlags, + ) -> Result { let inner = self.inner.read(); let Inner::Bound(bound_datagram) = inner.as_ref() else { @@ -180,7 +193,7 @@ impl DatagramSocket { }; let sent_bytes = bound_datagram - .try_send(buf, remote, flags) + .try_send(reader, remote, flags) .map(|sent_bytes| { bound_datagram.update_io_events(&self.pollee); sent_bytes @@ -209,16 +222,13 @@ impl Pollable for DatagramSocket { impl FileLike for DatagramSocket { fn read(&self, writer: &mut VmWriter) -> Result { - let mut buf = vec![0u8; writer.avail()]; // TODO: set correct flags let flags = SendRecvFlags::empty(); - let read_len = self.recv(&mut buf, flags).map(|(len, _)| len)?; - writer.write_fallible(&mut buf.as_slice().into())?; + let read_len = self.recv(writer, flags).map(|(len, _)| len)?; Ok(read_len) } fn write(&self, reader: &mut VmReader) -> Result { - let buf = reader.collect()?; let remote = self.remote_endpoint().ok_or_else(|| { Error::with_message( Errno::EDESTADDRREQ, @@ -230,7 +240,7 @@ impl FileLike for DatagramSocket { let flags = SendRecvFlags::empty(); // TODO: Block if send buffer is full - self.try_send(&buf, &remote, flags) + self.try_send(reader, &remote, flags) } fn as_socket(self: Arc) -> Option> { @@ -320,7 +330,7 @@ impl Socket for DatagramSocket { fn sendmsg( &self, - io_vecs: &[IoVec], + reader: &mut dyn MultiRead, message_header: MessageHeader, flags: SendRecvFlags, ) -> Result { @@ -351,30 +361,25 @@ impl Socket for DatagramSocket { warn!("sending control message is not supported"); } - let buf = copy_message_from_user(io_vecs); - // TODO: Block if the send buffer is full - self.try_send(&buf, &remote_endpoint, flags) + self.try_send(reader, &remote_endpoint, flags) } - fn recvmsg(&self, io_vecs: &[IoVec], flags: SendRecvFlags) -> Result<(usize, MessageHeader)> { + fn recvmsg( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, MessageHeader)> { // TODO: Deal with flags debug_assert!(flags.is_all_supported()); - let mut buf = create_message_buffer(io_vecs); - - let (received_bytes, peer_addr) = self.recv(&mut buf, flags)?; - - let copied_bytes = { - let message = &buf[..received_bytes]; - copy_message_to_user(io_vecs, message) - }; + let (received_bytes, peer_addr) = self.recv(writer, flags)?; // TODO: Receive control message let message_header = MessageHeader::new(Some(peer_addr), None); - Ok((copied_bytes, message_header)) + Ok((received_bytes, message_header)) } fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> { diff --git a/kernel/src/net/socket/ip/stream/connected.rs b/kernel/src/net/socket/ip/stream/connected.rs index e0f4fd90b..c8ae4b40b 100644 --- a/kernel/src/net/socket/ip/stream/connected.rs +++ b/kernel/src/net/socket/ip/stream/connected.rs @@ -16,6 +16,7 @@ use crate::{ }, prelude::*, process::signal::Pollee, + util::{MultiRead, MultiWrite}, }; pub struct ConnectedStream { @@ -55,14 +56,20 @@ impl ConnectedStream { Ok(()) } - pub fn try_recv(&self, buf: &mut [u8], _flags: SendRecvFlags) -> Result { - let result = self - .bound_socket - .raw_with(|socket: &mut RawTcpSocket| socket.recv_slice(buf)); + pub fn try_recv(&self, writer: &mut dyn MultiWrite, _flags: SendRecvFlags) -> Result { + let result = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { + socket.recv( + |socket_buffer| match writer.write(&mut VmReader::from(&*socket_buffer)) { + Ok(len) => (len, Ok(len)), + Err(e) => (0, Err(e)), + }, + ) + }); match result { - Ok(0) => return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty"), - Ok(recv_bytes) => Ok(recv_bytes), + Ok(Ok(0)) => return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty"), + Ok(Ok(recv_bytes)) => Ok(recv_bytes), + Ok(Err(e)) => Err(e), Err(RecvError::Finished) => Ok(0), Err(RecvError::InvalidState) => { return_errno_with_message!(Errno::ECONNRESET, "the connection is reset") @@ -70,14 +77,20 @@ impl ConnectedStream { } } - pub fn try_send(&self, buf: &[u8], _flags: SendRecvFlags) -> Result { - let result = self - .bound_socket - .raw_with(|socket: &mut RawTcpSocket| socket.send_slice(buf)); + pub fn try_send(&self, reader: &mut dyn MultiRead, _flags: SendRecvFlags) -> Result { + let result = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { + socket.send( + |socket_buffer| match reader.read(&mut VmWriter::from(socket_buffer)) { + Ok(len) => (len, Ok(len)), + Err(e) => (0, Err(e)), + }, + ) + }); match result { - Ok(0) => return_errno_with_message!(Errno::EAGAIN, "the send buffer is full"), - Ok(sent_bytes) => Ok(sent_bytes), + Ok(Ok(0)) => return_errno_with_message!(Errno::EAGAIN, "the send buffer is full"), + Ok(Ok(sent_bytes)) => Ok(sent_bytes), + Ok(Err(e)) => Err(e), Err(SendError::InvalidState) => { // FIXME: `EPIPE` is another possibility, which means that the socket is shut down // for writing. In that case, we should also trigger a `SIGPIPE` if `MSG_NOSIGNAL` diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 921233f61..812a506af 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -21,7 +21,6 @@ use crate::{ socket::{ options::{Error as SocketError, SocketOption}, util::{ - copy_message_from_user, copy_message_to_user, create_message_buffer, options::SocketOptionSet, send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd, socket_addr::SocketAddr, MessageHeader, }, @@ -30,7 +29,7 @@ use crate::{ }, prelude::*, process::signal::{Pollable, Pollee, Poller}, - util::IoVec, + util::{MultiRead, MultiWrite}, }; mod connected; @@ -245,7 +244,11 @@ impl StreamSocket { accepted } - fn try_recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + fn try_recv( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, SocketAddr)> { let state = self.state.read(); let connected_stream = match state.as_ref() { @@ -258,7 +261,7 @@ impl StreamSocket { } }; - let received = connected_stream.try_recv(buf, flags).map(|recv_bytes| { + let received = connected_stream.try_recv(writer, flags).map(|recv_bytes| { connected_stream.update_io_events(&self.pollee); let remote_endpoint = connected_stream.remote_endpoint(); @@ -271,15 +274,19 @@ impl StreamSocket { received } - fn recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + fn recv( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, SocketAddr)> { if self.is_nonblocking() { - self.try_recv(buf, flags) + self.try_recv(writer, flags) } else { - self.wait_events(IoEvents::IN, || self.try_recv(buf, flags)) + self.wait_events(IoEvents::IN, || self.try_recv(writer, flags)) } } - fn try_send(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + fn try_send(&self, reader: &mut dyn MultiRead, flags: SendRecvFlags) -> Result { let state = self.state.read(); let connected_stream = match state.as_ref() { @@ -295,7 +302,7 @@ impl StreamSocket { } }; - let sent_bytes = connected_stream.try_send(buf, flags).map(|sent_bytes| { + let sent_bytes = connected_stream.try_send(reader, flags).map(|sent_bytes| { connected_stream.update_io_events(&self.pollee); sent_bytes }); @@ -306,11 +313,11 @@ impl StreamSocket { sent_bytes } - fn send(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + fn send(&self, reader: &mut dyn MultiRead, flags: SendRecvFlags) -> Result { if self.is_nonblocking() { - self.try_send(buf, flags) + self.try_send(reader, flags) } else { - self.wait_events(IoEvents::OUT, || self.try_send(buf, flags)) + self.wait_events(IoEvents::OUT, || self.try_send(reader, flags)) } } @@ -340,19 +347,15 @@ impl Pollable for StreamSocket { impl FileLike for StreamSocket { fn read(&self, writer: &mut VmWriter) -> Result { - let mut buf = vec![0u8; writer.avail()]; // TODO: Set correct flags let flags = SendRecvFlags::empty(); - let read_len = self.recv(&mut buf, flags).map(|(len, _)| len)?; - writer.write_fallible(&mut buf.as_slice().into())?; - Ok(read_len) + self.recv(writer, flags).map(|(len, _)| len) } fn write(&self, reader: &mut VmReader) -> Result { - let buf = reader.collect()?; // TODO: Set correct flags let flags = SendRecvFlags::empty(); - self.send(&buf, flags) + self.send(reader, flags) } fn status_flags(&self) -> StatusFlags { @@ -509,7 +512,7 @@ impl Socket for StreamSocket { fn sendmsg( &self, - io_vecs: &[IoVec], + reader: &mut dyn MultiRead, message_header: MessageHeader, flags: SendRecvFlags, ) -> Result { @@ -529,23 +532,18 @@ impl Socket for StreamSocket { warn!("sending control message is not supported"); } - let buf = copy_message_from_user(io_vecs); - - self.send(&buf, flags) + self.send(reader, flags) } - fn recvmsg(&self, io_vecs: &[IoVec], flags: SendRecvFlags) -> Result<(usize, MessageHeader)> { + fn recvmsg( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, MessageHeader)> { // TODO: Deal with flags debug_assert!(flags.is_all_supported()); - let mut buf = create_message_buffer(io_vecs); - - let (received_bytes, _) = self.recv(&mut buf, flags)?; - - let copied_bytes = { - let message = &buf[..received_bytes]; - copy_message_to_user(io_vecs, message) - }; + let (received_bytes, _) = self.recv(writer, flags)?; // TODO: Receive control message @@ -553,7 +551,7 @@ impl Socket for StreamSocket { // peer address is ignored for connected socket. let message_header = MessageHeader::new(None, None); - Ok((copied_bytes, message_header)) + Ok((received_bytes, message_header)) } fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> { diff --git a/kernel/src/net/socket/mod.rs b/kernel/src/net/socket/mod.rs index d3f23bdc7..1418ed302 100644 --- a/kernel/src/net/socket/mod.rs +++ b/kernel/src/net/socket/mod.rs @@ -5,7 +5,11 @@ pub use self::util::{ options::LingerOption, send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd, socket_addr::SocketAddr, MessageHeader, }; -use crate::{fs::file_handle::FileLike, prelude::*, util::IoVec}; +use crate::{ + fs::file_handle::FileLike, + prelude::*, + util::{MultiRead, MultiWrite}, +}; pub mod ip; pub mod options; @@ -64,7 +68,7 @@ pub trait Socket: FileLike + Send + Sync { /// Sends a message on a socket. fn sendmsg( &self, - io_vecs: &[IoVec], + reader: &mut dyn MultiRead, message_header: MessageHeader, flags: SendRecvFlags, ) -> Result; @@ -74,5 +78,9 @@ pub trait Socket: FileLike + Send + Sync { /// If successful, the `io_vecs` buffer will be filled with the received content. /// This method returns the length of the received message, /// and the message header. - fn recvmsg(&self, io_vecs: &[IoVec], flags: SendRecvFlags) -> Result<(usize, MessageHeader)>; + fn recvmsg( + &self, + writers: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, MessageHeader)>; } diff --git a/kernel/src/net/socket/unix/stream/connected.rs b/kernel/src/net/socket/unix/stream/connected.rs index 90dbf3a92..68be25f70 100644 --- a/kernel/src/net/socket/unix/stream/connected.rs +++ b/kernel/src/net/socket/unix/stream/connected.rs @@ -13,6 +13,7 @@ use crate::{ }, prelude::*, process::signal::{Pollee, Poller}, + util::{MultiRead, MultiWrite}, }; pub(super) struct Connected { @@ -70,14 +71,12 @@ impl Connected { Ok(()) } - pub(super) fn try_read(&self, buf: &mut [u8]) -> Result { - let mut writer = VmWriter::from(buf).to_fallible(); - self.reader.try_read(&mut writer) + pub(super) fn try_read(&self, writer: &mut dyn MultiWrite) -> Result { + self.reader.try_read(writer) } - pub(super) fn try_write(&self, buf: &[u8]) -> Result { - let mut reader = VmReader::from(buf).to_fallible(); - self.writer.try_write(&mut reader) + pub(super) fn try_write(&self, reader: &mut dyn MultiRead) -> Result { + self.writer.try_write(reader) } pub(super) fn shutdown(&self, cmd: SockShutdownCmd) { diff --git a/kernel/src/net/socket/unix/stream/socket.rs b/kernel/src/net/socket/unix/stream/socket.rs index c47d16891..053f2f81b 100644 --- a/kernel/src/net/socket/unix/stream/socket.rs +++ b/kernel/src/net/socket/unix/stream/socket.rs @@ -15,15 +15,12 @@ use crate::{ fs::{file_handle::FileLike, utils::StatusFlags}, net::socket::{ unix::UnixSocketAddr, - util::{ - copy_message_from_user, copy_message_to_user, create_message_buffer, - send_recv_flags::SendRecvFlags, socket_addr::SocketAddr, MessageHeader, - }, + util::{send_recv_flags::SendRecvFlags, socket_addr::SocketAddr, MessageHeader}, SockShutdownCmd, Socket, }, prelude::*, process::signal::{Pollable, Poller}, - util::IoVec, + util::{MultiRead, MultiWrite}, }; pub struct UnixStreamSocket { @@ -66,15 +63,15 @@ impl UnixStreamSocket { ) } - fn send(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + fn send(&self, reader: &mut dyn MultiRead, flags: SendRecvFlags) -> Result { if self.is_nonblocking() { - self.try_send(buf, flags) + self.try_send(reader, flags) } else { - self.wait_events(IoEvents::OUT, || self.try_send(buf, flags)) + self.wait_events(IoEvents::OUT, || self.try_send(reader, flags)) } } - fn try_send(&self, buf: &[u8], _flags: SendRecvFlags) -> Result { + fn try_send(&self, buf: &mut dyn MultiRead, _flags: SendRecvFlags) -> Result { match self.state.read().as_ref() { State::Connected(connected) => connected.try_write(buf), State::Init(_) | State::Listen(_) => { @@ -83,15 +80,15 @@ impl UnixStreamSocket { } } - fn recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result { + fn recv(&self, writer: &mut dyn MultiWrite, flags: SendRecvFlags) -> Result { if self.is_nonblocking() { - self.try_recv(buf, flags) + self.try_recv(writer, flags) } else { - self.wait_events(IoEvents::IN, || self.try_recv(buf, flags)) + self.wait_events(IoEvents::IN, || self.try_recv(writer, flags)) } } - fn try_recv(&self, buf: &mut [u8], _flags: SendRecvFlags) -> Result { + fn try_recv(&self, buf: &mut dyn MultiWrite, _flags: SendRecvFlags) -> Result { match self.state.read().as_ref() { State::Connected(connected) => connected.try_read(buf), State::Init(_) | State::Listen(_) => { @@ -170,19 +167,16 @@ impl FileLike for UnixStreamSocket { } fn read(&self, writer: &mut VmWriter) -> Result { - let mut buf = vec![0u8; writer.avail()]; // TODO: Set correct flags let flags = SendRecvFlags::empty(); - let read_len = self.recv(&mut buf, flags)?; - writer.write_fallible(&mut buf.as_slice().into())?; + let read_len = self.recv(writer, flags)?; Ok(read_len) } fn write(&self, reader: &mut VmReader) -> Result { - let buf = reader.collect()?; // TODO: Set correct flags let flags = SendRecvFlags::empty(); - self.send(&buf, flags) + self.send(reader, flags) } fn status_flags(&self) -> StatusFlags { @@ -327,7 +321,7 @@ impl Socket for UnixStreamSocket { fn sendmsg( &self, - io_vecs: &[IoVec], + reader: &mut dyn MultiRead, message_header: MessageHeader, flags: SendRecvFlags, ) -> Result { @@ -343,27 +337,23 @@ impl Socket for UnixStreamSocket { warn!("sending control message is not supported"); } - let buf = copy_message_from_user(io_vecs); - - self.send(&buf, flags) + self.send(reader, flags) } - fn recvmsg(&self, io_vecs: &[IoVec], flags: SendRecvFlags) -> Result<(usize, MessageHeader)> { + fn recvmsg( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, MessageHeader)> { // TODO: Deal with flags debug_assert!(flags.is_all_supported()); - let mut buf = create_message_buffer(io_vecs); - let received_bytes = self.recv(&mut buf, flags)?; - - let copied_bytes = { - let message = &buf[..received_bytes]; - copy_message_to_user(io_vecs, message) - }; + let received_bytes = self.recv(writer, flags)?; // TODO: Receive control message let message_header = MessageHeader::new(None, None); - Ok((copied_bytes, message_header)) + Ok((received_bytes, message_header)) } } diff --git a/kernel/src/net/socket/util/message_header.rs b/kernel/src/net/socket/util/message_header.rs index 278a088b4..a396ffd52 100644 --- a/kernel/src/net/socket/util/message_header.rs +++ b/kernel/src/net/socket/util/message_header.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 use super::socket_addr::SocketAddr; -use crate::{prelude::*, util::IoVec}; +use crate::prelude::*; /// Message header used for sendmsg/recvmsg. #[derive(Debug)] @@ -30,66 +30,3 @@ impl MessageHeader { /// TODO: Implement the struct. The struct is empty now. #[derive(Debug)] pub struct ControlMessage; - -/// Copies a message from user space. -/// -/// Since udp allows sending and receiving packet of length 0, -/// The returned buffer may have length of zero. -pub fn copy_message_from_user(io_vecs: &[IoVec]) -> Box<[u8]> { - let mut buffer = create_message_buffer(io_vecs); - - let mut total_bytes = 0; - for io_vec in io_vecs { - if io_vec.is_empty() { - continue; - } - let dst = &mut buffer[total_bytes..total_bytes + io_vec.len()]; - // FIXME: short read should be allowed here - match io_vec.read_exact_from_user(dst) { - Ok(()) => total_bytes += io_vec.len(), - Err(_) => { - warn!("fails to copy message from user"); - break; - } - } - } - - buffer.truncate(total_bytes); - buffer.into_boxed_slice() -} - -/// Creates a buffer whose length -/// is equal to the total length of `io_vecs`. -pub fn create_message_buffer(io_vecs: &[IoVec]) -> Vec { - let buffer_len: usize = io_vecs.iter().map(|iovec| iovec.len()).sum(); - vec![0; buffer_len] -} - -/// Copies a message to user space. -/// -/// This method returns the actual copied length. -pub fn copy_message_to_user(io_vecs: &[IoVec], message: &[u8]) -> usize { - let mut total_bytes = 0; - - for io_vec in io_vecs { - if io_vec.is_empty() { - continue; - } - - let len = io_vec.len().min(message.len() - total_bytes); - if len == 0 { - break; - } - - let src = &message[total_bytes..total_bytes + len]; - match io_vec.write_to_user(src) { - Ok(len) => total_bytes += len, - Err(_) => { - warn!("fails to copy message to user"); - break; - } - } - } - - total_bytes -} diff --git a/kernel/src/net/socket/util/mod.rs b/kernel/src/net/socket/util/mod.rs index fa8f399bc..72694110e 100644 --- a/kernel/src/net/socket/util/mod.rs +++ b/kernel/src/net/socket/util/mod.rs @@ -7,6 +7,3 @@ pub mod shutdown_cmd; pub mod socket_addr; pub use message_header::MessageHeader; -pub(in crate::net) use message_header::{ - copy_message_from_user, copy_message_to_user, create_message_buffer, -}; diff --git a/kernel/src/net/socket/vsock/common.rs b/kernel/src/net/socket/vsock/common.rs index 9f6b007ae..8617fa5ad 100644 --- a/kernel/src/net/socket/vsock/common.rs +++ b/kernel/src/net/socket/vsock/common.rs @@ -16,7 +16,7 @@ use super::{ listen::Listen, }, }; -use crate::{events::IoEvents, prelude::*, return_errno_with_message}; +use crate::{events::IoEvents, prelude::*, return_errno_with_message, util::MultiRead}; /// Manage all active sockets pub struct VsockSpace { @@ -190,10 +190,15 @@ impl VsockSpace { } /// Send a data packet - pub fn send(&self, buffer: &[u8], info: &mut ConnectionInfo) -> Result<()> { + pub fn send(&self, reader: &mut dyn MultiRead, info: &mut ConnectionInfo) -> Result<()> { + // FIXME: Creating this buffer should be avoided + // if the underlying driver can accept reader. + let mut buffer = vec![0u8; reader.sum_lens()]; + reader.read(&mut VmWriter::from(buffer.as_mut_slice()))?; + let mut driver = self.driver.disable_irq().lock(); driver - .send(buffer, info) + .send(&buffer, info) .map_err(|_| Error::with_message(Errno::EIO, "cannot send data packet")) } diff --git a/kernel/src/net/socket/vsock/stream/connected.rs b/kernel/src/net/socket/vsock/stream/connected.rs index 8dcf14810..44c94ccc3 100644 --- a/kernel/src/net/socket/vsock/stream/connected.rs +++ b/kernel/src/net/socket/vsock/stream/connected.rs @@ -11,7 +11,7 @@ use crate::{ }, prelude::*, process::signal::{Pollee, Poller}, - util::ring_buffer::RingBuffer, + util::{ring_buffer::RingBuffer, MultiRead, MultiWrite}, }; const PER_CONNECTION_BUFFER_CAPACITY: usize = 4096; @@ -50,10 +50,9 @@ impl Connected { self.id } - pub fn try_recv(&self, buf: &mut [u8]) -> Result { + pub fn try_recv(&self, writer: &mut dyn MultiWrite) -> Result { let mut connection = self.connection.disable_irq().lock(); - let bytes_read = connection.buffer.len().min(buf.len()); - connection.buffer.pop_slice(&mut buf[..bytes_read]).unwrap(); + let bytes_read = connection.buffer.read_fallible(writer)?; connection.info.done_forwarding(bytes_read); match bytes_read { @@ -68,14 +67,14 @@ impl Connected { } } - pub fn send(&self, packet: &[u8], flags: SendRecvFlags) -> Result { + pub fn send(&self, reader: &mut dyn MultiRead, flags: SendRecvFlags) -> Result { let mut connection = self.connection.disable_irq().lock(); debug_assert!(flags.is_all_supported()); - let buf_len = packet.len(); + let buf_len = reader.sum_lens(); VSOCK_GLOBAL .get() .unwrap() - .send(packet, &mut connection.info)?; + .send(reader, &mut connection.info)?; Ok(buf_len) } diff --git a/kernel/src/net/socket/vsock/stream/socket.rs b/kernel/src/net/socket/vsock/stream/socket.rs index 65654eaaa..40b1a3487 100644 --- a/kernel/src/net/socket/vsock/stream/socket.rs +++ b/kernel/src/net/socket/vsock/stream/socket.rs @@ -9,13 +9,12 @@ use crate::{ events::IoEvents, fs::{file_handle::FileLike, utils::StatusFlags}, net::socket::{ - util::{copy_message_from_user, copy_message_to_user, create_message_buffer}, vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL}, MessageHeader, SendRecvFlags, SockShutdownCmd, Socket, SocketAddr, }, prelude::*, process::signal::{Pollable, Poller}, - util::IoVec, + util::{MultiRead, MultiWrite}, }; pub struct VsockStreamSocket { @@ -81,17 +80,21 @@ impl VsockStreamSocket { Ok((socket, peer_addr.into())) } - fn send(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + fn send(&self, reader: &mut dyn MultiRead, flags: SendRecvFlags) -> Result { let inner = self.status.read(); match &*inner { - Status::Connected(connected) => connected.send(buf, flags), + Status::Connected(connected) => connected.send(reader, flags), Status::Init(_) | Status::Listen(_) => { return_errno_with_message!(Errno::EINVAL, "the socket is not connected"); } } } - fn try_recv(&self, buf: &mut [u8], _flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + fn try_recv( + &self, + writer: &mut dyn MultiWrite, + _flags: SendRecvFlags, + ) -> Result<(usize, SocketAddr)> { let connected = match &*self.status.read() { Status::Connected(connected) => connected.clone(), Status::Init(_) | Status::Listen(_) => { @@ -99,7 +102,7 @@ impl VsockStreamSocket { } }; - let read_size = connected.try_recv(buf)?; + let read_size = connected.try_recv(writer)?; connected.update_io_events(); let peer_addr = self.peer_addr()?; @@ -113,11 +116,15 @@ impl VsockStreamSocket { Ok((read_size, peer_addr)) } - fn recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + fn recv( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, SocketAddr)> { if self.is_nonblocking() { - self.try_recv(buf, flags) + self.try_recv(writer, flags) } else { - self.wait_events(IoEvents::IN, || self.try_recv(buf, flags)) + self.wait_events(IoEvents::IN, || self.try_recv(writer, flags)) } } } @@ -138,19 +145,16 @@ impl FileLike for VsockStreamSocket { } fn read(&self, writer: &mut VmWriter) -> Result { - let mut buf = vec![0u8; writer.avail()]; // TODO: Set correct flags let read_len = self - .recv(&mut buf, SendRecvFlags::empty()) + .recv(writer, SendRecvFlags::empty()) .map(|(len, _)| len)?; - writer.write_fallible(&mut buf.as_slice().into())?; Ok(read_len) } fn write(&self, reader: &mut VmReader) -> Result { - let buf = reader.collect()?; // TODO: Set correct flags - self.send(&buf, SendRecvFlags::empty()) + self.send(reader, SendRecvFlags::empty()) } fn status_flags(&self) -> StatusFlags { @@ -285,7 +289,7 @@ impl Socket for VsockStreamSocket { fn sendmsg( &self, - io_vecs: &[IoVec], + reader: &mut dyn MultiRead, message_header: MessageHeader, flags: SendRecvFlags, ) -> Result { @@ -301,28 +305,24 @@ impl Socket for VsockStreamSocket { warn!("sending control message is not supported"); } - let buf = copy_message_from_user(io_vecs); - self.send(&buf, flags) + self.send(reader, flags) } - fn recvmsg(&self, io_vecs: &[IoVec], flags: SendRecvFlags) -> Result<(usize, MessageHeader)> { + fn recvmsg( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, MessageHeader)> { // TODO: Deal with flags debug_assert!(flags.is_all_supported()); - let mut buf = create_message_buffer(io_vecs); - - let (received_bytes, _) = self.recv(&mut buf, flags)?; - - let copied_bytes = { - let message = &buf[..received_bytes]; - copy_message_to_user(io_vecs, message) - }; + let (received_bytes, _) = self.recv(writer, flags)?; // TODO: Receive control message let messsge_header = MessageHeader::new(None, None); - Ok((copied_bytes, messsge_header)) + Ok((received_bytes, messsge_header)) } fn addr(&self) -> Result { diff --git a/kernel/src/syscall/preadv.rs b/kernel/src/syscall/preadv.rs index 0e37fa63f..a87cfc94f 100644 --- a/kernel/src/syscall/preadv.rs +++ b/kernel/src/syscall/preadv.rs @@ -4,7 +4,7 @@ use super::SyscallReturn; use crate::{ fs::file_table::FileDesc, prelude::*, - util::{copy_iovs_from_user, IoVec}, + util::{MultiWrite, VmWriterArray}, }; pub fn sys_readv( @@ -74,29 +74,23 @@ fn do_sys_preadv( return Ok(0); } - // Calculate the total buffer length and check for overflow - let total_len = io_vec_count - .checked_mul(core::mem::size_of::()) - .and_then(|val| val.checked_add(offset as usize)); - if total_len.is_none() { - return_errno_with_message!(Errno::EINVAL, "offset + io_vec_count overflow"); - } - let mut total_len: usize = 0; let mut cur_offset = offset as usize; - let io_vecs = copy_iovs_from_user(io_vec_ptr, io_vec_count)?; - for io_vec in io_vecs.as_ref() { - if io_vec.is_empty() { + let mut writer_array = VmWriterArray::from_user_io_vecs(ctx, io_vec_ptr, io_vec_count)?; + for writer in writer_array.writers_mut() { + if !writer.has_avail() { continue; } - if total_len.checked_add(io_vec.len()).is_none() + + let writer_len = writer.sum_lens(); + if total_len.checked_add(writer_len).is_none() || total_len - .checked_add(io_vec.len()) + .checked_add(writer_len) .and_then(|sum| sum.checked_add(cur_offset)) .is_none() || total_len - .checked_add(io_vec.len()) + .checked_add(writer_len) .and_then(|sum| sum.checked_add(cur_offset)) .map(|sum| sum > isize::MAX as usize) .unwrap_or(false) @@ -104,19 +98,16 @@ fn do_sys_preadv( return_errno_with_message!(Errno::EINVAL, "Total length overflow"); } - let mut buffer = vec![0u8; io_vec.len()]; - // TODO: According to the man page // at , // readv must be atomic, // but the current implementation does not ensure atomicity. // A suitable fix would be to add a `readv` method for the `FileLike` trait, // allowing each subsystem to implement atomicity. - let read_len = file.read_bytes_at(cur_offset, &mut buffer)?; - io_vec.write_exact_to_user(&buffer)?; + let read_len = file.read_at(cur_offset, writer)?; total_len += read_len; cur_offset += read_len; - if read_len == 0 || read_len < buffer.len() { + if read_len == 0 || writer.has_avail() { // End of file reached or no more data to read break; } @@ -147,23 +138,21 @@ fn do_sys_readv( let mut total_len = 0; - let io_vecs = copy_iovs_from_user(io_vec_ptr, io_vec_count)?; - for io_vec in io_vecs.as_ref() { - if io_vec.is_empty() { + let mut writer_array = VmWriterArray::from_user_io_vecs(ctx, io_vec_ptr, io_vec_count)?; + for writer in writer_array.writers_mut() { + if !writer.has_avail() { continue; } - let mut buffer = vec![0u8; io_vec.len()]; // TODO: According to the man page // at , // readv must be atomic, // but the current implementation does not ensure atomicity. // A suitable fix would be to add a `readv` method for the `FileLike` trait, // allowing each subsystem to implement atomicity. - let read_len = file.read_bytes(&mut buffer)?; - io_vec.write_exact_to_user(&buffer)?; + let read_len = file.read(writer)?; total_len += read_len; - if read_len == 0 || read_len < buffer.len() { + if read_len == 0 || writer.has_avail() { // End of file reached or no more data to read break; } diff --git a/kernel/src/syscall/pwritev.rs b/kernel/src/syscall/pwritev.rs index 79c93ae0b..939c048cc 100644 --- a/kernel/src/syscall/pwritev.rs +++ b/kernel/src/syscall/pwritev.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 use super::SyscallReturn; -use crate::{fs::file_table::FileDesc, prelude::*, util::copy_iovs_from_user}; +use crate::{fs::file_table::FileDesc, prelude::*, util::VmReaderArray}; pub fn sys_writev( fd: FileDesc, @@ -72,18 +72,20 @@ fn do_sys_pwritev( let mut total_len: usize = 0; let mut cur_offset = offset as usize; - let io_vecs = copy_iovs_from_user(io_vec_ptr, io_vec_count)?; - for io_vec in io_vecs.as_ref() { - if io_vec.is_empty() { + let mut reader_array = VmReaderArray::from_user_io_vecs(ctx, io_vec_ptr, io_vec_count)?; + for reader in reader_array.readers_mut() { + if !reader.has_remain() { continue; } - if total_len.checked_add(io_vec.len()).is_none() + + let reader_len = reader.remain(); + if total_len.checked_add(reader_len).is_none() || total_len - .checked_add(io_vec.len()) + .checked_add(reader_len) .and_then(|sum| sum.checked_add(cur_offset)) .is_none() || total_len - .checked_add(io_vec.len()) + .checked_add(reader_len) .and_then(|sum| sum.checked_add(cur_offset)) .map(|sum| sum > isize::MAX as usize) .unwrap_or(false) @@ -91,19 +93,13 @@ fn do_sys_pwritev( return_errno_with_message!(Errno::EINVAL, "Total length overflow"); } - let buffer = { - let mut buffer = vec![0u8; io_vec.len()]; - io_vec.read_exact_from_user(&mut buffer)?; - buffer - }; - // TODO: According to the man page // at , // writev must be atomic, // but the current implementation does not ensure atomicity. // A suitable fix would be to add a `writev` method for the `FileLike` trait, // allowing each subsystem to implement atomicity. - let write_len = file.write_bytes_at(cur_offset, &buffer)?; + let write_len = file.write_at(cur_offset, reader)?; total_len += write_len; cur_offset += write_len; } @@ -126,25 +122,19 @@ fn do_sys_writev( }; let mut total_len = 0; - let io_vecs = copy_iovs_from_user(io_vec_ptr, io_vec_count)?; - for io_vec in io_vecs.as_ref() { - if io_vec.is_empty() { + let mut reader_array = VmReaderArray::from_user_io_vecs(ctx, io_vec_ptr, io_vec_count)?; + for reader in reader_array.readers_mut() { + if !reader.has_remain() { continue; } - let buffer = { - let mut buffer = vec![0u8; io_vec.len()]; - io_vec.read_exact_from_user(&mut buffer)?; - buffer - }; - // TODO: According to the man page // at , // writev must be atomic, // but the current implementation does not ensure atomicity. // A suitable fix would be to add a `writev` method for the `FileLike` trait, // allowing each subsystem to implement atomicity. - let write_len = file.write_bytes(&buffer)?; + let write_len = file.write(reader)?; total_len += write_len; } Ok(total_len) diff --git a/kernel/src/syscall/recvfrom.rs b/kernel/src/syscall/recvfrom.rs index 6bff30098..8b1bb0022 100644 --- a/kernel/src/syscall/recvfrom.rs +++ b/kernel/src/syscall/recvfrom.rs @@ -5,10 +5,7 @@ use crate::{ fs::file_table::FileDesc, net::socket::SendRecvFlags, prelude::*, - util::{ - net::{get_socket_from_fd, write_socket_addr_to_user}, - IoVec, - }, + util::net::{get_socket_from_fd, write_socket_addr_to_user}, }; pub fn sys_recvfrom( @@ -18,15 +15,19 @@ pub fn sys_recvfrom( flags: i32, src_addr: Vaddr, addrlen_ptr: Vaddr, - _ctx: &Context, + ctx: &Context, ) -> Result { let flags = SendRecvFlags::from_bits_truncate(flags); debug!("sockfd = {sockfd}, buf = 0x{buf:x}, len = {len}, flags = {flags:?}, src_addr = 0x{src_addr:x}, addrlen_ptr = 0x{addrlen_ptr:x}"); let socket = get_socket_from_fd(sockfd)?; - let io_vecs = [IoVec::new(buf, len)]; - let (recv_size, message_header) = socket.recvmsg(&io_vecs, flags)?; + let mut writers = { + let vm_space = ctx.process.root_vmar().vm_space(); + vm_space.writer(buf, len)? + }; + + let (recv_size, message_header) = socket.recvmsg(&mut writers, flags)?; if let Some(socket_addr) = message_header.addr() && src_addr != 0 diff --git a/kernel/src/syscall/recvmsg.rs b/kernel/src/syscall/recvmsg.rs index 3901bc543..efa1cf3d2 100644 --- a/kernel/src/syscall/recvmsg.rs +++ b/kernel/src/syscall/recvmsg.rs @@ -24,8 +24,8 @@ pub fn sys_recvmsg( let (total_bytes, message_header) = { let socket = get_socket_from_fd(sockfd)?; - let io_vecs = c_user_msghdr.copy_iovs_from_user()?; - socket.recvmsg(&io_vecs, flags)? + let mut io_vec_writer = c_user_msghdr.copy_writer_array_from_user(ctx)?; + socket.recvmsg(&mut io_vec_writer, flags)? }; if let Some(addr) = message_header.addr() { diff --git a/kernel/src/syscall/sendmsg.rs b/kernel/src/syscall/sendmsg.rs index ce4fdc6cb..fe9ea559d 100644 --- a/kernel/src/syscall/sendmsg.rs +++ b/kernel/src/syscall/sendmsg.rs @@ -24,9 +24,9 @@ pub fn sys_sendmsg( let socket = get_socket_from_fd(sockfd)?; - let (io_vecs, message_header) = { + let (mut io_vec_reader, message_header) = { let addr = c_user_msghdr.read_socket_addr_from_user()?; - let io_vecs = c_user_msghdr.copy_iovs_from_user()?; + let io_vec_reader = c_user_msghdr.copy_reader_array_from_user(ctx)?; let control_message = { if c_user_msghdr.msg_control != 0 { @@ -36,10 +36,10 @@ pub fn sys_sendmsg( None }; - (io_vecs, MessageHeader::new(addr, control_message)) + (io_vec_reader, MessageHeader::new(addr, control_message)) }; - let total_bytes = socket.sendmsg(&io_vecs, message_header, flags)?; + let total_bytes = socket.sendmsg(&mut io_vec_reader, message_header, flags)?; Ok(SyscallReturn::Return(total_bytes as _)) } diff --git a/kernel/src/syscall/sendto.rs b/kernel/src/syscall/sendto.rs index e33538fa9..29f915732 100644 --- a/kernel/src/syscall/sendto.rs +++ b/kernel/src/syscall/sendto.rs @@ -5,10 +5,7 @@ use crate::{ fs::file_table::FileDesc, net::socket::{MessageHeader, SendRecvFlags}, prelude::*, - util::{ - net::{get_socket_from_fd, read_socket_addr_from_user}, - IoVec, - }, + util::net::{get_socket_from_fd, read_socket_addr_from_user}, }; pub fn sys_sendto( @@ -18,7 +15,7 @@ pub fn sys_sendto( flags: i32, dest_addr: Vaddr, addrlen: usize, - _ctx: &Context, + ctx: &Context, ) -> Result { let flags = SendRecvFlags::from_bits_truncate(flags); let socket_addr = if dest_addr == 0 { @@ -31,10 +28,13 @@ pub fn sys_sendto( let socket = get_socket_from_fd(sockfd)?; - let io_vecs = [IoVec::new(buf, len)]; let message_header = MessageHeader::new(socket_addr, None); - let send_size = socket.sendmsg(&io_vecs, message_header, flags)?; + let mut reader = { + let vm_space = ctx.process.root_vmar().vm_space(); + vm_space.reader(buf, len)? + }; + let send_size = socket.sendmsg(&mut reader, message_header, flags)?; Ok(SyscallReturn::Return(send_size as _)) } diff --git a/kernel/src/util/iovec.rs b/kernel/src/util/iovec.rs index f4db857a5..02e966131 100644 --- a/kernel/src/util/iovec.rs +++ b/kernel/src/util/iovec.rs @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MPL-2.0 +use ostd::mm::{Infallible, VmSpace}; + use crate::prelude::*; /// A kernel space IO vector. #[derive(Debug, Clone, Copy)] -pub struct IoVec { +struct IoVec { base: Vaddr, len: usize, } @@ -37,92 +39,194 @@ impl TryFrom for IoVec { } impl IoVec { - /// Creates a new `IoVec`. - pub const fn new(base: Vaddr, len: usize) -> Self { - Self { base, len } - } - - /// Returns the base address. - pub const fn base(&self) -> Vaddr { - self.base - } - - /// Returns the length. - pub const fn len(&self) -> usize { - self.len - } - /// Returns whether the `IoVec` points to an empty user buffer. - pub const fn is_empty(&self) -> bool { + const fn is_empty(&self) -> bool { self.len == 0 || self.base == 0 } - /// Reads bytes from the user space buffer pointed by - /// the `IoVec` to `dst`. - /// - /// If successful, the read length will be equal to `dst.len()`. - /// - /// # Panics - /// - /// This method will panic if - /// 1.`dst.len()` is not the same as `self.len()`; - /// 2. `self.is_empty()` is `true`. - pub fn read_exact_from_user(&self, dst: &mut [u8]) -> Result<()> { - assert_eq!(dst.len(), self.len); - assert!(!self.is_empty()); - - CurrentUserSpace::get().read_bytes(self.base, &mut VmWriter::from(dst)) + fn reader<'a>(&self, vm_space: &'a VmSpace) -> Result> { + Ok(vm_space.reader(self.base, self.len)?) } - /// Writes bytes from the `src` buffer - /// to the user space buffer pointed by the `IoVec`. - /// - /// If successful, the written length will be equal to `src.len()`. - /// - /// # Panics - /// - /// This method will panic if - /// 1. `src.len()` is not the same as `self.len()`; - /// 2. `self.is_empty()` is `true`. - pub fn write_exact_to_user(&self, src: &[u8]) -> Result<()> { - assert_eq!(src.len(), self.len); - assert!(!self.is_empty()); - - CurrentUserSpace::get().write_bytes(self.base, &mut VmReader::from(src)) - } - - /// Reads bytes to the `dst` buffer - /// from the user space buffer pointed by the `IoVec`. - /// - /// If successful, returns the length of actually read bytes. - pub fn read_from_user(&self, dst: &mut [u8]) -> Result { - let len = self.len.min(dst.len()); - CurrentUserSpace::get().read_bytes(self.base, &mut VmWriter::from(&mut dst[..len]))?; - Ok(len) - } - - /// Writes bytes from the `src` buffer - /// to the user space buffer pointed by the `IoVec`. - /// - /// If successful, returns the length of actually written bytes. - pub fn write_to_user(&self, src: &[u8]) -> Result { - let len = self.len.min(src.len()); - CurrentUserSpace::get().write_bytes(self.base, &mut VmReader::from(&src[..len]))?; - Ok(len) + fn writer<'a>(&self, vm_space: &'a VmSpace) -> Result> { + Ok(vm_space.writer(self.base, self.len)?) } } -/// Copies IO vectors from user space. -pub fn copy_iovs_from_user(start_addr: Vaddr, count: usize) -> Result> { - let mut io_vecs = Vec::with_capacity(count); +/// The util function for create [`VmReader`]/[`VmWriter`]s. +fn copy_iovs_and_convert<'a, T: 'a>( + ctx: &'a Context, + start_addr: Vaddr, + count: usize, + convert_iovec: impl Fn(&IoVec, &'a VmSpace) -> Result, +) -> Result> { + let vm_space = ctx.process.root_vmar().vm_space(); - let user_space = CurrentUserSpace::get(); + let mut v = Vec::with_capacity(count); for idx in 0..count { - let addr = start_addr + idx * core::mem::size_of::(); - let uiov = user_space.read_val::(addr)?; - let iov = IoVec::try_from(uiov)?; - io_vecs.push(iov); + let iov = { + let addr = start_addr + idx * core::mem::size_of::(); + let uiov: UserIoVec = vm_space + .reader(addr, core::mem::size_of::())? + .read_val()?; + IoVec::try_from(uiov)? + }; + + if iov.is_empty() { + continue; + } + + let converted = convert_iovec(&iov, vm_space)?; + v.push(converted) } - Ok(io_vecs.into_boxed_slice()) + Ok(v.into_boxed_slice()) +} + +/// A collection of [`VmReader`]s. +/// +/// Such readers are built from user-provided buffer, so it's always fallible. +pub struct VmReaderArray<'a>(Box<[VmReader<'a>]>); + +/// A collection of [`VmWriter`]s. +/// +/// Such writers are built from user-provided buffer, so it's always fallible. +pub struct VmWriterArray<'a>(Box<[VmWriter<'a>]>); + +impl<'a> VmReaderArray<'a> { + /// Creates a new `IoVecReader` from user-provided io vec buffer. + pub fn from_user_io_vecs( + ctx: &'a Context<'a>, + start_addr: Vaddr, + count: usize, + ) -> Result { + let readers = copy_iovs_and_convert(ctx, start_addr, count, IoVec::reader)?; + Ok(Self(readers)) + } + + /// Returns mutable reference to [`VmReader`]s. + pub fn readers_mut(&'a mut self) -> &'a mut [VmReader<'a>] { + &mut self.0 + } +} + +impl<'a> VmWriterArray<'a> { + /// Creates a new `IoVecWriter` from user-provided io vec buffer. + pub fn from_user_io_vecs( + ctx: &'a Context<'a>, + start_addr: Vaddr, + count: usize, + ) -> Result { + let writers = copy_iovs_and_convert(ctx, start_addr, count, IoVec::writer)?; + Ok(Self(writers)) + } + + /// Returns mutable reference to [`VmWriter`]s. + pub fn writers_mut(&'a mut self) -> &'a mut [VmWriter<'a>] { + &mut self.0 + } +} + +/// Trait defining the read behavior for a collection of [`VmReader`]s. +pub trait MultiRead { + /// Reads the exact number of bytes required to exhaust `self` or fill `writer`, + /// accumulating total bytes read. + /// + /// If the return value is `Ok(n)`, + /// then `n` should be `min(self.sum_lens(), writer.avail())`. + /// + /// # Errors + /// + /// This method returns [`Errno::EFAULT`] if a page fault occurs. + /// The position of `self` and the `writer` is left unspecified when this method returns error. + fn read(&mut self, writer: &mut VmWriter<'_, Infallible>) -> Result; + + /// Calculates the total length of data remaining to read. + fn sum_lens(&self) -> usize; + + /// Checks if the data remaining to read is empty. + fn is_empty(&self) -> bool { + self.sum_lens() == 0 + } +} + +/// Trait defining the write behavior for a collection of [`VmWriter`]s. +pub trait MultiWrite { + /// Writes the exact number of bytes required to exhaust `writer` or fill `self`, + /// accumulating total bytes read. + /// + /// If the return value is `Ok(n)`, + /// then `n` should be `min(self.sum_lens(), reader.remain())`. + /// + /// # Errors + /// + /// This method returns [`Errno::EFAULT`] if a page fault occurs. + /// The position of `self` and the `reader` is left unspecified when this method returns error. + fn write(&mut self, reader: &mut VmReader<'_, Infallible>) -> Result; + + /// Calculates the length of space available to write. + fn sum_lens(&self) -> usize; + + /// Checks if the space available to write is empty. + fn is_empty(&self) -> bool { + self.sum_lens() == 0 + } +} + +impl<'a> MultiRead for VmReaderArray<'a> { + fn read(&mut self, writer: &mut VmWriter<'_, Infallible>) -> Result { + let mut total_len = 0; + + for reader in &mut self.0 { + let copied_len = reader.read_fallible(writer)?; + total_len += copied_len; + if !writer.has_avail() { + break; + } + } + Ok(total_len) + } + + fn sum_lens(&self) -> usize { + self.0.iter().map(|vm_reader| vm_reader.remain()).sum() + } +} + +impl<'a> MultiRead for VmReader<'a> { + fn read(&mut self, writer: &mut VmWriter<'_, Infallible>) -> Result { + Ok(self.read_fallible(writer)?) + } + + fn sum_lens(&self) -> usize { + self.remain() + } +} + +impl<'a> MultiWrite for VmWriterArray<'a> { + fn write(&mut self, reader: &mut VmReader<'_, Infallible>) -> Result { + let mut total_len = 0; + + for writer in &mut self.0 { + let copied_len = writer.write_fallible(reader)?; + total_len += copied_len; + if !reader.has_remain() { + break; + } + } + Ok(total_len) + } + + fn sum_lens(&self) -> usize { + self.0.iter().map(|vm_writer| vm_writer.avail()).sum() + } +} + +impl<'a> MultiWrite for VmWriter<'a> { + fn write(&mut self, reader: &mut VmReader<'_, Infallible>) -> Result { + Ok(self.write_fallible(reader)?) + } + + fn sum_lens(&self) -> usize { + self.avail() + } } diff --git a/kernel/src/util/mod.rs b/kernel/src/util/mod.rs index 5857a4d72..cbe498c94 100644 --- a/kernel/src/util/mod.rs +++ b/kernel/src/util/mod.rs @@ -5,4 +5,4 @@ pub mod net; pub mod random; pub mod ring_buffer; -pub use iovec::{copy_iovs_from_user, IoVec}; +pub use iovec::{MultiRead, MultiWrite, VmReaderArray, VmWriterArray}; diff --git a/kernel/src/util/net/socket.rs b/kernel/src/util/net/socket.rs index c398bed19..ecd26122d 100644 --- a/kernel/src/util/net/socket.rs +++ b/kernel/src/util/net/socket.rs @@ -4,7 +4,7 @@ use super::read_socket_addr_from_user; use crate::{ net::socket::SocketAddr, prelude::*, - util::{copy_iovs_from_user, net::write_socket_addr_with_max_len, IoVec}, + util::{net::write_socket_addr_with_max_len, VmReaderArray, VmWriterArray}, }; /// Standard well-defined IP protocols. @@ -112,7 +112,11 @@ impl CUserMsgHdr { Ok(()) } - pub fn copy_iovs_from_user(&self) -> Result> { - copy_iovs_from_user(self.msg_iov, self.msg_iovlen as usize) + pub fn copy_reader_array_from_user<'a>(&self, ctx: &'a Context) -> Result> { + VmReaderArray::from_user_io_vecs(ctx, self.msg_iov, self.msg_iovlen as usize) + } + + pub fn copy_writer_array_from_user<'a>(&self, ctx: &'a Context) -> Result> { + VmWriterArray::from_user_io_vecs(ctx, self.msg_iov, self.msg_iovlen as usize) } } diff --git a/kernel/src/util/ring_buffer.rs b/kernel/src/util/ring_buffer.rs index e3df106c5..7f7ea5661 100644 --- a/kernel/src/util/ring_buffer.rs +++ b/kernel/src/util/ring_buffer.rs @@ -10,6 +10,7 @@ use align_ext::AlignExt; use inherit_methods_macro::inherit_methods; use ostd::mm::{FrameAllocOptions, Segment, VmIo}; +use super::{MultiRead, MultiWrite}; use crate::prelude::*; /// A lock-free SPSC FIFO ring buffer backed by a [`Segment`]. @@ -220,13 +221,10 @@ impl RingBuffer { } impl RingBuffer { - /// Writes data from the `VmReader` to the `RingBuffer`. + /// Writes data from the `reader` to the `RingBuffer`. /// /// Returns the number of bytes written. - pub fn write_fallible( - &mut self, - reader: &mut VmReader, - ) -> core::result::Result { + pub fn write_fallible(&mut self, reader: &mut dyn MultiRead) -> Result { let mut producer = Producer { rb: self, phantom: PhantomData, @@ -234,13 +232,10 @@ impl RingBuffer { producer.write_fallible(reader) } - /// Reads data from the `VmWriter` to the `RingBuffer`. + /// Reads data from the `writer` to the `RingBuffer`. /// /// Returns the number of bytes read. - pub fn read_fallible( - &mut self, - writer: &mut VmWriter, - ) -> core::result::Result { + pub fn read_fallible(&mut self, writer: &mut dyn MultiWrite) -> Result { let mut consumer = Consumer { rb: self, phantom: PhantomData, @@ -310,38 +305,26 @@ impl>> Producer { /// Writes data from the `VmReader` to the `RingBuffer`. /// /// Returns the number of bytes written. - pub fn write_fallible( - &mut self, - reader: &mut VmReader, - ) -> core::result::Result { + pub fn write_fallible(&mut self, reader: &mut dyn MultiRead) -> Result { let rb = &self.rb; let free_len = rb.free_len(); if free_len == 0 { return Ok(0); } - let write_len = reader.remain().min(free_len); + let write_len = reader.sum_lens().min(free_len); let tail = rb.tail(); let write_len = if tail + write_len > rb.capacity { // Write into two separate parts let mut writer = rb.segment.writer().skip(tail).limit(rb.capacity - tail); - let mut len = writer.write_fallible(reader).map_err(|(e, l1)| { - rb.advance_tail(tail, l1); - (e.into(), l1) - })?; + let mut len = reader.read(&mut writer)?; let mut writer = rb.segment.writer().limit(write_len - (rb.capacity - tail)); - len += writer.write_fallible(reader).map_err(|(e, l2)| { - rb.advance_tail(tail, len + l2); - (e.into(), len + l2) - })?; + len += reader.read(&mut writer)?; len } else { let mut writer = rb.segment.writer().skip(tail).limit(write_len); - writer.write_fallible(reader).map_err(|(e, len)| { - rb.advance_tail(tail, len); - (e.into(), len) - })? + reader.read(&mut writer)? }; rb.advance_tail(tail, write_len); @@ -418,38 +401,26 @@ impl>> Consumer { /// Reads data from the `VmWriter` to the `RingBuffer`. /// /// Returns the number of bytes read. - pub fn read_fallible( - &mut self, - writer: &mut VmWriter, - ) -> core::result::Result { + pub fn read_fallible(&mut self, writer: &mut dyn MultiWrite) -> Result { let rb = &self.rb; let len = rb.len(); if len == 0 { return Ok(0); } - let read_len = writer.avail().min(len); + let read_len = writer.sum_lens().min(len); let head = rb.head(); let read_len = if head + read_len > rb.capacity { // Read from two separate parts let mut reader = rb.segment.reader().skip(head).limit(rb.capacity - head); - let mut len = reader.read_fallible(writer).map_err(|(e, l1)| { - rb.advance_head(head, l1); - (e.into(), l1) - })?; + let mut len = writer.write(&mut reader)?; let mut reader = rb.segment.reader().limit(read_len - (rb.capacity - head)); - len += reader.read_fallible(writer).map_err(|(e, l2)| { - rb.advance_head(head, len + l2); - (e.into(), len + l2) - })?; + len += writer.write(&mut reader)?; len } else { let mut reader = rb.segment.reader().skip(head).limit(read_len); - reader.read_fallible(writer).map_err(|(e, len)| { - rb.advance_head(head, len); - (e.into(), len) - })? + writer.write(&mut reader)? }; rb.advance_head(head, read_len);