Add basic structure for netlink route socket

This commit is contained in:
jiangjianfeng 2025-04-11 07:31:21 +00:00 committed by Tate, Hongliang Tian
parent 2c41055470
commit ac42e83387
18 changed files with 1001 additions and 44 deletions

View File

@ -5,6 +5,7 @@ pub mod socket;
pub fn init() {
iface::init();
socket::netlink::init();
socket::vsock::init();
}

View File

@ -15,6 +15,7 @@ use crate::{
};
pub mod ip;
pub mod netlink;
pub mod options;
pub mod unix;
mod util;

View File

@ -0,0 +1,71 @@
// SPDX-License-Identifier: MPL-2.0
mod multicast;
pub use multicast::{GroupIdSet, MAX_GROUPS};
use crate::{net::socket::SocketAddr, prelude::*};
/// The socket address of a netlink socket.
///
/// The address contains the port number for unicast
/// and the group IDs for multicast.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct NetlinkSocketAddr {
port: PortNum,
groups: GroupIdSet,
}
impl NetlinkSocketAddr {
/// Creates a new netlink address.
pub const fn new(port: PortNum, groups: GroupIdSet) -> Self {
Self { port, groups }
}
/// Creates a new unspecified address.
///
/// Both the port ID and group numbers are left unspecified.
///
/// Note that an unspecified address can also represent the kernel socket address.
pub const fn new_unspecified() -> Self {
Self {
port: UNSPECIFIED_PORT,
groups: GroupIdSet::new_empty(),
}
}
/// Returns the port number.
pub const fn port(&self) -> PortNum {
self.port
}
/// Returns the group ID set.
pub const fn groups(&self) -> GroupIdSet {
self.groups
}
}
impl TryFrom<SocketAddr> for NetlinkSocketAddr {
type Error = Error;
fn try_from(value: SocketAddr) -> Result<Self> {
match value {
SocketAddr::Netlink(addr) => Ok(addr),
_ => return_errno_with_message!(
Errno::EAFNOSUPPORT,
"the address is in an unsupported address family"
),
}
}
}
impl From<NetlinkSocketAddr> for SocketAddr {
fn from(value: NetlinkSocketAddr) -> Self {
SocketAddr::Netlink(value)
}
}
pub type NetlinkProtocolId = u32;
pub type PortNum = u32;
pub const UNSPECIFIED_PORT: PortNum = 0;

View File

@ -0,0 +1,91 @@
// SPDX-License-Identifier: MPL-2.0
use crate::prelude::*;
/// A set of group IDs.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct GroupIdSet(u32);
impl GroupIdSet {
/// Creates a new empty `GroupIdSet`.
pub const fn new_empty() -> Self {
Self(0)
}
/// Creates a new `GroupIdSet` with multiple groups.
///
/// Each 1 bit in `groups` represent a group.
pub const fn new(groups: u32) -> Self {
Self(groups)
}
/// Creates an iterator over all group IDs.
pub const fn ids_iter(&self) -> GroupIdIter {
GroupIdIter::new(self)
}
/// Adds a new group.
///
/// If the group already exists, this method will return an error.
pub fn add_group(&mut self, group_id: GroupId) -> Result<()> {
if group_id >= 32 {
return_errno_with_message!(Errno::EINVAL, "the group ID is invalid");
}
let mask = 1u32 << group_id;
if self.0 & mask != 0 {
return_errno_with_message!(Errno::EINVAL, "the group ID already exists");
}
self.0 |= mask;
Ok(())
}
/// Sets new groups.
pub fn set_groups(&mut self, new_groups: u32) {
self.0 = new_groups;
}
/// Clears all groups.
pub fn clear(&mut self) {
self.0 = 0;
}
/// Checks if the set of group IDs is empty.
pub fn is_empty(&self) -> bool {
self.0 == 0
}
/// Returns the group IDs as a u32.
pub fn as_u32(&self) -> u32 {
self.0
}
}
/// Iterator over a set of group IDs.
pub struct GroupIdIter {
groups: u32,
}
impl GroupIdIter {
const fn new(groups: &GroupIdSet) -> Self {
Self { groups: groups.0 }
}
}
impl Iterator for GroupIdIter {
type Item = GroupId;
fn next(&mut self) -> Option<Self::Item> {
if self.groups > 0 {
let group_id = self.groups.trailing_zeros();
self.groups &= self.groups - 1;
return Some(group_id);
}
None
}
}
pub const MAX_GROUPS: u32 = 32;
pub type GroupId = u32;

View File

