diff --git a/src/driver/mod.rs b/src/driver/mod.rs index 40e443e..46c196c 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -17,11 +17,7 @@ fn ifreq_for(name: &str) -> ifreq { ifreq } -fn ifreq_ioctl( - lower: libc::c_int, - cmd: libc::c_ulong, - ifreq: &mut ifreq, -) -> io::Result { +fn ifreq_ioctl(lower: libc::c_int, cmd: libc::c_ulong, ifreq: &mut ifreq) -> io::Result { unsafe { let res = libc::ioctl(lower, cmd as _, ifreq as *mut ifreq); if res == -1 { diff --git a/src/driver/tap.rs b/src/driver/tap.rs index 58c8ba9..7b7ea3d 100644 --- a/src/driver/tap.rs +++ b/src/driver/tap.rs @@ -92,7 +92,7 @@ impl TapDesc { let mac = smoltcp::wire::EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]); unsafe { - ifr.ifr_ifru.ifru_hwaddr.sa_family = libc::ARPHRD_ETHER as u16; + ifr.ifr_ifru.ifru_hwaddr.sa_family = libc::ARPHRD_ETHER; ifr.ifr_ifru.ifru_hwaddr.sa_data[..6] .copy_from_slice(&[0x02, 0x00, 0x00, 0x00, 0x00, 0x01]); } diff --git a/src/libs/wait_queue.rs b/src/libs/wait_queue.rs index d539295..0974de7 100644 --- a/src/libs/wait_queue.rs +++ b/src/libs/wait_queue.rs @@ -1,8 +1,4 @@ -use std::{ - sync::atomic::AtomicBool, - thread::sleep, - time::Duration, -}; +use std::{sync::atomic::AtomicBool, thread::sleep, time::Duration}; use linux_errnos::Errno; diff --git a/src/main.rs b/src/main.rs index b196eb8..9dcdbd0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,20 +1,91 @@ use std::{net::Ipv4Addr, sync::Arc}; use berkeley_socket::{ - driver::{irq::start_network_polling_thread, tap::TapDevice}, interface::{tap::TapIface, Iface}, posix::SOCK, socket::{endpoint::Endpoint, inet::{common::NET_DEVICES, syscall::Inet}, Family} + driver::{irq::start_network_polling_thread, tap::TapDevice}, + interface::{tap::TapIface, Iface}, + posix::SOCK, + socket::{ + endpoint::Endpoint, + inet::{common::NET_DEVICES, syscall::Inet}, + Family, + }, }; -use smoltcp::wire::{IpAddress, IpEndpoint, Ipv4Cidr, IpCidr}; +use smoltcp::wire::{IpAddress, IpCidr, IpEndpoint, Ipv4Cidr}; use spin::Mutex; +fn make_udp_echo() { + let socket = Inet::socket(SOCK::Datagram, 0).unwrap(); + socket + .bind(Endpoint::Ip(IpEndpoint::new( + IpAddress::v4(192, 168, 213, 2), + 1234, + ))) + .unwrap(); + socket + .connect(Endpoint::Ip(IpEndpoint::new( + IpAddress::v4(192, 168, 213, 3), + 12345, + ))) + .unwrap(); + let mut buffer = [0u8; 1024]; + + loop { + let len = socket.read(&mut buffer).unwrap(); + log::info!( + "Received {} bytes: {}", + len, + String::from_utf8_lossy(&buffer[..len]) + ); + let len = socket.write(&buffer[..len]).unwrap(); + log::info!( + "Sent {} bytes: {}", + len, + String::from_utf8_lossy(&buffer[..len]) + ); + } +} + +fn make_tcp_echo() { + let socket = Inet::socket(SOCK::Stream, 0).unwrap(); + socket + .bind(Endpoint::Ip(IpEndpoint::new( + IpAddress::v4(192, 168, 213, 2), + 4321, + ))) + .unwrap(); + socket.listen(1).unwrap(); + + loop { + let (client_socket, _) = socket.accept().unwrap(); + log::info!("Accepted connection from {:?}", client_socket); + let mut buffer = [0u8; 1024]; + + loop { + let len = client_socket.read(&mut buffer).unwrap(); + if len == 0 { + break; + } + log::info!( + "Received {} bytes: {}", + len, + String::from_utf8_lossy(&buffer[..len]) + ); + let len = client_socket.write(&buffer[..len]).unwrap(); + log::info!( + "Sent {} bytes: {}", + len, + String::from_utf8_lossy(&buffer[..len]) + ); + } + } +} + fn main() { env_logger::init(); let device = TapDevice::new("tap0", smoltcp::phy::Medium::Ethernet).unwrap(); let iface_inner = TapIface::new(Arc::new(Mutex::new(device))); - let ip_cidr = IpCidr::Ipv4(Ipv4Cidr::new( - Ipv4Addr::new(192, 168, 213, 2), - 24 - )); + let ip_cidr = IpCidr::Ipv4(Ipv4Cidr::new(Ipv4Addr::new(192, 168, 213, 2), 24)); let ip_cidr = vec![ip_cidr]; @@ -28,17 +99,12 @@ fn main() { }); let _ = start_network_polling_thread(); - let socket = Inet::socket(SOCK::Datagram, 0).unwrap(); - socket.bind(Endpoint::Ip( - IpEndpoint::new( - IpAddress::v4(192, 168, 213, 2), - 1234, - ) - )).unwrap(); - let mut buffer = [0u8; 1024]; - - loop { - let len = socket.read(&mut buffer).unwrap(); - log::info!("Received {} bytes: {}", len, String::from_utf8_lossy(&buffer[..len])); - } + let udp = std::thread::spawn(move || { + make_udp_echo(); + }); + let tcp = std::thread::spawn(move || { + make_tcp_echo(); + }); + udp.join().unwrap(); + tcp.join().unwrap(); } diff --git a/src/socket/inet/datagram/inner.rs b/src/socket/inet/datagram/inner.rs index 2d537b9..da21c3b 100644 --- a/src/socket/inet/datagram/inner.rs +++ b/src/socket/inet/datagram/inner.rs @@ -134,7 +134,7 @@ impl BoundUdp { to: Option, ) -> Result { let remote = to.or(*self.remote.lock()).ok_or(SystemError::ENOTCONN)?; - + self.with_mut_socket(|socket| { if socket.can_send() && socket.send_slice(buf, remote).is_ok() { log::debug!("send {} bytes", buf.len()); diff --git a/src/socket/inet/datagram/mod.rs b/src/socket/inet/datagram/mod.rs index 4dd58ae..b952bec 100644 --- a/src/socket/inet/datagram/mod.rs +++ b/src/socket/inet/datagram/mod.rs @@ -273,8 +273,7 @@ impl Socket for UdpSocket { } impl InetSocket for UdpSocket { - fn on_iface_events(&self) { - } + fn on_iface_events(&self) {} } bitflags::bitflags! { diff --git a/src/socket/inet/mod.rs b/src/socket/inet/mod.rs index 45937d6..70fd0c0 100644 --- a/src/socket/inet/mod.rs +++ b/src/socket/inet/mod.rs @@ -4,21 +4,21 @@ use smoltcp; // pub mod icmp; pub mod common; pub mod datagram; -// pub mod stream; pub mod posix; +pub mod stream; pub mod syscall; pub use common::BoundInner; pub use common::Types; // pub use raw::RawSocket; pub use datagram::UdpSocket; +pub use stream::TcpSocket; use smoltcp::wire::IpAddress; use smoltcp::wire::IpEndpoint; use smoltcp::wire::Ipv4Address; // use smoltcp::wire::Ipv6Address; -// pub use stream::TcpSocket; // pub use syscall::Inet; use super::Socket; diff --git a/src/socket/inet/stream/inner.rs b/src/socket/inet/stream/inner.rs index 10d02f5..9e5e44b 100644 --- a/src/socket/inet/stream/inner.rs +++ b/src/socket/inet/stream/inner.rs @@ -1,13 +1,14 @@ use core::sync::atomic::AtomicUsize; +use crate::event_poll::EPollEventType; use crate::libs::rwlock::RwLock; // use crate::net::socket::EPollEventType; use crate::socket::{self, inet::Types}; use alloc::boxed::Box; use alloc::vec::Vec; +use linux_errnos::Errno as SystemError; use smoltcp; use smoltcp::socket::tcp; -use linux_errnos::Errno as SystemError; // pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024; @@ -163,11 +164,11 @@ impl Init { } inners.push(inner); - return Ok(Listening { + Ok(Listening { inners, connect: AtomicUsize::new(0), listen_addr, - }); + }) } pub(super) fn close(&self) { @@ -213,10 +214,7 @@ impl Connecting { pub fn into_result(self) -> (Inner, Result<(), SystemError>) { let result = *self.result.read(); match result { - ConnectResult::Connecting => ( - Inner::Connecting(self), - Err(SystemError::EAGAIN_OR_EWOULDBLOCK), - ), + ConnectResult::Connecting => (Inner::Connecting(self), Err(SystemError::EAGAIN)), ConnectResult::Connected => ( Inner::Established(Established { inner: self.inner }), Ok(()), @@ -264,7 +262,7 @@ impl Connecting { } // Refused *result = ConnectResult::Refused; - return true; + true }) } @@ -302,7 +300,7 @@ impl Listening { .unwrap(); if connected.with::(|socket| !socket.is_active()) { - return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + return Err(SystemError::EAGAIN); } let remote_endpoint = connected.with::(|socket| { @@ -327,7 +325,7 @@ impl Listening { // TODO is smoltcp socket swappable? core::mem::swap(&mut new_listen, connected); - return Ok((Established { inner: new_listen }, remote_endpoint)); + Ok((Established { inner: new_listen }, remote_endpoint)) } pub fn update_io_events(&self, pollee: &AtomicUsize) { diff --git a/src/socket/inet/stream/mod.rs b/src/socket/inet/stream/mod.rs index 4447079..26c1cbd 100644 --- a/src/socket/inet/stream/mod.rs +++ b/src/socket/inet/stream/mod.rs @@ -1,14 +1,14 @@ use alloc::sync::{Arc, Weak}; use core::sync::atomic::{AtomicBool, AtomicUsize}; -use system_error::SystemError; +use linux_errnos::Errno as SystemError; -use crate::libs::wait_queue::WaitQueue; +use crate::libs::wait_queue::{wq_wait_event_interruptible, WaitQueue}; // use crate::event_poll::EPollEventType; use crate::socket::common::shutdown::{ShutdownBit, ShutdownTemp}; use crate::socket::endpoint::Endpoint; use crate::socket::{Socket, PMSG, PSOL}; // use crate::sched::SchedMode; -use crate::{libs::rwlock::RwLock, net::socket::common::shutdown::Shutdown}; +use crate::{libs::rwlock::RwLock, socket::common::shutdown::Shutdown}; use smoltcp; mod inner; @@ -18,7 +18,7 @@ pub use option::Options as TcpOption; use super::{InetSocket, UNSPECIFIED_LOCAL_ENDPOINT_V4}; -type EP = EPollEventType; +type EP = crate::event_poll::EPollEventType; #[derive(Debug)] pub struct TcpSocket { inner: RwLock>, @@ -98,7 +98,7 @@ impl TcpSocket { if let Some(err) = err { return Err(err); } - return Ok(()); + Ok(()) } pub fn try_accept(&self) -> Result<(Arc, smoltcp::wire::IpEndpoint), SystemError> { @@ -161,7 +161,7 @@ impl TcpSocket { } writer.replace(init); - return result; + result } // for irq use @@ -297,7 +297,7 @@ impl Socket for TcpSocket { return self.do_bind(addr); } log::debug!("TcpSocket::bind: invalid endpoint"); - return Err(SystemError::EINVAL); + Err(SystemError::EINVAL) } fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> { @@ -307,12 +307,12 @@ impl Socket for TcpSocket { }; self.start_connect(endpoint)?; // Only Nonblock or error will return error. - return loop { + loop { match self.check_connect() { - Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => {} + Err(SystemError::EAGAIN) => {} result => break result, } - }; + } } fn poll(&self) -> usize { @@ -323,20 +323,20 @@ impl Socket for TcpSocket { self.do_listen(backlog) } - fn accept(&self) -> Result<(Arc, Endpoint), SystemError> { + fn accept(&self) -> Result<(Arc, Endpoint), SystemError> { if self.is_nonblock() { self.try_accept() } else { loop { match self.try_accept() { - Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { - wq_wait_event_interruptible!(self.wait_queue, self.incoming(), {})?; + Err(SystemError::EAGAIN) => { + wq_wait_event_interruptible(&self.wait_queue, || self.incoming(), None)?; } result => break result, } } } - .map(|(inner, endpoint)| (SocketInode::new(inner), Endpoint::Ip(endpoint))) + .map(|(inner, endpoint)| (inner as Arc, Endpoint::Ip(endpoint))) } fn recv(&self, buffer: &mut [u8], _flags: PMSG) -> Result { diff --git a/src/socket/inet/syscall.rs b/src/socket/inet/syscall.rs index f9fd6fa..528feb7 100644 --- a/src/socket/inet/syscall.rs +++ b/src/socket/inet/syscall.rs @@ -5,7 +5,7 @@ use smoltcp::{self, wire::IpProtocol}; use crate::{ posix::SOCK, socket::{ - inet::UdpSocket, + inet::{TcpSocket, UdpSocket}, Family, Socket, // SocketInode, }, @@ -19,28 +19,22 @@ fn create_inet_socket( // log::debug!("type: {:?}, protocol: {:?}", socket_type, protocol); match socket_type { SOCK::Datagram => match protocol { - IpProtocol::HopByHop | IpProtocol::Udp => { - Ok(UdpSocket::new(false)) + IpProtocol::HopByHop | IpProtocol::Udp => Ok(UdpSocket::new(false)), + _ => Err(SystemError::EPROTONOSUPPORT), + }, + SOCK::Stream => match protocol { + IpProtocol::HopByHop | IpProtocol::Tcp => { + log::debug!("create tcp socket"); + Ok(TcpSocket::new(false, version)) } _ => { Err(SystemError::EPROTONOSUPPORT) } }, - // SOCK::Stream => match protocol { - // IpProtocol::HopByHop | IpProtocol::Tcp => { - // log::debug!("create tcp socket"); - // return Ok(TcpSocket::new(false, version)); - // } - // _ => { - // return Err(SystemError::EPROTONOSUPPORT); - // } - // }, SOCK::Raw => { todo!("raw") } - _ => { - Err(SystemError::EPROTONOSUPPORT) - } + _ => Err(SystemError::EPROTONOSUPPORT), } } diff --git a/src/socket/mod.rs b/src/socket/mod.rs index 3d7a7ce..6a81799 100644 --- a/src/socket/mod.rs +++ b/src/socket/mod.rs @@ -33,9 +33,9 @@ pub trait Socket: Sync + Send + Debug + Any { // /// 接受连接,仅用于listening stream socket // /// ## Block // /// 如果没有连接到来,会阻塞 - // fn accept(&self) -> Result<(Arc, Endpoint), SystemError> { - // Err(SystemError::ENOSYS) - // } + fn accept(&self) -> Result<(Arc, Endpoint), SystemError> { + Err(SystemError::ENOSYS) + } /// # `bind` /// 对应于POSIX的bind函数,用于绑定到本机指定的端点 fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> {