Implement tcp&udp socket

This commit is contained in:
Jianfeng Jiang
2023-05-31 10:48:16 +08:00
committed by Tate, Hongliang Tian
parent f437dc6244
commit 8719234dc2
14 changed files with 1216 additions and 0 deletions

View File

@ -0,0 +1,51 @@
use crate::prelude::*;
use core::ops::{Deref, DerefMut};
/// AlwaysSome is a wrapper for Option.
///
/// AlwaysSome should always be Some(T), so we can treat it as a smart pointer.
/// If it becomes None, the AlwaysSome should be viewed invalid and cannot be used anymore.
pub struct AlwaysSome<T>(Option<T>);
impl<T> AlwaysSome<T> {
pub fn new(value: T) -> Self {
AlwaysSome(Some(value))
}
pub fn try_take_with<R, E: Into<Error>, F: FnOnce(T) -> core::result::Result<R, (E, T)>>(
&mut self,
f: F,
) -> Result<R> {
let value = if let Some(value) = self.0.take() {
value
} else {
return_errno_with_message!(Errno::EINVAL, "the take cell is none");
};
match f(value) {
Ok(res) => Ok(res),
Err((err, t)) => {
self.0 = Some(t);
Err(err.into())
}
}
}
/// Takes inner value
pub fn take(&mut self) -> T {
debug_assert!(self.0.is_some());
self.0.take().unwrap()
}
}
impl<T> Deref for AlwaysSome<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.0.as_ref().unwrap()
}
}
impl<T> DerefMut for AlwaysSome<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut().unwrap()
}
}

View File

@ -0,0 +1,65 @@
use crate::net::iface::BindPortConfig;
use crate::net::iface::Iface;
use crate::net::iface::{AnyBoundSocket, AnyUnboundSocket};
use crate::net::iface::{IpAddress, IpEndpoint};
use crate::net::IFACES;
use crate::prelude::*;
pub fn get_iface_to_bind(ip_addr: &IpAddress) -> Option<Arc<dyn Iface>> {
let ifaces = IFACES.get().unwrap();
let IpAddress::Ipv4(ipv4_addr) = ip_addr;
ifaces
.iter()
.find(|iface| {
if let Some(iface_ipv4_addr) = iface.ipv4_addr() {
iface_ipv4_addr == *ipv4_addr
} else {
false
}
})
.map(Clone::clone)
}
/// Get a suitable iface to deal with sendto/connect request if the socket is not bound to an iface.
/// If the remote address is the same as that of some iface, we will use the iface.
/// Otherwise, we will use a default interface.
fn get_ephemeral_iface(remote_ip_addr: &IpAddress) -> Arc<dyn Iface> {
let ifaces = IFACES.get().unwrap();
let IpAddress::Ipv4(remote_ipv4_addr) = remote_ip_addr;
if let Some(iface) = ifaces.iter().find(|iface| {
if let Some(iface_ipv4_addr) = iface.ipv4_addr() {
iface_ipv4_addr == *remote_ipv4_addr
} else {
false
}
}) {
return iface.clone();
}
// FIXME: use the virtio-net as the default interface
ifaces[0].clone()
}
pub(super) fn bind_socket(
unbound_socket: AnyUnboundSocket,
endpoint: IpEndpoint,
can_reuse: bool,
) -> core::result::Result<Arc<AnyBoundSocket>, (Error, AnyUnboundSocket)> {
let iface = match get_iface_to_bind(&endpoint.addr) {
Some(iface) => iface,
None => {
let err = Error::with_message(Errno::EADDRNOTAVAIL, "Request iface is not available");
return Err((err, unbound_socket));
}
};
let bind_port_config = match BindPortConfig::new(endpoint.port, can_reuse) {
Ok(config) => config,
Err(e) => return Err((e, unbound_socket)),
};
iface.bind_socket(unbound_socket, bind_port_config)
}
pub fn get_ephemeral_endpoint(remote_endpoint: &IpEndpoint) -> IpEndpoint {
let iface = get_ephemeral_iface(&remote_endpoint.addr);
let ip_addr = iface.ipv4_addr().unwrap();
IpEndpoint::new(IpAddress::Ipv4(ip_addr), 0)
}

View File