@ -0,0 +1,51 @@
// SPDX-License-Identifier: MPL-2.0
//! This module defines netlink sockets.
//!
//! Netlink provides a standardized, socket-based interface,
//! typically used for communication between user space and kernel space.
//! It can also be used for interaction between two user processes.
//!
//! Each netlink socket corresponds to
//! a netlink protocol identified by a protocol ID (u32).
//! Protocols are generally defined to serve specific functions.
//! For instance, the NETLINK_ROUTE protocol is employed
//! to retrieve or modify network device settings.
//! Only sockets associated with the same protocol can communicate with each other.
//! Some protocols are pre-defined by the kernel and serve fixed purposes,
//! but users can also establish custom protocols by specifying new protocol IDs.
//!
//! Before initiating communication,
//! a netlink socket must be bound to an address,
//! which consists of a port number and a multicast group number.
//!
//! The port number is used for unicast communication,
//! whereas the multicast group number is meant for multicast communication.
//!
//! In terms of unicast communication within each protocol,
//! a port number can only be bound to one socket.
//! However, the same port number can be utilized across different protocols.
//! Typically, the port number corresponds to the process ID of the running process.
//!
//! Multicast communication allows a message
//! to be sent to one or multiple multicast groups simultaneously.
//! Each protocol can support up to 32 multicast groups,
//! and a socket can belong to zero or multiple multicast groups.
//!
//! Netlink communication is akin to UDP in that
//! it does not require a connection to be established before sending messages.
//! The destination address must be specified when dispatching a message.
//!
mod addr;
mod message;
mod route;
mod table;
pub use addr::{GroupIdSet, NetlinkSocketAddr};
pub use route::NetlinkRouteSocket;
pub use table::{is_valid_protocol, StandardNetlinkProtocol};
pub(in crate::net) fn init() {
table::init();
}

View File

@ -0,0 +1,139 @@
// SPDX-License-Identifier: MPL-2.0
use core::ops::Sub;
use super::message::RtnlMessage;
use crate::{
events::IoEvents,
net::socket::{
netlink::{
message::ProtocolSegment, route::kernel::get_netlink_route_kernel, table::BoundHandle,
NetlinkSocketAddr,
},
SendRecvFlags,
},
prelude::*,
util::{MultiRead, MultiWrite},
};
pub(super) struct BoundNetlinkRoute {
handle: BoundHandle,
receive_queue: Mutex<VecDeque<RtnlMessage>>,
}
impl BoundNetlinkRoute {
pub(super) const fn new(handle: BoundHandle) -> Self {
Self {
handle,
receive_queue: Mutex::new(VecDeque::new()),
}
}
pub(super) const fn addr(&self) -> NetlinkSocketAddr {
self.handle.addr()
}
pub(super) fn try_send(
&self,
reader: &mut dyn MultiRead,
remote: Option<&NetlinkSocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
// TODO: Deal with flags
if !flags.is_all_supported() {
warn!("unsupported flags: {:?}", flags);
}
if let Some(remote) = remote {
// TODO: Further check whether other socket address can be supported.
if *remote != NetlinkSocketAddr::new_unspecified() {
return_errno_with_message!(
Errno::ECONNREFUSED,
"sending netlink route messages to user space is not supported"
);
}
} else {
// TODO: We should use the connected remote address, if any.
}
let mut nlmsg = {
let sum_lens = reader.sum_lens();
match RtnlMessage::read_from(reader) {
Ok(nlmsg) => nlmsg,
Err(e) if e.error() == Errno::EFAULT => {
// EFAULT indicates an error occurred while copying data from user space,
// and this error should be returned back to user space.
return Err(e);
}
Err(e) => {
// Errors other than EFAULT indicate a failure in parsing the netlink message.
// These errors should be silently ignored.
warn!("failed to send netlink message: {:?}", e);
return Ok(sum_lens);
}
}
};
let local_port = self.addr().port();
for segment in nlmsg.segments_mut() {
// The header's PID should be the sender's port ID.
// However, the sender can also leave it unspecified.
// In such cases, we will manually set the PID to the sender's port ID.
let header = segment.header_mut();
if header.pid == 0 {
header.pid = local_port;
}
}
get_netlink_route_kernel().request(&nlmsg, |response| {
self.receive_queue.lock().push_back(response);
});
Ok(nlmsg.total_len())
}
pub(super) fn try_receive(
&self,
writer: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, NetlinkSocketAddr)> {
// TODO: Deal with other flags. Only MSG_PEEK is handled here.
if !flags.sub(SendRecvFlags::MSG_PEEK).is_all_supported() {
warn!("unsupported flags: {:?}", flags);
}
let mut receive_queue = self.receive_queue.lock();
let Some(response) = receive_queue.front() else {
return_errno_with_message!(Errno::EAGAIN, "nothing to receive");
};
let len = {
let max_len = writer.sum_lens();
response.total_len().min(max_len)
};
response.write_to(writer)?;
if !flags.contains(SendRecvFlags::MSG_PEEK) {
receive_queue.pop_front().unwrap();
}
// TODO: The message can only come from kernel socket currently.
let remote = NetlinkSocketAddr::new_unspecified();
Ok((len, remote))
}
pub(super) fn check_io_events(&self) -> IoEvents {
let mut events = IoEvents::OUT;
let receive_queue = self.receive_queue.lock();
if !receive_queue.is_empty() {
events |= IoEvents::IN;
}
events
}
}

