Remove the shim kernel crate

This commit is contained in:
Zhang Junyang
2024-08-19 19:15:22 +08:00
committed by Tate, Hongliang Tian
parent d76c7a5b1e
commit dafd16075f
416 changed files with 231 additions and 273 deletions

View File

@ -0,0 +1,218 @@
// SPDX-License-Identifier: MPL-2.0
use super::Iface;
use crate::{
events::Observer,
net::socket::ip::{IpAddress, IpEndpoint},
prelude::*,
};
pub type RawTcpSocket = smoltcp::socket::tcp::Socket<'static>;
pub type RawUdpSocket = smoltcp::socket::udp::Socket<'static>;
pub struct AnyUnboundSocket {
socket_family: AnyRawSocket,
observer: Weak<dyn Observer<()>>,
}
#[allow(clippy::large_enum_variant)]
pub(super) enum AnyRawSocket {
Tcp(RawTcpSocket),
Udp(RawUdpSocket),
}
pub(super) enum SocketFamily {
Tcp,
Udp,
}
impl AnyUnboundSocket {
pub fn new_tcp(observer: Weak<dyn Observer<()>>) -> Self {
let raw_tcp_socket = {
let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; RECV_BUF_LEN]);
let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0u8; SEND_BUF_LEN]);
RawTcpSocket::new(rx_buffer, tx_buffer)
};
AnyUnboundSocket {
socket_family: AnyRawSocket::Tcp(raw_tcp_socket),
observer,
}
}
pub fn new_udp(observer: Weak<dyn Observer<()>>) -> Self {
let raw_udp_socket = {
let metadata = smoltcp::socket::udp::PacketMetadata::EMPTY;
let rx_buffer = smoltcp::socket::udp::PacketBuffer::new(
vec![metadata; UDP_METADATA_LEN],
vec![0u8; UDP_RECEIVE_PAYLOAD_LEN],
);
let tx_buffer = smoltcp::socket::udp::PacketBuffer::new(
vec![metadata; UDP_METADATA_LEN],
vec![0u8; UDP_SEND_PAYLOAD_LEN],
);
RawUdpSocket::new(rx_buffer, tx_buffer)
};
AnyUnboundSocket {
socket_family: AnyRawSocket::Udp(raw_udp_socket),
observer,
}
}
pub(super) fn into_raw(self) -> (AnyRawSocket, Weak<dyn Observer<()>>) {
(self.socket_family, self.observer)
}
}
pub struct AnyBoundSocket(Arc<AnyBoundSocketInner>);
impl AnyBoundSocket {
pub(super) fn new(
iface: Arc<dyn Iface>,
handle: smoltcp::iface::SocketHandle,
port: u16,
socket_family: SocketFamily,
observer: Weak<dyn Observer<()>>,
) -> Self {
Self(Arc::new(AnyBoundSocketInner {
iface,
handle,
port,
socket_family,
observer: RwLock::new(observer),
}))
}
pub(super) fn inner(&self) -> &Arc<AnyBoundSocketInner> {
&self.0
}
/// Set the observer whose `on_events` will be called when certain iface events happen. After
/// setting, the new observer will fire once immediately to avoid missing any events.
///
/// If there is an existing observer, due to race conditions, this function does not guarentee
/// that the old observer will never be called after the setting. Users should be aware of this
/// and proactively handle the race conditions if necessary.
pub fn set_observer(&self, handler: Weak<dyn Observer<()>>) {
*self.0.observer.write() = handler;
self.0.on_iface_events();
}
pub fn local_endpoint(&self) -> Option<IpEndpoint> {
let ip_addr = {
let ipv4_addr = self.0.iface.ipv4_addr()?;
IpAddress::Ipv4(ipv4_addr)
};
Some(IpEndpoint::new(ip_addr, self.0.port))
}
pub fn raw_with<T: smoltcp::socket::AnySocket<'static>, R, F: FnMut(&mut T) -> R>(
&self,
f: F,
) -> R {
self.0.raw_with(f)
}
/// Try to connect to a remote endpoint. Tcp socket only.
pub fn do_connect(&self, remote_endpoint: IpEndpoint) -> Result<()> {
let mut sockets = self.0.iface.sockets();
let socket = sockets.get_mut::<RawTcpSocket>(self.0.handle);
let port = self.0.port;
let mut iface_inner = self.0.iface.iface_inner();
let cx = iface_inner.context();
socket
.connect(cx, remote_endpoint, port)
.map_err(|_| Error::with_message(Errno::ENOBUFS, "send connection request failed"))?;
Ok(())
}
pub fn iface(&self) -> &Arc<dyn Iface> {
&self.0.iface
}
}
impl Drop for AnyBoundSocket {
fn drop(&mut self) {
if self.0.start_closing() {
self.0.iface.common().remove_bound_socket_now(&self.0);
} else {
self.0
.iface
.common()
.remove_bound_socket_when_closed(&self.0);
}
}
}
pub(super) struct AnyBoundSocketInner {
iface: Arc<dyn Iface>,
handle: smoltcp::iface::SocketHandle,
port: u16,
socket_family: SocketFamily,
observer: RwLock<Weak<dyn Observer<()>>>,
}
impl AnyBoundSocketInner {
pub(super) fn on_iface_events(&self) {
if let Some(observer) = Weak::upgrade(&*self.observer.read()) {
observer.on_events(&())
}
}
pub(super) fn is_closed(&self) -> bool {
match self.socket_family {
SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| {
socket.state() == smoltcp::socket::tcp::State::Closed
}),
SocketFamily::Udp => true,
}
}
/// Starts closing the socket and returns whether the socket is closed.
///
/// For sockets that can be closed immediately, such as UDP sockets and TCP listening sockets,
/// this method will always return `true`.
///
/// For other sockets, such as TCP connected sockets, they cannot be closed immediately because
/// we at least need to send the FIN packet and wait for the remote end to send an ACK packet.
/// In this case, this method will return `false` and [`Self::is_closed`] can be used to
/// determine if the closing process is complete.
fn start_closing(&self) -> bool {
match self.socket_family {
SocketFamily::Tcp => self.raw_with(|socket: &mut RawTcpSocket| {
socket.close();
socket.state() == smoltcp::socket::tcp::State::Closed
}),
SocketFamily::Udp => {
self.raw_with(|socket: &mut RawUdpSocket| socket.close());
true
}
}
}
pub fn raw_with<T: smoltcp::socket::AnySocket<'static>, R, F: FnMut(&mut T) -> R>(
&self,
mut f: F,
) -> R {
let mut sockets = self.iface.sockets();
let socket = sockets.get_mut::<T>(self.handle);
f(socket)
}
}
impl Drop for AnyBoundSocketInner {
fn drop(&mut self) {
let iface_common = self.iface.common();
iface_common.remove_socket(self.handle);
iface_common.release_port(self.port);
}
}
// For TCP
pub const RECV_BUF_LEN: usize = 65536;
pub const SEND_BUF_LEN: usize = 65536;
// For UDP
const UDP_METADATA_LEN: usize = 256;
const UDP_SEND_PAYLOAD_LEN: usize = 65536;
const UDP_RECEIVE_PAYLOAD_LEN: usize = 65536;

