Create datagram_common and use it in UDP

This commit is contained in:
Ruihan Li
2025-04-20 16:12:14 +08:00
committed by Tate, Hongliang Tian
parent 96e62b8fa5
commit c9f939bcc4
5 changed files with 333 additions and 225 deletions

View File

@ -9,46 +9,50 @@ use crate::{
events::IoEvents,
net::{
iface::{Iface, UdpSocket},
socket::util::send_recv_flags::SendRecvFlags,
socket::util::{datagram_common, send_recv_flags::SendRecvFlags},
},
prelude::*,
util::{MultiRead, MultiWrite},
};
pub struct BoundDatagram {
pub(super) struct BoundDatagram {
bound_socket: UdpSocket,
remote_endpoint: Option<IpEndpoint>,
}
impl BoundDatagram {
pub fn new(bound_socket: UdpSocket) -> Self {
pub(super) fn new(bound_socket: UdpSocket) -> Self {
Self {
bound_socket,
remote_endpoint: None,
}
}
pub fn local_endpoint(&self) -> IpEndpoint {
pub(super) fn iface(&self) -> &Arc<Iface> {
self.bound_socket.iface()
}
}
impl datagram_common::Bound for BoundDatagram {
type Endpoint = IpEndpoint;
fn local_endpoint(&self) -> Self::Endpoint {
self.bound_socket.local_endpoint().unwrap()
}
pub fn remote_endpoint(&self) -> Option<&IpEndpoint> {
fn remote_endpoint(&self) -> Option<&Self::Endpoint> {
self.remote_endpoint.as_ref()
}
pub fn set_remote_endpoint(&mut self, endpoint: &IpEndpoint) {
fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint) {
self.remote_endpoint = Some(*endpoint)
}
pub fn iface(&self) -> &Arc<Iface> {
self.bound_socket.iface()
}
pub fn try_recv(
fn try_recv(
&self,
writer: &mut dyn MultiWrite,
_flags: SendRecvFlags,
) -> Result<(usize, IpEndpoint)> {
) -> Result<(usize, Self::Endpoint)> {
let result = self.bound_socket.recv(|packet, udp_metadata| {
let copied_res = writer.write(&mut VmReader::from(packet));
let endpoint = udp_metadata.endpoint;
@ -67,10 +71,10 @@ impl BoundDatagram {
}
}
pub fn try_send(
fn try_send(
&self,
reader: &mut dyn MultiRead,
remote: &IpEndpoint,
remote: &Self::Endpoint,
_flags: SendRecvFlags,
) -> Result<usize> {
let result = self
@ -99,7 +103,7 @@ impl BoundDatagram {
}
}
pub(super) fn check_io_events(&self) -> IoEvents {
fn check_io_events(&self) -> IoEvents {
self.bound_socket.raw_with(|socket| {
let mut events = IoEvents::empty();

View File

@ -3,11 +3,10 @@
use core::sync::atomic::{AtomicBool, Ordering};
use aster_bigtcp::wire::IpEndpoint;
use ostd::sync::PreemptDisabled;
use takeable::Takeable;
use unbound::BindOptions;
use self::{bound::BoundDatagram, unbound::UnboundDatagram};
use super::{common::get_ephemeral_endpoint, UNSPECIFIED_LOCAL_ENDPOINT};
use super::UNSPECIFIED_LOCAL_ENDPOINT;
use crate::{
events::IoEvents,
match_sock_option_mut,
@ -15,6 +14,7 @@ use crate::{
options::{Error as SocketError, SocketOption},
private::SocketPrivate,
util::{
datagram_common::{select_remote_and_bind, Bound, Inner},
options::{SetSocketLevelOption, SocketOptionSet},
send_recv_flags::SendRecvFlags,
socket_addr::SocketAddr,
@ -48,160 +48,32 @@ impl OptionSet {
pub struct DatagramSocket {
// Lock order: `inner` first, `options` second
inner: RwLock<Takeable<Inner>, PreemptDisabled>,
inner: RwMutex<Inner<UnboundDatagram, BoundDatagram>>,
options: RwLock<OptionSet>,
is_nonblocking: AtomicBool,
pollee: Pollee,
}
enum Inner {
Unbound(UnboundDatagram),
Bound(BoundDatagram),
}
impl Inner {
fn bind(
self,
endpoint: &IpEndpoint,
can_reuse: bool,
observer: DatagramObserver,
) -> core::result::Result<BoundDatagram, (Error, Self)> {
let unbound_datagram = match self {
Inner::Unbound(unbound_datagram) => unbound_datagram,
Inner::Bound(bound_datagram) => {
return Err((
Error::with_message(Errno::EINVAL, "the socket is already bound to an address"),
Inner::Bound(bound_datagram),
));
}
};
let bound_datagram = match unbound_datagram.bind(endpoint, can_reuse, observer) {
Ok(bound_datagram) => bound_datagram,
Err((err, unbound_datagram)) => return Err((err, Inner::Unbound(unbound_datagram))),
};
Ok(bound_datagram)
}
fn bind_to_ephemeral_endpoint(
self,
remote_endpoint: &IpEndpoint,
observer: DatagramObserver,
) -> core::result::Result<BoundDatagram, (Error, Self)> {
if let Inner::Bound(bound_datagram) = self {
return Ok(bound_datagram);
}
let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(&endpoint, false, observer)
}
}
impl DatagramSocket {
pub fn new(is_nonblocking: bool) -> Arc<Self> {
let unbound_datagram = UnboundDatagram::new();
Arc::new(Self {
inner: RwLock::new(Takeable::new(Inner::Unbound(unbound_datagram))),
inner: RwMutex::new(Inner::Unbound(unbound_datagram)),
options: RwLock::new(OptionSet::new()),
is_nonblocking: AtomicBool::new(is_nonblocking),
pollee: Pollee::new(),
})
}
fn try_bind_ephemeral(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
// Fast path
if let Inner::Bound(_) = self.inner.read().as_ref() {
return Ok(());
}
// Slow path
let mut inner = self.inner.write();
inner.borrow_result(|owned_inner| {
let bound_datagram = match owned_inner.bind_to_ephemeral_endpoint(
remote_endpoint,
DatagramObserver::new(self.pollee.clone()),
) {
Ok(bound_datagram) => bound_datagram,
Err((err, err_inner)) => {
return (err_inner, Err(err));
}
};
(Inner::Bound(bound_datagram), Ok(()))
})
}
/// Selects the remote endpoint and binds if the socket is not bound.
///
/// The remote endpoint specified in the system call (e.g., `sendto`) argument is preferred,
/// otherwise the connected endpoint of the socket is used. If there are no remote endpoints
/// available, this method will fail with [`EDESTADDRREQ`].
///
/// If the remote endpoint is specified but the socket is not bound, this method will try to
/// bind the socket to an ephemeral endpoint.
///
/// If the above steps succeed, `op` will be called with the bound socket and the selected
/// remote endpoint.
///
/// [`EDESTADDRREQ`]: crate::error::Errno::EDESTADDRREQ
fn select_remote_and_bind<F, R>(&self, remote: Option<&IpEndpoint>, op: F) -> Result<R>
where
F: FnOnce(&BoundDatagram, &IpEndpoint) -> Result<R>,
{
let mut inner = self.inner.read();
// Not really a loop, since we always break on the first iteration. But we need to use
// `loop` here because we want to use `break` later.
#[expect(clippy::never_loop)]
let bound_datagram = loop {
// Fast path: The socket is already bound.
if let Inner::Bound(bound_datagram) = inner.as_ref() {
break bound_datagram;
}
// Slow path: Try to bind the socket to an ephemeral endpoint.
drop(inner);
if let Some(remote_endpoint) = remote {
self.try_bind_ephemeral(remote_endpoint)?;
} else {
return_errno_with_message!(
Errno::EDESTADDRREQ,
"the destination address is not specified"
);
}
inner = self.inner.read();
// Now the socket must be bound.
if let Inner::Bound(bound_datagram) = inner.as_ref() {
break bound_datagram;
}
unreachable!("`try_bind_ephemeral` succeeds so the socket cannot be unbound");
};
let remote_endpoint = remote
.or_else(|| bound_datagram.remote_endpoint())
.ok_or_else(|| {
Error::with_message(
Errno::EDESTADDRREQ,
"the destination address is not specified",
)
})?;
op(bound_datagram, remote_endpoint)
}
fn try_recv(
&self,
writer: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, SocketAddr)> {
let inner = self.inner.read();
let Inner::Bound(bound_datagram) = inner.as_ref() else {
return_errno_with_message!(Errno::EAGAIN, "the socket is not bound");
};
let recv_bytes = bound_datagram
let recv_bytes = self
.inner
.read()
.try_recv(writer, flags)
.map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()))?;
self.pollee.invalidate();
@ -215,33 +87,38 @@ impl DatagramSocket {
remote: Option<&IpEndpoint>,
flags: SendRecvFlags,
) -> Result<usize> {
let (sent_bytes, iface_to_poll) =
self.select_remote_and_bind(remote, |bound_datagram, remote_endpoint| {
let (sent_bytes, iface_to_poll) = select_remote_and_bind(
&self.inner,
remote,
|| {
let remote_endpoint = remote.ok_or_else(|| {
Error::with_message(
Errno::EDESTADDRREQ,
"the destination address is not specified",
)
})?;
self.inner
.write()
.bind_ephemeral(remote_endpoint, &self.pollee)
},
|bound_datagram, remote_endpoint| {
let sent_bytes = bound_datagram.try_send(reader, remote_endpoint, flags)?;
let iface_to_poll = bound_datagram.iface().clone();
Ok((sent_bytes, iface_to_poll))
})?;
},
)?;
self.pollee.invalidate();
iface_to_poll.poll();
Ok(sent_bytes)
}
fn check_io_events(&self) -> IoEvents {
let inner = self.inner.read();
match inner.as_ref() {
Inner::Unbound(unbound_datagram) => unbound_datagram.check_io_events(),
Inner::Bound(bound_socket) => bound_socket.check_io_events(),
}
}
}
impl Pollable for DatagramSocket {
fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee
.poll_with(mask, poller, || self.check_io_events())
.poll_with(mask, poller, || self.inner.read().check_io_events())
}
}
@ -258,57 +135,36 @@ impl SocketPrivate for DatagramSocket {
impl Socket for DatagramSocket {
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?;
let mut inner = self.inner.write();
let can_reuse = self.options.read().socket.reuse_addr();
inner.borrow_result(|owned_inner| {
let bound_datagram = match owned_inner.bind(
&endpoint,
can_reuse,
DatagramObserver::new(self.pollee.clone()),
) {
Ok(bound_datagram) => bound_datagram,
Err((err, err_inner)) => {
return (err_inner, Err(err));
}
};
(Inner::Bound(bound_datagram), Ok(()))
})
self.inner
.write()
.bind(&endpoint, &self.pollee, BindOptions { can_reuse })
}
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?;
self.try_bind_ephemeral(&endpoint)?;
let mut inner = self.inner.write();
let Inner::Bound(bound_datagram) = inner.as_mut() else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound")
};
bound_datagram.set_remote_endpoint(&endpoint);
Ok(())
self.inner.write().connect(&endpoint, &self.pollee)
}
fn addr(&self) -> Result<SocketAddr> {
let inner = self.inner.read();
match inner.as_ref() {
Inner::Unbound(_) => Ok(UNSPECIFIED_LOCAL_ENDPOINT.into()),
Inner::Bound(bound_datagram) => Ok(bound_datagram.local_endpoint().into()),
}
let endpoint = self
.inner
.read()
.addr()
.unwrap_or(UNSPECIFIED_LOCAL_ENDPOINT);
Ok(endpoint.into())
}
fn peer_addr(&self) -> Result<SocketAddr> {
let inner = self.inner.read();
let endpoint =
*self.inner.read().peer_addr().ok_or_else(|| {
Error::with_message(Errno::ENOTCONN, "the socket is not connected")
})?;
let remote_endpoint = match inner.as_ref() {
Inner::Unbound(_) => None,
Inner::Bound(bound_datagram) => bound_datagram.remote_endpoint(),
};
remote_endpoint
.map(|endpoint| (*endpoint).into())
.ok_or_else(|| Error::with_message(Errno::ENOTCONN, "the socket is not connected"))
Ok(endpoint.into())
}
fn sendmsg(
@ -378,11 +234,11 @@ impl Socket for DatagramSocket {
let inner = self.inner.read();
let mut options = self.options.write();
match options.socket.set_option(option, inner.as_ref()) {
match options.socket.set_option(option, &*inner) {
Err(e) => Err(e),
Ok(need_iface_poll) => {
let iface_to_poll = need_iface_poll
.then(|| match inner.as_ref() {
.then(|| match &*inner {
Inner::Unbound(_) => None,
Inner::Bound(bound_datagram) => Some(bound_datagram.iface().clone()),
})
@ -401,4 +257,4 @@ impl Socket for DatagramSocket {
}
}
impl SetSocketLevelOption for Inner {}
impl SetSocketLevelOption for Inner<UnboundDatagram, BoundDatagram> {}

View File

@ -3,39 +3,65 @@
use aster_bigtcp::{socket::UdpSocket, wire::IpEndpoint};
use super::{bound::BoundDatagram, DatagramObserver};
use crate::{events::IoEvents, net::socket::ip::common::bind_port, prelude::*};
use crate::{
events::IoEvents,
net::socket::{
ip::common::{bind_port, get_ephemeral_endpoint},
util::datagram_common,
},
prelude::*,
process::signal::Pollee,
};
pub struct UnboundDatagram {
pub(super) struct UnboundDatagram {
_private: (),
}
impl UnboundDatagram {
pub fn new() -> Self {
pub(super) fn new() -> Self {
Self { _private: () }
}
}
pub fn bind(
self,
endpoint: &IpEndpoint,
can_reuse: bool,
observer: DatagramObserver,
) -> core::result::Result<BoundDatagram, (Error, Self)> {
let bound_port = match bind_port(endpoint, can_reuse) {
Ok(bound_port) => bound_port,
Err(err) => return Err((err, self)),
};
pub(super) struct BindOptions {
pub(super) can_reuse: bool,
}
let bound_socket = match UdpSocket::new_bind(bound_port, observer) {
Ok(bound_socket) => bound_socket,
Err((_, err)) => {
unreachable!("`new_bind fails with {:?}, which should not happen", err)
}
};
impl datagram_common::Unbound for UnboundDatagram {
type Endpoint = IpEndpoint;
type BindOptions = BindOptions;
type Bound = BoundDatagram;
fn bind(
&mut self,
endpoint: &Self::Endpoint,
pollee: &Pollee,
options: BindOptions,
) -> Result<Self::Bound> {
let bound_port = bind_port(endpoint, options.can_reuse)?;
let bound_socket =
match UdpSocket::new_bind(bound_port, DatagramObserver::new(pollee.clone())) {
Ok(bound_socket) => bound_socket,
Err((_, err)) => {
unreachable!("`new_bind` fails with {:?}, which should not happen", err)
}
};
Ok(BoundDatagram::new(bound_socket))
}
pub(super) fn check_io_events(&self) -> IoEvents {
fn bind_ephemeral(
&mut self,
remote_endpoint: &Self::Endpoint,
pollee: &Pollee,
) -> Result<Self::Bound> {
let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(&endpoint, pollee, BindOptions { can_reuse: false })
}
fn check_io_events(&self) -> IoEvents {
IoEvents::OUT
}
}

View File

@ -0,0 +1,221 @@
// SPDX-License-Identifier: MPL-2.0
use ostd::sync::RwMutex;
use super::send_recv_flags::SendRecvFlags;
use crate::{
events::IoEvents,
process::signal::Pollee,
return_errno_with_message,
util::{MultiRead, MultiWrite},
Errno, Error, Result,
};
pub trait Unbound {
type Endpoint;
type BindOptions;
type Bound;
fn bind(
&mut self,
endpoint: &Self::Endpoint,
pollee: &Pollee,
options: Self::BindOptions,
) -> Result<Self::Bound>;
fn bind_ephemeral(
&mut self,
remote_endpoint: &Self::Endpoint,
pollee: &Pollee,
) -> Result<Self::Bound>;
fn check_io_events(&self) -> IoEvents;
}
pub trait Bound {
type Endpoint;
fn local_endpoint(&self) -> Self::Endpoint;
fn remote_endpoint(&self) -> Option<&Self::Endpoint>;
fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint);
fn try_recv(
&self,
writer: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, Self::Endpoint)>;
fn try_send(
&self,
reader: &mut dyn MultiRead,
remote: &Self::Endpoint,
flags: SendRecvFlags,
) -> Result<usize>;
fn check_io_events(&self) -> IoEvents;
}
pub enum Inner<UnboundSocket, BoundSocket> {
Unbound(UnboundSocket),
Bound(BoundSocket),
}
impl<UnboundSocket, BoundSocket> Inner<UnboundSocket, BoundSocket>
where
UnboundSocket: Unbound<Endpoint = BoundSocket::Endpoint, Bound = BoundSocket>,
BoundSocket: Bound,
{
pub fn bind(
&mut self,
endpoint: &UnboundSocket::Endpoint,
pollee: &Pollee,
options: UnboundSocket::BindOptions,
) -> Result<()> {
let unbound_datagram = match self {
Inner::Unbound(unbound_datagram) => unbound_datagram,
Inner::Bound(_) => {
return_errno_with_message!(
Errno::EINVAL,
"the socket is already bound to an address"
)
}
};
let bound_datagram = unbound_datagram.bind(endpoint, pollee, options)?;
*self = Inner::Bound(bound_datagram);
Ok(())
}
pub fn bind_ephemeral(
&mut self,
remote_endpoint: &UnboundSocket::Endpoint,
pollee: &Pollee,
) -> Result<()> {
let unbound_datagram = match self {
Inner::Unbound(unbound_datagram) => unbound_datagram,
Inner::Bound(_) => return Ok(()),
};
let bound_datagram = unbound_datagram.bind_ephemeral(remote_endpoint, pollee)?;
*self = Inner::Bound(bound_datagram);
Ok(())
}
pub fn connect(
&mut self,
remote_endpoint: &UnboundSocket::Endpoint,
pollee: &Pollee,
) -> Result<()> {
self.bind_ephemeral(remote_endpoint, pollee)?;
let bound_datagram = match self {
Inner::Unbound(_) => {
unreachable!(
"`bind_to_ephemeral_endpoint` succeeds so the socket cannot be unbound"
);
}
Inner::Bound(bound_datagram) => bound_datagram,
};
bound_datagram.set_remote_endpoint(remote_endpoint);
Ok(())
}
pub fn addr(&self) -> Option<UnboundSocket::Endpoint> {
match self {
Inner::Unbound(_) => None,
Inner::Bound(bound_datagram) => Some(bound_datagram.local_endpoint()),
}
}
pub fn peer_addr(&self) -> Option<&UnboundSocket::Endpoint> {
match self {
Inner::Unbound(_) => None,
Inner::Bound(bound_datagram) => bound_datagram.remote_endpoint(),
}
}
pub fn try_recv(
&self,
writer: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, UnboundSocket::Endpoint)> {
match self {
Inner::Unbound(_) => {
return_errno_with_message!(Errno::EAGAIN, "the socket is not bound");
}
Inner::Bound(bound_datagram) => bound_datagram.try_recv(writer, flags),
}
}
// If you're looking for `try_send`, there isn't one. Use `select_remote_and_bind` below and
// call `Bound::try_send` directly.
pub fn check_io_events(&self) -> IoEvents {
match self {
Inner::Unbound(unbound_datagram) => unbound_datagram.check_io_events(),
Inner::Bound(bound_datagram) => bound_datagram.check_io_events(),
}
}
}
/// Selects the remote endpoint and binds if the socket is not bound.
///
/// The remote endpoint specified in the system call (e.g., `sendto`) argument is preferred,
/// otherwise the connected endpoint of the socket is used. If there are no remote endpoints
/// available, this method will fail with [`EDESTADDRREQ`].
///
/// If the remote endpoint is specified but the socket is not bound, this method will try to
/// bind the socket to an ephemeral endpoint.
///
/// If the above steps succeed, `op` will be called with the bound socket and the selected
/// remote endpoint.
///
/// [`EDESTADDRREQ`]: crate::error::Errno::EDESTADDRREQ
pub fn select_remote_and_bind<UnboundSocket, BoundSocket, B, F, R>(
inner_mutex: &RwMutex<Inner<UnboundSocket, BoundSocket>>,
remote: Option<&UnboundSocket::Endpoint>,
bind_ephemeral: B,
op: F,
) -> Result<R>
where
UnboundSocket: Unbound<Endpoint = BoundSocket::Endpoint, Bound = BoundSocket>,
BoundSocket: Bound,
B: FnOnce() -> Result<()>,
F: FnOnce(&BoundSocket, &UnboundSocket::Endpoint) -> Result<R>,
{
let mut inner = inner_mutex.read();
// Not really a loop, since we always break on the first iteration. But we need to use
// `loop` here because we want to use `break` later.
#[expect(clippy::never_loop)]
let bound_datagram = loop {
// Fast path: The socket is already bound.
if let Inner::Bound(bound_datagram) = &*inner {
break bound_datagram;
}
// Slow path: Try to bind the socket to an ephemeral endpoint.
drop(inner);
bind_ephemeral()?;
inner = inner_mutex.read();
// Now the socket must be bound.
if let Inner::Bound(bound_datagram) = &*inner {
break bound_datagram;
}
unreachable!("`try_bind_ephemeral` succeeds so the socket cannot be unbound");
};
let remote_endpoint = remote
.or_else(|| bound_datagram.remote_endpoint())
.ok_or_else(|| {
Error::with_message(
Errno::EDESTADDRREQ,
"the destination address is not specified",
)
})?;
op(bound_datagram, remote_endpoint)
}

View File

@ -1,5 +1,6 @@
// SPDX-License-Identifier: MPL-2.0
pub mod datagram_common;
mod message_header;
pub mod options;
pub mod send_recv_flags;