diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index efec551c..69db894a 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -52,7 +52,7 @@ linkme = "=0.3.27" num = { version = "=0.4.0", default-features = false } num-derive = "=0.3" num-traits = { git = "https://git.mirrors.dragonos.org.cn/DragonOS-Community/num-traits.git", rev="1597c1c", default-features = false } -smoltcp = { git = "https://git.mirrors.dragonos.org.cn/DragonOS-Community/smoltcp.git", rev = "3e61c909fd540d05575068d16dc4574e196499ed", default-features = false, features = ["log", "alloc", "socket-raw", "socket-udp", "socket-tcp", "socket-icmp", "socket-dhcpv4", "socket-dns", "proto-ipv4", "proto-ipv6"]} +smoltcp = { git = "https://git.mirrors.dragonos.org.cn/DragonOS-Community/smoltcp.git", rev = "3e61c909fd540d05575068d16dc4574e196499ed", default-features = false, features = ["log", "alloc", "socket-raw", "socket-udp", "socket-tcp", "socket-icmp", "socket-dhcpv4", "socket-dns", "proto-ipv4", "proto-ipv6", "medium-ip"]} system_error = { path = "crates/system_error" } uefi = { version = "=0.26.0", features = ["alloc"] } uefi-raw = "=0.5.0" diff --git a/kernel/src/driver/net/loopback.rs b/kernel/src/driver/net/loopback.rs index 8a229eec..441f2f12 100644 --- a/kernel/src/driver/net/loopback.rs +++ b/kernel/src/driver/net/loopback.rs @@ -204,7 +204,7 @@ impl phy::Device for LoopbackDriver { let mut result = phy::DeviceCapabilities::default(); result.max_transmission_unit = 65535; result.max_burst_size = Some(1); - result.medium = smoltcp::phy::Medium::Ethernet; + result.medium = smoltcp::phy::Medium::Ip; return result; } /// ## Loopback驱动处理接受数据事件 @@ -284,9 +284,11 @@ impl LoopbackInterface { pub fn new(mut driver: LoopbackDriver) -> Arc { let iface_id = generate_iface_id(); - let hardware_addr = HardwareAddress::Ethernet(smoltcp::wire::EthernetAddress([ - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - ])); + // let hardware_addr = HardwareAddress::Ethernet(smoltcp::wire::EthernetAddress([ + // 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // ])); + + let hardware_addr = HardwareAddress::Ip; let mut iface_config = smoltcp::iface::Config::new(hardware_addr); diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index 80f09dd5..4b3cfe0f 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -285,6 +285,14 @@ impl IfaceCommon { self.bounds.write().push(socket); } + pub fn unbind_socket(&self, socket: Arc) { + let mut bounds = self.bounds.write(); + if let Some(index) = bounds.iter().position(|s| Arc::ptr_eq(s, &socket)) { + bounds.remove(index); + log::debug!("unbind socket success"); + } + } + // TODO: 需要在inet实现多网卡监听或路由子系统实现后移除 pub fn is_default_iface(&self) -> bool { self.default_iface diff --git a/kernel/src/net/net_core.rs b/kernel/src/net/net_core.rs index 142fbf84..9038e7de 100644 --- a/kernel/src/net/net_core.rs +++ b/kernel/src/net/net_core.rs @@ -1,4 +1,4 @@ -use alloc::{boxed::Box, collections::BTreeMap, sync::Arc}; +use alloc::{collections::BTreeMap, sync::Arc}; use log::{debug, info, warn}; use smoltcp::{socket::dhcpv4, wire}; use system_error::SystemError; @@ -7,45 +7,23 @@ use crate::{ driver::net::{Iface, Operstate}, libs::rwlock::RwLockReadGuard, net::NET_DEVICES, - time::{ - sleep::nanosleep, - timer::{next_n_ms_timer_jiffies, Timer, TimerFunction}, - PosixTimeSpec, - }, + time::{sleep::nanosleep, PosixTimeSpec}, }; -/// The network poll function, which will be called by timer. -/// -/// The main purpose of this function is to poll all network interfaces. -#[derive(Debug)] -#[allow(dead_code)] -struct NetWorkPollFunc; - -impl TimerFunction for NetWorkPollFunc { - fn run(&mut self) -> Result<(), SystemError> { - poll_ifaces(); - let next_time = next_n_ms_timer_jiffies(10); - let timer = Timer::new(Box::new(NetWorkPollFunc), next_time); - timer.activate(); - return Ok(()); - } -} - pub fn net_init() -> Result<(), SystemError> { - dhcp_query()?; - // Init poll timer function - // let next_time = next_n_ms_timer_jiffies(5); - // let timer = Timer::new(Box::new(NetWorkPollFunc), next_time); - // timer.activate(); - return Ok(()); + dhcp_query() } fn dhcp_query() -> Result<(), SystemError> { let binding = NET_DEVICES.write_irqsave(); - // log::debug!("binding: {:?}", *binding); - //由于现在os未实现在用户态为网卡动态分配内存,而lo网卡的id最先分配且ip固定不能被分配 - //所以特判取用id为1的网卡(也就是virtio_net) - let net_face = binding.get(&1).ok_or(SystemError::ENODEV)?.clone(); + + // Default iface, misspelled to net_face + let net_face = binding + .iter() + .find(|(_, iface)| iface.common().is_default_iface()) + .unwrap() + .1 + .clone(); drop(binding); @@ -60,8 +38,10 @@ fn dhcp_query() -> Result<(), SystemError> { let sockets = || net_face.sockets().lock_irqsave(); - // let dhcp_handle = SOCKET_SET.lock_irqsave().add(dhcp_socket); let dhcp_handle = sockets().add(dhcp_socket); + defer::defer!({ + sockets().remove(dhcp_handle); + }); const DHCP_TRY_ROUND: u8 = 100; for i in 0..DHCP_TRY_ROUND { @@ -147,7 +127,7 @@ fn dhcp_query() -> Result<(), SystemError> { } pub fn poll_ifaces() { - log::debug!("poll_ifaces"); + // log::debug!("poll_ifaces"); let guard: RwLockReadGuard>> = NET_DEVICES.read_irqsave(); if guard.len() == 0 { warn!("poll_ifaces: No net driver found!"); diff --git a/kernel/src/net/socket/common/shutdown.rs b/kernel/src/net/socket/common/shutdown.rs index 096cb43d..847a8a0d 100644 --- a/kernel/src/net/socket/common/shutdown.rs +++ b/kernel/src/net/socket/common/shutdown.rs @@ -124,7 +124,7 @@ impl TryFrom for ShutdownTemp { fn try_from(value: usize) -> Result { match value { - 0 | 1 | 2 => Ok(ShutdownTemp { + 0..2 => Ok(ShutdownTemp { bit: value as u8 + 1, }), _ => Err(SystemError::EINVAL), diff --git a/kernel/src/net/socket/inet/common/mod.rs b/kernel/src/net/socket/inet/common/mod.rs index 2fa69613..455503fb 100644 --- a/kernel/src/net/socket/inet/common/mod.rs +++ b/kernel/src/net/socket/inet/common/mod.rs @@ -53,11 +53,11 @@ impl BoundInner { }) .expect("No default interface"); - let handle = iface.sockets().lock_no_preempt().add(socket); + let handle = iface.sockets().lock_irqsave().add(socket); return Ok(Self { handle, iface }); } else { let iface = get_iface_to_bind(address).ok_or(ENODEV)?; - let handle = iface.sockets().lock_no_preempt().add(socket); + let handle = iface.sockets().lock_irqsave().add(socket); return Ok(Self { handle, iface }); } } diff --git a/kernel/src/net/socket/inet/datagram/inner.rs b/kernel/src/net/socket/inet/datagram/inner.rs index 8c964a64..cbd49a57 100644 --- a/kernel/src/net/socket/inet/datagram/inner.rs +++ b/kernel/src/net/socket/inet/datagram/inner.rs @@ -33,16 +33,14 @@ impl UnboundUdp { } pub fn bind(self, local_endpoint: smoltcp::wire::IpEndpoint) -> Result { - // let (addr, port) = (local_endpoint.addr, local_endpoint.port); - // if self.socket.bind(local_endpoint).is_err() { - // log::debug!("bind failed!"); - // return Err(EINVAL); - // } let inner = BoundInner::bind(self.socket, &local_endpoint.addr)?; let bind_addr = local_endpoint.addr; let bind_port = if local_endpoint.port == 0 { inner.port_manager().bind_ephemeral_port(InetTypes::Udp)? } else { + inner + .port_manager() + .bind_port(InetTypes::Udp, local_endpoint.port)?; local_endpoint.port }; @@ -77,10 +75,6 @@ impl UnboundUdp { remote: SpinLock::new(Some(endpoint)), }) } - - pub fn close(&mut self) { - self.socket.close(); - } } #[derive(Debug)] diff --git a/kernel/src/net/socket/inet/datagram/mod.rs b/kernel/src/net/socket/inet/datagram/mod.rs index 44bb9bb2..dbf6a630 100644 --- a/kernel/src/net/socket/inet/datagram/mod.rs +++ b/kernel/src/net/socket/inet/datagram/mod.rs @@ -78,28 +78,31 @@ impl UdpSocket { bound.close(); inner.take(); } + // unbound socket just drop (only need to free memory) } pub fn try_recv( &self, buf: &mut [u8], ) -> Result<(usize, smoltcp::wire::IpEndpoint), SystemError> { - let received = match self.inner.read().as_ref().expect("Udp Inner is None") { - UdpInner::Bound(bound) => bound.try_recv(buf), + match self.inner.read().as_ref().expect("Udp Inner is None") { + UdpInner::Bound(bound) => { + let ret = bound.try_recv(buf); + poll_ifaces(); + ret + } _ => Err(ENOTCONN), - }; - poll_ifaces(); - return received; + } } #[inline] pub fn can_recv(&self) -> bool { - self.on_events().contains(EP::EPOLLIN) + self.event().contains(EP::EPOLLIN) } #[inline] pub fn can_send(&self) -> bool { - self.on_events().contains(EP::EPOLLOUT) + self.event().contains(EP::EPOLLOUT) } pub fn try_send( @@ -138,7 +141,7 @@ impl UdpSocket { } } - pub fn on_events(&self) -> EPollEventType { + pub fn event(&self) -> EPollEventType { let mut event = EPollEventType::empty(); match self.inner.read().as_ref().unwrap() { UdpInner::Unbound(_) => { @@ -154,8 +157,6 @@ impl UdpSocket { if can_send { event.insert(EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND); - } else { - todo!("缓冲区空间不够,需要使用信号处理"); } } } @@ -169,7 +170,7 @@ impl Socket for UdpSocket { } fn poll(&self) -> usize { - self.on_events().bits() as usize + self.event().bits() as usize } fn bind(&self, local_endpoint: Endpoint) -> Result<(), SystemError> { @@ -195,7 +196,9 @@ impl Socket for UdpSocket { fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> { if let Endpoint::Ip(remote) = endpoint { - self.bind_emphemeral(remote.addr)?; + if !self.is_bound() { + self.bind_emphemeral(remote.addr)?; + } if let UdpInner::Bound(inner) = self.inner.read().as_ref().expect("UDP Inner disappear") { inner.connect(remote); @@ -272,6 +275,11 @@ impl Socket for UdpSocket { } .map(|(len, remote)| (len, Endpoint::Ip(remote))); } + + fn close(&self) -> Result<(), SystemError> { + self.close(); + Ok(()) + } } impl InetSocket for UdpSocket { diff --git a/kernel/src/net/socket/inet/stream/inner.rs b/kernel/src/net/socket/inet/stream/inner.rs index 241f7b3c..5146a913 100644 --- a/kernel/src/net/socket/inet/stream/inner.rs +++ b/kernel/src/net/socket/inet/stream/inner.rs @@ -268,6 +268,15 @@ impl Connecting { .expect("A Connecting Tcp With No Local Endpoint") }) } + + pub fn get_peer_name(&self) -> smoltcp::wire::IpEndpoint { + self.inner + .with::(|socket| { + socket + .remote_endpoint() + .expect("A Connecting Tcp With No Remote Endpoint") + }) + } } #[derive(Debug)] @@ -355,6 +364,13 @@ impl Listening { .port_manager() .unbind_port(Types::Tcp, port); } + + pub fn release(&self) { + // log::debug!("Release Listening Socket"); + for inner in self.inners.iter() { + inner.release(); + } + } } #[derive(Debug)] @@ -370,10 +386,6 @@ impl Established { self.inner.with_mut(f) } - pub fn with) -> R>(&self, f: F) -> R { - self.inner.with(f) - } - pub fn close(&self) { self.inner .with_mut::(|socket| socket.close()); @@ -384,13 +396,13 @@ impl Established { self.inner.release(); } - pub fn local_endpoint(&self) -> smoltcp::wire::IpEndpoint { + pub fn get_name(&self) -> smoltcp::wire::IpEndpoint { self.inner .with::(|socket| socket.local_endpoint()) .unwrap() } - pub fn remote_endpoint(&self) -> smoltcp::wire::IpEndpoint { + pub fn get_peer_name(&self) -> smoltcp::wire::IpEndpoint { self.inner .with::(|socket| socket.remote_endpoint().unwrap()) } diff --git a/kernel/src/net/socket/inet/stream/mod.rs b/kernel/src/net/socket/inet/stream/mod.rs index 3cde0925..97b56752 100644 --- a/kernel/src/net/socket/inet/stream/mod.rs +++ b/kernel/src/net/socket/inet/stream/mod.rs @@ -99,7 +99,6 @@ impl TcpSocket { } pub fn try_accept(&self) -> Result<(Arc, smoltcp::wire::IpEndpoint), SystemError> { - // poll_ifaces(); match self.inner.write().as_mut().expect("Tcp Inner is None") { Inner::Listening(listening) => listening.accept().map(|(stream, remote)| { ( @@ -227,16 +226,9 @@ impl TcpSocket { } } - fn in_notify(&self) -> bool { - self.update_events(); - // shouldn't pollee but just get the status of the socket + fn incoming(&self) -> bool { EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLIN) } - - fn out_notify(&self) -> bool { - self.update_events(); - EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLOUT) - } } impl Socket for TcpSocket { @@ -252,16 +244,25 @@ impl Socket for TcpSocket { })), Inner::Init(Init::Bound((_, local))) => Ok(Endpoint::Ip(*local)), Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_name())), - Inner::Established(established) => Ok(Endpoint::Ip(established.local_endpoint())), + Inner::Established(established) => Ok(Endpoint::Ip(established.get_name())), Inner::Listening(listening) => Ok(Endpoint::Ip(listening.get_name())), } } + fn get_peer_name(&self) -> Result { + match self.inner.read().as_ref().expect("Tcp Inner is None") { + Inner::Init(_) => Err(ENOTCONN), + Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_peer_name())), + Inner::Established(established) => Ok(Endpoint::Ip(established.get_peer_name())), + Inner::Listening(_) => Err(ENOTCONN), + } + } + fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> { if let Endpoint::Ip(addr) = endpoint { return self.do_bind(addr); } - log::warn!("TcpSocket::bind: invalid endpoint"); + log::debug!("TcpSocket::bind: invalid endpoint"); return Err(EINVAL); } @@ -295,7 +296,7 @@ impl Socket for TcpSocket { loop { match self.try_accept() { Err(EAGAIN_OR_EWOULDBLOCK) => { - wq_wait_event_interruptible!(self.wait_queue, self.in_notify(), {})?; + wq_wait_event_interruptible!(self.wait_queue, self.incoming(), {})?; } result => break result, } @@ -348,7 +349,15 @@ impl Socket for TcpSocket { } fn close(&self) -> Result<(), SystemError> { - let inner = self.inner.write().take().unwrap(); + let Some(inner) = self.inner.write().take() else { + log::warn!("TcpSocket::close: already closed, unexpected"); + return Ok(()); + }; + if let Some(iface) = inner.iface() { + iface + .common() + .unbind_socket(self.self_ref.upgrade().unwrap()); + } match inner { // complete connecting socket close logic @@ -356,22 +365,21 @@ impl Socket for TcpSocket { let conn = unsafe { conn.into_established() }; conn.close(); conn.release(); - Ok(()) } Inner::Established(es) => { es.close(); es.release(); - Ok(()) } Inner::Listening(ls) => { ls.close(); - Ok(()) + ls.release(); } Inner::Init(init) => { init.close(); - Ok(()) } - } + }; + + Ok(()) } fn set_option(&self, level: PSOL, name: usize, val: &[u8]) -> Result<(), SystemError> { diff --git a/kernel/src/net/socket/unix/seqpacket/mod.rs b/kernel/src/net/socket/unix/seqpacket/mod.rs index 1f44f29c..510c8ad7 100644 --- a/kernel/src/net/socket/unix/seqpacket/mod.rs +++ b/kernel/src/net/socket/unix/seqpacket/mod.rs @@ -290,19 +290,58 @@ impl Socket for SeqpacketSocket { self.shutdown.recv_shutdown(); self.shutdown.send_shutdown(); - let path = match self.get_name()? { + let endpoint = self.get_name()?; + let path = match &endpoint { Endpoint::Inode((_, path)) => path, + Endpoint::Unixpath((_, path)) => path, + Endpoint::Abspath((_, path)) => path, _ => return Err(SystemError::EINVAL), }; - //如果path是空的说明没有bind,不用释放相关映射资源 if path.is_empty() { return Ok(()); } - // TODO: 释放INODE_MAP相关资源 - // 尝试释放相关抽象地址资源 - let _ = remove_abs_addr(&path); + match &endpoint { + Endpoint::Unixpath((inode_id, _)) => { + let mut inode_guard = INODE_MAP.write_irqsave(); + inode_guard.remove(inode_id); + } + Endpoint::Inode((current_inode, current_path)) => { + let mut inode_guard = INODE_MAP.write_irqsave(); + // 遍历查找匹配的条目 + let target_entry = inode_guard + .iter() + .find(|(_, ep)| { + if let Endpoint::Inode((map_inode, map_path)) = ep { + // 通过指针相等性比较确保是同一对象 + Arc::ptr_eq(map_inode, current_inode) && map_path == current_path + } else { + log::debug!("not match"); + false + } + }) + .map(|(id, _)| *id); + + if let Some(id) = target_entry { + inode_guard.remove(&id).ok_or(SystemError::EINVAL)?; + } + } + Endpoint::Abspath((abshandle, _)) => { + let mut abs_inode_map = ABS_INODE_MAP.lock_irqsave(); + abs_inode_map.remove(&abshandle.name()); + } + _ => { + log::error!("invalid endpoint type"); + return Err(SystemError::EINVAL); + } + } + + *self.inner.write() = Inner::Init(Init::new()); + self.wait_queue.wakeup(None); + + let _ = remove_abs_addr(path); + return Ok(()); } @@ -471,12 +510,12 @@ impl Socket for SeqpacketSocket { } fn send_buffer_size(&self) -> usize { - log::warn!("using default buffer size"); + // log::warn!("using default buffer size"); SeqpacketSocket::DEFAULT_BUF_SIZE } fn recv_buffer_size(&self) -> usize { - log::warn!("using default buffer size"); + // log::warn!("using default buffer size"); SeqpacketSocket::DEFAULT_BUF_SIZE } diff --git a/kernel/src/net/socket/unix/stream/mod.rs b/kernel/src/net/socket/unix/stream/mod.rs index b9ebc9dc..8057b295 100644 --- a/kernel/src/net/socket/unix/stream/mod.rs +++ b/kernel/src/net/socket/unix/stream/mod.rs @@ -322,20 +322,59 @@ impl Socket for StreamSocket { self.shutdown.recv_shutdown(); self.shutdown.send_shutdown(); - let path = match self.get_name()? { + let endpoint = self.get_name()?; + let path = match &endpoint { Endpoint::Inode((_, path)) => path, + Endpoint::Unixpath((_, path)) => path, + Endpoint::Abspath((_, path)) => path, _ => return Err(SystemError::EINVAL), }; - //如果path是空的说明没有bind,不用释放相关映射资源 if path.is_empty() { return Ok(()); } - // TODO: 释放INODE_MAP相关资源 - // 尝试释放相关抽象地址资源 - let _ = remove_abs_addr(&path); - return Ok(()); + match &endpoint { + Endpoint::Unixpath((inode_id, _)) => { + let mut inode_guard = INODE_MAP.write_irqsave(); + inode_guard.remove(inode_id); + } + Endpoint::Inode((current_inode, current_path)) => { + let mut inode_guard = INODE_MAP.write_irqsave(); + // 遍历查找匹配的条目 + let target_entry = inode_guard + .iter() + .find(|(_, ep)| { + if let Endpoint::Inode((map_inode, map_path)) = ep { + // 通过指针相等性比较确保是同一对象 + Arc::ptr_eq(map_inode, current_inode) && map_path == current_path + } else { + log::debug!("not match"); + false + } + }) + .map(|(id, _)| *id); + + if let Some(id) = target_entry { + inode_guard.remove(&id).ok_or(SystemError::EINVAL)?; + } + } + Endpoint::Abspath((abshandle, _)) => { + let mut abs_inode_map = ABS_INODE_MAP.lock_irqsave(); + abs_inode_map.remove(&abshandle.name()); + } + _ => { + log::error!("invalid endpoint type"); + return Err(SystemError::EINVAL); + } + } + + *self.inner.write() = Inner::Init(Init::new()); + self.wait_queue.wakeup(None); + + let _ = remove_abs_addr(path); + + Ok(()) } fn get_peer_name(&self) -> Result { diff --git a/user/apps/test_unix_stream_socket/src/main.rs b/user/apps/test_unix_stream_socket/src/main.rs index 848e2209..ad184070 100644 --- a/user/apps/test_unix_stream_socket/src/main.rs +++ b/user/apps/test_unix_stream_socket/src/main.rs @@ -138,7 +138,9 @@ fn test_stream() -> Result<(), Error> { send_message(client_fd, MSG2).expect("Failed to send message"); println!("Server send finish"); + println!("Server begin close!"); unsafe { close(server_fd) }; + println!("Server close finish!"); }); let client_fd = create_stream_socket()?;