View File

@ -0,0 +1,274 @@
// SPDX-License-Identifier: MPL-2.0
use alloc::collections::btree_map::Entry;
use core::sync::atomic::{AtomicU64, Ordering};
use keyable_arc::KeyableArc;
use ostd::sync::{LocalIrqDisabled, WaitQueue};
use smoltcp::{
iface::{SocketHandle, SocketSet},
phy::Device,
wire::IpCidr,
};
use super::{
any_socket::{AnyBoundSocketInner, AnyRawSocket, AnyUnboundSocket, SocketFamily},
time::get_network_timestamp,
util::BindPortConfig,
AnyBoundSocket, Iface,
};
use crate::{net::socket::ip::Ipv4Address, prelude::*};
pub struct IfaceCommon {
interface: SpinLock<smoltcp::iface::Interface>,
sockets: SpinLock<SocketSet<'static>>,
used_ports: RwLock<BTreeMap<u16, usize>>,
/// The time should do next poll. We stores the total milliseconds since system boots up.
next_poll_at_ms: AtomicU64,
bound_sockets: RwLock<BTreeSet<KeyableArc<AnyBoundSocketInner>>>,
closing_sockets: SpinLock<BTreeSet<KeyableArc<AnyBoundSocketInner>>>,
/// The wait queue that background polling thread will sleep on
polling_wait_queue: WaitQueue,
}
impl IfaceCommon {
pub(super) fn new(interface: smoltcp::iface::Interface) -> Self {
let socket_set = SocketSet::new(Vec::new());
let used_ports = BTreeMap::new();
Self {
interface: SpinLock::new(interface),
sockets: SpinLock::new(socket_set),
used_ports: RwLock::new(used_ports),
next_poll_at_ms: AtomicU64::new(0),
bound_sockets: RwLock::new(BTreeSet::new()),
closing_sockets: SpinLock::new(BTreeSet::new()),
polling_wait_queue: WaitQueue::new(),
}
}
/// Acquires the lock to the interface.
///
/// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second.
pub(super) fn interface(&self) -> SpinLockGuard<smoltcp::iface::Interface, LocalIrqDisabled> {
self.interface.disable_irq().lock()
}
/// Acuqires the lock to the sockets.
///
/// *Lock ordering:* [`Self::sockets`] first, [`Self::interface`] second.
pub(super) fn sockets(
&self,
) -> SpinLockGuard<smoltcp::iface::SocketSet<'static>, LocalIrqDisabled> {
self.sockets.disable_irq().lock()
}
pub(super) fn ipv4_addr(&self) -> Option<Ipv4Address> {
self.interface.disable_irq().lock().ipv4_addr()
}
pub(super) fn netmask(&self) -> Option<Ipv4Address> {
let interface = self.interface.disable_irq().lock();
let ip_addrs = interface.ip_addrs();
ip_addrs.first().map(|cidr| match cidr {
IpCidr::Ipv4(ipv4_cidr) => ipv4_cidr.netmask(),
})
}
pub(super) fn polling_wait_queue(&self) -> &WaitQueue {
&self.polling_wait_queue
}
/// Alloc an unused port range from 49152 ~ 65535 (According to smoltcp docs)
fn alloc_ephemeral_port(&self) -> Result<u16> {
let mut used_ports = self.used_ports.write();
for port in IP_LOCAL_PORT_START..=IP_LOCAL_PORT_END {
if let Entry::Vacant(e) = used_ports.entry(port) {
e.insert(0);
return Ok(port);
}
}
return_errno_with_message!(Errno::EAGAIN, "no ephemeral port is available");
}
fn bind_port(&self, port: u16, can_reuse: bool) -> Result<()> {
let mut used_ports = self.used_ports.write();
if let Some(used_times) = used_ports.get_mut(&port) {
if *used_times == 0 || can_reuse {
*used_times += 1;
} else {
return_errno_with_message!(Errno::EADDRINUSE, "the address is already in use");
}
} else {
used_ports.insert(port, 1);
}
Ok(())
}
/// Release port number so the port can be used again. For reused port, the port may still be in use.
pub(super) fn release_port(&self, port: u16) {
let mut used_ports = self.used_ports.write();
if let Some(used_times) = used_ports.remove(&port) {
if used_times != 1 {
used_ports.insert(port, used_times - 1);
}
}
}
pub(super) fn bind_socket(
&self,
iface: Arc<dyn Iface>,
socket: Box<AnyUnboundSocket>,
config: BindPortConfig,
) -> core::result::Result<AnyBoundSocket, (Error, Box<AnyUnboundSocket>)> {
let port = if let Some(port) = config.port() {
port
} else {
match self.alloc_ephemeral_port() {
Ok(port) => port,
Err(err) => return Err((err, socket)),
}
};
if let Some(err) = self.bind_port(port, config.can_reuse()).err() {
return Err((err, socket));
}
let (handle, socket_family, observer) = match socket.into_raw() {
(AnyRawSocket::Tcp(tcp_socket), observer) => (
self.sockets.disable_irq().lock().add(tcp_socket),
SocketFamily::Tcp,
observer,
),
(AnyRawSocket::Udp(udp_socket), observer) => (
self.sockets.disable_irq().lock().add(udp_socket),
SocketFamily::Udp,
observer,
),
};
let bound_socket = AnyBoundSocket::new(iface, handle, port, socket_family, observer);
self.insert_bound_socket(bound_socket.inner());
Ok(bound_socket)
}
/// Remove a socket from the interface
pub(super) fn remove_socket(&self, handle: SocketHandle) {
self.sockets.disable_irq().lock().remove(handle);
}
pub(super) fn poll<D: Device + ?Sized>(&self, device: &mut D) {
let mut sockets = self.sockets.disable_irq().lock();
let mut interface = self.interface.disable_irq().lock();
let timestamp = get_network_timestamp();
let (has_events, poll_at) = {
let mut has_events = false;
let mut poll_at;
loop {
// `poll` transmits and receives a bounded number of packets. This loop ensures
// that all packets are transmitted and received. For details, see
// <https://github.com/smoltcp-rs/smoltcp/blob/8e3ea5c7f09a76f0a4988fda20cadc74eacdc0d8/src/iface/interface/mod.rs#L400-L405>.
while interface.poll(timestamp, device, &mut sockets) {
has_events = true;
}
// `poll_at` can return `Some(Instant::from_millis(0))`, which means `PollAt::Now`.
// For details, see
// <https://github.com/smoltcp-rs/smoltcp/blob/8e3ea5c7f09a76f0a4988fda20cadc74eacdc0d8/src/iface/interface/mod.rs#L478>.
poll_at = interface.poll_at(timestamp, &sockets);
let Some(instant) = poll_at else {
break;
};
if instant > timestamp {
break;
}
}
(has_events, poll_at)
};
// drop sockets here to avoid deadlock
drop(interface);
drop(sockets);
if let Some(instant) = poll_at {
let old_instant = self.next_poll_at_ms.load(Ordering::Relaxed);
let new_instant = instant.total_millis() as u64;
self.next_poll_at_ms.store(new_instant, Ordering::Relaxed);
if old_instant == 0 || new_instant < old_instant {
self.polling_wait_queue.wake_all();
}
} else {
self.next_poll_at_ms.store(0, Ordering::Relaxed);
}
if has_events {
// We never try to hold the write lock in the IRQ context, and we disable IRQ when
// holding the write lock. So we don't need to disable IRQ when holding the read lock.
self.bound_sockets.read().iter().for_each(|bound_socket| {
bound_socket.on_iface_events();
});
let closed_sockets = self
.closing_sockets
.disable_irq()
.lock()
.extract_if(|closing_socket| closing_socket.is_closed())
.collect::<Vec<_>>();
drop(closed_sockets);
}
}
pub(super) fn next_poll_at_ms(&self) -> Option<u64> {
let millis = self.next_poll_at_ms.load(Ordering::Relaxed);
if millis == 0 {
None
} else {
Some(millis)
}
}
fn insert_bound_socket(&self, socket: &Arc<AnyBoundSocketInner>) {
let keyable_socket = KeyableArc::from(socket.clone());
let inserted = self
.bound_sockets
.write_irq_disabled()
.insert(keyable_socket);
assert!(inserted);
}
pub(super) fn remove_bound_socket_now(&self, socket: &Arc<AnyBoundSocketInner>) {
let keyable_socket = KeyableArc::from(socket.clone());
let removed = self
.bound_sockets
.write_irq_disabled()
.remove(&keyable_socket);
assert!(removed);
}
pub(super) fn remove_bound_socket_when_closed(&self, socket: &Arc<AnyBoundSocketInner>) {
let keyable_socket = KeyableArc::from(socket.clone());
let removed = self
.bound_sockets
.write_irq_disabled()
.remove(&keyable_socket);
assert!(removed);
let mut closing_sockets = self.closing_sockets.disable_irq().lock();
// Check `is_closed` after holding the lock to avoid race conditions.
if keyable_socket.is_closed() {
return;
}
let inserted = closing_sockets.insert(keyable_socket);
assert!(inserted);
}
}
const IP_LOCAL_PORT_START: u16 = 49152;
const IP_LOCAL_PORT_END: u16 = 65535;

