diff --git a/kernel/src/net/socket/netlink/addr/mod.rs b/kernel/src/net/socket/netlink/addr/mod.rs index b26cee6c..276ce223 100644 --- a/kernel/src/net/socket/netlink/addr/mod.rs +++ b/kernel/src/net/socket/netlink/addr/mod.rs @@ -43,6 +43,11 @@ impl NetlinkSocketAddr { pub const fn groups(&self) -> GroupIdSet { self.groups } + + /// Adds some new groups to the address. + pub fn add_groups(&mut self, groups: GroupIdSet) { + self.groups.add_groups(groups); + } } impl TryFrom for NetlinkSocketAddr { diff --git a/kernel/src/net/socket/netlink/addr/multicast.rs b/kernel/src/net/socket/netlink/addr/multicast.rs index 172a5475..22044305 100644 --- a/kernel/src/net/socket/netlink/addr/multicast.rs +++ b/kernel/src/net/socket/netlink/addr/multicast.rs @@ -24,21 +24,14 @@ impl GroupIdSet { GroupIdIter::new(self) } - /// Adds a new group. - /// - /// If the group already exists, this method will return an error. - pub fn add_group(&mut self, group_id: GroupId) -> Result<()> { - if group_id >= 32 { - return_errno_with_message!(Errno::EINVAL, "the group ID is invalid"); - } + /// Adds some new groups. + pub fn add_groups(&mut self, groups: GroupIdSet) { + self.0 |= groups.0; + } - let mask = 1u32 << group_id; - if self.0 & mask != 0 { - return_errno_with_message!(Errno::EINVAL, "the group ID already exists"); - } - self.0 |= mask; - - Ok(()) + /// Drops some groups. + pub fn drop_groups(&mut self, groups: GroupIdSet) { + self.0 &= !groups.0; } /// Sets new groups. diff --git a/kernel/src/net/socket/netlink/common/bound.rs b/kernel/src/net/socket/netlink/common/bound.rs new file mode 100644 index 00000000..24aaf211 --- /dev/null +++ b/kernel/src/net/socket/netlink/common/bound.rs @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MPL-2.0 + +use crate::{ + events::IoEvents, + net::socket::netlink::{ + receiver::MessageQueue, table::BoundHandle, GroupIdSet, NetlinkSocketAddr, + }, + prelude::*, +}; + +pub struct BoundNetlink { + pub(in crate::net::socket::netlink) handle: BoundHandle, + pub(in crate::net::socket::netlink) remote_addr: NetlinkSocketAddr, + pub(in crate::net::socket::netlink) receive_queue: MessageQueue, +} + +impl BoundNetlink { + pub(super) fn new(handle: BoundHandle, message_queue: MessageQueue) -> Self { + Self { + handle, + remote_addr: NetlinkSocketAddr::new_unspecified(), + receive_queue: message_queue, + } + } + + pub(in crate::net::socket::netlink) fn bind_common( + &mut self, + endpoint: &NetlinkSocketAddr, + ) -> Result<()> { + if endpoint.port() != self.handle.port() { + return_errno_with_message!( + Errno::EINVAL, + "the socket cannot be bound to a different port" + ); + } + + let groups = endpoint.groups(); + self.handle.bind_groups(groups); + + Ok(()) + } + + pub(in crate::net::socket::netlink) fn check_io_events_common(&self) -> IoEvents { + let mut events = IoEvents::OUT; + + let receive_queue = self.receive_queue.0.lock(); + if !receive_queue.is_empty() { + events |= IoEvents::IN; + } + + events + } + + pub(super) fn add_groups(&mut self, groups: GroupIdSet) { + self.handle.add_groups(groups); + } + + pub(super) fn drop_groups(&mut self, groups: GroupIdSet) { + self.handle.drop_groups(groups); + } +} diff --git a/kernel/src/net/socket/netlink/common/mod.rs b/kernel/src/net/socket/netlink/common/mod.rs new file mode 100644 index 00000000..4ac2c817 --- /dev/null +++ b/kernel/src/net/socket/netlink/common/mod.rs @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::sync::atomic::{AtomicBool, Ordering}; + +pub(super) use bound::BoundNetlink; +use unbound::UnboundNetlink; + +use super::{GroupIdSet, NetlinkSocketAddr}; +use crate::{ + events::IoEvents, + match_sock_option_ref, + net::socket::{ + netlink::{table::SupportedNetlinkProtocol, AddMembership, DropMembership}, + options::SocketOption, + private::SocketPrivate, + util::datagram_common::{select_remote_and_bind, Bound, Inner}, + MessageHeader, SendRecvFlags, Socket, SocketAddr, + }, + prelude::*, + process::signal::{PollHandle, Pollable, Pollee}, + util::{MultiRead, MultiWrite}, +}; + +mod bound; +mod unbound; + +pub struct NetlinkSocket { + inner: RwMutex, BoundNetlink>>, + + is_nonblocking: AtomicBool, + pollee: Pollee, +} + +impl NetlinkSocket

+where + BoundNetlink: Bound, +{ + pub fn new(is_nonblocking: bool) -> Arc { + let unbound = UnboundNetlink::new(); + Arc::new(Self { + inner: RwMutex::new(Inner::Unbound(unbound)), + is_nonblocking: AtomicBool::new(is_nonblocking), + pollee: Pollee::new(), + }) + } + + fn try_send( + &self, + reader: &mut dyn MultiRead, + remote: Option<&NetlinkSocketAddr>, + flags: SendRecvFlags, + ) -> Result { + let sent_bytes = select_remote_and_bind( + &self.inner, + remote, + || { + self.inner + .write() + .bind_ephemeral(&NetlinkSocketAddr::new_unspecified(), &self.pollee) + }, + |bound, remote_endpoint| bound.try_send(reader, remote_endpoint, flags), + )?; + self.pollee.invalidate(); + + Ok(sent_bytes) + } + + // FIXME: This method is marked as `pub(super)` because it's invoked during kernel mode testing. + pub(super) fn try_recv( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, SocketAddr)> { + let recv_bytes = self + .inner + .read() + .try_recv(writer, flags) + .map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()))?; + self.pollee.invalidate(); + + Ok(recv_bytes) + } +} + +impl Socket for NetlinkSocket

+where + BoundNetlink: Bound, +{ + fn bind(&self, socket_addr: SocketAddr) -> Result<()> { + let endpoint = socket_addr.try_into()?; + + self.inner.write().bind(&endpoint, &self.pollee, ()) + } + + fn connect(&self, socket_addr: SocketAddr) -> Result<()> { + let endpoint = socket_addr.try_into()?; + + self.inner.write().connect(&endpoint, &self.pollee) + } + + fn addr(&self) -> Result { + let endpoint = match &*self.inner.read() { + Inner::Unbound(unbound) => unbound.addr(), + Inner::Bound(bound) => bound.local_endpoint(), + }; + + Ok(endpoint.into()) + } + + fn peer_addr(&self) -> Result { + let endpoint = self + .inner + .read() + .peer_addr() + .cloned() + .unwrap_or(NetlinkSocketAddr::new_unspecified()); + + Ok(endpoint.into()) + } + + fn sendmsg( + &self, + reader: &mut dyn MultiRead, + message_header: MessageHeader, + flags: SendRecvFlags, + ) -> Result { + let MessageHeader { + addr, + control_message, + } = message_header; + + let remote = match addr { + None => None, + Some(addr) => Some(addr.try_into()?), + }; + + if control_message.is_some() { + // TODO: Support sending control message + warn!("sending control message is not supported"); + } + + // TODO: Make sure our blocking behavior matches that of Linux + self.try_send(reader, remote.as_ref(), flags) + } + + fn recvmsg( + &self, + writers: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, MessageHeader)> { + let (received_len, addr) = self.block_on(IoEvents::IN, || self.try_recv(writers, flags))?; + + // TODO: Receive control message + + let message_header = MessageHeader::new(Some(addr), None); + + Ok((received_len, message_header)) + } + + fn set_option(&self, option: &dyn SocketOption) -> Result<()> { + match do_set_netlink_option(&self.inner, option) { + Ok(()) => Ok(()), + Err(e) => { + warn!( + "We currently ignore set option errors to pass libnl test: {:?}", + e + ); + Ok(()) + } + } + } +} + +impl SocketPrivate for NetlinkSocket

+where + BoundNetlink: Bound, +{ + fn is_nonblocking(&self) -> bool { + self.is_nonblocking.load(Ordering::Relaxed) + } + + fn set_nonblocking(&self, nonblocking: bool) { + self.is_nonblocking.store(nonblocking, Ordering::Relaxed); + } +} + +impl Pollable for NetlinkSocket

+where + BoundNetlink: Bound, +{ + fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { + self.pollee + .poll_with(mask, poller, || self.inner.read().check_io_events()) + } +} + +impl Inner, BoundNetlink> { + fn add_groups(&mut self, groups: GroupIdSet) { + match self { + Inner::Unbound(unbound_socket) => unbound_socket.add_groups(groups), + Inner::Bound(bound_socket) => bound_socket.add_groups(groups), + } + } + + fn drop_groups(&mut self, groups: GroupIdSet) { + match self { + Inner::Unbound(unbound_socket) => unbound_socket.drop_groups(groups), + Inner::Bound(bound_socket) => bound_socket.drop_groups(groups), + } + } +} + +fn do_set_netlink_option( + inner: &RwMutex, BoundNetlink>>, + option: &dyn SocketOption, +) -> Result<()> { + match_sock_option_ref!(option, { + add_membership: AddMembership => { + let groups = add_membership.get().unwrap(); + inner.write().add_groups(GroupIdSet::new(*groups)); + }, + drop_membership: DropMembership => { + let groups = drop_membership.get().unwrap(); + inner.write().drop_groups(GroupIdSet::new(*groups)); + }, + _ => return_errno_with_message!(Errno::ENOPROTOOPT, "the socket option to be set is unknown") + }); + + Ok(()) +} diff --git a/kernel/src/net/socket/netlink/common/unbound.rs b/kernel/src/net/socket/netlink/common/unbound.rs new file mode 100644 index 00000000..8656bf98 --- /dev/null +++ b/kernel/src/net/socket/netlink/common/unbound.rs @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::marker::PhantomData; + +use crate::{ + events::IoEvents, + net::socket::{ + netlink::{ + common::bound::BoundNetlink, + receiver::{MessageQueue, MessageReceiver}, + table::SupportedNetlinkProtocol, + GroupIdSet, NetlinkSocketAddr, + }, + util::datagram_common, + }, + prelude::*, + process::signal::Pollee, +}; + +pub(super) struct UnboundNetlink { + groups: GroupIdSet, + phantom: PhantomData>, +} + +impl UnboundNetlink

{ + pub(super) const fn new() -> Self { + Self { + groups: GroupIdSet::new_empty(), + phantom: PhantomData, + } + } + + pub(super) fn addr(&self) -> NetlinkSocketAddr { + NetlinkSocketAddr::new(0, self.groups) + } + + pub(super) fn add_groups(&mut self, groups: GroupIdSet) { + self.groups.add_groups(groups); + } + + pub(super) fn drop_groups(&mut self, groups: GroupIdSet) { + self.groups.drop_groups(groups); + } +} + +impl datagram_common::Unbound for UnboundNetlink

{ + type Endpoint = NetlinkSocketAddr; + type BindOptions = (); + + type Bound = BoundNetlink; + + fn bind( + &mut self, + endpoint: &Self::Endpoint, + pollee: &Pollee, + _options: Self::BindOptions, + ) -> Result { + let message_queue = MessageQueue::::new(); + + let bound_handle = { + let endpoint = { + let mut endpoint = endpoint.clone(); + endpoint.add_groups(self.groups); + endpoint + }; + let receiver = MessageReceiver::new(message_queue.clone(), pollee.clone()); +

::bind(&endpoint, receiver)? + }; + + Ok(BoundNetlink::new(bound_handle, message_queue)) + } + + fn bind_ephemeral( + &mut self, + _remote_endpoint: &Self::Endpoint, + pollee: &Pollee, + ) -> Result { + let message_queue = MessageQueue::::new(); + + let bound_handle = { + let endpoint = { + let mut endpoint = NetlinkSocketAddr::new_unspecified(); + endpoint.add_groups(self.groups); + endpoint + }; + let receiver = MessageReceiver::new(message_queue.clone(), pollee.clone()); +

::bind(&endpoint, receiver)? + }; + + Ok(BoundNetlink::new(bound_handle, message_queue)) + } + + fn check_io_events(&self) -> IoEvents { + IoEvents::OUT + } +} diff --git a/kernel/src/net/socket/netlink/kobject_uevent/bound.rs b/kernel/src/net/socket/netlink/kobject_uevent/bound.rs new file mode 100644 index 00000000..0cd2df9b --- /dev/null +++ b/kernel/src/net/socket/netlink/kobject_uevent/bound.rs @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::ops::Sub; + +use super::message::UeventMessage; +use crate::{ + events::IoEvents, + net::socket::{ + netlink::{common::BoundNetlink, NetlinkSocketAddr}, + util::datagram_common, + SendRecvFlags, + }, + prelude::*, + util::{MultiRead, MultiWrite}, +}; + +pub(super) type BoundNetlinkUevent = BoundNetlink; + +impl datagram_common::Bound for BoundNetlinkUevent { + type Endpoint = NetlinkSocketAddr; + + fn local_endpoint(&self) -> Self::Endpoint { + self.handle.addr() + } + + fn bind(&mut self, endpoint: &Self::Endpoint) -> Result<()> { + self.bind_common(endpoint) + } + + fn remote_endpoint(&self) -> Option<&Self::Endpoint> { + Some(&self.remote_addr) + } + + fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint) { + self.remote_addr = *endpoint; + } + + fn try_send( + &self, + reader: &mut dyn MultiRead, + remote: &Self::Endpoint, + flags: SendRecvFlags, + ) -> Result { + // TODO: Deal with flags + if !flags.is_all_supported() { + warn!("unsupported flags: {:?}", flags); + } + + if *remote != NetlinkSocketAddr::new_unspecified() { + return_errno_with_message!( + Errno::ECONNREFUSED, + "sending uevent messages to user space is not supported" + ); + } + + // FIXME: How to deal with sending message to kernel socket? + // Here we simply ignore the message and return the message length. + Ok(reader.sum_lens()) + } + + fn try_recv( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, Self::Endpoint)> { + // TODO: Deal with other flags. Only MSG_PEEK is handled here. + if !flags.sub(SendRecvFlags::MSG_PEEK).is_all_supported() { + warn!("unsupported flags: {:?}", flags); + } + + let mut receive_queue = self.receive_queue.0.lock(); + + let Some(response) = receive_queue.front() else { + return_errno_with_message!(Errno::EAGAIN, "nothing to receive"); + }; + + let len = { + let max_len = writer.sum_lens(); + response.total_len().min(max_len) + }; + + response.write_to(writer)?; + + let remote = response.src_addr().clone(); + + if !flags.contains(SendRecvFlags::MSG_PEEK) { + receive_queue.pop_front().unwrap(); + } + + Ok((len, remote)) + } + + fn check_io_events(&self) -> IoEvents { + self.check_io_events_common() + } +} diff --git a/kernel/src/net/socket/netlink/kobject_uevent/message/mod.rs b/kernel/src/net/socket/netlink/kobject_uevent/message/mod.rs new file mode 100644 index 00000000..0e5cf207 --- /dev/null +++ b/kernel/src/net/socket/netlink/kobject_uevent/message/mod.rs @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MPL-2.0 + +#![cfg_attr(not(ktest), expect(dead_code))] + +use uevent::Uevent; + +use crate::{ + net::socket::netlink::{table::MulticastMessage, NetlinkSocketAddr}, + prelude::*, + util::MultiWrite, +}; + +mod syn_uevent; +#[cfg(ktest)] +mod test; +mod uevent; + +/// A uevent message. +/// +/// Note that uevent messages are not the same as common netlink messages. +/// It does not have a netlink header. +#[derive(Debug, Clone)] +pub struct UeventMessage { + uevent: String, + src_addr: NetlinkSocketAddr, +} + +impl UeventMessage { + /// Creates a new uevent message. + fn new(uevent: Uevent, src_addr: NetlinkSocketAddr) -> Self { + Self { + uevent: uevent.to_string(), + src_addr, + } + } + + /// Returns the source address of the uevent message. + pub(super) fn src_addr(&self) -> &NetlinkSocketAddr { + &self.src_addr + } + + /// Returns the total length of the uevent. + pub(super) fn total_len(&self) -> usize { + self.uevent.len() + } + + /// Writes the uevent to the given `writer`. + pub(super) fn write_to(&self, writer: &mut dyn MultiWrite) -> Result<()> { + // FIXME: If the message can be truncated, we should avoid returning an error. + if self.uevent.len() > writer.sum_lens() { + return_errno_with_message!(Errno::EFAULT, "the writer length is too small"); + } + writer.write(&mut VmReader::from(self.uevent.as_bytes()))?; + Ok(()) + } +} + +impl MulticastMessage for UeventMessage {} diff --git a/kernel/src/net/socket/netlink/kobject_uevent/message/syn_uevent.rs b/kernel/src/net/socket/netlink/kobject_uevent/message/syn_uevent.rs new file mode 100644 index 00000000..516f8474 --- /dev/null +++ b/kernel/src/net/socket/netlink/kobject_uevent/message/syn_uevent.rs @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! The synthetic uevent. +//! +//! The event is triggered when someone writes some content to `/sys/.../uevent` file. +//! It differs from the event triggers by devices. +//! +//! Reference: . + +use alloc::format; +use core::str::FromStr; + +use super::uevent::SysObjAction; +use crate::prelude::*; + +/// The synthetic uevent. +pub(super) struct SyntheticUevent { + pub(super) action: SysObjAction, + pub(super) uuid: Option, + pub(super) envs: Vec<(String, String)>, +} + +impl FromStr for SyntheticUevent { + type Err = Error; + + fn from_str(s: &str) -> Result { + let mut split = s.split(" "); + + let action = { + let Some(action_str) = split.next() else { + return_errno_with_message!(Errno::EINVAL, "the string is empty"); + }; + SysObjAction::from_str(action_str)? + }; + + let uuid = if let Some(uuid_str) = split.next() { + Some(Uuid::from_str(uuid_str)?) + } else { + None + }; + + let mut envs = Vec::new(); + for env_str in split.into_iter() { + let (key, value) = { + // Each string should be in the `KEY=VALUE` format. + match env_str.split_once('=') { + Some(key_value) => key_value, + None => return_errno_with_message!(Errno::EINVAL, "invalid key value pairs"), + } + }; + + // Both `KEY` and `VALUE` can contain alphanumeric characters only. + for byte in key.as_bytes().iter().chain(value.as_bytes()) { + if !byte.is_ascii_alphanumeric() { + return_errno_with_message!( + Errno::EINVAL, + "invalid character in key value pairs" + ); + } + } + + // The `KEY` name gains `SYNTH_ARG_` prefix to avoid possible collisions + // with existing variables. + let key = format!("SYNTH_ARG_{}", key); + let value = value.to_string(); + envs.push((key, value)); + } + + Ok(Self { action, uuid, envs }) + } +} + +pub(super) struct Uuid(pub(super) String); + +impl FromStr for Uuid { + type Err = Error; + + fn from_str(s: &str) -> Result { + /// The allowed UUID pattern, where each `x` is a hex digit. + /// + /// Reference: . + const UUID_PATTERN: &str = "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"; + + let bytes = s.as_bytes(); + + if bytes.len() != UUID_PATTERN.len() { + return_errno_with_message!(Errno::EINVAL, "the UUID length is invalid"); + } + + for (byte, pattern) in bytes.into_iter().zip(UUID_PATTERN.as_bytes()) { + if *pattern == b'x' && byte.is_ascii_hexdigit() { + continue; + } else if *pattern == b'-' && *byte == b'-' { + continue; + } else { + return_errno_with_message!(Errno::EINVAL, "the UUID content is invalid"); + } + } + + Ok(Self(s.to_string())) + } +} diff --git a/kernel/src/net/socket/netlink/kobject_uevent/message/test.rs b/kernel/src/net/socket/netlink/kobject_uevent/message/test.rs new file mode 100644 index 00000000..8a848a00 --- /dev/null +++ b/kernel/src/net/socket/netlink/kobject_uevent/message/test.rs @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MPL-2.0 + +use alloc::vec; +use core::str::FromStr; + +use ostd::{mm::VmWriter, prelude::*}; + +use crate::{ + net::socket::{ + netlink::{ + kobject_uevent::{ + message::{ + syn_uevent::{SyntheticUevent, Uuid}, + uevent::Uevent, + }, + UeventMessage, + }, + table::{NetlinkUeventProtocol, SupportedNetlinkProtocol}, + GroupIdSet, NetlinkSocketAddr, NetlinkUeventSocket, + }, + SendRecvFlags, Socket, SocketAddr, + }, + prelude::*, +}; + +#[ktest] +fn uuid() { + let uuid = Uuid::from_str("12345678-1234-1234-1234-123456789012"); + assert!(uuid.is_ok()); + + let uuid = Uuid::from_str("12345678-1234-1234-1234-12345678901"); + assert!(uuid.is_err()); + + let uuid = Uuid::from_str("12345678-1234-1234-1234-1234567890g"); + assert!(uuid.is_err()); +} + +#[ktest] +fn synthetic_uevent() { + let uevent = SyntheticUevent::from_str("add"); + assert!(uevent.is_ok()); + + let uevent = SyntheticUevent::from_str("add 12345678-1234-1234-1234-123456789012"); + assert!(uevent.is_ok()); + + let uevent = SyntheticUevent::from_str("add 12345678-1234-1234-1234-123456789012 NAME=lo"); + assert!(uevent.is_ok()); +} + +#[ktest] +fn multicast_synthetic_uevent() { + crate::net::socket::netlink::init(); + + // Creates a new netlink uevent socket and joins the group for kobject uevents. + let socket = NetlinkUeventSocket::new(true); + let socket_addr = SocketAddr::Netlink(NetlinkSocketAddr::new(100, GroupIdSet::new(0x1))); + socket.bind(socket_addr).unwrap(); + + // Tries to receive and returns EAGAIN if no message is available. + let mut buffer = vec![0u8; 1024]; + let mut writer = VmWriter::from(buffer.as_mut_slice()).to_fallible(); + let res = socket.try_recv(&mut writer, SendRecvFlags::empty()); + assert!(res.is_err_and(|err| err.error() == Errno::EAGAIN)); + + // Broadcasts a uevent message. + let uevent = { + let lo_infos = vec![ + ("INTERFACE".to_string(), "lo".to_string()), + ("IFINDEX".to_string(), "1".to_string()), + ]; + let synth_uevent = SyntheticUevent::from_str("add").unwrap(); + Uevent::new_from_syn( + synth_uevent, + "/devices/virtual/net/lo".to_string(), + "net".to_string(), + lo_infos, + ) + }; + let uevent_message = + UeventMessage::new(uevent, NetlinkSocketAddr::new(0, GroupIdSet::new(0x1))); + NetlinkUeventProtocol::multicast(GroupIdSet::new(0x1), uevent_message).unwrap(); + + let (len, _) = socket + .try_recv(&mut writer, SendRecvFlags::empty()) + .unwrap(); + let s = core::str::from_utf8(&buffer[..len]).unwrap(); + + assert_eq!(s, "add@/devices/virtual/net/lo\0ACTION=add\0DEVPATH=/devices/virtual/net/lo\0SUBSYSTEM=net\0SYNTH_UUID=0\0INTERFACE=lo\0IFINDEX=1\0SEQNUM=1\0"); +} diff --git a/kernel/src/net/socket/netlink/kobject_uevent/message/uevent.rs b/kernel/src/net/socket/netlink/kobject_uevent/message/uevent.rs new file mode 100644 index 00000000..1bd8e079 --- /dev/null +++ b/kernel/src/net/socket/netlink/kobject_uevent/message/uevent.rs @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MPL-2.0 + +use alloc::format; +use core::{ + str::FromStr, + sync::atomic::{AtomicU64, Ordering}, +}; + +use super::syn_uevent::{SyntheticUevent, Uuid}; +use crate::prelude::*; + +/// `SysObj` action type. +/// +/// Reference: . +#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromInt)] +#[repr(u8)] +pub(super) enum SysObjAction { + /// Indicates the addition of a new `SysObj` to the system. + /// + /// Triggered when a device is discovered or registered. + Add = 0, + + /// Signals the removal of a `SysObj` from the system. + /// + /// Typically occurs during device disconnection or deregistration. + Remove = 1, + + /// Denotes a modification to the `SysObj`'s properties or state. + /// + /// Used for attribute changes that don't involve structural modifications. + Change = 2, + + /// Represents hierarchical relocation of a `SysObj`. + /// + /// Occurs when a device moves within the device tree topology. + Move = 3, + + /// Marks a device returning to operational status after being offlined. + /// + /// Common in hot-pluggable device scenarios. + Online = 4, + + /// Indicates a device entering non-operational status. + /// + /// Typically precedes safe removal of hot-pluggable hardware. + Offline = 5, + + /// Signifies successful driver-device binding. + /// + /// Occurs after successful driver probe sequence. + Bind = 6, + + /// Indicates driver-device binding termination. + /// + /// Precedes driver unload or device removal. + Unbind = 7, +} + +const SYSOBJ_ACTION_STRS: [&str; SysObjAction::Unbind as usize + 1] = [ + "add", "remove", "change", "move", "online", "offline", "bind", "unbind", +]; + +impl FromStr for SysObjAction { + type Err = Error; + + fn from_str(s: &str) -> Result { + let Some(index) = SYSOBJ_ACTION_STRS + .iter() + .position(|action_str| s == *action_str) + else { + return_errno_with_message!(Errno::EINVAL, "the string is not a valid `SysObj` action"); + }; + + Ok(SysObjAction::try_from(index as u8).unwrap()) + } +} + +impl SysObjAction { + fn as_str(&self) -> &'static str { + SYSOBJ_ACTION_STRS[*self as usize] + } +} + +/// Userspace event. +pub(super) struct Uevent { + /// The `SysObj` action. + action: SysObjAction, + /// The absolute `SysObj` path under sysfs. + devpath: String, + /// The subsystem the event originates from + subsystem: String, + /// Other key-value arguments + envs: Vec<(String, String)>, + /// Sequence number. + seq_num: u64, +} + +impl Uevent { + /// Creates a new uevent. + fn new( + action: SysObjAction, + devpath: String, + subsystem: String, + envs: Vec<(String, String)>, + ) -> Self { + debug_assert!(devpath.starts_with('/')); + + let seq_num = SEQ_NUM_ALLOCATOR.fetch_add(1, Ordering::Relaxed); + + Self { + action, + devpath, + subsystem, + envs, + seq_num, + } + } + + /// Creates a new uevent from synthetic uevent. + pub(super) fn new_from_syn( + synth_uevent: SyntheticUevent, + devpath: String, + subsystem: String, + mut other_envs: Vec<(String, String)>, + ) -> Self { + let SyntheticUevent { + action, + uuid, + mut envs, + } = synth_uevent; + + let uuid_key = "SYNTH_UUID".to_string(); + if let Some(Uuid(uuid)) = uuid { + envs.push((uuid_key, uuid)); + } else { + envs.push((uuid_key, "0".to_string())); + }; + + envs.append(&mut other_envs); + + Self::new(action, devpath, subsystem, envs) + } +} + +impl ToString for Uevent { + fn to_string(&self) -> String { + let mut env_string = { + let len = self + .envs + .iter() + .map(|(key, value)| key.len() + value.len() + 2) + .sum(); + String::with_capacity(len) + }; + + for (key, value) in self.envs.iter() { + env_string.push_str(key); + env_string.push('='); + env_string.push_str(value); + env_string.push('\0'); + } + + format!( + "{}@{}\0ACTION={}\0DEVPATH={}\0SUBSYSTEM={}\0{}SEQNUM={}\0", + self.action.as_str(), + self.devpath, + self.action.as_str(), + self.devpath, + self.subsystem, + env_string, + self.seq_num + ) + } +} + +static SEQ_NUM_ALLOCATOR: AtomicU64 = AtomicU64::new(1); diff --git a/kernel/src/net/socket/netlink/kobject_uevent/mod.rs b/kernel/src/net/socket/netlink/kobject_uevent/mod.rs new file mode 100644 index 00000000..4534239c --- /dev/null +++ b/kernel/src/net/socket/netlink/kobject_uevent/mod.rs @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: MPL-2.0 + +pub(super) use message::UeventMessage; + +use crate::net::socket::netlink::{common::NetlinkSocket, table::NetlinkUeventProtocol}; + +mod bound; +mod message; + +pub type NetlinkUeventSocket = NetlinkSocket; diff --git a/kernel/src/net/socket/netlink/message/mod.rs b/kernel/src/net/socket/netlink/message/mod.rs index cd354fab..dd82e9a2 100644 --- a/kernel/src/net/socket/netlink/message/mod.rs +++ b/kernel/src/net/socket/netlink/message/mod.rs @@ -26,7 +26,7 @@ use crate::{ /// A netlink message can be transmitted to and from user space using a single send/receive syscall. /// It consists of one or more [`ProtocolSegment`]s. #[derive(Debug)] -pub(super) struct Message { +pub struct Message { segments: Vec, } @@ -69,7 +69,7 @@ impl Message { } } -pub(super) trait ProtocolSegment: Sized { +pub trait ProtocolSegment: Sized { fn header(&self) -> &CMsgSegHdr; fn header_mut(&mut self) -> &mut CMsgSegHdr; fn read_from(reader: &mut dyn MultiRead) -> Result; diff --git a/kernel/src/net/socket/netlink/message/segment/header.rs b/kernel/src/net/socket/netlink/message/segment/header.rs index 2e442d73..40118f4e 100644 --- a/kernel/src/net/socket/netlink/message/segment/header.rs +++ b/kernel/src/net/socket/netlink/message/segment/header.rs @@ -23,7 +23,7 @@ pub struct CMsgSegHdr { } bitflags! { - /// Common flags used in [`CMsgSegmentHdr`]. + /// Common flags used in [`CMsgSegHdr`]. /// /// Reference: . pub struct SegHdrCommonFlags: u16 { diff --git a/kernel/src/net/socket/netlink/mod.rs b/kernel/src/net/socket/netlink/mod.rs index b789a6b7..34a667a4 100644 --- a/kernel/src/net/socket/netlink/mod.rs +++ b/kernel/src/net/socket/netlink/mod.rs @@ -38,11 +38,17 @@ //! mod addr; +mod common; +mod kobject_uevent; mod message; +mod options; +mod receiver; mod route; mod table; pub use addr::{GroupIdSet, NetlinkSocketAddr}; +pub use kobject_uevent::NetlinkUeventSocket; +pub use options::{AddMembership, DropMembership}; pub use route::NetlinkRouteSocket; pub use table::{is_valid_protocol, StandardNetlinkProtocol}; diff --git a/kernel/src/net/socket/netlink/options.rs b/kernel/src/net/socket/netlink/options.rs new file mode 100644 index 00000000..ac10b8fa --- /dev/null +++ b/kernel/src/net/socket/netlink/options.rs @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: MPL-2.0 + +use crate::impl_socket_options; + +impl_socket_options!( + pub struct AddMembership(u32); + pub struct DropMembership(u32); +); diff --git a/kernel/src/net/socket/netlink/receiver.rs b/kernel/src/net/socket/netlink/receiver.rs new file mode 100644 index 00000000..88b8b8d5 --- /dev/null +++ b/kernel/src/net/socket/netlink/receiver.rs @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MPL-2.0 + +use crate::{events::IoEvents, prelude::*, process::signal::Pollee}; + +pub struct MessageReceiver { + message_queue: MessageQueue, + pollee: Pollee, +} + +pub(super) struct MessageQueue(pub(super) Arc>>); + +impl Clone for MessageQueue { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl MessageQueue { + pub(super) fn new() -> Self { + Self(Arc::new(Mutex::new(VecDeque::new()))) + } + + fn enqueue(&self, message: Message) -> Result<()> { + // FIXME: We should verify the socket buffer length to ensure + // that adding the message doesn't exceed the buffer capacity. + self.0.lock().push_back(message); + Ok(()) + } +} + +impl MessageReceiver { + pub(super) const fn new(message_queue: MessageQueue, pollee: Pollee) -> Self { + Self { + message_queue, + pollee, + } + } + + pub(super) fn enqueue_message(&self, message: Message) -> Result<()> { + self.message_queue.enqueue(message)?; + self.pollee.notify(IoEvents::IN); + + Ok(()) + } +} diff --git a/kernel/src/net/socket/netlink/route/bound.rs b/kernel/src/net/socket/netlink/route/bound.rs index 9331ab4d..2bad4990 100644 --- a/kernel/src/net/socket/netlink/route/bound.rs +++ b/kernel/src/net/socket/netlink/route/bound.rs @@ -7,8 +7,8 @@ use crate::{ events::IoEvents, net::socket::{ netlink::{ - message::ProtocolSegment, route::kernel::get_netlink_route_kernel, table::BoundHandle, - NetlinkSocketAddr, + common::BoundNetlink, message::ProtocolSegment, + route::kernel::get_netlink_route_kernel, NetlinkSocketAddr, }, util::datagram_common, SendRecvFlags, @@ -17,21 +17,7 @@ use crate::{ util::{MultiRead, MultiWrite}, }; -pub(super) struct BoundNetlinkRoute { - handle: BoundHandle, - remote_addr: NetlinkSocketAddr, - receive_queue: Mutex>, -} - -impl BoundNetlinkRoute { - pub(super) const fn new(handle: BoundHandle) -> Self { - Self { - handle, - remote_addr: NetlinkSocketAddr::new_unspecified(), - receive_queue: Mutex::new(VecDeque::new()), - } - } -} +pub(super) type BoundNetlinkRoute = BoundNetlink; impl datagram_common::Bound for BoundNetlinkRoute { type Endpoint = NetlinkSocketAddr; @@ -40,6 +26,10 @@ impl datagram_common::Bound for BoundNetlinkRoute { self.handle.addr() } + fn bind(&mut self, endpoint: &Self::Endpoint) -> Result<()> { + self.bind_common(endpoint) + } + fn remote_endpoint(&self) -> Option<&Self::Endpoint> { Some(&self.remote_addr) } @@ -97,9 +87,7 @@ impl datagram_common::Bound for BoundNetlinkRoute { } } - get_netlink_route_kernel().request(&nlmsg, |response| { - self.receive_queue.lock().push_back(response); - }); + get_netlink_route_kernel().request(&nlmsg, local_port); Ok(nlmsg.total_len()) } @@ -114,7 +102,7 @@ impl datagram_common::Bound for BoundNetlinkRoute { warn!("unsupported flags: {:?}", flags); } - let mut receive_queue = self.receive_queue.lock(); + let mut receive_queue = self.receive_queue.0.lock(); let Some(response) = receive_queue.front() else { return_errno_with_message!(Errno::EAGAIN, "nothing to receive"); @@ -138,13 +126,6 @@ impl datagram_common::Bound for BoundNetlinkRoute { } fn check_io_events(&self) -> IoEvents { - let mut events = IoEvents::OUT; - - let receive_queue = self.receive_queue.lock(); - if !receive_queue.is_empty() { - events |= IoEvents::IN; - } - - events + self.check_io_events_common() } } diff --git a/kernel/src/net/socket/netlink/route/kernel/mod.rs b/kernel/src/net/socket/netlink/route/kernel/mod.rs index 4635ab0b..040406f7 100644 --- a/kernel/src/net/socket/netlink/route/kernel/mod.rs +++ b/kernel/src/net/socket/netlink/route/kernel/mod.rs @@ -7,7 +7,11 @@ use core::marker::PhantomData; use super::message::{RtnlMessage, RtnlSegment}; use crate::{ - net::socket::netlink::message::{CSegmentType, ErrorSegment, ProtocolSegment}, + net::socket::netlink::{ + addr::PortNum, + message::{CSegmentType, ErrorSegment, ProtocolSegment}, + table::{NetlinkRouteProtocol, SupportedNetlinkProtocol}, + }, prelude::*, }; @@ -26,11 +30,7 @@ impl NetlinkRouteKernelSocket { } } - pub(super) fn request( - &self, - request: &RtnlMessage, - mut consume_response: F, - ) { + pub(super) fn request(&self, request: &RtnlMessage, dst_port: PortNum) { debug!("netlink route request: {:?}", request); for segment in request.segments() { @@ -61,7 +61,7 @@ impl NetlinkRouteKernelSocket { debug!("netlink route response: {:?}", response); - consume_response(response); + NetlinkRouteProtocol::unicast(dst_port, response).unwrap(); } } } diff --git a/kernel/src/net/socket/netlink/route/message/mod.rs b/kernel/src/net/socket/netlink/route/message/mod.rs index 8b3d0f39..e8c0585e 100644 --- a/kernel/src/net/socket/netlink/route/message/mod.rs +++ b/kernel/src/net/socket/netlink/route/message/mod.rs @@ -18,4 +18,4 @@ pub(super) use segment::{ use crate::net::socket::netlink::message::Message; /// A netlink route message. -pub(super) type RtnlMessage = Message; +pub(in crate::net::socket::netlink) type RtnlMessage = Message; diff --git a/kernel/src/net/socket/netlink/route/mod.rs b/kernel/src/net/socket/netlink/route/mod.rs index 1313449b..b2c80aad 100644 --- a/kernel/src/net/socket/netlink/route/mod.rs +++ b/kernel/src/net/socket/netlink/route/mod.rs @@ -2,179 +2,12 @@ //! Netlink Route Socket. -use core::sync::atomic::{AtomicBool, Ordering}; +pub(super) use message::RtnlMessage; -use bound::BoundNetlinkRoute; -use unbound::UnboundNetlinkRoute; - -use super::NetlinkSocketAddr; -use crate::{ - events::IoEvents, - net::socket::{ - options::SocketOption, - private::SocketPrivate, - util::datagram_common::{select_remote_and_bind, Bound, Inner}, - MessageHeader, SendRecvFlags, Socket, SocketAddr, - }, - prelude::*, - process::signal::{PollHandle, Pollable, Pollee}, - util::{MultiRead, MultiWrite}, -}; +use crate::net::socket::netlink::{common::NetlinkSocket, table::NetlinkRouteProtocol}; mod bound; mod kernel; mod message; -mod unbound; -pub struct NetlinkRouteSocket { - inner: RwMutex>, - - is_nonblocking: AtomicBool, - pollee: Pollee, -} - -impl NetlinkRouteSocket { - pub fn new(is_nonblocking: bool) -> Self { - let unbound = UnboundNetlinkRoute::new(); - Self { - inner: RwMutex::new(Inner::Unbound(unbound)), - is_nonblocking: AtomicBool::new(is_nonblocking), - pollee: Pollee::new(), - } - } - - fn try_send( - &self, - reader: &mut dyn MultiRead, - remote: Option<&NetlinkSocketAddr>, - flags: SendRecvFlags, - ) -> Result { - let sent_bytes = select_remote_and_bind( - &self.inner, - remote, - || { - self.inner - .write() - .bind_ephemeral(&NetlinkSocketAddr::new_unspecified(), &self.pollee) - }, - |bound, remote_endpoint| bound.try_send(reader, remote_endpoint, flags), - )?; - self.pollee.notify(IoEvents::OUT | IoEvents::IN); - - Ok(sent_bytes) - } - - fn try_recv( - &self, - writer: &mut dyn MultiWrite, - flags: SendRecvFlags, - ) -> Result<(usize, SocketAddr)> { - let recv_bytes = self - .inner - .read() - .try_recv(writer, flags) - .map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()))?; - self.pollee.invalidate(); - - Ok(recv_bytes) - } -} - -impl Socket for NetlinkRouteSocket { - fn bind(&self, socket_addr: SocketAddr) -> Result<()> { - let endpoint = socket_addr.try_into()?; - - // FIXME: We need to further check the Linux behavior - // whether we should return error if the socket is bound. - // The socket may call `bind` syscall to join new multicast groups. - self.inner.write().bind(&endpoint, &self.pollee, ()) - } - - fn connect(&self, socket_addr: SocketAddr) -> Result<()> { - let endpoint = socket_addr.try_into()?; - - self.inner.write().connect(&endpoint, &self.pollee) - } - - fn addr(&self) -> Result { - let endpoint = self - .inner - .read() - .addr() - .unwrap_or(NetlinkSocketAddr::new_unspecified()); - - Ok(endpoint.into()) - } - - fn peer_addr(&self) -> Result { - let endpoint = self - .inner - .read() - .peer_addr() - .cloned() - .unwrap_or(NetlinkSocketAddr::new_unspecified()); - - Ok(endpoint.into()) - } - - fn sendmsg( - &self, - reader: &mut dyn MultiRead, - message_header: MessageHeader, - flags: SendRecvFlags, - ) -> Result { - let MessageHeader { - addr, - control_message, - } = message_header; - - let remote = match addr { - None => None, - Some(addr) => Some(addr.try_into()?), - }; - - if control_message.is_some() { - // TODO: Support sending control message - warn!("sending control message is not supported"); - } - - // TODO: Make sure our blocking behavior matches that of Linux - self.try_send(reader, remote.as_ref(), flags) - } - - fn recvmsg( - &self, - writers: &mut dyn MultiWrite, - flags: SendRecvFlags, - ) -> Result<(usize, MessageHeader)> { - let (received_len, addr) = self.block_on(IoEvents::IN, || self.try_recv(writers, flags))?; - - // TODO: Receive control message - - let message_header = MessageHeader::new(Some(addr), None); - - Ok((received_len, message_header)) - } - - fn set_option(&self, _option: &dyn SocketOption) -> Result<()> { - // TODO: This dummy option is added to pass the libnl test - Ok(()) - } -} - -impl SocketPrivate for NetlinkRouteSocket { - fn is_nonblocking(&self) -> bool { - self.is_nonblocking.load(Ordering::Relaxed) - } - - fn set_nonblocking(&self, nonblocking: bool) { - self.is_nonblocking.store(nonblocking, Ordering::Relaxed); - } -} - -impl Pollable for NetlinkRouteSocket { - fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { - self.pollee - .poll_with(mask, poller, || self.inner.read().check_io_events()) - } -} +pub type NetlinkRouteSocket = NetlinkSocket; diff --git a/kernel/src/net/socket/netlink/route/unbound.rs b/kernel/src/net/socket/netlink/route/unbound.rs deleted file mode 100644 index 9a7dad37..00000000 --- a/kernel/src/net/socket/netlink/route/unbound.rs +++ /dev/null @@ -1,58 +0,0 @@ -// SPDX-License-Identifier: MPL-2.0 - -use super::bound::BoundNetlinkRoute; -use crate::{ - events::IoEvents, - net::socket::{ - netlink::{table::NETLINK_SOCKET_TABLE, NetlinkSocketAddr, StandardNetlinkProtocol}, - util::datagram_common, - }, - prelude::*, - process::signal::Pollee, -}; - -pub(super) struct UnboundNetlinkRoute { - _private: (), -} - -impl UnboundNetlinkRoute { - pub(super) const fn new() -> Self { - Self { _private: () } - } -} - -impl datagram_common::Unbound for UnboundNetlinkRoute { - type Endpoint = NetlinkSocketAddr; - type BindOptions = (); - - type Bound = BoundNetlinkRoute; - - fn bind( - &mut self, - endpoint: &Self::Endpoint, - _pollee: &Pollee, - _options: Self::BindOptions, - ) -> Result { - let bound_handle = - NETLINK_SOCKET_TABLE.bind(StandardNetlinkProtocol::ROUTE as _, endpoint)?; - - Ok(BoundNetlinkRoute::new(bound_handle)) - } - - fn bind_ephemeral( - &mut self, - _remote_endpoint: &Self::Endpoint, - _pollee: &Pollee, - ) -> Result { - let bound_handle = NETLINK_SOCKET_TABLE.bind( - StandardNetlinkProtocol::ROUTE as _, - &NetlinkSocketAddr::new_unspecified(), - )?; - - Ok(BoundNetlinkRoute::new(bound_handle)) - } - - fn check_io_events(&self) -> IoEvents { - IoEvents::OUT - } -} diff --git a/kernel/src/net/socket/netlink/table/mod.rs b/kernel/src/net/socket/netlink/table/mod.rs index 81f15ed8..fe058819 100644 --- a/kernel/src/net/socket/netlink/table/mod.rs +++ b/kernel/src/net/socket/netlink/table/mod.rs @@ -1,57 +1,82 @@ // SPDX-License-Identifier: MPL-2.0 use multicast::MulticastGroup; +pub(super) use multicast::MulticastMessage; +use spin::Once; use super::addr::{GroupIdSet, NetlinkProtocolId, NetlinkSocketAddr, PortNum, MAX_GROUPS}; -use crate::{net::socket::netlink::addr::UNSPECIFIED_PORT, prelude::*, util::random::getrandom}; +use crate::{ + net::socket::netlink::{ + addr::UNSPECIFIED_PORT, kobject_uevent::UeventMessage, receiver::MessageReceiver, + route::RtnlMessage, + }, + prelude::*, + util::random::getrandom, +}; mod multicast; -pub(super) static NETLINK_SOCKET_TABLE: NetlinkSocketTable = NetlinkSocketTable::new(); +static NETLINK_SOCKET_TABLE: Once = Once::new(); /// All bound netlink sockets. -pub(super) struct NetlinkSocketTable { - protocols: [Mutex>; MAX_ALLOWED_PROTOCOL_ID as usize], +struct NetlinkSocketTable { + route: RwMutex>, + uevent: RwMutex>, } impl NetlinkSocketTable { - pub(super) const fn new() -> Self { + fn new() -> Self { Self { - protocols: [const { Mutex::new(None) }; MAX_ALLOWED_PROTOCOL_ID as usize], + route: RwMutex::new(ProtocolSocketTable::new()), + uevent: RwMutex::new(ProtocolSocketTable::new()), } } +} - /// Adds a new netlink protocol. - fn add_new_protocol(&self, protocol_id: NetlinkProtocolId) { - if protocol_id >= MAX_ALLOWED_PROTOCOL_ID { - return; - } +pub trait SupportedNetlinkProtocol { + type Message: 'static + Send; - let mut protocol = self.protocols[protocol_id as usize].lock(); - if protocol.is_some() { - return; - } + fn socket_table() -> &'static RwMutex>; - let new_protocol = ProtocolSocketTable::new(protocol_id); - *protocol = Some(new_protocol); - } - - pub(super) fn bind( - &self, - protocol: NetlinkProtocolId, + fn bind( addr: &NetlinkSocketAddr, - ) -> Result { - if protocol >= MAX_ALLOWED_PROTOCOL_ID { - return_errno_with_message!(Errno::EINVAL, "the netlink protocol does not exist"); - } + receiver: MessageReceiver, + ) -> Result> { + let mut socket_table = Self::socket_table().write(); + socket_table.bind(Self::socket_table(), addr, receiver) + } - let mut protocol = self.protocols[protocol as usize].lock(); + fn unicast(dst_port: PortNum, message: Self::Message) -> Result<()> { + let socket_table = Self::socket_table().read(); + socket_table.unicast(dst_port, message) + } - let Some(protocol_sockets) = protocol.as_mut() else { - return_errno_with_message!(Errno::EINVAL, "the netlink protocol does not exist") - }; + fn multicast(dst_groups: GroupIdSet, message: Self::Message) -> Result<()> + where + Self::Message: MulticastMessage, + { + let socket_table = Self::socket_table().read(); + socket_table.multicast(dst_groups, message) + } +} - protocol_sockets.bind(addr) +pub enum NetlinkRouteProtocol {} + +impl SupportedNetlinkProtocol for NetlinkRouteProtocol { + type Message = RtnlMessage; + + fn socket_table() -> &'static RwMutex> { + &NETLINK_SOCKET_TABLE.get().unwrap().route + } +} + +pub enum NetlinkUeventProtocol {} + +impl SupportedNetlinkProtocol for NetlinkUeventProtocol { + type Message = UeventMessage; + + fn socket_table() -> &'static RwMutex> { + &NETLINK_SOCKET_TABLE.get().unwrap().uevent } } @@ -59,21 +84,17 @@ impl NetlinkSocketTable { /// /// Each table can have bound sockets for unicast /// and at most 32 groups for multicast. -struct ProtocolSocketTable { - id: NetlinkProtocolId, - // TODO: This table should maintain the port number-to-socket relationship - // to support both unicast and multicast effectively. - unicast_sockets: BTreeSet, +pub struct ProtocolSocketTable { + unicast_sockets: BTreeMap>, multicast_groups: Box<[MulticastGroup]>, } -impl ProtocolSocketTable { +impl ProtocolSocketTable { /// Creates a new table. - fn new(id: NetlinkProtocolId) -> Self { + fn new() -> Self { let multicast_groups = (0u32..MAX_GROUPS).map(|_| MulticastGroup::new()).collect(); Self { - id, - unicast_sockets: BTreeSet::new(), + unicast_sockets: BTreeMap::new(), multicast_groups, } } @@ -89,29 +110,66 @@ impl ProtocolSocketTable { /// /// Additionally, this socket can join one or more multicast groups, /// as specified in `addr.groups()`. - fn bind(&mut self, addr: &NetlinkSocketAddr) -> Result { + fn bind( + &mut self, + socket_table: &'static RwMutex>, + addr: &NetlinkSocketAddr, + receiver: MessageReceiver, + ) -> Result> { let port = if addr.port() != UNSPECIFIED_PORT { addr.port() } else { let mut random_port = current!().pid(); - while random_port == UNSPECIFIED_PORT || self.unicast_sockets.contains(&random_port) { + while random_port == UNSPECIFIED_PORT || self.unicast_sockets.contains_key(&random_port) + { getrandom(random_port.as_bytes_mut()).unwrap(); } random_port }; - if self.unicast_sockets.contains(&port) { + if self.unicast_sockets.contains_key(&port) { return_errno_with_message!(Errno::EADDRINUSE, "the netlink port is already in use"); } - self.unicast_sockets.insert(port); + self.unicast_sockets.insert(port, receiver); for group_id in addr.groups().ids_iter() { let group = &mut self.multicast_groups[group_id as usize]; group.add_member(port); } - Ok(BoundHandle::new(self.id, port, addr.groups())) + Ok(BoundHandle::new(socket_table, port, addr.groups())) + } + + fn unicast(&self, dst_port: PortNum, message: Message) -> Result<()> { + let Some(receiver) = self.unicast_sockets.get(&dst_port) else { + // FIXME: Should we return error here? + return Ok(()); + }; + + receiver.enqueue_message(message) + } + + fn multicast(&self, dst_groups: GroupIdSet, message: Message) -> Result<()> + where + Message: MulticastMessage, + { + for group in dst_groups.ids_iter() { + let Some(group) = self.multicast_groups.get(group as usize) else { + continue; + }; + + for port_num in group.members() { + let Some(receiver) = self.unicast_sockets.get(port_num) else { + continue; + }; + + // FIXME: Should we slightly ignore the error if the socket's buffer has no enough space? + receiver.enqueue_message(message.clone())?; + } + } + + Ok(()) } } @@ -119,19 +177,22 @@ impl ProtocolSocketTable { /// /// When dropping a `BoundHandle`, /// the port will be automatically released. -#[derive(Debug)] -pub(super) struct BoundHandle { - protocol: NetlinkProtocolId, +pub struct BoundHandle { + socket_table: &'static RwMutex>, port: PortNum, groups: GroupIdSet, } -impl BoundHandle { - fn new(protocol: NetlinkProtocolId, port: PortNum, groups: GroupIdSet) -> Self { +impl BoundHandle { + fn new( + socket_table: &'static RwMutex>, + port: PortNum, + groups: GroupIdSet, + ) -> Self { debug_assert_ne!(port, UNSPECIFIED_PORT); Self { - protocol, + socket_table, port, groups, } @@ -144,15 +205,49 @@ impl BoundHandle { pub(super) const fn addr(&self) -> NetlinkSocketAddr { NetlinkSocketAddr::new(self.port, self.groups) } + + pub(super) fn add_groups(&mut self, groups: GroupIdSet) { + let mut protocol_sockets = self.socket_table.write(); + + for group_id in groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.add_member(self.port); + } + + self.groups.add_groups(groups); + } + + pub(super) fn drop_groups(&mut self, groups: GroupIdSet) { + let mut protocol_sockets = self.socket_table.write(); + + for group_id in groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.remove_member(self.port); + } + + self.groups.drop_groups(groups); + } + + pub(super) fn bind_groups(&mut self, groups: GroupIdSet) { + let mut protocol_sockets = self.socket_table.write(); + + for group_id in self.groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.remove_member(self.port); + } + + for group_id in groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.add_member(self.port); + } + + self.groups = groups; + } } -impl Drop for BoundHandle { +impl Drop for BoundHandle { fn drop(&mut self) { - let mut protocol_sockets = NETLINK_SOCKET_TABLE.protocols[self.protocol as usize].lock(); - - let Some(protocol_sockets) = protocol_sockets.as_mut() else { - return; - }; + let mut protocol_sockets = self.socket_table.write(); protocol_sockets.unicast_sockets.remove(&self.port); @@ -164,11 +259,7 @@ impl Drop for BoundHandle { } pub(super) fn init() { - for protocol in 0..MAX_ALLOWED_PROTOCOL_ID { - if is_standard_protocol(protocol) { - NETLINK_SOCKET_TABLE.add_new_protocol(protocol); - } - } + NETLINK_SOCKET_TABLE.call_once(|| NetlinkSocketTable::new()); } /// Returns whether the `protocol` is valid. @@ -176,11 +267,6 @@ pub fn is_valid_protocol(protocol: NetlinkProtocolId) -> bool { protocol < MAX_ALLOWED_PROTOCOL_ID } -/// Returns whether the `protocol` is reserved for system use. -fn is_standard_protocol(protocol: NetlinkProtocolId) -> bool { - StandardNetlinkProtocol::try_from(protocol).is_ok() -} - /// Netlink protocols that are assigned for specific usage. /// /// Reference: . diff --git a/kernel/src/net/socket/netlink/table/multicast.rs b/kernel/src/net/socket/netlink/table/multicast.rs index 93f01f11..3036e86d 100644 --- a/kernel/src/net/socket/netlink/table/multicast.rs +++ b/kernel/src/net/socket/netlink/table/multicast.rs @@ -18,21 +18,20 @@ impl MulticastGroup { } } - /// Returns whether the group contains a member. - #[expect(unused)] - pub fn contains_member(&self, port_num: PortNum) -> bool { - self.members.contains(&port_num) - } - /// Adds a new member to the multicast group. pub fn add_member(&mut self, port_num: PortNum) { - debug_assert!(!self.members.contains(&port_num)); self.members.insert(port_num); } /// Removes a member from the multicast group. pub fn remove_member(&mut self, port_num: PortNum) { - debug_assert!(self.members.contains(&port_num)); self.members.remove(&port_num); } + + /// Returns all members in this group. + pub fn members(&self) -> &BTreeSet { + &self.members + } } + +pub trait MulticastMessage: Clone {} diff --git a/kernel/src/net/socket/util/datagram_common.rs b/kernel/src/net/socket/util/datagram_common.rs index c9bfc971..12329945 100644 --- a/kernel/src/net/socket/util/datagram_common.rs +++ b/kernel/src/net/socket/util/datagram_common.rs @@ -36,6 +36,9 @@ pub trait Bound { type Endpoint; fn local_endpoint(&self) -> Self::Endpoint; + fn bind(&mut self, _endpoint: &Self::Endpoint) -> Result<()> { + return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address") + } fn remote_endpoint(&self) -> Option<&Self::Endpoint>; fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint); @@ -72,11 +75,8 @@ where ) -> Result<()> { let unbound_datagram = match self { Inner::Unbound(unbound_datagram) => unbound_datagram, - Inner::Bound(_) => { - return_errno_with_message!( - Errno::EINVAL, - "the socket is already bound to an address" - ) + Inner::Bound(bound_datagram) => { + return bound_datagram.bind(endpoint); } }; diff --git a/kernel/src/syscall/socket.rs b/kernel/src/syscall/socket.rs index c526bfeb..47e90fe4 100644 --- a/kernel/src/syscall/socket.rs +++ b/kernel/src/syscall/socket.rs @@ -5,7 +5,9 @@ use crate::{ fs::{file_handle::FileLike, file_table::FdFlags}, net::socket::{ ip::{datagram::DatagramSocket, stream::StreamSocket}, - netlink::{is_valid_protocol, NetlinkRouteSocket, StandardNetlinkProtocol}, + netlink::{ + is_valid_protocol, NetlinkRouteSocket, NetlinkUeventSocket, StandardNetlinkProtocol, + }, unix::UnixStreamSocket, vsock::VsockStreamSocket, }, @@ -52,7 +54,10 @@ pub fn sys_socket(domain: i32, type_: i32, protocol: i32, ctx: &Context) -> Resu debug!("netlink family = {:?}", netlink_family); match netlink_family { Ok(StandardNetlinkProtocol::ROUTE) => { - Arc::new(NetlinkRouteSocket::new(is_nonblocking)) + NetlinkRouteSocket::new(is_nonblocking) as Arc + } + Ok(StandardNetlinkProtocol::KOBJECT_UEVENT) => { + NetlinkUeventSocket::new(is_nonblocking) as Arc } Ok(_) => { return_errno_with_message!( diff --git a/kernel/src/util/net/options/mod.rs b/kernel/src/util/net/options/mod.rs index aedc66e1..da49f828 100644 --- a/kernel/src/util/net/options/mod.rs +++ b/kernel/src/util/net/options/mod.rs @@ -52,10 +52,12 @@ //! use ip::new_ip_option; +use netlink::new_netlink_option; use crate::{net::socket::options::SocketOption, prelude::*}; mod ip; +mod netlink; mod socket; mod tcp; mod utils; @@ -130,6 +132,34 @@ macro_rules! impl_raw_sock_option_get_only { }; } +/// Impl `RawSocketOption` for a struct which is for only `setsockopt` and implements `SocketOption`. +#[macro_export] +macro_rules! impl_raw_sock_option_set_only { + ($option:ty) => { + impl RawSocketOption for $option { + fn read_from_user(&mut self, addr: Vaddr, max_len: u32) -> Result<()> { + use $crate::util::net::options::utils::ReadFromUser; + + let input = ReadFromUser::read_from_user(addr, max_len)?; + self.set(input); + Ok(()) + } + + fn write_to_user(&self, _addr: Vaddr, _max_len: u32) -> Result { + return_errno_with_message!(Errno::ENOPROTOOPT, "the option is setter-only"); + } + + fn as_sock_option_mut(&mut self) -> &mut dyn SocketOption { + self + } + + fn as_sock_option(&self) -> &dyn SocketOption { + self + } + } + }; +} + pub fn new_raw_socket_option( level: CSocketOptionLevel, name: i32, @@ -138,6 +168,7 @@ pub fn new_raw_socket_option( CSocketOptionLevel::SOL_SOCKET => new_socket_option(name), CSocketOptionLevel::SOL_IP => new_ip_option(name), CSocketOptionLevel::SOL_TCP => new_tcp_option(name), + CSocketOptionLevel::SOL_NETLINK => new_netlink_option(name), _ => return_errno_with_message!(Errno::EOPNOTSUPP, "unsupported option level"), } } @@ -153,4 +184,5 @@ pub enum CSocketOptionLevel { SOL_UDP = 17, SOL_IPV6 = 41, SOL_RAW = 255, + SOL_NETLINK = 270, } diff --git a/kernel/src/util/net/options/netlink.rs b/kernel/src/util/net/options/netlink.rs new file mode 100644 index 00000000..5d9bb832 --- /dev/null +++ b/kernel/src/util/net/options/netlink.rs @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MPL-2.0 + +use super::RawSocketOption; +use crate::{ + impl_raw_sock_option_set_only, + net::socket::netlink::{AddMembership, DropMembership}, + prelude::*, + util::net::options::SocketOption, +}; + +/// Socket options for netlink socket. +/// +/// Reference: . +#[repr(i32)] +#[derive(Debug, Clone, Copy, TryFromInt)] +#[expect(non_camel_case_types)] +#[expect(clippy::upper_case_acronyms)] +pub enum CNetlinkOptionName { + ADD_MEMBERSHIP = 1, + DROP_MEMBERSHIP = 2, + PKTINFO = 3, +} + +pub fn new_netlink_option(name: i32) -> Result> { + let name = CNetlinkOptionName::try_from(name).map_err(|_| Errno::ENOPROTOOPT)?; + match name { + CNetlinkOptionName::ADD_MEMBERSHIP => Ok(Box::new(AddMembership::new())), + CNetlinkOptionName::DROP_MEMBERSHIP => Ok(Box::new(DropMembership::new())), + _ => return_errno_with_message!(Errno::ENOPROTOOPT, "unsupported netlink option"), + } +} + +impl_raw_sock_option_set_only!(AddMembership); +impl_raw_sock_option_set_only!(DropMembership);