From 26253829bbd9ee37b64bd2616851b628dd133c9c Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Sun, 20 Apr 2025 17:30:39 +0800 Subject: [PATCH] Use `datagram_common` in netlink --- kernel/src/net/socket/netlink/route/bound.rs | 43 +++-- kernel/src/net/socket/netlink/route/mod.rs | 151 +++++++----------- .../src/net/socket/netlink/route/unbound.rs | 43 +++-- kernel/src/net/socket/netlink/table/mod.rs | 4 + 4 files changed, 126 insertions(+), 115 deletions(-) diff --git a/kernel/src/net/socket/netlink/route/bound.rs b/kernel/src/net/socket/netlink/route/bound.rs index 3cd59d63..9331ab4d 100644 --- a/kernel/src/net/socket/netlink/route/bound.rs +++ b/kernel/src/net/socket/netlink/route/bound.rs @@ -10,6 +10,7 @@ use crate::{ message::ProtocolSegment, route::kernel::get_netlink_route_kernel, table::BoundHandle, NetlinkSocketAddr, }, + util::datagram_common, SendRecvFlags, }, prelude::*, @@ -18,6 +19,7 @@ use crate::{ pub(super) struct BoundNetlinkRoute { handle: BoundHandle, + remote_addr: NetlinkSocketAddr, receive_queue: Mutex>, } @@ -25,18 +27,31 @@ impl BoundNetlinkRoute { pub(super) const fn new(handle: BoundHandle) -> Self { Self { handle, + remote_addr: NetlinkSocketAddr::new_unspecified(), receive_queue: Mutex::new(VecDeque::new()), } } +} - pub(super) const fn addr(&self) -> NetlinkSocketAddr { +impl datagram_common::Bound for BoundNetlinkRoute { + type Endpoint = NetlinkSocketAddr; + + fn local_endpoint(&self) -> Self::Endpoint { self.handle.addr() } - pub(super) fn try_send( + fn remote_endpoint(&self) -> Option<&Self::Endpoint> { + Some(&self.remote_addr) + } + + fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint) { + self.remote_addr = *endpoint; + } + + fn try_send( &self, reader: &mut dyn MultiRead, - remote: Option<&NetlinkSocketAddr>, + remote: &Self::Endpoint, flags: SendRecvFlags, ) -> Result { // TODO: Deal with flags @@ -44,16 +59,12 @@ impl BoundNetlinkRoute { warn!("unsupported flags: {:?}", flags); } - if let Some(remote) = remote { - // TODO: Further check whether other socket address can be supported. - if *remote != NetlinkSocketAddr::new_unspecified() { - return_errno_with_message!( - Errno::ECONNREFUSED, - "sending netlink route messages to user space is not supported" - ); - } - } else { - // TODO: We should use the connected remote address, if any. + // TODO: Further check whether other socket address can be supported. + if *remote != NetlinkSocketAddr::new_unspecified() { + return_errno_with_message!( + Errno::ECONNREFUSED, + "sending netlink route messages to user space is not supported" + ); } let mut nlmsg = { @@ -75,7 +86,7 @@ impl BoundNetlinkRoute { } }; - let local_port = self.addr().port(); + let local_port = self.handle.port(); for segment in nlmsg.segments_mut() { // The header's PID should be the sender's port ID. // However, the sender can also leave it unspecified. @@ -93,7 +104,7 @@ impl BoundNetlinkRoute { Ok(nlmsg.total_len()) } - pub(super) fn try_receive( + fn try_recv( &self, writer: &mut dyn MultiWrite, flags: SendRecvFlags, @@ -126,7 +137,7 @@ impl BoundNetlinkRoute { Ok((len, remote)) } - pub(super) fn check_io_events(&self) -> IoEvents { + fn check_io_events(&self) -> IoEvents { let mut events = IoEvents::OUT; let receive_queue = self.receive_queue.lock(); diff --git a/kernel/src/net/socket/netlink/route/mod.rs b/kernel/src/net/socket/netlink/route/mod.rs index b6a2e958..1313449b 100644 --- a/kernel/src/net/socket/netlink/route/mod.rs +++ b/kernel/src/net/socket/netlink/route/mod.rs @@ -5,15 +5,16 @@ use core::sync::atomic::{AtomicBool, Ordering}; use bound::BoundNetlinkRoute; -use takeable::Takeable; use unbound::UnboundNetlinkRoute; use super::NetlinkSocketAddr; use crate::{ events::IoEvents, net::socket::{ - options::SocketOption, private::SocketPrivate, MessageHeader, SendRecvFlags, Socket, - SocketAddr, + options::SocketOption, + private::SocketPrivate, + util::datagram_common::{select_remote_and_bind, Bound, Inner}, + MessageHeader, SendRecvFlags, Socket, SocketAddr, }, prelude::*, process::signal::{PollHandle, Pollable, Pollee}, @@ -26,96 +27,94 @@ mod message; mod unbound; pub struct NetlinkRouteSocket { + inner: RwMutex>, + is_nonblocking: AtomicBool, pollee: Pollee, - inner: RwMutex>, -} - -enum Inner { - Unbound(UnboundNetlinkRoute), - Bound(BoundNetlinkRoute), } impl NetlinkRouteSocket { pub fn new(is_nonblocking: bool) -> Self { + let unbound = UnboundNetlinkRoute::new(); Self { + inner: RwMutex::new(Inner::Unbound(unbound)), is_nonblocking: AtomicBool::new(is_nonblocking), pollee: Pollee::new(), - inner: RwMutex::new(Takeable::new(Inner::Unbound(UnboundNetlinkRoute::new()))), } } - fn try_receive( - &self, - writer: &mut dyn MultiWrite, - flags: SendRecvFlags, - ) -> Result<(usize, NetlinkSocketAddr)> { - let inner = self.inner.read(); - - let bound = match inner.as_ref() { - Inner::Unbound(_) => { - return_errno_with_message!(Errno::EAGAIN, "the socket is not bound") - } - Inner::Bound(bound_netlink_route) => bound_netlink_route, - }; - - let received = bound.try_receive(writer, flags)?; - self.pollee.invalidate(); - - Ok(received) - } - fn try_send( &self, reader: &mut dyn MultiRead, remote: Option<&NetlinkSocketAddr>, flags: SendRecvFlags, ) -> Result { - let inner = self.inner.read(); - - let bound = match inner.as_ref() { - Inner::Unbound(_) => todo!(), - Inner::Bound(bound) => bound, - }; - - let sent_bytes = bound.try_send(reader, remote, flags)?; + let sent_bytes = select_remote_and_bind( + &self.inner, + remote, + || { + self.inner + .write() + .bind_ephemeral(&NetlinkSocketAddr::new_unspecified(), &self.pollee) + }, + |bound, remote_endpoint| bound.try_send(reader, remote_endpoint, flags), + )?; self.pollee.notify(IoEvents::OUT | IoEvents::IN); Ok(sent_bytes) } - fn check_io_events(&self) -> IoEvents { - let inner = self.inner.read(); - match inner.as_ref() { - Inner::Unbound(unbound) => unbound.check_io_events(), - Inner::Bound(bound) => bound.check_io_events(), - } + fn try_recv( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, SocketAddr)> { + let recv_bytes = self + .inner + .read() + .try_recv(writer, flags) + .map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()))?; + self.pollee.invalidate(); + + Ok(recv_bytes) } } impl Socket for NetlinkRouteSocket { fn bind(&self, socket_addr: SocketAddr) -> Result<()> { - let SocketAddr::Netlink(netlink_addr) = socket_addr else { - return_errno_with_message!( - Errno::EAFNOSUPPORT, - "the provided address is not netlink address" - ); - }; + let endpoint = socket_addr.try_into()?; - let mut inner = self.inner.write(); - inner.borrow_result(|owned_inner| match owned_inner.bind(&netlink_addr) { - Ok(bound_inner) => (bound_inner, Ok(())), - Err((err, err_inner)) => (err_inner, Err(err)), - }) + // FIXME: We need to further check the Linux behavior + // whether we should return error if the socket is bound. + // The socket may call `bind` syscall to join new multicast groups. + self.inner.write().bind(&endpoint, &self.pollee, ()) + } + + fn connect(&self, socket_addr: SocketAddr) -> Result<()> { + let endpoint = socket_addr.try_into()?; + + self.inner.write().connect(&endpoint, &self.pollee) } fn addr(&self) -> Result { - let netlink_addr = match self.inner.read().as_ref() { - Inner::Unbound(_) => NetlinkSocketAddr::new_unspecified(), - Inner::Bound(bound) => bound.addr(), - }; + let endpoint = self + .inner + .read() + .addr() + .unwrap_or(NetlinkSocketAddr::new_unspecified()); - Ok(SocketAddr::Netlink(netlink_addr)) + Ok(endpoint.into()) + } + + fn peer_addr(&self) -> Result { + let endpoint = self + .inner + .read() + .peer_addr() + .cloned() + .unwrap_or(NetlinkSocketAddr::new_unspecified()); + + Ok(endpoint.into()) } fn sendmsg( @@ -148,15 +147,11 @@ impl Socket for NetlinkRouteSocket { writers: &mut dyn MultiWrite, flags: SendRecvFlags, ) -> Result<(usize, MessageHeader)> { - let (received_len, addr) = - self.block_on(IoEvents::IN, || self.try_receive(writers, flags))?; + let (received_len, addr) = self.block_on(IoEvents::IN, || self.try_recv(writers, flags))?; // TODO: Receive control message - let message_header = { - let addr = SocketAddr::Netlink(addr); - MessageHeader::new(Some(addr), None) - }; + let message_header = MessageHeader::new(Some(addr), None); Ok((received_len, message_header)) } @@ -180,28 +175,6 @@ impl SocketPrivate for NetlinkRouteSocket { impl Pollable for NetlinkRouteSocket { fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { self.pollee - .poll_with(mask, poller, || self.check_io_events()) - } -} - -impl Inner { - fn bind(self, addr: &NetlinkSocketAddr) -> core::result::Result { - let unbound = match self { - Inner::Unbound(unbound) => unbound, - Inner::Bound(bound) => { - // FIXME: We need to further check the Linux behavior - // whether we should return error if the socket is bound. - // The socket may call `bind` syscall to join new multicast groups. - return Err(( - Error::with_message(Errno::EINVAL, "the socket is already bound"), - Self::Bound(bound), - )); - } - }; - - match unbound.bind(addr) { - Ok(bound) => Ok(Self::Bound(bound)), - Err((err, unbound)) => Err((err, Self::Unbound(unbound))), - } + .poll_with(mask, poller, || self.inner.read().check_io_events()) } } diff --git a/kernel/src/net/socket/netlink/route/unbound.rs b/kernel/src/net/socket/netlink/route/unbound.rs index 27a80fab..9a7dad37 100644 --- a/kernel/src/net/socket/netlink/route/unbound.rs +++ b/kernel/src/net/socket/netlink/route/unbound.rs @@ -3,10 +3,12 @@ use super::bound::BoundNetlinkRoute; use crate::{ events::IoEvents, - net::socket::netlink::{ - table::NETLINK_SOCKET_TABLE, NetlinkSocketAddr, StandardNetlinkProtocol, + net::socket::{ + netlink::{table::NETLINK_SOCKET_TABLE, NetlinkSocketAddr, StandardNetlinkProtocol}, + util::datagram_common, }, prelude::*, + process::signal::Pollee, }; pub(super) struct UnboundNetlinkRoute { @@ -17,19 +19,40 @@ impl UnboundNetlinkRoute { pub(super) const fn new() -> Self { Self { _private: () } } +} - pub(super) fn bind( - self, - addr: &NetlinkSocketAddr, - ) -> core::result::Result { - let bound_handle = NETLINK_SOCKET_TABLE - .bind(StandardNetlinkProtocol::ROUTE as _, addr) - .map_err(|err| (err, self))?; +impl datagram_common::Unbound for UnboundNetlinkRoute { + type Endpoint = NetlinkSocketAddr; + type BindOptions = (); + + type Bound = BoundNetlinkRoute; + + fn bind( + &mut self, + endpoint: &Self::Endpoint, + _pollee: &Pollee, + _options: Self::BindOptions, + ) -> Result { + let bound_handle = + NETLINK_SOCKET_TABLE.bind(StandardNetlinkProtocol::ROUTE as _, endpoint)?; Ok(BoundNetlinkRoute::new(bound_handle)) } - pub(super) fn check_io_events(&self) -> IoEvents { + fn bind_ephemeral( + &mut self, + _remote_endpoint: &Self::Endpoint, + _pollee: &Pollee, + ) -> Result { + let bound_handle = NETLINK_SOCKET_TABLE.bind( + StandardNetlinkProtocol::ROUTE as _, + &NetlinkSocketAddr::new_unspecified(), + )?; + + Ok(BoundNetlinkRoute::new(bound_handle)) + } + + fn check_io_events(&self) -> IoEvents { IoEvents::OUT } } diff --git a/kernel/src/net/socket/netlink/table/mod.rs b/kernel/src/net/socket/netlink/table/mod.rs index cd45c7a8..81f15ed8 100644 --- a/kernel/src/net/socket/netlink/table/mod.rs +++ b/kernel/src/net/socket/netlink/table/mod.rs @@ -137,6 +137,10 @@ impl BoundHandle { } } + pub(super) const fn port(&self) -> PortNum { + self.port + } + pub(super) const fn addr(&self) -> NetlinkSocketAddr { NetlinkSocketAddr::new(self.port, self.groups) }