View File

@ -0,0 +1,78 @@
// SPDX-License-Identifier: MPL-2.0
use smoltcp::{
iface::Config,
phy::{Loopback, Medium},
wire::IpCidr,
};
use super::{common::IfaceCommon, internal::IfaceInternal, Iface};
use crate::{
net::{
iface::time::get_network_timestamp,
socket::ip::{IpAddress, Ipv4Address},
},
prelude::*,
};
pub const LOOPBACK_ADDRESS: IpAddress = {
let ipv4_addr = Ipv4Address::new(127, 0, 0, 1);
IpAddress::Ipv4(ipv4_addr)
};
pub const LOOPBACK_ADDRESS_PREFIX_LEN: u8 = 8; // mask: 255.0.0.0
pub struct IfaceLoopback {
driver: Mutex<Loopback>,
common: IfaceCommon,
weak_self: Weak<Self>,
}
impl IfaceLoopback {
pub fn new() -> Arc<Self> {
let mut loopback = Loopback::new(Medium::Ip);
let interface = {
let config = Config::new(smoltcp::wire::HardwareAddress::Ip);
let now = get_network_timestamp();
let mut interface = smoltcp::iface::Interface::new(config, &mut loopback, now);
interface.update_ip_addrs(|ip_addrs| {
debug_assert!(ip_addrs.is_empty());
let ip_addr = IpCidr::new(LOOPBACK_ADDRESS, LOOPBACK_ADDRESS_PREFIX_LEN);
ip_addrs.push(ip_addr).unwrap();
});
interface
};
println!("Loopback ipaddr: {}", interface.ipv4_addr().unwrap());
let common = IfaceCommon::new(interface);
Arc::new_cyclic(|weak| Self {
driver: Mutex::new(loopback),
common,
weak_self: weak.clone(),
})
}
}
impl IfaceInternal for IfaceLoopback {
fn common(&self) -> &IfaceCommon {
&self.common
}
fn arc_self(&self) -> Arc<dyn Iface> {
self.weak_self.upgrade().unwrap()
}
}
impl Iface for IfaceLoopback {
fn name(&self) -> &str {
"lo"
}
fn mac_addr(&self) -> Option<smoltcp::wire::EthernetAddress> {
None
}
fn poll(&self) {
let mut device = self.driver.lock();
self.common.poll(&mut *device);
}
}

