fix status in connect, add connecting features

This commit is contained in:
Samuka007 2025-05-28 12:31:13 +08:00
parent a13e8a2da1
commit 9b37918841
6 changed files with 196 additions and 50 deletions

View File

@ -11,6 +11,7 @@ smoltcp = { version = "0.12.0", default-features = false, features = [
"medium-ethernet", "medium-ethernet",
"medium-ip", "medium-ip",
"proto-ipv4", "proto-ipv4",
"proto-ipv6",
"socket-udp", "socket-udp",
"socket-tcp", "socket-tcp",
]} ]}

View File

@ -1,4 +1,4 @@
use std::{net::Ipv4Addr, sync::Arc}; use std::{io::{self, Read}, net::Ipv4Addr, sync::Arc};
use berkeley_socket::{ use berkeley_socket::{
driver::{irq::start_network_polling_thread, tap::TapDevice}, driver::{irq::start_network_polling_thread, tap::TapDevice},
@ -80,6 +80,69 @@ fn make_tcp_echo() {
} }
} }
fn make_request() {
log::info!("Input a valid IP address and port to connect to:");
let mut input = String::new();
io::stdin().read_line(&mut input).unwrap();
let parts: Vec<&str> = input.trim().split(':').collect();
if parts.len() != 2 {
log::error!("Invalid input format. Use <IP>:<port>.");
return;
}
let ip: Ipv4Addr = match parts[0].parse() {
Ok(ip) => ip,
Err(_) => {
log::error!("Invalid IP address.");
return;
}
};
let port: u16 = match parts[1].parse() {
Ok(port) => port,
Err(_) => {
log::error!("Invalid port number.");
return;
}
};
let endpoint = Endpoint::Ip(IpEndpoint::new(IpAddress::Ipv4(Ipv4Addr::from(ip)), port));
let socket = Inet::socket(SOCK::Stream, 0).unwrap();
match socket.connect(endpoint) {
Ok(_) => {
log::info!("Connected to {}:{}", ip, port);
let mut buffer = [0u8; 1024];
loop {
let len = io::stdin().read(&mut buffer).unwrap();
if len == 0 {
break; // EOF
}
let sent_len = socket.write(&buffer[..len]).unwrap();
log::info!("Sent {} bytes", sent_len);
match socket.read(&mut buffer) {
Ok(received_len) => {
if received_len == 0 {
log::info!("Socket closed by remote peer.");
break;
}
log::info!(
"Received {} bytes: {}",
received_len,
String::from_utf8_lossy(&buffer[..received_len])
);
}
Err(e) => {
log::error!("Socket read error: {}", e);
break;
}
}
}
}
Err(e) => {
log::error!("Failed to connect: {}", e);
}
}
log::info!("Connection closed.");
}
fn main() { fn main() {
env_logger::init(); env_logger::init();
let device = TapDevice::new("tap0", smoltcp::phy::Medium::Ethernet).unwrap(); let device = TapDevice::new("tap0", smoltcp::phy::Medium::Ethernet).unwrap();
@ -105,6 +168,24 @@ fn main() {
let tcp = std::thread::spawn(move || { let tcp = std::thread::spawn(move || {
make_tcp_echo(); make_tcp_echo();
}); });
loop {
let char = io::stdin().bytes().next().unwrap().unwrap();
match char {
b'q' | b'Q' => {
log::info!("Exiting...");
break;
}
b'r' => {
make_request();
}
_ => {
log::info!("Press 'q' to exit.");
}
}
}
// Optionally join threads before exiting
udp.join().unwrap(); udp.join().unwrap();
tcp.join().unwrap(); tcp.join().unwrap();
} }

View File

@ -11,6 +11,7 @@ pub enum Endpoint {
// LinkLayer(LinkLayerEndpoint), // LinkLayer(LinkLayerEndpoint),
/// 网络层端点 /// 网络层端点
Ip(IpEndpoint), Ip(IpEndpoint),
Other,
// /// inode端点,Unix实际保存的端点 // /// inode端点,Unix实际保存的端点
// Inode((Arc<socket::SocketInode>, String)), // Inode((Arc<socket::SocketInode>, String)),
// /// Unix传递id索引和path所用的端点 // /// Unix传递id索引和path所用的端点

View File

@ -78,6 +78,7 @@ impl BoundInner {
T: smoltcp::socket::AnySocket<'static>, T: smoltcp::socket::AnySocket<'static>,
{ {
let (iface, address) = get_ephemeral_iface(&remote); let (iface, address) = get_ephemeral_iface(&remote);
log::debug!("bind_ephemeral address: {}", address);
// let bound_port = iface.port_manager().bind_ephemeral_port(socket_type)?; // let bound_port = iface.port_manager().bind_ephemeral_port(socket_type)?;
let handle = iface.sockets().lock().add(socket); let handle = iface.sockets().lock().add(socket);
// let endpoint = smoltcp::wire::IpEndpoint::new(local_addr, bound_port); // let endpoint = smoltcp::wire::IpEndpoint::new(local_addr, bound_port);

View File

@ -45,16 +45,6 @@ impl Init {
Init::Unbound((Box::new(new_smoltcp_socket()), ver)) Init::Unbound((Box::new(new_smoltcp_socket()), ver))
} }
/// 传入一个已经绑定的socket
pub(super) fn new_bound(inner: socket::inet::BoundInner) -> Self {
let endpoint = inner.with::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
socket
.local_endpoint()
.expect("A Bound Socket Must Have A Local Endpoint")
});
Init::Bound((inner, endpoint))
}
pub(super) fn bind( pub(super) fn bind(
self, self,
local_endpoint: smoltcp::wire::IpEndpoint, local_endpoint: smoltcp::wire::IpEndpoint,
@ -116,7 +106,7 @@ impl Init {
.map_err(|_| SystemError::ECONNREFUSED) .map_err(|_| SystemError::ECONNREFUSED)
}); });
match result { match result {
Ok(_) => Ok(Connecting::new(inner)), Ok(_) => Ok(Connecting::new(inner, local.addr.version())),
Err(err) => Err((Init::Bound((inner, local)), err)), Err(err) => Err((Init::Bound((inner, local)), err)),
} }
} }
@ -193,13 +183,18 @@ enum ConnectResult {
#[derive(Debug)] #[derive(Debug)]
pub struct Connecting { pub struct Connecting {
inner: socket::inet::BoundInner, inner: socket::inet::BoundInner,
version: smoltcp::wire::IpVersion,
result: RwLock<ConnectResult>, result: RwLock<ConnectResult>,
} }
impl Connecting { impl Connecting {
fn new(inner: socket::inet::BoundInner) -> Self { fn new(
inner: socket::inet::BoundInner,
version: smoltcp::wire::IpVersion,
) -> Self {
Connecting { Connecting {
inner, inner,
version,
result: RwLock::new(ConnectResult::Connecting), result: RwLock::new(ConnectResult::Connecting),
} }
} }
@ -212,6 +207,7 @@ impl Connecting {
} }
pub fn into_result(self) -> (Inner, Result<(), SystemError>) { pub fn into_result(self) -> (Inner, Result<(), SystemError>) {
// log::debug!("Into_result {:?}", self.inner);
let result = *self.result.read(); let result = *self.result.read();
match result { match result {
ConnectResult::Connecting => (Inner::Connecting(self), Err(SystemError::EAGAIN)), ConnectResult::Connecting => (Inner::Connecting(self), Err(SystemError::EAGAIN)),
@ -220,7 +216,7 @@ impl Connecting {
Ok(()), Ok(()),
), ),
ConnectResult::Refused => ( ConnectResult::Refused => (
Inner::Init(Init::new_bound(self.inner)), Inner::Init(Init::new(self.version)),
Err(SystemError::ECONNREFUSED), Err(SystemError::ECONNREFUSED),
), ),
} }
@ -237,26 +233,25 @@ impl Connecting {
/// _exactly_ once. The caller is responsible for not missing this event. /// _exactly_ once. The caller is responsible for not missing this event.
#[must_use] #[must_use]
pub(super) fn update_io_events(&self) -> bool { pub(super) fn update_io_events(&self) -> bool {
// if matches!(*self.result.read_irqsave(), ConnectResult::Connecting) {
// return false;
// }
self.inner self.inner
.with_mut(|socket: &mut smoltcp::socket::tcp::Socket| { .with_mut(|socket: &mut smoltcp::socket::tcp::Socket| {
let mut result = self.result.write(); let mut result = self.result.write();
if matches!(*result, ConnectResult::Refused | ConnectResult::Connected) { if matches!(*result, ConnectResult::Refused | ConnectResult::Connected) {
return false; // Already connected or refused log::warn!(
"update_io_events called on a Connecting socket that is already {:?}",
*result
);
return true; // Already connected or refused, so shouldn't in this state, trigger update!
} }
// Connected // Connected
if socket.can_send() { if socket.may_send() {
log::debug!("can send"); log::debug!("can send");
*result = ConnectResult::Connected; *result = ConnectResult::Connected;
return true; return true;
} }
// Connecting // Connecting
if socket.is_open() { if socket.is_open() {
log::debug!("connecting");
*result = ConnectResult::Connecting; *result = ConnectResult::Connecting;
return false; return false;
} }
@ -416,30 +411,77 @@ impl Established {
pub fn recv_slice(&self, buf: &mut [u8]) -> Result<usize, SystemError> { pub fn recv_slice(&self, buf: &mut [u8]) -> Result<usize, SystemError> {
self.inner self.inner
.with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| { .with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
if socket.can_send() {
match socket.recv_slice(buf) { match socket.recv_slice(buf) {
Ok(size) => Ok(size), Ok(size) => Ok(size),
Err(tcp::RecvError::InvalidState) => { Err(tcp::RecvError::InvalidState) => {
log::error!("TcpSocket::try_recv: InvalidState"); socket.may_recv();
Err(SystemError::ENOTCONN) use smoltcp::socket::tcp::State;
match socket.state() {
// Not ENOTCONN since the socket is in established state
State::Closed => Err(SystemError::ECONNRESET),
// remote sent FIN
State::Closing
| State::LastAck
| State::TimeWait
| State::CloseWait => {
log::debug!("TCP state: {:?}, recv will return 0", socket.state());
Ok(0) // return 0 to indicate EOF
}
// Socket should not be in these state
State::Listen | State::SynReceived | State::SynSent => {
log::error!("Unexpected TCP state: {:?}", socket.state());
Err(SystemError::ECONNRESET) // return reset to drop this error socket, not stadard behavior
},
// already checked in `can_recv()`
State::Established
| State::FinWait1
| State::FinWait2 => {
unreachable!("Should be able to recv: {:?}", socket.state())
}
}
} }
Err(tcp::RecvError::Finished) => Ok(0), Err(tcp::RecvError::Finished) => Ok(0),
} }
} else {
Err(SystemError::ENOBUFS)
}
}) })
} }
pub fn send_slice(&self, buf: &[u8]) -> Result<usize, SystemError> { pub fn send_slice(&self, buf: &[u8]) -> Result<usize, SystemError> {
self.inner self.inner
.with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| { .with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
if socket.can_send() { match socket.send_slice(buf) {
socket Ok(0) => Err(SystemError::EAGAIN),
.send_slice(buf) Ok(size) => Ok(size),
.map_err(|_| SystemError::ECONNABORTED) Err(tcp::SendError::InvalidState) => {
} else { use smoltcp::socket::tcp::State;
Err(SystemError::ENOBUFS) match socket.state() {
// Not ENOTCONN since the socket is in established state
State::Closed => Err(SystemError::ECONNRESET),
// Socket is already closed by us
State::LastAck
| State::TimeWait
| State::Closing
| State::FinWait1
| State::FinWait2 => Err(SystemError::EPIPE),
// Socket should not be in these state
State::Listen | State::SynReceived | State::SynSent => {
log::error!("Unexpected TCP state: {:?}", socket.state());
Err(SystemError::ECONNRESET) // return reset to drop this error socket, not stadard behavior
},
// these states are already checked in `can_send()`
State::Established
// In CLOSE-WAIT, the remote endpoint has closed our receive half of the connection
// but we still can transmit indefinitely.
| State::CloseWait => {
unreachable!("Should be able to send: {:?}", socket.state())
}
}
}
} }
}) })
} }

View File

@ -250,9 +250,13 @@ impl TcpSocket {
} }
} }
fn incoming(&self) -> bool { fn is_epoll_in(&self) -> bool {
EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLIN) EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLIN)
} }
fn is_epoll_out(&self) -> bool {
EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLOUT)
}
} }
impl Socket for TcpSocket { impl Socket for TcpSocket {
@ -269,7 +273,7 @@ impl Socket for TcpSocket {
{ {
inner::Inner::Init(inner::Init::Unbound((_, ver))) => Ok(Endpoint::Ip(match ver { inner::Inner::Init(inner::Init::Unbound((_, ver))) => Ok(Endpoint::Ip(match ver {
smoltcp::wire::IpVersion::Ipv4 => UNSPECIFIED_LOCAL_ENDPOINT_V4, smoltcp::wire::IpVersion::Ipv4 => UNSPECIFIED_LOCAL_ENDPOINT_V4,
// smoltcp::wire::IpVersion::Ipv6 => UNSPECIFIED_LOCAL_ENDPOINT_V6, smoltcp::wire::IpVersion::Ipv6 => todo!("UNSPECIFIED_LOCAL_ENDPOINT_V6"),
})), })),
inner::Inner::Init(inner::Init::Bound((_, local))) => Ok(Endpoint::Ip(*local)), inner::Inner::Init(inner::Init::Bound((_, local))) => Ok(Endpoint::Ip(*local)),
inner::Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_name())), inner::Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_name())),
@ -324,27 +328,43 @@ impl Socket for TcpSocket {
} }
fn accept(&self) -> Result<(Arc<dyn Socket>, Endpoint), SystemError> { fn accept(&self) -> Result<(Arc<dyn Socket>, Endpoint), SystemError> {
if self.is_nonblock() {
self.try_accept()
} else {
loop { loop {
match self.try_accept() { match self.try_accept() {
Err(SystemError::EAGAIN) if self.is_nonblock() => break Err(SystemError::EAGAIN),
Err(SystemError::EAGAIN) => { Err(SystemError::EAGAIN) => {
wq_wait_event_interruptible(&self.wait_queue, || self.incoming(), None)?; wq_wait_event_interruptible(&self.wait_queue, || self.is_epoll_in(), None)?;
}
result => {
break result.map(|(inner, endpoint)| {
(inner as Arc<dyn Socket>, Endpoint::Ip(endpoint))
})
}
}
}
}
fn recv(&self, buffer: &mut [u8], _flags: PMSG) -> Result<usize, SystemError> {
loop {
match self.try_recv(buffer) {
Err(SystemError::EAGAIN) if self.is_nonblock() => break Err(SystemError::EAGAIN),
Err(SystemError::EAGAIN) => {
wq_wait_event_interruptible(&self.wait_queue, || self.is_epoll_in(), None)?;
} }
result => break result, result => break result,
} }
} }
} }
.map(|(inner, endpoint)| (inner as Arc<dyn Socket>, Endpoint::Ip(endpoint)))
}
fn recv(&self, buffer: &mut [u8], _flags: PMSG) -> Result<usize, SystemError> {
self.try_recv(buffer)
}
fn send(&self, buffer: &[u8], _flags: PMSG) -> Result<usize, SystemError> { fn send(&self, buffer: &[u8], _flags: PMSG) -> Result<usize, SystemError> {
self.try_send(buffer) loop {
match self.try_send(buffer) {
Err(SystemError::EAGAIN) if self.is_nonblock() => break Err(SystemError::EAGAIN),
Err(SystemError::EAGAIN) => {
wq_wait_event_interruptible(&self.wait_queue, || self.is_epoll_out(), None)?;
}
result => break result,
}
}
} }
fn send_buffer_size(&self) -> usize { fn send_buffer_size(&self) -> usize {