Implement netlink uevent socket

This commit is contained in:
jiangjianfeng 2025-05-30 09:10:31 +00:00 committed by Ruihan Li
parent d35888c817
commit f946f09ee4
27 changed files with 1251 additions and 364 deletions

View File

@ -43,6 +43,11 @@ impl NetlinkSocketAddr {
pub const fn groups(&self) -> GroupIdSet {
self.groups
}
/// Adds some new groups to the address.
pub fn add_groups(&mut self, groups: GroupIdSet) {
self.groups.add_groups(groups);
}
}
impl TryFrom<SocketAddr> for NetlinkSocketAddr {

View File

@ -24,21 +24,14 @@ impl GroupIdSet {
GroupIdIter::new(self)
}
/// Adds a new group.
///
/// If the group already exists, this method will return an error.
pub fn add_group(&mut self, group_id: GroupId) -> Result<()> {
if group_id >= 32 {
return_errno_with_message!(Errno::EINVAL, "the group ID is invalid");
/// Adds some new groups.
pub fn add_groups(&mut self, groups: GroupIdSet) {
self.0 |= groups.0;
}
let mask = 1u32 << group_id;
if self.0 & mask != 0 {
return_errno_with_message!(Errno::EINVAL, "the group ID already exists");
}
self.0 |= mask;
Ok(())
/// Drops some groups.
pub fn drop_groups(&mut self, groups: GroupIdSet) {
self.0 &= !groups.0;
}
/// Sets new groups.

View File

@ -0,0 +1,61 @@
// SPDX-License-Identifier: MPL-2.0
use crate::{
events::IoEvents,
net::socket::netlink::{
receiver::MessageQueue, table::BoundHandle, GroupIdSet, NetlinkSocketAddr,
},
prelude::*,
};
pub struct BoundNetlink<Message: 'static> {
pub(in crate::net::socket::netlink) handle: BoundHandle<Message>,
pub(in crate::net::socket::netlink) remote_addr: NetlinkSocketAddr,
pub(in crate::net::socket::netlink) receive_queue: MessageQueue<Message>,
}
impl<Message: 'static> BoundNetlink<Message> {
pub(super) fn new(handle: BoundHandle<Message>, message_queue: MessageQueue<Message>) -> Self {
Self {
handle,
remote_addr: NetlinkSocketAddr::new_unspecified(),
receive_queue: message_queue,
}
}
pub(in crate::net::socket::netlink) fn bind_common(
&mut self,
endpoint: &NetlinkSocketAddr,
) -> Result<()> {
if endpoint.port() != self.handle.port() {
return_errno_with_message!(
Errno::EINVAL,
"the socket cannot be bound to a different port"
);
}
let groups = endpoint.groups();
self.handle.bind_groups(groups);
Ok(())
}
pub(in crate::net::socket::netlink) fn check_io_events_common(&self) -> IoEvents {
let mut events = IoEvents::OUT;
let receive_queue = self.receive_queue.0.lock();
if !receive_queue.is_empty() {
events |= IoEvents::IN;
}
events
}
pub(super) fn add_groups(&mut self, groups: GroupIdSet) {
self.handle.add_groups(groups);
}
pub(super) fn drop_groups(&mut self, groups: GroupIdSet) {
self.handle.drop_groups(groups);
}
}

View File

@ -0,0 +1,230 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
pub(super) use bound::BoundNetlink;
use unbound::UnboundNetlink;
use super::{GroupIdSet, NetlinkSocketAddr};
use crate::{
events::IoEvents,
match_sock_option_ref,
net::socket::{
netlink::{table::SupportedNetlinkProtocol, AddMembership, DropMembership},
options::SocketOption,
private::SocketPrivate,
util::datagram_common::{select_remote_and_bind, Bound, Inner},
MessageHeader, SendRecvFlags, Socket, SocketAddr,
},
prelude::*,
process::signal::{PollHandle, Pollable, Pollee},
util::{MultiRead, MultiWrite},
};
mod bound;
mod unbound;
pub struct NetlinkSocket<P: SupportedNetlinkProtocol> {
inner: RwMutex<Inner<UnboundNetlink<P>, BoundNetlink<P::Message>>>,
is_nonblocking: AtomicBool,
pollee: Pollee,
}
impl<P: SupportedNetlinkProtocol> NetlinkSocket<P>
where
BoundNetlink<P::Message>: Bound<Endpoint = NetlinkSocketAddr>,
{
pub fn new(is_nonblocking: bool) -> Arc<Self> {
let unbound = UnboundNetlink::new();
Arc::new(Self {
inner: RwMutex::new(Inner::Unbound(unbound)),
is_nonblocking: AtomicBool::new(is_nonblocking),
pollee: Pollee::new(),
})
}
fn try_send(
&self,
reader: &mut dyn MultiRead,
remote: Option<&NetlinkSocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
let sent_bytes = select_remote_and_bind(
&self.inner,
remote,
|| {
self.inner
.write()
.bind_ephemeral(&NetlinkSocketAddr::new_unspecified(), &self.pollee)
},
|bound, remote_endpoint| bound.try_send(reader, remote_endpoint, flags),
)?;
self.pollee.invalidate();
Ok(sent_bytes)
}
// FIXME: This method is marked as `pub(super)` because it's invoked during kernel mode testing.
pub(super) fn try_recv(
&self,
writer: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, SocketAddr)> {
let recv_bytes = self
.inner
.read()
.try_recv(writer, flags)
.map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()))?;
self.pollee.invalidate();
Ok(recv_bytes)
}
}
impl<P: SupportedNetlinkProtocol> Socket for NetlinkSocket<P>
where
BoundNetlink<P::Message>: Bound<Endpoint = NetlinkSocketAddr>,
{
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?;
self.inner.write().bind(&endpoint, &self.pollee, ())
}
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?;
self.inner.write().connect(&endpoint, &self.pollee)
}
fn addr(&self) -> Result<SocketAddr> {
let endpoint = match &*self.inner.read() {
Inner::Unbound(unbound) => unbound.addr(),
Inner::Bound(bound) => bound.local_endpoint(),
};
Ok(endpoint.into())
}
fn peer_addr(&self) -> Result<SocketAddr> {
let endpoint = self
.inner
.read()
.peer_addr()
.cloned()
.unwrap_or(NetlinkSocketAddr::new_unspecified());
Ok(endpoint.into())
}
fn sendmsg(
&self,
reader: &mut dyn MultiRead,
message_header: MessageHeader,
flags: SendRecvFlags,
) -> Result<usize> {
let MessageHeader {
addr,
control_message,
} = message_header;
let remote = match addr {
None => None,
Some(addr) => Some(addr.try_into()?),
};
if control_message.is_some() {
// TODO: Support sending control message
warn!("sending control message is not supported");
}
// TODO: Make sure our blocking behavior matches that of Linux
self.try_send(reader, remote.as_ref(), flags)
}
fn recvmsg(
&self,
writers: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, MessageHeader)> {
let (received_len, addr) = self.block_on(IoEvents::IN, || self.try_recv(writers, flags))?;
// TODO: Receive control message
let message_header = MessageHeader::new(Some(addr), None);
Ok((received_len, message_header))
}
fn set_option(&self, option: &dyn SocketOption) -> Result<()> {
match do_set_netlink_option(&self.inner, option) {
Ok(()) => Ok(()),
Err(e) => {
warn!(
"We currently ignore set option errors to pass libnl test: {:?}",
e
);
Ok(())
}
}
}
}
impl<P: SupportedNetlinkProtocol> SocketPrivate for NetlinkSocket<P>
where
BoundNetlink<P::Message>: Bound<Endpoint = NetlinkSocketAddr>,
{
fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Relaxed)
}
fn set_nonblocking(&self, nonblocking: bool) {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
}
}
impl<P: SupportedNetlinkProtocol> Pollable for NetlinkSocket<P>
where
BoundNetlink<P::Message>: Bound<Endpoint = NetlinkSocketAddr>,
{
fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee
.poll_with(mask, poller, || self.inner.read().check_io_events())
}
}
impl<P: SupportedNetlinkProtocol> Inner<UnboundNetlink<P>, BoundNetlink<P::Message>> {
fn add_groups(&mut self, groups: GroupIdSet) {
match self {
Inner::Unbound(unbound_socket) => unbound_socket.add_groups(groups),
Inner::Bound(bound_socket) => bound_socket.add_groups(groups),
}
}
fn drop_groups(&mut self, groups: GroupIdSet) {
match self {
Inner::Unbound(unbound_socket) => unbound_socket.drop_groups(groups),
Inner::Bound(bound_socket) => bound_socket.drop_groups(groups),
}
}
}
fn do_set_netlink_option<P: SupportedNetlinkProtocol>(
inner: &RwMutex<Inner<UnboundNetlink<P>, BoundNetlink<P::Message>>>,
option: &dyn SocketOption,
) -> Result<()> {
match_sock_option_ref!(option, {
add_membership: AddMembership => {
let groups = add_membership.get().unwrap();
inner.write().add_groups(GroupIdSet::new(*groups));
},
drop_membership: DropMembership => {
let groups = drop_membership.get().unwrap();
inner.write().drop_groups(GroupIdSet::new(*groups));
},
_ => return_errno_with_message!(Errno::ENOPROTOOPT, "the socket option to be set is unknown")
});
Ok(())
}