View File

@ -0,0 +1,94 @@
// SPDX-License-Identifier: MPL-2.0
use ostd::sync::WaitQueue;
use smoltcp::iface::SocketSet;
use self::common::IfaceCommon;
use crate::prelude::*;
mod any_socket;
mod common;
mod loopback;
mod time;
mod util;
mod virtio;
pub use any_socket::{
AnyBoundSocket, AnyUnboundSocket, RawTcpSocket, RawUdpSocket, RECV_BUF_LEN, SEND_BUF_LEN,
};
pub use loopback::IfaceLoopback;
use ostd::sync::LocalIrqDisabled;
pub use smoltcp::wire::EthernetAddress;
pub use util::{spawn_background_poll_thread, BindPortConfig};
pub use virtio::IfaceVirtio;
use crate::net::socket::ip::Ipv4Address;
/// Network interface.
///
/// A network interface (abbreviated as iface) is a hardware or software component that connects a device or computer to a network.
/// Network interfaces can be physical components like Ethernet ports or wireless adapters,
/// or they can be virtual interfaces created by software such as virtual private network (VPN) connections.
pub trait Iface: internal::IfaceInternal + Send + Sync {
/// The iface name. For linux, usually the driver name followed by a unit number.
fn name(&self) -> &str;
/// The optional mac address
fn mac_addr(&self) -> Option<EthernetAddress>;
/// Transmit packets queued in the iface, and receive packets queued in the iface.
/// It any event happens, this function will also update socket status.
fn poll(&self);
/// Bind a socket to the iface. So the packet for this socket will be dealt with by the interface.
/// If port is None, the iface will pick up an empheral port for the socket.
/// FIXME: The reason for binding socket and interface together is because there are limitations inside smoltcp.
/// See discussion at <https://github.com/smoltcp-rs/smoltcp/issues/779>.
fn bind_socket(
&self,
socket: Box<AnyUnboundSocket>,
config: BindPortConfig,
) -> core::result::Result<AnyBoundSocket, (Error, Box<AnyUnboundSocket>)> {
let common = self.common();
common.bind_socket(self.arc_self(), socket, config)
}
/// The optional ipv4 address
/// FIXME: An interface indeed support multiple addresses
fn ipv4_addr(&self) -> Option<Ipv4Address> {
self.common().ipv4_addr()
}
/// The netmask.
/// FIXME: The netmask and IP address should be one-to-one if there are multiple ip address
fn netmask(&self) -> Option<Ipv4Address> {
self.common().netmask()
}
/// The waitqueue used to background polling thread
fn polling_wait_queue(&self) -> &WaitQueue {
self.common().polling_wait_queue()
}
}
mod internal {
use super::*;
/// A helper trait
pub trait IfaceInternal {
fn common(&self) -> &IfaceCommon;
/// The inner socket set
fn sockets(&self) -> SpinLockGuard<SocketSet<'static>, LocalIrqDisabled> {
self.common().sockets()
}
/// The inner iface.
fn iface_inner(&self) -> SpinLockGuard<smoltcp::iface::Interface, LocalIrqDisabled> {
self.common().interface()
}
/// The time we should do another poll.
fn next_poll_at_ms(&self) -> Option<u64> {
self.common().next_poll_at_ms()
}
fn arc_self(&self) -> Arc<dyn Iface>;
}
}

