2024-10-17 06:45:35 +00:00

477 lines
16 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

pub mod inner;
use alloc::{
string::String,
sync::{Arc, Weak},
};
use core::sync::atomic::{AtomicBool, Ordering};
use crate::sched::SchedMode;
use crate::{libs::rwlock::RwLock, net::socket::*};
use inner::*;
use system_error::SystemError;
use super::INODE_MAP;
type EP = EPollEventType;
#[derive(Debug)]
pub struct SeqpacketSocket {
inner: RwLock<Inner>,
shutdown: Shutdown,
is_nonblocking: AtomicBool,
wait_queue: WaitQueue,
self_ref: Weak<Self>,
}
impl SeqpacketSocket {
/// 默认的元数据缓冲区大小
pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024;
/// 默认的缓冲区大小
pub const DEFAULT_BUF_SIZE: usize = 64 * 1024;
pub fn new(is_nonblocking: bool) -> Arc<Self> {
Arc::new_cyclic(|me| Self {
inner: RwLock::new(Inner::Init(Init::new())),
shutdown: Shutdown::new(),
is_nonblocking: AtomicBool::new(is_nonblocking),
wait_queue: WaitQueue::default(),
self_ref: me.clone(),
})
}
pub fn new_inode(is_nonblocking: bool) -> Result<Arc<Inode>, SystemError> {
let socket = SeqpacketSocket::new(is_nonblocking);
let inode = Inode::new(socket.clone());
// 建立时绑定自身为后续能正常获取本端地址
let _ = match &mut *socket.inner.write() {
Inner::Init(init) => init.bind(Endpoint::Inode((inode.clone(), String::from("")))),
_ => return Err(SystemError::EINVAL),
};
return Ok(inode);
}
pub fn new_connected(connected: Connected, is_nonblocking: bool) -> Arc<Self> {
Arc::new_cyclic(|me| Self {
inner: RwLock::new(Inner::Connected(connected)),
shutdown: Shutdown::new(),
is_nonblocking: AtomicBool::new(is_nonblocking),
wait_queue: WaitQueue::default(),
self_ref: me.clone(),
})
}
pub fn new_pairs() -> Result<(Arc<Inode>, Arc<Inode>), SystemError> {
let socket0 = SeqpacketSocket::new(false);
let socket1 = SeqpacketSocket::new(false);
let inode0 = Inode::new(socket0.clone());
let inode1 = Inode::new(socket1.clone());
let (conn_0, conn_1) = Connected::new_pair(
Some(Endpoint::Inode((inode0.clone(), String::from("")))),
Some(Endpoint::Inode((inode1.clone(), String::from("")))),
);
*socket0.inner.write() = Inner::Connected(conn_0);
*socket1.inner.write() = Inner::Connected(conn_1);
return Ok((inode0, inode1));
}
fn try_accept(&self) -> Result<(Arc<Inode>, Endpoint), SystemError> {
match &*self.inner.read() {
Inner::Listen(listen) => listen.try_accept() as _,
_ => {
log::error!("the socket is not listening");
return Err(SystemError::EINVAL);
}
}
}
fn is_acceptable(&self) -> bool {
match &*self.inner.read() {
Inner::Listen(listen) => listen.is_acceptable(),
_ => {
panic!("the socket is not listening");
}
}
}
fn is_peer_shutdown(&self) -> Result<bool, SystemError> {
let peer_shutdown = match self.get_peer_name()? {
Endpoint::Inode((inode, _)) => Arc::downcast::<SeqpacketSocket>(inode.inner())
.map_err(|_| SystemError::EINVAL)?
.shutdown
.get()
.is_both_shutdown(),
_ => return Err(SystemError::EINVAL),
};
Ok(peer_shutdown)
}
fn can_recv(&self) -> Result<bool, SystemError> {
let can = match &*self.inner.read() {
Inner::Connected(connected) => connected.can_recv(),
_ => return Err(SystemError::ENOTCONN),
};
Ok(can)
}
fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Relaxed)
}
#[allow(dead_code)]
fn set_nonblocking(&self, nonblocking: bool) {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
}
}
impl Socket for SeqpacketSocket {
fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> {
let peer_inode = match endpoint {
Endpoint::Inode((inode, _)) => inode,
Endpoint::Unixpath((inode_id, _)) => {
let inode_guard = INODE_MAP.read_irqsave();
let inode = inode_guard.get(&inode_id).unwrap();
match inode {
Endpoint::Inode((inode, _)) => inode.clone(),
_ => return Err(SystemError::EINVAL),
}
}
_ => return Err(SystemError::EINVAL),
};
// 远端为服务端
let remote_socket = Arc::downcast::<SeqpacketSocket>(peer_inode.inner())
.map_err(|_| SystemError::EINVAL)?;
let client_epoint = match &mut *self.inner.write() {
Inner::Init(init) => match init.endpoint().cloned() {
Some(end) => {
log::debug!("bind when connect");
Some(end)
}
None => {
log::debug!("not bind when connect");
let inode = Inode::new(self.self_ref.upgrade().unwrap().clone());
let epoint = Endpoint::Inode((inode.clone(), String::from("")));
let _ = init.bind(epoint.clone());
Some(epoint)
}
},
Inner::Listen(_) => return Err(SystemError::EINVAL),
Inner::Connected(_) => return Err(SystemError::EISCONN),
};
// ***阻塞与非阻塞处理还未实现
// 客户端与服务端建立连接将服务端inode推入到自身的listen_incom队列中
// accept时从中获取推出对应的socket
match &*remote_socket.inner.read() {
Inner::Listen(listener) => match listener.push_incoming(client_epoint) {
Ok(connected) => {
*self.inner.write() = Inner::Connected(connected);
log::debug!("try to wake up");
remote_socket.wait_queue.wakeup(None);
return Ok(());
}
// ***错误处理
Err(_) => todo!(),
},
Inner::Init(_) => {
log::debug!("init einval");
return Err(SystemError::EINVAL);
}
Inner::Connected(_) => return Err(SystemError::EISCONN),
};
}
fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> {
// 将自身socket的inode与用户端提供路径的文件indoe_id进行绑定
match endpoint {
Endpoint::Unixpath((inodeid, path)) => {
let inode = match &mut *self.inner.write() {
Inner::Init(init) => init.bind_path(path)?,
_ => {
log::error!("socket has listen or connected");
return Err(SystemError::EINVAL);
}
};
INODE_MAP.write_irqsave().insert(inodeid, inode);
Ok(())
}
_ => return Err(SystemError::EINVAL),
}
}
fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> {
log::debug!("seqpacket shutdown");
match &*self.inner.write() {
Inner::Connected(connected) => connected.shutdown(how),
_ => Err(SystemError::EINVAL),
}
}
fn listen(&self, backlog: usize) -> Result<(), SystemError> {
let mut state = self.inner.write();
log::debug!("listen into socket");
let epoint = match &*state {
Inner::Init(init) => init.endpoint().ok_or(SystemError::EINVAL)?.clone(),
Inner::Listen(listener) => return listener.listen(backlog),
Inner::Connected(_) => {
log::error!("the socket is connected");
return Err(SystemError::EINVAL);
}
};
let listener = Listener::new(epoint, backlog);
*state = Inner::Listen(listener);
Ok(())
}
fn accept(&self) -> Result<(Arc<Inode>, Endpoint), SystemError> {
if !self.is_nonblocking() {
loop {
wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?;
match self.try_accept() {
Ok((socket, epoint)) => return Ok((socket, epoint)),
Err(_) => continue,
}
}
} else {
// ***非阻塞状态
todo!()
}
}
fn set_option(
&self,
_level: crate::net::socket::PSOL,
_optname: usize,
_optval: &[u8],
) -> Result<(), SystemError> {
log::warn!("setsockopt is not implemented");
Ok(())
}
fn wait_queue(&self) -> &WaitQueue {
return &self.wait_queue;
}
fn close(&self) -> Result<(), SystemError> {
// log::debug!("seqpacket close");
self.shutdown.recv_shutdown();
self.shutdown.send_shutdown();
Ok(())
}
fn get_peer_name(&self) -> Result<Endpoint, SystemError> {
// 获取对端地址
let endpoint = match &*self.inner.read() {
Inner::Connected(connected) => connected.peer_endpoint().cloned(),
_ => return Err(SystemError::ENOTCONN),
};
if let Some(endpoint) = endpoint {
return Ok(endpoint);
} else {
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
}
}
fn get_name(&self) -> Result<Endpoint, SystemError> {
// 获取本端地址
let endpoint = match &*self.inner.read() {
Inner::Init(init) => init.endpoint().cloned(),
Inner::Listen(listener) => Some(listener.endpoint().clone()),
Inner::Connected(connected) => connected.endpoint().cloned(),
};
if let Some(endpoint) = endpoint {
return Ok(endpoint);
} else {
return Err(SystemError::EAGAIN_OR_EWOULDBLOCK);
}
}
fn get_option(
&self,
_level: crate::net::socket::PSOL,
_name: usize,
_value: &mut [u8],
) -> Result<usize, SystemError> {
log::warn!("getsockopt is not implemented");
Ok(0)
}
fn read(&self, buffer: &mut [u8]) -> Result<usize, SystemError> {
self.recv(buffer, crate::net::socket::PMSG::empty())
}
fn recv(
&self,
buffer: &mut [u8],
flags: crate::net::socket::PMSG,
) -> Result<usize, SystemError> {
if flags.contains(PMSG::OOB) {
return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP);
}
if !flags.contains(PMSG::DONTWAIT) {
loop {
wq_wait_event_interruptible!(
self.wait_queue,
self.can_recv()? || self.is_peer_shutdown()?,
{}
)?;
// connect锁和flag判断顺序不正确应该先判断在
match &*self.inner.write() {
Inner::Connected(connected) => match connected.try_read(buffer) {
Ok(usize) => {
log::debug!("recv from successfully");
return Ok(usize);
}
Err(_) => continue,
},
_ => {
log::error!("the socket is not connected");
return Err(SystemError::ENOTCONN);
}
}
}
} else {
unimplemented!("unimplemented non_block")
}
}
fn recv_msg(
&self,
_msg: &mut crate::net::syscall::MsgHdr,
_flags: crate::net::socket::PMSG,
) -> Result<usize, SystemError> {
Err(SystemError::ENOSYS)
}
fn send(&self, buffer: &[u8], flags: crate::net::socket::PMSG) -> Result<usize, SystemError> {
if flags.contains(PMSG::OOB) {
return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP);
}
if self.is_peer_shutdown()? {
return Err(SystemError::EPIPE);
}
if !flags.contains(PMSG::DONTWAIT) {
loop {
match &*self.inner.write() {
Inner::Connected(connected) => match connected.try_write(buffer) {
Ok(usize) => {
log::debug!("send successfully");
return Ok(usize);
}
Err(_) => continue,
},
_ => {
log::error!("the socket is not connected");
return Err(SystemError::ENOTCONN);
}
}
}
} else {
unimplemented!("unimplemented non_block")
}
}
fn send_msg(
&self,
_msg: &crate::net::syscall::MsgHdr,
_flags: crate::net::socket::PMSG,
) -> Result<usize, SystemError> {
Err(SystemError::ENOSYS)
}
fn write(&self, buffer: &[u8]) -> Result<usize, SystemError> {
self.send(buffer, crate::net::socket::PMSG::empty())
}
fn recv_from(
&self,
buffer: &mut [u8],
flags: PMSG,
_address: Option<Endpoint>,
) -> Result<(usize, Endpoint), SystemError> {
// log::debug!("recvfrom flags {:?}", flags);
if flags.contains(PMSG::OOB) {
return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP);
}
if !flags.contains(PMSG::DONTWAIT) {
loop {
wq_wait_event_interruptible!(
self.wait_queue,
self.can_recv()? || self.is_peer_shutdown()?,
{}
)?;
// connect锁和flag判断顺序不正确应该先判断在
match &*self.inner.write() {
Inner::Connected(connected) => match connected.recv_slice(buffer) {
Ok(usize) => {
// log::debug!("recvs from successfully");
return Ok((usize, connected.peer_endpoint().unwrap().clone()));
}
Err(_) => continue,
},
_ => {
log::error!("the socket is not connected");
return Err(SystemError::ENOTCONN);
}
}
}
} else {
unimplemented!("unimplemented non_block")
}
//Err(SystemError::ENOSYS)
}
fn send_buffer_size(&self) -> usize {
log::warn!("using default buffer size");
SeqpacketSocket::DEFAULT_BUF_SIZE
}
fn recv_buffer_size(&self) -> usize {
log::warn!("using default buffer size");
SeqpacketSocket::DEFAULT_BUF_SIZE
}
fn poll(&self) -> usize {
let mut mask = EP::empty();
let shutdown = self.shutdown.get();
// 参考linux的unix_poll https://code.dragonos.org.cn/xref/linux-6.1.9/net/unix/af_unix.c#3152
// 用关闭读写端表示连接断开
if shutdown.is_both_shutdown() || self.is_peer_shutdown().unwrap() {
mask |= EP::EPOLLHUP;
}
if shutdown.is_recv_shutdown() {
mask |= EP::EPOLLRDHUP | EP::EPOLLIN | EP::EPOLLRDNORM;
}
match &*self.inner.read() {
Inner::Connected(connected) => {
if connected.can_recv() {
mask |= EP::EPOLLIN | EP::EPOLLRDNORM;
}
// if (sk_is_readable(sk))
// mask |= EPOLLIN | EPOLLRDNORM;
// TODO:处理紧急情况 EPOLLPRI
// TODO:处理连接是否关闭 EPOLLHUP
if !shutdown.is_send_shutdown() {
if connected.can_send().unwrap() {
mask |= EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND;
} else {
todo!("poll: buffer space not enough");
}
}
}
Inner::Listen(_) => mask |= EP::EPOLLIN,
Inner::Init(_) => mask |= EP::EPOLLOUT,
}
mask.bits() as usize
}
}