Use datagram_common in netlink

This commit is contained in:
Ruihan Li 2025-04-20 17:30:39 +08:00 committed by Tate, Hongliang Tian
parent c9f939bcc4
commit 26253829bb
4 changed files with 126 additions and 115 deletions

View File

@ -10,6 +10,7 @@ use crate::{
message::ProtocolSegment, route::kernel::get_netlink_route_kernel, table::BoundHandle, message::ProtocolSegment, route::kernel::get_netlink_route_kernel, table::BoundHandle,
NetlinkSocketAddr, NetlinkSocketAddr,
}, },
util::datagram_common,
SendRecvFlags, SendRecvFlags,
}, },
prelude::*, prelude::*,
@ -18,6 +19,7 @@ use crate::{
pub(super) struct BoundNetlinkRoute { pub(super) struct BoundNetlinkRoute {
handle: BoundHandle, handle: BoundHandle,
remote_addr: NetlinkSocketAddr,
receive_queue: Mutex<VecDeque<RtnlMessage>>, receive_queue: Mutex<VecDeque<RtnlMessage>>,
} }
@ -25,18 +27,31 @@ impl BoundNetlinkRoute {
pub(super) const fn new(handle: BoundHandle) -> Self { pub(super) const fn new(handle: BoundHandle) -> Self {
Self { Self {
handle, handle,
remote_addr: NetlinkSocketAddr::new_unspecified(),
receive_queue: Mutex::new(VecDeque::new()), receive_queue: Mutex::new(VecDeque::new()),
} }
} }
}
pub(super) const fn addr(&self) -> NetlinkSocketAddr { impl datagram_common::Bound for BoundNetlinkRoute {
type Endpoint = NetlinkSocketAddr;
fn local_endpoint(&self) -> Self::Endpoint {
self.handle.addr() self.handle.addr()
} }
pub(super) fn try_send( fn remote_endpoint(&self) -> Option<&Self::Endpoint> {
Some(&self.remote_addr)
}
fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint) {
self.remote_addr = *endpoint;
}
fn try_send(
&self, &self,
reader: &mut dyn MultiRead, reader: &mut dyn MultiRead,
remote: Option<&NetlinkSocketAddr>, remote: &Self::Endpoint,
flags: SendRecvFlags, flags: SendRecvFlags,
) -> Result<usize> { ) -> Result<usize> {
// TODO: Deal with flags // TODO: Deal with flags
@ -44,16 +59,12 @@ impl BoundNetlinkRoute {
warn!("unsupported flags: {:?}", flags); warn!("unsupported flags: {:?}", flags);
} }
if let Some(remote) = remote { // TODO: Further check whether other socket address can be supported.
// TODO: Further check whether other socket address can be supported. if *remote != NetlinkSocketAddr::new_unspecified() {
if *remote != NetlinkSocketAddr::new_unspecified() { return_errno_with_message!(
return_errno_with_message!( Errno::ECONNREFUSED,
Errno::ECONNREFUSED, "sending netlink route messages to user space is not supported"
"sending netlink route messages to user space is not supported" );
);
}
} else {
// TODO: We should use the connected remote address, if any.
} }
let mut nlmsg = { let mut nlmsg = {
@ -75,7 +86,7 @@ impl BoundNetlinkRoute {
} }
}; };
let local_port = self.addr().port(); let local_port = self.handle.port();
for segment in nlmsg.segments_mut() { for segment in nlmsg.segments_mut() {
// The header's PID should be the sender's port ID. // The header's PID should be the sender's port ID.
// However, the sender can also leave it unspecified. // However, the sender can also leave it unspecified.
@ -93,7 +104,7 @@ impl BoundNetlinkRoute {
Ok(nlmsg.total_len()) Ok(nlmsg.total_len())
} }
pub(super) fn try_receive( fn try_recv(
&self, &self,
writer: &mut dyn MultiWrite, writer: &mut dyn MultiWrite,
flags: SendRecvFlags, flags: SendRecvFlags,
@ -126,7 +137,7 @@ impl BoundNetlinkRoute {
Ok((len, remote)) Ok((len, remote))
} }
pub(super) fn check_io_events(&self) -> IoEvents { fn check_io_events(&self) -> IoEvents {
let mut events = IoEvents::OUT; let mut events = IoEvents::OUT;
let receive_queue = self.receive_queue.lock(); let receive_queue = self.receive_queue.lock();

View File

@ -5,15 +5,16 @@
use core::sync::atomic::{AtomicBool, Ordering}; use core::sync::atomic::{AtomicBool, Ordering};
use bound::BoundNetlinkRoute; use bound::BoundNetlinkRoute;
use takeable::Takeable;
use unbound::UnboundNetlinkRoute; use unbound::UnboundNetlinkRoute;
use super::NetlinkSocketAddr; use super::NetlinkSocketAddr;
use crate::{ use crate::{
events::IoEvents, events::IoEvents,
net::socket::{ net::socket::{
options::SocketOption, private::SocketPrivate, MessageHeader, SendRecvFlags, Socket, options::SocketOption,
SocketAddr, private::SocketPrivate,
util::datagram_common::{select_remote_and_bind, Bound, Inner},
MessageHeader, SendRecvFlags, Socket, SocketAddr,
}, },
prelude::*, prelude::*,
process::signal::{PollHandle, Pollable, Pollee}, process::signal::{PollHandle, Pollable, Pollee},
@ -26,96 +27,94 @@ mod message;
mod unbound; mod unbound;
pub struct NetlinkRouteSocket { pub struct NetlinkRouteSocket {
inner: RwMutex<Inner<UnboundNetlinkRoute, BoundNetlinkRoute>>,
is_nonblocking: AtomicBool, is_nonblocking: AtomicBool,
pollee: Pollee, pollee: Pollee,
inner: RwMutex<Takeable<Inner>>,
}
enum Inner {
Unbound(UnboundNetlinkRoute),
Bound(BoundNetlinkRoute),
} }
impl NetlinkRouteSocket { impl NetlinkRouteSocket {
pub fn new(is_nonblocking: bool) -> Self { pub fn new(is_nonblocking: bool) -> Self {
let unbound = UnboundNetlinkRoute::new();
Self { Self {
inner: RwMutex::new(Inner::Unbound(unbound)),
is_nonblocking: AtomicBool::new(is_nonblocking), is_nonblocking: AtomicBool::new(is_nonblocking),
pollee: Pollee::new(), pollee: Pollee::new(),
inner: RwMutex::new(Takeable::new(Inner::Unbound(UnboundNetlinkRoute::new()))),
} }
} }
fn try_receive(
&self,
writer: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, NetlinkSocketAddr)> {
let inner = self.inner.read();
let bound = match inner.as_ref() {
Inner::Unbound(_) => {
return_errno_with_message!(Errno::EAGAIN, "the socket is not bound")
}
Inner::Bound(bound_netlink_route) => bound_netlink_route,
};
let received = bound.try_receive(writer, flags)?;
self.pollee.invalidate();
Ok(received)
}
fn try_send( fn try_send(
&self, &self,
reader: &mut dyn MultiRead, reader: &mut dyn MultiRead,
remote: Option<&NetlinkSocketAddr>, remote: Option<&NetlinkSocketAddr>,
flags: SendRecvFlags, flags: SendRecvFlags,
) -> Result<usize> { ) -> Result<usize> {
let inner = self.inner.read(); let sent_bytes = select_remote_and_bind(
&self.inner,
let bound = match inner.as_ref() { remote,
Inner::Unbound(_) => todo!(), || {
Inner::Bound(bound) => bound, self.inner
}; .write()
.bind_ephemeral(&NetlinkSocketAddr::new_unspecified(), &self.pollee)
let sent_bytes = bound.try_send(reader, remote, flags)?; },
|bound, remote_endpoint| bound.try_send(reader, remote_endpoint, flags),
)?;
self.pollee.notify(IoEvents::OUT | IoEvents::IN); self.pollee.notify(IoEvents::OUT | IoEvents::IN);
Ok(sent_bytes) Ok(sent_bytes)
} }
fn check_io_events(&self) -> IoEvents { fn try_recv(
let inner = self.inner.read(); &self,
match inner.as_ref() { writer: &mut dyn MultiWrite,
Inner::Unbound(unbound) => unbound.check_io_events(), flags: SendRecvFlags,
Inner::Bound(bound) => bound.check_io_events(), ) -> Result<(usize, SocketAddr)> {
} let recv_bytes = self
.inner
.read()
.try_recv(writer, flags)
.map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()))?;
self.pollee.invalidate();
Ok(recv_bytes)
} }
} }
impl Socket for NetlinkRouteSocket { impl Socket for NetlinkRouteSocket {
fn bind(&self, socket_addr: SocketAddr) -> Result<()> { fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
let SocketAddr::Netlink(netlink_addr) = socket_addr else { let endpoint = socket_addr.try_into()?;
return_errno_with_message!(
Errno::EAFNOSUPPORT,
"the provided address is not netlink address"
);
};
let mut inner = self.inner.write(); // FIXME: We need to further check the Linux behavior
inner.borrow_result(|owned_inner| match owned_inner.bind(&netlink_addr) { // whether we should return error if the socket is bound.
Ok(bound_inner) => (bound_inner, Ok(())), // The socket may call `bind` syscall to join new multicast groups.
Err((err, err_inner)) => (err_inner, Err(err)), self.inner.write().bind(&endpoint, &self.pollee, ())
}) }
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?;
self.inner.write().connect(&endpoint, &self.pollee)
} }
fn addr(&self) -> Result<SocketAddr> { fn addr(&self) -> Result<SocketAddr> {
let netlink_addr = match self.inner.read().as_ref() { let endpoint = self
Inner::Unbound(_) => NetlinkSocketAddr::new_unspecified(), .inner
Inner::Bound(bound) => bound.addr(), .read()
}; .addr()
.unwrap_or(NetlinkSocketAddr::new_unspecified());
Ok(SocketAddr::Netlink(netlink_addr)) Ok(endpoint.into())
}
fn peer_addr(&self) -> Result<SocketAddr> {
let endpoint = self
.inner
.read()
.peer_addr()
.cloned()
.unwrap_or(NetlinkSocketAddr::new_unspecified());
Ok(endpoint.into())
} }
fn sendmsg( fn sendmsg(
@ -148,15 +147,11 @@ impl Socket for NetlinkRouteSocket {
writers: &mut dyn MultiWrite, writers: &mut dyn MultiWrite,
flags: SendRecvFlags, flags: SendRecvFlags,
) -> Result<(usize, MessageHeader)> { ) -> Result<(usize, MessageHeader)> {
let (received_len, addr) = let (received_len, addr) = self.block_on(IoEvents::IN, || self.try_recv(writers, flags))?;
self.block_on(IoEvents::IN, || self.try_receive(writers, flags))?;
// TODO: Receive control message // TODO: Receive control message
let message_header = { let message_header = MessageHeader::new(Some(addr), None);
let addr = SocketAddr::Netlink(addr);
MessageHeader::new(Some(addr), None)
};
Ok((received_len, message_header)) Ok((received_len, message_header))
} }
@ -180,28 +175,6 @@ impl SocketPrivate for NetlinkRouteSocket {
impl Pollable for NetlinkRouteSocket { impl Pollable for NetlinkRouteSocket {
fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee self.pollee
.poll_with(mask, poller, || self.check_io_events()) .poll_with(mask, poller, || self.inner.read().check_io_events())
}
}
impl Inner {
fn bind(self, addr: &NetlinkSocketAddr) -> core::result::Result<Self, (Error, Self)> {
let unbound = match self {
Inner::Unbound(unbound) => unbound,
Inner::Bound(bound) => {
// FIXME: We need to further check the Linux behavior
// whether we should return error if the socket is bound.
// The socket may call `bind` syscall to join new multicast groups.
return Err((
Error::with_message(Errno::EINVAL, "the socket is already bound"),
Self::Bound(bound),
));
}
};
match unbound.bind(addr) {
Ok(bound) => Ok(Self::Bound(bound)),
Err((err, unbound)) => Err((err, Self::Unbound(unbound))),
}
} }
} }

View File

@ -3,10 +3,12 @@
use super::bound::BoundNetlinkRoute; use super::bound::BoundNetlinkRoute;
use crate::{ use crate::{
events::IoEvents, events::IoEvents,
net::socket::netlink::{ net::socket::{
table::NETLINK_SOCKET_TABLE, NetlinkSocketAddr, StandardNetlinkProtocol, netlink::{table::NETLINK_SOCKET_TABLE, NetlinkSocketAddr, StandardNetlinkProtocol},
util::datagram_common,
}, },
prelude::*, prelude::*,
process::signal::Pollee,
}; };
pub(super) struct UnboundNetlinkRoute { pub(super) struct UnboundNetlinkRoute {
@ -17,19 +19,40 @@ impl UnboundNetlinkRoute {
pub(super) const fn new() -> Self { pub(super) const fn new() -> Self {
Self { _private: () } Self { _private: () }
} }
}
pub(super) fn bind( impl datagram_common::Unbound for UnboundNetlinkRoute {
self, type Endpoint = NetlinkSocketAddr;
addr: &NetlinkSocketAddr, type BindOptions = ();
) -> core::result::Result<BoundNetlinkRoute, (Error, Self)> {
let bound_handle = NETLINK_SOCKET_TABLE type Bound = BoundNetlinkRoute;
.bind(StandardNetlinkProtocol::ROUTE as _, addr)
.map_err(|err| (err, self))?; fn bind(
&mut self,
endpoint: &Self::Endpoint,
_pollee: &Pollee,
_options: Self::BindOptions,
) -> Result<BoundNetlinkRoute> {
let bound_handle =
NETLINK_SOCKET_TABLE.bind(StandardNetlinkProtocol::ROUTE as _, endpoint)?;
Ok(BoundNetlinkRoute::new(bound_handle)) Ok(BoundNetlinkRoute::new(bound_handle))
} }
pub(super) fn check_io_events(&self) -> IoEvents { fn bind_ephemeral(
&mut self,
_remote_endpoint: &Self::Endpoint,
_pollee: &Pollee,
) -> Result<Self::Bound> {
let bound_handle = NETLINK_SOCKET_TABLE.bind(
StandardNetlinkProtocol::ROUTE as _,
&NetlinkSocketAddr::new_unspecified(),
)?;
Ok(BoundNetlinkRoute::new(bound_handle))
}
fn check_io_events(&self) -> IoEvents {
IoEvents::OUT IoEvents::OUT
} }
} }

View File

@ -137,6 +137,10 @@ impl BoundHandle {
} }
} }
pub(super) const fn port(&self) -> PortNum {
self.port
}
pub(super) const fn addr(&self) -> NetlinkSocketAddr { pub(super) const fn addr(&self) -> NetlinkSocketAddr {
NetlinkSocketAddr::new(self.port, self.groups) NetlinkSocketAddr::new(self.port, self.groups)
} }