View File

@ -0,0 +1,207 @@
// SPDX-License-Identifier: MPL-2.0
//! Netlink Route Socket.
use core::sync::atomic::{AtomicBool, Ordering};
use bound::BoundNetlinkRoute;
use takeable::Takeable;
use unbound::UnboundNetlinkRoute;
use super::NetlinkSocketAddr;
use crate::{
events::IoEvents,
net::socket::{
options::SocketOption, private::SocketPrivate, MessageHeader, SendRecvFlags, Socket,
SocketAddr,
},
prelude::*,
process::signal::{PollHandle, Pollable, Pollee},
util::{MultiRead, MultiWrite},
};
mod bound;
mod kernel;
mod message;
mod unbound;
pub struct NetlinkRouteSocket {
is_nonblocking: AtomicBool,
pollee: Pollee,
inner: RwMutex<Takeable<Inner>>,
}
enum Inner {
Unbound(UnboundNetlinkRoute),
Bound(BoundNetlinkRoute),
}
impl NetlinkRouteSocket {
pub fn new(is_nonblocking: bool) -> Self {
Self {
is_nonblocking: AtomicBool::new(is_nonblocking),
pollee: Pollee::new(),
inner: RwMutex::new(Takeable::new(Inner::Unbound(UnboundNetlinkRoute::new()))),
}
}
fn try_receive(
&self,
writer: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, NetlinkSocketAddr)> {
let inner = self.inner.read();
let bound = match inner.as_ref() {
Inner::Unbound(_) => {
return_errno_with_message!(Errno::EAGAIN, "the socket is not bound")
}
Inner::Bound(bound_netlink_route) => bound_netlink_route,
};
let received = bound.try_receive(writer, flags)?;
self.pollee.invalidate();
Ok(received)
}
fn try_send(
&self,
reader: &mut dyn MultiRead,
remote: Option<&NetlinkSocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
let inner = self.inner.read();
let bound = match inner.as_ref() {
Inner::Unbound(_) => todo!(),
Inner::Bound(bound) => bound,
};
let sent_bytes = bound.try_send(reader, remote, flags)?;
self.pollee.notify(IoEvents::OUT | IoEvents::IN);
Ok(sent_bytes)
}
fn check_io_events(&self) -> IoEvents {
let inner = self.inner.read();
match inner.as_ref() {
Inner::Unbound(unbound) => unbound.check_io_events(),
Inner::Bound(bound) => bound.check_io_events(),
}
}
}
impl Socket for NetlinkRouteSocket {
fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
let SocketAddr::Netlink(netlink_addr) = socket_addr else {
return_errno_with_message!(
Errno::EAFNOSUPPORT,
"the provided address is not netlink address"
);
};
let mut inner = self.inner.write();
inner.borrow_result(|owned_inner| match owned_inner.bind(&netlink_addr) {
Ok(bound_inner) => (bound_inner, Ok(())),
Err((err, err_inner)) => (err_inner, Err(err)),
})
}
fn addr(&self) -> Result<SocketAddr> {
let netlink_addr = match self.inner.read().as_ref() {
Inner::Unbound(_) => NetlinkSocketAddr::new_unspecified(),
Inner::Bound(bound) => bound.addr(),
};
Ok(SocketAddr::Netlink(netlink_addr))
}
fn sendmsg(
&self,
reader: &mut dyn MultiRead,
message_header: MessageHeader,
flags: SendRecvFlags,
) -> Result<usize> {
let MessageHeader {
addr,
control_message,
} = message_header;
let remote = match addr {
None => None,
Some(addr) => Some(addr.try_into()?),
};
if control_message.is_some() {
// TODO: Support sending control message
warn!("sending control message is not supported");
}
// TODO: Make sure our blocking behavior matches that of Linux
self.try_send(reader, remote.as_ref(), flags)
}
fn recvmsg(
&self,
writers: &mut dyn MultiWrite,
flags: SendRecvFlags,
) -> Result<(usize, MessageHeader)> {
let (received_len, addr) =
self.block_on(IoEvents::IN, || self.try_receive(writers, flags))?;
// TODO: Receive control message
let message_header = {
let addr = SocketAddr::Netlink(addr);
MessageHeader::new(Some(addr), None)
};
Ok((received_len, message_header))
}
fn set_option(&self, _option: &dyn SocketOption) -> Result<()> {
// TODO: This dummy option is added to pass the libnl test
Ok(())
}
}
impl SocketPrivate for NetlinkRouteSocket {
fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Relaxed)
}
fn set_nonblocking(&self, nonblocking: bool) {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
}
}
impl Pollable for NetlinkRouteSocket {
fn poll(&self, mask: IoEvents, poller: Option<&mut PollHandle>) -> IoEvents {
self.pollee
.poll_with(mask, poller, || self.check_io_events())
}
}
impl Inner {
fn bind(self, addr: &NetlinkSocketAddr) -> core::result::Result<Self, (Error, Self)> {
let unbound = match self {
Inner::Unbound(unbound) => unbound,
Inner::Bound(bound) => {
// FIXME: We need to further check the Linux behavior
// whether we should return error if the socket is bound.
// The socket may call `bind` syscall to join new multicast groups.
return Err((
Error::with_message(Errno::EINVAL, "the socket is already bound"),
Self::Bound(bound),
));
}
};
match unbound.bind(addr) {
Ok(bound) => Ok(Self::Bound(bound)),
Err((err, unbound)) => Err((err, Self::Unbound(unbound))),
}
}
}

