fix: tcp poll没有正确处理posix socket的listen状态的问题 (#859)

This commit is contained in:
LoGin 2024-07-24 18:21:39 +08:00 committed by GitHub
parent 79ad6e5ba4
commit 634349e0eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 222 additions and 122 deletions

View File

@ -436,6 +436,7 @@ impl EventPoll {
} }
// 判断epoll上有没有就绪事件 // 判断epoll上有没有就绪事件
let mut available = epoll_guard.ep_events_available(); let mut available = epoll_guard.ep_events_available();
drop(epoll_guard); drop(epoll_guard);
loop { loop {
if available { if available {
@ -759,6 +760,7 @@ impl EventPoll {
/// 与C兼容的Epoll事件结构体 /// 与C兼容的Epoll事件结构体
#[derive(Copy, Clone, Default)] #[derive(Copy, Clone, Default)]
#[repr(packed)] #[repr(packed)]
#[repr(C)]
pub struct EPollEvent { pub struct EPollEvent {
/// 表示触发的事件 /// 表示触发的事件
events: u32, events: u32,
@ -870,5 +872,8 @@ bitflags! {
/// 表示epoll已经被释放但是在目前的设计中未用到 /// 表示epoll已经被释放但是在目前的设计中未用到
const POLLFREE = 0x4000; const POLLFREE = 0x4000;
/// listen状态的socket可以接受连接
const EPOLL_LISTEN_CAN_ACCEPT = Self::EPOLLIN.bits | Self::EPOLLRDNORM.bits;
} }
} }

View File

@ -191,25 +191,25 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
for (handle, socket_type) in sockets.iter() { for (handle, socket_type) in sockets.iter() {
let handle_guard = HANDLE_MAP.read_irqsave(); let handle_guard = HANDLE_MAP.read_irqsave();
let global_handle = GlobalSocketHandle::new_smoltcp_handle(handle); let global_handle = GlobalSocketHandle::new_smoltcp_handle(handle);
let item = handle_guard.get(&global_handle); let item: Option<&super::socket::SocketHandleItem> = handle_guard.get(&global_handle);
if item.is_none() { if item.is_none() {
continue; continue;
} }
let handle_item = item.unwrap(); let handle_item = item.unwrap();
let posix_item = handle_item.posix_item();
if posix_item.is_none() {
continue;
}
let posix_item = posix_item.unwrap();
// 获取socket上的事件 // 获取socket上的事件
let mut events = let mut events = SocketPollMethod::poll(socket_type, handle_item).bits() as u64;
SocketPollMethod::poll(socket_type, handle_item.shutdown_type()).bits() as u64;
// 分发到相应类型socket处理 // 分发到相应类型socket处理
match socket_type { match socket_type {
smoltcp::socket::Socket::Raw(_) | smoltcp::socket::Socket::Udp(_) => { smoltcp::socket::Socket::Raw(_) | smoltcp::socket::Socket::Udp(_) => {
handle_guard posix_item.wakeup_any(events);
.get(&global_handle)
.unwrap()
.wait_queue
.wakeup_any(events);
} }
smoltcp::socket::Socket::Icmp(_) => unimplemented!("Icmp socket hasn't unimplemented"), smoltcp::socket::Socket::Icmp(_) => unimplemented!("Icmp socket hasn't unimplemented"),
smoltcp::socket::Socket::Tcp(inner_socket) => { smoltcp::socket::Socket::Tcp(inner_socket) => {
@ -222,17 +222,14 @@ fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> {
if inner_socket.state() == smoltcp::socket::tcp::State::CloseWait { if inner_socket.state() == smoltcp::socket::tcp::State::CloseWait {
events |= EPollEventType::EPOLLHUP.bits() as u64; events |= EPollEventType::EPOLLHUP.bits() as u64;
} }
handle_guard
.get(&global_handle) posix_item.wakeup_any(events);
.unwrap()
.wait_queue
.wakeup_any(events);
} }
smoltcp::socket::Socket::Dhcpv4(_) => {} smoltcp::socket::Socket::Dhcpv4(_) => {}
smoltcp::socket::Socket::Dns(_) => unimplemented!("Dns socket hasn't unimplemented"), smoltcp::socket::Socket::Dns(_) => unimplemented!("Dns socket hasn't unimplemented"),
} }
EventPoll::wakeup_epoll( EventPoll::wakeup_epoll(
&handle_item.epitems, &posix_item.epitems,
EPollEventType::from_bits_truncate(events as u32), EPollEventType::from_bits_truncate(events as u32),
)?; )?;
drop(handle_guard); drop(handle_guard);

View File

@ -16,8 +16,8 @@ use crate::{
}; };
use super::{ use super::{
handle::GlobalSocketHandle, Socket, SocketHandleItem, SocketMetadata, SocketOptions, handle::GlobalSocketHandle, PosixSocketHandleItem, Socket, SocketHandleItem, SocketMetadata,
SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET, SocketOptions, SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET,
}; };
/// @brief 表示原始的socket。原始套接字绕过传输层协议如 TCP 或 UDP并提供对网络层协议如 IP的直接访问。 /// @brief 表示原始的socket。原始套接字绕过传输层协议如 TCP 或 UDP并提供对网络层协议如 IP的直接访问。
@ -32,6 +32,7 @@ pub struct RawSocket {
header_included: bool, header_included: bool,
/// socket的metadata /// socket的metadata
metadata: SocketMetadata, metadata: SocketMetadata,
posix_item: Arc<PosixSocketHandleItem>,
} }
impl RawSocket { impl RawSocket {
@ -76,15 +77,22 @@ impl RawSocket {
options, options,
); );
let posix_item = Arc::new(PosixSocketHandleItem::new(None));
return Self { return Self {
handle, handle,
header_included: false, header_included: false,
metadata, metadata,
posix_item,
}; };
} }
} }
impl Socket for RawSocket { impl Socket for RawSocket {
fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
self.posix_item.clone()
}
fn close(&mut self) { fn close(&mut self) {
let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let mut socket_set_guard = SOCKET_SET.lock_irqsave();
if let smoltcp::socket::Socket::Udp(mut sock) = if let smoltcp::socket::Socket::Udp(mut sock) =
@ -123,11 +131,7 @@ impl Socket for RawSocket {
} }
} }
drop(socket_set_guard); drop(socket_set_guard);
SocketHandleItem::sleep( self.posix_item.sleep(EPollEventType::EPOLLIN.bits() as u64);
self.socket_handle(),
EPollEventType::EPOLLIN.bits() as u64,
HANDLE_MAP.read_irqsave(),
);
} }
} }
@ -240,6 +244,7 @@ pub struct UdpSocket {
pub handle: GlobalSocketHandle, pub handle: GlobalSocketHandle,
remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect() 应该使用IP地址。 remote_endpoint: Option<Endpoint>, // 记录远程endpoint提供给connect() 应该使用IP地址。
metadata: SocketMetadata, metadata: SocketMetadata,
posix_item: Arc<PosixSocketHandleItem>,
} }
impl UdpSocket { impl UdpSocket {
@ -278,10 +283,13 @@ impl UdpSocket {
options, options,
); );
let posix_item = Arc::new(PosixSocketHandleItem::new(None));
return Self { return Self {
handle, handle,
remote_endpoint: None, remote_endpoint: None,
metadata, metadata,
posix_item,
}; };
} }
@ -311,6 +319,10 @@ impl UdpSocket {
} }
impl Socket for UdpSocket { impl Socket for UdpSocket {
fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
self.posix_item.clone()
}
fn close(&mut self) { fn close(&mut self) {
let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let mut socket_set_guard = SOCKET_SET.lock_irqsave();
if let smoltcp::socket::Socket::Udp(mut sock) = if let smoltcp::socket::Socket::Udp(mut sock) =
@ -344,11 +356,7 @@ impl Socket for UdpSocket {
// return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
} }
drop(socket_set_guard); drop(socket_set_guard);
SocketHandleItem::sleep( self.posix_item.sleep(EPollEventType::EPOLLIN.bits() as u64);
self.socket_handle(),
EPollEventType::EPOLLIN.bits() as u64,
HANDLE_MAP.read_irqsave(),
);
} }
} }
@ -484,6 +492,7 @@ pub struct TcpSocket {
local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind() local_endpoint: Option<wire::IpEndpoint>, // save local endpoint for bind()
is_listening: bool, is_listening: bool,
metadata: SocketMetadata, metadata: SocketMetadata,
posix_item: Arc<PosixSocketHandleItem>,
} }
impl TcpSocket { impl TcpSocket {
@ -516,6 +525,7 @@ impl TcpSocket {
Self::DEFAULT_METADATA_BUF_SIZE, Self::DEFAULT_METADATA_BUF_SIZE,
options, options,
); );
let posix_item = Arc::new(PosixSocketHandleItem::new(None));
// debug!("when there's a new tcp socket,its'len: {}",handles.len()); // debug!("when there's a new tcp socket,its'len: {}",handles.len());
return Self { return Self {
@ -523,6 +533,7 @@ impl TcpSocket {
local_endpoint: None, local_endpoint: None,
is_listening: false, is_listening: false,
metadata, metadata,
posix_item,
}; };
} }
@ -532,10 +543,8 @@ impl TcpSocket {
local_endpoint: wire::IpEndpoint, local_endpoint: wire::IpEndpoint,
) -> Result<(), SystemError> { ) -> Result<(), SystemError> {
let listen_result = if local_endpoint.addr.is_unspecified() { let listen_result = if local_endpoint.addr.is_unspecified() {
// debug!("Tcp Socket Listen on port {}", local_endpoint.port);
socket.listen(local_endpoint.port) socket.listen(local_endpoint.port)
} else { } else {
// debug!("Tcp Socket Listen on {local_endpoint}");
socket.listen(local_endpoint) socket.listen(local_endpoint)
}; };
return match listen_result { return match listen_result {
@ -561,9 +570,33 @@ impl TcpSocket {
let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]); let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]);
tcp::Socket::new(rx_buffer, tx_buffer) tcp::Socket::new(rx_buffer, tx_buffer)
} }
/// listening状态的posix socket是需要特殊处理的
fn tcp_poll_listening(&self) -> EPollEventType {
let socketset_guard = SOCKET_SET.lock_irqsave();
let can_accept = self.handles.iter().any(|h| {
if let Some(sh) = h.smoltcp_handle() {
let socket = socketset_guard.get::<tcp::Socket>(sh);
socket.is_active()
} else {
false
}
});
if can_accept {
return EPollEventType::EPOLL_LISTEN_CAN_ACCEPT;
} else {
return EPollEventType::empty();
}
}
} }
impl Socket for TcpSocket { impl Socket for TcpSocket {
fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
self.posix_item.clone()
}
fn close(&mut self) { fn close(&mut self) {
for handle in self.handles.iter() { for handle in self.handles.iter() {
{ {
@ -641,11 +674,8 @@ impl Socket for TcpSocket {
return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); return (Err(SystemError::ENOTCONN), Endpoint::Ip(None));
} }
drop(socket_set_guard); drop(socket_set_guard);
SocketHandleItem::sleep( self.posix_item
self.socket_handle(), .sleep((EPollEventType::EPOLLIN | EPollEventType::EPOLLHUP).bits() as u64);
(EPollEventType::EPOLLIN.bits() | EPollEventType::EPOLLHUP.bits()) as u64,
HANDLE_MAP.read_irqsave(),
);
} }
} }
@ -688,24 +718,31 @@ impl Socket for TcpSocket {
} }
fn poll(&self) -> EPollEventType { fn poll(&self) -> EPollEventType {
// 处理listen的快速路径
if self.is_listening {
return self.tcp_poll_listening();
}
// 由于上面处理了listening状态所以这里只处理非listening状态这种情况下只有一个handle
assert!(self.handles.len() == 1);
let mut socket_set_guard = SOCKET_SET.lock_irqsave(); let mut socket_set_guard = SOCKET_SET.lock_irqsave();
// debug!("tcp socket:poll, socket'len={}",self.handle.len()); // debug!("tcp socket:poll, socket'len={}",self.handle.len());
let socket = socket_set_guard let socket = socket_set_guard
.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); .get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
return SocketPollMethod::tcp_poll( let handle_map_guard = HANDLE_MAP.read_irqsave();
socket, let handle_item = handle_map_guard.get(&self.socket_handle()).unwrap();
HANDLE_MAP let shutdown_type = handle_item.shutdown_type();
.read_irqsave() let is_posix_listen = handle_item.is_posix_listen;
.get(&self.socket_handle()) drop(handle_map_guard);
.unwrap()
.shutdown_type(), return SocketPollMethod::tcp_poll(socket, shutdown_type, is_posix_listen);
);
} }
fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> {
let mut sockets = SOCKET_SET.lock_irqsave(); let mut sockets = SOCKET_SET.lock_irqsave();
// debug!("tcp socket:connect, socket'len={}",self.handle.len()); // debug!("tcp socket:connect, socket'len={}", self.handles.len());
let socket = let socket =
sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap()); sockets.get_mut::<tcp::Socket>(self.handles.get(0).unwrap().smoltcp_handle().unwrap());
@ -739,11 +776,7 @@ impl Socket for TcpSocket {
} }
tcp::State::SynSent => { tcp::State::SynSent => {
drop(sockets); drop(sockets);
SocketHandleItem::sleep( self.posix_item.sleep(Self::CAN_CONNECT);
self.socket_handle(),
Self::CAN_CONNECT,
HANDLE_MAP.read_irqsave(),
);
} }
_ => { _ => {
return Err(SystemError::ECONNREFUSED); return Err(SystemError::ECONNREFUSED);
@ -772,6 +805,11 @@ impl Socket for TcpSocket {
return Ok(()); return Ok(());
} }
// debug!(
// "tcp socket:listen, socket'len={}, backlog = {backlog}",
// self.handles.len()
// );
let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?;
let mut sockets = SOCKET_SET.lock_irqsave(); let mut sockets = SOCKET_SET.lock_irqsave();
// 获取handle的数量 // 获取handle的数量
@ -781,16 +819,19 @@ impl Socket for TcpSocket {
// 添加剩余需要构建的socket // 添加剩余需要构建的socket
// debug!("tcp socket:before listen, socket'len={}", self.handle_list.len()); // debug!("tcp socket:before listen, socket'len={}", self.handle_list.len());
let mut handle_guard = HANDLE_MAP.write_irqsave(); let mut handle_guard = HANDLE_MAP.write_irqsave();
let wait_queue = Arc::clone(&handle_guard.get(&self.socket_handle()).unwrap().wait_queue); let socket_handle_item_0 = handle_guard.get_mut(&self.socket_handle()).unwrap();
socket_handle_item_0.is_posix_listen = true;
self.handles.extend((handlen..backlog).map(|_| { self.handles.extend((handlen..backlog).map(|_| {
let socket = Self::create_new_socket(); let socket = Self::create_new_socket();
let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket)); let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket));
let handle_item = SocketHandleItem::new(Some(wait_queue.clone())); let mut handle_item = SocketHandleItem::new(Arc::downgrade(&self.posix_item));
handle_item.is_posix_listen = true;
handle_guard.insert(handle, handle_item); handle_guard.insert(handle, handle_item);
handle handle
})); }));
// debug!("tcp socket:listen, socket'len={}",self.handle.len());
// debug!("tcp socket:listen, socket'len={}", self.handles.len());
// debug!("tcp socket:listen, backlog={backlog}"); // debug!("tcp socket:listen, backlog={backlog}");
// 监听所有的socket // 监听所有的socket
@ -805,6 +846,7 @@ impl Socket for TcpSocket {
} }
// debug!("Tcp Socket before listen, open={}", socket.is_open()); // debug!("Tcp Socket before listen, open={}", socket.is_open());
} }
return Ok(()); return Ok(());
} }
@ -820,6 +862,7 @@ impl Socket for TcpSocket {
self.local_endpoint = Some(ip); self.local_endpoint = Some(ip);
self.is_listening = false; self.is_listening = false;
return Ok(()); return Ok(());
} }
return Err(SystemError::EINVAL); return Err(SystemError::EINVAL);
@ -862,8 +905,7 @@ impl Socket for TcpSocket {
.remote_endpoint() .remote_endpoint()
.ok_or(SystemError::ENOTCONN)?; .ok_or(SystemError::ENOTCONN)?;
let mut tcp_socket = Self::create_new_socket(); let tcp_socket = Self::create_new_socket();
self.do_listen(&mut tcp_socket, endpoint)?;
let new_handle = GlobalSocketHandle::new_smoltcp_handle(sockset.add(tcp_socket)); let new_handle = GlobalSocketHandle::new_smoltcp_handle(sockset.add(tcp_socket));
@ -883,31 +925,40 @@ impl Socket for TcpSocket {
local_endpoint: self.local_endpoint, local_endpoint: self.local_endpoint,
is_listening: false, is_listening: false,
metadata, metadata,
posix_item: Arc::new(PosixSocketHandleItem::new(None)),
}); });
{ {
let mut handle_guard = HANDLE_MAP.write_irqsave(); let mut handle_guard = HANDLE_MAP.write_irqsave();
// 先删除原来的 // 先删除原来的
let item = handle_guard.remove(&old_handle).unwrap(); let item = handle_guard.remove(&old_handle).unwrap();
item.reset_shutdown_type();
assert!(item.is_posix_listen);
// 按照smoltcp行为将新的handle绑定到原来的item // 按照smoltcp行为将新的handle绑定到原来的item
let new_item = SocketHandleItem::new(None); let new_item = SocketHandleItem::new(Arc::downgrade(&sock_ret.posix_item));
handle_guard.insert(old_handle, new_item); handle_guard.insert(old_handle, new_item);
// 插入新的item // 插入新的item
handle_guard.insert(new_handle, item); handle_guard.insert(new_handle, item);
let socket = sockset.get_mut::<tcp::Socket>(
self.handles[handle_index].smoltcp_handle().unwrap(),
);
if !socket.is_listening() {
self.do_listen(socket, endpoint)?;
}
drop(handle_guard); drop(handle_guard);
} }
return Ok((sock_ret, Endpoint::Ip(Some(remote_ep)))); return Ok((sock_ret, Endpoint::Ip(Some(remote_ep))));
} }
drop(sockset); drop(sockset);
// debug!("[TCP] [Accept] sleeping socket with handle: {:?}", self.handles.get(0).unwrap().smoltcp_handle().unwrap()); // debug!("[TCP] [Accept] sleeping socket with handle: {:?}", self.handles.get(0).unwrap().smoltcp_handle().unwrap());
SocketHandleItem::sleep( self.posix_item.sleep(Self::CAN_ACCPET);
self.socket_handle(), // NOTICE
Self::CAN_ACCPET,
HANDLE_MAP.read_irqsave(),
);
// debug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len()); // debug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len());
} }
} }

