Use datagram_common in netlink

This commit is contained in:
Ruihan Li 2025-04-20 17:30:39 +08:00 committed by Tate, Hongliang Tian
parent c9f939bcc4
commit 26253829bb
4 changed files with 126 additions and 115 deletions

View File

@ -10,6 +10,7 @@ use crate::{
message::ProtocolSegment, route::kernel::get_netlink_route_kernel, table::BoundHandle,
NetlinkSocketAddr,
},
util::datagram_common,
SendRecvFlags,
},
prelude::*,
@ -18,6 +19,7 @@ use crate::{
pub(super) struct BoundNetlinkRoute {
handle: BoundHandle,
remote_addr: NetlinkSocketAddr,
receive_queue: Mutex<VecDeque<RtnlMessage>>,
}
@ -25,18 +27,31 @@ impl BoundNetlinkRoute {
pub(super) const fn new(handle: BoundHandle) -> Self {
Self {
handle,
remote_addr: NetlinkSocketAddr::new_unspecified(),
receive_queue: Mutex::new(VecDeque::new()),
}
}
}
pub(super) const fn addr(&self) -> NetlinkSocketAddr {
impl datagram_common::Bound for BoundNetlinkRoute {
type Endpoint = NetlinkSocketAddr;
fn local_endpoint(&self) -> Self::Endpoint {
self.handle.addr()
}
pub(super) fn try_send(
fn remote_endpoint(&self) -> Option<&Self::Endpoint> {
Some(&self.remote_addr)
}
fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint) {
self.remote_addr = *endpoint;
}
fn try_send(
&self,
reader: &mut dyn MultiRead,
remote: Option<&NetlinkSocketAddr>,
remote: &Self::Endpoint,
flags: SendRecvFlags,
) -> Result<usize> {
// TODO: Deal with flags
@ -44,7 +59,6 @@ impl BoundNetlinkRoute {
warn!("unsupported flags: {:?}", flags);
}
if let Some(remote) = remote {
// TODO: Further check whether other socket address can be supported.
if *remote != NetlinkSocketAddr::new_unspecified() {
return_errno_with_message!(
@ -52,9 +66,6 @@ impl BoundNetlinkRoute {
"sending netlink route messages to user space is not supported"
);
}
} else {
// TODO: We should use the connected remote address, if any.
}
let mut nlmsg = {
let sum_lens = reader.sum_lens();
@ -75,7 +86,7 @@ impl BoundNetlinkRoute {
}
};
let local_port = self.addr().port();
let local_port = self.handle.port();
for segment in nlmsg.segments_mut() {
// The header's PID should be the sender's port ID.
// However, the sender can also leave it unspecified.
@ -93,7 +104,7 @@ impl BoundNetlinkRoute {
Ok(nlmsg.total_len())
}
pub(super) fn try_receive(
fn try_recv(
&self,
writer: &mut dyn MultiWrite,
flags: SendRecvFlags,
@ -126,7 +137,7 @@ impl BoundNetlinkRoute {
Ok((len, remote))
}
pub(super) fn check_io_events(&self) -> IoEvents {
fn check_io_events(&self) -> IoEvents {
let mut events = IoEvents::OUT;
let receive_queue = self.receive_queue.lock();

View File

@ -5,15 +5,16 @@
use core::sync::atomic::{AtomicBool, Ordering};
use bound::BoundNetlinkRoute;
use takeable::Takeable;
use unbound::UnboundNetlinkRoute;
use super::NetlinkSocketAddr;
use crate::{
events::IoEvents,
net::socket::{
options::SocketOption, private::SocketPrivate, MessageHeader, SendRecvFlags, Socket,
SocketAddr,
options::SocketOption,
private::SocketPrivate,
util::datagram_common::{select_remote_and_bind, Bound, Inner},
MessageHeader, SendRecvFlags, Socket, SocketAddr,
},
prelude::*,
process::signal::{PollHandle, Pollable, Pollee},
@ -26,96 +27,94 @@ mod message;
mod unbound;
pub struct NetlinkRouteSocket {
inner: RwMutex<Inner<UnboundNetlinkRoute, BoundNetlinkRoute>>,
is_nonblocking: AtomicBool,
pollee: Pollee,
inner: RwMutex<Takeable<Inner>>,
}
enum Inner {
Unbound(UnboundNetlinkRoute),
Bound(BoundNetlinkRoute),
}
impl NetlinkRouteSocket {
pub fn new(is_nonblocking: bool) -> Self {
let unbound = UnboundNetlinkRoute::new();
Self {
inner: RwMutex::new(Inner::Unbound(unbound)),
is_nonblocking: AtomicBool::new(is_nonblocking),
pollee: Pollee::new(),
inner: RwMutex::new(Takeable::new(Inner::Unbound(UnboundNetlinkRoute::new()))),
}
}
fn try_receive(
&self,
writer: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, NetlinkSocketAddr)> {
let inner = self.inner.read();
let bound = match inner.as_ref() {
Inner::Unbound(_) => {
return_errno_with_message!(Errno::EAGAIN, "the socket is not bound")
}
Inner::Bound(bound_netlink_route) => bound_netlink_route,
};
let received = bound.try_receive(writer, flags)?;
self.pollee.invalidate();
Ok(received)
}
fn try_send(
&self,
reader: &mut dyn MultiRead,
remote: Option<&NetlinkSocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
let inner = self.inner.read();
let bound = match inner.as_ref() {
Inner::Unbound(_) => todo!(),
Inner::Bound(bound) => bound,
};
let sent_bytes = bound.try_send(reader, remote, flags)?;
let sent_bytes = select_remote_and_bind(
&self.inner,
remote,
|| {
self.inner
.write()
.bind_ephemeral(&NetlinkSocketAddr::new_unspecified(), &self.pollee)
},
|bound, remote_endpoint| bound.try_send(reader, remote_endpoint, flags),
)?;
self.pollee.notify(IoEvents::OUT | IoEvents::IN);
Ok(sent_bytes)
}
fn check_io_events(&self) -> IoEvents {
let inner = self.inner.read();
match inner.as_ref() {
Inner::Unbound(unbound) => unbound.check_io_events(),
Inner::Bound(bound) => bound.check_io_events(),
}
fn try_recv(
&self,
writer: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, SocketAddr)> {
let recv_bytes = self
.inner
.read()
.try_recv(writer, flags)
.map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()))?;
self.pollee.invalidate();
Ok(recv_bytes)
}
}
impl Socket for NetlinkRouteSocket {
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
let SocketAddr::Netlink(netlink_addr) = socket_addr else {
return_errno_with_message!(
Errno::EAFNOSUPPORT,
"the provided address is not netlink address"
);
};
let endpoint = socket_addr.try_into()?;
let mut inner = self.inner.write();
inner.borrow_result(|owned_inner| match owned_inner.bind(&netlink_addr) {
Ok(bound_inner) => (bound_inner, Ok(())),
Err((err, err_inner)) => (err_inner, Err(err)),
})
// FIXME: We need to further check the Linux behavior
// whether we should return error if the socket is bound.
// The socket may call `bind` syscall to join new multicast groups.
self.inner.write().bind(&endpoint, &self.pollee, ())
}
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?;
self.inner.write().connect(&endpoint, &self.pollee)
}
fn addr(&self) -> Result<SocketAddr> {
let netlink_addr = match self.inner.read().as_ref() {
Inner::Unbound(_) => NetlinkSocketAddr::new_unspecified(),
Inner::Bound(bound) => bound.addr(),
};
let endpoint = self
.inner
.read()
.addr()
.unwrap_or(NetlinkSocketAddr::new_unspecified());
Ok(SocketAddr::Netlink(netlink_addr))
Ok(endpoint.into())
}
fn peer_addr(&self) -> Result<SocketAddr> {
let endpoint = self
.inner
.read()
.peer_addr()
.cloned()
.unwrap_or(NetlinkSocketAddr::new_unspecified());
Ok(endpoint.into())
}
fn sendmsg(
@ -148,15 +147,11 @@ impl Socket for NetlinkRouteSocket {
writers: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, MessageHeader)> {
let (received_len, addr) =
self.block_on(IoEvents::IN, || self.try_receive(writers, flags))?;
let (received_len, addr) = self.block_on(IoEvents::IN, || self.try_recv(writers, flags))?;
// TODO: Receive control message
let message_header = {
let addr = SocketAddr::Netlink(addr);
MessageHeader::new(Some(addr), None)
};
let message_header = MessageHeader::new(Some(addr), None);
Ok((received_len, message_header))
}
@ -180,28 +175,6 @@ impl SocketPrivate for NetlinkRouteSocket {
impl Pollable for NetlinkRouteSocket {
fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee
.poll_with(mask, poller, || self.check_io_events())
}
}
impl Inner {
fn bind(self, addr: &NetlinkSocketAddr) -> core::result::Result<Self, (Error, Self)> {
let unbound = match self {
Inner::Unbound(unbound) => unbound,
Inner::Bound(bound) => {
// FIXME: We need to further check the Linux behavior
// whether we should return error if the socket is bound.
// The socket may call `bind` syscall to join new multicast groups.
return Err((
Error::with_message(Errno::EINVAL, "the socket is already bound"),
Self::Bound(bound),
));
}
};
match unbound.bind(addr) {
Ok(bound) => Ok(Self::Bound(bound)),
Err((err, unbound)) => Err((err, Self::Unbound(unbound))),
}
.poll_with(mask, poller, || self.inner.read().check_io_events())
}
}

View File

@ -3,10 +3,12 @@
use super::bound::BoundNetlinkRoute;
use crate::{
events::IoEvents,
net::socket::netlink::{
table::NETLINK_SOCKET_TABLE, NetlinkSocketAddr, StandardNetlinkProtocol,
net::socket::{
netlink::{table::NETLINK_SOCKET_TABLE, NetlinkSocketAddr, StandardNetlinkProtocol},
util::datagram_common,
},
prelude::*,
process::signal::Pollee,
};
pub(super) struct UnboundNetlinkRoute {
@ -17,19 +19,40 @@ impl UnboundNetlinkRoute {
pub(super) const fn new() -> Self {
Self { _private: () }
}
}
pub(super) fn bind(
self,
addr: &NetlinkSocketAddr,
) -> core::result::Result<BoundNetlinkRoute, (Error, Self)> {
let bound_handle = NETLINK_SOCKET_TABLE
.bind(StandardNetlinkProtocol::ROUTE as _, addr)
.map_err(|err| (err, self))?;
impl datagram_common::Unbound for UnboundNetlinkRoute {
type Endpoint = NetlinkSocketAddr;
type BindOptions = ();
type Bound = BoundNetlinkRoute;
fn bind(
&mut self,
endpoint: &Self::Endpoint,
_pollee: &Pollee,
_options: Self::BindOptions,
) -> Result<BoundNetlinkRoute> {
let bound_handle =
NETLINK_SOCKET_TABLE.bind(StandardNetlinkProtocol::ROUTE as _, endpoint)?;
Ok(BoundNetlinkRoute::new(bound_handle))
}
pub(super) fn check_io_events(&self) -> IoEvents {
fn bind_ephemeral(
&mut self,
_remote_endpoint: &Self::Endpoint,
_pollee: &Pollee,
) -> Result<Self::Bound> {
let bound_handle = NETLINK_SOCKET_TABLE.bind(
StandardNetlinkProtocol::ROUTE as _,
&NetlinkSocketAddr::new_unspecified(),
)?;
Ok(BoundNetlinkRoute::new(bound_handle))
}
fn check_io_events(&self) -> IoEvents {
IoEvents::OUT
}
}

View File

@ -137,6 +137,10 @@ impl BoundHandle {
}
}
pub(super) const fn port(&self) -> PortNum {
self.port
}
pub(super) const fn addr(&self) -> NetlinkSocketAddr {
NetlinkSocketAddr::new(self.port, self.groups)
}