@ -0,0 +1,269 @@
use crate::net::iface::IpEndpoint;
use crate::{
fs::{
file_handle::FileLike,
utils::{IoEvents, Poller},
},
net::{
iface::{AnyBoundSocket, AnyUnboundSocket, RawUdpSocket},
poll_ifaces,
socket::{
util::{
send_recv_flags::SendRecvFlags, sock_options::SockOptionName, sockaddr::SocketAddr,
},
Socket,
},
},
prelude::*,
};
use super::always_some::AlwaysSome;
use super::common::{bind_socket, get_ephemeral_endpoint};
pub struct DatagramSocket {
inner: RwLock<Inner>,
}
enum Inner {
Unbound(AlwaysSome<AnyUnboundSocket>),
Bound {
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: Option<IpEndpoint>,
},
}
impl Inner {
fn is_bound(&self) -> bool {
if let Inner::Bound { .. } = self {
true
} else {
false
}
}
fn bind(&mut self, endpoint: IpEndpoint) -> Result<()> {
if self.is_bound() {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address");
}
let unbound_socket = match self {
Inner::Unbound(unbound_socket) => unbound_socket,
_ => unreachable!(),
};
let bound_socket =
unbound_socket.try_take_with(|socket| bind_socket(socket, endpoint, false))?;
let bound_endpoint = bound_socket.local_endpoint().unwrap();
bound_socket.raw_with(|socket: &mut RawUdpSocket| {
socket
.bind(bound_endpoint)
.map_err(|_| Error::with_message(Errno::EINVAL, "cannot bind socket"))
})?;
*self = Inner::Bound {
bound_socket,
remote_endpoint: None,
};
// Once the socket is bound, we should update the socket state at once.
self.update_socket_state();
Ok(())
}
fn bind_to_ephemeral_endpoint(&mut self, remote_endpoint: &IpEndpoint) -> Result<()> {
let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(endpoint)
}
fn set_remote_endpoint(&mut self, endpoint: IpEndpoint) -> Result<()> {
if let Inner::Bound {
remote_endpoint, ..
} = self
{
*remote_endpoint = Some(endpoint);
Ok(())
} else {
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
}
}
fn remote_endpoint(&self) -> Option<IpEndpoint> {
if let Inner::Bound {
remote_endpoint, ..
} = self
{
remote_endpoint.clone()
} else {
None
}
}
fn local_endpoint(&self) -> Option<IpEndpoint> {
if let Inner::Bound { bound_socket, .. } = self {
bound_socket.local_endpoint()
} else {
None
}
}
fn bound_socket(&self) -> Option<Arc<AnyBoundSocket>> {
if let Inner::Bound { bound_socket, .. } = self {
Some(bound_socket.clone())
} else {
None
}
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
match self {
Inner::Unbound(unbound_socket) => unbound_socket.poll(mask, poller),
Inner::Bound { bound_socket, .. } => bound_socket.poll(mask, poller),
}
}
fn update_socket_state(&self) {
if let Inner::Bound { bound_socket, .. } = self {
bound_socket.update_socket_state();
}
}
}
impl DatagramSocket {
pub fn new() -> Self {
let udp_socket = AnyUnboundSocket::new_udp();
Self {
inner: RwLock::new(Inner::Unbound(AlwaysSome::new(udp_socket))),
}
}
pub fn is_bound(&self) -> bool {
self.inner.read().is_bound()
}
fn try_recvfrom(&self, buf: &mut [u8], flags: &SendRecvFlags) -> Result<(usize, IpEndpoint)> {
poll_ifaces();
let bound_socket = self.inner.read().bound_socket().unwrap();
let recv_slice = |socket: &mut RawUdpSocket| match socket.recv_slice(buf) {
Err(smoltcp::socket::udp::RecvError::Exhausted) => {
return_errno_with_message!(Errno::EAGAIN, "recv buf is empty")
}
Ok((len, remote_endpoint)) => Ok((len, remote_endpoint)),
};
bound_socket.raw_with(recv_slice)
}
fn remote_endpoint(&self) -> Result<IpEndpoint> {
self.inner
.read()
.remote_endpoint()
.ok_or(Error::with_message(
Errno::EINVAL,
"udp should provide remote addr",
))
}
}
impl FileLike for DatagramSocket {
fn read(&self, buf: &mut [u8]) -> Result<usize> {
// FIXME: respect flags
let flags = SendRecvFlags::empty();
let (recv_len, _) = self.recvfrom(buf, flags)?;
Ok(recv_len)
}
fn write(&self, buf: &[u8]) -> Result<usize> {
// FIXME: set correct flags
let flags = SendRecvFlags::empty();
self.sendto(buf, None, flags)
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.inner.read().poll(mask, poller)
}
fn as_socket(&self) -> Option<&dyn Socket> {
Some(self)
}
}
impl Socket for DatagramSocket {
fn bind(&self, sockaddr: SocketAddr) -> Result<()> {
let endpoint = sockaddr.try_into()?;
self.inner.write().bind(endpoint)
}
fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
let remote_endpoint: IpEndpoint = sockaddr.try_into()?;
let mut inner = self.inner.write();
if !self.is_bound() {
inner.bind_to_ephemeral_endpoint(&remote_endpoint)?
}
inner.set_remote_endpoint(remote_endpoint)?;
inner.update_socket_state();
Ok(())
}
fn addr(&self) -> Result<SocketAddr> {
if let Some(local_endpoint) = self.inner.read().local_endpoint() {
local_endpoint.try_into()
} else {
return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint");
}
}
fn peer_addr(&self) -> Result<SocketAddr> {
if let Some(remote_endpoint) = self.inner.read().remote_endpoint() {
remote_endpoint.try_into()
} else {
return_errno_with_message!(Errno::EINVAL, "remote endpoint is not specified");
}
}
fn set_sock_option(&self, optname: SockOptionName, option_val: &[u8]) -> Result<()> {
// FIXME: deal with sock options here
Ok(())
}
// FIXME: respect RecvFromFlags
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
debug_assert!(flags.is_all_supported());
if !self.is_bound() {
return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint");
}
let poller = Poller::new();
let bound_socket = self.inner.read().bound_socket().unwrap();
loop {
if let Ok((recv_len, remote_endpoint)) = self.try_recvfrom(buf, &flags) {
let remote_addr = remote_endpoint.try_into()?;
return Ok((recv_len, remote_addr));
}
let events = self.inner.read().poll(IoEvents::IN, Some(&poller));
if !events.contains(IoEvents::IN) {
poller.wait();
}
}
}
fn sendto(
&self,
buf: &[u8],
remote: Option<SocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
let remote_endpoint: IpEndpoint = if let Some(remote_addr) = remote {
remote_addr.try_into()?
} else {
self.remote_endpoint()?
};
if !self.is_bound() {
self.inner
.write()
.bind_to_ephemeral_endpoint(&remote_endpoint)?;
}
let bound_socket = self.inner.read().bound_socket().unwrap();
let send_slice = |socket: &mut RawUdpSocket| match socket.send_slice(buf, remote_endpoint) {
Err(_) => return_errno_with_message!(Errno::ENOBUFS, "send udp packet fails"),
Ok(()) => Ok(buf.len()),
};
let len = bound_socket.raw_with(send_slice)?;
poll_ifaces();
Ok(len)
}
}

