diff --git a/kernel/src/net/iface/init.rs b/kernel/src/net/iface/init.rs index b12ed22f..1de4ec47 100644 --- a/kernel/src/net/iface/init.rs +++ b/kernel/src/net/iface/init.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 use alloc::{borrow::ToOwned, sync::Arc}; +use core::slice::Iter; use aster_bigtcp::{ device::WithDevice, @@ -12,7 +13,19 @@ use spin::Once; use super::{poll::poll_ifaces, Iface}; use crate::{net::iface::sched::PollScheduler, prelude::*}; -pub static IFACES: Once>> = Once::new(); +static IFACES: Once>> = Once::new(); + +pub fn loopback_iface() -> &'static Arc { + &IFACES.get().unwrap()[0] +} + +pub fn virtio_iface() -> Option<&'static Arc> { + IFACES.get().unwrap().get(1) +} + +pub fn iter_all_ifaces() -> Iter<'static, Arc> { + IFACES.get().unwrap().iter() +} pub fn init() { IFACES.call_once(|| { @@ -29,14 +42,13 @@ pub fn init() { ifaces }); - for (name, _) in aster_network::all_devices() { - let callback = || { + if let Some(iface_virtio) = virtio_iface() { + for (name, _) in aster_network::all_devices() { // TODO: further check that the irq num is the same as iface's irq num - let iface_virtio = &IFACES.get().unwrap()[0]; - iface_virtio.poll(); - }; - aster_network::register_recv_callback(&name, callback); - aster_network::register_send_callback(&name, callback); + let callback = || iface_virtio.poll(); + aster_network::register_recv_callback(&name, callback); + aster_network::register_send_callback(&name, callback); + } } poll_ifaces(); diff --git a/kernel/src/net/iface/mod.rs b/kernel/src/net/iface/mod.rs index 504738ae..6c9831fe 100644 --- a/kernel/src/net/iface/mod.rs +++ b/kernel/src/net/iface/mod.rs @@ -5,7 +5,7 @@ mod init; mod poll; mod sched; -pub use init::{init, IFACES}; +pub use init::{init, iter_all_ifaces, loopback_iface, virtio_iface}; pub use poll::lazy_init; pub type Iface = dyn aster_bigtcp::iface::Iface; diff --git a/kernel/src/net/iface/poll.rs b/kernel/src/net/iface/poll.rs index 33bcf38c..36fec958 100644 --- a/kernel/src/net/iface/poll.rs +++ b/kernel/src/net/iface/poll.rs @@ -6,7 +6,7 @@ use core::time::Duration; use log::trace; use ostd::timer::Jiffies; -use super::{Iface, IFACES}; +use super::{iter_all_ifaces, Iface}; use crate::{ sched::{Nice, SchedPolicy}, thread::kernel_thread::ThreadOptions, @@ -14,15 +14,13 @@ use crate::{ }; pub fn lazy_init() { - for iface in IFACES.get().unwrap() { + for iface in iter_all_ifaces() { spawn_background_poll_thread(iface.clone()); } } pub(super) fn poll_ifaces() { - let ifaces = IFACES.get().unwrap(); - - for iface in ifaces.iter() { + for iface in iter_all_ifaces() { iface.poll(); } } diff --git a/kernel/src/net/socket/ip/common.rs b/kernel/src/net/socket/ip/common.rs index ce3761a3..91970be2 100644 --- a/kernel/src/net/socket/ip/common.rs +++ b/kernel/src/net/socket/ip/common.rs @@ -7,15 +7,13 @@ use aster_bigtcp::{ }; use crate::{ - net::iface::{BoundPort, Iface, IFACES}, + net::iface::{iter_all_ifaces, loopback_iface, virtio_iface, BoundPort, Iface}, prelude::*, }; pub(super) fn get_iface_to_bind(ip_addr: &IpAddress) -> Option> { - let ifaces = IFACES.get().unwrap(); let IpAddress::Ipv4(ipv4_addr) = ip_addr; - ifaces - .iter() + iter_all_ifaces() .find(|iface| { if let Some(iface_ipv4_addr) = iface.ipv4_addr() { iface_ipv4_addr == *ipv4_addr @@ -30,9 +28,8 @@ pub(super) fn get_iface_to_bind(ip_addr: &IpAddress) -> Option> { /// If the remote address is the same as that of some iface, we will use the iface. /// Otherwise, we will use a default interface. fn get_ephemeral_iface(remote_ip_addr: &IpAddress) -> Arc { - let ifaces = IFACES.get().unwrap(); let IpAddress::Ipv4(remote_ipv4_addr) = remote_ip_addr; - if let Some(iface) = ifaces.iter().find(|iface| { + if let Some(iface) = iter_all_ifaces().find(|iface| { if let Some(iface_ipv4_addr) = iface.ipv4_addr() { iface_ipv4_addr == *remote_ipv4_addr } else { @@ -41,8 +38,14 @@ fn get_ephemeral_iface(remote_ip_addr: &IpAddress) -> Arc { }) { return iface.clone(); } - // FIXME: use the virtio-net as the default interface - ifaces[0].clone() + + // FIXME: Instead of hardcoding the rules here, we should choose the + // default interface according to the routing table. + if let Some(virtio_iface) = virtio_iface() { + virtio_iface.clone() + } else { + loopback_iface().clone() + } } pub(super) fn bind_port(endpoint: &IpEndpoint, can_reuse: bool) -> Result { diff --git a/kernel/src/net/socket/netlink/route/kernel/addr.rs b/kernel/src/net/socket/netlink/route/kernel/addr.rs index e3c4cd21..c97da97c 100644 --- a/kernel/src/net/socket/netlink/route/kernel/addr.rs +++ b/kernel/src/net/socket/netlink/route/kernel/addr.rs @@ -7,7 +7,7 @@ use core::num::NonZeroU32; use super::util::finish_response; use crate::{ net::{ - iface::{Iface, IFACES}, + iface::{iter_all_ifaces, Iface}, socket::netlink::{ message::{CMsgSegHdr, CSegmentType, GetRequestFlags, SegHdrCommonFlags}, route::message::{ @@ -28,9 +28,7 @@ pub(super) fn do_get_addr(request_segment: &AddrSegment) -> Result = ifaces - .iter() + let mut response_segments: Vec = iter_all_ifaces() // GETADDR only supports dump mode, so we're going to report all addresses. .filter_map(|iface| iface_to_new_addr(request_segment.header(), iface)) .map(RtnlSegment::NewAddr) diff --git a/kernel/src/net/socket/netlink/route/kernel/link.rs b/kernel/src/net/socket/netlink/route/kernel/link.rs index 786ac178..a070784b 100644 --- a/kernel/src/net/socket/netlink/route/kernel/link.rs +++ b/kernel/src/net/socket/netlink/route/kernel/link.rs @@ -9,7 +9,7 @@ use aster_bigtcp::iface::InterfaceType; use super::util::finish_response; use crate::{ net::{ - iface::{Iface, IFACES}, + iface::{iter_all_ifaces, Iface}, socket::netlink::{ message::{CMsgSegHdr, CSegmentType, GetRequestFlags, SegHdrCommonFlags}, route::message::{LinkAttr, LinkSegment, LinkSegmentBody, RtnlSegment}, @@ -22,9 +22,7 @@ use crate::{ pub(super) fn do_get_link(request_segment: &LinkSegment) -> Result> { let filter_by = FilterBy::from_request(request_segment)?; - let ifaces = IFACES.get().unwrap(); - let mut response_segments: Vec = ifaces - .iter() + let mut response_segments: Vec = iter_all_ifaces() // Filter to include only requested links. .filter(|iface| match &filter_by { FilterBy::Index(index) => *index == iface.index(),