diff --git a/kernel/src/net/iface/device.rs b/kernel/src/net/iface/device.rs new file mode 100644 index 000000000..25725de9c --- /dev/null +++ b/kernel/src/net/iface/device.rs @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MPL-2.0 + +use smoltcp::phy::Device; + +/// A trait that allows to obtain a mutable reference of [`Device`]. +/// +/// A [`Device`] is usually protected by a lock (e.g., a spin lock or a mutex), and it may be +/// stored behind a shared type (e.g., an `Arc`). This property abstracts this fact by providing a +/// method that the caller can use to get the mutable reference without worrying about how the +/// reference is obtained. +pub trait WithDevice: Send + Sync { + type Device: Device + ?Sized; + + /// Calls the closure with a mutable reference of [`Device`]. + fn with(&self, f: F) -> R + where + F: FnOnce(&mut Self::Device) -> R; +} diff --git a/kernel/src/net/iface/virtio.rs b/kernel/src/net/iface/ether.rs similarity index 67% rename from kernel/src/net/iface/virtio.rs rename to kernel/src/net/iface/ether.rs index 5afd84d2c..045dcf615 100644 --- a/kernel/src/net/iface/virtio.rs +++ b/kernel/src/net/iface/ether.rs @@ -1,10 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::borrow::ToOwned; - -use aster_network::AnyNetworkDevice; -use aster_virtio::device::network::DEVICE_NAME; -use ostd::sync::PreemptDisabled; +pub use smoltcp::wire::EthernetAddress; use smoltcp::{ iface::{Config, SocketHandle, SocketSet}, socket::dhcpv4, @@ -12,45 +8,40 @@ use smoltcp::{ }; use super::{ - common::IfaceCommon, ext::IfaceExt, internal::IfaceInternal, time::get_network_timestamp, Iface, + common::IfaceCommon, device::WithDevice, internal::IfaceInternal, time::get_network_timestamp, + Iface, }; use crate::prelude::*; -pub struct IfaceVirtio { - driver: Arc>, - common: IfaceCommon, +pub struct EtherIface { + driver: D, + common: IfaceCommon, dhcp_handle: SocketHandle, } -impl IfaceVirtio { - pub fn new() -> Arc { - let virtio_net = aster_network::get_device(DEVICE_NAME).unwrap(); - - let interface = { - let mac_addr = virtio_net.lock().mac_addr(); +impl EtherIface { + pub fn new(driver: D, ether_addr: EthernetAddress, ext: E) -> Arc { + let interface = driver.with(|device| { let ip_addr = IpCidr::new(wire::IpAddress::Ipv4(wire::Ipv4Address::UNSPECIFIED), 0); - let config = Config::new(wire::HardwareAddress::Ethernet(wire::EthernetAddress( - mac_addr.0, - ))); + let config = Config::new(wire::HardwareAddress::Ethernet(ether_addr)); let now = get_network_timestamp(); - let mut interface = - smoltcp::iface::Interface::new(config, &mut *virtio_net.lock(), now); + let mut interface = smoltcp::iface::Interface::new(config, device, now); interface.update_ip_addrs(|ip_addrs| { debug_assert!(ip_addrs.is_empty()); ip_addrs.push(ip_addr).unwrap(); }); interface - }; + }); - let common = IfaceCommon::new(interface, IfaceExt::new("virtio".to_owned())); + let common = IfaceCommon::new(interface, ext); let mut socket_set = common.sockets(); let dhcp_handle = init_dhcp_client(&mut socket_set); drop(socket_set); Arc::new(Self { - driver: virtio_net, + driver, common, dhcp_handle, }) @@ -95,20 +86,20 @@ impl IfaceVirtio { } } -impl IfaceInternal for IfaceVirtio { - fn common(&self) -> &IfaceCommon { +impl IfaceInternal for EtherIface { + fn common(&self) -> &IfaceCommon { &self.common } } -impl Iface for IfaceVirtio { +impl Iface for EtherIface { fn raw_poll(&self, schedule_next_poll: &dyn Fn(Option)) { - let mut driver = self.driver.disable_irq().lock(); + self.driver.with(|device| { + let next_poll = self.common.poll(&mut *device); + schedule_next_poll(next_poll); - let next_poll = self.common.poll(&mut *driver); - schedule_next_poll(next_poll); - - self.process_dhcp(); + self.process_dhcp(); + }); } } diff --git a/kernel/src/net/iface/init.rs b/kernel/src/net/iface/init.rs index f7572617e..bbf406617 100644 --- a/kernel/src/net/iface/init.rs +++ b/kernel/src/net/iface/init.rs @@ -1,12 +1,16 @@ // SPDX-License-Identifier: MPL-2.0 -use alloc::sync::Arc; +use alloc::{borrow::ToOwned, sync::Arc}; +use ostd::sync::PreemptDisabled; use spin::Once; use super::{spawn_background_poll_thread, Iface}; use crate::{ - net::iface::{ext::IfaceEx, IfaceLoopback, IfaceVirtio}, + net::iface::{ + device::WithDevice, + ext::{IfaceEx, IfaceExt}, + }, prelude::*, }; @@ -14,8 +18,8 @@ pub static IFACES: Once>> = Once::new(); pub fn init() { IFACES.call_once(|| { - let iface_virtio = IfaceVirtio::new(); - let iface_loopback = IfaceLoopback::new(); + let iface_virtio = new_virtio(); + let iface_loopback = new_loopback(); vec![iface_virtio, iface_loopback] }); @@ -30,6 +34,69 @@ pub fn init() { poll_ifaces(); } +fn new_virtio() -> Arc { + use aster_network::AnyNetworkDevice; + use aster_virtio::device::network::DEVICE_NAME; + + use super::ether::{EtherIface, EthernetAddress}; + + let virtio_net = aster_network::get_device(DEVICE_NAME).unwrap(); + + let ether_addr = virtio_net.disable_irq().lock().mac_addr().0; + + struct Wrapper(Arc>); + + impl WithDevice for Wrapper { + type Device = dyn AnyNetworkDevice; + + fn with(&self, f: F) -> R + where + F: FnOnce(&mut Self::Device) -> R, + { + let mut device = self.0.disable_irq().lock(); + f(&mut *device) + } + } + + EtherIface::new( + Wrapper(virtio_net), + EthernetAddress(ether_addr), + IfaceExt::new("virtio".to_owned()), + ) +} + +fn new_loopback() -> Arc { + use smoltcp::phy::{Loopback, Medium}; + + use super::ip::{IpAddress, IpCidr, IpIface, Ipv4Address}; + + const LOOPBACK_ADDRESS: IpAddress = { + let ipv4_addr = Ipv4Address::new(127, 0, 0, 1); + IpAddress::Ipv4(ipv4_addr) + }; + const LOOPBACK_ADDRESS_PREFIX_LEN: u8 = 8; // mask: 255.0.0.0 + + struct Wrapper(Mutex); + + impl WithDevice for Wrapper { + type Device = Loopback; + + fn with(&self, f: F) -> R + where + F: FnOnce(&mut Self::Device) -> R, + { + let mut device = self.0.lock(); + f(&mut device) + } + } + + IpIface::new( + Wrapper(Mutex::new(Loopback::new(Medium::Ip))), + IpCidr::new(LOOPBACK_ADDRESS, LOOPBACK_ADDRESS_PREFIX_LEN), + IfaceExt::new("lo".to_owned()), + ) as _ +} + pub fn lazy_init() { for iface in IFACES.get().unwrap() { spawn_background_poll_thread(iface.clone()); diff --git a/kernel/src/net/iface/ip.rs b/kernel/src/net/iface/ip.rs new file mode 100644 index 000000000..950da42c8 --- /dev/null +++ b/kernel/src/net/iface/ip.rs @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MPL-2.0 + +use smoltcp::iface::Config; +pub use smoltcp::wire::{IpAddress, IpCidr, Ipv4Address}; + +use super::{common::IfaceCommon, device::WithDevice, internal::IfaceInternal, Iface}; +use crate::{net::iface::time::get_network_timestamp, prelude::*}; + +pub struct IpIface { + driver: D, + common: IfaceCommon, +} + +impl IpIface { + pub fn new(driver: D, ip_cidr: IpCidr, ext: E) -> Arc { + let interface = driver.with(|device| { + let config = Config::new(smoltcp::wire::HardwareAddress::Ip); + let now = get_network_timestamp(); + + let mut interface = smoltcp::iface::Interface::new(config, device, now); + interface.update_ip_addrs(|ip_addrs| { + debug_assert!(ip_addrs.is_empty()); + ip_addrs.push(ip_cidr).unwrap(); + }); + interface + }); + + let common = IfaceCommon::new(interface, ext); + + Arc::new(Self { driver, common }) + } +} + +impl IfaceInternal for IpIface { + fn common(&self) -> &IfaceCommon { + &self.common + } +} + +impl Iface for IpIface { + fn raw_poll(&self, schedule_next_poll: &dyn Fn(Option)) { + self.driver.with(|device| { + let next_poll = self.common.poll(device); + schedule_next_poll(next_poll); + }); + } +} diff --git a/kernel/src/net/iface/loopback.rs b/kernel/src/net/iface/loopback.rs deleted file mode 100644 index 57296821b..000000000 --- a/kernel/src/net/iface/loopback.rs +++ /dev/null @@ -1,70 +0,0 @@ -// SPDX-License-Identifier: MPL-2.0 - -use alloc::borrow::ToOwned; - -use smoltcp::{ - iface::Config, - phy::{Loopback, Medium}, - wire::IpCidr, -}; - -use super::{common::IfaceCommon, internal::IfaceInternal, Iface}; -use crate::{ - net::{ - iface::{ext::IfaceExt, time::get_network_timestamp}, - socket::ip::{IpAddress, Ipv4Address}, - }, - prelude::*, -}; - -pub const LOOPBACK_ADDRESS: IpAddress = { - let ipv4_addr = Ipv4Address::new(127, 0, 0, 1); - IpAddress::Ipv4(ipv4_addr) -}; -pub const LOOPBACK_ADDRESS_PREFIX_LEN: u8 = 8; // mask: 255.0.0.0 - -pub struct IfaceLoopback { - driver: Mutex, - common: IfaceCommon, -} - -impl IfaceLoopback { - pub fn new() -> Arc { - let mut loopback = Loopback::new(Medium::Ip); - - let interface = { - let config = Config::new(smoltcp::wire::HardwareAddress::Ip); - let now = get_network_timestamp(); - - let mut interface = smoltcp::iface::Interface::new(config, &mut loopback, now); - interface.update_ip_addrs(|ip_addrs| { - debug_assert!(ip_addrs.is_empty()); - let ip_addr = IpCidr::new(LOOPBACK_ADDRESS, LOOPBACK_ADDRESS_PREFIX_LEN); - ip_addrs.push(ip_addr).unwrap(); - }); - interface - }; - - println!("Loopback ipaddr: {}", interface.ipv4_addr().unwrap()); - - Arc::new(Self { - driver: Mutex::new(loopback), - common: IfaceCommon::new(interface, IfaceExt::new("lo".to_owned())), - }) - } -} - -impl IfaceInternal for IfaceLoopback { - fn common(&self) -> &IfaceCommon { - &self.common - } -} - -impl Iface for IfaceLoopback { - fn raw_poll(&self, schedule_next_poll: &dyn Fn(Option)) { - let mut device = self.driver.lock(); - - let next_poll = self.common.poll(&mut *device); - schedule_next_poll(next_poll); - } -} diff --git a/kernel/src/net/iface/mod.rs b/kernel/src/net/iface/mod.rs index d0baf3c29..72c709f60 100644 --- a/kernel/src/net/iface/mod.rs +++ b/kernel/src/net/iface/mod.rs @@ -5,22 +5,21 @@ use crate::prelude::*; mod any_socket; mod common; +mod device; +mod ether; mod ext; mod init; -mod loopback; +mod ip; mod time; mod util; -mod virtio; pub use any_socket::{ AnyBoundSocket, AnyUnboundSocket, RawTcpSocket, RawUdpSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN, UDP_SEND_PAYLOAD_LEN, }; pub use init::{init, lazy_init, poll_ifaces, IFACES}; -pub use loopback::IfaceLoopback; pub use smoltcp::wire::EthernetAddress; pub use util::{spawn_background_poll_thread, BindPortConfig}; -pub use virtio::IfaceVirtio; use crate::net::socket::ip::Ipv4Address;