fix(net): misc of resources release (#1096)

* fix: TCP socket miss activation after close

* fix: TCP socket miss activation after close (#1085)

* fix: loopback, udp resource aquire
- remove tcp useless status update
- enable smoltcp medium-ip feature
- change loopback device use ip for addressing, avoid arp procedure
- fix udp couldn't close bug
- fix udp resource aquire didn't lock port
- remove useless Timer in network initialization

* fmt: format

* fix: loopback and udp resource problem (#1086)

* fix: loopback, udp resource aquire
- remove tcp useless status update
- enable smoltcp medium-ip feature
- change loopback device use ip for addressing, avoid arp procedure
- fix udp couldn't close bug
- fix udp resource aquire didn't lock port
- remove useless Timer in network initialization

* fix(net): Unix 资源释放 (#1087)

* unix socket 相关资源释放 #991
* 完善streamsocket资源释放
* 解决inode和id不匹配

* fix TCP socketset release (#1095)

* fix: TCP socket miss activation after close

* fix: loopback, udp resource aquire
- remove tcp useless status update
- enable smoltcp medium-ip feature
- change loopback device use ip for addressing, avoid arp procedure
- fix udp couldn't close bug
- fix udp resource aquire didn't lock port
- remove useless Timer in network initialization

---------

Co-authored-by: YuLong Huang <139891737+LINGLUO00@users.noreply.github.com>
This commit is contained in:
Samuel Dai 2025-03-10 12:58:39 +08:00 committed by GitHub
parent c4c35ed0cc
commit 69dde46586
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 193 additions and 101 deletions

View File

@ -52,7 +52,7 @@ linkme = "=0.3.27"
num = { version = "=0.4.0", default-features = false }
num-derive = "=0.3"
num-traits = { git = "https://git.mirrors.dragonos.org.cn/DragonOS-Community/num-traits.git", rev="1597c1c", default-features = false }
smoltcp = { git = "https://git.mirrors.dragonos.org.cn/DragonOS-Community/smoltcp.git", rev = "3e61c909fd540d05575068d16dc4574e196499ed", default-features = false, features = ["log", "alloc", "socket-raw", "socket-udp", "socket-tcp", "socket-icmp", "socket-dhcpv4", "socket-dns", "proto-ipv4", "proto-ipv6"]}
smoltcp = { git = "https://git.mirrors.dragonos.org.cn/DragonOS-Community/smoltcp.git", rev = "3e61c909fd540d05575068d16dc4574e196499ed", default-features = false, features = ["log", "alloc", "socket-raw", "socket-udp", "socket-tcp", "socket-icmp", "socket-dhcpv4", "socket-dns", "proto-ipv4", "proto-ipv6", "medium-ip"]}
system_error = { path = "crates/system_error" }
uefi = { version = "=0.26.0", features = ["alloc"] }
uefi-raw = "=0.5.0"

View File

@ -204,7 +204,7 @@ impl phy::Device for LoopbackDriver {
let mut result = phy::DeviceCapabilities::default();
result.max_transmission_unit = 65535;
result.max_burst_size = Some(1);
result.medium = smoltcp::phy::Medium::Ethernet;
result.medium = smoltcp::phy::Medium::Ip;
return result;
}
/// ## Loopback驱动处理接受数据事件
@ -284,9 +284,11 @@ impl LoopbackInterface {
pub fn new(mut driver: LoopbackDriver) -> Arc<Self> {
let iface_id = generate_iface_id();
let hardware_addr = HardwareAddress::Ethernet(smoltcp::wire::EthernetAddress([
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
]));
// let hardware_addr = HardwareAddress::Ethernet(smoltcp::wire::EthernetAddress([
// 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
// ]));
let hardware_addr = HardwareAddress::Ip;
let mut iface_config = smoltcp::iface::Config::new(hardware_addr);

View File

@ -285,6 +285,14 @@ impl IfaceCommon {
self.bounds.write().push(socket);
}
pub fn unbind_socket(&self, socket: Arc<dyn InetSocket>) {
let mut bounds = self.bounds.write();
if let Some(index) = bounds.iter().position(|s| Arc::ptr_eq(s, &socket)) {
bounds.remove(index);
log::debug!("unbind socket success");
}
}
// TODO: 需要在inet实现多网卡监听或路由子系统实现后移除
pub fn is_default_iface(&self) -> bool {
self.default_iface

View File

@ -1,4 +1,4 @@
use alloc::{boxed::Box, collections::BTreeMap, sync::Arc};
use alloc::{collections::BTreeMap, sync::Arc};
use log::{debug, info, warn};
use smoltcp::{socket::dhcpv4, wire};
use system_error::SystemError;
@ -7,45 +7,23 @@ use crate::{
driver::net::{Iface, Operstate},
libs::rwlock::RwLockReadGuard,
net::NET_DEVICES,
time::{
sleep::nanosleep,
timer::{next_n_ms_timer_jiffies, Timer, TimerFunction},
PosixTimeSpec,
},
time::{sleep::nanosleep, PosixTimeSpec},
};
/// The network poll function, which will be called by timer.
///
/// The main purpose of this function is to poll all network interfaces.
#[derive(Debug)]
#[allow(dead_code)]
struct NetWorkPollFunc;
impl TimerFunction for NetWorkPollFunc {
fn run(&mut self) -> Result<(), SystemError> {
poll_ifaces();
let next_time = next_n_ms_timer_jiffies(10);
let timer = Timer::new(Box::new(NetWorkPollFunc), next_time);
timer.activate();
return Ok(());
}
}
pub fn net_init() -> Result<(), SystemError> {
dhcp_query()?;
// Init poll timer function
// let next_time = next_n_ms_timer_jiffies(5);
// let timer = Timer::new(Box::new(NetWorkPollFunc), next_time);
// timer.activate();
return Ok(());
dhcp_query()
}
fn dhcp_query() -> Result<(), SystemError> {
let binding = NET_DEVICES.write_irqsave();
// log::debug!("binding: {:?}", *binding);
//由于现在os未实现在用户态为网卡动态分配内存而lo网卡的id最先分配且ip固定不能被分配
//所以特判取用id为1的网卡也就是virtio_net
let net_face = binding.get(&1).ok_or(SystemError::ENODEV)?.clone();
// Default iface, misspelled to net_face
let net_face = binding
.iter()
.find(|(_, iface)| iface.common().is_default_iface())
.unwrap()
.1
.clone();
drop(binding);
@ -60,8 +38,10 @@ fn dhcp_query() -> Result<(), SystemError> {
let sockets = || net_face.sockets().lock_irqsave();
// let dhcp_handle = SOCKET_SET.lock_irqsave().add(dhcp_socket);
let dhcp_handle = sockets().add(dhcp_socket);
defer::defer!({
sockets().remove(dhcp_handle);
});
const DHCP_TRY_ROUND: u8 = 100;
for i in 0..DHCP_TRY_ROUND {
@ -147,7 +127,7 @@ fn dhcp_query() -> Result<(), SystemError> {
}
pub fn poll_ifaces() {
log::debug!("poll_ifaces");
// log::debug!("poll_ifaces");
let guard: RwLockReadGuard<BTreeMap<usize, Arc<dyn Iface>>> = NET_DEVICES.read_irqsave();
if guard.len() == 0 {
warn!("poll_ifaces: No net driver found!");

View File

@ -124,7 +124,7 @@ impl TryFrom<usize> for ShutdownTemp {
fn try_from(value: usize) -> Result<Self, Self::Error> {
match value {
0 | 1 | 2 => Ok(ShutdownTemp {
0..2 => Ok(ShutdownTemp {
bit: value as u8 + 1,
}),
_ => Err(SystemError::EINVAL),

View File

@ -53,11 +53,11 @@ impl BoundInner {
})
.expect("No default interface");
let handle = iface.sockets().lock_no_preempt().add(socket);
let handle = iface.sockets().lock_irqsave().add(socket);
return Ok(Self { handle, iface });
} else {
let iface = get_iface_to_bind(address).ok_or(ENODEV)?;
let handle = iface.sockets().lock_no_preempt().add(socket);
let handle = iface.sockets().lock_irqsave().add(socket);
return Ok(Self { handle, iface });
}
}

View File

@ -33,16 +33,14 @@ impl UnboundUdp {
}
pub fn bind(self, local_endpoint: smoltcp::wire::IpEndpoint) -> Result<BoundUdp, SystemError> {
// let (addr, port) = (local_endpoint.addr, local_endpoint.port);
// if self.socket.bind(local_endpoint).is_err() {
// log::debug!("bind failed!");
// return Err(EINVAL);
// }
let inner = BoundInner::bind(self.socket, &local_endpoint.addr)?;
let bind_addr = local_endpoint.addr;
let bind_port = if local_endpoint.port == 0 {
inner.port_manager().bind_ephemeral_port(InetTypes::Udp)?
} else {
inner
.port_manager()
.bind_port(InetTypes::Udp, local_endpoint.port)?;
local_endpoint.port
};
@ -77,10 +75,6 @@ impl UnboundUdp {
remote: SpinLock::new(Some(endpoint)),
})
}
pub fn close(&mut self) {
self.socket.close();
}
}
#[derive(Debug)]

View File

@ -78,28 +78,31 @@ impl UdpSocket {
bound.close();
inner.take();
}
// unbound socket just drop (only need to free memory)
}
pub fn try_recv(
&self,
buf: &mut [u8],
) -> Result<(usize, smoltcp::wire::IpEndpoint), SystemError> {
let received = match self.inner.read().as_ref().expect("Udp Inner is None") {
UdpInner::Bound(bound) => bound.try_recv(buf),
match self.inner.read().as_ref().expect("Udp Inner is None") {
UdpInner::Bound(bound) => {
let ret = bound.try_recv(buf);
poll_ifaces();
ret
}
_ => Err(ENOTCONN),
};
poll_ifaces();
return received;
}
}
#[inline]
pub fn can_recv(&self) -> bool {
self.on_events().contains(EP::EPOLLIN)
self.event().contains(EP::EPOLLIN)
}
#[inline]
pub fn can_send(&self) -> bool {
self.on_events().contains(EP::EPOLLOUT)
self.event().contains(EP::EPOLLOUT)
}
pub fn try_send(
@ -138,7 +141,7 @@ impl UdpSocket {
}
}
pub fn on_events(&self) -> EPollEventType {
pub fn event(&self) -> EPollEventType {
let mut event = EPollEventType::empty();
match self.inner.read().as_ref().unwrap() {
UdpInner::Unbound(_) => {
@ -154,8 +157,6 @@ impl UdpSocket {
if can_send {
event.insert(EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND);
} else {
todo!("缓冲区空间不够,需要使用信号处理");
}
}
}
@ -169,7 +170,7 @@ impl Socket for UdpSocket {
}
fn poll(&self) -> usize {
self.on_events().bits() as usize
self.event().bits() as usize
}
fn bind(&self, local_endpoint: Endpoint) -> Result<(), SystemError> {
@ -195,7 +196,9 @@ impl Socket for UdpSocket {
fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> {
if let Endpoint::Ip(remote) = endpoint {
self.bind_emphemeral(remote.addr)?;
if !self.is_bound() {
self.bind_emphemeral(remote.addr)?;
}
if let UdpInner::Bound(inner) = self.inner.read().as_ref().expect("UDP Inner disappear")
{
inner.connect(remote);
@ -272,6 +275,11 @@ impl Socket for UdpSocket {
}
.map(|(len, remote)| (len, Endpoint::Ip(remote)));
}
fn close(&self) -> Result<(), SystemError> {
self.close();
Ok(())
}
}
impl InetSocket for UdpSocket {

View File

@ -268,6 +268,15 @@ impl Connecting {
.expect("A Connecting Tcp With No Local Endpoint")
})
}
pub fn get_peer_name(&self) -> smoltcp::wire::IpEndpoint {
self.inner
.with::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
socket
.remote_endpoint()
.expect("A Connecting Tcp With No Remote Endpoint")
})
}
}
#[derive(Debug)]
@ -355,6 +364,13 @@ impl Listening {
.port_manager()
.unbind_port(Types::Tcp, port);
}
pub fn release(&self) {
// log::debug!("Release Listening Socket");
for inner in self.inners.iter() {
inner.release();
}
}
}
#[derive(Debug)]
@ -370,10 +386,6 @@ impl Established {
self.inner.with_mut(f)
}
pub fn with<R, F: Fn(&smoltcp::socket::tcp::Socket<'static>) -> R>(&self, f: F) -> R {
self.inner.with(f)
}
pub fn close(&self) {
self.inner
.with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| socket.close());
@ -384,13 +396,13 @@ impl Established {
self.inner.release();
}
pub fn local_endpoint(&self) -> smoltcp::wire::IpEndpoint {
pub fn get_name(&self) -> smoltcp::wire::IpEndpoint {
self.inner
.with::<smoltcp::socket::tcp::Socket, _, _>(|socket| socket.local_endpoint())
.unwrap()
}
pub fn remote_endpoint(&self) -> smoltcp::wire::IpEndpoint {
pub fn get_peer_name(&self) -> smoltcp::wire::IpEndpoint {
self.inner
.with::<smoltcp::socket::tcp::Socket, _, _>(|socket| socket.remote_endpoint().unwrap())
}

View File

@ -99,7 +99,6 @@ impl TcpSocket {
}
pub fn try_accept(&self) -> Result<(Arc<TcpSocket>, smoltcp::wire::IpEndpoint), SystemError> {
// poll_ifaces();
match self.inner.write().as_mut().expect("Tcp Inner is None") {
Inner::Listening(listening) => listening.accept().map(|(stream, remote)| {
(
@ -227,16 +226,9 @@ impl TcpSocket {
}
}
fn in_notify(&self) -> bool {
self.update_events();
// shouldn't pollee but just get the status of the socket
fn incoming(&self) -> bool {
EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLIN)
}
fn out_notify(&self) -> bool {
self.update_events();
EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLOUT)
}
}
impl Socket for TcpSocket {
@ -252,16 +244,25 @@ impl Socket for TcpSocket {
})),
Inner::Init(Init::Bound((_, local))) => Ok(Endpoint::Ip(*local)),
Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_name())),
Inner::Established(established) => Ok(Endpoint::Ip(established.local_endpoint())),
Inner::Established(established) => Ok(Endpoint::Ip(established.get_name())),
Inner::Listening(listening) => Ok(Endpoint::Ip(listening.get_name())),
}
}
fn get_peer_name(&self) -> Result<Endpoint, SystemError> {
match self.inner.read().as_ref().expect("Tcp Inner is None") {
Inner::Init(_) => Err(ENOTCONN),
Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_peer_name())),
Inner::Established(established) => Ok(Endpoint::Ip(established.get_peer_name())),
Inner::Listening(_) => Err(ENOTCONN),
}
}
fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> {
if let Endpoint::Ip(addr) = endpoint {
return self.do_bind(addr);
}
log::warn!("TcpSocket::bind: invalid endpoint");
log::debug!("TcpSocket::bind: invalid endpoint");
return Err(EINVAL);
}
@ -295,7 +296,7 @@ impl Socket for TcpSocket {
loop {
match self.try_accept() {
Err(EAGAIN_OR_EWOULDBLOCK) => {
wq_wait_event_interruptible!(self.wait_queue, self.in_notify(), {})?;
wq_wait_event_interruptible!(self.wait_queue, self.incoming(), {})?;
}
result => break result,
}
@ -348,7 +349,15 @@ impl Socket for TcpSocket {
}
fn close(&self) -> Result<(), SystemError> {
let inner = self.inner.write().take().unwrap();
let Some(inner) = self.inner.write().take() else {
log::warn!("TcpSocket::close: already closed, unexpected");
return Ok(());
};
if let Some(iface) = inner.iface() {
iface
.common()
.unbind_socket(self.self_ref.upgrade().unwrap());
}
match inner {
// complete connecting socket close logic
@ -356,22 +365,21 @@ impl Socket for TcpSocket {
let conn = unsafe { conn.into_established() };
conn.close();
conn.release();
Ok(())
}
Inner::Established(es) => {
es.close();
es.release();
Ok(())
}
Inner::Listening(ls) => {
ls.close();
Ok(())
ls.release();
}
Inner::Init(init) => {
init.close();
Ok(())
}
}
};
Ok(())
}
fn set_option(&self, level: PSOL, name: usize, val: &[u8]) -> Result<(), SystemError> {

View File

@ -290,19 +290,58 @@ impl Socket for SeqpacketSocket {
self.shutdown.recv_shutdown();
self.shutdown.send_shutdown();
let path = match self.get_name()? {
let endpoint = self.get_name()?;
let path = match &endpoint {
Endpoint::Inode((_, path)) => path,
Endpoint::Unixpath((_, path)) => path,
Endpoint::Abspath((_, path)) => path,
_ => return Err(SystemError::EINVAL),
};
//如果path是空的说明没有bind不用释放相关映射资源
if path.is_empty() {
return Ok(());
}
// TODO: 释放INODE_MAP相关资源
// 尝试释放相关抽象地址资源
let _ = remove_abs_addr(&path);
match &endpoint {
Endpoint::Unixpath((inode_id, _)) => {
let mut inode_guard = INODE_MAP.write_irqsave();
inode_guard.remove(inode_id);
}
Endpoint::Inode((current_inode, current_path)) => {
let mut inode_guard = INODE_MAP.write_irqsave();
// 遍历查找匹配的条目
let target_entry = inode_guard
.iter()
.find(|(_, ep)| {
if let Endpoint::Inode((map_inode, map_path)) = ep {
// 通过指针相等性比较确保是同一对象
Arc::ptr_eq(map_inode, current_inode) && map_path == current_path
} else {
log::debug!("not match");
false
}
})
.map(|(id, _)| *id);
if let Some(id) = target_entry {
inode_guard.remove(&id).ok_or(SystemError::EINVAL)?;
}
}
Endpoint::Abspath((abshandle, _)) => {
let mut abs_inode_map = ABS_INODE_MAP.lock_irqsave();
abs_inode_map.remove(&abshandle.name());
}
_ => {
log::error!("invalid endpoint type");
return Err(SystemError::EINVAL);
}
}
*self.inner.write() = Inner::Init(Init::new());
self.wait_queue.wakeup(None);
let _ = remove_abs_addr(path);
return Ok(());
}
@ -471,12 +510,12 @@ impl Socket for SeqpacketSocket {
}
fn send_buffer_size(&self) -> usize {
log::warn!("using default buffer size");
// log::warn!("using default buffer size");
SeqpacketSocket::DEFAULT_BUF_SIZE
}
fn recv_buffer_size(&self) -> usize {
log::warn!("using default buffer size");
// log::warn!("using default buffer size");
SeqpacketSocket::DEFAULT_BUF_SIZE
}

View File

@ -322,20 +322,59 @@ impl Socket for StreamSocket {
self.shutdown.recv_shutdown();
self.shutdown.send_shutdown();
let path = match self.get_name()? {
let endpoint = self.get_name()?;
let path = match &endpoint {
Endpoint::Inode((_, path)) => path,
Endpoint::Unixpath((_, path)) => path,
Endpoint::Abspath((_, path)) => path,
_ => return Err(SystemError::EINVAL),
};
//如果path是空的说明没有bind不用释放相关映射资源
if path.is_empty() {
return Ok(());
}
// TODO: 释放INODE_MAP相关资源
// 尝试释放相关抽象地址资源
let _ = remove_abs_addr(&path);
return Ok(());
match &endpoint {
Endpoint::Unixpath((inode_id, _)) => {
let mut inode_guard = INODE_MAP.write_irqsave();
inode_guard.remove(inode_id);
}
Endpoint::Inode((current_inode, current_path)) => {
let mut inode_guard = INODE_MAP.write_irqsave();
// 遍历查找匹配的条目
let target_entry = inode_guard
.iter()
.find(|(_, ep)| {
if let Endpoint::Inode((map_inode, map_path)) = ep {
// 通过指针相等性比较确保是同一对象
Arc::ptr_eq(map_inode, current_inode) && map_path == current_path
} else {
log::debug!("not match");
false
}
})
.map(|(id, _)| *id);
if let Some(id) = target_entry {
inode_guard.remove(&id).ok_or(SystemError::EINVAL)?;
}
}
Endpoint::Abspath((abshandle, _)) => {
let mut abs_inode_map = ABS_INODE_MAP.lock_irqsave();
abs_inode_map.remove(&abshandle.name());
}
_ => {
log::error!("invalid endpoint type");
return Err(SystemError::EINVAL);
}
}
*self.inner.write() = Inner::Init(Init::new());
self.wait_queue.wakeup(None);
let _ = remove_abs_addr(path);
Ok(())
}
fn get_peer_name(&self) -> Result<Endpoint, SystemError> {

View File

@ -138,7 +138,9 @@ fn test_stream() -> Result<(), Error> {
send_message(client_fd, MSG2).expect("Failed to send message");
println!("Server send finish");
println!("Server begin close!");
unsafe { close(server_fd) };
println!("Server close finish!");
});
let client_fd = create_stream_socket()?;