View File

@ -0,0 +1,96 @@
// SPDX-License-Identifier: MPL-2.0
use core::marker::PhantomData;
use crate::{
events::IoEvents,
net::socket::{
netlink::{
common::bound::BoundNetlink,
receiver::{MessageQueue, MessageReceiver},
table::SupportedNetlinkProtocol,
GroupIdSet, NetlinkSocketAddr,
},
util::datagram_common,
},
prelude::*,
process::signal::Pollee,
};
pub(super) struct UnboundNetlink<P: SupportedNetlinkProtocol> {
groups: GroupIdSet,
phantom: PhantomData<BoundNetlink<P::Message>>,
}
impl<P: SupportedNetlinkProtocol> UnboundNetlink<P> {
pub(super) const fn new() -> Self {
Self {
groups: GroupIdSet::new_empty(),
phantom: PhantomData,
}
}
pub(super) fn addr(&self) -> NetlinkSocketAddr {
NetlinkSocketAddr::new(0, self.groups)
}
pub(super) fn add_groups(&mut self, groups: GroupIdSet) {
self.groups.add_groups(groups);
}
pub(super) fn drop_groups(&mut self, groups: GroupIdSet) {
self.groups.drop_groups(groups);
}
}
impl<P: SupportedNetlinkProtocol> datagram_common::Unbound for UnboundNetlink<P> {
type Endpoint = NetlinkSocketAddr;
type BindOptions = ();
type Bound = BoundNetlink<P::Message>;
fn bind(
&mut self,
endpoint: &Self::Endpoint,
pollee: &Pollee,
_options: Self::BindOptions,
) -> Result<Self::Bound> {
let message_queue = MessageQueue::<P::Message>::new();
let bound_handle = {
let endpoint = {
let mut endpoint = endpoint.clone();
endpoint.add_groups(self.groups);
endpoint
};
let receiver = MessageReceiver::new(message_queue.clone(), pollee.clone());
<P as SupportedNetlinkProtocol>::bind(&endpoint, receiver)?
};
Ok(BoundNetlink::new(bound_handle, message_queue))
}
fn bind_ephemeral(
&mut self,
_remote_endpoint: &Self::Endpoint,
pollee: &Pollee,
) -> Result<Self::Bound> {
let message_queue = MessageQueue::<P::Message>::new();
let bound_handle = {
let endpoint = {
let mut endpoint = NetlinkSocketAddr::new_unspecified();
endpoint.add_groups(self.groups);
endpoint
};
let receiver = MessageReceiver::new(message_queue.clone(), pollee.clone());
<P as SupportedNetlinkProtocol>::bind(&endpoint, receiver)?
};
Ok(BoundNetlink::new(bound_handle, message_queue))
}
fn check_io_events(&self) -> IoEvents {
IoEvents::OUT
}
}

View File

@ -0,0 +1,96 @@
// SPDX-License-Identifier: MPL-2.0
use core::ops::Sub;
use super::message::UeventMessage;
use crate::{
events::IoEvents,
net::socket::{
netlink::{common::BoundNetlink, NetlinkSocketAddr},
util::datagram_common,
SendRecvFlags,
},
prelude::*,
util::{MultiRead, MultiWrite},
};
pub(super) type BoundNetlinkUevent = BoundNetlink<UeventMessage>;
impl datagram_common::Bound for BoundNetlinkUevent {
type Endpoint = NetlinkSocketAddr;
fn local_endpoint(&self) -> Self::Endpoint {
self.handle.addr()
}
fn bind(&mut self, endpoint: &Self::Endpoint) -> Result<()> {
self.bind_common(endpoint)
}
fn remote_endpoint(&self) -> Option<&Self::Endpoint> {
Some(&self.remote_addr)
}
fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint) {
self.remote_addr = *endpoint;
}
fn try_send(
&self,
reader: &mut dyn MultiRead,
remote: &Self::Endpoint,
flags: SendRecvFlags,
) -> Result<usize> {
// TODO: Deal with flags
if !flags.is_all_supported() {
warn!("unsupported flags: {:?}", flags);
}
if *remote != NetlinkSocketAddr::new_unspecified() {
return_errno_with_message!(
Errno::ECONNREFUSED,
"sending uevent messages to user space is not supported"
);
}
// FIXME: How to deal with sending message to kernel socket?
// Here we simply ignore the message and return the message length.
Ok(reader.sum_lens())
}
fn try_recv(
&self,
writer: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, Self::Endpoint)> {
// TODO: Deal with other flags. Only MSG_PEEK is handled here.
if !flags.sub(SendRecvFlags::MSG_PEEK).is_all_supported() {
warn!("unsupported flags: {:?}", flags);
}
let mut receive_queue = self.receive_queue.0.lock();
let Some(response) = receive_queue.front() else {
return_errno_with_message!(Errno::EAGAIN, "nothing to receive");
};
let len = {
let max_len = writer.sum_lens();
response.total_len().min(max_len)
};
response.write_to(writer)?;
let remote = response.src_addr().clone();
if !flags.contains(SendRecvFlags::MSG_PEEK) {
receive_queue.pop_front().unwrap();
}
Ok((len, remote))
}
fn check_io_events(&self) -> IoEvents {
self.check_io_events_common()
}
}