View File

@ -0,0 +1,7 @@
mod always_some;
mod common;
mod datagram;
mod stream;
pub use datagram::DatagramSocket;
pub use stream::StreamSocket;

View File

@ -0,0 +1,95 @@
use crate::net::iface::IpEndpoint;
use crate::{
fs::utils::{IoEvents, Poller},
net::{
iface::{AnyBoundSocket, RawTcpSocket},
poll_ifaces,
socket::util::{send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd},
},
prelude::*,
};
pub struct ConnectedStream {
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
}
impl ConnectedStream {
pub fn new(bound_socket: Arc<AnyBoundSocket>, remote_endpoint: IpEndpoint) -> Self {
Self {
bound_socket,
remote_endpoint,
}
}
pub fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
// TODO: deal with cmd
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
socket.close();
});
poll_ifaces();
Ok(())
}
pub fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, IpEndpoint)> {
debug_assert!(flags.is_all_supported());
let poller = Poller::new();
loop {
let recv_len = self.try_recvfrom(buf, flags)?;
if recv_len > 0 {
let remote_endpoint = self.remote_endpoint()?;
return Ok((recv_len, remote_endpoint));
}
let events = self.bound_socket.poll(IoEvents::IN, Some(&poller));
if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) {
return_errno_with_message!(Errno::ENOTCONN, "recv packet fails");
}
if !events.contains(IoEvents::IN) {
poller.wait();
}
}
}
fn try_recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<usize> {
poll_ifaces();
let res = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
socket
.recv_slice(buf)
.map_err(|_| Error::with_message(Errno::ENOTCONN, "fail to recv packet"))
});
self.bound_socket.update_socket_state();
res
}
pub fn sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
debug_assert!(flags.is_all_supported());
let mut sent_len = 0;
let buf_len = buf.len();
loop {
let len = self
.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.send_slice(&buf[sent_len..]))
.map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet"))?;
poll_ifaces();
sent_len += len;
if sent_len == buf_len {
return Ok(sent_len);
}
}
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.bound_socket
.local_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint"))
}
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
Ok(self.remote_endpoint)
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.bound_socket.poll(mask, poller)
}
}