View File

@ -0,0 +1,35 @@
// SPDX-License-Identifier: MPL-2.0
use super::bound::BoundNetlinkRoute;
use crate::{
events::IoEvents,
net::socket::netlink::{
table::NETLINK_SOCKET_TABLE, NetlinkSocketAddr, StandardNetlinkProtocol,
},
prelude::*,
};
pub(super) struct UnboundNetlinkRoute {
_private: (),
}
impl UnboundNetlinkRoute {
pub(super) const fn new() -> Self {
Self { _private: () }
}
pub(super) fn bind(
self,
addr: &NetlinkSocketAddr,
) -> core::result::Result<BoundNetlinkRoute, (Error, Self)> {
let bound_handle = NETLINK_SOCKET_TABLE
.bind(StandardNetlinkProtocol::ROUTE as _, addr)
.map_err(|err| (err, self))?;
Ok(BoundNetlinkRoute::new(bound_handle))
}
pub(super) fn check_io_events(&self) -> IoEvents {
IoEvents::OUT
}
}

View File

@ -0,0 +1,228 @@
// SPDX-License-Identifier: MPL-2.0
use multicast::MulticastGroup;
use super::addr::{GroupIdSet, NetlinkProtocolId, NetlinkSocketAddr, PortNum, MAX_GROUPS};
use crate::{net::socket::netlink::addr::UNSPECIFIED_PORT, prelude::*, util::random::getrandom};
mod multicast;
pub(super) static NETLINK_SOCKET_TABLE: NetlinkSocketTable = NetlinkSocketTable::new();
/// All bound netlink sockets.
pub(super) struct NetlinkSocketTable {
protocols: [Mutex<Option<ProtocolSocketTable>>; MAX_ALLOWED_PROTOCOL_ID as usize],
}
impl NetlinkSocketTable {
pub(super) const fn new() -> Self {
Self {
protocols: [const { Mutex::new(None) }; MAX_ALLOWED_PROTOCOL_ID as usize],
}
}
/// Adds a new netlink protocol.
fn add_new_protocol(&self, protocol_id: NetlinkProtocolId) {
if protocol_id >= MAX_ALLOWED_PROTOCOL_ID {
return;
}
let mut protocol = self.protocols[protocol_id as usize].lock();
if protocol.is_some() {
return;
}
let new_protocol = ProtocolSocketTable::new(protocol_id);
*protocol = Some(new_protocol);
}
pub(super) fn bind(
&self,
protocol: NetlinkProtocolId,
addr: &NetlinkSocketAddr,
) -> Result<BoundHandle> {
if protocol >= MAX_ALLOWED_PROTOCOL_ID {
return_errno_with_message!(Errno::EINVAL, "the netlink protocol does not exist");
}
let mut protocol = self.protocols[protocol as usize].lock();
let Some(protocol_sockets) = protocol.as_mut() else {
return_errno_with_message!(Errno::EINVAL, "the netlink protocol does not exist")
};
protocol_sockets.bind(addr)
}
}
/// Bound socket table of a single netlink protocol.
///
/// Each table can have bound sockets for unicast
/// and at most 32 groups for multicast.
struct ProtocolSocketTable {
id: NetlinkProtocolId,
// TODO: This table should maintain the port number-to-socket relationship
// to support both unicast and multicast effectively.
unicast_sockets: BTreeSet<PortNum>,
multicast_groups: Box<[MulticastGroup]>,
}
impl ProtocolSocketTable {
/// Creates a new table.
fn new(id: NetlinkProtocolId) -> Self {
let multicast_groups = (0u32..MAX_GROUPS).map(|_| MulticastGroup::new()).collect();
Self {
id,
unicast_sockets: BTreeSet::new(),
multicast_groups,
}
}
/// Binds a socket to the table.
/// Returns the bound handle.
///
/// The socket will be bound to a port specified by `addr.port()`.
/// If `addr.port()` is zero, the kernel will assign a port,
/// typically corresponding to the process ID of the current process.
/// If the assigned port is already in use,
/// this function will try to allocate a random unused port.
///
/// Additionally, this socket can join one or more multicast groups,
/// as specified in `addr.groups()`.
fn bind(&mut self, addr: &NetlinkSocketAddr) -> Result<BoundHandle> {
let port = if addr.port() != UNSPECIFIED_PORT {
addr.port()
} else {
let mut random_port = current!().pid();
while random_port == UNSPECIFIED_PORT || self.unicast_sockets.contains(&random_port) {
getrandom(random_port.as_bytes_mut()).unwrap();
}
random_port
};
if self.unicast_sockets.contains(&port) {
return_errno_with_message!(Errno::EADDRINUSE, "the netlink port is already in use");
}
self.unicast_sockets.insert(port);
for group_id in addr.groups().ids_iter() {
let group = &mut self.multicast_groups[group_id as usize];
group.add_member(port);
}
Ok(BoundHandle::new(self.id, port, addr.groups()))
}
}
/// A bound netlink socket address.
///
/// When dropping a `BoundHandle`,
/// the port will be automatically released.
#[derive(Debug)]
pub(super) struct BoundHandle {
protocol: NetlinkProtocolId,
port: PortNum,
groups: GroupIdSet,
}
impl BoundHandle {
fn new(protocol: NetlinkProtocolId, port: PortNum, groups: GroupIdSet) -> Self {
debug_assert_ne!(port, UNSPECIFIED_PORT);
Self {
protocol,
port,
groups,
}
}
pub(super) const fn addr(&self) -> NetlinkSocketAddr {
NetlinkSocketAddr::new(self.port, self.groups)
}
}
impl Drop for BoundHandle {
fn drop(&mut self) {
let mut protocol_sockets = NETLINK_SOCKET_TABLE.protocols[self.protocol as usize].lock();
let Some(protocol_sockets) = protocol_sockets.as_mut() else {
return;
};
protocol_sockets.unicast_sockets.remove(&self.port);
for group_id in self.groups.ids_iter() {
let group = &mut protocol_sockets.multicast_groups[group_id as usize];
group.remove_member(self.port);
}
}
}
pub(super) fn init() {
for protocol in 0..MAX_ALLOWED_PROTOCOL_ID {
if is_standard_protocol(protocol) {
NETLINK_SOCKET_TABLE.add_new_protocol(protocol);
}
}
}
/// Returns whether the `protocol` is valid.
pub fn is_valid_protocol(protocol: NetlinkProtocolId) -> bool {
protocol < MAX_ALLOWED_PROTOCOL_ID
}
/// Returns whether the `protocol` is reserved for system use.
fn is_standard_protocol(protocol: NetlinkProtocolId) -> bool {
StandardNetlinkProtocol::try_from(protocol).is_ok()
}
/// Netlink protocols that are assigned for specific usage.
///
/// Reference: <https://elixir.bootlin.com/linux/v6.0.9/source/include/uapi/linux/netlink.h#L9>.
#[allow(non_camel_case_types)]
#[repr(u32)]
#[derive(Debug, Clone, Copy, TryFromInt)]
pub enum StandardNetlinkProtocol {
/// Routing/device hook
ROUTE = 0,
/// Unused number
UNUSED = 1,
/// Reserved for user mode socket protocols
USERSOCK = 2,
/// Unused number, formerly ip_queue
FIREWALL = 3,
/// Socket monitoring
SOCK_DIAG = 4,
/// Netfilter/iptables ULOG
NFLOG = 5,
/// IPsec
XFRM = 6,
/// SELinux event notifications
SELINUX = 7,
/// Open-iSCSI
ISCSI = 8,
/// Auditing
AUDIT = 9,
FIB_LOOKUP = 10,
CONNECTOR = 11,
/// Netfilter subsystem
NETFILTER = 12,
IP6_FW = 13,
/// DECnet routing messages
DNRTMSG = 14,
/// Kernel messages to userspace
KOBJECT_UEVENT = 15,
GENERIC = 16,
/// Leave room for NETLINK_DM (DM Events)
/// SCSI Transports
SCSITRANSPORT = 18,
ECRYPTFS = 19,
RDMA = 20,
/// Crypto layer
CRYPTO = 21,
/// SMC monitoring
SMC = 22,
}
const MAX_ALLOWED_PROTOCOL_ID: NetlinkProtocolId = 32;