View File

@ -0,0 +1,58 @@
// SPDX-License-Identifier: MPL-2.0
#![cfg_attr(not(ktest), expect(dead_code))]
use uevent::Uevent;
use crate::{
net::socket::netlink::{table::MulticastMessage, NetlinkSocketAddr},
prelude::*,
util::MultiWrite,
};
mod syn_uevent;
#[cfg(ktest)]
mod test;
mod uevent;
/// A uevent message.
///
/// Note that uevent messages are not the same as common netlink messages.
/// It does not have a netlink header.
#[derive(Debug, Clone)]
pub struct UeventMessage {
uevent: String,
src_addr: NetlinkSocketAddr,
}
impl UeventMessage {
/// Creates a new uevent message.
fn new(uevent: Uevent, src_addr: NetlinkSocketAddr) -> Self {
Self {
uevent: uevent.to_string(),
src_addr,
}
}
/// Returns the source address of the uevent message.
pub(super) fn src_addr(&self) -> &NetlinkSocketAddr {
&self.src_addr
}
/// Returns the total length of the uevent.
pub(super) fn total_len(&self) -> usize {
self.uevent.len()
}
/// Writes the uevent to the given `writer`.
pub(super) fn write_to(&self, writer: &mut dyn MultiWrite) -> Result<()> {
// FIXME: If the message can be truncated, we should avoid returning an error.
if self.uevent.len() > writer.sum_lens() {
return_errno_with_message!(Errno::EFAULT, "the writer length is too small");
}
writer.write(&mut VmReader::from(self.uevent.as_bytes()))?;
Ok(())
}
}
impl MulticastMessage for UeventMessage {}

View File

@ -0,0 +1,102 @@
// SPDX-License-Identifier: MPL-2.0
//! The synthetic uevent.
//!
//! The event is triggered when someone writes some content to `/sys/.../uevent` file.
//! It differs from the event triggers by devices.
//!
//! Reference: <https://elixir.bootlin.com/linux/v6.0.9/source/Documentation/ABI/testing/sysfs-uevent>.
use alloc::format;
use core::str::FromStr;
use super::uevent::SysObjAction;
use crate::prelude::*;
/// The synthetic uevent.
pub(super) struct SyntheticUevent {
pub(super) action: SysObjAction,
pub(super) uuid: Option<Uuid>,
pub(super) envs: Vec<(String, String)>,
}
impl FromStr for SyntheticUevent {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
let mut split = s.split(" ");
let action = {
let Some(action_str) = split.next() else {
return_errno_with_message!(Errno::EINVAL, "the string is empty");
};
SysObjAction::from_str(action_str)?
};
let uuid = if let Some(uuid_str) = split.next() {
Some(Uuid::from_str(uuid_str)?)
} else {
None
};
let mut envs = Vec::new();
for env_str in split.into_iter() {
let (key, value) = {
// Each string should be in the `KEY=VALUE` format.
match env_str.split_once('=') {
Some(key_value) => key_value,
None => return_errno_with_message!(Errno::EINVAL, "invalid key value pairs"),
}
};
// Both `KEY` and `VALUE` can contain alphanumeric characters only.
for byte in key.as_bytes().iter().chain(value.as_bytes()) {
if !byte.is_ascii_alphanumeric() {
return_errno_with_message!(
Errno::EINVAL,
"invalid character in key value pairs"
);
}
}
// The `KEY` name gains `SYNTH_ARG_` prefix to avoid possible collisions
// with existing variables.
let key = format!("SYNTH_ARG_{}", key);
let value = value.to_string();
envs.push((key, value));
}
Ok(Self { action, uuid, envs })
}
}
pub(super) struct Uuid(pub(super) String);
impl FromStr for Uuid {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
/// The allowed UUID pattern, where each `x` is a hex digit.
///
/// Reference: <https://elixir.bootlin.com/linux/v6.0.9/source/Documentation/ABI/testing/sysfs-uevent#L19>.
const UUID_PATTERN: &str = "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx";
let bytes = s.as_bytes();
if bytes.len() != UUID_PATTERN.len() {
return_errno_with_message!(Errno::EINVAL, "the UUID length is invalid");
}
for (byte, pattern) in bytes.into_iter().zip(UUID_PATTERN.as_bytes()) {
if *pattern == b'x' && byte.is_ascii_hexdigit() {
continue;
} else if *pattern == b'-' && *byte == b'-' {
continue;
} else {
return_errno_with_message!(Errno::EINVAL, "the UUID content is invalid");
}
}
Ok(Self(s.to_string()))
}
}

View File

@ -0,0 +1,89 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::vec;
use core::str::FromStr;
use ostd::{mm::VmWriter, prelude::*};
use crate::{
net::socket::{
netlink::{
kobject_uevent::{
message::{
syn_uevent::{SyntheticUevent, Uuid},
uevent::Uevent,
},
UeventMessage,
},
table::{NetlinkUeventProtocol, SupportedNetlinkProtocol},
GroupIdSet, NetlinkSocketAddr, NetlinkUeventSocket,
},
SendRecvFlags, Socket, SocketAddr,
},
prelude::*,
};
#[ktest]
fn uuid() {
let uuid = Uuid::from_str("12345678-1234-1234-1234-123456789012");
assert!(uuid.is_ok());
let uuid = Uuid::from_str("12345678-1234-1234-1234-12345678901");
assert!(uuid.is_err());
let uuid = Uuid::from_str("12345678-1234-1234-1234-1234567890g");
assert!(uuid.is_err());
}
#[ktest]
fn synthetic_uevent() {
let uevent = SyntheticUevent::from_str("add");
assert!(uevent.is_ok());
let uevent = SyntheticUevent::from_str("add 12345678-1234-1234-1234-123456789012");
assert!(uevent.is_ok());
let uevent = SyntheticUevent::from_str("add 12345678-1234-1234-1234-123456789012 NAME=lo");
assert!(uevent.is_ok());
}
#[ktest]
fn multicast_synthetic_uevent() {
crate::net::socket::netlink::init();
// Creates a new netlink uevent socket and joins the group for kobject uevents.
let socket = NetlinkUeventSocket::new(true);
let socket_addr = SocketAddr::Netlink(NetlinkSocketAddr::new(100, GroupIdSet::new(0x1)));
socket.bind(socket_addr).unwrap();
// Tries to receive and returns EAGAIN if no message is available.
let mut buffer = vec![0u8; 1024];
let mut writer = VmWriter::from(buffer.as_mut_slice()).to_fallible();
let res = socket.try_recv(&mut writer, SendRecvFlags::empty());
assert!(res.is_err_and(|err| err.error() == Errno::EAGAIN));
// Broadcasts a uevent message.
let uevent = {
let lo_infos = vec![
("INTERFACE".to_string(), "lo".to_string()),
("IFINDEX".to_string(), "1".to_string()),
];
let synth_uevent = SyntheticUevent::from_str("add").unwrap();
Uevent::new_from_syn(
synth_uevent,
"/devices/virtual/net/lo".to_string(),
"net".to_string(),
lo_infos,
)
};
let uevent_message =
UeventMessage::new(uevent, NetlinkSocketAddr::new(0, GroupIdSet::new(0x1)));
NetlinkUeventProtocol::multicast(GroupIdSet::new(0x1), uevent_message).unwrap();
let (len, _) = socket
.try_recv(&mut writer, SendRecvFlags::empty())
.unwrap();
let s = core::str::from_utf8(&buffer[..len]).unwrap();
assert_eq!(s, "add@/devices/virtual/net/lo\0ACTION=add\0DEVPATH=/devices/virtual/net/lo\0SUBSYSTEM=net\0SYNTH_UUID=0\0INTERFACE=lo\0IFINDEX=1\0SEQNUM=1\0");
}