View File

@ -0,0 +1,175 @@
use crate::fs::utils::{IoEvents, Poller};
use crate::net::iface::Iface;
use crate::net::iface::IpEndpoint;
use crate::net::iface::{AnyBoundSocket, AnyUnboundSocket};
use crate::net::poll_ifaces;
use crate::net::socket::ip::always_some::AlwaysSome;
use crate::net::socket::ip::common::{bind_socket, get_ephemeral_endpoint};
use crate::prelude::*;
pub struct InitStream {
inner: RwLock<Inner>,
}
enum Inner {
Unbound(AlwaysSome<AnyUnboundSocket>),
Bound(AlwaysSome<Arc<AnyBoundSocket>>),
Connecting {
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
},
}
impl Inner {
fn is_bound(&self) -> bool {
match self {
Self::Unbound(_) => false,
Self::Bound(..) | Self::Connecting { .. } => true,
}
}
fn bind(&mut self, endpoint: IpEndpoint) -> Result<()> {
let unbound_socket = if let Inner::Unbound(unbound_socket) = self {
unbound_socket
} else {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address");
};
let bound_socket =
unbound_socket.try_take_with(|raw_socket| bind_socket(raw_socket, endpoint, false))?;
bound_socket.update_socket_state();
*self = Inner::Bound(AlwaysSome::new(bound_socket));
Ok(())
}
fn bind_to_ephemeral_endpoint(&mut self, remote_endpoint: &IpEndpoint) -> Result<()> {
let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(endpoint)
}
fn do_connect(&mut self, new_remote_endpoint: IpEndpoint) -> Result<()> {
match self {
Inner::Unbound(_) => return_errno_with_message!(Errno::EINVAL, "the socket is invalid"),
Inner::Connecting {
bound_socket,
remote_endpoint,
} => {
*remote_endpoint = new_remote_endpoint;
bound_socket.do_connect(new_remote_endpoint)?;
}
Inner::Bound(bound_socket) => {
bound_socket.do_connect(new_remote_endpoint)?;
*self = Inner::Connecting {
bound_socket: bound_socket.take(),
remote_endpoint: new_remote_endpoint,
};
}
}
Ok(())
}
fn bound_socket(&self) -> Option<&Arc<AnyBoundSocket>> {
match self {
Inner::Bound(bound_socket) => Some(&bound_socket),
Inner::Connecting { bound_socket, .. } => Some(bound_socket),
_ => None,
}
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
match self {
Inner::Bound(bound_socket) => bound_socket.poll(mask, poller),
Inner::Connecting { bound_socket, .. } => bound_socket.poll(mask, poller),
Inner::Unbound(unbound_socket) => unbound_socket.poll(mask, poller),
}
}
fn iface(&self) -> Option<Arc<dyn Iface>> {
match self {
Inner::Bound(bound_socket) => Some(bound_socket.iface().clone()),
Inner::Connecting { bound_socket, .. } => Some(bound_socket.iface().clone()),
_ => None,
}
}
fn local_endpoint(&self) -> Option<IpEndpoint> {
self.bound_socket()
.map(|socket| socket.local_endpoint())
.flatten()
}
fn remote_endpoint(&self) -> Option<IpEndpoint> {
if let Inner::Connecting {
remote_endpoint, ..
} = self
{
Some(*remote_endpoint)
} else {
None
}
}
}
impl InitStream {
pub fn new() -> Self {
let socket = AnyUnboundSocket::new_tcp();
let inner = Inner::Unbound(AlwaysSome::new(socket));
Self {
inner: RwLock::new(inner),
}
}
pub fn is_bound(&self) -> bool {
self.inner.read().is_bound()
}
pub fn bind(&self, endpoint: IpEndpoint) -> Result<()> {
self.inner.write().bind(endpoint)
}
pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
if !self.is_bound() {
self.inner
.write()
.bind_to_ephemeral_endpoint(remote_endpoint)?
}
self.inner.write().do_connect(*remote_endpoint)?;
// Wait until building connection
let poller = Poller::new();
loop {
poll_ifaces();
let events = self
.inner
.read()
.poll(IoEvents::OUT | IoEvents::IN, Some(&poller));
if events.contains(IoEvents::IN) || events.contains(IoEvents::OUT) {
return Ok(());
} else if !events.is_empty() {
return_errno_with_message!(Errno::ECONNREFUSED, "connect refused")
} else {
poller.wait();
}
}
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.inner
.read()
.local_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has local endpoint"))
}
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
self.inner
.read()
.remote_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint"))
}
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.inner.read().poll(mask, poller)
}
pub fn bound_socket(&self) -> Option<Arc<AnyBoundSocket>> {
self.inner.read().bound_socket().map(Clone::clone)
}
}