View File

@ -0,0 +1,38 @@
// SPDX-License-Identifier: MPL-2.0
use crate::{net::socket::netlink::addr::PortNum, prelude::*};
/// A netlink multicast group.
///
/// A group can contain multiple sockets,
/// each identified by its bound port number.
pub struct MulticastGroup {
members: BTreeSet<PortNum>,
}
impl MulticastGroup {
/// Creates a new multicast group.
pub const fn new() -> Self {
Self {
members: BTreeSet::new(),
}
}
/// Returns whether the group contains a member.
#[expect(unused)]
pub fn contains_member(&self, port_num: PortNum) -> bool {
self.members.contains(&port_num)
}
/// Adds a new member to the multicast group.
pub fn add_member(&mut self, port_num: PortNum) {
debug_assert!(!self.members.contains(&port_num));
self.members.insert(port_num);
}
/// Removes a member from the multicast group.
pub fn remove_member(&mut self, port_num: PortNum) {
debug_assert!(self.members.contains(&port_num));
self.members.remove(&port_num);
}
}

View File

@ -3,7 +3,7 @@
use aster_bigtcp::wire::{Ipv4Address, PortNum};
use crate::{
net::socket::{unix::UnixSocketAddr, vsock::addr::VsockSocketAddr},
net::socket::{netlink::NetlinkSocketAddr, unix::UnixSocketAddr, vsock::addr::VsockSocketAddr},
prelude::*,
};
@ -11,5 +11,6 @@ use crate::{
pub enum SocketAddr {
Unix(UnixSocketAddr),
IPv4(Ipv4Address, PortNum),
Netlink(NetlinkSocketAddr),
Vsock(VsockSocketAddr),
}

View File

@ -15,7 +15,7 @@ pub use stream::VsockStreamSocket;
// init static driver
pub static VSOCK_GLOBAL: Once<Arc<VsockSpace>> = Once::new();
pub fn init() {
pub(in crate::net) fn init() {
if let Some(driver) = get_device(DEVICE_NAME) {
VSOCK_GLOBAL.call_once(|| Arc::new(VsockSpace::new(driver)));
register_recv_callback(DEVICE_NAME, || {

View File

@ -5,6 +5,7 @@ use crate::{
fs::{file_handle::FileLike, file_table::FdFlags},
net::socket::{
ip::{datagram::DatagramSocket, stream::StreamSocket},
netlink::{is_valid_protocol, NetlinkRouteSocket, StandardNetlinkProtocol},
unix::UnixStreamSocket,
vsock::VsockStreamSocket,
},
@ -16,29 +17,62 @@ pub fn sys_socket(domain: i32, type_: i32, protocol: i32, ctx: &Context) -> Resu
let domain = CSocketAddrFamily::try_from(domain)?;
let sock_type = SockType::try_from(type_ & SOCK_TYPE_MASK)?;
let sock_flags = SockFlags::from_bits_truncate(type_ & !SOCK_TYPE_MASK);
let protocol = Protocol::try_from(protocol)?;
debug!(
"domain = {:?}, sock_type = {:?}, sock_flags = {:?}, protocol = {:?}",
domain, sock_type, sock_flags, protocol
"domain = {:?}, sock_type = {:?}, sock_flags = {:?}",
domain, sock_type, sock_flags
);
let nonblocking = sock_flags.contains(SockFlags::SOCK_NONBLOCK);
let file_like = match (domain, sock_type, protocol) {
let is_nonblocking = sock_flags.contains(SockFlags::SOCK_NONBLOCK);
let file_like = match (domain, sock_type) {
// FIXME: SOCK_SEQPACKET is added to run fcntl_test, not supported yet.
(CSocketAddrFamily::AF_UNIX, SockType::SOCK_STREAM | SockType::SOCK_SEQPACKET, _) => {
UnixStreamSocket::new(nonblocking) as Arc<dyn FileLike>
(CSocketAddrFamily::AF_UNIX, SockType::SOCK_STREAM | SockType::SOCK_SEQPACKET) => {
UnixStreamSocket::new(is_nonblocking) as Arc<dyn FileLike>
}
(
CSocketAddrFamily::AF_INET,
SockType::SOCK_STREAM,
Protocol::IPPROTO_IP | Protocol::IPPROTO_TCP,
) => StreamSocket::new(nonblocking) as Arc<dyn FileLike>,
(
CSocketAddrFamily::AF_INET,
SockType::SOCK_DGRAM,
Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP,
) => DatagramSocket::new(nonblocking) as Arc<dyn FileLike>,
(CSocketAddrFamily::AF_VSOCK, SockType::SOCK_STREAM, _) => {
Arc::new(VsockStreamSocket::new(nonblocking)) as Arc<dyn FileLike>
(CSocketAddrFamily::AF_INET, SockType::SOCK_STREAM) => {
let protocol = Protocol::try_from(protocol)?;
debug!("protocol = {:?}", protocol);
match protocol {
Protocol::IPPROTO_IP | Protocol::IPPROTO_TCP => {
StreamSocket::new(is_nonblocking) as Arc<dyn FileLike>
}
_ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported protocol"),
}
}
(CSocketAddrFamily::AF_INET, SockType::SOCK_DGRAM) => {
let protocol = Protocol::try_from(protocol)?;
debug!("protocol = {:?}", protocol);
match protocol {
Protocol::IPPROTO_IP | Protocol::IPPROTO_UDP => {
DatagramSocket::new(is_nonblocking) as Arc<dyn FileLike>
}
_ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported protocol"),
}
}
(CSocketAddrFamily::AF_NETLINK, SockType::SOCK_RAW | SockType::SOCK_DGRAM) => {
let netlink_family = StandardNetlinkProtocol::try_from(protocol as u32);
debug!("netlink family = {:?}", netlink_family);
match netlink_family {
Ok(StandardNetlinkProtocol::ROUTE) => {
Arc::new(NetlinkRouteSocket::new(is_nonblocking))
}
Ok(_) => {
return_errno_with_message!(
Errno::EAFNOSUPPORT,
"some standard netlink families are not supported yet"
);
}
Err(_) => {
if is_valid_protocol(protocol as u32) {
return_errno_with_message!(
Errno::EAFNOSUPPORT,
"user-provided netlink family is not supported"
)
}
return_errno_with_message!(Errno::EAFNOSUPPORT, "invalid netlink family");
}
}
}
(CSocketAddrFamily::AF_VSOCK, SockType::SOCK_STREAM) => {
Arc::new(VsockStreamSocket::new(is_nonblocking)) as Arc<dyn FileLike>
}
_ => return_errno_with_message!(Errno::EAFNOSUPPORT, "unsupported domain"),
};

View File

@ -4,7 +4,7 @@ use core::cmp::min;
use ostd::task::Task;
use super::{ip::CSocketAddrInet, unix, vsock::CSocketAddrVm};
use super::{ip::CSocketAddrInet, netlink::CSocketAddrNetlink, unix, vsock::CSocketAddrVm};
use crate::{current_userspace, net::socket::SocketAddr, prelude::*};
/// Address family.
@ -162,6 +162,13 @@ pub fn read_socket_addr_from_user(addr: Vaddr, addr_len: usize) -> Result<Socket
let addr = unix::from_c_bytes(&storage.as_bytes()[..addr_len])?;
SocketAddr::Unix(addr)
}
Ok(CSocketAddrFamily::AF_NETLINK) => {
if addr_len < size_of::<CSocketAddrNetlink>() {
return_errno_with_message!(Errno::EINVAL, "the socket address length is too small");
}
let addr = CSocketAddrNetlink::from_bytes(storage.as_bytes());
SocketAddr::Netlink(addr.into())
}
Ok(CSocketAddrFamily::AF_VSOCK) => {
if addr_len < size_of::<CSocketAddrVm>() {
return_errno_with_message!(Errno::EINVAL, "the socket address length is too small");
@ -238,36 +245,45 @@ pub fn write_socket_addr_with_max_len(
);
}
let current_task = Task::current().unwrap();
let user_space = CurrentUserSpace::new(&current_task);
let actual_len = match socket_addr {
SocketAddr::IPv4(addr, port) => {
let socket_addr = CSocketAddrInet::from((*addr, *port));
let actual_len = size_of::<CSocketAddrInet>();
let written_len = min(actual_len, max_len as _);
user_space.write_bytes(
SocketAddr::IPv4(addr, port) => write_c_socket_address_util::<CSocketAddrInet, _>(
(*addr, *port),
dest,
&mut VmReader::from(&socket_addr.as_bytes()[..written_len]),
)?;
actual_len
}
max_len as usize,
)?,
SocketAddr::Unix(addr) => unix::into_c_bytes_and(addr, |bytes| {
let written_len = min(bytes.len(), max_len as _);
user_space.write_bytes(dest, &mut VmReader::from(&bytes[..written_len]))?;
current_userspace!().write_bytes(dest, &mut VmReader::from(&bytes[..written_len]))?;
Ok::<usize, Error>(bytes.len())
})?,
SocketAddr::Netlink(addr) => {
write_c_socket_address_util::<CSocketAddrNetlink, _>(*addr, dest, max_len as usize)?
}
SocketAddr::Vsock(addr) => {
let socket_addr = CSocketAddrVm::from(*addr);
let actual_len = size_of::<CSocketAddrVm>();
let written_len = min(actual_len, max_len as _);
user_space.write_bytes(
dest,
&mut VmReader::from(&socket_addr.as_bytes()[..written_len]),
)?;
actual_len
write_c_socket_address_util::<CSocketAddrVm, _>(*addr, dest, max_len as usize)?
}
};
Ok(actual_len as i32)
}
// Utility function to write a C socket address to user space.
fn write_c_socket_address_util<TCSockAddr: Pod, TSockAddr>(
addr: TSockAddr,
dest: Vaddr,
max_len: usize,
) -> Result<usize>
where
TCSockAddr: From<TSockAddr>,
{
let c_socket_addr = TCSockAddr::from(addr);
let actual_len = size_of::<TCSockAddr>();
let written_len = min(actual_len, max_len);
current_userspace!().write_bytes(
dest,
&mut VmReader::from(&c_socket_addr.as_bytes()[..written_len]),
)?;
Ok(actual_len)
}

View File

@ -38,6 +38,7 @@ impl From<(Ipv4Address, PortNum)> for CSocketAddrInet {
impl From<CSocketAddrInet> for (Ipv4Address, PortNum) {
fn from(value: CSocketAddrInet) -> Self {
debug_assert_eq!(value.sin_family, CSocketAddrFamily::AF_INET as u16);
(value.sin_addr.into(), value.sin_port.into())
}
}

View File

@ -7,5 +7,6 @@ pub use family::{
mod family;
mod ip;
mod netlink;
mod unix;
mod vsock;

View File

@ -0,0 +1,41 @@
// SPDX-License-Identifier: MPL-2.0
use super::CSocketAddrFamily;
use crate::{
net::socket::netlink::{GroupIdSet, NetlinkSocketAddr},
prelude::*,
};
/// Netlink socket address.
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod)]
pub struct CSocketAddrNetlink {
/// Address family (AF_NETLINK).
nl_family: u16,
/// Pad bytes (always zero).
nl_pad: u16,
/// Port ID.
nl_pid: u32,
/// Multicast groups mask.
nl_groups: u32,
}
impl From<NetlinkSocketAddr> for CSocketAddrNetlink {
fn from(value: NetlinkSocketAddr) -> Self {
Self {
nl_family: CSocketAddrFamily::AF_NETLINK as _,
nl_pad: 0,
nl_pid: value.port(),
nl_groups: value.groups().as_u32(),
}
}
}
impl From<CSocketAddrNetlink> for NetlinkSocketAddr {
fn from(value: CSocketAddrNetlink) -> Self {
debug_assert_eq!(value.nl_family, CSocketAddrFamily::AF_NETLINK as u16);
let port = value.nl_pid;
let groups = GroupIdSet::new(value.nl_groups);
NetlinkSocketAddr::new(port, groups)
}
}

View File

@ -33,6 +33,7 @@ impl From<VsockSocketAddr> for CSocketAddrVm {
impl From<CSocketAddrVm> for VsockSocketAddr {
fn from(value: CSocketAddrVm) -> Self {
debug_assert_eq!(value.svm_family, CSocketAddrFamily::AF_VSOCK as u16);
Self {
cid: value.svm_cid,
port: value.svm_port,