View File

@ -0,0 +1,176 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::format;
use core::{
str::FromStr,
sync::atomic::{AtomicU64, Ordering},
};
use super::syn_uevent::{SyntheticUevent, Uuid};
use crate::prelude::*;
/// `SysObj` action type.
///
/// Reference: <https://elixir.bootlin.com/linux/v6.14/source/include/linux/kobject.h#L53>.
#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromInt)]
#[repr(u8)]
pub(super) enum SysObjAction {
/// Indicates the addition of a new `SysObj` to the system.
///
/// Triggered when a device is discovered or registered.
Add = 0,
/// Signals the removal of a `SysObj` from the system.
///
/// Typically occurs during device disconnection or deregistration.
Remove = 1,
/// Denotes a modification to the `SysObj`'s properties or state.
///
/// Used for attribute changes that don't involve structural modifications.
Change = 2,
/// Represents hierarchical relocation of a `SysObj`.
///
/// Occurs when a device moves within the device tree topology.
Move = 3,
/// Marks a device returning to operational status after being offlined.
///
/// Common in hot-pluggable device scenarios.
Online = 4,
/// Indicates a device entering non-operational status.
///
/// Typically precedes safe removal of hot-pluggable hardware.
Offline = 5,
/// Signifies successful driver-device binding.
///
/// Occurs after successful driver probe sequence.
Bind = 6,
/// Indicates driver-device binding termination.
///
/// Precedes driver unload or device removal.
Unbind = 7,
}
const SYSOBJ_ACTION_STRS: [&str; SysObjAction::Unbind as usize + 1] = [
"add", "remove", "change", "move", "online", "offline", "bind", "unbind",
];
impl FromStr for SysObjAction {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
let Some(index) = SYSOBJ_ACTION_STRS
.iter()
.position(|action_str| s == *action_str)
else {
return_errno_with_message!(Errno::EINVAL, "the string is not a valid `SysObj` action");
};
Ok(SysObjAction::try_from(index as u8).unwrap())
}
}
impl SysObjAction {
fn as_str(&self) -> &'static str {
SYSOBJ_ACTION_STRS[*self as usize]
}
}
/// Userspace event.
pub(super) struct Uevent {
/// The `SysObj` action.
action: SysObjAction,
/// The absolute `SysObj` path under sysfs.
devpath: String,
/// The subsystem the event originates from
subsystem: String,
/// Other key-value arguments
envs: Vec<(String, String)>,
/// Sequence number.
seq_num: u64,
}
impl Uevent {
/// Creates a new uevent.
fn new(
action: SysObjAction,
devpath: String,
subsystem: String,
envs: Vec<(String, String)>,
) -> Self {
debug_assert!(devpath.starts_with('/'));
let seq_num = SEQ_NUM_ALLOCATOR.fetch_add(1, Ordering::Relaxed);
Self {
action,
devpath,
subsystem,
envs,
seq_num,
}
}
/// Creates a new uevent from synthetic uevent.
pub(super) fn new_from_syn(
synth_uevent: SyntheticUevent,
devpath: String,
subsystem: String,
mut other_envs: Vec<(String, String)>,
) -> Self {
let SyntheticUevent {
action,
uuid,
mut envs,
} = synth_uevent;
let uuid_key = "SYNTH_UUID".to_string();
if let Some(Uuid(uuid)) = uuid {
envs.push((uuid_key, uuid));
} else {
envs.push((uuid_key, "0".to_string()));
};
envs.append(&mut other_envs);
Self::new(action, devpath, subsystem, envs)
}
}
impl ToString for Uevent {
fn to_string(&self) -> String {
let mut env_string = {
let len = self
.envs
.iter()
.map(|(key, value)| key.len() + value.len() + 2)
.sum();
String::with_capacity(len)
};
for (key, value) in self.envs.iter() {
env_string.push_str(key);
env_string.push('=');
env_string.push_str(value);
env_string.push('\0');
}
format!(
"{}@{}\0ACTION={}\0DEVPATH={}\0SUBSYSTEM={}\0{}SEQNUM={}\0",
self.action.as_str(),
self.devpath,
self.action.as_str(),
self.devpath,
self.subsystem,
env_string,
self.seq_num
)
}
}
static SEQ_NUM_ALLOCATOR: AtomicU64 = AtomicU64::new(1);

View File

@ -0,0 +1,10 @@
// SPDX-License-Identifier: MPL-2.0
pub(super) use message::UeventMessage;
use crate::net::socket::netlink::{common::NetlinkSocket, table::NetlinkUeventProtocol};
mod bound;
mod message;
pub type NetlinkUeventSocket = NetlinkSocket<NetlinkUeventProtocol>;

View File

