From c9f939bcc45620d3e2e0c22f74901768b878a815 Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Sun, 20 Apr 2025 16:12:14 +0800 Subject: [PATCH] Create `datagram_common` and use it in UDP --- kernel/src/net/socket/ip/datagram/bound.rs | 34 +-- kernel/src/net/socket/ip/datagram/mod.rs | 236 ++++-------------- kernel/src/net/socket/ip/datagram/unbound.rs | 66 +++-- kernel/src/net/socket/util/datagram_common.rs | 221 ++++++++++++++++ kernel/src/net/socket/util/mod.rs | 1 + 5 files changed, 333 insertions(+), 225 deletions(-) create mode 100644 kernel/src/net/socket/util/datagram_common.rs diff --git a/kernel/src/net/socket/ip/datagram/bound.rs b/kernel/src/net/socket/ip/datagram/bound.rs index 784c8e1fb..0bc4f4737 100644 --- a/kernel/src/net/socket/ip/datagram/bound.rs +++ b/kernel/src/net/socket/ip/datagram/bound.rs @@ -9,46 +9,50 @@ use crate::{ events::IoEvents, net::{ iface::{Iface, UdpSocket}, - socket::util::send_recv_flags::SendRecvFlags, + socket::util::{datagram_common, send_recv_flags::SendRecvFlags}, }, prelude::*, util::{MultiRead, MultiWrite}, }; -pub struct BoundDatagram { +pub(super) struct BoundDatagram { bound_socket: UdpSocket, remote_endpoint: Option, } impl BoundDatagram { - pub fn new(bound_socket: UdpSocket) -> Self { + pub(super) fn new(bound_socket: UdpSocket) -> Self { Self { bound_socket, remote_endpoint: None, } } - pub fn local_endpoint(&self) -> IpEndpoint { + pub(super) fn iface(&self) -> &Arc { + self.bound_socket.iface() + } +} + +impl datagram_common::Bound for BoundDatagram { + type Endpoint = IpEndpoint; + + fn local_endpoint(&self) -> Self::Endpoint { self.bound_socket.local_endpoint().unwrap() } - pub fn remote_endpoint(&self) -> Option<&IpEndpoint> { + fn remote_endpoint(&self) -> Option<&Self::Endpoint> { self.remote_endpoint.as_ref() } - pub fn set_remote_endpoint(&mut self, endpoint: &IpEndpoint) { + fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint) { self.remote_endpoint = Some(*endpoint) } - pub fn iface(&self) -> &Arc { - self.bound_socket.iface() - } - - pub fn try_recv( + fn try_recv( &self, writer: &mut dyn MultiWrite, _flags: SendRecvFlags, - ) -> Result<(usize, IpEndpoint)> { + ) -> Result<(usize, Self::Endpoint)> { let result = self.bound_socket.recv(|packet, udp_metadata| { let copied_res = writer.write(&mut VmReader::from(packet)); let endpoint = udp_metadata.endpoint; @@ -67,10 +71,10 @@ impl BoundDatagram { } } - pub fn try_send( + fn try_send( &self, reader: &mut dyn MultiRead, - remote: &IpEndpoint, + remote: &Self::Endpoint, _flags: SendRecvFlags, ) -> Result { let result = self @@ -99,7 +103,7 @@ impl BoundDatagram { } } - pub(super) fn check_io_events(&self) -> IoEvents { + fn check_io_events(&self) -> IoEvents { self.bound_socket.raw_with(|socket| { let mut events = IoEvents::empty(); diff --git a/kernel/src/net/socket/ip/datagram/mod.rs b/kernel/src/net/socket/ip/datagram/mod.rs index 7bad0b5eb..f560a6410 100644 --- a/kernel/src/net/socket/ip/datagram/mod.rs +++ b/kernel/src/net/socket/ip/datagram/mod.rs @@ -3,11 +3,10 @@ use core::sync::atomic::{AtomicBool, Ordering}; use aster_bigtcp::wire::IpEndpoint; -use ostd::sync::PreemptDisabled; -use takeable::Takeable; +use unbound::BindOptions; use self::{bound::BoundDatagram, unbound::UnboundDatagram}; -use super::{common::get_ephemeral_endpoint, UNSPECIFIED_LOCAL_ENDPOINT}; +use super::UNSPECIFIED_LOCAL_ENDPOINT; use crate::{ events::IoEvents, match_sock_option_mut, @@ -15,6 +14,7 @@ use crate::{ options::{Error as SocketError, SocketOption}, private::SocketPrivate, util::{ + datagram_common::{select_remote_and_bind, Bound, Inner}, options::{SetSocketLevelOption, SocketOptionSet}, send_recv_flags::SendRecvFlags, socket_addr::SocketAddr, @@ -48,160 +48,32 @@ impl OptionSet { pub struct DatagramSocket { // Lock order: `inner` first, `options` second - inner: RwLock, PreemptDisabled>, + inner: RwMutex>, options: RwLock, is_nonblocking: AtomicBool, pollee: Pollee, } -enum Inner { - Unbound(UnboundDatagram), - Bound(BoundDatagram), -} - -impl Inner { - fn bind( - self, - endpoint: &IpEndpoint, - can_reuse: bool, - observer: DatagramObserver, - ) -> core::result::Result { - let unbound_datagram = match self { - Inner::Unbound(unbound_datagram) => unbound_datagram, - Inner::Bound(bound_datagram) => { - return Err(( - Error::with_message(Errno::EINVAL, "the socket is already bound to an address"), - Inner::Bound(bound_datagram), - )); - } - }; - - let bound_datagram = match unbound_datagram.bind(endpoint, can_reuse, observer) { - Ok(bound_datagram) => bound_datagram, - Err((err, unbound_datagram)) => return Err((err, Inner::Unbound(unbound_datagram))), - }; - Ok(bound_datagram) - } - - fn bind_to_ephemeral_endpoint( - self, - remote_endpoint: &IpEndpoint, - observer: DatagramObserver, - ) -> core::result::Result { - if let Inner::Bound(bound_datagram) = self { - return Ok(bound_datagram); - } - - let endpoint = get_ephemeral_endpoint(remote_endpoint); - self.bind(&endpoint, false, observer) - } -} - impl DatagramSocket { pub fn new(is_nonblocking: bool) -> Arc { let unbound_datagram = UnboundDatagram::new(); Arc::new(Self { - inner: RwLock::new(Takeable::new(Inner::Unbound(unbound_datagram))), + inner: RwMutex::new(Inner::Unbound(unbound_datagram)), options: RwLock::new(OptionSet::new()), is_nonblocking: AtomicBool::new(is_nonblocking), pollee: Pollee::new(), }) } - fn try_bind_ephemeral(&self, remote_endpoint: &IpEndpoint) -> Result<()> { - // Fast path - if let Inner::Bound(_) = self.inner.read().as_ref() { - return Ok(()); - } - - // Slow path - let mut inner = self.inner.write(); - inner.borrow_result(|owned_inner| { - let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint( - remote_endpoint, - DatagramObserver::new(self.pollee.clone()), - ) { - Ok(bound_datagram) => bound_datagram, - Err((err, err_inner)) => { - return (err_inner, Err(err)); - } - }; - (Inner::Bound(bound_datagram), Ok(())) - }) - } - - /// Selects the remote endpoint and binds if the socket is not bound. - /// - /// The remote endpoint specified in the system call (e.g., `sendto`) argument is preferred, - /// otherwise the connected endpoint of the socket is used. If there are no remote endpoints - /// available, this method will fail with [`EDESTADDRREQ`]. - /// - /// If the remote endpoint is specified but the socket is not bound, this method will try to - /// bind the socket to an ephemeral endpoint. - /// - /// If the above steps succeed, `op` will be called with the bound socket and the selected - /// remote endpoint. - /// - /// [`EDESTADDRREQ`]: crate::error::Errno::EDESTADDRREQ - fn select_remote_and_bind(&self, remote: Option<&IpEndpoint>, op: F) -> Result - where - F: FnOnce(&BoundDatagram, &IpEndpoint) -> Result, - { - let mut inner = self.inner.read(); - - // Not really a loop, since we always break on the first iteration. But we need to use - // `loop` here because we want to use `break` later. - #[expect(clippy::never_loop)] - let bound_datagram = loop { - // Fast path: The socket is already bound. - if let Inner::Bound(bound_datagram) = inner.as_ref() { - break bound_datagram; - } - - // Slow path: Try to bind the socket to an ephemeral endpoint. - drop(inner); - if let Some(remote_endpoint) = remote { - self.try_bind_ephemeral(remote_endpoint)?; - } else { - return_errno_with_message!( - Errno::EDESTADDRREQ, - "the destination address is not specified" - ); - } - inner = self.inner.read(); - - // Now the socket must be bound. - if let Inner::Bound(bound_datagram) = inner.as_ref() { - break bound_datagram; - } - unreachable!("`try_bind_ephemeral` succeeds so the socket cannot be unbound"); - }; - - let remote_endpoint = remote - .or_else(|| bound_datagram.remote_endpoint()) - .ok_or_else(|| { - Error::with_message( - Errno::EDESTADDRREQ, - "the destination address is not specified", - ) - })?; - - op(bound_datagram, remote_endpoint) - } - fn try_recv( &self, writer: &mut dyn MultiWrite, flags: SendRecvFlags, ) -> Result<(usize, SocketAddr)> { - let inner = self.inner.read(); - - let Inner::Bound(bound_datagram) = inner.as_ref() else { - return_errno_with_message!(Errno::EAGAIN, "the socket is not bound"); - }; - - let recv_bytes = bound_datagram + let recv_bytes = self + .inner + .read() .try_recv(writer, flags) .map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()))?; self.pollee.invalidate(); @@ -215,33 +87,38 @@ impl DatagramSocket { remote: Option<&IpEndpoint>, flags: SendRecvFlags, ) -> Result { - let (sent_bytes, iface_to_poll) = - self.select_remote_and_bind(remote, |bound_datagram, remote_endpoint| { + let (sent_bytes, iface_to_poll) = select_remote_and_bind( + &self.inner, + remote, + || { + let remote_endpoint = remote.ok_or_else(|| { + Error::with_message( + Errno::EDESTADDRREQ, + "the destination address is not specified", + ) + })?; + self.inner + .write() + .bind_ephemeral(remote_endpoint, &self.pollee) + }, + |bound_datagram, remote_endpoint| { let sent_bytes = bound_datagram.try_send(reader, remote_endpoint, flags)?; let iface_to_poll = bound_datagram.iface().clone(); Ok((sent_bytes, iface_to_poll)) - })?; + }, + )?; self.pollee.invalidate(); iface_to_poll.poll(); Ok(sent_bytes) } - - fn check_io_events(&self) -> IoEvents { - let inner = self.inner.read(); - - match inner.as_ref() { - Inner::Unbound(unbound_datagram) => unbound_datagram.check_io_events(), - Inner::Bound(bound_socket) => bound_socket.check_io_events(), - } - } } impl Pollable for DatagramSocket { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { self.pollee - .poll_with(mask, poller, || self.check_io_events()) + .poll_with(mask, poller, || self.inner.read().check_io_events()) } } @@ -258,57 +135,36 @@ impl SocketPrivate for DatagramSocket { impl Socket for DatagramSocket { fn bind(&self, socket_addr: SocketAddr) -> Result<()> { let endpoint = socket_addr.try_into()?; - - let mut inner = self.inner.write(); let can_reuse = self.options.read().socket.reuse_addr(); - inner.borrow_result(|owned_inner| { - let bound_datagram = match owned_inner.bind( - &endpoint, - can_reuse, - DatagramObserver::new(self.pollee.clone()), - ) { - Ok(bound_datagram) => bound_datagram, - Err((err, err_inner)) => { - return (err_inner, Err(err)); - } - }; - (Inner::Bound(bound_datagram), Ok(())) - }) + + self.inner + .write() + .bind(&endpoint, &self.pollee, BindOptions { can_reuse }) } fn connect(&self, socket_addr: SocketAddr) -> Result<()> { let endpoint = socket_addr.try_into()?; - self.try_bind_ephemeral(&endpoint)?; - - let mut inner = self.inner.write(); - let Inner::Bound(bound_datagram) = inner.as_mut() else { - return_errno_with_message!(Errno::EINVAL, "the socket is not bound") - }; - bound_datagram.set_remote_endpoint(&endpoint); - - Ok(()) + self.inner.write().connect(&endpoint, &self.pollee) } fn addr(&self) -> Result { - let inner = self.inner.read(); - match inner.as_ref() { - Inner::Unbound(_) => Ok(UNSPECIFIED_LOCAL_ENDPOINT.into()), - Inner::Bound(bound_datagram) => Ok(bound_datagram.local_endpoint().into()), - } + let endpoint = self + .inner + .read() + .addr() + .unwrap_or(UNSPECIFIED_LOCAL_ENDPOINT); + + Ok(endpoint.into()) } fn peer_addr(&self) -> Result { - let inner = self.inner.read(); + let endpoint = + *self.inner.read().peer_addr().ok_or_else(|| { + Error::with_message(Errno::ENOTCONN, "the socket is not connected") + })?; - let remote_endpoint = match inner.as_ref() { - Inner::Unbound(_) => None, - Inner::Bound(bound_datagram) => bound_datagram.remote_endpoint(), - }; - - remote_endpoint - .map(|endpoint| (*endpoint).into()) - .ok_or_else(|| Error::with_message(Errno::ENOTCONN, "the socket is not connected")) + Ok(endpoint.into()) } fn sendmsg( @@ -378,11 +234,11 @@ impl Socket for DatagramSocket { let inner = self.inner.read(); let mut options = self.options.write(); - match options.socket.set_option(option, inner.as_ref()) { + match options.socket.set_option(option, &*inner) { Err(e) => Err(e), Ok(need_iface_poll) => { let iface_to_poll = need_iface_poll - .then(|| match inner.as_ref() { + .then(|| match &*inner { Inner::Unbound(_) => None, Inner::Bound(bound_datagram) => Some(bound_datagram.iface().clone()), }) @@ -401,4 +257,4 @@ impl Socket for DatagramSocket { } } -impl SetSocketLevelOption for Inner {} +impl SetSocketLevelOption for Inner {} diff --git a/kernel/src/net/socket/ip/datagram/unbound.rs b/kernel/src/net/socket/ip/datagram/unbound.rs index 2ba33dea7..a372652da 100644 --- a/kernel/src/net/socket/ip/datagram/unbound.rs +++ b/kernel/src/net/socket/ip/datagram/unbound.rs @@ -3,39 +3,65 @@ use aster_bigtcp::{socket::UdpSocket, wire::IpEndpoint}; use super::{bound::BoundDatagram, DatagramObserver}; -use crate::{events::IoEvents, net::socket::ip::common::bind_port, prelude::*}; +use crate::{ + events::IoEvents, + net::socket::{ + ip::common::{bind_port, get_ephemeral_endpoint}, + util::datagram_common, + }, + prelude::*, + process::signal::Pollee, +}; -pub struct UnboundDatagram { +pub(super) struct UnboundDatagram { _private: (), } impl UnboundDatagram { - pub fn new() -> Self { + pub(super) fn new() -> Self { Self { _private: () } } +} - pub fn bind( - self, - endpoint: &IpEndpoint, - can_reuse: bool, - observer: DatagramObserver, - ) -> core::result::Result { - let bound_port = match bind_port(endpoint, can_reuse) { - Ok(bound_port) => bound_port, - Err(err) => return Err((err, self)), - }; +pub(super) struct BindOptions { + pub(super) can_reuse: bool, +} - let bound_socket = match UdpSocket::new_bind(bound_port, observer) { - Ok(bound_socket) => bound_socket, - Err((_, err)) => { - unreachable!("`new_bind fails with {:?}, which should not happen", err) - } - }; +impl datagram_common::Unbound for UnboundDatagram { + type Endpoint = IpEndpoint; + type BindOptions = BindOptions; + + type Bound = BoundDatagram; + + fn bind( + &mut self, + endpoint: &Self::Endpoint, + pollee: &Pollee, + options: BindOptions, + ) -> Result { + let bound_port = bind_port(endpoint, options.can_reuse)?; + + let bound_socket = + match UdpSocket::new_bind(bound_port, DatagramObserver::new(pollee.clone())) { + Ok(bound_socket) => bound_socket, + Err((_, err)) => { + unreachable!("`new_bind` fails with {:?}, which should not happen", err) + } + }; Ok(BoundDatagram::new(bound_socket)) } - pub(super) fn check_io_events(&self) -> IoEvents { + fn bind_ephemeral( + &mut self, + remote_endpoint: &Self::Endpoint, + pollee: &Pollee, + ) -> Result { + let endpoint = get_ephemeral_endpoint(remote_endpoint); + self.bind(&endpoint, pollee, BindOptions { can_reuse: false }) + } + + fn check_io_events(&self) -> IoEvents { IoEvents::OUT } } diff --git a/kernel/src/net/socket/util/datagram_common.rs b/kernel/src/net/socket/util/datagram_common.rs new file mode 100644 index 000000000..c9bfc971f --- /dev/null +++ b/kernel/src/net/socket/util/datagram_common.rs @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: MPL-2.0 + +use ostd::sync::RwMutex; + +use super::send_recv_flags::SendRecvFlags; +use crate::{ + events::IoEvents, + process::signal::Pollee, + return_errno_with_message, + util::{MultiRead, MultiWrite}, + Errno, Error, Result, +}; + +pub trait Unbound { + type Endpoint; + type BindOptions; + + type Bound; + + fn bind( + &mut self, + endpoint: &Self::Endpoint, + pollee: &Pollee, + options: Self::BindOptions, + ) -> Result; + fn bind_ephemeral( + &mut self, + remote_endpoint: &Self::Endpoint, + pollee: &Pollee, + ) -> Result; + + fn check_io_events(&self) -> IoEvents; +} + +pub trait Bound { + type Endpoint; + + fn local_endpoint(&self) -> Self::Endpoint; + fn remote_endpoint(&self) -> Option<&Self::Endpoint>; + fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint); + + fn try_recv( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, Self::Endpoint)>; + fn try_send( + &self, + reader: &mut dyn MultiRead, + remote: &Self::Endpoint, + flags: SendRecvFlags, + ) -> Result; + + fn check_io_events(&self) -> IoEvents; +} + +pub enum Inner { + Unbound(UnboundSocket), + Bound(BoundSocket), +} + +impl Inner +where + UnboundSocket: Unbound, + BoundSocket: Bound, +{ + pub fn bind( + &mut self, + endpoint: &UnboundSocket::Endpoint, + pollee: &Pollee, + options: UnboundSocket::BindOptions, + ) -> Result<()> { + let unbound_datagram = match self { + Inner::Unbound(unbound_datagram) => unbound_datagram, + Inner::Bound(_) => { + return_errno_with_message!( + Errno::EINVAL, + "the socket is already bound to an address" + ) + } + }; + + let bound_datagram = unbound_datagram.bind(endpoint, pollee, options)?; + *self = Inner::Bound(bound_datagram); + + Ok(()) + } + + pub fn bind_ephemeral( + &mut self, + remote_endpoint: &UnboundSocket::Endpoint, + pollee: &Pollee, + ) -> Result<()> { + let unbound_datagram = match self { + Inner::Unbound(unbound_datagram) => unbound_datagram, + Inner::Bound(_) => return Ok(()), + }; + + let bound_datagram = unbound_datagram.bind_ephemeral(remote_endpoint, pollee)?; + *self = Inner::Bound(bound_datagram); + + Ok(()) + } + + pub fn connect( + &mut self, + remote_endpoint: &UnboundSocket::Endpoint, + pollee: &Pollee, + ) -> Result<()> { + self.bind_ephemeral(remote_endpoint, pollee)?; + + let bound_datagram = match self { + Inner::Unbound(_) => { + unreachable!( + "`bind_to_ephemeral_endpoint` succeeds so the socket cannot be unbound" + ); + } + Inner::Bound(bound_datagram) => bound_datagram, + }; + bound_datagram.set_remote_endpoint(remote_endpoint); + + Ok(()) + } + + pub fn addr(&self) -> Option { + match self { + Inner::Unbound(_) => None, + Inner::Bound(bound_datagram) => Some(bound_datagram.local_endpoint()), + } + } + + pub fn peer_addr(&self) -> Option<&UnboundSocket::Endpoint> { + match self { + Inner::Unbound(_) => None, + Inner::Bound(bound_datagram) => bound_datagram.remote_endpoint(), + } + } + + pub fn try_recv( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, UnboundSocket::Endpoint)> { + match self { + Inner::Unbound(_) => { + return_errno_with_message!(Errno::EAGAIN, "the socket is not bound"); + } + Inner::Bound(bound_datagram) => bound_datagram.try_recv(writer, flags), + } + } + + // If you're looking for `try_send`, there isn't one. Use `select_remote_and_bind` below and + // call `Bound::try_send` directly. + + pub fn check_io_events(&self) -> IoEvents { + match self { + Inner::Unbound(unbound_datagram) => unbound_datagram.check_io_events(), + Inner::Bound(bound_datagram) => bound_datagram.check_io_events(), + } + } +} + +/// Selects the remote endpoint and binds if the socket is not bound. +/// +/// The remote endpoint specified in the system call (e.g., `sendto`) argument is preferred, +/// otherwise the connected endpoint of the socket is used. If there are no remote endpoints +/// available, this method will fail with [`EDESTADDRREQ`]. +/// +/// If the remote endpoint is specified but the socket is not bound, this method will try to +/// bind the socket to an ephemeral endpoint. +/// +/// If the above steps succeed, `op` will be called with the bound socket and the selected +/// remote endpoint. +/// +/// [`EDESTADDRREQ`]: crate::error::Errno::EDESTADDRREQ +pub fn select_remote_and_bind( + inner_mutex: &RwMutex>, + remote: Option<&UnboundSocket::Endpoint>, + bind_ephemeral: B, + op: F, +) -> Result +where + UnboundSocket: Unbound, + BoundSocket: Bound, + B: FnOnce() -> Result<()>, + F: FnOnce(&BoundSocket, &UnboundSocket::Endpoint) -> Result, +{ + let mut inner = inner_mutex.read(); + + // Not really a loop, since we always break on the first iteration. But we need to use + // `loop` here because we want to use `break` later. + #[expect(clippy::never_loop)] + let bound_datagram = loop { + // Fast path: The socket is already bound. + if let Inner::Bound(bound_datagram) = &*inner { + break bound_datagram; + } + + // Slow path: Try to bind the socket to an ephemeral endpoint. + drop(inner); + bind_ephemeral()?; + inner = inner_mutex.read(); + + // Now the socket must be bound. + if let Inner::Bound(bound_datagram) = &*inner { + break bound_datagram; + } + unreachable!("`try_bind_ephemeral` succeeds so the socket cannot be unbound"); + }; + + let remote_endpoint = remote + .or_else(|| bound_datagram.remote_endpoint()) + .ok_or_else(|| { + Error::with_message( + Errno::EDESTADDRREQ, + "the destination address is not specified", + ) + })?; + + op(bound_datagram, remote_endpoint) +} diff --git a/kernel/src/net/socket/util/mod.rs b/kernel/src/net/socket/util/mod.rs index 72694110e..39706922e 100644 --- a/kernel/src/net/socket/util/mod.rs +++ b/kernel/src/net/socket/util/mod.rs @@ -1,5 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 +pub mod datagram_common; mod message_header; pub mod options; pub mod send_recv_flags;