From 9b37918841128519bf56d053fd72bed39aa1e578 Mon Sep 17 00:00:00 2001 From: Samuka007 Date: Wed, 28 May 2025 12:31:13 +0800 Subject: [PATCH] fix status in connect, add connecting features --- Cargo.toml | 1 + src/main.rs | 83 ++++++++++++++++++++++- src/socket/endpoint.rs | 1 + src/socket/inet/common/mod.rs | 1 + src/socket/inet/stream/inner.rs | 112 ++++++++++++++++++++++---------- src/socket/inet/stream/mod.rs | 48 ++++++++++---- 6 files changed, 196 insertions(+), 50 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 63df627..f52db6b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ smoltcp = { version = "0.12.0", default-features = false, features = [ "medium-ethernet", "medium-ip", "proto-ipv4", + "proto-ipv6", "socket-udp", "socket-tcp", ]} diff --git a/src/main.rs b/src/main.rs index 1fef4f7..ab025ce 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::{net::Ipv4Addr, sync::Arc}; +use std::{io::{self, Read}, net::Ipv4Addr, sync::Arc}; use berkeley_socket::{ driver::{irq::start_network_polling_thread, tap::TapDevice}, @@ -80,6 +80,69 @@ fn make_tcp_echo() { } } +fn make_request() { + log::info!("Input a valid IP address and port to connect to:"); + let mut input = String::new(); + io::stdin().read_line(&mut input).unwrap(); + let parts: Vec<&str> = input.trim().split(':').collect(); + if parts.len() != 2 { + log::error!("Invalid input format. Use :."); + return; + } + let ip: Ipv4Addr = match parts[0].parse() { + Ok(ip) => ip, + Err(_) => { + log::error!("Invalid IP address."); + return; + } + }; + let port: u16 = match parts[1].parse() { + Ok(port) => port, + Err(_) => { + log::error!("Invalid port number."); + return; + } + }; + let endpoint = Endpoint::Ip(IpEndpoint::new(IpAddress::Ipv4(Ipv4Addr::from(ip)), port)); + + let socket = Inet::socket(SOCK::Stream, 0).unwrap(); + match socket.connect(endpoint) { + Ok(_) => { + log::info!("Connected to {}:{}", ip, port); + let mut buffer = [0u8; 1024]; + loop { + let len = io::stdin().read(&mut buffer).unwrap(); + if len == 0 { + break; // EOF + } + let sent_len = socket.write(&buffer[..len]).unwrap(); + log::info!("Sent {} bytes", sent_len); + match socket.read(&mut buffer) { + Ok(received_len) => { + if received_len == 0 { + log::info!("Socket closed by remote peer."); + break; + } + log::info!( + "Received {} bytes: {}", + received_len, + String::from_utf8_lossy(&buffer[..received_len]) + ); + } + Err(e) => { + log::error!("Socket read error: {}", e); + break; + } + } + } + } + Err(e) => { + log::error!("Failed to connect: {}", e); + } + } + log::info!("Connection closed."); +} + fn main() { env_logger::init(); let device = TapDevice::new("tap0", smoltcp::phy::Medium::Ethernet).unwrap(); @@ -105,6 +168,24 @@ fn main() { let tcp = std::thread::spawn(move || { make_tcp_echo(); }); + + loop { + let char = io::stdin().bytes().next().unwrap().unwrap(); + match char { + b'q' | b'Q' => { + log::info!("Exiting..."); + break; + } + b'r' => { + make_request(); + } + _ => { + log::info!("Press 'q' to exit."); + } + } + } + + // Optionally join threads before exiting udp.join().unwrap(); tcp.join().unwrap(); } diff --git a/src/socket/endpoint.rs b/src/socket/endpoint.rs index fba1b17..832169c 100644 --- a/src/socket/endpoint.rs +++ b/src/socket/endpoint.rs @@ -11,6 +11,7 @@ pub enum Endpoint { // LinkLayer(LinkLayerEndpoint), /// 网络层端点 Ip(IpEndpoint), + Other, // /// inode端点,Unix实际保存的端点 // Inode((Arc, String)), // /// Unix传递id索引和path所用的端点 diff --git a/src/socket/inet/common/mod.rs b/src/socket/inet/common/mod.rs index fa4ff6b..4b5abac 100644 --- a/src/socket/inet/common/mod.rs +++ b/src/socket/inet/common/mod.rs @@ -78,6 +78,7 @@ impl BoundInner { T: smoltcp::socket::AnySocket<'static>, { let (iface, address) = get_ephemeral_iface(&remote); + log::debug!("bind_ephemeral address: {}", address); // let bound_port = iface.port_manager().bind_ephemeral_port(socket_type)?; let handle = iface.sockets().lock().add(socket); // let endpoint = smoltcp::wire::IpEndpoint::new(local_addr, bound_port); diff --git a/src/socket/inet/stream/inner.rs b/src/socket/inet/stream/inner.rs index 9e5e44b..2360132 100644 --- a/src/socket/inet/stream/inner.rs +++ b/src/socket/inet/stream/inner.rs @@ -45,16 +45,6 @@ impl Init { Init::Unbound((Box::new(new_smoltcp_socket()), ver)) } - /// 传入一个已经绑定的socket - pub(super) fn new_bound(inner: socket::inet::BoundInner) -> Self { - let endpoint = inner.with::(|socket| { - socket - .local_endpoint() - .expect("A Bound Socket Must Have A Local Endpoint") - }); - Init::Bound((inner, endpoint)) - } - pub(super) fn bind( self, local_endpoint: smoltcp::wire::IpEndpoint, @@ -116,7 +106,7 @@ impl Init { .map_err(|_| SystemError::ECONNREFUSED) }); match result { - Ok(_) => Ok(Connecting::new(inner)), + Ok(_) => Ok(Connecting::new(inner, local.addr.version())), Err(err) => Err((Init::Bound((inner, local)), err)), } } @@ -193,13 +183,18 @@ enum ConnectResult { #[derive(Debug)] pub struct Connecting { inner: socket::inet::BoundInner, + version: smoltcp::wire::IpVersion, result: RwLock, } impl Connecting { - fn new(inner: socket::inet::BoundInner) -> Self { + fn new( + inner: socket::inet::BoundInner, + version: smoltcp::wire::IpVersion, + ) -> Self { Connecting { inner, + version, result: RwLock::new(ConnectResult::Connecting), } } @@ -212,6 +207,7 @@ impl Connecting { } pub fn into_result(self) -> (Inner, Result<(), SystemError>) { + // log::debug!("Into_result {:?}", self.inner); let result = *self.result.read(); match result { ConnectResult::Connecting => (Inner::Connecting(self), Err(SystemError::EAGAIN)), @@ -220,7 +216,7 @@ impl Connecting { Ok(()), ), ConnectResult::Refused => ( - Inner::Init(Init::new_bound(self.inner)), + Inner::Init(Init::new(self.version)), Err(SystemError::ECONNREFUSED), ), } @@ -237,26 +233,25 @@ impl Connecting { /// _exactly_ once. The caller is responsible for not missing this event. #[must_use] pub(super) fn update_io_events(&self) -> bool { - // if matches!(*self.result.read_irqsave(), ConnectResult::Connecting) { - // return false; - // } - self.inner .with_mut(|socket: &mut smoltcp::socket::tcp::Socket| { let mut result = self.result.write(); if matches!(*result, ConnectResult::Refused | ConnectResult::Connected) { - return false; // Already connected or refused + log::warn!( + "update_io_events called on a Connecting socket that is already {:?}", + *result + ); + return true; // Already connected or refused, so shouldn't in this state, trigger update! } // Connected - if socket.can_send() { + if socket.may_send() { log::debug!("can send"); *result = ConnectResult::Connected; return true; } // Connecting if socket.is_open() { - log::debug!("connecting"); *result = ConnectResult::Connecting; return false; } @@ -416,17 +411,39 @@ impl Established { pub fn recv_slice(&self, buf: &mut [u8]) -> Result { self.inner .with_mut::(|socket| { - if socket.can_send() { - match socket.recv_slice(buf) { - Ok(size) => Ok(size), - Err(tcp::RecvError::InvalidState) => { - log::error!("TcpSocket::try_recv: InvalidState"); - Err(SystemError::ENOTCONN) + match socket.recv_slice(buf) { + Ok(size) => Ok(size), + Err(tcp::RecvError::InvalidState) => { + socket.may_recv(); + use smoltcp::socket::tcp::State; + match socket.state() { + // Not ENOTCONN since the socket is in established state + State::Closed => Err(SystemError::ECONNRESET), + + // remote sent FIN + State::Closing + | State::LastAck + | State::TimeWait + | State::CloseWait => { + log::debug!("TCP state: {:?}, recv will return 0", socket.state()); + Ok(0) // return 0 to indicate EOF + } + + // Socket should not be in these state + State::Listen | State::SynReceived | State::SynSent => { + log::error!("Unexpected TCP state: {:?}", socket.state()); + Err(SystemError::ECONNRESET) // return reset to drop this error socket, not stadard behavior + }, + + // already checked in `can_recv()` + State::Established + | State::FinWait1 + | State::FinWait2 => { + unreachable!("Should be able to recv: {:?}", socket.state()) + } } - Err(tcp::RecvError::Finished) => Ok(0), } - } else { - Err(SystemError::ENOBUFS) + Err(tcp::RecvError::Finished) => Ok(0), } }) } @@ -434,12 +451,37 @@ impl Established { pub fn send_slice(&self, buf: &[u8]) -> Result { self.inner .with_mut::(|socket| { - if socket.can_send() { - socket - .send_slice(buf) - .map_err(|_| SystemError::ECONNABORTED) - } else { - Err(SystemError::ENOBUFS) + match socket.send_slice(buf) { + Ok(0) => Err(SystemError::EAGAIN), + Ok(size) => Ok(size), + Err(tcp::SendError::InvalidState) => { + use smoltcp::socket::tcp::State; + match socket.state() { + // Not ENOTCONN since the socket is in established state + State::Closed => Err(SystemError::ECONNRESET), + + // Socket is already closed by us + State::LastAck + | State::TimeWait + | State::Closing + | State::FinWait1 + | State::FinWait2 => Err(SystemError::EPIPE), + + // Socket should not be in these state + State::Listen | State::SynReceived | State::SynSent => { + log::error!("Unexpected TCP state: {:?}", socket.state()); + Err(SystemError::ECONNRESET) // return reset to drop this error socket, not stadard behavior + }, + + // these states are already checked in `can_send()` + State::Established + // In CLOSE-WAIT, the remote endpoint has closed our receive half of the connection + // but we still can transmit indefinitely. + | State::CloseWait => { + unreachable!("Should be able to send: {:?}", socket.state()) + } + } + } } }) } diff --git a/src/socket/inet/stream/mod.rs b/src/socket/inet/stream/mod.rs index 26c1cbd..8852c8e 100644 --- a/src/socket/inet/stream/mod.rs +++ b/src/socket/inet/stream/mod.rs @@ -250,9 +250,13 @@ impl TcpSocket { } } - fn incoming(&self) -> bool { + fn is_epoll_in(&self) -> bool { EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLIN) } + + fn is_epoll_out(&self) -> bool { + EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLOUT) + } } impl Socket for TcpSocket { @@ -269,7 +273,7 @@ impl Socket for TcpSocket { { inner::Inner::Init(inner::Init::Unbound((_, ver))) => Ok(Endpoint::Ip(match ver { smoltcp::wire::IpVersion::Ipv4 => UNSPECIFIED_LOCAL_ENDPOINT_V4, - // smoltcp::wire::IpVersion::Ipv6 => UNSPECIFIED_LOCAL_ENDPOINT_V6, + smoltcp::wire::IpVersion::Ipv6 => todo!("UNSPECIFIED_LOCAL_ENDPOINT_V6"), })), inner::Inner::Init(inner::Init::Bound((_, local))) => Ok(Endpoint::Ip(*local)), inner::Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_name())), @@ -324,27 +328,43 @@ impl Socket for TcpSocket { } fn accept(&self) -> Result<(Arc, Endpoint), SystemError> { - if self.is_nonblock() { - self.try_accept() - } else { - loop { - match self.try_accept() { - Err(SystemError::EAGAIN) => { - wq_wait_event_interruptible(&self.wait_queue, || self.incoming(), None)?; - } - result => break result, + loop { + match self.try_accept() { + Err(SystemError::EAGAIN) if self.is_nonblock() => break Err(SystemError::EAGAIN), + Err(SystemError::EAGAIN) => { + wq_wait_event_interruptible(&self.wait_queue, || self.is_epoll_in(), None)?; + } + result => { + break result.map(|(inner, endpoint)| { + (inner as Arc, Endpoint::Ip(endpoint)) + }) } } } - .map(|(inner, endpoint)| (inner as Arc, Endpoint::Ip(endpoint))) } fn recv(&self, buffer: &mut [u8], _flags: PMSG) -> Result { - self.try_recv(buffer) + loop { + match self.try_recv(buffer) { + Err(SystemError::EAGAIN) if self.is_nonblock() => break Err(SystemError::EAGAIN), + Err(SystemError::EAGAIN) => { + wq_wait_event_interruptible(&self.wait_queue, || self.is_epoll_in(), None)?; + } + result => break result, + } + } } fn send(&self, buffer: &[u8], _flags: PMSG) -> Result { - self.try_send(buffer) + loop { + match self.try_send(buffer) { + Err(SystemError::EAGAIN) if self.is_nonblock() => break Err(SystemError::EAGAIN), + Err(SystemError::EAGAIN) => { + wq_wait_event_interruptible(&self.wait_queue, || self.is_epoll_out(), None)?; + } + result => break result, + } + } } fn send_buffer_size(&self) -> usize {