View File

@ -0,0 +1,8 @@
// SPDX-License-Identifier: MPL-2.0
use ostd::arch::timer::Jiffies;
pub(super) fn get_network_timestamp() -> smoltcp::time::Instant {
let millis = Jiffies::elapsed().as_duration().as_millis();
smoltcp::time::Instant::from_millis(millis as i64)
}

View File

@ -0,0 +1,88 @@
// SPDX-License-Identifier: MPL-2.0
use core::time::Duration;
use ostd::{arch::timer::Jiffies, task::Priority};
use super::Iface;
use crate::{
prelude::*,
thread::{
kernel_thread::{KernelThreadExt, ThreadOptions},
Thread,
},
time::wait::WaitTimeout,
};
pub enum BindPortConfig {
CanReuse(u16),
Specified(u16),
Ephemeral,
}
impl BindPortConfig {
pub fn new(port: u16, can_reuse: bool) -> Result<Self> {
let config = if port != 0 {
if can_reuse {
Self::CanReuse(port)
} else {
Self::Specified(port)
}
} else if can_reuse {
return_errno_with_message!(Errno::EINVAL, "invalid bind port config");
} else {
Self::Ephemeral
};
Ok(config)
}
pub(super) fn can_reuse(&self) -> bool {
matches!(self, Self::CanReuse(_))
}
pub(super) fn port(&self) -> Option<u16> {
match self {
Self::CanReuse(port) | Self::Specified(port) => Some(*port),
Self::Ephemeral => None,
}
}
}
pub fn spawn_background_poll_thread(iface: Arc<dyn Iface>) {
let task_fn = move || {
trace!("spawn background poll thread for {}", iface.name());
let wait_queue = iface.polling_wait_queue();
loop {
let next_poll_at_ms = if let Some(next_poll_at_ms) = iface.next_poll_at_ms() {
next_poll_at_ms
} else {
wait_queue.wait_until(|| iface.next_poll_at_ms())
};
let now_as_ms = Jiffies::elapsed().as_duration().as_millis() as u64;
// FIXME: Ideally, we should perform the `poll` just before `next_poll_at_ms`.
// However, this approach may result in a spinning busy loop
// if the `poll` operation yields no results.
// To mitigate this issue,
// we have opted to assign a high priority to the polling thread,
// ensuring that the `poll` runs as soon as possible.
// For a more in-depth discussion, please refer to the following link:
// <https://github.com/asterinas/asterinas/pull/630#discussion_r1496817030>.
if now_as_ms >= next_poll_at_ms {
iface.poll();
continue;
}
let duration = Duration::from_millis(next_poll_at_ms - now_as_ms);
wait_queue.wait_until_or_timeout(
// If `iface.next_poll_at_ms()` changes to an earlier time, we will end the waiting.
|| (iface.next_poll_at_ms()? < next_poll_at_ms).then_some(()),
&duration,
);
}
};
let options = ThreadOptions::new(task_fn).priority(Priority::high());
Thread::spawn_kernel_thread(options);
}

