mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-26 10:53:25 +00:00
Seperate ConnectingStream
from InitStream
For TCP streams we used to have three states, e.g. `InitStream`, `ListenStream`, `ConnectedStream`. If the socket is not bound, it is in the `InitStream` state. If the socket is bound, it is still in that state. Most seriously, if the socket is connecting to the remote peer, but the connection has not been established, it is also in the `InitStream` state. While the socket is trying to connect to its peer, it needs to handle iface events to update its internal state. So it is expected to implement the trait that observes such events for this later. However, the reality that sockets connecting to their peers are mixed in with other unbound and bound sockets in the `InitStream` will complicate things. In fact, the connecting socket should belong to an independent state. It does not share too much logic with unbound and bound sockets in the `InitStream`. So in this commit we will decouple that and create a new `ConnectingStream` state.
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
58948d498c
commit
6b903d0c10
@ -0,0 +1,83 @@
|
||||
use core::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use alloc::sync::Arc;
|
||||
|
||||
use crate::events::IoEvents;
|
||||
use crate::net::poll_ifaces;
|
||||
use crate::prelude::*;
|
||||
|
||||
use crate::net::iface::{AnyBoundSocket, IpEndpoint};
|
||||
use crate::process::signal::Poller;
|
||||
|
||||
use super::connected::ConnectedStream;
|
||||
use super::init::InitStream;
|
||||
|
||||
pub struct ConnectingStream {
|
||||
nonblocking: AtomicBool,
|
||||
bound_socket: Arc<AnyBoundSocket>,
|
||||
remote_endpoint: IpEndpoint,
|
||||
}
|
||||
|
||||
impl ConnectingStream {
|
||||
pub fn new(
|
||||
nonblocking: bool,
|
||||
bound_socket: Arc<AnyBoundSocket>,
|
||||
remote_endpoint: IpEndpoint,
|
||||
) -> Result<Self> {
|
||||
bound_socket.do_connect(remote_endpoint)?;
|
||||
|
||||
Ok(Self {
|
||||
nonblocking: AtomicBool::new(nonblocking),
|
||||
bound_socket,
|
||||
remote_endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn wait_conn(&self) -> core::result::Result<ConnectedStream, (Error, InitStream)> {
|
||||
debug_assert!(!self.is_nonblocking());
|
||||
|
||||
let poller = Poller::new();
|
||||
loop {
|
||||
poll_ifaces();
|
||||
|
||||
let events = self.poll(IoEvents::OUT | IoEvents::IN, Some(&poller));
|
||||
if events.contains(IoEvents::IN) || events.contains(IoEvents::OUT) {
|
||||
return Ok(ConnectedStream::new(
|
||||
self.is_nonblocking(),
|
||||
self.bound_socket.clone(),
|
||||
self.remote_endpoint,
|
||||
));
|
||||
} else if !events.is_empty() {
|
||||
return Err((
|
||||
Error::with_message(Errno::ECONNREFUSED, "connection refused"),
|
||||
InitStream::new_bound(self.is_nonblocking(), self.bound_socket.clone()),
|
||||
));
|
||||
} else {
|
||||
// FIXME: deal with nonblocking mode & connecting timeout
|
||||
poller.wait().expect("async connect() not implemented");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
|
||||
self.bound_socket
|
||||
.local_endpoint()
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "no local endpoint"))
|
||||
}
|
||||
|
||||
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
|
||||
Ok(self.remote_endpoint)
|
||||
}
|
||||
|
||||
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.bound_socket.poll(mask, poller)
|
||||
}
|
||||
|
||||
pub fn is_nonblocking(&self) -> bool {
|
||||
self.nonblocking.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn set_nonblocking(&self, nonblocking: bool) {
|
||||
self.nonblocking.store(nonblocking, Ordering::Relaxed);
|
||||
}
|
||||
}
|
@ -4,12 +4,14 @@ use crate::events::IoEvents;
|
||||
use crate::net::iface::Iface;
|
||||
use crate::net::iface::IpEndpoint;
|
||||
use crate::net::iface::{AnyBoundSocket, AnyUnboundSocket};
|
||||
use crate::net::poll_ifaces;
|
||||
use crate::net::socket::ip::always_some::AlwaysSome;
|
||||
use crate::net::socket::ip::common::{bind_socket, get_ephemeral_endpoint};
|
||||
use crate::prelude::*;
|
||||
use crate::process::signal::Poller;
|
||||
|
||||
use super::connecting::ConnectingStream;
|
||||
use super::listen::ListenStream;
|
||||
|
||||
pub struct InitStream {
|
||||
inner: RwLock<Inner>,
|
||||
is_nonblocking: AtomicBool,
|
||||
@ -18,17 +20,13 @@ pub struct InitStream {
|
||||
enum Inner {
|
||||
Unbound(AlwaysSome<Box<AnyUnboundSocket>>),
|
||||
Bound(AlwaysSome<Arc<AnyBoundSocket>>),
|
||||
Connecting {
|
||||
bound_socket: Arc<AnyBoundSocket>,
|
||||
remote_endpoint: IpEndpoint,
|
||||
},
|
||||
}
|
||||
|
||||
impl Inner {
|
||||
fn is_bound(&self) -> bool {
|
||||
match self {
|
||||
Self::Unbound(_) => false,
|
||||
Self::Bound(..) | Self::Connecting { .. } => true,
|
||||
Self::Bound(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,39 +48,16 @@ impl Inner {
|
||||
self.bind(endpoint)
|
||||
}
|
||||
|
||||
fn do_connect(&mut self, new_remote_endpoint: IpEndpoint) -> Result<()> {
|
||||
match self {
|
||||
Inner::Unbound(_) => return_errno_with_message!(Errno::EINVAL, "the socket is invalid"),
|
||||
Inner::Connecting {
|
||||
bound_socket,
|
||||
remote_endpoint,
|
||||
} => {
|
||||
*remote_endpoint = new_remote_endpoint;
|
||||
bound_socket.do_connect(new_remote_endpoint)?;
|
||||
}
|
||||
Inner::Bound(bound_socket) => {
|
||||
bound_socket.do_connect(new_remote_endpoint)?;
|
||||
*self = Inner::Connecting {
|
||||
bound_socket: bound_socket.take(),
|
||||
remote_endpoint: new_remote_endpoint,
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn bound_socket(&self) -> Option<&Arc<AnyBoundSocket>> {
|
||||
match self {
|
||||
Inner::Bound(bound_socket) => Some(bound_socket),
|
||||
Inner::Connecting { bound_socket, .. } => Some(bound_socket),
|
||||
_ => None,
|
||||
Inner::Unbound(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
match self {
|
||||
Inner::Bound(bound_socket) => bound_socket.poll(mask, poller),
|
||||
Inner::Connecting { bound_socket, .. } => bound_socket.poll(mask, poller),
|
||||
Inner::Unbound(unbound_socket) => unbound_socket.poll(mask, poller),
|
||||
}
|
||||
}
|
||||
@ -90,8 +65,7 @@ impl Inner {
|
||||
fn iface(&self) -> Option<Arc<dyn Iface>> {
|
||||
match self {
|
||||
Inner::Bound(bound_socket) => Some(bound_socket.iface().clone()),
|
||||
Inner::Connecting { bound_socket, .. } => Some(bound_socket.iface().clone()),
|
||||
_ => None,
|
||||
Inner::Unbound(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -99,17 +73,6 @@ impl Inner {
|
||||
self.bound_socket()
|
||||
.and_then(|socket| socket.local_endpoint())
|
||||
}
|
||||
|
||||
fn remote_endpoint(&self) -> Option<IpEndpoint> {
|
||||
if let Inner::Connecting {
|
||||
remote_endpoint, ..
|
||||
} = self
|
||||
{
|
||||
Some(*remote_endpoint)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl InitStream {
|
||||
@ -122,40 +85,38 @@ impl InitStream {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_bound(&self) -> bool {
|
||||
self.inner.read().is_bound()
|
||||
pub fn new_bound(nonblocking: bool, bound_socket: Arc<AnyBoundSocket>) -> Self {
|
||||
let inner = Inner::Bound(AlwaysSome::new(bound_socket));
|
||||
Self {
|
||||
is_nonblocking: AtomicBool::new(nonblocking),
|
||||
inner: RwLock::new(inner),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bind(&self, endpoint: IpEndpoint) -> Result<()> {
|
||||
self.inner.write().bind(endpoint)
|
||||
}
|
||||
|
||||
pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
|
||||
if !self.is_bound() {
|
||||
pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result<ConnectingStream> {
|
||||
if !self.inner.read().is_bound() {
|
||||
self.inner
|
||||
.write()
|
||||
.bind_to_ephemeral_endpoint(remote_endpoint)?
|
||||
}
|
||||
self.inner.write().do_connect(*remote_endpoint)?;
|
||||
// Wait until building connection
|
||||
let poller = Poller::new();
|
||||
loop {
|
||||
poll_ifaces();
|
||||
let events = self
|
||||
.inner
|
||||
.read()
|
||||
.poll(IoEvents::OUT | IoEvents::IN, Some(&poller));
|
||||
if events.contains(IoEvents::IN) || events.contains(IoEvents::OUT) {
|
||||
return Ok(());
|
||||
} else if !events.is_empty() {
|
||||
return_errno_with_message!(Errno::ECONNREFUSED, "connect refused");
|
||||
} else if self.is_nonblocking() {
|
||||
return_errno_with_message!(Errno::EAGAIN, "try connect again");
|
||||
ConnectingStream::new(
|
||||
self.is_nonblocking(),
|
||||
self.inner.read().bound_socket().unwrap().clone(),
|
||||
*remote_endpoint,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn listen(&self, backlog: usize) -> Result<ListenStream> {
|
||||
let bound_socket = if let Some(bound_socket) = self.inner.read().bound_socket() {
|
||||
bound_socket.clone()
|
||||
} else {
|
||||
// FIXME: deal with connecting timeout
|
||||
poller.wait()?;
|
||||
}
|
||||
}
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot listen without bound")
|
||||
};
|
||||
ListenStream::new(self.is_nonblocking(), bound_socket, backlog)
|
||||
}
|
||||
|
||||
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
|
||||
@ -165,21 +126,10 @@ impl InitStream {
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has local endpoint"))
|
||||
}
|
||||
|
||||
pub fn remote_endpoint(&self) -> Result<IpEndpoint> {
|
||||
self.inner
|
||||
.read()
|
||||
.remote_endpoint()
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has remote endpoint"))
|
||||
}
|
||||
|
||||
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.inner.read().poll(mask, poller)
|
||||
}
|
||||
|
||||
pub fn bound_socket(&self) -> Option<Arc<AnyBoundSocket>> {
|
||||
self.inner.read().bound_socket().map(Clone::clone)
|
||||
}
|
||||
|
||||
pub fn is_nonblocking(&self) -> bool {
|
||||
self.is_nonblocking.load(Ordering::Relaxed)
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::events::IoEvents;
|
||||
use crate::fs::{file_handle::FileLike, utils::StatusFlags};
|
||||
use crate::net::iface::IpEndpoint;
|
||||
use crate::net::socket::{
|
||||
util::{
|
||||
send_recv_flags::SendRecvFlags, shutdown_cmd::SockShutdownCmd,
|
||||
@ -10,9 +11,13 @@ use crate::net::socket::{
|
||||
use crate::prelude::*;
|
||||
use crate::process::signal::Poller;
|
||||
|
||||
use self::{connected::ConnectedStream, init::InitStream, listen::ListenStream};
|
||||
use self::{
|
||||
connected::ConnectedStream, connecting::ConnectingStream, init::InitStream,
|
||||
listen::ListenStream,
|
||||
};
|
||||
|
||||
mod connected;
|
||||
mod connecting;
|
||||
mod init;
|
||||
mod listen;
|
||||
|
||||
@ -23,6 +28,8 @@ pub struct StreamSocket {
|
||||
enum State {
|
||||
// Start state
|
||||
Init(Arc<InitStream>),
|
||||
// Intermediate state
|
||||
Connecting(Arc<ConnectingStream>),
|
||||
// Final State 1
|
||||
Connected(Arc<ConnectedStream>),
|
||||
// Final State 2
|
||||
@ -40,6 +47,7 @@ impl StreamSocket {
|
||||
fn is_nonblocking(&self) -> bool {
|
||||
match &*self.state.read() {
|
||||
State::Init(init) => init.is_nonblocking(),
|
||||
State::Connecting(connecting) => connecting.is_nonblocking(),
|
||||
State::Connected(connected) => connected.is_nonblocking(),
|
||||
State::Listen(listen) => listen.is_nonblocking(),
|
||||
}
|
||||
@ -48,10 +56,25 @@ impl StreamSocket {
|
||||
fn set_nonblocking(&self, nonblocking: bool) {
|
||||
match &*self.state.read() {
|
||||
State::Init(init) => init.set_nonblocking(nonblocking),
|
||||
State::Connecting(connecting) => connecting.set_nonblocking(nonblocking),
|
||||
State::Connected(connected) => connected.set_nonblocking(nonblocking),
|
||||
State::Listen(listen) => listen.set_nonblocking(nonblocking),
|
||||
}
|
||||
}
|
||||
|
||||
fn do_connect(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<ConnectingStream>> {
|
||||
let mut state = self.state.write();
|
||||
let init_stream = match &*state {
|
||||
State::Init(init_stream) => init_stream,
|
||||
State::Listen(_) | State::Connecting(_) | State::Connected(_) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot connect")
|
||||
}
|
||||
};
|
||||
|
||||
let connecting = Arc::new(init_stream.connect(remote_endpoint)?);
|
||||
*state = State::Connecting(connecting.clone());
|
||||
Ok(connecting)
|
||||
}
|
||||
}
|
||||
|
||||
impl FileLike for StreamSocket {
|
||||
@ -72,6 +95,7 @@ impl FileLike for StreamSocket {
|
||||
let state = self.state.read();
|
||||
match &*state {
|
||||
State::Init(init) => init.poll(mask, poller),
|
||||
State::Connecting(connecting) => connecting.poll(mask, poller),
|
||||
State::Connected(connected) => connected.poll(mask, poller),
|
||||
State::Listen(listen) => listen.poll(mask, poller),
|
||||
}
|
||||
@ -112,44 +136,37 @@ impl Socket for StreamSocket {
|
||||
fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
|
||||
let remote_endpoint = sockaddr.try_into()?;
|
||||
|
||||
let init_stream = match &*self.state.read() {
|
||||
State::Init(init_stream) => init_stream.clone(),
|
||||
_ => return_errno_with_message!(Errno::EINVAL, "cannot connect"),
|
||||
};
|
||||
|
||||
init_stream.connect(&remote_endpoint)?;
|
||||
|
||||
let connected_stream = {
|
||||
let nonblocking = init_stream.is_nonblocking();
|
||||
let bound_socket = init_stream.bound_socket().unwrap();
|
||||
Arc::new(ConnectedStream::new(
|
||||
nonblocking,
|
||||
bound_socket,
|
||||
remote_endpoint,
|
||||
))
|
||||
};
|
||||
let connecting_stream = self.do_connect(&remote_endpoint)?;
|
||||
match connecting_stream.wait_conn() {
|
||||
Ok(connected_stream) => {
|
||||
let connected_stream = Arc::new(connected_stream);
|
||||
*self.state.write() = State::Connected(connected_stream);
|
||||
Ok(())
|
||||
}
|
||||
Err((err, init_stream)) => {
|
||||
let init_stream = Arc::new(init_stream);
|
||||
*self.state.write() = State::Init(init_stream);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn listen(&self, backlog: usize) -> Result<()> {
|
||||
let mut state = self.state.write();
|
||||
match &*state {
|
||||
State::Init(init_stream) => {
|
||||
if !init_stream.is_bound() {
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot listen without bound");
|
||||
}
|
||||
let nonblocking = init_stream.is_nonblocking();
|
||||
let bound_socket = init_stream.bound_socket().unwrap();
|
||||
let listener = Arc::new(ListenStream::new(nonblocking, bound_socket, backlog)?);
|
||||
*state = State::Listen(listener);
|
||||
Ok(())
|
||||
let init_stream = match &*state {
|
||||
State::Init(init_stream) => init_stream,
|
||||
State::Connecting(connecting_stream) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot listen for a connecting stream")
|
||||
}
|
||||
State::Listen(listen_stream) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "cannot listen for a listening stream")
|
||||
}
|
||||
_ => return_errno_with_message!(Errno::EINVAL, "cannot listen"),
|
||||
}
|
||||
State::Connected(_) => return_errno_with_message!(Errno::EINVAL, "cannot listen"),
|
||||
};
|
||||
|
||||
let listener = Arc::new(init_stream.listen(backlog)?);
|
||||
*state = State::Listen(listener);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
||||
@ -185,6 +202,7 @@ impl Socket for StreamSocket {
|
||||
let state = self.state.read();
|
||||
let local_endpoint = match &*state {
|
||||
State::Init(init_stream) => init_stream.local_endpoint(),
|
||||
State::Connecting(connecting_stream) => connecting_stream.local_endpoint(),
|
||||
State::Listen(listen_stream) => listen_stream.local_endpoint(),
|
||||
State::Connected(connected_stream) => connected_stream.local_endpoint(),
|
||||
}?;
|
||||
@ -194,7 +212,10 @@ impl Socket for StreamSocket {
|
||||
fn peer_addr(&self) -> Result<SocketAddr> {
|
||||
let state = self.state.read();
|
||||
let remote_endpoint = match &*state {
|
||||
State::Init(init_stream) => init_stream.remote_endpoint(),
|
||||
State::Init(init_stream) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "init socket does not have peer")
|
||||
}
|
||||
State::Connecting(connecting_stream) => connecting_stream.remote_endpoint(),
|
||||
State::Listen(listen_stream) => {
|
||||
return_errno_with_message!(Errno::EINVAL, "listening socket does not have peer")
|
||||
}
|
||||
|
Reference in New Issue
Block a user