diff --git a/kernel/src/net/mod.rs b/kernel/src/net/mod.rs index 88e105905..afdb1386f 100644 --- a/kernel/src/net/mod.rs +++ b/kernel/src/net/mod.rs @@ -5,6 +5,7 @@ pub mod socket; pub fn init() { iface::init(); + socket::netlink::init(); socket::vsock::init(); } diff --git a/kernel/src/net/socket/mod.rs b/kernel/src/net/socket/mod.rs index adbd36dd4..8a272618c 100644 --- a/kernel/src/net/socket/mod.rs +++ b/kernel/src/net/socket/mod.rs @@ -15,6 +15,7 @@ use crate::{ }; pub mod ip; +pub mod netlink; pub mod options; pub mod unix; mod util; diff --git a/kernel/src/net/socket/netlink/addr/mod.rs b/kernel/src/net/socket/netlink/addr/mod.rs new file mode 100644 index 000000000..b26cee6c7 --- /dev/null +++ b/kernel/src/net/socket/netlink/addr/mod.rs @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MPL-2.0 + +mod multicast; + +pub use multicast::{GroupIdSet, MAX_GROUPS}; + +use crate::{net::socket::SocketAddr, prelude::*}; + +/// The socket address of a netlink socket. +/// +/// The address contains the port number for unicast +/// and the group IDs for multicast. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NetlinkSocketAddr { + port: PortNum, + groups: GroupIdSet, +} + +impl NetlinkSocketAddr { + /// Creates a new netlink address. + pub const fn new(port: PortNum, groups: GroupIdSet) -> Self { + Self { port, groups } + } + + /// Creates a new unspecified address. + /// + /// Both the port ID and group numbers are left unspecified. + /// + /// Note that an unspecified address can also represent the kernel socket address. + pub const fn new_unspecified() -> Self { + Self { + port: UNSPECIFIED_PORT, + groups: GroupIdSet::new_empty(), + } + } + + /// Returns the port number. + pub const fn port(&self) -> PortNum { + self.port + } + + /// Returns the group ID set. + pub const fn groups(&self) -> GroupIdSet { + self.groups + } +} + +impl TryFrom for NetlinkSocketAddr { + type Error = Error; + + fn try_from(value: SocketAddr) -> Result { + match value { + SocketAddr::Netlink(addr) => Ok(addr), + _ => return_errno_with_message!( + Errno::EAFNOSUPPORT, + "the address is in an unsupported address family" + ), + } + } +} + +impl From for SocketAddr { + fn from(value: NetlinkSocketAddr) -> Self { + SocketAddr::Netlink(value) + } +} + +pub type NetlinkProtocolId = u32; +pub type PortNum = u32; + +pub const UNSPECIFIED_PORT: PortNum = 0; diff --git a/kernel/src/net/socket/netlink/addr/multicast.rs b/kernel/src/net/socket/netlink/addr/multicast.rs new file mode 100644 index 000000000..172a54756 --- /dev/null +++ b/kernel/src/net/socket/netlink/addr/multicast.rs @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MPL-2.0 + +use crate::prelude::*; + +/// A set of group IDs. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct GroupIdSet(u32); + +impl GroupIdSet { + /// Creates a new empty `GroupIdSet`. + pub const fn new_empty() -> Self { + Self(0) + } + + /// Creates a new `GroupIdSet` with multiple groups. + /// + /// Each 1 bit in `groups` represent a group. + pub const fn new(groups: u32) -> Self { + Self(groups) + } + + /// Creates an iterator over all group IDs. + pub const fn ids_iter(&self) -> GroupIdIter { + GroupIdIter::new(self) + } + + /// Adds a new group. + /// + /// If the group already exists, this method will return an error. + pub fn add_group(&mut self, group_id: GroupId) -> Result<()> { + if group_id >= 32 { + return_errno_with_message!(Errno::EINVAL, "the group ID is invalid"); + } + + let mask = 1u32 << group_id; + if self.0 & mask != 0 { + return_errno_with_message!(Errno::EINVAL, "the group ID already exists"); + } + self.0 |= mask; + + Ok(()) + } + + /// Sets new groups. + pub fn set_groups(&mut self, new_groups: u32) { + self.0 = new_groups; + } + + /// Clears all groups. + pub fn clear(&mut self) { + self.0 = 0; + } + + /// Checks if the set of group IDs is empty. + pub fn is_empty(&self) -> bool { + self.0 == 0 + } + + /// Returns the group IDs as a u32. + pub fn as_u32(&self) -> u32 { + self.0 + } +} + +/// Iterator over a set of group IDs. +pub struct GroupIdIter { + groups: u32, +} + +impl GroupIdIter { + const fn new(groups: &GroupIdSet) -> Self { + Self { groups: groups.0 } + } +} + +impl Iterator for GroupIdIter { + type Item = GroupId; + + fn next(&mut self) -> Option { + if self.groups > 0 { + let group_id = self.groups.trailing_zeros(); + self.groups &= self.groups - 1; + return Some(group_id); + } + + None + } +} + +pub const MAX_GROUPS: u32 = 32; +pub type GroupId = u32; diff --git a/kernel/src/net/socket/netlink/mod.rs b/kernel/src/net/socket/netlink/mod.rs new file mode 100644 index 000000000..b789a6b7c --- /dev/null +++ b/kernel/src/net/socket/netlink/mod.rs @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! This module defines netlink sockets. +//! +//! Netlink provides a standardized, socket-based interface, +//! typically used for communication between user space and kernel space. +//! It can also be used for interaction between two user processes. +//! +//! Each netlink socket corresponds to +//! a netlink protocol identified by a protocol ID (u32). +//! Protocols are generally defined to serve specific functions. +//! For instance, the NETLINK_ROUTE protocol is employed +//! to retrieve or modify network device settings. +//! Only sockets associated with the same protocol can communicate with each other. +//! Some protocols are pre-defined by the kernel and serve fixed purposes, +//! but users can also establish custom protocols by specifying new protocol IDs. +//! +//! Before initiating communication, +//! a netlink socket must be bound to an address, +//! which consists of a port number and a multicast group number. +//! +//! The port number is used for unicast communication, +//! whereas the multicast group number is meant for multicast communication. +//! +//! In terms of unicast communication within each protocol, +//! a port number can only be bound to one socket. +//! However, the same port number can be utilized across different protocols. +//! Typically, the port number corresponds to the process ID of the running process. +//! +//! Multicast communication allows a message +//! to be sent to one or multiple multicast groups simultaneously. +//! Each protocol can support up to 32 multicast groups, +//! and a socket can belong to zero or multiple multicast groups. +//! +//! Netlink communication is akin to UDP in that +//! it does not require a connection to be established before sending messages. +//! The destination address must be specified when dispatching a message. +//! + +mod addr; +mod message; +mod route; +mod table; + +pub use addr::{GroupIdSet, NetlinkSocketAddr}; +pub use route::NetlinkRouteSocket; +pub use table::{is_valid_protocol, StandardNetlinkProtocol}; + +pub(in crate::net) fn init() { + table::init(); +} diff --git a/kernel/src/net/socket/netlink/route/bound.rs b/kernel/src/net/socket/netlink/route/bound.rs new file mode 100644 index 000000000..3cd59d635 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/bound.rs @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: MPL-2.0 + +use core::ops::Sub; + +use super::message::RtnlMessage; +use crate::{ + events::IoEvents, + net::socket::{ + netlink::{ + message::ProtocolSegment, route::kernel::get_netlink_route_kernel, table::BoundHandle, + NetlinkSocketAddr, + }, + SendRecvFlags, + }, + prelude::*, + util::{MultiRead, MultiWrite}, +}; + +pub(super) struct BoundNetlinkRoute { + handle: BoundHandle, + receive_queue: Mutex>, +} + +impl BoundNetlinkRoute { + pub(super) const fn new(handle: BoundHandle) -> Self { + Self { + handle, + receive_queue: Mutex::new(VecDeque::new()), + } + } + + pub(super) const fn addr(&self) -> NetlinkSocketAddr { + self.handle.addr() + } + + pub(super) fn try_send( + &self, + reader: &mut dyn MultiRead, + remote: Option<&NetlinkSocketAddr>, + flags: SendRecvFlags, + ) -> Result { + // TODO: Deal with flags + if !flags.is_all_supported() { + warn!("unsupported flags: {:?}", flags); + } + + if let Some(remote) = remote { + // TODO: Further check whether other socket address can be supported. + if *remote != NetlinkSocketAddr::new_unspecified() { + return_errno_with_message!( + Errno::ECONNREFUSED, + "sending netlink route messages to user space is not supported" + ); + } + } else { + // TODO: We should use the connected remote address, if any. + } + + let mut nlmsg = { + let sum_lens = reader.sum_lens(); + + match RtnlMessage::read_from(reader) { + Ok(nlmsg) => nlmsg, + Err(e) if e.error() == Errno::EFAULT => { + // EFAULT indicates an error occurred while copying data from user space, + // and this error should be returned back to user space. + return Err(e); + } + Err(e) => { + // Errors other than EFAULT indicate a failure in parsing the netlink message. + // These errors should be silently ignored. + warn!("failed to send netlink message: {:?}", e); + return Ok(sum_lens); + } + } + }; + + let local_port = self.addr().port(); + for segment in nlmsg.segments_mut() { + // The header's PID should be the sender's port ID. + // However, the sender can also leave it unspecified. + // In such cases, we will manually set the PID to the sender's port ID. + let header = segment.header_mut(); + if header.pid == 0 { + header.pid = local_port; + } + } + + get_netlink_route_kernel().request(&nlmsg, |response| { + self.receive_queue.lock().push_back(response); + }); + + Ok(nlmsg.total_len()) + } + + pub(super) fn try_receive( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, NetlinkSocketAddr)> { + // TODO: Deal with other flags. Only MSG_PEEK is handled here. + if !flags.sub(SendRecvFlags::MSG_PEEK).is_all_supported() { + warn!("unsupported flags: {:?}", flags); + } + + let mut receive_queue = self.receive_queue.lock(); + + let Some(response) = receive_queue.front() else { + return_errno_with_message!(Errno::EAGAIN, "nothing to receive"); + }; + + let len = { + let max_len = writer.sum_lens(); + response.total_len().min(max_len) + }; + + response.write_to(writer)?; + + if !flags.contains(SendRecvFlags::MSG_PEEK) { + receive_queue.pop_front().unwrap(); + } + + // TODO: The message can only come from kernel socket currently. + let remote = NetlinkSocketAddr::new_unspecified(); + + Ok((len, remote)) + } + + pub(super) fn check_io_events(&self) -> IoEvents { + let mut events = IoEvents::OUT; + + let receive_queue = self.receive_queue.lock(); + if !receive_queue.is_empty() { + events |= IoEvents::IN; + } + + events + } +} diff --git a/kernel/src/net/socket/netlink/route/mod.rs b/kernel/src/net/socket/netlink/route/mod.rs new file mode 100644 index 000000000..b6a2e9584 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/mod.rs @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Netlink Route Socket. + +use core::sync::atomic::{AtomicBool, Ordering}; + +use bound::BoundNetlinkRoute; +use takeable::Takeable; +use unbound::UnboundNetlinkRoute; + +use super::NetlinkSocketAddr; +use crate::{ + events::IoEvents, + net::socket::{ + options::SocketOption, private::SocketPrivate, MessageHeader, SendRecvFlags, Socket, + SocketAddr, + }, + prelude::*, + process::signal::{PollHandle, Pollable, Pollee}, + util::{MultiRead, MultiWrite}, +}; + +mod bound; +mod kernel; +mod message; +mod unbound; + +pub struct NetlinkRouteSocket { + is_nonblocking: AtomicBool, + pollee: Pollee, + inner: RwMutex>, +} + +enum Inner { + Unbound(UnboundNetlinkRoute), + Bound(BoundNetlinkRoute), +} + +impl NetlinkRouteSocket { + pub fn new(is_nonblocking: bool) -> Self { + Self { + is_nonblocking: AtomicBool::new(is_nonblocking), + pollee: Pollee::new(), + inner: RwMutex::new(Takeable::new(Inner::Unbound(UnboundNetlinkRoute::new()))), + } + } + + fn try_receive( + &self, + writer: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, NetlinkSocketAddr)> { + let inner = self.inner.read(); + + let bound = match inner.as_ref() { + Inner::Unbound(_) => { + return_errno_with_message!(Errno::EAGAIN, "the socket is not bound") + } + Inner::Bound(bound_netlink_route) => bound_netlink_route, + }; + + let received = bound.try_receive(writer, flags)?; + self.pollee.invalidate(); + + Ok(received) + } + + fn try_send( + &self, + reader: &mut dyn MultiRead, + remote: Option<&NetlinkSocketAddr>, + flags: SendRecvFlags, + ) -> Result { + let inner = self.inner.read(); + + let bound = match inner.as_ref() { + Inner::Unbound(_) => todo!(), + Inner::Bound(bound) => bound, + }; + + let sent_bytes = bound.try_send(reader, remote, flags)?; + self.pollee.notify(IoEvents::OUT | IoEvents::IN); + + Ok(sent_bytes) + } + + fn check_io_events(&self) -> IoEvents { + let inner = self.inner.read(); + match inner.as_ref() { + Inner::Unbound(unbound) => unbound.check_io_events(), + Inner::Bound(bound) => bound.check_io_events(), + } + } +} + +impl Socket for NetlinkRouteSocket { + fn bind(&self, socket_addr: SocketAddr) -> Result<()> { + let SocketAddr::Netlink(netlink_addr) = socket_addr else { + return_errno_with_message!( + Errno::EAFNOSUPPORT, + "the provided address is not netlink address" + ); + }; + + let mut inner = self.inner.write(); + inner.borrow_result(|owned_inner| match owned_inner.bind(&netlink_addr) { + Ok(bound_inner) => (bound_inner, Ok(())), + Err((err, err_inner)) => (err_inner, Err(err)), + }) + } + + fn addr(&self) -> Result { + let netlink_addr = match self.inner.read().as_ref() { + Inner::Unbound(_) => NetlinkSocketAddr::new_unspecified(), + Inner::Bound(bound) => bound.addr(), + }; + + Ok(SocketAddr::Netlink(netlink_addr)) + } + + fn sendmsg( + &self, + reader: &mut dyn MultiRead, + message_header: MessageHeader, + flags: SendRecvFlags, + ) -> Result { + let MessageHeader { + addr, + control_message, + } = message_header; + + let remote = match addr { + None => None, + Some(addr) => Some(addr.try_into()?), + }; + + if control_message.is_some() { + // TODO: Support sending control message + warn!("sending control message is not supported"); + } + + // TODO: Make sure our blocking behavior matches that of Linux + self.try_send(reader, remote.as_ref(), flags) + } + + fn recvmsg( + &self, + writers: &mut dyn MultiWrite, + flags: SendRecvFlags, + ) -> Result<(usize, MessageHeader)> { + let (received_len, addr) = + self.block_on(IoEvents::IN, || self.try_receive(writers, flags))?; + + // TODO: Receive control message + + let message_header = { + let addr = SocketAddr::Netlink(addr); + MessageHeader::new(Some(addr), None) + }; + + Ok((received_len, message_header)) + } + + fn set_option(&self, _option: &dyn SocketOption) -> Result<()> { + // TODO: This dummy option is added to pass the libnl test + Ok(()) + } +} + +impl SocketPrivate for NetlinkRouteSocket { + fn is_nonblocking(&self) -> bool { + self.is_nonblocking.load(Ordering::Relaxed) + } + + fn set_nonblocking(&self, nonblocking: bool) { + self.is_nonblocking.store(nonblocking, Ordering::Relaxed); + } +} + +impl Pollable for NetlinkRouteSocket { + fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents { + self.pollee + .poll_with(mask, poller, || self.check_io_events()) + } +} + +impl Inner { + fn bind(self, addr: &NetlinkSocketAddr) -> core::result::Result { + let unbound = match self { + Inner::Unbound(unbound) => unbound, + Inner::Bound(bound) => { + // FIXME: We need to further check the Linux behavior + // whether we should return error if the socket is bound. + // The socket may call `bind` syscall to join new multicast groups. + return Err(( + Error::with_message(Errno::EINVAL, "the socket is already bound"), + Self::Bound(bound), + )); + } + }; + + match unbound.bind(addr) { + Ok(bound) => Ok(Self::Bound(bound)), + Err((err, unbound)) => Err((err, Self::Unbound(unbound))), + } + } +} diff --git a/kernel/src/net/socket/netlink/route/unbound.rs b/kernel/src/net/socket/netlink/route/unbound.rs new file mode 100644 index 000000000..27a80fab0 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/unbound.rs @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MPL-2.0 + +use super::bound::BoundNetlinkRoute; +use crate::{ + events::IoEvents, + net::socket::netlink::{ + table::NETLINK_SOCKET_TABLE, NetlinkSocketAddr, StandardNetlinkProtocol, + }, + prelude::*, +}; + +pub(super) struct UnboundNetlinkRoute { + _private: (), +} + +impl UnboundNetlinkRoute { + pub(super) const fn new() -> Self { + Self { _private: () } + } + + pub(super) fn bind( + self, + addr: &NetlinkSocketAddr, + ) -> core::result::Result { + let bound_handle = NETLINK_SOCKET_TABLE + .bind(StandardNetlinkProtocol::ROUTE as _, addr) + .map_err(|err| (err, self))?; + + Ok(BoundNetlinkRoute::new(bound_handle)) + } + + pub(super) fn check_io_events(&self) -> IoEvents { + IoEvents::OUT + } +} diff --git a/kernel/src/net/socket/netlink/table/mod.rs b/kernel/src/net/socket/netlink/table/mod.rs new file mode 100644 index 000000000..cd45c7a8f --- /dev/null +++ b/kernel/src/net/socket/netlink/table/mod.rs @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MPL-2.0 + +use multicast::MulticastGroup; + +use super::addr::{GroupIdSet, NetlinkProtocolId, NetlinkSocketAddr, PortNum, MAX_GROUPS}; +use crate::{net::socket::netlink::addr::UNSPECIFIED_PORT, prelude::*, util::random::getrandom}; + +mod multicast; + +pub(super) static NETLINK_SOCKET_TABLE: NetlinkSocketTable = NetlinkSocketTable::new(); + +/// All bound netlink sockets. +pub(super) struct NetlinkSocketTable { + protocols: [Mutex>; MAX_ALLOWED_PROTOCOL_ID as usize], +} + +impl NetlinkSocketTable { + pub(super) const fn new() -> Self { + Self { + protocols: [const { Mutex::new(None) }; MAX_ALLOWED_PROTOCOL_ID as usize], + } + } + + /// Adds a new netlink protocol. + fn add_new_protocol(&self, protocol_id: NetlinkProtocolId) { + if protocol_id >= MAX_ALLOWED_PROTOCOL_ID { + return; + } + + let mut protocol = self.protocols[protocol_id as usize].lock(); + if protocol.is_some() { + return; + } + + let new_protocol = ProtocolSocketTable::new(protocol_id); + *protocol = Some(new_protocol); + } + + pub(super) fn bind( + &self, + protocol: NetlinkProtocolId, + addr: &NetlinkSocketAddr, + ) -> Result { + if protocol >= MAX_ALLOWED_PROTOCOL_ID { + return_errno_with_message!(Errno::EINVAL, "the netlink protocol does not exist"); + } + + let mut protocol = self.protocols[protocol as usize].lock(); + + let Some(protocol_sockets) = protocol.as_mut() else { + return_errno_with_message!(Errno::EINVAL, "the netlink protocol does not exist") + }; + + protocol_sockets.bind(addr) + } +} + +/// Bound socket table of a single netlink protocol. +/// +/// Each table can have bound sockets for unicast +/// and at most 32 groups for multicast. +struct ProtocolSocketTable { + id: NetlinkProtocolId, + // TODO: This table should maintain the port number-to-socket relationship + // to support both unicast and multicast effectively. + unicast_sockets: BTreeSet, + multicast_groups: Box<[MulticastGroup]>, +} + +impl ProtocolSocketTable { + /// Creates a new table. + fn new(id: NetlinkProtocolId) -> Self { + let multicast_groups = (0u32..MAX_GROUPS).map(|_| MulticastGroup::new()).collect(); + Self { + id, + unicast_sockets: BTreeSet::new(), + multicast_groups, + } + } + + /// Binds a socket to the table. + /// Returns the bound handle. + /// + /// The socket will be bound to a port specified by `addr.port()`. + /// If `addr.port()` is zero, the kernel will assign a port, + /// typically corresponding to the process ID of the current process. + /// If the assigned port is already in use, + /// this function will try to allocate a random unused port. + /// + /// Additionally, this socket can join one or more multicast groups, + /// as specified in `addr.groups()`. + fn bind(&mut self, addr: &NetlinkSocketAddr) -> Result { + let port = if addr.port() != UNSPECIFIED_PORT { + addr.port() + } else { + let mut random_port = current!().pid(); + while random_port == UNSPECIFIED_PORT || self.unicast_sockets.contains(&random_port) { + getrandom(random_port.as_bytes_mut()).unwrap(); + } + random_port + }; + + if self.unicast_sockets.contains(&port) { + return_errno_with_message!(Errno::EADDRINUSE, "the netlink port is already in use"); + } + + self.unicast_sockets.insert(port); + + for group_id in addr.groups().ids_iter() { + let group = &mut self.multicast_groups[group_id as usize]; + group.add_member(port); + } + + Ok(BoundHandle::new(self.id, port, addr.groups())) + } +} + +/// A bound netlink socket address. +/// +/// When dropping a `BoundHandle`, +/// the port will be automatically released. +#[derive(Debug)] +pub(super) struct BoundHandle { + protocol: NetlinkProtocolId, + port: PortNum, + groups: GroupIdSet, +} + +impl BoundHandle { + fn new(protocol: NetlinkProtocolId, port: PortNum, groups: GroupIdSet) -> Self { + debug_assert_ne!(port, UNSPECIFIED_PORT); + + Self { + protocol, + port, + groups, + } + } + + pub(super) const fn addr(&self) -> NetlinkSocketAddr { + NetlinkSocketAddr::new(self.port, self.groups) + } +} + +impl Drop for BoundHandle { + fn drop(&mut self) { + let mut protocol_sockets = NETLINK_SOCKET_TABLE.protocols[self.protocol as usize].lock(); + + let Some(protocol_sockets) = protocol_sockets.as_mut() else { + return; + }; + + protocol_sockets.unicast_sockets.remove(&self.port); + + for group_id in self.groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.remove_member(self.port); + } + } +} + +pub(super) fn init() { + for protocol in 0..MAX_ALLOWED_PROTOCOL_ID { + if is_standard_protocol(protocol) { + NETLINK_SOCKET_TABLE.add_new_protocol(protocol); + } + } +} + +/// Returns whether the `protocol` is valid. +pub fn is_valid_protocol(protocol: NetlinkProtocolId) -> bool { + protocol < MAX_ALLOWED_PROTOCOL_ID +} + +/// Returns whether the `protocol` is reserved for system use. +fn is_standard_protocol(protocol: NetlinkProtocolId) -> bool { + StandardNetlinkProtocol::try_from(protocol).is_ok() +} + +/// Netlink protocols that are assigned for specific usage. +/// +/// Reference: . +#[allow(non_camel_case_types)] +#[repr(u32)] +#[derive(Debug, Clone, Copy, TryFromInt)] +pub enum StandardNetlinkProtocol { + /// Routing/device hook + ROUTE = 0, + /// Unused number + UNUSED = 1, + /// Reserved for user mode socket protocols + USERSOCK = 2, + /// Unused number, formerly ip_queue + FIREWALL = 3, + /// Socket monitoring + SOCK_DIAG = 4, + /// Netfilter/iptables ULOG + NFLOG = 5, + /// IPsec + XFRM = 6, + /// SELinux event notifications + SELINUX = 7, + /// Open-iSCSI + ISCSI = 8, + /// Auditing + AUDIT = 9, + FIB_LOOKUP = 10, + CONNECTOR = 11, + /// Netfilter subsystem + NETFILTER = 12, + IP6_FW = 13, + /// DECnet routing messages + DNRTMSG = 14, + /// Kernel messages to userspace + KOBJECT_UEVENT = 15, + GENERIC = 16, + /// Leave room for NETLINK_DM (DM Events) + /// SCSI Transports + SCSITRANSPORT = 18, + ECRYPTFS = 19, + RDMA = 20, + /// Crypto layer + CRYPTO = 21, + /// SMC monitoring + SMC = 22, +} + +const MAX_ALLOWED_PROTOCOL_ID: NetlinkProtocolId = 32; diff --git a/kernel/src/net/socket/netlink/table/multicast.rs b/kernel/src/net/socket/netlink/table/multicast.rs new file mode 100644 index 000000000..93f01f110 --- /dev/null +++ b/kernel/src/net/socket/netlink/table/multicast.rs @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MPL-2.0 + +use crate::{net::socket::netlink::addr::PortNum, prelude::*}; + +/// A netlink multicast group. +/// +/// A group can contain multiple sockets, +/// each identified by its bound port number. +pub struct MulticastGroup { + members: BTreeSet, +} + +impl MulticastGroup { + /// Creates a new multicast group. + pub const fn new() -> Self { + Self { + members: BTreeSet::new(), + } + } + + /// Returns whether the group contains a member. + #[expect(unused)] + pub fn contains_member(&self, port_num: PortNum) -> bool { + self.members.contains(&port_num) + } + + /// Adds a new member to the multicast group. + pub fn add_member(&mut self, port_num: PortNum) { + debug_assert!(!self.members.contains(&port_num)); + self.members.insert(port_num); + } + + /// Removes a member from the multicast group. + pub fn remove_member(&mut self, port_num: PortNum) { + debug_assert!(self.members.contains(&port_num)); + self.members.remove(&port_num); + } +} diff --git a/kernel/src/net/socket/util/socket_addr.rs b/kernel/src/net/socket/util/socket_addr.rs index 3c1b994f3..70c15fee0 100644 --- a/kernel/src/net/socket/util/socket_addr.rs +++ b/kernel/src/net/socket/util/socket_addr.rs @@ -3,7 +3,7 @@ use aster_bigtcp::wire::{Ipv4Address, PortNum}; use crate::{ - net::socket::{unix::UnixSocketAddr, vsock::addr::VsockSocketAddr}, + net::socket::{netlink::NetlinkSocketAddr, unix::UnixSocketAddr, vsock::addr::VsockSocketAddr}, prelude::*, }; @@ -11,5 +11,6 @@ use crate::{ pub enum SocketAddr { Unix(UnixSocketAddr), IPv4(Ipv4Address, PortNum), + Netlink(NetlinkSocketAddr), Vsock(VsockSocketAddr), } diff --git a/kernel/src/net/socket/vsock/mod.rs b/kernel/src/net/socket/vsock/mod.rs index 72e43bb32..aee74940a 100644 --- a/kernel/src/net/socket/vsock/mod.rs +++ b/kernel/src/net/socket/vsock/mod.rs @@ -15,7 +15,7 @@ pub use stream::VsockStreamSocket; // init static driver pub static VSOCK_GLOBAL: Once> = Once::new(); -pub fn init() { +pub(in crate::net) fn init() { if let Some(driver) = get_device(DEVICE_NAME) { VSOCK_GLOBAL.call_once(|| Arc::new(VsockSpace::new(driver))); register_recv_callback(DEVICE_NAME, || { diff --git a/kernel/src/syscall/socket.rs b/kernel/src/syscall/socket.rs index 7c4e79509..c526bfeb8 100644 --- a/kernel/src/syscall/socket.rs +++ b/kernel/src/syscall/socket.rs @@ -5,6 +5,7 @@ use crate::{ fs::{file_handle::FileLike, file_table::FdFlags}, net::socket::{ ip::{datagram::DatagramSocket, stream::StreamSocket}, + netlink::{is_valid_protocol, NetlinkRouteSocket, StandardNetlinkProtocol}, unix::UnixStreamSocket, vsock::VsockStreamSocket, }, @@ -16,29 +17,62 @@ pub fn sys_socket(domain: i32, type_: i32, protocol: i32, ctx: &Context) -> Resu let domain = CSocketAddrFamily::try_from(domain)?; let sock_type = SockType::try_from(type_ & SOCK_TYPE_MASK)?; let sock_flags = SockFlags::from_bits_truncate(type_ & !SOCK_TYPE_MASK); - let protocol = Protocol::try_from(protocol)?; debug!( - "domain = {:?}, sock_type = {:?}, sock_flags = {:?}, protocol = {:?}", - domain, sock_type, sock_flags, protocol + "domain = {:?}, sock_type = {:?}, sock_flags = {:?}", + domain, sock_type, sock_flags ); - let nonblocking = sock_flags.contains(SockFlags::SOCK_NONBLOCK); - let file_like = match (domain, sock_type, protocol) { + let is_nonblocking = sock_flags.contains(SockFlags::SOCK_NONBLOCK); + let file_like = match (domain, sock_type) { // FIXME: SOCK_SEQPACKET is added to run fcntl_test, not supported yet. - (CSocketAddrFamily::AF_UNIX, SockType::SOCK_STREAM | SockType::SOCK_SEQPACKET, _) => { - UnixStreamSocket::new(nonblocking) as Arc + (CSocketAddrFamily::AF_UNIX, SockType::SOCK_STREAM | SockType::SOCK_SEQPACKET) => { + UnixStreamSocket::new(is_nonblocking) as Arc } - ( - CSocketAddrFamily::AF_INET, - SockType::SOCK_STREAM, - Protocol::IPPROTO_IP | Protocol::IPPROTO_TCP, - ) => StreamSocket::new(nonblocking) as Arc, - ( - CSocketAddrFamily::AF_INET, - SockType::SOCK_DGRAM, - Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP, - ) => DatagramSocket::new(nonblocking) as Arc, - (CSocketAddrFamily::AF_VSOCK, SockType::SOCK_STREAM, _) => { - Arc::new(VsockStreamSocket::new(nonblocking)) as Arc + (CSocketAddrFamily::AF_INET, SockType::SOCK_STREAM) => { + let protocol = Protocol::try_from(protocol)?; + debug!("protocol = {:?}", protocol); + match protocol { + Protocol::IPPROTO_IP | Protocol::IPPROTO_TCP => { + StreamSocket::new(is_nonblocking) as Arc + } + _ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported protocol"), + } + } + (CSocketAddrFamily::AF_INET, SockType::SOCK_DGRAM) => { + let protocol = Protocol::try_from(protocol)?; + debug!("protocol = {:?}", protocol); + match protocol { + Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP => { + DatagramSocket::new(is_nonblocking) as Arc + } + _ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported protocol"), + } + } + (CSocketAddrFamily::AF_NETLINK, SockType::SOCK_RAW | SockType::SOCK_DGRAM) => { + let netlink_family = StandardNetlinkProtocol::try_from(protocol as u32); + debug!("netlink family = {:?}", netlink_family); + match netlink_family { + Ok(StandardNetlinkProtocol::ROUTE) => { + Arc::new(NetlinkRouteSocket::new(is_nonblocking)) + } + Ok(_) => { + return_errno_with_message!( + Errno::EAFNOSUPPORT, + "some standard netlink families are not supported yet" + ); + } + Err(_) => { + if is_valid_protocol(protocol as u32) { + return_errno_with_message!( + Errno::EAFNOSUPPORT, + "user-provided netlink family is not supported" + ) + } + return_errno_with_message!(Errno::EAFNOSUPPORT, "invalid netlink family"); + } + } + } + (CSocketAddrFamily::AF_VSOCK, SockType::SOCK_STREAM) => { + Arc::new(VsockStreamSocket::new(is_nonblocking)) as Arc } _ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported domain"), }; diff --git a/kernel/src/util/net/addr/family.rs b/kernel/src/util/net/addr/family.rs index 4af9a1b21..14a4d0119 100644 --- a/kernel/src/util/net/addr/family.rs +++ b/kernel/src/util/net/addr/family.rs @@ -4,7 +4,7 @@ use core::cmp::min; use ostd::task::Task; -use super::{ip::CSocketAddrInet, unix, vsock::CSocketAddrVm}; +use super::{ip::CSocketAddrInet, netlink::CSocketAddrNetlink, unix, vsock::CSocketAddrVm}; use crate::{current_userspace, net::socket::SocketAddr, prelude::*}; /// Address family. @@ -162,6 +162,13 @@ pub fn read_socket_addr_from_user(addr: Vaddr, addr_len: usize) -> Result { + if addr_len < size_of::() { + return_errno_with_message!(Errno::EINVAL, "the socket address length is too small"); + } + let addr = CSocketAddrNetlink::from_bytes(storage.as_bytes()); + SocketAddr::Netlink(addr.into()) + } Ok(CSocketAddrFamily::AF_VSOCK) => { if addr_len < size_of::() { return_errno_with_message!(Errno::EINVAL, "the socket address length is too small"); @@ -238,36 +245,45 @@ pub fn write_socket_addr_with_max_len( ); } - let current_task = Task::current().unwrap(); - let user_space = CurrentUserSpace::new(¤t_task); - let actual_len = match socket_addr { - SocketAddr::IPv4(addr, port) => { - let socket_addr = CSocketAddrInet::from((*addr, *port)); - let actual_len = size_of::(); - let written_len = min(actual_len, max_len as _); - user_space.write_bytes( - dest, - &mut VmReader::from(&socket_addr.as_bytes()[..written_len]), - )?; - actual_len - } + SocketAddr::IPv4(addr, port) => write_c_socket_address_util::( + (*addr, *port), + dest, + max_len as usize, + )?, SocketAddr::Unix(addr) => unix::into_c_bytes_and(addr, |bytes| { let written_len = min(bytes.len(), max_len as _); - user_space.write_bytes(dest, &mut VmReader::from(&bytes[..written_len]))?; + current_userspace!().write_bytes(dest, &mut VmReader::from(&bytes[..written_len]))?; Ok::(bytes.len()) })?, + SocketAddr::Netlink(addr) => { + write_c_socket_address_util::(*addr, dest, max_len as usize)? + } SocketAddr::Vsock(addr) => { - let socket_addr = CSocketAddrVm::from(*addr); - let actual_len = size_of::(); - let written_len = min(actual_len, max_len as _); - user_space.write_bytes( - dest, - &mut VmReader::from(&socket_addr.as_bytes()[..written_len]), - )?; - actual_len + write_c_socket_address_util::(*addr, dest, max_len as usize)? } }; Ok(actual_len as i32) } + +// Utility function to write a C socket address to user space. +fn write_c_socket_address_util( + addr: TSockAddr, + dest: Vaddr, + max_len: usize, +) -> Result +where + TCSockAddr: From, +{ + let c_socket_addr = TCSockAddr::from(addr); + let actual_len = size_of::(); + let written_len = min(actual_len, max_len); + + current_userspace!().write_bytes( + dest, + &mut VmReader::from(&c_socket_addr.as_bytes()[..written_len]), + )?; + + Ok(actual_len) +} diff --git a/kernel/src/util/net/addr/ip.rs b/kernel/src/util/net/addr/ip.rs index c2f9bb908..847c4f2b6 100644 --- a/kernel/src/util/net/addr/ip.rs +++ b/kernel/src/util/net/addr/ip.rs @@ -38,6 +38,7 @@ impl From<(Ipv4Address, PortNum)> for CSocketAddrInet { impl From for (Ipv4Address, PortNum) { fn from(value: CSocketAddrInet) -> Self { + debug_assert_eq!(value.sin_family, CSocketAddrFamily::AF_INET as u16); (value.sin_addr.into(), value.sin_port.into()) } } diff --git a/kernel/src/util/net/addr/mod.rs b/kernel/src/util/net/addr/mod.rs index f4db56cf7..6bdd64305 100644 --- a/kernel/src/util/net/addr/mod.rs +++ b/kernel/src/util/net/addr/mod.rs @@ -7,5 +7,6 @@ pub use family::{ mod family; mod ip; +mod netlink; mod unix; mod vsock; diff --git a/kernel/src/util/net/addr/netlink.rs b/kernel/src/util/net/addr/netlink.rs new file mode 100644 index 000000000..aba463a76 --- /dev/null +++ b/kernel/src/util/net/addr/netlink.rs @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MPL-2.0 + +use super::CSocketAddrFamily; +use crate::{ + net::socket::netlink::{GroupIdSet, NetlinkSocketAddr}, + prelude::*, +}; + +/// Netlink socket address. +#[repr(C)] +#[derive(Debug, Clone, Copy, Pod)] +pub struct CSocketAddrNetlink { + /// Address family (AF_NETLINK). + nl_family: u16, + /// Pad bytes (always zero). + nl_pad: u16, + /// Port ID. + nl_pid: u32, + /// Multicast groups mask. + nl_groups: u32, +} + +impl From for CSocketAddrNetlink { + fn from(value: NetlinkSocketAddr) -> Self { + Self { + nl_family: CSocketAddrFamily::AF_NETLINK as _, + nl_pad: 0, + nl_pid: value.port(), + nl_groups: value.groups().as_u32(), + } + } +} + +impl From for NetlinkSocketAddr { + fn from(value: CSocketAddrNetlink) -> Self { + debug_assert_eq!(value.nl_family, CSocketAddrFamily::AF_NETLINK as u16); + let port = value.nl_pid; + let groups = GroupIdSet::new(value.nl_groups); + NetlinkSocketAddr::new(port, groups) + } +} diff --git a/kernel/src/util/net/addr/vsock.rs b/kernel/src/util/net/addr/vsock.rs index 8e1dfe88f..3dc52f6e8 100644 --- a/kernel/src/util/net/addr/vsock.rs +++ b/kernel/src/util/net/addr/vsock.rs @@ -33,6 +33,7 @@ impl From for CSocketAddrVm { impl From for VsockSocketAddr { fn from(value: CSocketAddrVm) -> Self { + debug_assert_eq!(value.svm_family, CSocketAddrFamily::AF_VSOCK as u16); Self { cid: value.svm_cid, port: value.svm_port,