View File

@ -0,0 +1,127 @@
// SPDX-License-Identifier: MPL-2.0
use aster_network::AnyNetworkDevice;
use aster_virtio::device::network::DEVICE_NAME;
use ostd::sync::PreemptDisabled;
use smoltcp::{
iface::{Config, SocketHandle, SocketSet},
socket::dhcpv4,
wire::{self, IpCidr},
};
use super::{common::IfaceCommon, internal::IfaceInternal, time::get_network_timestamp, Iface};
use crate::prelude::*;
pub struct IfaceVirtio {
driver: Arc<SpinLock<dyn AnyNetworkDevice, PreemptDisabled>>,
common: IfaceCommon,
dhcp_handle: SocketHandle,
weak_self: Weak<Self>,
}
impl IfaceVirtio {
pub fn new() -> Arc<Self> {
let virtio_net = aster_network::get_device(DEVICE_NAME).unwrap();
let interface = {
let mac_addr = virtio_net.lock().mac_addr();
let ip_addr = IpCidr::new(wire::IpAddress::Ipv4(wire::Ipv4Address::UNSPECIFIED), 0);
let config = Config::new(wire::HardwareAddress::Ethernet(wire::EthernetAddress(
mac_addr.0,
)));
let now = get_network_timestamp();
let mut interface =
smoltcp::iface::Interface::new(config, &mut *virtio_net.lock(), now);
interface.update_ip_addrs(|ip_addrs| {
debug_assert!(ip_addrs.is_empty());
ip_addrs.push(ip_addr).unwrap();
});
interface
};
let common = IfaceCommon::new(interface);
let mut socket_set = common.sockets();
let dhcp_handle = init_dhcp_client(&mut socket_set);
drop(socket_set);
Arc::new_cyclic(|weak| Self {
driver: virtio_net,
common,
dhcp_handle,
weak_self: weak.clone(),
})
}
/// FIXME: Once we have user program dhcp client, we may remove dhcp logic from kernel.
pub fn process_dhcp(&self) {
let mut socket_set = self.common.sockets();
let dhcp_socket: &mut dhcpv4::Socket = socket_set.get_mut(self.dhcp_handle);
let config = if let Some(event) = dhcp_socket.poll() {
debug!("event = {:?}", event);
if let dhcpv4::Event::Configured(config) = event {
config
} else {
return;
}
} else {
return;
};
let ip_addr = IpCidr::Ipv4(config.address);
let mut interface = self.common.interface();
interface.update_ip_addrs(|ipaddrs| {
if let Some(addr) = ipaddrs.iter_mut().next() {
// already has ipaddrs
*addr = ip_addr
} else {
// does not has ip addr
ipaddrs.push(ip_addr).unwrap();
}
});
println!(
"DHCP update IP address: {:?}",
interface.ipv4_addr().unwrap()
);
if let Some(router) = config.router {
println!("Default router address: {:?}", router);
interface
.routes_mut()
.add_default_ipv4_route(router)
.unwrap();
}
}
}
impl IfaceInternal for IfaceVirtio {
fn common(&self) -> &IfaceCommon {
&self.common
}
fn arc_self(&self) -> Arc<dyn Iface> {
self.weak_self.upgrade().unwrap()
}
}
impl Iface for IfaceVirtio {
fn name(&self) -> &str {
"virtio"
}
fn mac_addr(&self) -> Option<smoltcp::wire::EthernetAddress> {
let interface = self.common.interface();
let hardware_addr = interface.hardware_addr();
match hardware_addr {
wire::HardwareAddress::Ethernet(ethe_address) => Some(ethe_address),
wire::HardwareAddress::Ip => None,
}
}
fn poll(&self) {
let mut driver = self.driver.disable_irq().lock();
self.common.poll(&mut *driver);
self.process_dhcp();
}
}
/// Register a dhcp socket.
fn init_dhcp_client(socket_set: &mut SocketSet) -> SocketHandle {
let dhcp_socket = dhcpv4::Socket::new();
socket_set.add(dhcp_socket)
}