View File

@ -0,0 +1,145 @@
use crate::net::iface::{AnyUnboundSocket, BindPortConfig, IpEndpoint};
use crate::fs::utils::{IoEvents, Poller};
use crate::net::iface::{AnyBoundSocket, RawTcpSocket};
use crate::{net::poll_ifaces, prelude::*};
use super::connected::ConnectedStream;
pub struct ListenStream {
backlog: usize,
/// Sockets also listening at LocalEndPoint when called `listen`
backlog_sockets: RwLock<Vec<BacklogSocket>>,
}
impl ListenStream {
pub fn new(bound_socket: Arc<AnyBoundSocket>, backlog: usize) -> Result<Self> {
debug_assert!(backlog >= 1);
let backlog_socket = BacklogSocket::new(&bound_socket)?;
let listen_stream = Self {
backlog,
backlog_sockets: RwLock::new(vec![backlog_socket]),
};
listen_stream.fill_backlog_sockets()?;
Ok(listen_stream)
}
pub fn accept(&self) -> Result<(ConnectedStream, IpEndpoint)> {
// wait to accept
let poller = Poller::new();
loop {
poll_ifaces();
let accepted_socket = if let Some(accepted_socket) = self.try_accept() {
accepted_socket
} else {
let events = self.poll(IoEvents::IN | IoEvents::OUT, Some(&poller));
if !events.contains(IoEvents::IN) && !events.contains(IoEvents::OUT) {
poller.wait();
}
continue;
};
let remote_endpoint = accepted_socket.remote_endpoint().unwrap();
let connected_stream = {
let BacklogSocket {
bound_socket: backlog_socket,
} = accepted_socket;
ConnectedStream::new(backlog_socket, remote_endpoint)
};
return Ok((connected_stream, remote_endpoint));
}
}
/// Append sockets listening at LocalEndPoint to support backlog
fn fill_backlog_sockets(&self) -> Result<()> {
let backlog = self.backlog;
let mut backlog_sockets = self.backlog_sockets.write();
let current_backlog_len = backlog_sockets.len();
debug_assert!(backlog >= current_backlog_len);
if backlog == current_backlog_len {
return Ok(());
}
let bound_socket = backlog_sockets[0].bound_socket.clone();
for _ in current_backlog_len..backlog {
let backlog_socket = BacklogSocket::new(&bound_socket)?;
backlog_sockets.push(backlog_socket);
}
Ok(())
}
fn try_accept(&self) -> Option<BacklogSocket> {
let backlog_socket = {
let mut backlog_sockets = self.backlog_sockets.write();
let index = backlog_sockets
.iter()
.position(|backlog_socket| backlog_socket.is_active())?;
backlog_sockets.remove(index)
};
self.fill_backlog_sockets().unwrap();
Some(backlog_socket)
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.bound_socket()
.local_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint"))
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let backlog_sockets = self.backlog_sockets.read();
for backlog_socket in backlog_sockets.iter() {
if backlog_socket.is_active() {
return IoEvents::IN;
} else {
// regiser poller to the backlog socket
backlog_socket.poll(mask, poller);
}
}
return IoEvents::empty();
}
fn bound_socket(&self) -> Arc<AnyBoundSocket> {
self.backlog_sockets.read()[0].bound_socket.clone()
}
}
struct BacklogSocket {
bound_socket: Arc<AnyBoundSocket>,
}
impl BacklogSocket {
fn new(bound_socket: &AnyBoundSocket) -> Result<Self> {
let local_endpoint = bound_socket.local_endpoint().ok_or(Error::with_message(
Errno::EINVAL,
"the socket is not bound",
))?;
let unbound_socket = AnyUnboundSocket::new_tcp();
let bound_socket = {
let iface = bound_socket.iface();
let bind_port_config = BindPortConfig::new(local_endpoint.port, true)?;
iface
.bind_socket(unbound_socket, bind_port_config)
.map_err(|(e, _)| e)?
};
bound_socket.raw_with(|raw_tcp_socket: &mut RawTcpSocket| {
raw_tcp_socket
.listen(local_endpoint)
.map_err(|_| Error::with_message(Errno::EINVAL, "fail to listen"))
})?;
bound_socket.update_socket_state();
Ok(Self { bound_socket })
}
fn is_active(&self) -> bool {
self.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.is_active())
}
fn remote_endpoint(&self) -> Option<IpEndpoint> {
self.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.remote_endpoint())
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.bound_socket.poll(mask, poller)
}
}