@ -26,7 +26,7 @@ use crate::{
/// A netlink message can be transmitted to and from user space using a single send/receive syscall.
/// It consists of one or more [`ProtocolSegment`]s.
#[derive(Debug)]
pub(super) struct Message<T: ProtocolSegment> {
pub struct Message<T: ProtocolSegment> {
segments: Vec<T>,
}
@ -69,7 +69,7 @@ impl<T: ProtocolSegment> Message<T> {
}
}
pub(super) trait ProtocolSegment: Sized {
pub trait ProtocolSegment: Sized {
fn header(&self) -> &CMsgSegHdr;
fn header_mut(&mut self) -> &mut CMsgSegHdr;
fn read_from(reader: &mut dyn MultiRead) -> Result<Self>;

View File

@ -23,7 +23,7 @@ pub struct CMsgSegHdr {
}
bitflags! {
/// Common flags used in [`CMsgSegmentHdr`].
/// Common flags used in [`CMsgSegHdr`].
///
/// Reference: <https://elixir.bootlin.com/linux/v6.13/source/include/uapi/linux/netlink.h#L62>.
pub struct SegHdrCommonFlags: u16 {

View File

@ -38,11 +38,17 @@
//!
mod addr;
mod common;
mod kobject_uevent;
mod message;
mod options;
mod receiver;
mod route;
mod table;
pub use addr::{GroupIdSet, NetlinkSocketAddr};
pub use kobject_uevent::NetlinkUeventSocket;
pub use options::{AddMembership, DropMembership};
pub use route::NetlinkRouteSocket;
pub use table::{is_valid_protocol, StandardNetlinkProtocol};

View File

@ -0,0 +1,8 @@
// SPDX-License-Identifier: MPL-2.0
use crate::impl_socket_options;
impl_socket_options!(
pub struct AddMembership(u32);
pub struct DropMembership(u32);
);

View File

@ -0,0 +1,45 @@
// SPDX-License-Identifier: MPL-2.0
use crate::{events::IoEvents, prelude::*, process::signal::Pollee};
pub struct MessageReceiver<Message> {
message_queue: MessageQueue<Message>,
pollee: Pollee,
}
pub(super) struct MessageQueue<Message>(pub(super) Arc<Mutex<VecDeque<Message>>>);
impl<Message> Clone for MessageQueue<Message> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<Message> MessageQueue<Message> {
pub(super) fn new() -> Self {
Self(Arc::new(Mutex::new(VecDeque::new())))
}
fn enqueue(&self, message: Message) -> Result<()> {
// FIXME: We should verify the socket buffer length to ensure
// that adding the message doesn't exceed the buffer capacity.
self.0.lock().push_back(message);
Ok(())
}
}
impl<Message> MessageReceiver<Message> {
pub(super) const fn new(message_queue: MessageQueue<Message>, pollee: Pollee) -> Self {
Self {
message_queue,
pollee,
}
}
pub(super) fn enqueue_message(&self, message: Message) -> Result<()> {
self.message_queue.enqueue(message)?;
self.pollee.notify(IoEvents::IN);
Ok(())
}
}

View File

@ -7,8 +7,8 @@ use crate::{
events::IoEvents,
net::socket::{
netlink::{
message::ProtocolSegment, route::kernel::get_netlink_route_kernel, table::BoundHandle,
NetlinkSocketAddr,
common::BoundNetlink, message::ProtocolSegment,
route::kernel::get_netlink_route_kernel, NetlinkSocketAddr,
},
util::datagram_common,
SendRecvFlags,
@ -17,21 +17,7 @@ use crate::{
util::{MultiRead, MultiWrite},
};
pub(super) struct BoundNetlinkRoute {
handle: BoundHandle,
remote_addr: NetlinkSocketAddr,
receive_queue: Mutex<VecDeque<RtnlMessage>>,
}
impl BoundNetlinkRoute {
pub(super) const fn new(handle: BoundHandle) -> Self {
Self {
handle,
remote_addr: NetlinkSocketAddr::new_unspecified(),
receive_queue: Mutex::new(VecDeque::new()),
}
}
}
pub(super) type BoundNetlinkRoute = BoundNetlink<RtnlMessage>;
impl datagram_common::Bound for BoundNetlinkRoute {
type Endpoint = NetlinkSocketAddr;
@ -40,6 +26,10 @@ impl datagram_common::Bound for BoundNetlinkRoute {
self.handle.addr()
}
fn bind(&mut self, endpoint: &Self::Endpoint) -> Result<()> {
self.bind_common(endpoint)
}
fn remote_endpoint(&self) -> Option<&Self::Endpoint> {
Some(&self.remote_addr)
}
@ -97,9 +87,7 @@ impl datagram_common::Bound for BoundNetlinkRoute {
}
}
get_netlink_route_kernel().request(&nlmsg, |response| {
self.receive_queue.lock().push_back(response);
});
get_netlink_route_kernel().request(&nlmsg, local_port);
Ok(nlmsg.total_len())
}
@ -114,7 +102,7 @@ impl datagram_common::Bound for BoundNetlinkRoute {
warn!("unsupported flags: {:?}", flags);
}
let mut receive_queue = self.receive_queue.lock();
let mut receive_queue = self.receive_queue.0.lock();
let Some(response) = receive_queue.front() else {
return_errno_with_message!(Errno::EAGAIN, "nothing to receive");
@ -138,13 +126,6 @@ impl datagram_common::Bound for BoundNetlinkRoute {
}
fn check_io_events(&self) -> IoEvents {
let mut events = IoEvents::OUT;
let receive_queue = self.receive_queue.lock();
if !receive_queue.is_empty() {
events |= IoEvents::IN;
}
events
self.check_io_events_common()
}
}

View File

@ -7,7 +7,11 @@ use core::marker::PhantomData;
use super::message::{RtnlMessage, RtnlSegment};
use crate::{
net::socket::netlink::message::{CSegmentType, ErrorSegment, ProtocolSegment},
net::socket::netlink::{
addr::PortNum,
message::{CSegmentType, ErrorSegment, ProtocolSegment},
table::{NetlinkRouteProtocol, SupportedNetlinkProtocol},
},
prelude::*,
};
@ -26,11 +30,7 @@ impl NetlinkRouteKernelSocket {
}
}
pub(super) fn request<F: FnMut(RtnlMessage)>(
&self,
request: &RtnlMessage,
mut consume_response: F,
) {
pub(super) fn request(&self, request: &RtnlMessage, dst_port: PortNum) {
debug!("netlink route request: {:?}", request);
for segment in request.segments() {
@ -61,7 +61,7 @@ impl NetlinkRouteKernelSocket {
debug!("netlink route response: {:?}", response);
consume_response(response);
NetlinkRouteProtocol::unicast(dst_port, response).unwrap();
}
}
}

View File

@ -18,4 +18,4 @@ pub(super) use segment::{
use crate::net::socket::netlink::message::Message;
/// A netlink route message.
pub(super) type RtnlMessage = Message<RtnlSegment>;
pub(in crate::net::socket::netlink) type RtnlMessage = Message<RtnlSegment>;

View File

@ -2,179 +2,12 @@
//! Netlink Route Socket.
use core::sync::atomic::{AtomicBool, Ordering};
pub(super) use message::RtnlMessage;
use bound::BoundNetlinkRoute;
use unbound::UnboundNetlinkRoute;
use super::NetlinkSocketAddr;
use crate::{
events::IoEvents,
net::socket::{
options::SocketOption,
private::SocketPrivate,
util::datagram_common::{select_remote_and_bind, Bound, Inner},
MessageHeader, SendRecvFlags, Socket, SocketAddr,
},
prelude::*,
process::signal::{PollHandle, Pollable, Pollee},
util::{MultiRead, MultiWrite},
};
use crate::net::socket::netlink::{common::NetlinkSocket, table::NetlinkRouteProtocol};
mod bound;
mod kernel;
mod message;
mod unbound;
pub struct NetlinkRouteSocket {
inner: RwMutex<Inner<UnboundNetlinkRoute, BoundNetlinkRoute>>,
is_nonblocking: AtomicBool,
pollee: Pollee,
}
impl NetlinkRouteSocket {
pub fn new(is_nonblocking: bool) -> Self {
let unbound = UnboundNetlinkRoute::new();
Self {
inner: RwMutex::new(Inner::Unbound(unbound)),
is_nonblocking: AtomicBool::new(is_nonblocking),
pollee: Pollee::new(),
}
}
fn try_send(
&self,
reader: &mut dyn MultiRead,
remote: Option<&NetlinkSocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
let sent_bytes = select_remote_and_bind(
&self.inner,
remote,
|| {
self.inner
.write()
.bind_ephemeral(&NetlinkSocketAddr::new_unspecified(), &self.pollee)
},
|bound, remote_endpoint| bound.try_send(reader, remote_endpoint, flags),
)?;
self.pollee.notify(IoEvents::OUT | IoEvents::IN);
Ok(sent_bytes)
}
fn try_recv(
&self,
writer: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, SocketAddr)> {
let recv_bytes = self
.inner
.read()
.try_recv(writer, flags)
.map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()))?;
self.pollee.invalidate();
Ok(recv_bytes)
}
}
impl Socket for NetlinkRouteSocket {
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?;
// FIXME: We need to further check the Linux behavior
// whether we should return error if the socket is bound.
// The socket may call `bind` syscall to join new multicast groups.
self.inner.write().bind(&endpoint, &self.pollee, ())
}
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
let endpoint = socket_addr.try_into()?;
self.inner.write().connect(&endpoint, &self.pollee)
}
fn addr(&self) -> Result<SocketAddr> {
let endpoint = self
.inner
.read()
.addr()
.unwrap_or(NetlinkSocketAddr::new_unspecified());
Ok(endpoint.into())
}
fn peer_addr(&self) -> Result<SocketAddr> {
let endpoint = self
.inner
.read()
.peer_addr()
.cloned()
.unwrap_or(NetlinkSocketAddr::new_unspecified());
Ok(endpoint.into())
}
fn sendmsg(
&self,
reader: &mut dyn MultiRead,
message_header: MessageHeader,
flags: SendRecvFlags,
) -> Result<usize> {
let MessageHeader {
addr,
control_message,
} = message_header;
let remote = match addr {
None => None,
Some(addr) => Some(addr.try_into()?),
};
if control_message.is_some() {
// TODO: Support sending control message
warn!("sending control message is not supported");
}
// TODO: Make sure our blocking behavior matches that of Linux
self.try_send(reader, remote.as_ref(), flags)
}
fn recvmsg(
&self,
writers: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, MessageHeader)> {
let (received_len, addr) = self.block_on(IoEvents::IN, || self.try_recv(writers, flags))?;
// TODO: Receive control message
let message_header = MessageHeader::new(Some(addr), None);
Ok((received_len, message_header))
}
fn set_option(&self, _option: &dyn SocketOption) -> Result<()> {
// TODO: This dummy option is added to pass the libnl test
Ok(())
}
}
impl SocketPrivate for NetlinkRouteSocket {
fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Relaxed)
}
fn set_nonblocking(&self, nonblocking: bool) {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
}
}
impl Pollable for NetlinkRouteSocket {
fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee
.poll_with(mask, poller, || self.inner.read().check_io_events())
}
}
pub type NetlinkRouteSocket = NetlinkSocket<NetlinkRouteProtocol>;

