diff --git a/kernel/src/net/iface/any_socket.rs b/kernel/src/net/iface/any_socket.rs index cce36b625..97379fc5e 100644 --- a/kernel/src/net/iface/any_socket.rs +++ b/kernel/src/net/iface/any_socket.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 -use super::Iface; +use super::{ext::IfaceExt, Iface}; use crate::{ events::Observer, net::socket::ip::{IpAddress, IpEndpoint}, @@ -63,11 +63,11 @@ impl AnyUnboundSocket { } } -pub struct AnyBoundSocket(Arc); +pub struct AnyBoundSocket(Arc>); -impl AnyBoundSocket { +impl AnyBoundSocket { pub(super) fn new( - iface: Arc, + iface: Arc>, handle: smoltcp::iface::SocketHandle, port: u16, socket_family: SocketFamily, @@ -82,7 +82,7 @@ impl AnyBoundSocket { })) } - pub(super) fn inner(&self) -> &Arc { + pub(super) fn inner(&self) -> &Arc> { &self.0 } @@ -144,12 +144,12 @@ impl AnyBoundSocket { Ok(()) } - pub fn iface(&self) -> &Arc { + pub fn iface(&self) -> &Arc> { &self.0.iface } } -impl Drop for AnyBoundSocket { +impl Drop for AnyBoundSocket { fn drop(&mut self) { if self.0.start_closing() { self.0.iface.common().remove_bound_socket_now(&self.0); @@ -162,15 +162,15 @@ impl Drop for AnyBoundSocket { } } -pub(super) struct AnyBoundSocketInner { - iface: Arc, +pub(super) struct AnyBoundSocketInner { + iface: Arc>, handle: smoltcp::iface::SocketHandle, port: u16, socket_family: SocketFamily, observer: RwLock>>, } -impl AnyBoundSocketInner { +impl AnyBoundSocketInner { pub(super) fn on_iface_events(&self) { if let Some(observer) = Weak::upgrade(&*self.observer.read()) { observer.on_events(&()) @@ -218,7 +218,7 @@ impl AnyBoundSocketInner { } } -impl Drop for AnyBoundSocketInner { +impl Drop for AnyBoundSocketInner { fn drop(&mut self) { let iface_common = self.iface.common(); iface_common.remove_socket(self.handle); diff --git a/kernel/src/net/iface/common.rs b/kernel/src/net/iface/common.rs index 02125123e..492e73294 100644 --- a/kernel/src/net/iface/common.rs +++ b/kernel/src/net/iface/common.rs @@ -1,10 +1,9 @@ // SPDX-License-Identifier: MPL-2.0 use alloc::collections::btree_map::Entry; -use core::sync::atomic::{AtomicU64, Ordering}; use keyable_arc::KeyableArc; -use ostd::sync::{LocalIrqDisabled, WaitQueue}; +use ostd::sync::LocalIrqDisabled; use smoltcp::{ iface::{SocketHandle, SocketSet}, phy::Device, @@ -12,36 +11,33 @@ use smoltcp::{ use super::{ any_socket::{AnyBoundSocketInner, AnyRawSocket, AnyUnboundSocket, SocketFamily}, + ext::IfaceExt, time::get_network_timestamp, util::BindPortConfig, AnyBoundSocket, Iface, }; use crate::{net::socket::ip::Ipv4Address, prelude::*}; -pub struct IfaceCommon { +pub struct IfaceCommon { interface: SpinLock, sockets: SpinLock>, used_ports: RwLock>, - /// The time should do next poll. We stores the total milliseconds since system boots up. - next_poll_at_ms: AtomicU64, - bound_sockets: RwLock>>, - closing_sockets: SpinLock>>, - /// The wait queue that background polling thread will sleep on - polling_wait_queue: WaitQueue, + bound_sockets: RwLock>>>, + closing_sockets: SpinLock>>>, + ext: E, } -impl IfaceCommon { - pub(super) fn new(interface: smoltcp::iface::Interface) -> Self { +impl IfaceCommon { + pub(super) fn new(interface: smoltcp::iface::Interface, ext: E) -> Self { let socket_set = SocketSet::new(Vec::new()); let used_ports = BTreeMap::new(); Self { interface: SpinLock::new(interface), sockets: SpinLock::new(socket_set), used_ports: RwLock::new(used_ports), - next_poll_at_ms: AtomicU64::new(0), bound_sockets: RwLock::new(BTreeSet::new()), closing_sockets: SpinLock::new(BTreeSet::new()), - polling_wait_queue: WaitQueue::new(), + ext, } } @@ -65,10 +61,6 @@ impl IfaceCommon { self.interface.disable_irq().lock().ipv4_addr() } - pub(super) fn polling_wait_queue(&self) -> &WaitQueue { - &self.polling_wait_queue - } - /// Alloc an unused port range from 49152 ~ 65535 (According to smoltcp docs) fn alloc_ephemeral_port(&self) -> Result { let mut used_ports = self.used_ports.write(); @@ -108,10 +100,10 @@ impl IfaceCommon { pub(super) fn bind_socket( &self, - iface: Arc, + iface: Arc>, socket: Box, config: BindPortConfig, - ) -> core::result::Result)> { + ) -> core::result::Result, (Error, Box)> { let port = if let Some(port) = config.port() { port } else { @@ -147,7 +139,8 @@ impl IfaceCommon { self.sockets.disable_irq().lock().remove(handle); } - pub(super) fn poll(&self, device: &mut D) { + #[must_use] + pub(super) fn poll(&self, device: &mut D) -> Option { let mut sockets = self.sockets.disable_irq().lock(); let mut interface = self.interface.disable_irq().lock(); @@ -183,18 +176,6 @@ impl IfaceCommon { drop(interface); drop(sockets); - if let Some(instant) = poll_at { - let old_instant = self.next_poll_at_ms.load(Ordering::Relaxed); - let new_instant = instant.total_millis() as u64; - self.next_poll_at_ms.store(new_instant, Ordering::Relaxed); - - if old_instant == 0 || new_instant < old_instant { - self.polling_wait_queue.wake_all(); - } - } else { - self.next_poll_at_ms.store(0, Ordering::Relaxed); - } - if has_events { // We never try to hold the write lock in the IRQ context, and we disable IRQ when // holding the write lock. So we don't need to disable IRQ when holding the read lock. @@ -210,18 +191,15 @@ impl IfaceCommon { .collect::>(); drop(closed_sockets); } + + poll_at.map(|at| smoltcp::time::Instant::total_millis(&at) as u64) } - pub(super) fn next_poll_at_ms(&self) -> Option { - let millis = self.next_poll_at_ms.load(Ordering::Relaxed); - if millis == 0 { - None - } else { - Some(millis) - } + pub(super) fn ext(&self) -> &E { + &self.ext } - fn insert_bound_socket(&self, socket: &Arc) { + fn insert_bound_socket(&self, socket: &Arc>) { let keyable_socket = KeyableArc::from(socket.clone()); let inserted = self @@ -231,7 +209,7 @@ impl IfaceCommon { assert!(inserted); } - pub(super) fn remove_bound_socket_now(&self, socket: &Arc) { + pub(super) fn remove_bound_socket_now(&self, socket: &Arc>) { let keyable_socket = KeyableArc::from(socket.clone()); let removed = self @@ -241,7 +219,7 @@ impl IfaceCommon { assert!(removed); } - pub(super) fn remove_bound_socket_when_closed(&self, socket: &Arc) { + pub(super) fn remove_bound_socket_when_closed(&self, socket: &Arc>) { let keyable_socket = KeyableArc::from(socket.clone()); let removed = self diff --git a/kernel/src/net/iface/ext.rs b/kernel/src/net/iface/ext.rs new file mode 100644 index 000000000..48ad7da95 --- /dev/null +++ b/kernel/src/net/iface/ext.rs @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MPL-2.0 + +use alloc::string::String; +use core::sync::atomic::{AtomicU64, Ordering}; + +use ostd::sync::WaitQueue; + +use super::Iface; + +/// The iface extension. +pub struct IfaceExt { + /// The name of the iface. + name: String, + /// The time when we should do the next poll. + /// We store the total number of milliseconds since the system booted. + next_poll_at_ms: AtomicU64, + /// The wait queue that the background polling thread will sleep on. + polling_wait_queue: WaitQueue, +} + +impl IfaceExt { + pub(super) fn new(name: String) -> Self { + Self { + name, + next_poll_at_ms: AtomicU64::new(0), + polling_wait_queue: WaitQueue::new(), + } + } + + pub(super) fn next_poll_at_ms(&self) -> Option { + let millis = self.next_poll_at_ms.load(Ordering::Relaxed); + if millis == 0 { + None + } else { + Some(millis) + } + } + + pub(super) fn polling_wait_queue(&self) -> &WaitQueue { + &self.polling_wait_queue + } + + fn schedule_next_poll(&self, poll_at: Option) { + let Some(new_instant) = poll_at else { + self.next_poll_at_ms.store(0, Ordering::Relaxed); + return; + }; + + let old_instant = self.next_poll_at_ms.load(Ordering::Relaxed); + self.next_poll_at_ms.store(new_instant, Ordering::Relaxed); + + if old_instant == 0 || new_instant < old_instant { + self.polling_wait_queue.wake_all(); + } + } +} + +pub trait IfaceEx { + /// Gets the name of the iface. + /// + /// In Linux, the name is usually the driver name followed by a unit number. + fn name(&self) -> &str; + + /// Transmits or receives packets queued in the iface, and updates socket status accordingly. + /// + /// The background polling thread is woken up to perform the next poll if necessary. + fn poll(&self); +} + +impl IfaceEx for dyn Iface { + fn name(&self) -> &str { + &self.ext().name + } + + fn poll(&self) { + self.raw_poll(&|next_poll| self.ext().schedule_next_poll(next_poll)); + } +} diff --git a/kernel/src/net/iface/init.rs b/kernel/src/net/iface/init.rs new file mode 100644 index 000000000..f7572617e --- /dev/null +++ b/kernel/src/net/iface/init.rs @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MPL-2.0 + +use alloc::sync::Arc; + +use spin::Once; + +use super::{spawn_background_poll_thread, Iface}; +use crate::{ + net::iface::{ext::IfaceEx, IfaceLoopback, IfaceVirtio}, + prelude::*, +}; + +pub static IFACES: Once>> = Once::new(); + +pub fn init() { + IFACES.call_once(|| { + let iface_virtio = IfaceVirtio::new(); + let iface_loopback = IfaceLoopback::new(); + vec![iface_virtio, iface_loopback] + }); + + for (name, _) in aster_network::all_devices() { + aster_network::register_recv_callback(&name, || { + // TODO: further check that the irq num is the same as iface's irq num + let iface_virtio = &IFACES.get().unwrap()[0]; + iface_virtio.poll(); + }) + } + + poll_ifaces(); +} + +pub fn lazy_init() { + for iface in IFACES.get().unwrap() { + spawn_background_poll_thread(iface.clone()); + } +} + +pub fn poll_ifaces() { + let ifaces = IFACES.get().unwrap(); + + for iface in ifaces.iter() { + iface.poll(); + } +} diff --git a/kernel/src/net/iface/loopback.rs b/kernel/src/net/iface/loopback.rs index eba160a69..57296821b 100644 --- a/kernel/src/net/iface/loopback.rs +++ b/kernel/src/net/iface/loopback.rs @@ -1,5 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 +use alloc::borrow::ToOwned; + use smoltcp::{ iface::Config, phy::{Loopback, Medium}, @@ -9,7 +11,7 @@ use smoltcp::{ use super::{common::IfaceCommon, internal::IfaceInternal, Iface}; use crate::{ net::{ - iface::time::get_network_timestamp, + iface::{ext::IfaceExt, time::get_network_timestamp}, socket::ip::{IpAddress, Ipv4Address}, }, prelude::*, @@ -29,6 +31,7 @@ pub struct IfaceLoopback { impl IfaceLoopback { pub fn new() -> Arc { let mut loopback = Loopback::new(Medium::Ip); + let interface = { let config = Config::new(smoltcp::wire::HardwareAddress::Ip); let now = get_network_timestamp(); @@ -41,28 +44,27 @@ impl IfaceLoopback { }); interface }; + println!("Loopback ipaddr: {}", interface.ipv4_addr().unwrap()); - let common = IfaceCommon::new(interface); + Arc::new(Self { driver: Mutex::new(loopback), - common, + common: IfaceCommon::new(interface, IfaceExt::new("lo".to_owned())), }) } } -impl IfaceInternal for IfaceLoopback { +impl IfaceInternal for IfaceLoopback { fn common(&self) -> &IfaceCommon { &self.common } } impl Iface for IfaceLoopback { - fn name(&self) -> &str { - "lo" - } - - fn poll(&self) { + fn raw_poll(&self, schedule_next_poll: &dyn Fn(Option)) { let mut device = self.driver.lock(); - self.common.poll(&mut *device); + + let next_poll = self.common.poll(&mut *device); + schedule_next_poll(next_poll); } } diff --git a/kernel/src/net/iface/mod.rs b/kernel/src/net/iface/mod.rs index 552fbb3bc..d0baf3c29 100644 --- a/kernel/src/net/iface/mod.rs +++ b/kernel/src/net/iface/mod.rs @@ -1,12 +1,12 @@ // SPDX-License-Identifier: MPL-2.0 -use ostd::sync::WaitQueue; - use self::common::IfaceCommon; use crate::prelude::*; mod any_socket; mod common; +mod ext; +mod init; mod loopback; mod time; mod util; @@ -16,6 +16,7 @@ pub use any_socket::{ AnyBoundSocket, AnyUnboundSocket, RawTcpSocket, RawUdpSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN, UDP_SEND_PAYLOAD_LEN, }; +pub use init::{init, lazy_init, poll_ifaces, IFACES}; pub use loopback::IfaceLoopback; pub use smoltcp::wire::EthernetAddress; pub use util::{spawn_background_poll_thread, BindPortConfig}; @@ -29,17 +30,21 @@ use crate::net::socket::ip::Ipv4Address; /// computer to a network. Network interfaces can be physical components like Ethernet ports or /// wireless adapters. They can also be virtual interfaces created by software, such as virtual /// private network (VPN) connections. -pub trait Iface: internal::IfaceInternal + Send + Sync { - /// Gets the name of the iface. - /// - /// In Linux, the name is usually the driver name followed by a unit number. - fn name(&self) -> &str; - +pub trait Iface: internal::IfaceInternal + Send + Sync { /// Transmits or receives packets queued in the iface, and updates socket status accordingly. - fn poll(&self); + /// + /// The `schedule_next_poll` callback is invoked with the time at which the next poll should be + /// performed, or `None` if no next poll is required. It's up to the caller to determine the + /// mechanism to ensure that the next poll happens at the right time (e.g. by setting a timer). + fn raw_poll(&self, schedule_next_poll: &dyn Fn(Option)); } -impl dyn Iface { +impl dyn Iface { + /// Gets the extension of the iface. + pub fn ext(&self) -> &E { + self.common().ext() + } + /// Binds a socket to the iface. /// /// After binding the socket to the iface, the iface will handle all packets to and from the @@ -55,7 +60,7 @@ impl dyn Iface { self: &Arc, socket: Box, config: BindPortConfig, - ) -> core::result::Result)> { + ) -> core::result::Result, (Error, Box)> { let common = self.common(); common.bind_socket(self.clone(), socket, config) } @@ -66,23 +71,13 @@ impl dyn Iface { pub fn ipv4_addr(&self) -> Option { self.common().ipv4_addr() } - - /// Gets the wait queue that the background polling thread will sleep on. - fn polling_wait_queue(&self) -> &WaitQueue { - self.common().polling_wait_queue() - } - - /// Gets the time when we should perform another poll. - fn next_poll_at_ms(&self) -> Option { - self.common().next_poll_at_ms() - } } mod internal { use super::*; /// An internal trait that abstracts the common part of different ifaces. - pub trait IfaceInternal { - fn common(&self) -> &IfaceCommon; + pub trait IfaceInternal { + fn common(&self) -> &IfaceCommon; } } diff --git a/kernel/src/net/iface/util.rs b/kernel/src/net/iface/util.rs index bc73a6886..8416bc023 100644 --- a/kernel/src/net/iface/util.rs +++ b/kernel/src/net/iface/util.rs @@ -4,14 +4,13 @@ use core::time::Duration; use ostd::{arch::timer::Jiffies, task::Priority}; -use super::Iface; +use super::{ext::IfaceEx, Iface}; use crate::{ prelude::*, thread::{ kernel_thread::{KernelThreadExt, ThreadOptions}, Thread, }, - time::wait::WaitTimeout, }; pub enum BindPortConfig { @@ -51,12 +50,15 @@ impl BindPortConfig { pub fn spawn_background_poll_thread(iface: Arc) { let task_fn = move || { trace!("spawn background poll thread for {}", iface.name()); - let wait_queue = iface.polling_wait_queue(); + + let iface_ext = iface.ext(); + let wait_queue = iface_ext.polling_wait_queue(); + loop { - let next_poll_at_ms = if let Some(next_poll_at_ms) = iface.next_poll_at_ms() { + let next_poll_at_ms = if let Some(next_poll_at_ms) = iface_ext.next_poll_at_ms() { next_poll_at_ms } else { - wait_queue.wait_until(|| iface.next_poll_at_ms()) + wait_queue.wait_until(|| iface_ext.next_poll_at_ms()) }; let now_as_ms = Jiffies::elapsed().as_duration().as_millis() as u64; @@ -76,8 +78,9 @@ pub fn spawn_background_poll_thread(iface: Arc) { let duration = Duration::from_millis(next_poll_at_ms - now_as_ms); wait_queue.wait_until_or_timeout( - // If `iface.next_poll_at_ms()` changes to an earlier time, we will end the waiting. - || (iface.next_poll_at_ms()? < next_poll_at_ms).then_some(()), + // If `iface_ext.next_poll_at_ms()` changes to an earlier time, we will end the + // waiting. + || (iface_ext.next_poll_at_ms()? < next_poll_at_ms).then_some(()), &duration, ); } diff --git a/kernel/src/net/iface/virtio.rs b/kernel/src/net/iface/virtio.rs index 4d47fcbb2..5afd84d2c 100644 --- a/kernel/src/net/iface/virtio.rs +++ b/kernel/src/net/iface/virtio.rs @@ -1,5 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 +use alloc::borrow::ToOwned; + use aster_network::AnyNetworkDevice; use aster_virtio::device::network::DEVICE_NAME; use ostd::sync::PreemptDisabled; @@ -9,7 +11,9 @@ use smoltcp::{ wire::{self, IpCidr}, }; -use super::{common::IfaceCommon, internal::IfaceInternal, time::get_network_timestamp, Iface}; +use super::{ + common::IfaceCommon, ext::IfaceExt, internal::IfaceInternal, time::get_network_timestamp, Iface, +}; use crate::prelude::*; pub struct IfaceVirtio { @@ -21,6 +25,7 @@ pub struct IfaceVirtio { impl IfaceVirtio { pub fn new() -> Arc { let virtio_net = aster_network::get_device(DEVICE_NAME).unwrap(); + let interface = { let mac_addr = virtio_net.lock().mac_addr(); let ip_addr = IpCidr::new(wire::IpAddress::Ipv4(wire::Ipv4Address::UNSPECIFIED), 0); @@ -37,10 +42,13 @@ impl IfaceVirtio { }); interface }; - let common = IfaceCommon::new(interface); + + let common = IfaceCommon::new(interface, IfaceExt::new("virtio".to_owned())); + let mut socket_set = common.sockets(); let dhcp_handle = init_dhcp_client(&mut socket_set); drop(socket_set); + Arc::new(Self { driver: virtio_net, common, @@ -87,20 +95,19 @@ impl IfaceVirtio { } } -impl IfaceInternal for IfaceVirtio { +impl IfaceInternal for IfaceVirtio { fn common(&self) -> &IfaceCommon { &self.common } } impl Iface for IfaceVirtio { - fn name(&self) -> &str { - "virtio" - } - - fn poll(&self) { + fn raw_poll(&self, schedule_next_poll: &dyn Fn(Option)) { let mut driver = self.driver.disable_irq().lock(); - self.common.poll(&mut *driver); + + let next_poll = self.common.poll(&mut *driver); + schedule_next_poll(next_poll); + self.process_dhcp(); } } diff --git a/kernel/src/net/mod.rs b/kernel/src/net/mod.rs index 6a7d4b665..88e105905 100644 --- a/kernel/src/net/mod.rs +++ b/kernel/src/net/mod.rs @@ -1,47 +1,14 @@ // SPDX-License-Identifier: MPL-2.0 -use spin::Once; - -use self::{iface::spawn_background_poll_thread, socket::vsock}; -use crate::{ - net::iface::{Iface, IfaceLoopback, IfaceVirtio}, - prelude::*, -}; - -pub static IFACES: Once>> = Once::new(); - pub mod iface; pub mod socket; pub fn init() { - IFACES.call_once(|| { - let iface_virtio = IfaceVirtio::new(); - let iface_loopback = IfaceLoopback::new(); - vec![iface_virtio, iface_loopback] - }); - - for (name, _) in aster_network::all_devices() { - aster_network::register_recv_callback(&name, || { - // TODO: further check that the irq num is the same as iface's irq num - let iface_virtio = &IFACES.get().unwrap()[0]; - iface_virtio.poll(); - }) - } - poll_ifaces(); - vsock::init(); + iface::init(); + socket::vsock::init(); } /// Lazy init should be called after spawning init thread. pub fn lazy_init() { - for iface in IFACES.get().unwrap() { - spawn_background_poll_thread(iface.clone()); - } -} - -/// Poll iface -pub fn poll_ifaces() { - let ifaces = IFACES.get().unwrap(); - for iface in ifaces.iter() { - iface.poll(); - } + iface::lazy_init(); } diff --git a/kernel/src/net/socket/ip/common.rs b/kernel/src/net/socket/ip/common.rs index 48be647d6..c60d33644 100644 --- a/kernel/src/net/socket/ip/common.rs +++ b/kernel/src/net/socket/ip/common.rs @@ -2,10 +2,7 @@ use super::{IpAddress, IpEndpoint}; use crate::{ - net::{ - iface::{AnyBoundSocket, AnyUnboundSocket, BindPortConfig, Iface}, - IFACES, - }, + net::iface::{AnyBoundSocket, AnyUnboundSocket, BindPortConfig, Iface, IFACES}, prelude::*, }; diff --git a/kernel/src/net/socket/ip/datagram/mod.rs b/kernel/src/net/socket/ip/datagram/mod.rs index b2d562a99..6420919ac 100644 --- a/kernel/src/net/socket/ip/datagram/mod.rs +++ b/kernel/src/net/socket/ip/datagram/mod.rs @@ -11,7 +11,7 @@ use crate::{ fs::{file_handle::FileLike, utils::StatusFlags}, match_sock_option_mut, net::{ - poll_ifaces, + iface::poll_ifaces, socket::{ options::{Error as SocketError, SocketOption}, util::{ diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 5ca0d552c..bedd310bb 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -17,7 +17,7 @@ use crate::{ fs::{file_handle::FileLike, utils::StatusFlags}, match_sock_option_mut, match_sock_option_ref, net::{ - poll_ifaces, + iface::poll_ifaces, socket::{ options::{Error as SocketError, SocketOption}, util::{