View File

@ -0,0 +1,196 @@
use crate::fs::{
file_handle::FileLike,
utils::{IoEvents, Poller},
};
use crate::net::socket::{
util::{
send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd,
sock_options::SockOptionName, sockaddr::SocketAddr,
},
Socket,
};
use crate::prelude::*;
use self::{connected::ConnectedStream, init::InitStream, listen::ListenStream};
mod connected;
mod init;
mod listen;
pub struct StreamSocket {
state: RwLock<State>,
}
enum State {
// Start state
Init(Arc<InitStream>),
// Final State 1
Connected(Arc<ConnectedStream>),
// Final State 2
Listen(Arc<ListenStream>),
}
impl StreamSocket {
pub fn new() -> Self {
let state = State::Init(Arc::new(InitStream::new()));
Self {
state: RwLock::new(state),
}
}
}
impl FileLike for StreamSocket {
fn read(&self, buf: &mut [u8]) -> Result<usize> {
// FIXME: set correct flags
let flags = SendRecvFlags::empty();
let (recv_len, _) = self.recvfrom(buf, flags)?;
Ok(recv_len)
}
fn write(&self, buf: &[u8]) -> Result<usize> {
// FIXME: set correct flags
let flags = SendRecvFlags::empty();
self.sendto(buf, None, flags)
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
let state = self.state.read();
match &*state {
State::Init(init) => init.poll(mask, poller),
State::Connected(connected) => connected.poll(mask, poller),
State::Listen(listen) => listen.poll(mask, poller),
}
}
fn as_socket(&self) -> Option<&dyn Socket> {
Some(self)
}
}
impl Socket for StreamSocket {
fn bind(&self, sockaddr: SocketAddr) -> Result<()> {
let endpoint = sockaddr.try_into()?;
let state = self.state.read();
match &*state {
State::Init(init_stream) => init_stream.bind(endpoint),
_ => return_errno_with_message!(Errno::EINVAL, "cannot bind"),
}
}
fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
let remote_endpoint = sockaddr.try_into()?;
let mut state = self.state.write();
// FIXME: The rwlock is held when trying to connect, which may cause dead lock.
match &*state {
State::Init(init_stream) => {
init_stream.connect(&remote_endpoint)?;
let bound_socket = init_stream.bound_socket().unwrap();
let connected_stream =
Arc::new(ConnectedStream::new(bound_socket, remote_endpoint));
*state = State::Connected(connected_stream);
Ok(())
}
_ => return_errno_with_message!(Errno::EINVAL, "cannot connect"),
}
}
fn listen(&self, backlog: usize) -> Result<()> {
let mut state = self.state.write();
match &*state {
State::Init(init_stream) => {
if !init_stream.is_bound() {
return_errno_with_message!(Errno::EINVAL, "cannot listen without bound");
}
let bound_socket = init_stream.bound_socket().unwrap();
let listener = Arc::new(ListenStream::new(bound_socket, backlog)?);
*state = State::Listen(listener);
Ok(())
}
State::Listen(listen_stream) => {
return_errno_with_message!(Errno::EINVAL, "cannot listen for a listening stream")
}
_ => return_errno_with_message!(Errno::EINVAL, "cannot listen"),
}
}
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let state = self.state.read();
match &*state {
State::Listen(listen_stream) => {
let (connected_stream, remote_endpoint) = listen_stream.accept()?;
let state = RwLock::new(State::Connected(Arc::new(connected_stream)));
let accepted_socket = Arc::new(StreamSocket { state });
let socket_addr = remote_endpoint.try_into()?;
Ok((accepted_socket, socket_addr))
}
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not listening"),
}
}
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
let state = self.state.read();
match &*state {
State::Connected(connected_stream) => connected_stream.shutdown(cmd),
// TDOD: shutdown listening stream
_ => return_errno_with_message!(Errno::EINVAL, "cannot shutdown"),
}
}
fn addr(&self) -> Result<SocketAddr> {
let state = self.state.read();
let local_endpoint = match &*state {
State::Init(init_stream) => init_stream.local_endpoint(),
State::Listen(listen_stream) => listen_stream.local_endpoint(),
State::Connected(connected_stream) => connected_stream.local_endpoint(),
}?;
local_endpoint.try_into()
}
fn peer_addr(&self) -> Result<SocketAddr> {
let state = self.state.read();
let remote_endpoint = match &*state {
State::Init(init_stream) => init_stream.remote_endpoint(),
State::Listen(listen_stream) => {
return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer")
}
State::Connected(connected_stream) => connected_stream.remote_endpoint(),
}?;
remote_endpoint.try_into()
}
fn sock_option(&self, optname: &SockOptionName) -> Result<&[u8]> {
return_errno_with_message!(Errno::EINVAL, "getsockopt not implemented");
}
fn set_sock_option(&self, optname: SockOptionName, option_val: &[u8]) -> Result<()> {
// TODO: implement setsockopt
Ok(())
}
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
let state = self.state.read();
let (recv_size, remote_endpoint) = match &*state {
State::Connected(connected_stream) => connected_stream.recvfrom(buf, flags),
_ => return_errno_with_message!(Errno::EINVAL, "cannot recv"),
}?;
let socket_addr = remote_endpoint.try_into()?;
Ok((recv_size, socket_addr))
}
fn sendto(
&self,
buf: &[u8],
remote: Option<SocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
debug_assert!(remote.is_none());
if remote.is_some() {
return_errno_with_message!(Errno::EINVAL, "tcp socked should not provide remote addr");
}
let state = self.state.read();
match &*state {
State::Connected(connected_stream) => connected_stream.sendto(buf, flags),
_ => return_errno_with_message!(Errno::EINVAL, "cannot send"),
}
}
}