View File

@ -1,58 +0,0 @@
// SPDX-License-Identifier: MPL-2.0
use super::bound::BoundNetlinkRoute;
use crate::{
events::IoEvents,
net::socket::{
netlink::{table::NETLINK_SOCKET_TABLE, NetlinkSocketAddr, StandardNetlinkProtocol},
util::datagram_common,
},
prelude::*,
process::signal::Pollee,
};
pub(super) struct UnboundNetlinkRoute {
_private: (),
}
impl UnboundNetlinkRoute {
pub(super) const fn new() -> Self {
Self { _private: () }
}
}
impl datagram_common::Unbound for UnboundNetlinkRoute {
type Endpoint = NetlinkSocketAddr;
type BindOptions = ();
type Bound = BoundNetlinkRoute;
fn bind(
&mut self,
endpoint: &Self::Endpoint,
_pollee: &Pollee,
_options: Self::BindOptions,
) -> Result<BoundNetlinkRoute> {
let bound_handle =
NETLINK_SOCKET_TABLE.bind(StandardNetlinkProtocol::ROUTE as _, endpoint)?;
Ok(BoundNetlinkRoute::new(bound_handle))
}
fn bind_ephemeral(
&mut self,
_remote_endpoint: &Self::Endpoint,
_pollee: &Pollee,
) -> Result<Self::Bound> {
let bound_handle = NETLINK_SOCKET_TABLE.bind(
StandardNetlinkProtocol::ROUTE as _,
&NetlinkSocketAddr::new_unspecified(),
)?;
Ok(BoundNetlinkRoute::new(bound_handle))
}
fn check_io_events(&self) -> IoEvents {
IoEvents::OUT
}
}

View File

