Seperate ConnectingStream from InitStream

For TCP streams we used to have three states, e.g. `InitStream`,
`ListenStream`, `ConnectedStream`. If the socket is not bound, it is in
the `InitStream` state. If the socket is bound, it is still in that
state. Most seriously, if the socket is connecting to the remote peer,
but the connection has not been established, it is also in the
`InitStream` state.

While the socket is trying to connect to its peer, it needs to handle
iface events to update its internal state. So it is expected to
implement the trait that observes such events for this later. However,
the reality that sockets connecting to their peers are mixed in with
other unbound and bound sockets in the `InitStream` will complicate
things.

In fact, the connecting socket should belong to an independent state. It
does not share too much logic with unbound and bound sockets in the
`InitStream`. So in this commit we will decouple that and create a new
`ConnectingStream` state.
This commit is contained in:
Ruihan Li
2023-11-30 01:39:04 +08:00
committed by Tate, Hongliang Tian
parent 58948d498c
commit 6b903d0c10
3 changed files with 165 additions and 111 deletions

View File

@ -0,0 +1,83 @@
use core::sync::atomic::{AtomicBool, Ordering};
use alloc::sync::Arc;
use crate::events::IoEvents;
use crate::net::poll_ifaces;
use crate::prelude::*;
use crate::net::iface::{AnyBoundSocket, IpEndpoint};
use crate::process::signal::Poller;
use super::connected::ConnectedStream;
use super::init::InitStream;
pub struct ConnectingStream {
nonblocking: AtomicBool,
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
}
impl ConnectingStream {
pub fn new(
nonblocking: bool,
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
) -> Result<Self> {
bound_socket.do_connect(remote_endpoint)?;
Ok(Self {
nonblocking: AtomicBool::new(nonblocking),
bound_socket,
remote_endpoint,
})
}
pub fn wait_conn(&self) -> core::result::Result<ConnectedStream, (Error, InitStream)> {
debug_assert!(!self.is_nonblocking());
let poller = Poller::new();
loop {
poll_ifaces();
let events = self.poll(IoEvents::OUT | IoEvents::IN, Some(&poller));
if events.contains(IoEvents::IN) || events.contains(IoEvents::OUT) {
return Ok(ConnectedStream::new(
self.is_nonblocking(),
self.bound_socket.clone(),
self.remote_endpoint,
));
} else if !events.is_empty() {
return Err((
Error::with_message(Errno::ECONNREFUSED, "connection refused"),
InitStream::new_bound(self.is_nonblocking(), self.bound_socket.clone()),
));
} else {
// FIXME: deal with nonblocking mode & connecting timeout
poller.wait().expect("async connect() not implemented");
}
}
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.bound_socket
.local_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "no local endpoint"))
}
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
Ok(self.remote_endpoint)
}
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.bound_socket.poll(mask, poller)
}
pub fn is_nonblocking(&self) -> bool {
self.nonblocking.load(Ordering::Relaxed)
}
pub fn set_nonblocking(&self, nonblocking: bool) {
self.nonblocking.store(nonblocking, Ordering::Relaxed);
}
}

View File