View File

@ -0,0 +1,72 @@
use crate::{fs::file_handle::FileLike, prelude::*};
pub use self::util::send_recv_flags::SendRecvFlags;
pub use self::util::shutdown_cmd::SockShutdownCmd;
pub use self::util::sock_options::SockOptionName;
pub use self::util::sockaddr::SocketAddr;
pub mod ip;
mod util;
/// Operations defined on a socket.
pub trait Socket: FileLike + Send + Sync {
/// Assign the address specified by sockaddr to the socket
fn bind(&self, sockaddr: SocketAddr) -> Result<()> {
return_errno_with_message!(Errno::EINVAL, "bind not implemented");
}
/// Build connection for a given address
fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
return_errno_with_message!(Errno::EINVAL, "connect not implemented");
}
/// Listen for connections on a socket
fn listen(&self, backlog: usize) -> Result<()> {
return_errno_with_message!(Errno::EINVAL, "connect not implemented");
}
/// Accept a connection on a socket
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
return_errno_with_message!(Errno::EINVAL, "accept not implemented");
}
/// Shut down part of a full-duplex connection
fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
return_errno_with_message!(Errno::EINVAL, "shutdown not implemented");
}
/// Get address of this socket.
fn addr(&self) -> Result<SocketAddr> {
return_errno_with_message!(Errno::EINVAL, "getsockname not implemented");
}
/// Get address of peer socket
fn peer_addr(&self) -> Result<SocketAddr> {
return_errno_with_message!(Errno::EINVAL, "getpeername not implemented");
}
/// Get options on the socket
fn sock_option(&self, optname: &SockOptionName) -> Result<&[u8]> {
return_errno_with_message!(Errno::EINVAL, "getsockopt not implemented");
}
/// Set options on the socket
fn set_sock_option(&self, optname: SockOptionName, option_val: &[u8]) -> Result<()> {
return_errno_with_message!(Errno::EINVAL, "setsockopt not implemented");
}
/// Receive a message from a socket
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
return_errno_with_message!(Errno::EINVAL, "recvfrom not implemented");
}
/// Send a message on a socket
fn sendto(
&self,
buf: &[u8],
remote: Option<SocketAddr>,
flags: SendRecvFlags,
) -> Result<usize> {
return_errno_with_message!(Errno::EINVAL, "recvfrom not implemented");
}
}

View File

@ -0,0 +1,4 @@
pub mod send_recv_flags;
pub mod shutdown_cmd;
pub mod sock_options;
pub mod sockaddr;

View File