@ -1,57 +1,82 @@
// SPDX-License-Identifier: MPL-2.0
use multicast::MulticastGroup;
pub(super) use multicast::MulticastMessage;
use spin::Once;
use super::addr::{GroupIdSet, NetlinkProtocolId, NetlinkSocketAddr, PortNum, MAX_GROUPS};
use crate::{net::socket::netlink::addr::UNSPECIFIED_PORT, prelude::*, util::random::getrandom};
use crate::{
net::socket::netlink::{
addr::UNSPECIFIED_PORT, kobject_uevent::UeventMessage, receiver::MessageReceiver,
route::RtnlMessage,
},
prelude::*,
util::random::getrandom,
};
mod multicast;
pub(super) static NETLINK_SOCKET_TABLE: NetlinkSocketTable = NetlinkSocketTable::new();
static NETLINK_SOCKET_TABLE: Once<NetlinkSocketTable> = Once::new();
/// All bound netlink sockets.
pub(super) struct NetlinkSocketTable {
protocols: [Mutex<Option<ProtocolSocketTable>>; MAX_ALLOWED_PROTOCOL_ID as usize],
struct NetlinkSocketTable {
route: RwMutex<ProtocolSocketTable<RtnlMessage>>,
uevent: RwMutex<ProtocolSocketTable<UeventMessage>>,
}
impl NetlinkSocketTable {
pub(super) const fn new() -> Self {
fn new() -> Self {
Self {
protocols: [const { Mutex::new(None) }; MAX_ALLOWED_PROTOCOL_ID as usize],
route: RwMutex::new(ProtocolSocketTable::new()),
uevent: RwMutex::new(ProtocolSocketTable::new()),
}
}
}
/// Adds a new netlink protocol.
fn add_new_protocol(&self, protocol_id: NetlinkProtocolId) {
if protocol_id >= MAX_ALLOWED_PROTOCOL_ID {
return;
}
pub trait SupportedNetlinkProtocol {
type Message: 'static + Send;
let mut protocol = self.protocols[protocol_id as usize].lock();
if protocol.is_some() {
return;
}
fn socket_table() -> &'static RwMutex<ProtocolSocketTable<Self::Message>>;
let new_protocol = ProtocolSocketTable::new(protocol_id);
*protocol = Some(new_protocol);
}
pub(super) fn bind(
&self,
protocol: NetlinkProtocolId,
fn bind(
addr: &NetlinkSocketAddr,
) -> Result<BoundHandle> {
if protocol >= MAX_ALLOWED_PROTOCOL_ID {
return_errno_with_message!(Errno::EINVAL, "the netlink protocol does not exist");
receiver: MessageReceiver<Self::Message>,
) -> Result<BoundHandle<Self::Message>> {
let mut socket_table = Self::socket_table().write();
socket_table.bind(Self::socket_table(), addr, receiver)
}
let mut protocol = self.protocols[protocol as usize].lock();
fn unicast(dst_port: PortNum, message: Self::Message) -> Result<()> {
let socket_table = Self::socket_table().read();
socket_table.unicast(dst_port, message)
}
let Some(protocol_sockets) = protocol.as_mut() else {
return_errno_with_message!(Errno::EINVAL, "the netlink protocol does not exist")
};
fn multicast(dst_groups: GroupIdSet, message: Self::Message) -> Result<()>
where
Self::Message: MulticastMessage,
{
let socket_table = Self::socket_table().read();
socket_table.multicast(dst_groups, message)
}
}
protocol_sockets.bind(addr)
pub enum NetlinkRouteProtocol {}
impl SupportedNetlinkProtocol for NetlinkRouteProtocol {
type Message = RtnlMessage;
fn socket_table() -> &'static RwMutex<ProtocolSocketTable<Self::Message>> {
&NETLINK_SOCKET_TABLE.get().unwrap().route
}
}
pub enum NetlinkUeventProtocol {}
impl SupportedNetlinkProtocol for NetlinkUeventProtocol {
type Message = UeventMessage;
fn socket_table() -> &'static RwMutex<ProtocolSocketTable<Self::Message>> {
&NETLINK_SOCKET_TABLE.get().unwrap().uevent
}
}
@ -59,21 +84,17 @@ impl NetlinkSocketTable {
///
/// Each table can have bound sockets for unicast
/// and at most 32 groups for multicast.
struct ProtocolSocketTable {
id: NetlinkProtocolId,
// TODO: This table should maintain the port number-to-socket relationship
// to support both unicast and multicast effectively.
unicast_sockets: BTreeSet<PortNum>,
pub struct ProtocolSocketTable<Message> {
unicast_sockets: BTreeMap<PortNum, MessageReceiver<Message>>,
multicast_groups: Box<[MulticastGroup]>,
}
impl ProtocolSocketTable {
impl<Message: 'static> ProtocolSocketTable<Message> {
/// Creates a new table.
fn new(id: NetlinkProtocolId) -> Self {
fn new() -> Self {
let multicast_groups = (0u32..MAX_GROUPS).map(|_| MulticastGroup::new()).collect();
Self {
id,
unicast_sockets: BTreeSet::new(),
unicast_sockets: BTreeMap::new(),
multicast_groups,
}
}
@ -89,29 +110,66 @@ impl ProtocolSocketTable {
///
/// Additionally, this socket can join one or more multicast groups,
/// as specified in `addr.groups()`.
fn bind(&mut self, addr: &NetlinkSocketAddr) -> Result<BoundHandle> {
fn bind(
&mut self,
socket_table: &'static RwMutex<ProtocolSocketTable<Message>>,
addr: &NetlinkSocketAddr,
receiver: MessageReceiver<Message>,
) -> Result<BoundHandle<Message>> {
let port = if addr.port() != UNSPECIFIED_PORT {
addr.port()
} else {
let mut random_port = current!().pid();
while random_port == UNSPECIFIED_PORT || self.unicast_sockets.contains(&random_port) {
while random_port == UNSPECIFIED_PORT || self.unicast_sockets.contains_key(&random_port)
{
getrandom(random_port.as_bytes_mut()).unwrap();
}
random_port
};
if self.unicast_sockets.contains(&port) {
if self.unicast_sockets.contains_key(&port) {
return_errno_with_message!(Errno::EADDRINUSE, "the netlink port is already in use");
}
self.unicast_sockets.insert(port);
self.unicast_sockets.insert(port, receiver);
for group_id in addr.groups().ids_iter() {
let group = &mut self.multicast_groups[group_id as usize];
group.add_member(port);
}
Ok(BoundHandle::new(self.id, port, addr.groups()))
Ok(BoundHandle::new(socket_table, port, addr.groups()))
}
fn unicast(&self, dst_port: PortNum, message: Message) -> Result<()> {
let Some(receiver) = self.unicast_sockets.get(&dst_port) else {
// FIXME: Should we return error here?
return Ok(());
};
receiver.enqueue_message(message)
}
fn multicast(&self, dst_groups: GroupIdSet, message: Message) -> Result<()>
where
Message: MulticastMessage,
{
for group in dst_groups.ids_iter() {
let Some(group) = self.multicast_groups.get(group as usize) else {
continue;
};
for port_num in group.members() {
let Some(receiver) = self.unicast_sockets.get(port_num) else {
continue;
};
// FIXME: Should we slightly ignore the error if the socket's buffer has no enough space?
receiver.enqueue_message(message.clone())?;
}
}
Ok(())
}
}
@ -119,19 +177,22 @@ impl ProtocolSocketTable {
///
/// When dropping a `BoundHandle`,
/// the port will be automatically released.
#[derive(Debug)]
pub(super) struct BoundHandle {
protocol: NetlinkProtocolId,
pub struct BoundHandle<Message: 'static> {
socket_table: &'static RwMutex<ProtocolSocketTable<Message>>,
port: PortNum,
groups: GroupIdSet,
}
impl BoundHandle {
fn new(protocol: NetlinkProtocolId, port: PortNum, groups: GroupIdSet) -> Self {
impl<Message: 'static> BoundHandle<Message> {
fn new(
socket_table: &'static RwMutex<ProtocolSocketTable<Message>>,
port: PortNum,
groups: GroupIdSet,
) -> Self {
debug_assert_ne!(port, UNSPECIFIED_PORT);
Self {
protocol,
socket_table,
port,
groups,
}
@ -144,15 +205,49 @@ impl BoundHandle {
pub(super) const fn addr(&self) -> NetlinkSocketAddr {
NetlinkSocketAddr::new(self.port, self.groups)
}
pub(super) fn add_groups(&mut self, groups: GroupIdSet) {
let mut protocol_sockets = self.socket_table.write();
for group_id in groups.ids_iter() {
let group = &mut protocol_sockets.multicast_groups[group_id as usize];
group.add_member(self.port);
}
self.groups.add_groups(groups);
}
pub(super) fn drop_groups(&mut self, groups: GroupIdSet) {
let mut protocol_sockets = self.socket_table.write();
for group_id in groups.ids_iter() {
let group = &mut protocol_sockets.multicast_groups[group_id as usize];
group.remove_member(self.port);
}
self.groups.drop_groups(groups);
}
pub(super) fn bind_groups(&mut self, groups: GroupIdSet) {
let mut protocol_sockets = self.socket_table.write();
for group_id in self.groups.ids_iter() {
let group = &mut protocol_sockets.multicast_groups[group_id as usize];
group.remove_member(self.port);
}
for group_id in groups.ids_iter() {
let group = &mut protocol_sockets.multicast_groups[group_id as usize];
group.add_member(self.port);
}
self.groups = groups;
}
}
impl Drop for BoundHandle {
impl<Message: 'static> Drop for BoundHandle<Message> {
fn drop(&mut self) {
let mut protocol_sockets = NETLINK_SOCKET_TABLE.protocols[self.protocol as usize].lock();
let Some(protocol_sockets) = protocol_sockets.as_mut() else {
return;
};
let mut protocol_sockets = self.socket_table.write();
protocol_sockets.unicast_sockets.remove(&self.port);
@ -164,11 +259,7 @@ impl Drop for BoundHandle {
}
pub(super) fn init() {
for protocol in 0..MAX_ALLOWED_PROTOCOL_ID {
if is_standard_protocol(protocol) {
NETLINK_SOCKET_TABLE.add_new_protocol(protocol);
}
}
NETLINK_SOCKET_TABLE.call_once(|| NetlinkSocketTable::new());
}
/// Returns whether the `protocol` is valid.
@ -176,11 +267,6 @@ pub fn is_valid_protocol(protocol: NetlinkProtocolId) -> bool {
protocol < MAX_ALLOWED_PROTOCOL_ID
}
/// Returns whether the `protocol` is reserved for system use.
fn is_standard_protocol(protocol: NetlinkProtocolId) -> bool {
StandardNetlinkProtocol::try_from(protocol).is_ok()
}
/// Netlink protocols that are assigned for specific usage.
///
/// Reference: <https://elixir.bootlin.com/linux/v6.0.9/source/include/uapi/linux/netlink.h#L9>.

View File

@ -18,21 +18,20 @@ impl MulticastGroup {
}
}
/// Returns whether the group contains a member.
#[expect(unused)]
pub fn contains_member(&self, port_num: PortNum) -> bool {
self.members.contains(&port_num)
}
/// Adds a new member to the multicast group.
pub fn add_member(&mut self, port_num: PortNum) {
debug_assert!(!self.members.contains(&port_num));
self.members.insert(port_num);
}
/// Removes a member from the multicast group.
pub fn remove_member(&mut self, port_num: PortNum) {
debug_assert!(self.members.contains(&port_num));
self.members.remove(&port_num);
}
/// Returns all members in this group.
pub fn members(&self) -> &BTreeSet<PortNum> {
&self.members
}
}
pub trait MulticastMessage: Clone {}

View File

@ -36,6 +36,9 @@ pub trait Bound {
type Endpoint;
fn local_endpoint(&self) -> Self::Endpoint;
fn bind(&mut self, _endpoint: &Self::Endpoint) -> Result<()> {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address")
}
fn remote_endpoint(&self) -> Option<&Self::Endpoint>;
fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint);
@ -72,11 +75,8 @@ where
) -> Result<()> {
let unbound_datagram = match self {
Inner::Unbound(unbound_datagram) => unbound_datagram,
Inner::Bound(_) => {
return_errno_with_message!(
Errno::EINVAL,
"the socket is already bound to an address"
)
Inner::Bound(bound_datagram) => {
return bound_datagram.bind(endpoint);
}
};

View File

@ -5,7 +5,9 @@ use crate::{
fs::{file_handle::FileLike, file_table::FdFlags},
net::socket::{
ip::{datagram::DatagramSocket, stream::StreamSocket},
netlink::{is_valid_protocol, NetlinkRouteSocket, StandardNetlinkProtocol},
netlink::{
is_valid_protocol, NetlinkRouteSocket, NetlinkUeventSocket, StandardNetlinkProtocol,
},
unix::UnixStreamSocket,
vsock::VsockStreamSocket,
},
@ -52,7 +54,10 @@ pub fn sys_socket(domain: i32, type_: i32, protocol: i32, ctx: &Context) -> Resu
debug!("netlink family = {:?}", netlink_family);
match netlink_family {
Ok(StandardNetlinkProtocol::ROUTE) => {
Arc::new(NetlinkRouteSocket::new(is_nonblocking))
NetlinkRouteSocket::new(is_nonblocking) as Arc<dyn FileLike>
}
Ok(StandardNetlinkProtocol::KOBJECT_UEVENT) => {
NetlinkUeventSocket::new(is_nonblocking) as Arc<dyn FileLike>
}
Ok(_) => {
return_errno_with_message!(

View File

@ -52,10 +52,12 @@
//!
use ip::new_ip_option;
use netlink::new_netlink_option;
use crate::{net::socket::options::SocketOption, prelude::*};
mod ip;
mod netlink;
mod socket;
mod tcp;
mod utils;
@ -130,6 +132,34 @@ macro_rules! impl_raw_sock_option_get_only {
};
}
/// Impl `RawSocketOption` for a struct which is for only `setsockopt` and implements `SocketOption`.
#[macro_export]
macro_rules! impl_raw_sock_option_set_only {
($option:ty) => {
impl RawSocketOption for $option {
fn read_from_user(&mut self, addr: Vaddr, max_len: u32) -> Result<()> {
use $crate::util::net::options::utils::ReadFromUser;
let input = ReadFromUser::read_from_user(addr, max_len)?;
self.set(input);
Ok(())
}
fn write_to_user(&self, _addr: Vaddr, _max_len: u32) -> Result<usize> {
return_errno_with_message!(Errno::ENOPROTOOPT, "the option is setter-only");
}
fn as_sock_option_mut(&mut self) -> &mut dyn SocketOption {
self
}
fn as_sock_option(&self) -> &dyn SocketOption {
self
}
}
};
}
pub fn new_raw_socket_option(
level: CSocketOptionLevel,
name: i32,
@ -138,6 +168,7 @@ pub fn new_raw_socket_option(
CSocketOptionLevel::SOL_SOCKET => new_socket_option(name),
CSocketOptionLevel::SOL_IP => new_ip_option(name),
CSocketOptionLevel::SOL_TCP => new_tcp_option(name),
CSocketOptionLevel::SOL_NETLINK => new_netlink_option(name),
_ => return_errno_with_message!(Errno::EOPNOTSUPP, "unsupported option level"),
}
}
@ -153,4 +184,5 @@ pub enum CSocketOptionLevel {
SOL_UDP = 17,
SOL_IPV6 = 41,
SOL_RAW = 255,
SOL_NETLINK = 270,
}

View File

@ -0,0 +1,34 @@
// SPDX-License-Identifier: MPL-2.0
use super::RawSocketOption;
use crate::{
impl_raw_sock_option_set_only,
net::socket::netlink::{AddMembership, DropMembership},
prelude::*,
util::net::options::SocketOption,
};
/// Socket options for netlink socket.
///
/// Reference: <https://elixir.bootlin.com/linux/v6.0.9/source/include/uapi/linux/netlink.h#L149>.
#[repr(i32)]
#[derive(Debug, Clone, Copy, TryFromInt)]
#[expect(non_camel_case_types)]
#[expect(clippy::upper_case_acronyms)]
pub enum CNetlinkOptionName {
ADD_MEMBERSHIP = 1,
DROP_MEMBERSHIP = 2,
PKTINFO = 3,
}
pub fn new_netlink_option(name: i32) -> Result<Box<dyn RawSocketOption>> {
let name = CNetlinkOptionName::try_from(name).map_err(|_| Errno::ENOPROTOOPT)?;
match name {
CNetlinkOptionName::ADD_MEMBERSHIP => Ok(Box::new(AddMembership::new())),
CNetlinkOptionName::DROP_MEMBERSHIP => Ok(Box::new(DropMembership::new())),
_ => return_errno_with_message!(Errno::ENOPROTOOPT, "unsupported netlink option"),
}
}
impl_raw_sock_option_set_only!(AddMembership);
impl_raw_sock_option_set_only!(DropMembership);