mirror of
https://github.com/DragonOS-Community/dragonos-berkeley-socket.git
synced 2025-06-08 07:35:03 +00:00
fix status in connect, add connecting features
This commit is contained in:
parent
a13e8a2da1
commit
9b37918841
@ -11,6 +11,7 @@ smoltcp = { version = "0.12.0", default-features = false, features = [
|
||||
"medium-ethernet",
|
||||
"medium-ip",
|
||||
"proto-ipv4",
|
||||
"proto-ipv6",
|
||||
"socket-udp",
|
||||
"socket-tcp",
|
||||
]}
|
||||
|
83
src/main.rs
83
src/main.rs
@ -1,4 +1,4 @@
|
||||
use std::{net::Ipv4Addr, sync::Arc};
|
||||
use std::{io::{self, Read}, net::Ipv4Addr, sync::Arc};
|
||||
|
||||
use berkeley_socket::{
|
||||
driver::{irq::start_network_polling_thread, tap::TapDevice},
|
||||
@ -80,6 +80,69 @@ fn make_tcp_echo() {
|
||||
}
|
||||
}
|
||||
|
||||
fn make_request() {
|
||||
log::info!("Input a valid IP address and port to connect to:");
|
||||
let mut input = String::new();
|
||||
io::stdin().read_line(&mut input).unwrap();
|
||||
let parts: Vec<&str> = input.trim().split(':').collect();
|
||||
if parts.len() != 2 {
|
||||
log::error!("Invalid input format. Use <IP>:<port>.");
|
||||
return;
|
||||
}
|
||||
let ip: Ipv4Addr = match parts[0].parse() {
|
||||
Ok(ip) => ip,
|
||||
Err(_) => {
|
||||
log::error!("Invalid IP address.");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let port: u16 = match parts[1].parse() {
|
||||
Ok(port) => port,
|
||||
Err(_) => {
|
||||
log::error!("Invalid port number.");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let endpoint = Endpoint::Ip(IpEndpoint::new(IpAddress::Ipv4(Ipv4Addr::from(ip)), port));
|
||||
|
||||
let socket = Inet::socket(SOCK::Stream, 0).unwrap();
|
||||
match socket.connect(endpoint) {
|
||||
Ok(_) => {
|
||||
log::info!("Connected to {}:{}", ip, port);
|
||||
let mut buffer = [0u8; 1024];
|
||||
loop {
|
||||
let len = io::stdin().read(&mut buffer).unwrap();
|
||||
if len == 0 {
|
||||
break; // EOF
|
||||
}
|
||||
let sent_len = socket.write(&buffer[..len]).unwrap();
|
||||
log::info!("Sent {} bytes", sent_len);
|
||||
match socket.read(&mut buffer) {
|
||||
Ok(received_len) => {
|
||||
if received_len == 0 {
|
||||
log::info!("Socket closed by remote peer.");
|
||||
break;
|
||||
}
|
||||
log::info!(
|
||||
"Received {} bytes: {}",
|
||||
received_len,
|
||||
String::from_utf8_lossy(&buffer[..received_len])
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Socket read error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Failed to connect: {}", e);
|
||||
}
|
||||
}
|
||||
log::info!("Connection closed.");
|
||||
}
|
||||
|
||||
fn main() {
|
||||
env_logger::init();
|
||||
let device = TapDevice::new("tap0", smoltcp::phy::Medium::Ethernet).unwrap();
|
||||
@ -105,6 +168,24 @@ fn main() {
|
||||
let tcp = std::thread::spawn(move || {
|
||||
make_tcp_echo();
|
||||
});
|
||||
|
||||
loop {
|
||||
let char = io::stdin().bytes().next().unwrap().unwrap();
|
||||
match char {
|
||||
b'q' | b'Q' => {
|
||||
log::info!("Exiting...");
|
||||
break;
|
||||
}
|
||||
b'r' => {
|
||||
make_request();
|
||||
}
|
||||
_ => {
|
||||
log::info!("Press 'q' to exit.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Optionally join threads before exiting
|
||||
udp.join().unwrap();
|
||||
tcp.join().unwrap();
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ pub enum Endpoint {
|
||||
// LinkLayer(LinkLayerEndpoint),
|
||||
/// 网络层端点
|
||||
Ip(IpEndpoint),
|
||||
Other,
|
||||
// /// inode端点,Unix实际保存的端点
|
||||
// Inode((Arc<socket::SocketInode>, String)),
|
||||
// /// Unix传递id索引和path所用的端点
|
||||
|
@ -78,6 +78,7 @@ impl BoundInner {
|
||||
T: smoltcp::socket::AnySocket<'static>,
|
||||
{
|
||||
let (iface, address) = get_ephemeral_iface(&remote);
|
||||
log::debug!("bind_ephemeral address: {}", address);
|
||||
// let bound_port = iface.port_manager().bind_ephemeral_port(socket_type)?;
|
||||
let handle = iface.sockets().lock().add(socket);
|
||||
// let endpoint = smoltcp::wire::IpEndpoint::new(local_addr, bound_port);
|
||||
|
@ -45,16 +45,6 @@ impl Init {
|
||||
Init::Unbound((Box::new(new_smoltcp_socket()), ver))
|
||||
}
|
||||
|
||||
/// 传入一个已经绑定的socket
|
||||
pub(super) fn new_bound(inner: socket::inet::BoundInner) -> Self {
|
||||
let endpoint = inner.with::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
|
||||
socket
|
||||
.local_endpoint()
|
||||
.expect("A Bound Socket Must Have A Local Endpoint")
|
||||
});
|
||||
Init::Bound((inner, endpoint))
|
||||
}
|
||||
|
||||
pub(super) fn bind(
|
||||
self,
|
||||
local_endpoint: smoltcp::wire::IpEndpoint,
|
||||
@ -116,7 +106,7 @@ impl Init {
|
||||
.map_err(|_| SystemError::ECONNREFUSED)
|
||||
});
|
||||
match result {
|
||||
Ok(_) => Ok(Connecting::new(inner)),
|
||||
Ok(_) => Ok(Connecting::new(inner, local.addr.version())),
|
||||
Err(err) => Err((Init::Bound((inner, local)), err)),
|
||||
}
|
||||
}
|
||||
@ -193,13 +183,18 @@ enum ConnectResult {
|
||||
#[derive(Debug)]
|
||||
pub struct Connecting {
|
||||
inner: socket::inet::BoundInner,
|
||||
version: smoltcp::wire::IpVersion,
|
||||
result: RwLock<ConnectResult>,
|
||||
}
|
||||
|
||||
impl Connecting {
|
||||
fn new(inner: socket::inet::BoundInner) -> Self {
|
||||
fn new(
|
||||
inner: socket::inet::BoundInner,
|
||||
version: smoltcp::wire::IpVersion,
|
||||
) -> Self {
|
||||
Connecting {
|
||||
inner,
|
||||
version,
|
||||
result: RwLock::new(ConnectResult::Connecting),
|
||||
}
|
||||
}
|
||||
@ -212,6 +207,7 @@ impl Connecting {
|
||||
}
|
||||
|
||||
pub fn into_result(self) -> (Inner, Result<(), SystemError>) {
|
||||
// log::debug!("Into_result {:?}", self.inner);
|
||||
let result = *self.result.read();
|
||||
match result {
|
||||
ConnectResult::Connecting => (Inner::Connecting(self), Err(SystemError::EAGAIN)),
|
||||
@ -220,7 +216,7 @@ impl Connecting {
|
||||
Ok(()),
|
||||
),
|
||||
ConnectResult::Refused => (
|
||||
Inner::Init(Init::new_bound(self.inner)),
|
||||
Inner::Init(Init::new(self.version)),
|
||||
Err(SystemError::ECONNREFUSED),
|
||||
),
|
||||
}
|
||||
@ -237,26 +233,25 @@ impl Connecting {
|
||||
/// _exactly_ once. The caller is responsible for not missing this event.
|
||||
#[must_use]
|
||||
pub(super) fn update_io_events(&self) -> bool {
|
||||
// if matches!(*self.result.read_irqsave(), ConnectResult::Connecting) {
|
||||
// return false;
|
||||
// }
|
||||
|
||||
self.inner
|
||||
.with_mut(|socket: &mut smoltcp::socket::tcp::Socket| {
|
||||
let mut result = self.result.write();
|
||||
if matches!(*result, ConnectResult::Refused | ConnectResult::Connected) {
|
||||
return false; // Already connected or refused
|
||||
log::warn!(
|
||||
"update_io_events called on a Connecting socket that is already {:?}",
|
||||
*result
|
||||
);
|
||||
return true; // Already connected or refused, so shouldn't in this state, trigger update!
|
||||
}
|
||||
|
||||
// Connected
|
||||
if socket.can_send() {
|
||||
if socket.may_send() {
|
||||
log::debug!("can send");
|
||||
*result = ConnectResult::Connected;
|
||||
return true;
|
||||
}
|
||||
// Connecting
|
||||
if socket.is_open() {
|
||||
log::debug!("connecting");
|
||||
*result = ConnectResult::Connecting;
|
||||
return false;
|
||||
}
|
||||
@ -416,30 +411,77 @@ impl Established {
|
||||
pub fn recv_slice(&self, buf: &mut [u8]) -> Result<usize, SystemError> {
|
||||
self.inner
|
||||
.with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
|
||||
if socket.can_send() {
|
||||
match socket.recv_slice(buf) {
|
||||
Ok(size) => Ok(size),
|
||||
Err(tcp::RecvError::InvalidState) => {
|
||||
log::error!("TcpSocket::try_recv: InvalidState");
|
||||
Err(SystemError::ENOTCONN)
|
||||
socket.may_recv();
|
||||
use smoltcp::socket::tcp::State;
|
||||
match socket.state() {
|
||||
// Not ENOTCONN since the socket is in established state
|
||||
State::Closed => Err(SystemError::ECONNRESET),
|
||||
|
||||
// remote sent FIN
|
||||
State::Closing
|
||||
| State::LastAck
|
||||
| State::TimeWait
|
||||
| State::CloseWait => {
|
||||
log::debug!("TCP state: {:?}, recv will return 0", socket.state());
|
||||
Ok(0) // return 0 to indicate EOF
|
||||
}
|
||||
|
||||
// Socket should not be in these state
|
||||
State::Listen | State::SynReceived | State::SynSent => {
|
||||
log::error!("Unexpected TCP state: {:?}", socket.state());
|
||||
Err(SystemError::ECONNRESET) // return reset to drop this error socket, not stadard behavior
|
||||
},
|
||||
|
||||
// already checked in `can_recv()`
|
||||
State::Established
|
||||
| State::FinWait1
|
||||
| State::FinWait2 => {
|
||||
unreachable!("Should be able to recv: {:?}", socket.state())
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(tcp::RecvError::Finished) => Ok(0),
|
||||
}
|
||||
} else {
|
||||
Err(SystemError::ENOBUFS)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn send_slice(&self, buf: &[u8]) -> Result<usize, SystemError> {
|
||||
self.inner
|
||||
.with_mut::<smoltcp::socket::tcp::Socket, _, _>(|socket| {
|
||||
if socket.can_send() {
|
||||
socket
|
||||
.send_slice(buf)
|
||||
.map_err(|_| SystemError::ECONNABORTED)
|
||||
} else {
|
||||
Err(SystemError::ENOBUFS)
|
||||
match socket.send_slice(buf) {
|
||||
Ok(0) => Err(SystemError::EAGAIN),
|
||||
Ok(size) => Ok(size),
|
||||
Err(tcp::SendError::InvalidState) => {
|
||||
use smoltcp::socket::tcp::State;
|
||||
match socket.state() {
|
||||
// Not ENOTCONN since the socket is in established state
|
||||
State::Closed => Err(SystemError::ECONNRESET),
|
||||
|
||||
// Socket is already closed by us
|
||||
State::LastAck
|
||||
| State::TimeWait
|
||||
| State::Closing
|
||||
| State::FinWait1
|
||||
| State::FinWait2 => Err(SystemError::EPIPE),
|
||||
|
||||
// Socket should not be in these state
|
||||
State::Listen | State::SynReceived | State::SynSent => {
|
||||
log::error!("Unexpected TCP state: {:?}", socket.state());
|
||||
Err(SystemError::ECONNRESET) // return reset to drop this error socket, not stadard behavior
|
||||
},
|
||||
|
||||
// these states are already checked in `can_send()`
|
||||
State::Established
|
||||
// In CLOSE-WAIT, the remote endpoint has closed our receive half of the connection
|
||||
// but we still can transmit indefinitely.
|
||||
| State::CloseWait => {
|
||||
unreachable!("Should be able to send: {:?}", socket.state())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -250,9 +250,13 @@ impl TcpSocket {
|
||||
}
|
||||
}
|
||||
|
||||
fn incoming(&self) -> bool {
|
||||
fn is_epoll_in(&self) -> bool {
|
||||
EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLIN)
|
||||
}
|
||||
|
||||
fn is_epoll_out(&self) -> bool {
|
||||
EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLOUT)
|
||||
}
|
||||
}
|
||||
|
||||
impl Socket for TcpSocket {
|
||||
@ -269,7 +273,7 @@ impl Socket for TcpSocket {
|
||||
{
|
||||
inner::Inner::Init(inner::Init::Unbound((_, ver))) => Ok(Endpoint::Ip(match ver {
|
||||
smoltcp::wire::IpVersion::Ipv4 => UNSPECIFIED_LOCAL_ENDPOINT_V4,
|
||||
// smoltcp::wire::IpVersion::Ipv6 => UNSPECIFIED_LOCAL_ENDPOINT_V6,
|
||||
smoltcp::wire::IpVersion::Ipv6 => todo!("UNSPECIFIED_LOCAL_ENDPOINT_V6"),
|
||||
})),
|
||||
inner::Inner::Init(inner::Init::Bound((_, local))) => Ok(Endpoint::Ip(*local)),
|
||||
inner::Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_name())),
|
||||
@ -324,27 +328,43 @@ impl Socket for TcpSocket {
|
||||
}
|
||||
|
||||
fn accept(&self) -> Result<(Arc<dyn Socket>, Endpoint), SystemError> {
|
||||
if self.is_nonblock() {
|
||||
self.try_accept()
|
||||
} else {
|
||||
loop {
|
||||
match self.try_accept() {
|
||||
Err(SystemError::EAGAIN) if self.is_nonblock() => break Err(SystemError::EAGAIN),
|
||||
Err(SystemError::EAGAIN) => {
|
||||
wq_wait_event_interruptible(&self.wait_queue, || self.incoming(), None)?;
|
||||
wq_wait_event_interruptible(&self.wait_queue, || self.is_epoll_in(), None)?;
|
||||
}
|
||||
result => {
|
||||
break result.map(|(inner, endpoint)| {
|
||||
(inner as Arc<dyn Socket>, Endpoint::Ip(endpoint))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn recv(&self, buffer: &mut [u8], _flags: PMSG) -> Result<usize, SystemError> {
|
||||
loop {
|
||||
match self.try_recv(buffer) {
|
||||
Err(SystemError::EAGAIN) if self.is_nonblock() => break Err(SystemError::EAGAIN),
|
||||
Err(SystemError::EAGAIN) => {
|
||||
wq_wait_event_interruptible(&self.wait_queue, || self.is_epoll_in(), None)?;
|
||||
}
|
||||
result => break result,
|
||||
}
|
||||
}
|
||||
}
|
||||
.map(|(inner, endpoint)| (inner as Arc<dyn Socket>, Endpoint::Ip(endpoint)))
|
||||
}
|
||||
|
||||
fn recv(&self, buffer: &mut [u8], _flags: PMSG) -> Result<usize, SystemError> {
|
||||
self.try_recv(buffer)
|
||||
}
|
||||
|
||||
fn send(&self, buffer: &[u8], _flags: PMSG) -> Result<usize, SystemError> {
|
||||
self.try_send(buffer)
|
||||
loop {
|
||||
match self.try_send(buffer) {
|
||||
Err(SystemError::EAGAIN) if self.is_nonblock() => break Err(SystemError::EAGAIN),
|
||||
Err(SystemError::EAGAIN) => {
|
||||
wq_wait_event_interruptible(&self.wait_queue, || self.is_epoll_out(), None)?;
|
||||
}
|
||||
result => break result,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn send_buffer_size(&self) -> usize {
|
||||
|
Loading…
x
Reference in New Issue
Block a user