View File

@ -22,7 +22,7 @@ use crate::{
Metadata, Metadata,
}, },
libs::{ libs::{
rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}, rwlock::{RwLock, RwLockWriteGuard},
spinlock::{SpinLock, SpinLockGuard}, spinlock::{SpinLock, SpinLockGuard},
wait_queue::EventWaitQueue, wait_queue::EventWaitQueue,
}, },
@ -87,7 +87,7 @@ pub(super) fn new_socket(
} }
}; };
let handle_item = SocketHandleItem::new(None); let handle_item = SocketHandleItem::new(Arc::downgrade(&socket.posix_item()));
HANDLE_MAP HANDLE_MAP
.write_irqsave() .write_irqsave()
.insert(socket.socket_handle(), handle_item); .insert(socket.socket_handle(), handle_item);
@ -243,36 +243,26 @@ pub trait Socket: Sync + Send + Debug + Any {
fn as_any_mut(&mut self) -> &mut dyn Any; fn as_any_mut(&mut self) -> &mut dyn Any;
fn add_epoll(&mut self, epitem: Arc<EPollItem>) -> Result<(), SystemError> { fn add_epoll(&mut self, epitem: Arc<EPollItem>) -> Result<(), SystemError> {
HANDLE_MAP let posix_item = self.posix_item();
.write_irqsave() posix_item.add_epoll(epitem);
.get_mut(&self.socket_handle())
.unwrap()
.add_epoll(epitem);
Ok(()) Ok(())
} }
fn remove_epoll(&mut self, epoll: &Weak<SpinLock<EventPoll>>) -> Result<(), SystemError> { fn remove_epoll(&mut self, epoll: &Weak<SpinLock<EventPoll>>) -> Result<(), SystemError> {
HANDLE_MAP let posix_item = self.posix_item();
.write_irqsave() posix_item.remove_epoll(epoll)?;
.get_mut(&self.socket_handle())
.unwrap()
.remove_epoll(epoll)?;
Ok(()) Ok(())
} }
fn clear_epoll(&mut self) -> Result<(), SystemError> { fn clear_epoll(&mut self) -> Result<(), SystemError> {
let mut handle_map_guard = HANDLE_MAP.write_irqsave(); let posix_item = self.posix_item();
let handle_item = handle_map_guard.get_mut(&self.socket_handle()).unwrap();
for epitem in handle_item.epitems.lock_irqsave().iter() { for epitem in posix_item.epitems.lock_irqsave().iter() {
let epoll = epitem.epoll(); let epoll = epitem.epoll();
if epoll.upgrade().is_some() {
EventPoll::ep_remove( if let Some(epoll) = epoll.upgrade() {
&mut epoll.upgrade().unwrap().lock_irqsave(), EventPoll::ep_remove(&mut epoll.lock_irqsave(), epitem.fd(), None)?;
epitem.fd(),
None,
)?;
} }
} }
@ -280,6 +270,8 @@ pub trait Socket: Sync + Send + Debug + Any {
} }
fn close(&mut self); fn close(&mut self);
fn posix_item(&self) -> Arc<PosixSocketHandleItem>;
} }
impl Clone for Box<dyn Socket> { impl Clone for Box<dyn Socket> {
@ -410,54 +402,35 @@ impl IndexNode for SocketInode {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct SocketHandleItem { pub struct PosixSocketHandleItem {
/// shutdown状态
pub shutdown_type: RwLock<ShutdownType>,
/// socket的waitqueue /// socket的waitqueue
pub wait_queue: Arc<EventWaitQueue>, wait_queue: Arc<EventWaitQueue>,
/// epitems考虑写在这是否是最优解
pub epitems: SpinLock<LinkedList<Arc<EPollItem>>>, pub epitems: SpinLock<LinkedList<Arc<EPollItem>>>,
} }
impl SocketHandleItem { impl PosixSocketHandleItem {
pub fn new(wait_queue: Option<Arc<EventWaitQueue>>) -> Self { pub fn new(wait_queue: Option<Arc<EventWaitQueue>>) -> Self {
Self { Self {
shutdown_type: RwLock::new(ShutdownType::empty()),
wait_queue: wait_queue.unwrap_or(Arc::new(EventWaitQueue::new())), wait_queue: wait_queue.unwrap_or(Arc::new(EventWaitQueue::new())),
epitems: SpinLock::new(LinkedList::new()), epitems: SpinLock::new(LinkedList::new()),
} }
} }
/// ## 在socket的等待队列上睡眠 /// ## 在socket的等待队列上睡眠
pub fn sleep( pub fn sleep(&self, events: u64) {
socket_handle: GlobalSocketHandle,
events: u64,
handle_map_guard: RwLockReadGuard<'_, HashMap<GlobalSocketHandle, SocketHandleItem>>,
) {
unsafe { unsafe {
handle_map_guard ProcessManager::preempt_disable();
.get(&socket_handle) self.wait_queue.sleep_without_schedule(events);
.unwrap() ProcessManager::preempt_enable();
.wait_queue }
.sleep_without_schedule(events)
};
drop(handle_map_guard);
schedule(SchedMode::SM_NONE); schedule(SchedMode::SM_NONE);
} }
pub fn shutdown_type(&self) -> ShutdownType { pub fn add_epoll(&self, epitem: Arc<EPollItem>) {
*self.shutdown_type.read()
}
pub fn shutdown_type_writer(&mut self) -> RwLockWriteGuard<ShutdownType> {
self.shutdown_type.write_irqsave()
}
pub fn add_epoll(&mut self, epitem: Arc<EPollItem>) {
self.epitems.lock_irqsave().push_back(epitem) self.epitems.lock_irqsave().push_back(epitem)
} }
pub fn remove_epoll(&mut self, epoll: &Weak<SpinLock<EventPoll>>) -> Result<(), SystemError> { pub fn remove_epoll(&self, epoll: &Weak<SpinLock<EventPoll>>) -> Result<(), SystemError> {
let is_remove = !self let is_remove = !self
.epitems .epitems
.lock_irqsave() .lock_irqsave()
@ -471,6 +444,50 @@ impl SocketHandleItem {
Err(SystemError::ENOENT) Err(SystemError::ENOENT)
} }
/// ### 唤醒该队列上等待events的进程
///
/// ### 参数
/// - events: 发生的事件
///
/// 需要注意的是只要触发了events中的任意一件事件进程都会被唤醒
pub fn wakeup_any(&self, events: u64) {
self.wait_queue.wakeup_any(events);
}
}
#[derive(Debug)]
pub struct SocketHandleItem {
/// 对应的posix socket是否为listen的
pub is_posix_listen: bool,
/// shutdown状态
pub shutdown_type: RwLock<ShutdownType>,
pub posix_item: Weak<PosixSocketHandleItem>,
}
impl SocketHandleItem {
pub fn new(posix_item: Weak<PosixSocketHandleItem>) -> Self {
Self {
is_posix_listen: false,
shutdown_type: RwLock::new(ShutdownType::empty()),
posix_item,
}
}
pub fn shutdown_type(&self) -> ShutdownType {
*self.shutdown_type.read()
}
pub fn shutdown_type_writer(&mut self) -> RwLockWriteGuard<ShutdownType> {
self.shutdown_type.write_irqsave()
}
pub fn reset_shutdown_type(&self) {
*self.shutdown_type.write() = ShutdownType::empty();
}
pub fn posix_item(&self) -> Option<Arc<PosixSocketHandleItem>> {
self.posix_item.upgrade()
}
} }
/// # TCP 和 UDP 的端口管理器。 /// # TCP 和 UDP 的端口管理器。
@ -763,33 +780,47 @@ impl TryFrom<u8> for PosixSocketType {
pub struct SocketPollMethod; pub struct SocketPollMethod;
impl SocketPollMethod { impl SocketPollMethod {
pub fn poll(socket: &socket::Socket, shutdown: ShutdownType) -> EPollEventType { pub fn poll(socket: &socket::Socket, handle_item: &SocketHandleItem) -> EPollEventType {
let shutdown = handle_item.shutdown_type();
match socket { match socket {
socket::Socket::Udp(udp) => Self::udp_poll(udp, shutdown), socket::Socket::Udp(udp) => Self::udp_poll(udp, shutdown),
socket::Socket::Tcp(tcp) => Self::tcp_poll(tcp, shutdown), socket::Socket::Tcp(tcp) => Self::tcp_poll(tcp, shutdown, handle_item.is_posix_listen),
socket::Socket::Raw(raw) => Self::raw_poll(raw, shutdown), socket::Socket::Raw(raw) => Self::raw_poll(raw, shutdown),
_ => todo!(), _ => todo!(),
} }
} }
pub fn tcp_poll(socket: &tcp::Socket, shutdown: ShutdownType) -> EPollEventType { pub fn tcp_poll(
socket: &tcp::Socket,
shutdown: ShutdownType,
is_posix_listen: bool,
) -> EPollEventType {
let mut events = EPollEventType::empty(); let mut events = EPollEventType::empty();
if socket.is_listening() && socket.is_active() { // debug!("enter tcp_poll! is_posix_listen:{}", is_posix_listen);
events.insert(EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM); // 处理listen的socket
if is_posix_listen {
// 如果是listen的socket那么只有EPOLLIN和EPOLLRDNORM
if socket.is_active() {
events.insert(EPollEventType::EPOLL_LISTEN_CAN_ACCEPT);
}
// debug!("tcp_poll listen socket! events:{:?}", events);
return events; return events;
} }
// socket已经关闭 let state = socket.state();
if !socket.is_open() {
events.insert(EPollEventType::EPOLLHUP) if shutdown == ShutdownType::SHUTDOWN_MASK || state == tcp::State::Closed {
events.insert(EPollEventType::EPOLLHUP);
} }
if shutdown.contains(ShutdownType::RCV_SHUTDOWN) { if shutdown.contains(ShutdownType::RCV_SHUTDOWN) {
events.insert( events.insert(
EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM | EPollEventType::EPOLLRDHUP, EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM | EPollEventType::EPOLLRDHUP,
); );
} }
let state = socket.state(); // Connected or passive Fast Open socket?
if state != tcp::State::SynSent && state != tcp::State::SynReceived { if state != tcp::State::SynSent && state != tcp::State::SynReceived {
// socket有可读数据 // socket有可读数据
if socket.can_recv() { if socket.can_recv() {
@ -797,12 +828,12 @@ impl SocketPollMethod {
} }
if !(shutdown.contains(ShutdownType::SEND_SHUTDOWN)) { if !(shutdown.contains(ShutdownType::SEND_SHUTDOWN)) {
// 缓冲区可写 // 缓冲区可写这里判断可写的逻辑好像跟linux不太一样
if socket.send_queue() < socket.send_capacity() { if socket.send_queue() < socket.send_capacity() {
events.insert(EPollEventType::EPOLLOUT | EPollEventType::EPOLLWRNORM); events.insert(EPollEventType::EPOLLOUT | EPollEventType::EPOLLWRNORM);
} else { } else {
// TODO触发缓冲区已满的信号 // TODO触发缓冲区已满的信号SIGIO
todo!("A signal that the buffer is full needs to be sent"); todo!("A signal SIGIO that the buffer is full needs to be sent");
} }
} else { } else {
// 如果我们的socket关闭了SEND_SHUTDOWNepoll事件就是EPOLLOUT // 如果我们的socket关闭了SEND_SHUTDOWNepoll事件就是EPOLLOUT
@ -813,6 +844,7 @@ impl SocketPollMethod {
} }
// socket发生错误 // socket发生错误
// TODO: 这里的逻辑可能有问题需要进一步验证是否is_active()==false就代表socket发生错误
if !socket.is_active() { if !socket.is_active() {
events.insert(EPollEventType::EPOLLERR); events.insert(EPollEventType::EPOLLERR);
} }

View File

@ -4,7 +4,8 @@ use system_error::SystemError;
use crate::{libs::spinlock::SpinLock, net::Endpoint}; use crate::{libs::spinlock::SpinLock, net::Endpoint};
use super::{ use super::{
handle::GlobalSocketHandle, Socket, SocketInode, SocketMetadata, SocketOptions, SocketType, handle::GlobalSocketHandle, PosixSocketHandleItem, Socket, SocketInode, SocketMetadata,
SocketOptions, SocketType,
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -13,6 +14,7 @@ pub struct StreamSocket {
buffer: Arc<SpinLock<Vec<u8>>>, buffer: Arc<SpinLock<Vec<u8>>>,
peer_inode: Option<Arc<SocketInode>>, peer_inode: Option<Arc<SocketInode>>,
handle: GlobalSocketHandle, handle: GlobalSocketHandle,
posix_item: Arc<PosixSocketHandleItem>,
} }
impl StreamSocket { impl StreamSocket {
@ -36,16 +38,22 @@ impl StreamSocket {
options, options,
); );
let posix_item = Arc::new(PosixSocketHandleItem::new(None));
Self { Self {
metadata, metadata,
buffer, buffer,
peer_inode: None, peer_inode: None,
handle: GlobalSocketHandle::new_kernel_handle(), handle: GlobalSocketHandle::new_kernel_handle(),
posix_item,
} }
} }
} }
impl Socket for StreamSocket { impl Socket for StreamSocket {
fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
self.posix_item.clone()
}
fn socket_handle(&self) -> GlobalSocketHandle { fn socket_handle(&self) -> GlobalSocketHandle {
self.handle self.handle
} }
@ -121,6 +129,7 @@ pub struct SeqpacketSocket {
buffer: Arc<SpinLock<Vec<u8>>>, buffer: Arc<SpinLock<Vec<u8>>>,
peer_inode: Option<Arc<SocketInode>>, peer_inode: Option<Arc<SocketInode>>,
handle: GlobalSocketHandle, handle: GlobalSocketHandle,
posix_item: Arc<PosixSocketHandleItem>,
} }
impl SeqpacketSocket { impl SeqpacketSocket {
@ -144,16 +153,22 @@ impl SeqpacketSocket {
options, options,
); );
let posix_item = Arc::new(PosixSocketHandleItem::new(None));
Self { Self {
metadata, metadata,
buffer, buffer,
peer_inode: None, peer_inode: None,
handle: GlobalSocketHandle::new_kernel_handle(), handle: GlobalSocketHandle::new_kernel_handle(),
posix_item,
} }
} }
} }
impl Socket for SeqpacketSocket { impl Socket for SeqpacketSocket {
fn posix_item(&self) -> Arc<PosixSocketHandleItem> {
self.posix_item.clone()
}
fn close(&mut self) {} fn close(&mut self) {}
fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) { fn read(&self, buf: &mut [u8]) -> (Result<usize, SystemError>, Endpoint) {