@ -4,12 +4,14 @@ use crate::events::IoEvents;
use crate::net::iface::Iface;
use crate::net::iface::IpEndpoint;
use crate::net::iface::{AnyBoundSocket, AnyUnboundSocket};
use crate::net::poll_ifaces;
use crate::net::socket::ip::always_some::AlwaysSome;
use crate::net::socket::ip::common::{bind_socket, get_ephemeral_endpoint};
use crate::prelude::*;
use crate::process::signal::Poller;
use super::connecting::ConnectingStream;
use super::listen::ListenStream;
pub struct InitStream {
inner: RwLock<Inner>,
is_nonblocking: AtomicBool,
@ -18,17 +20,13 @@ pub struct InitStream {
enum Inner {
Unbound(AlwaysSome<Box<AnyUnboundSocket>>),
Bound(AlwaysSome<Arc<AnyBoundSocket>>),
Connecting {
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
},
}
impl Inner {
fn is_bound(&self) -> bool {
match self {
Self::Unbound(_) => false,
Self::Bound(..) | Self::Connecting { .. } => true,
Self::Bound(_) => true,
}
}
@ -50,39 +48,16 @@ impl Inner {
self.bind(endpoint)
}
fn do_connect(&mut self, new_remote_endpoint: IpEndpoint) -> Result<()> {
match self {
Inner::Unbound(_) => return_errno_with_message!(Errno::EINVAL, "the socket is invalid"),
Inner::Connecting {
bound_socket,
remote_endpoint,
} => {
*remote_endpoint = new_remote_endpoint;
bound_socket.do_connect(new_remote_endpoint)?;
}
Inner::Bound(bound_socket) => {
bound_socket.do_connect(new_remote_endpoint)?;
*self = Inner::Connecting {
bound_socket: bound_socket.take(),
remote_endpoint: new_remote_endpoint,
};
}
}
Ok(())
}
fn bound_socket(&self) -> Option<&Arc<AnyBoundSocket>> {
match self {
Inner::Bound(bound_socket) => Some(bound_socket),
Inner::Connecting { bound_socket, .. } => Some(bound_socket),
_ => None,
Inner::Unbound(_) => None,
}
}
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
match self {
Inner::Bound(bound_socket) => bound_socket.poll(mask, poller),
Inner::Connecting { bound_socket, .. } => bound_socket.poll(mask, poller),
Inner::Unbound(unbound_socket) => unbound_socket.poll(mask, poller),
}
}
@ -90,8 +65,7 @@ impl Inner {
fn iface(&self) -> Option<Arc<dyn Iface>> {
match self {
Inner::Bound(bound_socket) => Some(bound_socket.iface().clone()),
Inner::Connecting { bound_socket, .. } => Some(bound_socket.iface().clone()),
_ => None,
Inner::Unbound(_) => None,
}
}
@ -99,17 +73,6 @@ impl Inner {
self.bound_socket()
.and_then(|socket| socket.local_endpoint())
}
fn remote_endpoint(&self) -> Option<IpEndpoint> {
if let Inner::Connecting {
remote_endpoint, ..
} = self
{
Some(*remote_endpoint)
} else {
None
}
}
}
impl InitStream {
@ -122,40 +85,38 @@ impl InitStream {
}
}
pub fn is_bound(&self) -> bool {
self.inner.read().is_bound()
pub fn new_bound(nonblocking: bool, bound_socket: Arc<AnyBoundSocket>) -> Self {
let inner = Inner::Bound(AlwaysSome::new(bound_socket));
Self {
is_nonblocking: AtomicBool::new(nonblocking),
inner: RwLock::new(inner),
}
}
pub fn bind(&self, endpoint: IpEndpoint) -> Result<()> {
self.inner.write().bind(endpoint)
}
pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
if !self.is_bound() {
pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result<ConnectingStream> {
if !self.inner.read().is_bound() {
self.inner
.write()
.bind_to_ephemeral_endpoint(remote_endpoint)?
}
self.inner.write().do_connect(*remote_endpoint)?;
// Wait until building connection
let poller = Poller::new();
loop {
poll_ifaces();
let events = self
.inner
.read()
.poll(IoEvents::OUT | IoEvents::IN, Some(&poller));
if events.contains(IoEvents::IN) || events.contains(IoEvents::OUT) {
return Ok(());
} else if !events.is_empty() {
return_errno_with_message!(Errno::ECONNREFUSED, "connect refused");
} else if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try connect again");
ConnectingStream::new(
self.is_nonblocking(),
self.inner.read().bound_socket().unwrap().clone(),
*remote_endpoint,
)
}
pub fn listen(&self, backlog: usize) -> Result<ListenStream> {
let bound_socket = if let Some(bound_socket) = self.inner.read().bound_socket() {
bound_socket.clone()
} else {
// FIXME: deal with connecting timeout
poller.wait()?;
}
}
return_errno_with_message!(Errno::EINVAL, "cannot listen without bound")
};
ListenStream::new(self.is_nonblocking(), bound_socket, backlog)
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
@ -165,21 +126,10 @@ impl InitStream {
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has local endpoint"))
}
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
self.inner
.read()
.remote_endpoint()
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint"))
}
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
self.inner.read().poll(mask, poller)
}
pub fn bound_socket(&self) -> Option<Arc<AnyBoundSocket>> {
self.inner.read().bound_socket().map(Clone::clone)
}
pub fn is_nonblocking(&self) -> bool {
self.is_nonblocking.load(Ordering::Relaxed)
}

View File

@ -1,5 +1,6 @@
use crate::events::IoEvents;
use crate::fs::{file_handle::FileLike, utils::StatusFlags};
use crate::net::iface::IpEndpoint;
use crate::net::socket::{
util::{
send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd,
@ -10,9 +11,13 @@ use crate::net::socket::{
use crate::prelude::*;
use crate::process::signal::Poller;
use self::{connected::ConnectedStream, init::InitStream, listen::ListenStream};
use self::{
connected::ConnectedStream, connecting::ConnectingStream, init::InitStream,
listen::ListenStream,
};
mod connected;
mod connecting;
mod init;
mod listen;
@ -23,6 +28,8 @@ pub struct StreamSocket {
enum State {
// Start state
Init(Arc<InitStream>),
// Intermediate state
Connecting(Arc<ConnectingStream>),
// Final State 1
Connected(Arc<ConnectedStream>),
// Final State 2
@ -40,6 +47,7 @@ impl StreamSocket {
fn is_nonblocking(&self) -> bool {
match &*self.state.read() {
State::Init(init) => init.is_nonblocking(),
State::Connecting(connecting) => connecting.is_nonblocking(),
State::Connected(connected) => connected.is_nonblocking(),
State::Listen(listen) => listen.is_nonblocking(),
}
@ -48,10 +56,25 @@ impl StreamSocket {
fn set_nonblocking(&self, nonblocking: bool) {
match &*self.state.read() {
State::Init(init) => init.set_nonblocking(nonblocking),
State::Connecting(connecting) => connecting.set_nonblocking(nonblocking),
State::Connected(connected) => connected.set_nonblocking(nonblocking),
State::Listen(listen) => listen.set_nonblocking(nonblocking),
}
}
fn do_connect(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<ConnectingStream>> {
let mut state = self.state.write();
let init_stream = match &*state {
State::Init(init_stream) => init_stream,
State::Listen(_) | State::Connecting(_) | State::Connected(_) => {
return_errno_with_message!(Errno::EINVAL, "cannot connect")
}
};
let connecting = Arc::new(init_stream.connect(remote_endpoint)?);
*state = State::Connecting(connecting.clone());
Ok(connecting)
}
}
impl FileLike for StreamSocket {
@ -72,6 +95,7 @@ impl FileLike for StreamSocket {
let state = self.state.read();
match &*state {
State::Init(init) => init.poll(mask, poller),
State::Connecting(connecting) => connecting.poll(mask, poller),
State::Connected(connected) => connected.poll(mask, poller),
State::Listen(listen) => listen.poll(mask, poller),
}
@ -112,44 +136,37 @@ impl Socket for StreamSocket {
fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
let remote_endpoint = sockaddr.try_into()?;
let init_stream = match &*self.state.read() {
State::Init(init_stream) => init_stream.clone(),
_ => return_errno_with_message!(Errno::EINVAL, "cannot connect"),
};
init_stream.connect(&remote_endpoint)?;
let connected_stream = {
let nonblocking = init_stream.is_nonblocking();
let bound_socket = init_stream.bound_socket().unwrap();
Arc::new(ConnectedStream::new(
nonblocking,
bound_socket,
remote_endpoint,
))
};
let connecting_stream = self.do_connect(&remote_endpoint)?;
match connecting_stream.wait_conn() {
Ok(connected_stream) => {
let connected_stream = Arc::new(connected_stream);
*self.state.write() = State::Connected(connected_stream);
Ok(())
}
Err((err, init_stream)) => {
let init_stream = Arc::new(init_stream);
*self.state.write() = State::Init(init_stream);
Err(err)
}
}
}
fn listen(&self, backlog: usize) -> Result<()> {
let mut state = self.state.write();
match &*state {
State::Init(init_stream) => {
if !init_stream.is_bound() {
return_errno_with_message!(Errno::EINVAL, "cannot listen without bound");
}
let nonblocking = init_stream.is_nonblocking();
let bound_socket = init_stream.bound_socket().unwrap();
let listener = Arc::new(ListenStream::new(nonblocking, bound_socket, backlog)?);
*state = State::Listen(listener);
Ok(())
let init_stream = match &*state {
State::Init(init_stream) => init_stream,
State::Connecting(connecting_stream) => {
return_errno_with_message!(Errno::EINVAL, "cannot listen for a connecting stream")
}
State::Listen(listen_stream) => {
return_errno_with_message!(Errno::EINVAL, "cannot listen for a listening stream")
}
_ => return_errno_with_message!(Errno::EINVAL, "cannot listen"),
}
State::Connected(_) => return_errno_with_message!(Errno::EINVAL, "cannot listen"),
};
let listener = Arc::new(init_stream.listen(backlog)?);
*state = State::Listen(listener);
Ok(())
}
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
@ -185,6 +202,7 @@ impl Socket for StreamSocket {
let state = self.state.read();
let local_endpoint = match &*state {
State::Init(init_stream) => init_stream.local_endpoint(),
State::Connecting(connecting_stream) => connecting_stream.local_endpoint(),
State::Listen(listen_stream) => listen_stream.local_endpoint(),
State::Connected(connected_stream) => connected_stream.local_endpoint(),
}?;
@ -194,7 +212,10 @@ impl Socket for StreamSocket {
fn peer_addr(&self) -> Result<SocketAddr> {
let state = self.state.read();
let remote_endpoint = match &*state {
State::Init(init_stream) => init_stream.remote_endpoint(),
State::Init(init_stream) => {
return_errno_with_message!(Errno::EINVAL, "init socket does not have peer")
}
State::Connecting(connecting_stream) => connecting_stream.remote_endpoint(),
State::Listen(listen_stream) => {
return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer")
}