@ -0,0 +1,45 @@
use crate::prelude::*;
bitflags! {
#[repr(C)]
#[derive(Pod)]
/// Flags used for send/recv.
/// The definiton is from https://elixir.bootlin.com/linux/v6.0.9/source/include/linux/socket.h
pub struct SendRecvFlags: i32 {
const MSG_OOB = 1;
const MSG_PEEK = 2;
const MSG_DONTROUTE = 4;
// const MSG_TRYHARD = 4; /* Synonym for MSG_DONTROUTE for DECnet */
const MSG_CTRUNC = 8;
const MSG_PROBE = 0x10; /* Do not send. Only probe path f.e. for MTU */
const MSG_TRUNC = 0x20;
const MSG_DONTWAIT = 0x40; /* Nonblocking io */
const MSG_EOR = 0x80; /* End of record */
const MSG_WAITALL = 0x100; /* Wait for a full request */
const MSG_FIN = 0x200;
const MSG_SYN = 0x400;
const MSG_CONFIRM = 0x800; /* Confirm path validity */
const MSG_RST = 0x1000;
const MSG_ERRQUEUE = 0x2000; /* Fetch message from error queue */
const MSG_NOSIGNAL = 0x4000; /* Do not generate SIGPIPE */
const MSG_MORE = 0x8000; /* Sender will send more */
const MSG_WAITFORONE = 0x10000; /* recvmmsg(): block until 1+ packets avail */
const MSG_SENDPAGE_NOPOLICY = 0x10000; /* sendpage() internal : do no apply policy */
const MSG_SENDPAGE_NOTLAST = 0x20000; /* sendpage() internal : not the last page */
const MSG_BATCH = 0x40000; /* sendmmsg(): more messages coming */
// const MSG_EOF MSG_FIN
const MSG_NO_SHARED_FRAGS = 0x80000; /* sendpage() internal : page frags are not shared */
const MSG_SENDPAGE_DECRYPTED = 0x100000; /* sendpage() internal : page may carry plain text and require encryption */
}
}
impl SendRecvFlags {
fn supported_flags() -> Self {
SendRecvFlags::empty()
}
pub fn is_all_supported(&self) -> bool {
let supported_flags = Self::supported_flags();
supported_flags.contains(*self)
}
}

View File

@ -0,0 +1,15 @@
use crate::prelude::*;
#[repr(i32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromInt)]
#[allow(non_camel_case_types)]
/// Shutdown types
/// From https://elixir.bootlin.com/linux/v6.0.9/source/include/linux/net.h
pub enum SockShutdownCmd {
/// Shutdown receptions
SHUT_RD = 0,
/// Shutdown transmissions
SHUT_WR = 1,
/// Shutdown receptions and transmissions
SHUT_RDWR = 2,
}

View File

@ -0,0 +1,26 @@
use crate::prelude::*;
#[repr(i32)]
#[derive(Debug, Clone, Copy, TryFromInt, PartialEq, Eq, PartialOrd, Ord)]
#[allow(non_camel_case_types)]
/// The definition is from https://elixir.bootlin.com/linux/v6.0.9/source/include/uapi/asm-generic/socket.h.
/// We do not include all options here
pub enum SockOptionName {
SO_DEBUG = 1,
SO_REUSEADDR = 2,
SO_TYPE = 3,
SO_ERROR = 4,
SO_DONTROUTE = 5,
SO_BROADCAST = 6,
SO_SNDBUF = 7,
SO_RCVBUF = 8,
SO_SNDBUFFORCE = 32,
SO_RCVBUFFORCE = 33,
SO_KEEPALIVE = 9,
SO_OOBINLINE = 10,
SO_NO_CHECK = 11,
SO_PRIORITY = 12,
SO_LINGER = 13,
SO_BSDCOMPAT = 14,
SO_REUSEPORT = 15,
}

View File

@ -0,0 +1,51 @@
use crate::net::iface::{IpAddress, Ipv4Address};
use crate::net::iface::{IpEndpoint, IpListenEndpoint};
use crate::prelude::*;
type PortNum = u16;
#[derive(Debug)]
pub enum SocketAddr {
Unix,
IPv4(Ipv4Address, PortNum),
IPv6,
}
impl TryFrom<SocketAddr> for IpEndpoint {
type Error = Error;
fn try_from(value: SocketAddr) -> Result<Self> {
match value {
SocketAddr::IPv4(addr, port) => Ok(IpEndpoint::new(addr.into_address(), port)),
_ => return_errno_with_message!(
Errno::EINVAL,
"sock addr cannot be converted as IpEndpoint"
),
}
}
}
impl TryFrom<IpEndpoint> for SocketAddr {
type Error = Error;
fn try_from(endpoint: IpEndpoint) -> Result<Self> {
let port = endpoint.port;
let socket_addr = match endpoint.addr {
IpAddress::Ipv4(addr) => SocketAddr::IPv4(addr, port), // TODO: support IPv6
};
Ok(socket_addr)
}
}
impl TryFrom<IpListenEndpoint> for SocketAddr {
type Error = Error;
fn try_from(value: IpListenEndpoint) -> Result<Self> {
let port = value.port;
let socket_addr = match value.addr {
None => return_errno_with_message!(Errno::EINVAL, "address is unspecified"),
Some(IpAddress::Ipv4(address)) => SocketAddr::IPv4(address, port),
};
Ok(socket_addr)
}
}