diff --git a/docs/src/kernel/linux-compatibility.md b/docs/src/kernel/linux-compatibility.md index b78ff267e..a0bd177f6 100644 --- a/docs/src/kernel/linux-compatibility.md +++ b/docs/src/kernel/linux-compatibility.md @@ -66,8 +66,8 @@ provided by Linux on x86-64 architecture. | 43 | accept | ✅ | | 44 | sendto | ✅ | | 45 | recvfrom | ✅ | -| 46 | sendmsg | ❌ | -| 47 | recvmsg | ❌ | +| 46 | sendmsg | ✅ | +| 47 | recvmsg | ✅ | | 48 | shutdown | ✅ | | 49 | bind | ✅ | | 50 | listen | ✅ | diff --git a/kernel/aster-nix/src/net/socket/ip/datagram/bound.rs b/kernel/aster-nix/src/net/socket/ip/datagram/bound.rs index 91ac162bf..f246f5adc 100644 --- a/kernel/aster-nix/src/net/socket/ip/datagram/bound.rs +++ b/kernel/aster-nix/src/net/socket/ip/datagram/bound.rs @@ -39,11 +39,7 @@ impl BoundDatagram { self.remote_endpoint = Some(*endpoint) } - pub fn try_recvfrom( - &self, - buf: &mut [u8], - flags: SendRecvFlags, - ) -> Result<(usize, IpEndpoint)> { + pub fn try_recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, IpEndpoint)> { let result = self .bound_socket .raw_with(|socket: &mut RawUdpSocket| socket.recv_slice(buf)); @@ -55,12 +51,7 @@ impl BoundDatagram { } } - pub fn try_sendto( - &self, - buf: &[u8], - remote: &IpEndpoint, - flags: SendRecvFlags, - ) -> Result { + pub fn try_send(&self, buf: &[u8], remote: &IpEndpoint, flags: SendRecvFlags) -> Result { let result = self.bound_socket.raw_with(|socket: &mut RawUdpSocket| { if socket.payload_send_capacity() < buf.len() { return None; diff --git a/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs b/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs index 298443127..0e6e21f1a 100644 --- a/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs +++ b/kernel/aster-nix/src/net/socket/ip/datagram/mod.rs @@ -15,12 +15,16 @@ use crate::{ iface::IpEndpoint, poll_ifaces, socket::{ - util::{send_recv_flags::SendRecvFlags, socket_addr::SocketAddr}, + util::{ + copy_message_from_user, copy_message_to_user, create_message_buffer, + send_recv_flags::SendRecvFlags, socket_addr::SocketAddr, MessageHeader, + }, Socket, }, }, prelude::*, process::signal::{Pollee, Poller}, + util::IoVec, }; mod bound; @@ -120,20 +124,19 @@ impl DatagramSocket { }) } - fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + fn try_recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { let inner = self.inner.read(); let Inner::Bound(bound_datagram) = inner.as_ref() else { return_errno_with_message!(Errno::EAGAIN, "the socket is not bound"); }; - let received = - bound_datagram - .try_recvfrom(buf, flags) - .map(|(recv_bytes, remote_endpoint)| { - bound_datagram.update_io_events(&self.pollee); - (recv_bytes, remote_endpoint.into()) - }); + let received = bound_datagram + .try_recv(buf, flags) + .map(|(recv_bytes, remote_endpoint)| { + bound_datagram.update_io_events(&self.pollee); + (recv_bytes, remote_endpoint.into()) + }); drop(inner); poll_ifaces(); @@ -141,7 +144,15 @@ impl DatagramSocket { received } - fn try_sendto(&self, buf: &[u8], remote: &IpEndpoint, flags: SendRecvFlags) -> Result { + fn recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + if self.is_nonblocking() { + self.try_recv(buf, flags) + } else { + self.wait_events(IoEvents::IN, || self.try_recv(buf, flags)) + } + } + + fn try_send(&self, buf: &[u8], remote: &IpEndpoint, flags: SendRecvFlags) -> Result { let inner = self.inner.read(); let Inner::Bound(bound_datagram) = inner.as_ref() else { @@ -149,7 +160,7 @@ impl DatagramSocket { }; let sent_bytes = bound_datagram - .try_sendto(buf, remote, flags) + .try_send(buf, remote, flags) .map(|sent_bytes| { bound_datagram.update_io_events(&self.pollee); sent_bytes @@ -194,16 +205,24 @@ impl DatagramSocket { impl FileLike for DatagramSocket { fn read(&self, buf: &mut [u8]) -> Result { - // FIXME: respect flags + // TODO: set correct flags let flags = SendRecvFlags::empty(); - let (recv_len, _) = self.recvfrom(buf, flags)?; - Ok(recv_len) + self.recv(buf, flags).map(|(len, _)| len) } fn write(&self, buf: &[u8]) -> Result { - // FIXME: set correct flags + let remote = self.remote_endpoint().ok_or_else(|| { + Error::with_message( + Errno::EDESTADDRREQ, + "the destination address is not specified", + ) + })?; + + // TODO: Set correct flags let flags = SendRecvFlags::empty(); - self.sendto(buf, None, flags) + + // TODO: Block if send buffer is full + self.try_send(buf, &remote, flags) } fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { @@ -293,26 +312,21 @@ impl Socket for DatagramSocket { .ok_or_else(|| Error::with_message(Errno::ENOTCONN, "the socket is not connected")) } - // FIXME: respect RecvFromFlags - fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { - debug_assert!(flags.is_all_supported()); - - if self.is_nonblocking() { - self.try_recvfrom(buf, flags) - } else { - self.wait_events(IoEvents::IN, || self.try_recvfrom(buf, flags)) - } - } - - fn sendto( + fn sendmsg( &self, - buf: &[u8], - remote: Option, + io_vecs: &[IoVec], + message_header: MessageHeader, flags: SendRecvFlags, ) -> Result { + // TODO: Deal with flags debug_assert!(flags.is_all_supported()); - let remote_endpoint = match remote { + let MessageHeader { + addr, + control_message, + } = message_header; + + let remote_endpoint = match addr { Some(remote_addr) => { let endpoint = remote_addr.try_into()?; self.try_bind_empheral(&endpoint)?; @@ -326,8 +340,35 @@ impl Socket for DatagramSocket { })?, }; + if control_message.is_some() { + // TODO: Support sending control message + warn!("sending control message is not supported"); + } + + let buf = copy_message_from_user(io_vecs); + // TODO: Block if the send buffer is full - self.try_sendto(buf, &remote_endpoint, flags) + self.try_send(&buf, &remote_endpoint, flags) + } + + fn recvmsg(&self, io_vecs: &[IoVec], flags: SendRecvFlags) -> Result<(usize, MessageHeader)> { + // TODO: Deal with flags + debug_assert!(flags.is_all_supported()); + + let mut buf = create_message_buffer(io_vecs); + + let (received_bytes, peer_addr) = self.recv(&mut buf, flags)?; + + let copied_bytes = { + let message = &buf[..received_bytes]; + copy_message_to_user(io_vecs, message) + }; + + // TODO: Receive control message + + let message_header = MessageHeader::new(Some(peer_addr), None); + + Ok((copied_bytes, message_header)) } } diff --git a/kernel/aster-nix/src/net/socket/ip/stream/connected.rs b/kernel/aster-nix/src/net/socket/ip/stream/connected.rs index c0b13eba7..0887fb81c 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/connected.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/connected.rs @@ -53,10 +53,11 @@ impl ConnectedStream { Ok(()) } - pub fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result { + pub fn try_recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result { let result = self .bound_socket .raw_with(|socket: &mut RawTcpSocket| socket.recv_slice(buf)); + match result { Ok(0) => return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty"), Ok(recv_bytes) => Ok(recv_bytes), @@ -67,10 +68,11 @@ impl ConnectedStream { } } - pub fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + pub fn try_send(&self, buf: &[u8], flags: SendRecvFlags) -> Result { let result = self .bound_socket .raw_with(|socket: &mut RawTcpSocket| socket.send_slice(buf)); + match result { Ok(0) => return_errno_with_message!(Errno::EAGAIN, "the send buffer is full"), Ok(sent_bytes) => Ok(sent_bytes), diff --git a/kernel/aster-nix/src/net/socket/ip/stream/mod.rs b/kernel/aster-nix/src/net/socket/ip/stream/mod.rs index 418fbcdb2..b106f7e62 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/mod.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/mod.rs @@ -25,16 +25,19 @@ use crate::{ Error as SocketError, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption, }, util::{ + copy_message_from_user, copy_message_to_user, create_message_buffer, options::{SocketOptionSet, MIN_RECVBUF, MIN_SENDBUF}, send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd, socket_addr::SocketAddr, + MessageHeader, }, Socket, }, }, prelude::*, process::signal::{Pollee, Poller}, + util::IoVec, }; mod connected; @@ -249,7 +252,7 @@ impl StreamSocket { accepted } - fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + fn try_recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { let state = self.state.read(); let connected_stream = match state.as_ref() { @@ -262,7 +265,7 @@ impl StreamSocket { } }; - let received = connected_stream.try_recvfrom(buf, flags).map(|recv_bytes| { + let received = connected_stream.try_recv(buf, flags).map(|recv_bytes| { connected_stream.update_io_events(&self.pollee); let remote_endpoint = connected_stream.remote_endpoint(); @@ -275,7 +278,15 @@ impl StreamSocket { received } - fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + fn recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + if self.is_nonblocking() { + self.try_recv(buf, flags) + } else { + self.wait_events(IoEvents::IN, || self.try_recv(buf, flags)) + } + } + + fn try_send(&self, buf: &[u8], flags: SendRecvFlags) -> Result { let state = self.state.read(); let connected_stream = match state.as_ref() { @@ -291,7 +302,7 @@ impl StreamSocket { } }; - let sent_bytes = connected_stream.try_sendto(buf, flags).map(|sent_bytes| { + let sent_bytes = connected_stream.try_send(buf, flags).map(|sent_bytes| { connected_stream.update_io_events(&self.pollee); sent_bytes }); @@ -302,6 +313,14 @@ impl StreamSocket { sent_bytes } + fn send(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + if self.is_nonblocking() { + self.try_send(buf, flags) + } else { + self.wait_events(IoEvents::OUT, || self.try_send(buf, flags)) + } + } + // TODO: Support timeout fn wait_events(&self, mask: IoEvents, mut cond: F) -> Result where @@ -344,16 +363,15 @@ impl StreamSocket { impl FileLike for StreamSocket { fn read(&self, buf: &mut [u8]) -> Result { - // FIXME: set correct flags + // TODO: Set correct flags let flags = SendRecvFlags::empty(); - let (recv_len, _) = self.recvfrom(buf, flags)?; - Ok(recv_len) + self.recv(buf, flags).map(|(len, _)| len) } fn write(&self, buf: &[u8]) -> Result { - // FIXME: set correct flags + // TODO: Set correct flags let flags = SendRecvFlags::empty(); - self.sendto(buf, None, flags) + self.send(buf, flags) } fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { @@ -510,33 +528,53 @@ impl Socket for StreamSocket { Ok(remote_endpoint.into()) } - fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { - debug_assert!(flags.is_all_supported()); - - if self.is_nonblocking() { - self.try_recvfrom(buf, flags) - } else { - self.wait_events(IoEvents::IN, || self.try_recvfrom(buf, flags)) - } - } - - fn sendto( + fn sendmsg( &self, - buf: &[u8], - remote: Option, + io_vecs: &[IoVec], + message_header: MessageHeader, flags: SendRecvFlags, ) -> Result { + // TODO: Deal with flags debug_assert!(flags.is_all_supported()); + let MessageHeader { + control_message, .. + } = message_header; + // According to the Linux man pages, `EISCONN` _may_ be returned when the destination // address is specified for a connection-mode socket. In practice, the destination address // is simply ignored. We follow the same behavior as the Linux implementation to ignore it. - if self.is_nonblocking() { - self.try_sendto(buf, flags) - } else { - self.wait_events(IoEvents::OUT, || self.try_sendto(buf, flags)) + if control_message.is_some() { + // TODO: Support sending control message + warn!("sending control message is not supported"); } + + let buf = copy_message_from_user(io_vecs); + + self.send(&buf, flags) + } + + fn recvmsg(&self, io_vecs: &[IoVec], flags: SendRecvFlags) -> Result<(usize, MessageHeader)> { + // TODO: Deal with flags + debug_assert!(flags.is_all_supported()); + + let mut buf = create_message_buffer(io_vecs); + + let (received_bytes, _) = self.recv(&mut buf, flags)?; + + let copied_bytes = { + let message = &buf[..received_bytes]; + copy_message_to_user(io_vecs, message) + }; + + // TODO: Receive control message + + // According to , + // peer address is ignored for connected socket. + let message_header = MessageHeader::new(None, None); + + Ok((copied_bytes, message_header)) } fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> { diff --git a/kernel/aster-nix/src/net/socket/mod.rs b/kernel/aster-nix/src/net/socket/mod.rs index ce367e8fc..6fcd8247e 100644 --- a/kernel/aster-nix/src/net/socket/mod.rs +++ b/kernel/aster-nix/src/net/socket/mod.rs @@ -5,9 +5,9 @@ use self::options::SocketOption; pub use self::util::{ options::LingerOption, send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd, - socket_addr::SocketAddr, + socket_addr::SocketAddr, MessageHeader, }; -use crate::{fs::file_handle::FileLike, prelude::*}; +use crate::{fs::file_handle::FileLike, prelude::*, util::IoVec}; pub mod ip; pub mod options; @@ -63,18 +63,18 @@ pub trait Socket: FileLike + Send + Sync { return_errno_with_message!(Errno::EOPNOTSUPP, "setsockopt() is not supported"); } - /// Receive a message from a socket - fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { - return_errno_with_message!(Errno::EOPNOTSUPP, "recvfrom() is not supported"); - } - - /// Send a message on a socket - fn sendto( + /// Sends a message on a socket. + fn sendmsg( &self, - buf: &[u8], - remote: Option, + io_vecs: &[IoVec], + message_header: MessageHeader, flags: SendRecvFlags, - ) -> Result { - return_errno_with_message!(Errno::EOPNOTSUPP, "sendto() is not supported"); - } + ) -> Result; + + /// Receives a message from a socket. + /// + /// If successful, the `io_vecs` buffer will be filled with the received content. + /// This method returns the length of the received message, + /// and the message header. + fn recvmsg(&self, io_vecs: &[IoVec], flags: SendRecvFlags) -> Result<(usize, MessageHeader)>; } diff --git a/kernel/aster-nix/src/net/socket/unix/stream/socket.rs b/kernel/aster-nix/src/net/socket/unix/stream/socket.rs index f4c5e405f..58514f027 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/socket.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/socket.rs @@ -19,11 +19,15 @@ use crate::{ }, net::socket::{ unix::{addr::UnixSocketAddrBound, UnixSocketAddr}, - util::{send_recv_flags::SendRecvFlags, socket_addr::SocketAddr}, + util::{ + copy_message_from_user, copy_message_to_user, create_message_buffer, + send_recv_flags::SendRecvFlags, socket_addr::SocketAddr, MessageHeader, + }, SockShutdownCmd, Socket, }, prelude::*, process::signal::Poller, + util::IoVec, }; pub struct UnixStreamSocket(RwLock); @@ -86,6 +90,24 @@ impl UnixStreamSocket { status_flags.intersection(SUPPORTED_FLAGS) } + + fn send(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + let connected = match &*self.0.read() { + State::Connected(connected) => connected.clone(), + _ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"), + }; + + connected.write(buf) + } + + fn recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result { + let connected = match &*self.0.read() { + State::Connected(connected) => connected.clone(), + _ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"), + }; + + connected.read(buf) + } } impl FileLike for UnixStreamSocket { @@ -94,12 +116,15 @@ impl FileLike for UnixStreamSocket { } fn read(&self, buf: &mut [u8]) -> Result { - self.recvfrom(buf, SendRecvFlags::empty()) - .map(|(read_size, _)| read_size) + // TODO: Set correct flags + let flags = SendRecvFlags::empty(); + self.recv(buf, flags) } fn write(&self, buf: &[u8]) -> Result { - self.sendto(buf, None, SendRecvFlags::empty()) + // TODO: Set correct flags + let flags = SendRecvFlags::empty(); + self.send(buf, flags) } fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { @@ -251,31 +276,46 @@ impl Socket for UnixStreamSocket { } } - fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { - let connected = match &*self.0.read() { - State::Connected(connected) => connected.clone(), - _ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"), - }; - - let peer_addr = self.peer_addr()?; - let read_size = connected.read(buf)?; - Ok((read_size, peer_addr)) - } - - fn sendto( + fn sendmsg( &self, - buf: &[u8], - remote: Option, + io_vecs: &[IoVec], + message_header: MessageHeader, flags: SendRecvFlags, ) -> Result { - debug_assert!(remote.is_none()); - // TODO: deal with flags - let connected = match &*self.0.read() { - State::Connected(connected) => connected.clone(), - _ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"), + // TODO: Deal with flags + debug_assert!(flags.is_all_supported()); + + let MessageHeader { + control_message, .. + } = message_header; + + if control_message.is_some() { + // TODO: Support sending control message + warn!("sending control message is not supported"); + } + + let buf = copy_message_from_user(io_vecs); + + self.send(&buf, flags) + } + + fn recvmsg(&self, io_vecs: &[IoVec], flags: SendRecvFlags) -> Result<(usize, MessageHeader)> { + // TODO: Deal with flags + debug_assert!(flags.is_all_supported()); + + let mut buf = create_message_buffer(io_vecs); + let received_bytes = self.recv(&mut buf, flags)?; + + let copied_bytes = { + let message = &buf[..received_bytes]; + copy_message_to_user(io_vecs, message) }; - connected.write(buf) + // TODO: Receive control message + + let message_header = MessageHeader::new(None, None); + + Ok((copied_bytes, message_header)) } } diff --git a/kernel/aster-nix/src/net/socket/util/message_header.rs b/kernel/aster-nix/src/net/socket/util/message_header.rs new file mode 100644 index 000000000..1a0184153 --- /dev/null +++ b/kernel/aster-nix/src/net/socket/util/message_header.rs @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MPL-2.0 + +use super::socket_addr::SocketAddr; +use crate::{prelude::*, util::IoVec}; + +/// Message header used for sendmsg/recvmsg. +#[derive(Debug)] +pub struct MessageHeader { + pub(in crate::net) addr: Option, + pub(in crate::net) control_message: Option, +} + +impl MessageHeader { + /// Creates a new `MessageHeader`. + pub const fn new(addr: Option, control_message: Option) -> Self { + Self { + addr, + control_message, + } + } + + /// Returns the socket address. + pub fn addr(&self) -> Option<&SocketAddr> { + self.addr.as_ref() + } +} + +/// Control message carried by MessageHeader. +/// +/// TODO: Implement the struct. The struct is empty now. +#[derive(Debug)] +pub struct ControlMessage; + +/// Copies a message from user space. +/// +/// Since udp allows sending and receiving packet of length 0, +/// The returned buffer may have length of zero. +pub fn copy_message_from_user(io_vecs: &[IoVec]) -> Box<[u8]> { + let mut buffer = create_message_buffer(io_vecs); + + let mut total_bytes = 0; + for io_vec in io_vecs { + if io_vec.is_empty() { + continue; + } + let dst = &mut buffer[total_bytes..total_bytes + io_vec.len()]; + // FIXME: short read should be allowed here + match io_vec.read_exact_from_user(dst) { + Ok(()) => total_bytes += io_vec.len(), + Err(e) => { + warn!("fails to copy message from user"); + break; + } + } + } + + buffer.truncate(total_bytes); + buffer.into_boxed_slice() +} + +/// Creates a buffer whose length +/// is equal to the total length of `io_vecs`. +pub fn create_message_buffer(io_vecs: &[IoVec]) -> Vec { + let buffer_len: usize = io_vecs.iter().map(|iovec| iovec.len()).sum(); + vec![0; buffer_len] +} + +/// Copies a message to user space. +/// +/// This method returns the actual copied length. +pub fn copy_message_to_user(io_vecs: &[IoVec], message: &[u8]) -> usize { + let mut total_bytes = 0; + + for io_vec in io_vecs { + if io_vec.is_empty() { + continue; + } + + let len = io_vec.len().min(message.len() - total_bytes); + if len == 0 { + break; + } + + let src = &message[total_bytes..total_bytes + len]; + match io_vec.write_to_user(src) { + Ok(len) => total_bytes += len, + Err(e) => { + warn!("fails to copy message to user"); + break; + } + } + } + + total_bytes +} diff --git a/kernel/aster-nix/src/net/socket/util/mod.rs b/kernel/aster-nix/src/net/socket/util/mod.rs index 7478fe95e..fa8f399bc 100644 --- a/kernel/aster-nix/src/net/socket/util/mod.rs +++ b/kernel/aster-nix/src/net/socket/util/mod.rs @@ -1,6 +1,12 @@ // SPDX-License-Identifier: MPL-2.0 +mod message_header; pub mod options; pub mod send_recv_flags; pub mod shutdown_cmd; pub mod socket_addr; + +pub use message_header::MessageHeader; +pub(in crate::net) use message_header::{ + copy_message_from_user, copy_message_to_user, create_message_buffer, +}; diff --git a/kernel/aster-nix/src/net/socket/util/socket_addr.rs b/kernel/aster-nix/src/net/socket/util/socket_addr.rs index 72169ca3d..e54f529dd 100644 --- a/kernel/aster-nix/src/net/socket/util/socket_addr.rs +++ b/kernel/aster-nix/src/net/socket/util/socket_addr.rs @@ -10,7 +10,7 @@ use crate::{ type PortNum = u16; -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum SocketAddr { Unix(UnixSocketAddr), IPv4(Ipv4Address, PortNum), diff --git a/kernel/aster-nix/src/net/socket/vsock/stream/connected.rs b/kernel/aster-nix/src/net/socket/vsock/stream/connected.rs index 01ee8a1da..11a12c732 100644 --- a/kernel/aster-nix/src/net/socket/vsock/stream/connected.rs +++ b/kernel/aster-nix/src/net/socket/vsock/stream/connected.rs @@ -68,14 +68,14 @@ impl Connected { } } - pub fn send(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + pub fn send(&self, packet: &[u8], flags: SendRecvFlags) -> Result { let mut connection = self.connection.lock_irq_disabled(); debug_assert!(flags.is_all_supported()); - let buf_len = buf.len(); + let buf_len = packet.len(); VSOCK_GLOBAL .get() .unwrap() - .send(buf, &mut connection.info)?; + .send(packet, &mut connection.info)?; Ok(buf_len) } diff --git a/kernel/aster-nix/src/net/socket/vsock/stream/socket.rs b/kernel/aster-nix/src/net/socket/vsock/stream/socket.rs index 7f55e8452..869a97b9b 100644 --- a/kernel/aster-nix/src/net/socket/vsock/stream/socket.rs +++ b/kernel/aster-nix/src/net/socket/vsock/stream/socket.rs @@ -9,11 +9,13 @@ use crate::{ events::IoEvents, fs::{file_handle::FileLike, utils::StatusFlags}, net::socket::{ + util::{copy_message_from_user, copy_message_to_user, create_message_buffer}, vsock::{addr::VsockSocketAddr, VSOCK_GLOBAL}, - SendRecvFlags, SockShutdownCmd, Socket, SocketAddr, + MessageHeader, SendRecvFlags, SockShutdownCmd, Socket, SocketAddr, }, prelude::*, process::signal::Poller, + util::IoVec, }; pub struct VsockStreamSocket { @@ -112,7 +114,17 @@ impl VsockStreamSocket { Ok((socket, peer_addr.into())) } - fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + fn send(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + let inner = self.status.read(); + match &*inner { + Status::Connected(connected) => connected.send(buf, flags), + Status::Init(_) | Status::Listen(_) => { + return_errno_with_message!(Errno::EINVAL, "the socket is not connected"); + } + } + } + + fn try_recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { let connected = match &*self.status.read() { Status::Connected(connected) => connected.clone(), Status::Init(_) | Status::Listen(_) => { @@ -133,6 +145,14 @@ impl VsockStreamSocket { } Ok((read_size, peer_addr)) } + + fn recv(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { + if self.is_nonblocking() { + self.try_recv(buf, flags) + } else { + self.wait_events(IoEvents::IN, || self.try_recv(buf, flags)) + } + } } impl FileLike for VsockStreamSocket { @@ -141,12 +161,15 @@ impl FileLike for VsockStreamSocket { } fn read(&self, buf: &mut [u8]) -> Result { - self.recvfrom(buf, SendRecvFlags::empty()) - .map(|(read_size, _)| read_size) + // TODO: Set correct flags + let flags = SendRecvFlags::empty(); + self.recv(buf, SendRecvFlags::empty()).map(|(len, _)| len) } fn write(&self, buf: &[u8]) -> Result { - self.sendto(buf, None, SendRecvFlags::empty()) + // TODO: Set correct flags + let flags = SendRecvFlags::empty(); + self.send(buf, flags) } fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { @@ -219,7 +242,7 @@ impl Socket for VsockStreamSocket { // Send request vsockspace.request(&connecting.info()).unwrap(); // wait for response from driver - // TODO: add timeout + // TODO: Add timeout let poller = Poller::new(); if !connecting .poll(IoEvents::IN, Some(&poller)) @@ -287,33 +310,46 @@ impl Socket for VsockStreamSocket { } } - fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { - debug_assert!(flags.is_all_supported()); - - if self.is_nonblocking() { - self.try_recvfrom(buf, flags) - } else { - self.wait_events(IoEvents::IN, || self.try_recvfrom(buf, flags)) - } - } - - fn sendto( + fn sendmsg( &self, - buf: &[u8], - remote: Option, + io_vecs: &[IoVec], + message_header: MessageHeader, flags: SendRecvFlags, ) -> Result { - debug_assert!(remote.is_none()); - if remote.is_some() { - return_errno_with_message!(Errno::EINVAL, "vsock should not provide remote addr"); - } - let inner = self.status.read(); - match &*inner { - Status::Connected(connected) => connected.send(buf, flags), - Status::Init(_) | Status::Listen(_) => { - return_errno_with_message!(Errno::EINVAL, "the socket is not connected"); - } + // TODO: Deal with flags + debug_assert!(flags.is_all_supported()); + + let MessageHeader { + control_message, .. + } = message_header; + + if control_message.is_some() { + // TODO: Support sending control message + warn!("sending control message is not supported"); } + + let buf = copy_message_from_user(io_vecs); + self.send(&buf, flags) + } + + fn recvmsg(&self, io_vecs: &[IoVec], flags: SendRecvFlags) -> Result<(usize, MessageHeader)> { + // TODO: Deal with flags + debug_assert!(flags.is_all_supported()); + + let mut buf = create_message_buffer(io_vecs); + + let (received_bytes, _) = self.recv(&mut buf, flags)?; + + let copied_bytes = { + let message = &buf[..received_bytes]; + copy_message_to_user(io_vecs, message) + }; + + // TODO: Receive control message + + let messsge_header = MessageHeader::new(None, None); + + Ok((copied_bytes, messsge_header)) } fn addr(&self) -> Result { diff --git a/kernel/aster-nix/src/syscall/arch/x86.rs b/kernel/aster-nix/src/syscall/arch/x86.rs index 293af2b9d..e6fab2b1c 100644 --- a/kernel/aster-nix/src/syscall/arch/x86.rs +++ b/kernel/aster-nix/src/syscall/arch/x86.rs @@ -70,6 +70,7 @@ use crate::syscall::{ read::sys_read, readlink::{sys_readlink, sys_readlinkat}, recvfrom::sys_recvfrom, + recvmsg::sys_recvmsg, rename::{sys_rename, sys_renameat}, rmdir::sys_rmdir, rt_sigaction::sys_rt_sigaction, @@ -81,6 +82,7 @@ use crate::syscall::{ sched_yield::sys_sched_yield, select::sys_select, sendfile::sys_sendfile, + sendmsg::sys_sendmsg, sendto::sys_sendto, set_get_priority::{sys_get_priority, sys_set_priority}, set_robust_list::sys_set_robust_list, @@ -161,6 +163,8 @@ impl_syscall_nums_and_dispatch_fn! { SYS_ACCEPT = 43 => sys_accept(args[..3]); SYS_SENDTO = 44 => sys_sendto(args[..6]); SYS_RECVFROM = 45 => sys_recvfrom(args[..6]); + SYS_SENDMSG = 46 => sys_sendmsg(args[..3]); + SYS_RECVMSG = 47 => sys_recvmsg(args[..3]); SYS_SHUTDOWN = 48 => sys_shutdown(args[..2]); SYS_BIND = 49 => sys_bind(args[..3]); SYS_LISTEN = 50 => sys_listen(args[..2]); diff --git a/kernel/aster-nix/src/syscall/mod.rs b/kernel/aster-nix/src/syscall/mod.rs index 36bfccd5e..9fa4ea591 100644 --- a/kernel/aster-nix/src/syscall/mod.rs +++ b/kernel/aster-nix/src/syscall/mod.rs @@ -77,6 +77,7 @@ mod prlimit64; mod read; mod readlink; mod recvfrom; +mod recvmsg; mod rename; mod rmdir; mod rt_sigaction; @@ -88,6 +89,7 @@ mod sched_getaffinity; mod sched_yield; mod select; mod sendfile; +mod sendmsg; mod sendto; mod set_get_priority; mod set_robust_list; diff --git a/kernel/aster-nix/src/syscall/recvfrom.rs b/kernel/aster-nix/src/syscall/recvfrom.rs index b431a699e..6d5001f1a 100644 --- a/kernel/aster-nix/src/syscall/recvfrom.rs +++ b/kernel/aster-nix/src/syscall/recvfrom.rs @@ -7,7 +7,7 @@ use crate::{ prelude::*, util::{ net::{get_socket_from_fd, write_socket_addr_to_user}, - write_bytes_to_user, + IoVec, }, }; @@ -24,14 +24,14 @@ pub fn sys_recvfrom( let socket = get_socket_from_fd(sockfd)?; - let mut buffer = vec![0u8; len]; + let io_vecs = [IoVec::new(buf, len)]; + let (recv_size, message_header) = socket.recvmsg(&io_vecs, flags)?; - let (recv_size, socket_addr) = socket.recvfrom(&mut buffer, flags)?; - if buf != 0 { - write_bytes_to_user(buf, &buffer[..recv_size])?; - } - if src_addr != 0 { - write_socket_addr_to_user(&socket_addr, src_addr, addrlen_ptr)?; + if let Some(socket_addr) = message_header.addr() + && src_addr != 0 + { + write_socket_addr_to_user(socket_addr, src_addr, addrlen_ptr)?; } + Ok(SyscallReturn::Return(recv_size as _)) } diff --git a/kernel/aster-nix/src/syscall/recvmsg.rs b/kernel/aster-nix/src/syscall/recvmsg.rs new file mode 100644 index 000000000..21b1f8feb --- /dev/null +++ b/kernel/aster-nix/src/syscall/recvmsg.rs @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MPL-2.0 + +use super::SyscallReturn; +use crate::{ + fs::file_table::FileDesc, + net::socket::SendRecvFlags, + prelude::*, + util::{ + net::{get_socket_from_fd, CUserMsgHdr}, + read_val_from_user, + }, +}; + +pub fn sys_recvmsg(sockfd: FileDesc, user_msghdr_ptr: Vaddr, flags: i32) -> Result { + let c_user_msghdr: CUserMsgHdr = read_val_from_user(user_msghdr_ptr)?; + let flags = SendRecvFlags::from_bits_truncate(flags); + + debug!( + "sockfd = {}, user_msghdr = {:x?}, flags = {:?}", + sockfd, c_user_msghdr, flags + ); + + let (total_bytes, message_header) = { + let socket = get_socket_from_fd(sockfd)?; + let io_vecs = c_user_msghdr.copy_iovs_from_user()?; + socket.recvmsg(&io_vecs, flags)? + }; + + if let Some(addr) = message_header.addr() { + c_user_msghdr.write_socket_addr_to_user(addr)?; + } + + if c_user_msghdr.msg_control != 0 { + warn!("receiving control message is not supported"); + } + + Ok(SyscallReturn::Return(total_bytes as _)) +} diff --git a/kernel/aster-nix/src/syscall/sendmsg.rs b/kernel/aster-nix/src/syscall/sendmsg.rs new file mode 100644 index 000000000..47276b472 --- /dev/null +++ b/kernel/aster-nix/src/syscall/sendmsg.rs @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MPL-2.0 + +use super::SyscallReturn; +use crate::{ + fs::file_table::FileDesc, + net::socket::{MessageHeader, SendRecvFlags}, + prelude::*, + util::{ + net::{get_socket_from_fd, CUserMsgHdr}, + read_val_from_user, + }, +}; + +pub fn sys_sendmsg(sockfd: FileDesc, user_msghdr_ptr: Vaddr, flags: i32) -> Result { + let c_user_msghdr: CUserMsgHdr = read_val_from_user(user_msghdr_ptr)?; + let flags = SendRecvFlags::from_bits_truncate(flags); + + debug!( + "sockfd = {}, user_msghdr = {:x?}, flags = {:?}", + sockfd, c_user_msghdr, flags + ); + + let socket = get_socket_from_fd(sockfd)?; + + let (io_vecs, message_header) = { + let addr = c_user_msghdr.read_socket_addr_from_user()?; + let io_vecs = c_user_msghdr.copy_iovs_from_user()?; + + let control_message = { + if c_user_msghdr.msg_control != 0 { + // TODO: support sending control message + warn!("control message is not supported now"); + } + None + }; + + (io_vecs, MessageHeader::new(addr, control_message)) + }; + + let total_bytes = socket.sendmsg(&io_vecs, message_header, flags)?; + + Ok(SyscallReturn::Return(total_bytes as _)) +} diff --git a/kernel/aster-nix/src/syscall/sendto.rs b/kernel/aster-nix/src/syscall/sendto.rs index 7e70d4cf7..683ad5428 100644 --- a/kernel/aster-nix/src/syscall/sendto.rs +++ b/kernel/aster-nix/src/syscall/sendto.rs @@ -3,11 +3,11 @@ use super::SyscallReturn; use crate::{ fs::file_table::FileDesc, - net::socket::SendRecvFlags, + net::socket::{MessageHeader, SendRecvFlags}, prelude::*, util::{ net::{get_socket_from_fd, read_socket_addr_from_user}, - read_bytes_from_user, + IoVec, }, }; @@ -27,12 +27,13 @@ pub fn sys_sendto( Some(socket_addr) }; debug!("sockfd = {sockfd}, buf = 0x{buf:x}, len = 0x{len:x}, flags = {flags:?}, socket_addr = {socket_addr:?}"); - let mut buffer = vec![0u8; len]; - read_bytes_from_user(buf, &mut buffer)?; let socket = get_socket_from_fd(sockfd)?; - let send_size = socket.sendto(&buffer, socket_addr, flags)?; + let io_vecs = [IoVec::new(buf, len)]; + let message_header = MessageHeader::new(socket_addr, None); + + let send_size = socket.sendmsg(&io_vecs, message_header, flags)?; Ok(SyscallReturn::Return(send_size as _)) } diff --git a/kernel/aster-nix/src/syscall/writev.rs b/kernel/aster-nix/src/syscall/writev.rs index b136a69d9..accf0239c 100644 --- a/kernel/aster-nix/src/syscall/writev.rs +++ b/kernel/aster-nix/src/syscall/writev.rs @@ -3,20 +3,7 @@ #![allow(dead_code)] use super::SyscallReturn; -use crate::{ - fs::file_table::FileDesc, - prelude::*, - util::{read_bytes_from_user, read_val_from_user}, -}; - -const IOVEC_MAX: usize = 256; - -#[repr(C)] -#[derive(Debug, Clone, Copy, Pod)] -pub struct IoVec { - base: Vaddr, - len: usize, -} +use crate::{fs::file_table::FileDesc, prelude::*, util::copy_iovs_from_user}; pub fn sys_writev(fd: FileDesc, io_vec_ptr: Vaddr, io_vec_count: usize) -> Result { let res = do_sys_writev(fd, io_vec_ptr, io_vec_count)?; @@ -33,19 +20,27 @@ fn do_sys_writev(fd: FileDesc, io_vec_ptr: Vaddr, io_vec_count: usize) -> Result let filetable = current.file_table().lock(); filetable.get_file(fd)?.clone() }; + let mut total_len = 0; - for i in 0..io_vec_count { - let io_vec = read_val_from_user::(io_vec_ptr + i * core::mem::size_of::())?; - if io_vec.base == 0 || io_vec.len == 0 { + + let io_vecs = copy_iovs_from_user(io_vec_ptr, io_vec_count)?; + for io_vec in io_vecs.as_ref() { + if io_vec.is_empty() { continue; } + let buffer = { - let base = io_vec.base; - let len = io_vec.len; - let mut buffer = vec![0u8; len]; - read_bytes_from_user(base, &mut buffer)?; + let mut buffer = vec![0u8; io_vec.len()]; + io_vec.read_exact_from_user(&mut buffer)?; buffer }; + + // FIXME: According to the man page + // at , + // writev must be atomic, + // but the current implementation does not ensure atomicity. + // A suitable fix would be to add a `writev` method for the `FileLike` trait, + // allowing each subsystem to implement atomicity. let write_len = file.write(&buffer)?; total_len += write_len; } diff --git a/kernel/aster-nix/src/util/iovec.rs b/kernel/aster-nix/src/util/iovec.rs new file mode 100644 index 000000000..068efd967 --- /dev/null +++ b/kernel/aster-nix/src/util/iovec.rs @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MPL-2.0 + +use super::read_val_from_user; +use crate::{ + prelude::*, + util::{read_bytes_from_user, write_bytes_to_user}, +}; + +/// A kernel space IO vector. +#[derive(Debug, Clone, Copy)] +pub struct IoVec { + base: Vaddr, + len: usize, +} + +/// A user space IO vector. +/// +/// The difference between `IoVec` and `UserIoVec` +/// is that `UserIoVec` uses `isize` as the length type, +/// while `IoVec` uses `usize`. +#[repr(C)] +#[derive(Debug, Clone, Copy, Pod)] +struct UserIoVec { + base: Vaddr, + len: isize, +} + +impl TryFrom for IoVec { + type Error = Error; + + fn try_from(value: UserIoVec) -> Result { + if value.len < 0 { + return_errno_with_message!(Errno::EINVAL, "the length of IO vector cannot be negative"); + } + + Ok(IoVec { + base: value.base, + len: value.len as usize, + }) + } +} + +impl IoVec { + /// Creates a new `IoVec`. + pub const fn new(base: Vaddr, len: usize) -> Self { + Self { base, len } + } + + /// Returns the base address. + pub const fn base(&self) -> Vaddr { + self.base + } + + /// Returns the length. + pub const fn len(&self) -> usize { + self.len + } + + /// Returns whether the `IoVec` points to an empty user buffer. + pub const fn is_empty(&self) -> bool { + self.len == 0 || self.base == 0 + } + + /// Reads bytes from the user space buffer pointed by + /// the `IoVec` to `dst`. + /// + /// If successful, the read length will be equal to `dst.len()`. + /// + /// # Panics + /// + /// This method will panic if + /// 1.`dst.len()` is not the same as `self.len()`; + /// 2. `self.is_empty()` is `true`. + pub fn read_exact_from_user(&self, dst: &mut [u8]) -> Result<()> { + assert_eq!(dst.len(), self.len); + assert!(!self.is_empty()); + + read_bytes_from_user(self.base, dst) + } + + /// Writes bytes from the `src` buffer + /// to the user space buffer pointed by the `IoVec`. + /// + /// If successful, the written length will be equal to `src.len()`. + /// + /// # Panics + /// + /// This method will panic if + /// 1. `src.len()` is not the same as `self.len()`; + /// 2. `self.is_empty()` is `true`. + pub fn write_exact_to_user(&self, src: &[u8]) -> Result<()> { + assert_eq!(src.len(), self.len); + assert!(!self.is_empty()); + + write_bytes_to_user(self.base, src) + } + + /// Reads bytes to the `dst` buffer + /// from the user space buffer pointed by the `IoVec`. + /// + /// If successful, returns the length of actually read bytes. + pub fn read_from_user(&self, dst: &mut [u8]) -> Result { + let len = self.len.min(dst.len()); + read_bytes_from_user(self.base, &mut dst[..len])?; + Ok(len) + } + + /// Writes bytes from the `src` buffer + /// to the user space buffer pointed by the `IoVec`. + /// + /// If successful, returns the length of actually written bytes. + pub fn write_to_user(&self, src: &[u8]) -> Result { + let len = self.len.min(src.len()); + write_bytes_to_user(self.base, &src[..len])?; + Ok(len) + } +} + +/// Copies IO vectors from user space. +pub fn copy_iovs_from_user(start_addr: Vaddr, count: usize) -> Result> { + let mut io_vecs = Vec::with_capacity(count); + + for idx in 0..count { + let addr = start_addr + idx * core::mem::size_of::(); + let uiov = read_val_from_user::(addr)?; + let iov = IoVec::try_from(uiov)?; + io_vecs.push(iov); + } + + Ok(io_vecs.into_boxed_slice()) +} diff --git a/kernel/aster-nix/src/util/mod.rs b/kernel/aster-nix/src/util/mod.rs index 718637731..5cf53588c 100644 --- a/kernel/aster-nix/src/util/mod.rs +++ b/kernel/aster-nix/src/util/mod.rs @@ -6,9 +6,12 @@ use aster_frame::mm::VmIo; use aster_rights::Full; use crate::{prelude::*, vm::vmar::Vmar}; +mod iovec; pub mod net; pub mod random; +pub use iovec::{copy_iovs_from_user, IoVec}; + /// Read bytes into the `dest` buffer /// from the user space of the current process. /// If successful, diff --git a/kernel/aster-nix/src/util/net/addr.rs b/kernel/aster-nix/src/util/net/addr.rs index 56a4c1101..869e04a99 100644 --- a/kernel/aster-nix/src/util/net/addr.rs +++ b/kernel/aster-nix/src/util/net/addr.rs @@ -79,7 +79,25 @@ pub fn write_socket_addr_to_user( if addrlen_ptr == 0 { return_errno_with_message!(Errno::EINVAL, "must provide the addrlen ptr"); } - let max_len = read_val_from_user::(addrlen_ptr)? as usize; + + let write_size = { + let max_len = read_val_from_user::(addrlen_ptr)?; + write_socket_addr_with_max_len(socket_addr, dest, max_len)? + }; + + if addrlen_ptr != 0 { + write_val_to_user(addrlen_ptr, &write_size)?; + } + Ok(()) +} + +pub fn write_socket_addr_with_max_len( + socket_addr: &SocketAddr, + dest: Vaddr, + max_len: i32, +) -> Result { + let max_len = max_len as usize; + let write_size = match socket_addr { SocketAddr::Unix(path) => { let sock_addr_unix = CSocketAddrUnix::try_from(path)?; @@ -104,10 +122,8 @@ pub fn write_socket_addr_to_user( write_size as i32 } }; - if addrlen_ptr != 0 { - write_val_to_user(addrlen_ptr, &write_size)?; - } - Ok(()) + + Ok(write_size) } /// PlaceHolder diff --git a/kernel/aster-nix/src/util/net/mod.rs b/kernel/aster-nix/src/util/net/mod.rs index e47c40369..7561d9a11 100644 --- a/kernel/aster-nix/src/util/net/mod.rs +++ b/kernel/aster-nix/src/util/net/mod.rs @@ -4,9 +4,12 @@ mod addr; mod options; mod socket; -pub use addr::{read_socket_addr_from_user, write_socket_addr_to_user, CSocketAddrFamily}; +pub use addr::{ + read_socket_addr_from_user, write_socket_addr_to_user, write_socket_addr_with_max_len, + CSocketAddrFamily, +}; pub use options::{new_raw_socket_option, CSocketOptionLevel}; -pub use socket::{Protocol, SockFlags, SockType, SOCK_TYPE_MASK}; +pub use socket::{CUserMsgHdr, Protocol, SockFlags, SockType, SOCK_TYPE_MASK}; use crate::{fs::file_table::FileDesc, net::socket::Socket, prelude::*}; diff --git a/kernel/aster-nix/src/util/net/socket.rs b/kernel/aster-nix/src/util/net/socket.rs index 6464854e7..c398bed19 100644 --- a/kernel/aster-nix/src/util/net/socket.rs +++ b/kernel/aster-nix/src/util/net/socket.rs @@ -1,6 +1,11 @@ // SPDX-License-Identifier: MPL-2.0 -use crate::prelude::*; +use super::read_socket_addr_from_user; +use crate::{ + net::socket::SocketAddr, + prelude::*, + util::{copy_iovs_from_user, net::write_socket_addr_with_max_len, IoVec}, +}; /// Standard well-defined IP protocols. /// From https://elixir.bootlin.com/linux/v6.0.9/source/include/uapi/linux/in.h. @@ -68,3 +73,46 @@ bitflags! { const SOCK_CLOEXEC = 1 << 19; } } + +#[repr(C)] +#[derive(Debug, Clone, Copy, Pod)] +pub struct CUserMsgHdr { + /// Pointer to socket address structure + pub msg_name: Vaddr, + /// Size of socket address + pub msg_namelen: i32, + /// Scatter/Gather iov array + pub msg_iov: Vaddr, + /// The # of elements in msg_iov + pub msg_iovlen: u32, + /// Ancillary data + pub msg_control: Vaddr, + /// Ancillary data buffer length + pub msg_controllen: u32, + /// Flags on received message + pub msg_flags: u32, +} + +impl CUserMsgHdr { + pub fn read_socket_addr_from_user(&self) -> Result> { + if self.msg_name == 0 { + return Ok(None); + } + + let socket_addr = read_socket_addr_from_user(self.msg_name, self.msg_namelen as usize)?; + Ok(Some(socket_addr)) + } + + pub fn write_socket_addr_to_user(&self, addr: &SocketAddr) -> Result<()> { + if self.msg_name == 0 { + return Ok(()); + } + + write_socket_addr_with_max_len(addr, self.msg_name, self.msg_namelen)?; + Ok(()) + } + + pub fn copy_iovs_from_user(&self) -> Result> { + copy_iovs_from_user(self.msg_iov, self.msg_iovlen as usize) + } +} diff --git a/regression/apps/network/tcp_err.c b/regression/apps/network/tcp_err.c index 90ab9d8b4..384f7d179 100644 --- a/regression/apps/network/tcp_err.c +++ b/regression/apps/network/tcp_err.c @@ -6,6 +6,7 @@ #include #include #include +#include #include "test.h" @@ -290,3 +291,73 @@ FN_TEST(async_connect) errlen == sizeof(err) && err == 0); } END_TEST() + +void set_blocking(int sockfd) +{ + int flags = CHECK(fcntl(sockfd, F_GETFL, 0)); + CHECK(fcntl(sockfd, F_SETFL, flags & (~O_NONBLOCK))); +} + +FN_SETUP(enter_blocking_mode) +{ + set_blocking(sk_connected); + set_blocking(sk_bound); +} +END_SETUP() + +FN_TEST(sendmsg_and_recvmsg) +{ + struct msghdr msg = { 0 }; + struct iovec iov[2]; + char *message = "Message:"; + char *message2 = "Hello"; + iov[0].iov_base = message; + iov[0].iov_len = strlen(message); + iov[1].iov_base = message2; + iov[1].iov_len = strlen(message2); + msg.msg_iov = iov; + msg.msg_iovlen = 2; + + // Send one message and recv one message + TEST_RES(sendmsg(sk_connected, &msg, 0), + _ret == strlen(message) + strlen(message2)); + +#define BUFFER_SIZE 50 + char concatenated[BUFFER_SIZE] = { 0 }; + strcat(concatenated, message); + strcat(concatenated, message2); + + char buffer[BUFFER_SIZE] = { 0 }; + iov[0].iov_base = buffer; + iov[0].iov_len = BUFFER_SIZE; + msg.msg_iovlen = 1; + + TEST_RES(recvmsg(sk_accepted, &msg, 0), + _ret == strlen(concatenated) && + strcmp(buffer, concatenated) == 0); + + // Send two message and receive two message + + // This test is commented out due to a known issue: + // See + + // iov[0].iov_base = message; + // iov[0].iov_len = strlen(message); + // msg.msg_iovlen = 1; + // TEST_RES(sendmsg(sk_accepted, &msg, 0), _ret == strlen(message)); + // TEST_RES(sendmsg(sk_accepted, &msg, 0), _ret == strlen(message)); + + // char first_buffer[BUFFER_SIZE] = { 0 }; + // char second_buffer[BUFFER_SIZE] = { 0 }; + // iov[0].iov_base = first_buffer; + // iov[0].iov_len = BUFFER_SIZE; + // iov[1].iov_base = second_buffer; + // iov[1].iov_len = BUFFER_SIZE; + // msg.msg_iovlen = 2; + + // // Ensure two messages are prepared for receiving + // sleep(1); + + // TEST_RES(recvmsg(sk_connected, &msg, 0), _ret == strlen(message) * 2); +} +END_TEST() diff --git a/regression/apps/network/udp_err.c b/regression/apps/network/udp_err.c index 9769d30ab..03b778e5a 100644 --- a/regression/apps/network/udp_err.c +++ b/regression/apps/network/udp_err.c @@ -6,6 +6,7 @@ #include #include #include +#include #include "test.h" @@ -198,3 +199,63 @@ FN_TEST(connect) TEST_SUCC(connect(sk_connected, psaddr, addrlen)); } END_TEST() + +void set_blocking(int sockfd) +{ + int flags = CHECK(fcntl(sockfd, F_GETFL, 0)); + CHECK(fcntl(sockfd, F_SETFL, flags & (~O_NONBLOCK))); +} + +FN_SETUP(enter_blocking_mode) +{ + set_blocking(sk_connected); + set_blocking(sk_bound); +} +END_SETUP() + +FN_TEST(sendmsg_and_recvmsg) +{ + struct sockaddr_in saddr; + socklen_t addrlen = sizeof(saddr); + + sk_addr.sin_port = C_PORT; + + struct msghdr msg = { 0 }; + struct iovec iov[1]; + char *message = "Message"; + iov[0].iov_base = message; + iov[0].iov_len = strlen(message); + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_name = (struct sockaddr *)&sk_addr; + msg.msg_namelen = addrlen; + + // Send one message and receive one message + TEST_RES(sendmsg(sk_connected, &msg, 0), _ret == strlen(message)); + +#define BUFFER_SIZE 50 + char buffer[BUFFER_SIZE]; + iov[0].iov_base = buffer; + iov[0].iov_len = BUFFER_SIZE; + msg.msg_name = 0; + TEST_RES(recvmsg(sk_bound, &msg, 0), + _ret == strlen(message) && strcmp(message, buffer) == 0); + + // Send two messages and receive two messages + iov[0].iov_base = message; + iov[0].iov_len = strlen(message); + msg.msg_name = (struct sockaddr *)&sk_addr; + msg.msg_namelen = addrlen; + + TEST_RES(sendmsg(sk_connected, &msg, 0), _ret == strlen(message)); + TEST_RES(sendmsg(sk_connected, &msg, 0), _ret == strlen(message)); + + iov[0].iov_base = buffer; + iov[0].iov_len = BUFFER_SIZE; + + TEST_RES(recvmsg(sk_bound, &msg, 0), + _ret == strlen(message) && strcmp(message, buffer) == 0); + TEST_RES(recvmsg(sk_bound, &msg, 0), + _ret == strlen(message) && strcmp(message, buffer) == 0); +} +END_TEST()