diff --git a/Cargo.lock b/Cargo.lock index f9de43cca..da2f21dd5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1383,7 +1383,7 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "smoltcp" version = "0.11.0" -source = "git+https://github.com/smoltcp-rs/smoltcp?rev=469dccd#469dccdbaece66696494f8540615152fc5123b17" +source = "git+https://github.com/asterinas/smoltcp?rev=37716bf#37716bff5ed5b16aba1b8a37e788ee5a6bf32cab" dependencies = [ "bitflags 1.3.2", "byteorder", diff --git a/kernel/libs/aster-bigtcp/Cargo.toml b/kernel/libs/aster-bigtcp/Cargo.toml index f426a97ce..2f40952bc 100644 --- a/kernel/libs/aster-bigtcp/Cargo.toml +++ b/kernel/libs/aster-bigtcp/Cargo.toml @@ -8,7 +8,7 @@ edition = "2021" [dependencies] keyable-arc = { path = "../keyable-arc" } ostd = { path = "../../../ostd" } -smoltcp = { git = "https://github.com/smoltcp-rs/smoltcp", rev = "469dccd", default-features = false, features = [ +smoltcp = { git = "https://github.com/asterinas/smoltcp", rev = "37716bf", default-features = false, features = [ "alloc", "log", "medium-ethernet", diff --git a/kernel/libs/aster-bigtcp/src/errors.rs b/kernel/libs/aster-bigtcp/src/errors.rs index 57d7e75fc..5f058b1cc 100644 --- a/kernel/libs/aster-bigtcp/src/errors.rs +++ b/kernel/libs/aster-bigtcp/src/errors.rs @@ -14,5 +14,15 @@ pub mod tcp { } pub mod udp { - pub use smoltcp::socket::udp::{RecvError, SendError}; + pub use smoltcp::socket::udp::RecvError; + + /// An error returned by [`BoundTcpSocket::recv`]. + /// + /// [`BoundTcpSocket::recv`]: crate::socket::BoundTcpSocket::recv + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + pub enum SendError { + TooLarge, + Unaddressable, + BufferFull, + } } diff --git a/kernel/libs/aster-bigtcp/src/iface/common.rs b/kernel/libs/aster-bigtcp/src/iface/common.rs index 25fc3e548..660649a91 100644 --- a/kernel/libs/aster-bigtcp/src/iface/common.rs +++ b/kernel/libs/aster-bigtcp/src/iface/common.rs @@ -7,42 +7,45 @@ use alloc::{ btree_set::BTreeSet, }, sync::Arc, - vec::Vec, }; use keyable_arc::KeyableArc; -use ostd::sync::{LocalIrqDisabled, PreemptDisabled, RwLock, SpinLock, SpinLockGuard}; +use ostd::sync::{LocalIrqDisabled, PreemptDisabled, SpinLock, SpinLockGuard}; use smoltcp::{ - iface::{SocketHandle, SocketSet}, + iface::{packet::Packet, Context}, phy::Device, - wire::Ipv4Address, + wire::{Ipv4Address, Ipv4Packet}, }; -use super::{port::BindPortConfig, time::get_network_timestamp, Iface}; +use super::{ + poll::{FnHelper, PollContext}, + port::BindPortConfig, + time::get_network_timestamp, + Iface, +}; use crate::{ errors::BindError, - socket::{AnyBoundSocket, AnyBoundSocketInner, AnyRawSocket, AnyUnboundSocket, SocketFamily}, + socket::{ + BoundTcpSocket, BoundTcpSocketInner, BoundUdpSocket, BoundUdpSocketInner, UnboundTcpSocket, + UnboundUdpSocket, + }, }; pub struct IfaceCommon { interface: SpinLock, - sockets: SpinLock, LocalIrqDisabled>, used_ports: SpinLock, PreemptDisabled>, - bound_sockets: RwLock>>>, - closing_sockets: SpinLock>>, LocalIrqDisabled>, + tcp_sockets: SpinLock>>, LocalIrqDisabled>, + udp_sockets: SpinLock>>, LocalIrqDisabled>, ext: E, } impl IfaceCommon { pub(super) fn new(interface: smoltcp::iface::Interface, ext: E) -> Self { - let socket_set = SocketSet::new(Vec::new()); - let used_ports = BTreeMap::new(); Self { interface: SpinLock::new(interface), - sockets: SpinLock::new(socket_set), - used_ports: SpinLock::new(used_ports), - bound_sockets: RwLock::new(BTreeSet::new()), - closing_sockets: SpinLock::new(BTreeSet::new()), + used_ports: SpinLock::new(BTreeMap::new()), + tcp_sockets: SpinLock::new(BTreeSet::new()), + udp_sockets: SpinLock::new(BTreeSet::new()), ext, } } @@ -58,58 +61,57 @@ impl IfaceCommon { impl IfaceCommon { /// Acquires the lock to the interface. - /// - /// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second. pub(crate) fn interface(&self) -> SpinLockGuard { self.interface.lock() } - - /// Acuqires the lock to the sockets. - /// - /// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second. - pub(crate) fn sockets( - &self, - ) -> SpinLockGuard, LocalIrqDisabled> { - self.sockets.lock() - } } const IP_LOCAL_PORT_START: u16 = 32768; const IP_LOCAL_PORT_END: u16 = 60999; impl IfaceCommon { - pub(super) fn bind_socket( + pub(super) fn bind_tcp( &self, iface: Arc>, - socket: Box, + socket: Box, config: BindPortConfig, - ) -> core::result::Result, (BindError, Box)> { - let port = if let Some(port) = config.port() { - port - } else { - match self.alloc_ephemeral_port() { - Some(port) => port, - None => return Err((BindError::Exhausted, socket)), - } + ) -> core::result::Result, (BindError, Box)> { + let port = match self.bind_port(config) { + Ok(port) => port, + Err(err) => return Err((err, socket)), }; - if !self.bind_port(port, config.can_reuse()) { - return Err((BindError::InUse, socket)); - } - let (handle, socket_family, observer) = match socket.into_raw() { - (AnyRawSocket::Tcp(tcp_socket), observer) => ( - self.sockets.lock().add(tcp_socket), - SocketFamily::Tcp, - observer, - ), - (AnyRawSocket::Udp(udp_socket), observer) => ( - self.sockets.lock().add(udp_socket), - SocketFamily::Udp, - observer, - ), + let (raw_socket, observer) = socket.into_raw(); + let bound_socket = BoundTcpSocket::new(iface, port, raw_socket, observer); + + let inserted = self + .tcp_sockets + .lock() + .insert(KeyableArc::from(bound_socket.inner().clone())); + assert!(inserted); + + Ok(bound_socket) + } + + pub(super) fn bind_udp( + &self, + iface: Arc>, + socket: Box, + config: BindPortConfig, + ) -> core::result::Result, (BindError, Box)> { + let port = match self.bind_port(config) { + Ok(port) => port, + Err(err) => return Err((err, socket)), }; - let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family, observer); - self.insert_bound_socket(bound_socket.inner()); + + let (raw_socket, observer) = socket.into_raw(); + let bound_socket = BoundUdpSocket::new(iface, port, raw_socket, observer); + + let inserted = self + .udp_sockets + .lock() + .insert(KeyableArc::from(bound_socket.inner().clone())); + assert!(inserted); Ok(bound_socket) } @@ -130,36 +132,57 @@ impl IfaceCommon { None } - #[must_use] - fn bind_port(&self, port: u16, can_reuse: bool) -> bool { + fn bind_port(&self, config: BindPortConfig) -> Result { + let port = if let Some(port) = config.port() { + port + } else { + match self.alloc_ephemeral_port() { + Some(port) => port, + None => return Err(BindError::Exhausted), + } + }; + let mut used_ports = self.used_ports.lock(); + if let Some(used_times) = used_ports.get_mut(&port) { - if *used_times == 0 || can_reuse { + if *used_times == 0 || config.can_reuse() { // FIXME: Check if the previous socket was bound with SO_REUSEADDR. *used_times += 1; } else { - return false; + return Err(BindError::InUse); } } else { used_ports.insert(port, 1); } - true - } - fn insert_bound_socket(&self, socket: &Arc>) { - let keyable_socket = KeyableArc::from(socket.clone()); - - let inserted = self - .bound_sockets - .write_irq_disabled() - .insert(keyable_socket); - assert!(inserted); + Ok(port) } } impl IfaceCommon { + #[allow(clippy::mutable_key_type)] + fn remove_dead_tcp_sockets(&self, sockets: &mut BTreeSet>>) { + sockets.retain(|socket| { + if socket.is_dead() { + self.release_port(socket.port()); + false + } else { + true + } + }); + } + + pub(crate) fn remove_udp_socket(&self, socket: &Arc>) { + let keyable_socket = KeyableArc::from(socket.clone()); + + let removed = self.udp_sockets.lock().remove(&keyable_socket); + assert!(removed); + + self.release_port(keyable_socket.port()); + } + /// Releases the port so that it can be used again (if it is not being reused). - pub(crate) fn release_port(&self, port: u16) { + fn release_port(&self, port: u16) { let mut used_ports = self.used_ports.lock(); if let Some(used_times) = used_ports.remove(&port) { if used_times != 1 { @@ -167,96 +190,63 @@ impl IfaceCommon { } } } - - /// Removes a socket from the interface. - pub(crate) fn remove_socket(&self, handle: SocketHandle) { - self.sockets.lock().remove(handle); - } - - pub(crate) fn remove_bound_socket_now(&self, socket: &Arc>) { - let keyable_socket = KeyableArc::from(socket.clone()); - - let removed = self - .bound_sockets - .write_irq_disabled() - .remove(&keyable_socket); - assert!(removed); - } - - pub(crate) fn remove_bound_socket_when_closed(&self, socket: &Arc>) { - let keyable_socket = KeyableArc::from(socket.clone()); - - let removed = self - .bound_sockets - .write_irq_disabled() - .remove(&keyable_socket); - assert!(removed); - - let mut closing_sockets = self.closing_sockets.lock(); - - // Check `is_closed` after holding the lock to avoid race conditions. - if keyable_socket.is_closed() { - return; - } - - let inserted = closing_sockets.insert(keyable_socket); - assert!(inserted); - } } impl IfaceCommon { - #[must_use] - pub(super) fn poll(&self, device: &mut D) -> Option { - let mut sockets = self.sockets.lock(); - let mut interface = self.interface.lock(); + pub(super) fn poll( + &self, + device: &mut D, + process_phy: P, + mut dispatch_phy: Q, + ) -> Option + where + D: Device + ?Sized, + P: for<'pkt, 'cx, 'tx> FnHelper< + &'pkt [u8], + &'cx mut Context, + D::TxToken<'tx>, + Option<(Ipv4Packet<&'pkt [u8]>, D::TxToken<'tx>)>, + >, + Q: FnMut(&Packet, &mut Context, D::TxToken<'_>), + { + let mut interface = self.interface(); + interface.context().now = get_network_timestamp(); - let timestamp = get_network_timestamp(); - let (has_events, poll_at) = { - let mut has_events = false; - let mut poll_at; + let mut tcp_sockets = self.tcp_sockets.lock(); + let udp_sockets = self.udp_sockets.lock(); - loop { - // `poll` transmits and receives a bounded number of packets. This loop ensures - // that all packets are transmitted and received. For details, see - // . - while interface.poll(timestamp, device, &mut sockets) { - has_events = true; - } + let mut context = PollContext::new(interface.context(), &tcp_sockets, &udp_sockets); + context.poll_ingress(device, process_phy, &mut dispatch_phy); + context.poll_egress(device, dispatch_phy); - // `poll_at` can return `Some(Instant::from_millis(0))`, which means `PollAt::Now`. - // For details, see - // . - poll_at = interface.poll_at(timestamp, &sockets); - let Some(instant) = poll_at else { - break; - }; - if instant > timestamp { - break; - } + tcp_sockets.iter().for_each(|socket| { + if socket.has_new_events() { + socket.on_iface_events(); } + }); + udp_sockets.iter().for_each(|socket| { + if socket.has_new_events() { + socket.on_iface_events(); + } + }); - (has_events, poll_at) - }; + self.remove_dead_tcp_sockets(&mut tcp_sockets); - // drop sockets here to avoid deadlock - drop(interface); - drop(sockets); - - if has_events { - // We never try to hold the write lock in the IRQ context, and we disable IRQ when - // holding the write lock. So we don't need to disable IRQ when holding the read lock. - self.bound_sockets.read().iter().for_each(|bound_socket| { - bound_socket.on_iface_events(); - }); - - let closed_sockets = self - .closing_sockets - .lock() - .extract_if(|closing_socket| closing_socket.is_closed()) - .collect::>(); - drop(closed_sockets); + match ( + tcp_sockets + .iter() + .map(|socket| socket.next_poll_at_ms()) + .min(), + udp_sockets + .iter() + .map(|socket| socket.next_poll_at_ms()) + .min(), + ) { + (Some(tcp_poll_at), Some(udp_poll_at)) if tcp_poll_at <= udp_poll_at => { + Some(tcp_poll_at) + } + (tcp_poll_at, None) => tcp_poll_at, + (_, udp_poll_at) => udp_poll_at, } - - poll_at.map(|at| smoltcp::time::Instant::total_millis(&at) as u64) } } diff --git a/kernel/libs/aster-bigtcp/src/iface/iface.rs b/kernel/libs/aster-bigtcp/src/iface/iface.rs index 01d06762f..da21b0a0d 100644 --- a/kernel/libs/aster-bigtcp/src/iface/iface.rs +++ b/kernel/libs/aster-bigtcp/src/iface/iface.rs @@ -7,7 +7,7 @@ use smoltcp::wire::Ipv4Address; use super::port::BindPortConfig; use crate::{ errors::BindError, - socket::{AnyBoundSocket, AnyUnboundSocket}, + socket::{BoundTcpSocket, BoundUdpSocket, UnboundTcpSocket, UnboundUdpSocket}, }; /// A network interface. @@ -42,13 +42,22 @@ impl dyn Iface { /// FIXME: The reason for binding the socket and the iface together is because there are /// limitations inside smoltcp. See discussion at /// . - pub fn bind_socket( + pub fn bind_tcp( self: &Arc, - socket: Box, + socket: Box, config: BindPortConfig, - ) -> core::result::Result, (BindError, Box)> { + ) -> core::result::Result, (BindError, Box)> { let common = self.common(); - common.bind_socket(self.clone(), socket, config) + common.bind_tcp(self.clone(), socket, config) + } + + pub fn bind_udp( + self: &Arc, + socket: Box, + config: BindPortConfig, + ) -> core::result::Result, (BindError, Box)> { + let common = self.common(); + common.bind_udp(self.clone(), socket, config) } /// Gets the IPv4 address of the iface, if any. diff --git a/kernel/libs/aster-bigtcp/src/iface/mod.rs b/kernel/libs/aster-bigtcp/src/iface/mod.rs index 744699cbf..62a93c25e 100644 --- a/kernel/libs/aster-bigtcp/src/iface/mod.rs +++ b/kernel/libs/aster-bigtcp/src/iface/mod.rs @@ -4,6 +4,7 @@ mod common; #[allow(clippy::module_inception)] mod iface; mod phy; +mod poll; mod port; mod time; diff --git a/kernel/libs/aster-bigtcp/src/iface/phy/ether.rs b/kernel/libs/aster-bigtcp/src/iface/phy/ether.rs index be1a94e18..cfe749e4a 100644 --- a/kernel/libs/aster-bigtcp/src/iface/phy/ether.rs +++ b/kernel/libs/aster-bigtcp/src/iface/phy/ether.rs @@ -1,10 +1,15 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::sync::Arc; +use alloc::{collections::btree_map::BTreeMap, sync::Arc}; +use ostd::sync::{LocalIrqDisabled, SpinLock}; use smoltcp::{ - iface::Config, - wire::{self, EthernetAddress, Ipv4Address, Ipv4Cidr}, + iface::{packet::Packet, Config, Context}, + phy::{DeviceCapabilities, TxToken}, + wire::{ + self, ArpOperation, ArpPacket, ArpRepr, EthernetAddress, EthernetFrame, EthernetProtocol, + EthernetRepr, IpAddress, Ipv4Address, Ipv4Cidr, Ipv4Packet, + }, }; use crate::{ @@ -14,9 +19,11 @@ use crate::{ }, }; -pub struct EtherIface { +pub struct EtherIface { driver: D, common: IfaceCommon, + ether_addr: EthernetAddress, + arp_table: SpinLock, LocalIrqDisabled>, } impl EtherIface { @@ -45,21 +52,224 @@ impl EtherIface { let common = IfaceCommon::new(interface, ext); - Arc::new(Self { driver, common }) + Arc::new(Self { + driver, + common, + ether_addr, + arp_table: SpinLock::new(BTreeMap::new()), + }) } } -impl IfaceInternal for EtherIface { +impl IfaceInternal for EtherIface { fn common(&self) -> &IfaceCommon { &self.common } } -impl Iface for EtherIface { +impl Iface for EtherIface { fn raw_poll(&self, schedule_next_poll: &dyn Fn(Option)) { self.driver.with(|device| { - let next_poll = self.common.poll(&mut *device); + let next_poll = self.common.poll( + &mut *device, + |data, iface_cx, tx_token| self.process(data, iface_cx, tx_token), + |pkt, iface_cx, tx_token| self.dispatch(pkt, iface_cx, tx_token), + ); schedule_next_poll(next_poll); }); } } + +impl EtherIface { + fn process<'pkt, T: TxToken>( + &self, + data: &'pkt [u8], + iface_cx: &mut Context, + tx_token: T, + ) -> Option<(Ipv4Packet<&'pkt [u8]>, T)> { + match self.parse_ip_or_process_arp(data, iface_cx) { + Ok(pkt) => Some((pkt, tx_token)), + Err(Some(arp)) => { + Self::emit_arp(&arp, tx_token); + None + } + Err(None) => None, + } + } + + fn parse_ip_or_process_arp<'pkt>( + &self, + data: &'pkt [u8], + iface_cx: &mut Context, + ) -> Result, Option> { + // Parse the Ethernet header. Ignore the packet if the header is ill-formed. + let frame = EthernetFrame::new_checked(data).map_err(|_| None)?; + let repr = EthernetRepr::parse(&frame).map_err(|_| None)?; + + // Ignore the Ethernet frame if it is not sent to us. + if !repr.dst_addr.is_broadcast() && repr.dst_addr != self.ether_addr { + return Err(None); + } + + // Ignore the Ethernet frame if the protocol is not supported. + match repr.ethertype { + EthernetProtocol::Ipv4 => { + Ok(Ipv4Packet::new_checked(frame.payload()).map_err(|_| None)?) + } + EthernetProtocol::Arp => { + let pkt = ArpPacket::new_checked(frame.payload()).map_err(|_| None)?; + let arp = ArpRepr::parse(&pkt).map_err(|_| None)?; + Err(self.process_arp(&arp, iface_cx)) + } + _ => Err(None), + } + } + + fn process_arp(&self, arp_repr: &ArpRepr, iface_cx: &mut Context) -> Option { + match arp_repr { + ArpRepr::EthernetIpv4 { + operation: ArpOperation::Reply, + source_hardware_addr, + source_protocol_addr, + .. + } => { + // Ignore the ARP packet if the source addresses are not unicast or not local. + if !source_hardware_addr.is_unicast() + || !iface_cx.in_same_network(&IpAddress::Ipv4(*source_protocol_addr)) + { + return None; + } + + // Insert the mapping between the Ethernet address and the IP address. + // + // TODO: Remove the mapping if it expires. + self.arp_table + .lock() + .insert(*source_protocol_addr, *source_hardware_addr); + + None + } + ArpRepr::EthernetIpv4 { + operation: ArpOperation::Request, + source_hardware_addr, + source_protocol_addr, + target_protocol_addr, + .. + } => { + // Ignore the ARP packet if the source addresses are not unicast. + if !source_hardware_addr.is_unicast() || !source_protocol_addr.is_unicast() { + return None; + } + + // Ignore the ARP packet if we do not own the target address. + if !iface_cx + .ipv4_addr() + .is_some_and(|addr| addr == *target_protocol_addr) + { + return None; + } + + Some(ArpRepr::EthernetIpv4 { + operation: ArpOperation::Reply, + source_hardware_addr: self.ether_addr, + source_protocol_addr: *target_protocol_addr, + target_hardware_addr: *source_hardware_addr, + target_protocol_addr: *source_protocol_addr, + }) + } + _ => None, + } + } + + fn dispatch(&self, pkt: &Packet, iface_cx: &mut Context, tx_token: T) { + match self.resolve_ether_or_generate_arp(pkt, iface_cx) { + Ok(ether) => Self::emit_ip(ðer, pkt, &iface_cx.caps, tx_token), + Err(Some(arp)) => Self::emit_arp(&arp, tx_token), + Err(None) => (), + } + } + + fn resolve_ether_or_generate_arp( + &self, + pkt: &Packet, + iface_cx: &mut Context, + ) -> Result> { + // Resolve the next-hop IP address. + let next_hop_ip = match iface_cx.route(&pkt.ip_repr().dst_addr(), iface_cx.now()) { + Some(IpAddress::Ipv4(next_hop_ip)) => next_hop_ip, + None => return Err(None), + }; + + // Resolve the next-hop Ethernet address. + let next_hop_ether = if next_hop_ip.is_broadcast() { + EthernetAddress::BROADCAST + } else if let Some(next_hop_ether) = self.arp_table.lock().get(&next_hop_ip) { + *next_hop_ether + } else { + // If the next-hop Ethernet address cannot be resolved, we drop the original packet and + // send an ARP packet instead. The upper layer should be responsible for detecting the + // packet loss and retrying later to see if the Ethernet address is ready. + return Err(Some(ArpRepr::EthernetIpv4 { + operation: ArpOperation::Request, + source_hardware_addr: self.ether_addr, + source_protocol_addr: iface_cx.ipv4_addr().unwrap_or(Ipv4Address::UNSPECIFIED), + target_hardware_addr: EthernetAddress::BROADCAST, + target_protocol_addr: next_hop_ip, + })); + }; + + Ok(EthernetRepr { + src_addr: self.ether_addr, + dst_addr: next_hop_ether, + ethertype: EthernetProtocol::Ipv4, + }) + } + + /// Consumes the token and emits an IP packet. + fn emit_ip( + ether_repr: &EthernetRepr, + ip_pkt: &Packet, + caps: &DeviceCapabilities, + tx_token: T, + ) { + tx_token.consume( + ether_repr.buffer_len() + ip_pkt.ip_repr().buffer_len(), + |buffer| { + let mut frame = EthernetFrame::new_unchecked(buffer); + ether_repr.emit(&mut frame); + + let ip_repr = ip_pkt.ip_repr(); + ip_repr.emit(frame.payload_mut(), &caps.checksum); + ip_pkt.emit_payload( + &ip_repr, + &mut frame.payload_mut()[ip_repr.header_len()..], + caps, + ); + }, + ); + } + + /// Consumes the token and emits an ARP packet. + fn emit_arp(arp_repr: &ArpRepr, tx_token: T) { + let ether_repr = match arp_repr { + ArpRepr::EthernetIpv4 { + source_hardware_addr, + target_hardware_addr, + .. + } => EthernetRepr { + src_addr: *source_hardware_addr, + dst_addr: *target_hardware_addr, + ethertype: EthernetProtocol::Arp, + }, + _ => return, + }; + + tx_token.consume(ether_repr.buffer_len() + arp_repr.buffer_len(), |buffer| { + let mut frame = EthernetFrame::new_unchecked(buffer); + ether_repr.emit(&mut frame); + + let mut pkt = ArpPacket::new_unchecked(frame.payload_mut()); + arp_repr.emit(&mut pkt); + }); + } +} diff --git a/kernel/libs/aster-bigtcp/src/iface/phy/ip.rs b/kernel/libs/aster-bigtcp/src/iface/phy/ip.rs index 7ec2f1fb2..ef99ac102 100644 --- a/kernel/libs/aster-bigtcp/src/iface/phy/ip.rs +++ b/kernel/libs/aster-bigtcp/src/iface/phy/ip.rs @@ -4,7 +4,8 @@ use alloc::sync::Arc; use smoltcp::{ iface::Config, - wire::{self, Ipv4Cidr}, + phy::TxToken, + wire::{self, Ipv4Cidr, Ipv4Packet}, }; use crate::{ @@ -14,7 +15,7 @@ use crate::{ }, }; -pub struct IpIface { +pub struct IpIface { driver: D, common: IfaceCommon, } @@ -39,16 +40,30 @@ impl IpIface { } } -impl IfaceInternal for IpIface { +impl IfaceInternal for IpIface { fn common(&self) -> &IfaceCommon { &self.common } } -impl Iface for IpIface { +impl Iface for IpIface { fn raw_poll(&self, schedule_next_poll: &dyn Fn(Option)) { self.driver.with(|device| { - let next_poll = self.common.poll(device); + let next_poll = self.common.poll( + device, + |data, _iface_cx, tx_token| Some((Ipv4Packet::new_checked(data).ok()?, tx_token)), + |pkt, iface_cx, tx_token| { + let ip_repr = pkt.ip_repr(); + tx_token.consume(ip_repr.buffer_len(), |buffer| { + ip_repr.emit(&mut buffer[..], &iface_cx.checksum_caps()); + pkt.emit_payload( + &ip_repr, + &mut buffer[ip_repr.header_len()..], + &iface_cx.caps, + ); + }); + }, + ); schedule_next_poll(next_poll); }); } diff --git a/kernel/libs/aster-bigtcp/src/iface/poll.rs b/kernel/libs/aster-bigtcp/src/iface/poll.rs new file mode 100644 index 000000000..e189b7f97 --- /dev/null +++ b/kernel/libs/aster-bigtcp/src/iface/poll.rs @@ -0,0 +1,476 @@ +// SPDX-License-Identifier: MPL-2.0 + +use alloc::{collections::btree_set::BTreeSet, vec}; + +use keyable_arc::KeyableArc; +use smoltcp::{ + iface::{ + packet::{icmp_reply_payload_len, IpPayload, Packet}, + Context, + }, + phy::{ChecksumCapabilities, Device, RxToken, TxToken}, + wire::{ + Icmpv4DstUnreachable, Icmpv4Repr, IpAddress, IpProtocol, IpRepr, Ipv4Address, Ipv4Packet, + Ipv4Repr, TcpControl, TcpPacket, TcpRepr, UdpPacket, UdpRepr, IPV4_HEADER_LEN, + IPV4_MIN_MTU, + }, +}; + +use crate::socket::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult}; + +pub(super) struct PollContext<'a, E> { + iface_cx: &'a mut Context, + tcp_sockets: &'a BTreeSet>>, + udp_sockets: &'a BTreeSet>>, +} + +impl<'a, E> PollContext<'a, E> { + #[allow(clippy::mutable_key_type)] + pub(super) fn new( + iface_cx: &'a mut Context, + tcp_sockets: &'a BTreeSet>>, + udp_sockets: &'a BTreeSet>>, + ) -> Self { + Self { + iface_cx, + tcp_sockets, + udp_sockets, + } + } +} + +// This works around . +// See the issue above for details. +pub(super) trait FnHelper: FnMut(A, B, C) -> O {} +impl FnHelper for F where F: FnMut(A, B, C) -> O {} + +impl<'a, E> PollContext<'a, E> { + pub(super) fn poll_ingress( + &mut self, + device: &mut D, + mut process_phy: P, + dispatch_phy: &mut Q, + ) where + D: Device + ?Sized, + P: for<'pkt, 'cx, 'tx> FnHelper< + &'pkt [u8], + &'cx mut Context, + D::TxToken<'tx>, + Option<(Ipv4Packet<&'pkt [u8]>, D::TxToken<'tx>)>, + >, + Q: FnMut(&Packet, &mut Context, D::TxToken<'_>), + { + while let Some((rx_token, tx_token)) = device.receive(self.iface_cx.now()) { + rx_token.consume(|data| { + let Some((pkt, tx_token)) = process_phy(data, self.iface_cx, tx_token) else { + return; + }; + + let Some(reply) = self.parse_and_process_ipv4(pkt) else { + return; + }; + + dispatch_phy(&reply, self.iface_cx, tx_token); + }); + } + } + + fn parse_and_process_ipv4<'pkt>( + &mut self, + pkt: Ipv4Packet<&'pkt [u8]>, + ) -> Option> { + // Parse the IP header. Ignore the packet if the header is ill-formed. + let repr = Ipv4Repr::parse(&pkt, &self.iface_cx.checksum_caps()).ok()?; + + if !repr.dst_addr.is_broadcast() && !self.is_unicast_local(IpAddress::Ipv4(repr.dst_addr)) { + return self.generate_icmp_unreachable( + &IpRepr::Ipv4(repr), + pkt.payload(), + Icmpv4DstUnreachable::HostUnreachable, + ); + } + + match repr.next_header { + IpProtocol::Tcp => self.parse_and_process_tcp( + &IpRepr::Ipv4(repr), + pkt.payload(), + &self.iface_cx.checksum_caps(), + ), + IpProtocol::Udp => self.parse_and_process_udp( + &IpRepr::Ipv4(repr), + pkt.payload(), + &self.iface_cx.checksum_caps(), + ), + _ => None, + } + } + + fn parse_and_process_tcp<'pkt>( + &mut self, + ip_repr: &IpRepr, + ip_payload: &'pkt [u8], + checksum_caps: &ChecksumCapabilities, + ) -> Option> { + // TCP connections can only be established between unicast addresses. Ignore the packet if + // this is not the case. See + // . + if !ip_repr.src_addr().is_unicast() || !ip_repr.dst_addr().is_unicast() { + return None; + } + + // Parse the TCP header. Ignore the packet if the header is ill-formed. + let tcp_pkt = TcpPacket::new_checked(ip_payload).ok()?; + let tcp_repr = TcpRepr::parse( + &tcp_pkt, + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + checksum_caps, + ) + .ok()?; + + self.process_tcp_until_outgoing(ip_repr, &tcp_repr) + .map(|(ip_repr, tcp_repr)| Packet::new(ip_repr, IpPayload::Tcp(tcp_repr))) + } + + fn process_tcp_until_outgoing( + &mut self, + ip_repr: &IpRepr, + tcp_repr: &TcpRepr, + ) -> Option<(IpRepr, TcpRepr<'static>)> { + let (mut ip_repr, mut tcp_repr) = self.process_tcp(ip_repr, tcp_repr)?; + + loop { + if !self.is_unicast_local(ip_repr.dst_addr()) { + return Some((ip_repr, tcp_repr)); + } + + let (new_ip_repr, new_tcp_repr) = self.process_tcp(&ip_repr, &tcp_repr)?; + ip_repr = new_ip_repr; + tcp_repr = new_tcp_repr; + } + } + + fn process_tcp( + &mut self, + ip_repr: &IpRepr, + tcp_repr: &TcpRepr, + ) -> Option<(IpRepr, TcpRepr<'static>)> { + for socket in self.tcp_sockets.iter() { + if !socket.can_process(tcp_repr.dst_port) { + continue; + } + + match socket.process(self.iface_cx, ip_repr, tcp_repr) { + TcpProcessResult::NotProcessed => continue, + TcpProcessResult::Processed => return None, + TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr) => { + return Some((ip_repr, tcp_repr)) + } + } + } + + // "In no case does receipt of a segment containing RST give rise to a RST in response." + // See . + if tcp_repr.control == TcpControl::Rst { + return None; + } + + Some(smoltcp::socket::tcp::Socket::rst_reply(ip_repr, tcp_repr)) + } + + fn parse_and_process_udp<'pkt>( + &mut self, + ip_repr: &IpRepr, + ip_payload: &'pkt [u8], + checksum_caps: &ChecksumCapabilities, + ) -> Option> { + // Parse the UDP header. Ignore the packet if the header is ill-formed. + let udp_pkt = UdpPacket::new_checked(ip_payload).ok()?; + let udp_repr = UdpRepr::parse( + &udp_pkt, + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + checksum_caps, + ) + .ok()?; + + if !self.process_udp(ip_repr, &udp_repr, udp_pkt.payload()) { + return self.generate_icmp_unreachable( + ip_repr, + ip_payload, + Icmpv4DstUnreachable::PortUnreachable, + ); + } + + None + } + + fn process_udp(&mut self, ip_repr: &IpRepr, udp_repr: &UdpRepr, udp_payload: &[u8]) -> bool { + let mut processed = false; + + for socket in self.udp_sockets.iter() { + if !socket.can_process(udp_repr.dst_port) { + continue; + } + + processed |= socket.process(self.iface_cx, ip_repr, udp_repr, udp_payload); + if processed && ip_repr.dst_addr().is_unicast() { + break; + } + } + + processed + } + + fn generate_icmp_unreachable<'pkt>( + &self, + ip_repr: &IpRepr, + ip_payload: &'pkt [u8], + reason: Icmpv4DstUnreachable, + ) -> Option> { + if !ip_repr.src_addr().is_unicast() || !ip_repr.dst_addr().is_unicast() { + return None; + } + + if self.is_unicast_local(ip_repr.src_addr()) { + // In this case, the generating ICMP message will have a local IP address as the + // destination. However, since we don't have the ability to handle ICMP messages, we'll + // just skip the generation. + // + // TODO: Generate the ICMP message here once we're able to handle incoming ICMP + // messages. + return None; + } + + let IpRepr::Ipv4(ipv4_repr) = ip_repr; + + let reply_len = icmp_reply_payload_len(ip_payload.len(), IPV4_MIN_MTU, IPV4_HEADER_LEN); + let icmp_repr = Icmpv4Repr::DstUnreachable { + reason, + header: *ipv4_repr, + data: &ip_payload[..reply_len], + }; + + Some(Packet::new_ipv4( + Ipv4Repr { + src_addr: self + .iface_cx + .ipv4_addr() + .unwrap_or(Ipv4Address::UNSPECIFIED), + dst_addr: ipv4_repr.src_addr, + next_header: IpProtocol::Icmp, + payload_len: icmp_repr.buffer_len(), + hop_limit: 64, + }, + IpPayload::Icmpv4(icmp_repr), + )) + } + + /// Returns whether the destination address is the unicast address of a local interface. + /// + /// Note: "local" means that the IP address belongs to the local interface, not to be confused + /// with the localhost IP (127.0.0.1). + fn is_unicast_local(&self, dst_addr: IpAddress) -> bool { + match dst_addr { + IpAddress::Ipv4(dst_addr) => self + .iface_cx + .ipv4_addr() + .is_some_and(|addr| addr == dst_addr), + } + } +} + +impl<'a, E> PollContext<'a, E> { + pub(super) fn poll_egress(&mut self, device: &mut D, mut dispatch_phy: Q) + where + D: Device + ?Sized, + Q: FnMut(&Packet, &mut Context, D::TxToken<'_>), + { + while let Some(tx_token) = device.transmit(self.iface_cx.now()) { + if !self.dispatch_ipv4(tx_token, &mut dispatch_phy) { + break; + } + } + } + + fn dispatch_ipv4(&mut self, tx_token: T, dispatch_phy: &mut Q) -> bool + where + T: TxToken, + Q: FnMut(&Packet, &mut Context, T), + { + let (did_something_tcp, tx_token) = self.dispatch_tcp(tx_token, dispatch_phy); + + let Some(tx_token) = tx_token else { + return did_something_tcp; + }; + + let (did_something_udp, _tx_token) = self.dispatch_udp(tx_token, dispatch_phy); + + did_something_tcp || did_something_udp + } + + fn dispatch_tcp(&mut self, tx_token: T, dispatch_phy: &mut Q) -> (bool, Option) + where + T: TxToken, + Q: FnMut(&Packet, &mut Context, T), + { + let mut tx_token = Some(tx_token); + let mut did_something = false; + + for socket in self.tcp_sockets.iter() { + if !socket.need_dispatch(self.iface_cx.now()) { + continue; + } + + // We set `did_something` even if no packets are actually generated. This is because a + // timer can expire, but no packets are actually generated. + did_something = true; + + let mut deferred = None; + + let reply = socket.dispatch(self.iface_cx, |cx, ip_repr, tcp_repr| { + let mut this = PollContext::new(cx, self.tcp_sockets, self.udp_sockets); + + if !this.is_unicast_local(ip_repr.dst_addr()) { + dispatch_phy( + &Packet::new(ip_repr.clone(), IpPayload::Tcp(*tcp_repr)), + this.iface_cx, + tx_token.take().unwrap(), + ); + return None; + } + + if !socket.can_process(tcp_repr.dst_port) { + return this.process_tcp(ip_repr, tcp_repr); + } + + // We cannot call `process_tcp` now because it may cause deadlocks. We will copy + // the packet and call `process_tcp` after releasing the socket lock. + deferred = Some((ip_repr.clone(), { + let mut data = vec![0; tcp_repr.buffer_len()]; + tcp_repr.emit( + &mut TcpPacket::new_unchecked(data.as_mut_slice()), + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + &ChecksumCapabilities::ignored(), + ); + data + })); + + None + }); + + match (deferred, reply) { + (None, None) => (), + (Some((ip_repr, ip_payload)), None) => { + if let Some(reply) = self.parse_and_process_tcp( + &ip_repr, + &ip_payload, + &ChecksumCapabilities::ignored(), + ) { + dispatch_phy(&reply, self.iface_cx, tx_token.take().unwrap()); + } + } + (None, Some((ip_repr, tcp_repr))) if !self.is_unicast_local(ip_repr.dst_addr()) => { + dispatch_phy( + &Packet::new(ip_repr, IpPayload::Tcp(tcp_repr)), + self.iface_cx, + tx_token.take().unwrap(), + ); + } + (None, Some((ip_repr, tcp_repr))) => { + if let Some((new_ip_repr, new_tcp_repr)) = + self.process_tcp_until_outgoing(&ip_repr, &tcp_repr) + { + dispatch_phy( + &Packet::new(new_ip_repr, IpPayload::Tcp(new_tcp_repr)), + self.iface_cx, + tx_token.take().unwrap(), + ); + } + } + (Some(_), Some(_)) => unreachable!(), + } + + if tx_token.is_none() { + break; + } + } + + (did_something, tx_token) + } + + fn dispatch_udp(&mut self, tx_token: T, dispatch_phy: &mut Q) -> (bool, Option) + where + T: TxToken, + Q: FnMut(&Packet, &mut Context, T), + { + let mut tx_token = Some(tx_token); + let mut did_something = false; + + for socket in self.udp_sockets.iter() { + if !socket.need_dispatch(self.iface_cx.now()) { + continue; + } + + // We set `did_something` even if no packets are actually generated. This is because a + // timer can expire, but no packets are actually generated. + did_something = true; + + let mut deferred = None; + + socket.dispatch(self.iface_cx, |cx, ip_repr, udp_repr, udp_payload| { + let mut this = PollContext::new(cx, self.tcp_sockets, self.udp_sockets); + + if ip_repr.dst_addr().is_broadcast() || !this.is_unicast_local(ip_repr.dst_addr()) { + dispatch_phy( + &Packet::new(ip_repr.clone(), IpPayload::Udp(*udp_repr, udp_payload)), + this.iface_cx, + tx_token.take().unwrap(), + ); + if !ip_repr.dst_addr().is_broadcast() { + return; + } + } + + if !socket.can_process(udp_repr.dst_port) { + // TODO: Generate the ICMP message here once we're able to handle incoming ICMP + // messages. + let _ = this.process_udp(ip_repr, udp_repr, udp_payload); + return; + } + + // We cannot call `process_udp` now because it may cause deadlocks. We will copy + // the packet and call `process_udp` after releasing the socket lock. + deferred = Some((ip_repr.clone(), { + let mut data = vec![0; udp_repr.header_len() + udp_payload.len()]; + udp_repr.emit( + &mut UdpPacket::new_unchecked(&mut data), + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + udp_payload.len(), + |payload| payload.copy_from_slice(udp_payload), + &ChecksumCapabilities::ignored(), + ); + data + })); + }); + + if let Some((ip_repr, ip_payload)) = deferred { + if let Some(reply) = self.parse_and_process_udp( + &ip_repr, + &ip_payload, + &ChecksumCapabilities::ignored(), + ) { + dispatch_phy(&reply, self.iface_cx, tx_token.take().unwrap()); + } + } + + if tx_token.is_none() { + break; + } + } + + (did_something, tx_token) + } +} diff --git a/kernel/libs/aster-bigtcp/src/socket/bound.rs b/kernel/libs/aster-bigtcp/src/socket/bound.rs index 8ca0f53e6..775d48520 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound.rs @@ -1,44 +1,188 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::sync::{Arc, Weak}; +use alloc::{ + boxed::Box, + sync::{Arc, Weak}, +}; +use core::{ + ops::{Deref, DerefMut}, + sync::atomic::{AtomicBool, AtomicU64, Ordering}, +}; -use ostd::sync::RwLock; +use ostd::sync::{LocalIrqDisabled, RwLock, SpinLock, SpinLockGuard}; use smoltcp::{ - socket::tcp::ConnectError, - wire::{IpAddress, IpEndpoint}, + iface::Context, + socket::{udp::UdpMetadata, PollAt}, + time::Instant, + wire::{IpAddress, IpEndpoint, IpRepr, TcpRepr, UdpRepr}, }; use super::{event::SocketEventObserver, RawTcpSocket, RawUdpSocket}; use crate::iface::Iface; -pub(crate) enum SocketFamily { - Tcp, - Udp, +pub struct BoundSocket(Arc>); + +/// [`TcpSocket`] or [`UdpSocket`]. +pub trait AnySocket { + type RawSocket; + + /// Called by [`BoundSocket::new`]. + fn new(socket: Box) -> Self; + + /// Called by [`BoundSocket::drop`]. + fn on_drop(this: &Arc>) + where + Self: Sized; } -pub struct AnyBoundSocket(Arc>); +pub type BoundTcpSocket = BoundSocket; +pub type BoundUdpSocket = BoundSocket; -impl AnyBoundSocket { +/// Common states shared by [`BoundTcpSocketInner`] and [`BoundUdpSocketInner`]. +pub struct BoundSocketInner { + iface: Arc>, + port: u16, + socket: T, + observer: RwLock>, + next_poll_at_ms: AtomicU64, + has_new_events: AtomicBool, +} + +/// States needed by [`BoundTcpSocketInner`] but not [`BoundUdpSocketInner`]. +pub struct TcpSocket { + socket: SpinLock, + is_dead: AtomicBool, +} + +struct RawTcpSocketExt { + socket: Box, + /// Whether the socket is in the background. + /// + /// A background socket is a socket with its corresponding [`BoundSocket`] dropped. This means + /// that no more user events (like `send`/`recv`) can reach the socket, but it can be in a + /// state of waiting for certain network events (e.g., remote FIN/ACK packets), so + /// [`BoundSocketInner`] may still be alive for a while. + in_background: bool, +} + +impl Deref for RawTcpSocketExt { + type Target = RawTcpSocket; + + fn deref(&self) -> &Self::Target { + &self.socket + } +} + +impl DerefMut for RawTcpSocketExt { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.socket + } +} + +impl TcpSocket { + fn lock(&self) -> SpinLockGuard { + self.socket.lock() + } + + /// Returns whether the TCP socket is dead. + /// + /// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets. + fn is_dead(&self) -> bool { + self.is_dead.load(Ordering::Relaxed) + } + + /// Updates whether the TCP socket is dead. + /// + /// See [`BoundTcpSocketInner::is_dead`] for the definition of dead TCP sockets. + /// + /// This method must be called after handling network events. However, it is not necessary to + /// call this method after handling non-closing user events, because the socket can never be + /// dead if user events can reach the socket. + fn update_dead(&self, socket: &RawTcpSocketExt) { + if socket.in_background && socket.state() == smoltcp::socket::tcp::State::Closed { + self.is_dead.store(true, Ordering::Relaxed); + } + } +} + +impl AnySocket for TcpSocket { + type RawSocket = RawTcpSocket; + + fn new(socket: Box) -> Self { + let socket_ext = RawTcpSocketExt { + socket, + in_background: false, + }; + + Self { + socket: SpinLock::new(socket_ext), + is_dead: AtomicBool::new(false), + } + } + + fn on_drop(this: &Arc>) { + let mut socket = this.socket.lock(); + + socket.in_background = true; + socket.close(); + + // A TCP socket may not be appropriate for immediate removal. We leave the removal decision + // to the polling logic. + this.update_next_poll_at_ms(PollAt::Now); + this.socket.update_dead(&socket); + } +} + +/// States needed by [`BoundUdpSocketInner`] but not [`BoundTcpSocketInner`]. +type UdpSocket = SpinLock, LocalIrqDisabled>; + +impl AnySocket for UdpSocket { + type RawSocket = RawUdpSocket; + + fn new(socket: Box) -> Self { + Self::new(socket) + } + + fn on_drop(this: &Arc>) { + this.socket.lock().close(); + + // A UDP socket can be removed immediately. + this.iface.common().remove_udp_socket(this); + } +} + +impl Drop for BoundSocket { + fn drop(&mut self) { + T::on_drop(&self.0); + } +} + +pub(crate) type BoundTcpSocketInner = BoundSocketInner; +pub(crate) type BoundUdpSocketInner = BoundSocketInner; + +impl BoundSocket { pub(crate) fn new( iface: Arc>, - handle: smoltcp::iface::SocketHandle, port: u16, - socket_family: SocketFamily, + socket: Box, observer: Weak, ) -> Self { - Self(Arc::new(AnyBoundSocketInner { + Self(Arc::new(BoundSocketInner { iface, - handle, port, - socket_family, + socket: T::new(socket), observer: RwLock::new(observer), + next_poll_at_ms: AtomicU64::new(u64::MAX), + has_new_events: AtomicBool::new(false), })) } - pub(crate) fn inner(&self) -> &Arc> { + pub(crate) fn inner(&self) -> &Arc> { &self.0 } +} +impl BoundSocket { /// Sets the observer whose `on_events` will be called when certain iface events happen. After /// setting, the new observer will fire once immediately to avoid missing any events. /// @@ -51,6 +195,15 @@ impl AnyBoundSocket { self.0.on_iface_events(); } + /// Returns the observer. + /// + /// See also [`Self::set_observer`]. + pub fn observer(&self) -> Weak { + // We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we + // get the read lock. + self.0.observer.read().clone() + } + pub fn local_endpoint(&self) -> Option { let ip_addr = { let ipv4_addr = self.0.iface.ipv4_addr()?; @@ -59,58 +212,155 @@ impl AnyBoundSocket { Some(IpEndpoint::new(ip_addr, self.0.port)) } - pub fn raw_with, R, F: FnMut(&mut T) -> R>( - &self, - f: F, - ) -> R { - self.0.raw_with(f) - } - - /// Connects to a remote endpoint. - /// - /// # Panics - /// - /// This method will panic if the socket is not a TCP socket. - pub fn do_connect(&self, remote_endpoint: IpEndpoint) -> Result<(), ConnectError> { - let common = self.iface().common(); - - let mut sockets = common.sockets(); - let socket = sockets.get_mut::(self.0.handle); - - let mut iface = common.interface(); - let cx = iface.context(); - - socket.connect(cx, remote_endpoint, self.0.port) - } - pub fn iface(&self) -> &Arc> { &self.0.iface } } -impl Drop for AnyBoundSocket { - fn drop(&mut self) { - if self.0.start_closing() { - self.0.iface.common().remove_bound_socket_now(&self.0); - } else { - self.0 - .iface - .common() - .remove_bound_socket_when_closed(&self.0); - } +impl BoundTcpSocket { + /// Connects to a remote endpoint. + pub fn connect( + &self, + remote_endpoint: IpEndpoint, + ) -> Result<(), smoltcp::socket::tcp::ConnectError> { + let common = self.iface().common(); + let mut iface = common.interface(); + + let mut socket = self.0.socket.lock(); + + let result = socket.connect(iface.context(), remote_endpoint, self.0.port); + self.0 + .update_next_poll_at_ms(socket.poll_at(iface.context())); + + result + } + + /// Listens at a specified endpoint. + pub fn listen( + &self, + local_endpoint: IpEndpoint, + ) -> Result<(), smoltcp::socket::tcp::ListenError> { + let mut socket = self.0.socket.lock(); + + socket.listen(local_endpoint) + } + + pub fn send(&self, f: F) -> Result + where + F: FnOnce(&mut [u8]) -> (usize, R), + { + let mut socket = self.0.socket.lock(); + + let result = socket.send(f); + self.0.update_next_poll_at_ms(PollAt::Now); + + result + } + + pub fn recv(&self, f: F) -> Result + where + F: FnOnce(&mut [u8]) -> (usize, R), + { + let mut socket = self.0.socket.lock(); + + let result = socket.recv(f); + self.0.update_next_poll_at_ms(PollAt::Now); + + result + } + + pub fn close(&self) { + let mut socket = self.0.socket.lock(); + + socket.close(); + self.0.update_next_poll_at_ms(PollAt::Now); + } + + /// Calls `f` with an immutable reference to the associated [`RawTcpSocket`]. + // + // NOTE: If a mutable reference is required, add a method above that correctly updates the next + // polling time. + pub fn raw_with(&self, f: F) -> R + where + F: FnOnce(&RawTcpSocket) -> R, + { + let socket = self.0.socket.lock(); + f(&socket) } } -pub(crate) struct AnyBoundSocketInner { - iface: Arc>, - handle: smoltcp::iface::SocketHandle, - port: u16, - socket_family: SocketFamily, - observer: RwLock>, +impl BoundUdpSocket { + /// Binds to a specified endpoint. + pub fn bind(&self, local_endpoint: IpEndpoint) -> Result<(), smoltcp::socket::udp::BindError> { + let mut socket = self.0.socket.lock(); + + socket.bind(local_endpoint) + } + + pub fn send( + &self, + size: usize, + meta: impl Into, + f: F, + ) -> Result + where + F: FnOnce(&mut [u8]) -> R, + { + use smoltcp::socket::udp::SendError as SendErrorInner; + + use crate::errors::udp::SendError; + + let mut socket = self.0.socket.lock(); + + if size > socket.packet_send_capacity() { + return Err(SendError::TooLarge); + } + + let buffer = match socket.send(size, meta) { + Ok(data) => data, + Err(SendErrorInner::Unaddressable) => return Err(SendError::Unaddressable), + Err(SendErrorInner::BufferFull) => return Err(SendError::BufferFull), + }; + let result = f(buffer); + self.0.update_next_poll_at_ms(PollAt::Now); + + Ok(result) + } + + pub fn recv(&self, f: F) -> Result + where + F: FnOnce(&[u8], UdpMetadata) -> R, + { + let mut socket = self.0.socket.lock(); + + let (data, meta) = socket.recv()?; + let result = f(data, meta); + self.0.update_next_poll_at_ms(PollAt::Now); + + Ok(result) + } + + /// Calls `f` with an immutable reference to the associated [`RawUdpSocket`]. + // + // NOTE: If a mutable reference is required, add a method above that correctly updates the next + // polling time. + pub fn raw_with(&self, f: F) -> R + where + F: FnOnce(&RawUdpSocket) -> R, + { + let socket = self.0.socket.lock(); + f(&socket) + } } -impl AnyBoundSocketInner { +impl BoundSocketInner { + pub(crate) fn has_new_events(&self) -> bool { + self.has_new_events.load(Ordering::Relaxed) + } + pub(crate) fn on_iface_events(&self) { + self.has_new_events.store(false, Ordering::Relaxed); + // We never hold the write lock in IRQ handlers, so we don't need to disable IRQs when we // get the read lock. let observer = Weak::upgrade(&*self.observer.read()); @@ -120,51 +370,173 @@ impl AnyBoundSocketInner { } } - pub(crate) fn is_closed(&self) -> bool { - match self.socket_family { - SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| { - socket.state() == smoltcp::socket::tcp::State::Closed - }), - SocketFamily::Udp => true, - } + /// Returns the next polling time. + /// + /// Note: a zero means polling should be done now and a `u64::MAX` means no polling is required + /// before new network or user events. + pub(crate) fn next_poll_at_ms(&self) -> u64 { + self.next_poll_at_ms.load(Ordering::Relaxed) } - /// Starts closing the socket and returns whether the socket is closed. + /// Updates the next polling time according to `poll_at`. /// - /// For sockets that can be closed immediately, such as UDP sockets and TCP listening sockets, - /// this method will always return `true`. - /// - /// For other sockets, such as TCP connected sockets, they cannot be closed immediately because - /// we at least need to send the FIN packet and wait for the remote end to send an ACK packet. - /// In this case, this method will return `false` and [`Self::is_closed`] can be used to - /// determine if the closing process is complete. - fn start_closing(&self) -> bool { - match self.socket_family { - SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| { - socket.close(); - socket.state() == smoltcp::socket::tcp::State::Closed - }), - SocketFamily::Udp => { - self.raw_with(|socket: &mut RawUdpSocket| socket.close()); - true - } + /// The update is typically needed after new network or user events have been handled, so this + /// method also marks that there may be new events, so that the event observer provided by + /// [`BoundSocket::set_observer`] can be notified later. + fn update_next_poll_at_ms(&self, poll_at: PollAt) { + self.has_new_events.store(true, Ordering::Relaxed); + + match poll_at { + PollAt::Now => self.next_poll_at_ms.store(0, Ordering::Relaxed), + PollAt::Time(instant) => self + .next_poll_at_ms + .store(instant.total_millis() as u64, Ordering::Relaxed), + PollAt::Ingress => self.next_poll_at_ms.store(u64::MAX, Ordering::Relaxed), } } +} - pub fn raw_with, R, F: FnMut(&mut T) -> R>( +impl BoundSocketInner { + pub(crate) fn port(&self) -> u16 { + self.port + } +} + +impl BoundTcpSocketInner { + /// Returns whether the TCP socket is dead. + /// + /// A TCP socket is considered dead if and only if the following two conditions are met: + /// 1. The TCP connection is closed, so this socket cannot process any network events. + /// 2. The socket handle [`BoundTcpSocket`] is dropped, which means that this + /// [`BoundSocketInner`] is in background and no more user events can reach it. + pub(crate) fn is_dead(&self) -> bool { + self.socket.is_dead() + } +} + +impl BoundSocketInner { + /// Returns whether an incoming packet _may_ be processed by the socket. + /// + /// The check is intended to be lock-free and fast, but may have false positives. + pub(crate) fn can_process(&self, dst_port: u16) -> bool { + self.port == dst_port + } + + /// Returns whether the socket _may_ generate an outgoing packet. + /// + /// The check is intended to be lock-free and fast, but may have false positives. + pub(crate) fn need_dispatch(&self, now: Instant) -> bool { + now.total_millis() as u64 >= self.next_poll_at_ms.load(Ordering::Relaxed) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub(crate) enum TcpProcessResult { + NotProcessed, + Processed, + ProcessedWithReply(IpRepr, TcpRepr<'static>), +} + +impl BoundTcpSocketInner { + /// Tries to process an incoming packet and returns whether the packet is processed. + pub(crate) fn process( &self, - mut f: F, - ) -> R { - let mut sockets = self.iface.common().sockets(); - let socket = sockets.get_mut::(self.handle); - f(socket) + cx: &mut Context, + ip_repr: &IpRepr, + tcp_repr: &TcpRepr, + ) -> TcpProcessResult { + let mut socket = self.socket.lock(); + + if !socket.accepts(cx, ip_repr, tcp_repr) { + return TcpProcessResult::NotProcessed; + } + + let result = match socket.process(cx, ip_repr, tcp_repr) { + None => TcpProcessResult::Processed, + Some((ip_repr, tcp_repr)) => TcpProcessResult::ProcessedWithReply(ip_repr, tcp_repr), + }; + + self.update_next_poll_at_ms(socket.poll_at(cx)); + self.socket.update_dead(&socket); + + result + } + + /// Tries to generate an outgoing packet and dispatches the generated packet. + pub(crate) fn dispatch( + &self, + cx: &mut Context, + dispatch: D, + ) -> Option<(IpRepr, TcpRepr<'static>)> + where + D: FnOnce(&mut Context, &IpRepr, &TcpRepr) -> Option<(IpRepr, TcpRepr<'static>)>, + { + let mut socket = self.socket.lock(); + + let mut reply = None; + socket + .dispatch(cx, |cx, (ip_repr, tcp_repr)| { + reply = dispatch(cx, &ip_repr, &tcp_repr); + Ok::<(), ()>(()) + }) + .unwrap(); + + // `dispatch` can return a packet in response to the generated packet. If the socket + // accepts the packet, we can process it directly. + while let Some((ref ip_repr, ref tcp_repr)) = reply { + if !socket.accepts(cx, ip_repr, tcp_repr) { + break; + } + reply = socket.process(cx, ip_repr, tcp_repr); + } + + self.update_next_poll_at_ms(socket.poll_at(cx)); + self.socket.update_dead(&socket); + + reply } } -impl Drop for AnyBoundSocketInner { - fn drop(&mut self) { - let iface_common = self.iface.common(); - iface_common.remove_socket(self.handle); - iface_common.release_port(self.port); +impl BoundUdpSocketInner { + /// Tries to process an incoming packet and returns whether the packet is processed. + pub(crate) fn process( + &self, + cx: &mut Context, + ip_repr: &IpRepr, + udp_repr: &UdpRepr, + udp_payload: &[u8], + ) -> bool { + let mut socket = self.socket.lock(); + + if !socket.accepts(cx, ip_repr, udp_repr) { + return false; + } + + socket.process( + cx, + smoltcp::phy::PacketMeta::default(), + ip_repr, + udp_repr, + udp_payload, + ); + self.update_next_poll_at_ms(socket.poll_at(cx)); + + true + } + + /// Tries to generate an outgoing packet and dispatches the generated packet. + pub(crate) fn dispatch(&self, cx: &mut Context, dispatch: D) + where + D: FnOnce(&mut Context, &IpRepr, &UdpRepr, &[u8]), + { + let mut socket = self.socket.lock(); + + socket + .dispatch(cx, |cx, _meta, (ip_repr, udp_repr, udp_payload)| { + dispatch(cx, &ip_repr, &udp_repr, udp_payload); + Ok::<(), ()>(()) + }) + .unwrap(); + self.update_next_poll_at_ms(socket.poll_at(cx)); } } diff --git a/kernel/libs/aster-bigtcp/src/socket/mod.rs b/kernel/libs/aster-bigtcp/src/socket/mod.rs index 4c5929a32..a5b9de697 100644 --- a/kernel/libs/aster-bigtcp/src/socket/mod.rs +++ b/kernel/libs/aster-bigtcp/src/socket/mod.rs @@ -4,12 +4,11 @@ mod bound; mod event; mod unbound; -pub use bound::AnyBoundSocket; -pub(crate) use bound::{AnyBoundSocketInner, SocketFamily}; +pub use bound::{BoundTcpSocket, BoundUdpSocket}; +pub(crate) use bound::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult}; pub use event::SocketEventObserver; -pub(crate) use unbound::AnyRawSocket; pub use unbound::{ - AnyUnboundSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN, + UnboundTcpSocket, UnboundUdpSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN, UDP_SEND_PAYLOAD_LEN, }; diff --git a/kernel/libs/aster-bigtcp/src/socket/unbound.rs b/kernel/libs/aster-bigtcp/src/socket/unbound.rs index 9fbba4e08..97de13c6e 100644 --- a/kernel/libs/aster-bigtcp/src/socket/unbound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/unbound.rs @@ -1,34 +1,33 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::{sync::Weak, vec}; +use alloc::{boxed::Box, sync::Weak, vec}; use super::{event::SocketEventObserver, RawTcpSocket, RawUdpSocket}; -pub struct AnyUnboundSocket { - socket_family: AnyRawSocket, +pub struct UnboundSocket { + socket: Box, observer: Weak, } -#[allow(clippy::large_enum_variant)] -pub(crate) enum AnyRawSocket { - Tcp(RawTcpSocket), - Udp(RawUdpSocket), -} +pub type UnboundTcpSocket = UnboundSocket; +pub type UnboundUdpSocket = UnboundSocket; -impl AnyUnboundSocket { - pub fn new_tcp(observer: Weak) -> Self { +impl UnboundTcpSocket { + pub fn new(observer: Weak) -> Self { let raw_tcp_socket = { let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_RECV_BUF_LEN]); let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; TCP_SEND_BUF_LEN]); RawTcpSocket::new(rx_buffer, tx_buffer) }; - AnyUnboundSocket { - socket_family: AnyRawSocket::Tcp(raw_tcp_socket), + Self { + socket: Box::new(raw_tcp_socket), observer, } } +} - pub fn new_udp(observer: Weak) -> Self { +impl UnboundUdpSocket { + pub fn new(observer: Weak) -> Self { let raw_udp_socket = { let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY; let rx_buffer = smoltcp::socket::udp::PacketBuffer::new( @@ -41,14 +40,16 @@ impl AnyUnboundSocket { ); RawUdpSocket::new(rx_buffer, tx_buffer) }; - AnyUnboundSocket { - socket_family: AnyRawSocket::Udp(raw_udp_socket), + Self { + socket: Box::new(raw_udp_socket), observer, } } +} - pub(crate) fn into_raw(self) -> (AnyRawSocket, Weak) { - (self.socket_family, self.observer) +impl UnboundSocket { + pub(crate) fn into_raw(self) -> (Box, Weak) { + (self.socket, self.observer) } } diff --git a/kernel/src/net/iface/mod.rs b/kernel/src/net/iface/mod.rs index ce5674fea..d7a411d4e 100644 --- a/kernel/src/net/iface/mod.rs +++ b/kernel/src/net/iface/mod.rs @@ -8,4 +8,5 @@ pub use init::{init, IFACES}; pub use poll::{lazy_init, poll_ifaces}; pub type Iface = dyn aster_bigtcp::iface::Iface; -pub type AnyBoundSocket = aster_bigtcp::socket::AnyBoundSocket; +pub type BoundTcpSocket = aster_bigtcp::socket::BoundTcpSocket; +pub type BoundUdpSocket = aster_bigtcp::socket::BoundUdpSocket; diff --git a/kernel/src/net/socket/ip/common.rs b/kernel/src/net/socket/ip/common.rs index 113acea8f..bd081ad20 100644 --- a/kernel/src/net/socket/ip/common.rs +++ b/kernel/src/net/socket/ip/common.rs @@ -3,12 +3,11 @@ use aster_bigtcp::{ errors::BindError, iface::BindPortConfig, - socket::AnyUnboundSocket, wire::{IpAddress, IpEndpoint}, }; use crate::{ - net::iface::{AnyBoundSocket, Iface, IFACES}, + net::iface::{Iface, IFACES}, prelude::*, }; @@ -46,11 +45,16 @@ fn get_ephemeral_iface(remote_ip_addr: &IpAddress) -> Arc { ifaces[0].clone() } -pub(super) fn bind_socket( - unbound_socket: Box, +pub(super) fn bind_socket( + unbound_socket: Box, endpoint: &IpEndpoint, can_reuse: bool, -) -> core::result::Result)> { + bind: impl FnOnce( + Arc, + Box, + BindPortConfig, + ) -> core::result::Result)>, +) -> core::result::Result)> { let iface = match get_iface_to_bind(&endpoint.addr) { Some(iface) => iface, None => { @@ -64,9 +68,7 @@ pub(super) fn bind_socket( let bind_port_config = BindPortConfig::new(endpoint.port, can_reuse); - iface - .bind_socket(unbound_socket, bind_port_config) - .map_err(|(err, unbound)| (err.into(), unbound)) + bind(iface, unbound_socket, bind_port_config).map_err(|(err, unbound)| (err.into(), unbound)) } impl From for Error { diff --git a/kernel/src/net/socket/ip/datagram/bound.rs b/kernel/src/net/socket/ip/datagram/bound.rs index ddd9da673..d4b36f0ee 100644 --- a/kernel/src/net/socket/ip/datagram/bound.rs +++ b/kernel/src/net/socket/ip/datagram/bound.rs @@ -2,25 +2,24 @@ use aster_bigtcp::{ errors::udp::{RecvError, SendError}, - socket::RawUdpSocket, wire::IpEndpoint, }; use crate::{ events::IoEvents, - net::{iface::AnyBoundSocket, socket::util::send_recv_flags::SendRecvFlags}, + net::{iface::BoundUdpSocket, socket::util::send_recv_flags::SendRecvFlags}, prelude::*, process::signal::Pollee, util::{MultiRead, MultiWrite}, }; pub struct BoundDatagram { - bound_socket: AnyBoundSocket, + bound_socket: BoundUdpSocket, remote_endpoint: Option, } impl BoundDatagram { - pub fn new(bound_socket: AnyBoundSocket) -> Self { + pub fn new(bound_socket: BoundUdpSocket) -> Self { Self { bound_socket, remote_endpoint: None, @@ -44,12 +43,10 @@ impl BoundDatagram { writer: &mut dyn MultiWrite, _flags: SendRecvFlags, ) -> Result<(usize, IpEndpoint)> { - let result = self.bound_socket.raw_with(|socket: &mut RawUdpSocket| { - socket.recv().map(|(packet, udp_metadata)| { - let copied_res = writer.write(&mut VmReader::from(packet)); - let endpoint = udp_metadata.endpoint; - (copied_res, endpoint) - }) + let result = self.bound_socket.recv(|packet, udp_metadata| { + let copied_res = writer.write(&mut VmReader::from(packet)); + let endpoint = udp_metadata.endpoint; + (copied_res, endpoint) }); match result { @@ -59,7 +56,7 @@ impl BoundDatagram { return_errno_with_message!(Errno::EAGAIN, "the receive buffer is empty") } Err(RecvError::Truncated) => { - unreachable!("`Socket::recv` should never fail with `RecvError::Truncated`") + unreachable!("`recv` should never fail with `RecvError::Truncated`") } } } @@ -70,32 +67,31 @@ impl BoundDatagram { remote: &IpEndpoint, _flags: SendRecvFlags, ) -> Result { - let reader_len = reader.sum_lens(); + let result = self + .bound_socket + .send(reader.sum_lens(), *remote, |socket_buffer| { + // FIXME: If copy failed, we should not send any packet. + // But current smoltcp API seems not to support this behavior. + reader + .read(&mut VmWriter::from(socket_buffer)) + .map_err(|e| { + warn!("unexpected UDP packet will be sent"); + e + }) + }); - self.bound_socket.raw_with(|socket: &mut RawUdpSocket| { - if socket.payload_send_capacity() < reader_len { + match result { + Ok(inner) => inner, + Err(SendError::TooLarge) => { return_errno_with_message!(Errno::EMSGSIZE, "the message is too large"); } - - let socket_buffer = match socket.send(reader_len, *remote) { - Ok(socket_buffer) => socket_buffer, - Err(SendError::BufferFull) => { - return_errno_with_message!(Errno::EAGAIN, "the send buffer is full") - } - Err(SendError::Unaddressable) => { - return_errno_with_message!(Errno::EINVAL, "the destination address is invalid") - } - }; - - // FIXME: If copy failed, we should not send any packet. - // But current smoltcp API seems not to support this behavior. - reader - .read(&mut VmWriter::from(socket_buffer)) - .map_err(|e| { - warn!("unexpected UDP packet will be sent"); - e - }) - }) + Err(SendError::Unaddressable) => { + return_errno_with_message!(Errno::EINVAL, "the destination address is invalid"); + } + Err(SendError::BufferFull) => { + return_errno_with_message!(Errno::EAGAIN, "the send buffer is full"); + } + } } pub(super) fn init_pollee(&self, pollee: &Pollee) { @@ -104,7 +100,7 @@ impl BoundDatagram { } pub(super) fn update_io_events(&self, pollee: &Pollee) { - self.bound_socket.raw_with(|socket: &mut RawUdpSocket| { + self.bound_socket.raw_with(|socket| { if socket.can_recv() { pollee.add_events(IoEvents::IN); } else { diff --git a/kernel/src/net/socket/ip/datagram/mod.rs b/kernel/src/net/socket/ip/datagram/mod.rs index 61bc9dc33..b1989bc08 100644 --- a/kernel/src/net/socket/ip/datagram/mod.rs +++ b/kernel/src/net/socket/ip/datagram/mod.rs @@ -154,13 +154,9 @@ impl DatagramSocket { return_errno_with_message!(Errno::EAGAIN, "the socket is not bound"); }; - let received = - bound_datagram - .try_recv(writer, flags) - .map(|(recv_bytes, remote_endpoint)| { - bound_datagram.update_io_events(&self.pollee); - (recv_bytes, remote_endpoint.into()) - }); + let received = bound_datagram + .try_recv(writer, flags) + .map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into())); drop(inner); poll_ifaces(); @@ -192,12 +188,7 @@ impl DatagramSocket { return_errno_with_message!(Errno::EAGAIN, "the socket is not bound") }; - let sent_bytes = bound_datagram - .try_send(reader, remote, flags) - .map(|sent_bytes| { - bound_datagram.update_io_events(&self.pollee); - sent_bytes - }); + let sent_bytes = bound_datagram.try_send(reader, remote, flags); drop(inner); poll_ifaces(); diff --git a/kernel/src/net/socket/ip/datagram/unbound.rs b/kernel/src/net/socket/ip/datagram/unbound.rs index 365233f36..7415a2b2e 100644 --- a/kernel/src/net/socket/ip/datagram/unbound.rs +++ b/kernel/src/net/socket/ip/datagram/unbound.rs @@ -3,7 +3,7 @@ use alloc::sync::Weak; use aster_bigtcp::{ - socket::{AnyUnboundSocket, RawUdpSocket, SocketEventObserver}, + socket::{SocketEventObserver, UnboundUdpSocket}, wire::IpEndpoint, }; @@ -13,13 +13,13 @@ use crate::{ }; pub struct UnboundDatagram { - unbound_socket: Box, + unbound_socket: Box, } impl UnboundDatagram { pub fn new(observer: Weak) -> Self { Self { - unbound_socket: Box::new(AnyUnboundSocket::new_udp(observer)), + unbound_socket: Box::new(UnboundUdpSocket::new(observer)), } } @@ -28,15 +28,18 @@ impl UnboundDatagram { endpoint: &IpEndpoint, can_reuse: bool, ) -> core::result::Result { - let bound_socket = match bind_socket(self.unbound_socket, endpoint, can_reuse) { + let bound_socket = match bind_socket( + self.unbound_socket, + endpoint, + can_reuse, + |iface, socket, config| iface.bind_udp(socket, config), + ) { Ok(bound_socket) => bound_socket, Err((err, unbound_socket)) => return Err((err, Self { unbound_socket })), }; let bound_endpoint = bound_socket.local_endpoint().unwrap(); - bound_socket.raw_with(|socket: &mut RawUdpSocket| { - socket.bind(bound_endpoint).unwrap(); - }); + bound_socket.bind(bound_endpoint).unwrap(); Ok(BoundDatagram::new(bound_socket)) } diff --git a/kernel/src/net/socket/ip/stream/connected.rs b/kernel/src/net/socket/ip/stream/connected.rs index c8ae4b40b..9482487db 100644 --- a/kernel/src/net/socket/ip/stream/connected.rs +++ b/kernel/src/net/socket/ip/stream/connected.rs @@ -4,14 +4,14 @@ use alloc::sync::Weak; use aster_bigtcp::{ errors::tcp::{RecvError, SendError}, - socket::{RawTcpSocket, SocketEventObserver}, + socket::SocketEventObserver, wire::IpEndpoint, }; use crate::{ events::IoEvents, net::{ - iface::AnyBoundSocket, + iface::BoundTcpSocket, socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd}, }, prelude::*, @@ -20,7 +20,7 @@ use crate::{ }; pub struct ConnectedStream { - bound_socket: AnyBoundSocket, + bound_socket: BoundTcpSocket, remote_endpoint: IpEndpoint, /// Indicates whether this connection is "new" in a `connect()` system call. /// @@ -37,7 +37,7 @@ pub struct ConnectedStream { impl ConnectedStream { pub fn new( - bound_socket: AnyBoundSocket, + bound_socket: BoundTcpSocket, remote_endpoint: IpEndpoint, is_new_connection: bool, ) -> Self { @@ -50,20 +50,16 @@ impl ConnectedStream { pub fn shutdown(&self, _cmd: SockShutdownCmd) -> Result<()> { // TODO: deal with cmd - self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { - socket.close(); - }); + self.bound_socket.close(); Ok(()) } pub fn try_recv(&self, writer: &mut dyn MultiWrite, _flags: SendRecvFlags) -> Result { - let result = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { - socket.recv( - |socket_buffer| match writer.write(&mut VmReader::from(&*socket_buffer)) { - Ok(len) => (len, Ok(len)), - Err(e) => (0, Err(e)), - }, - ) + let result = self.bound_socket.recv(|socket_buffer| { + match writer.write(&mut VmReader::from(&*socket_buffer)) { + Ok(len) => (len, Ok(len)), + Err(e) => (0, Err(e)), + } }); match result { @@ -78,13 +74,11 @@ impl ConnectedStream { } pub fn try_send(&self, reader: &mut dyn MultiRead, _flags: SendRecvFlags) -> Result { - let result = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { - socket.send( - |socket_buffer| match reader.read(&mut VmWriter::from(socket_buffer)) { - Ok(len) => (len, Ok(len)), - Err(e) => (0, Err(e)), - }, - ) + let result = self.bound_socket.send(|socket_buffer| { + match reader.read(&mut VmWriter::from(socket_buffer)) { + Ok(len) => (len, Ok(len)), + Err(e) => (0, Err(e)), + } }); match result { @@ -123,7 +117,7 @@ impl ConnectedStream { } pub(super) fn update_io_events(&self, pollee: &Pollee) { - self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { + self.bound_socket.raw_with(|socket| { if socket.can_recv() { pollee.add_events(IoEvents::IN); } else { diff --git a/kernel/src/net/socket/ip/stream/connecting.rs b/kernel/src/net/socket/ip/stream/connecting.rs index 413e1cb9c..8a3141208 100644 --- a/kernel/src/net/socket/ip/stream/connecting.rs +++ b/kernel/src/net/socket/ip/stream/connecting.rs @@ -1,12 +1,12 @@ // SPDX-License-Identifier: MPL-2.0 -use aster_bigtcp::{socket::RawTcpSocket, wire::IpEndpoint}; +use aster_bigtcp::wire::IpEndpoint; use super::{connected::ConnectedStream, init::InitStream}; -use crate::{net::iface::AnyBoundSocket, prelude::*, process::signal::Pollee}; +use crate::{net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee}; pub struct ConnectingStream { - bound_socket: AnyBoundSocket, + bound_socket: BoundTcpSocket, remote_endpoint: IpEndpoint, conn_result: RwLock>, } @@ -24,15 +24,15 @@ pub enum NonConnectedStream { impl ConnectingStream { pub fn new( - bound_socket: AnyBoundSocket, + bound_socket: BoundTcpSocket, remote_endpoint: IpEndpoint, - ) -> core::result::Result { + ) -> core::result::Result { // The only reason this method might fail is because we're trying to connect to an // unspecified address (i.e. 0.0.0.0). We currently have no support for binding to, // listening on, or connecting to the unspecified address. // // We assume the remote will just refuse to connect, so we return `ECONNREFUSED`. - if bound_socket.do_connect(remote_endpoint).is_err() { + if bound_socket.connect(remote_endpoint).is_err() { return Err(( Error::with_message( Errno::ECONNREFUSED, @@ -91,7 +91,7 @@ impl ConnectingStream { return false; } - self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { + self.bound_socket.raw_with(|socket| { let mut result = self.conn_result.write(); if result.is_some() { return false; diff --git a/kernel/src/net/socket/ip/stream/init.rs b/kernel/src/net/socket/ip/stream/init.rs index 0f312793d..1650318fc 100644 --- a/kernel/src/net/socket/ip/stream/init.rs +++ b/kernel/src/net/socket/ip/stream/init.rs @@ -3,7 +3,7 @@ use alloc::sync::Weak; use aster_bigtcp::{ - socket::{AnyUnboundSocket, SocketEventObserver}, + socket::{SocketEventObserver, UnboundTcpSocket}, wire::IpEndpoint, }; @@ -11,7 +11,7 @@ use super::{connecting::ConnectingStream, listen::ListenStream}; use crate::{ events::IoEvents, net::{ - iface::AnyBoundSocket, + iface::BoundTcpSocket, socket::ip::common::{bind_socket, get_ephemeral_endpoint}, }, prelude::*, @@ -19,16 +19,16 @@ use crate::{ }; pub enum InitStream { - Unbound(Box), - Bound(AnyBoundSocket), + Unbound(Box), + Bound(BoundTcpSocket), } impl InitStream { pub fn new(observer: Weak) -> Self { - InitStream::Unbound(Box::new(AnyUnboundSocket::new_tcp(observer))) + InitStream::Unbound(Box::new(UnboundTcpSocket::new(observer))) } - pub fn new_bound(bound_socket: AnyBoundSocket) -> Self { + pub fn new_bound(bound_socket: BoundTcpSocket) -> Self { InitStream::Bound(bound_socket) } @@ -36,7 +36,7 @@ impl InitStream { self, endpoint: &IpEndpoint, can_reuse: bool, - ) -> core::result::Result { + ) -> core::result::Result { let unbound_socket = match self { InitStream::Unbound(unbound_socket) => unbound_socket, InitStream::Bound(bound_socket) => { @@ -46,7 +46,12 @@ impl InitStream { )); } }; - let bound_socket = match bind_socket(unbound_socket, endpoint, can_reuse) { + let bound_socket = match bind_socket( + unbound_socket, + endpoint, + can_reuse, + |iface, socket, config| iface.bind_tcp(socket, config), + ) { Ok(bound_socket) => bound_socket, Err((err, unbound_socket)) => return Err((err, InitStream::Unbound(unbound_socket))), }; @@ -56,7 +61,7 @@ impl InitStream { fn bind_to_ephemeral_endpoint( self, remote_endpoint: &IpEndpoint, - ) -> core::result::Result { + ) -> core::result::Result { let endpoint = get_ephemeral_endpoint(remote_endpoint); self.bind(&endpoint, false) } diff --git a/kernel/src/net/socket/ip/stream/listen.rs b/kernel/src/net/socket/ip/stream/listen.rs index 45b6f7cfa..51489d7f5 100644 --- a/kernel/src/net/socket/ip/stream/listen.rs +++ b/kernel/src/net/socket/ip/stream/listen.rs @@ -1,28 +1,25 @@ // SPDX-License-Identifier: MPL-2.0 use aster_bigtcp::{ - errors::tcp::ListenError, - iface::BindPortConfig, - socket::{AnyUnboundSocket, RawTcpSocket}, - wire::IpEndpoint, + errors::tcp::ListenError, iface::BindPortConfig, socket::UnboundTcpSocket, wire::IpEndpoint, }; use super::connected::ConnectedStream; -use crate::{events::IoEvents, net::iface::AnyBoundSocket, prelude::*, process::signal::Pollee}; +use crate::{events::IoEvents, net::iface::BoundTcpSocket, prelude::*, process::signal::Pollee}; pub struct ListenStream { backlog: usize, /// A bound socket held to ensure the TCP port cannot be released - bound_socket: AnyBoundSocket, + bound_socket: BoundTcpSocket, /// Backlog sockets listening at the local endpoint backlog_sockets: RwLock>, } impl ListenStream { pub fn new( - bound_socket: AnyBoundSocket, + bound_socket: BoundTcpSocket, backlog: usize, - ) -> core::result::Result { + ) -> core::result::Result { const SOMAXCONN: usize = 4096; let somaxconn = SOMAXCONN.min(backlog); @@ -102,30 +99,28 @@ impl ListenStream { } struct BacklogSocket { - bound_socket: AnyBoundSocket, + bound_socket: BoundTcpSocket, } impl BacklogSocket { // FIXME: All of the error codes below seem to have no Linux equivalents, and I see no reason // why the error may occur. Perhaps it is better to call `unwrap()` directly? - fn new(bound_socket: &AnyBoundSocket) -> Result { + fn new(bound_socket: &BoundTcpSocket) -> Result { let local_endpoint = bound_socket.local_endpoint().ok_or(Error::with_message( Errno::EINVAL, "the socket is not bound", ))?; - let unbound_socket = Box::new(AnyUnboundSocket::new_tcp(Weak::<()>::new())); + let unbound_socket = Box::new(UnboundTcpSocket::new(bound_socket.observer())); let bound_socket = { let iface = bound_socket.iface(); let bind_port_config = BindPortConfig::new(local_endpoint.port, true); iface - .bind_socket(unbound_socket, bind_port_config) + .bind_tcp(unbound_socket, bind_port_config) .map_err(|(err, _)| err)? }; - let result = bound_socket - .raw_with(|raw_tcp_socket: &mut RawTcpSocket| raw_tcp_socket.listen(local_endpoint)); - match result { + match bound_socket.listen(local_endpoint) { Ok(()) => Ok(Self { bound_socket }), Err(ListenError::Unaddressable) => { return_errno_with_message!(Errno::EINVAL, "the listening address is invalid") @@ -137,16 +132,15 @@ impl BacklogSocket { } fn is_active(&self) -> bool { - self.bound_socket - .raw_with(|socket: &mut RawTcpSocket| socket.is_active()) + self.bound_socket.raw_with(|socket| socket.is_active()) } fn remote_endpoint(&self) -> Option { self.bound_socket - .raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint()) + .raw_with(|socket| socket.remote_endpoint()) } - fn into_bound_socket(self) -> AnyBoundSocket { + fn into_bound_socket(self) -> BoundTcpSocket { self.bound_socket } } diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 812a506af..0ff58487a 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -230,18 +230,13 @@ impl StreamSocket { return_errno_with_message!(Errno::EINVAL, "the socket is not listening"); }; - let accepted = listen_stream.try_accept().map(|connected_stream| { + listen_stream.try_accept().map(|connected_stream| { listen_stream.update_io_events(&self.pollee); let remote_endpoint = connected_stream.remote_endpoint(); let accepted_socket = Self::new_connected(connected_stream); (accepted_socket as _, remote_endpoint.into()) - }); - - drop(state); - poll_ifaces(); - - accepted + }) } fn try_recv( @@ -262,8 +257,6 @@ impl StreamSocket { }; let received = connected_stream.try_recv(writer, flags).map(|recv_bytes| { - connected_stream.update_io_events(&self.pollee); - let remote_endpoint = connected_stream.remote_endpoint(); (recv_bytes, remote_endpoint.into()) }); @@ -302,10 +295,7 @@ impl StreamSocket { } }; - let sent_bytes = connected_stream.try_send(reader, flags).map(|sent_bytes| { - connected_stream.update_io_events(&self.pollee); - sent_bytes - }); + let sent_bytes = connected_stream.try_send(reader, flags); drop(state); poll_ifaces(); @@ -658,3 +648,11 @@ impl SocketEventObserver for StreamSocket { } } } + +impl Drop for StreamSocket { + fn drop(&mut self) { + self.state.write().take(); + + poll_ifaces(); + } +}