Implement non-blocking connect

This commit is contained in:
Ruihan Li
2024-01-09 23:42:26 +08:00
committed by Tate, Hongliang Tian
parent 9211061181
commit 27c5c27fd0
6 changed files with 146 additions and 38 deletions

View File

@ -17,13 +17,29 @@ use crate::{
pub struct ConnectedStream {
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
/// Indicates whether this connection is "new" in a `connect()` system call.
///
/// If the connection is not new, `connect()` will fail with the error code `EISCONN`,
/// otherwise it will succeed. This means that `connect()` will succeed _exactly_ once,
/// regardless of whether the connection is established synchronously or asynchronously.
///
/// If the connection is established synchronously, the synchronous `connect()` will succeed
/// and any subsequent `connect()` will fail; otherwise, the first `connect()` after the
/// connection is established asynchronously will succeed and any subsequent `connect()` will
/// fail.
is_new_connection: bool,
}
impl ConnectedStream {
pub fn new(bound_socket: Arc<AnyBoundSocket>, remote_endpoint: IpEndpoint) -> Self {
pub fn new(
bound_socket: Arc<AnyBoundSocket>,
remote_endpoint: IpEndpoint,
is_new_connection: bool,
) -> Self {
Self {
bound_socket,
remote_endpoint,
is_new_connection,
}
}
@ -73,6 +89,15 @@ impl ConnectedStream {
self.remote_endpoint
}
pub fn check_new(&mut self) -> Result<()> {
if !self.is_new_connection {
return_errno_with_message!(Errno::EISCONN, "the socket is already connected");
}
self.is_new_connection = false;
Ok(())
}
pub(super) fn init_pollee(&self, pollee: &Pollee) {
pollee.reset_events();
self.update_io_events(pollee);

View File

@ -2,7 +2,6 @@
use super::{connected::ConnectedStream, init::InitStream};
use crate::{
events::IoEvents,
net::iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket},
prelude::*,
process::signal::Pollee,
@ -46,6 +45,7 @@ impl ConnectingStream {
Some(ConnResult::Connected) => Ok(ConnectedStream::new(
self.bound_socket,
self.remote_endpoint,
true,
)),
Some(ConnResult::Refused) => Err((
Error::with_message(Errno::ECONNREFUSED, "the connection is refused"),
@ -68,15 +68,20 @@ impl ConnectingStream {
pub(super) fn init_pollee(&self, pollee: &Pollee) {
pollee.reset_events();
self.update_io_events(pollee);
}
pub(super) fn update_io_events(&self, pollee: &Pollee) {
/// Returns `true` when `conn_result` becomes ready, which indicates that the caller should
/// invoke the `into_result()` method as soon as possible.
///
/// Since `into_result()` needs to be called only once, this method will return `true`
/// _exactly_ once. The caller is responsible for not missing this event.
#[must_use]
pub(super) fn update_io_events(&self) -> bool {
if self.conn_result.read().is_some() {
return;
return false;
}
let became_writable = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
let mut result = self.conn_result.write();
if result.is_some() {
return false;
@ -94,17 +99,6 @@ impl ConnectingStream {
// Refused
*result = Some(ConnResult::Refused);
true
});
// Either when the connection is established, or when the connection fails, the socket
// shall indicate that it is writable.
//
// TODO: Find a way to turn `ConnectingStream` into `ConnectedStream` or `InitStream`
// here, so non-blocking `connect()` can work correctly. Meanwhile, the latter should
// be responsible to initialize all the I/O events including `IoEvents::OUT`, so the
// following hard-coded event addition can be removed.
if became_writable {
pollee.add_events(IoEvents::OUT);
}
})
}
}

View File

@ -72,6 +72,7 @@ impl ListenStream {
Ok(ConnectedStream::new(
active_backlog_socket.into_bound_socket(),
remote_endpoint,
false,
))
}

View File

@ -114,26 +114,61 @@ impl StreamSocket {
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
}
fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
// Returns `None` to block the task and wait for the connection to be established, and returns
// `Some(_)` if blocking is not necessary or not allowed.
fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Option<Result<()>> {
let is_nonblocking = self.is_nonblocking();
let mut state = self.state.write();
state.borrow_result(|owned_state| {
let State::Init(init_stream) = owned_state else {
return (
owned_state,
Err(Error::with_message(Errno::EINVAL, "cannot connect")),
);
state.borrow_result(|mut owned_state| {
let init_stream = match owned_state {
State::Init(init_stream) => init_stream,
State::Connecting(_) if is_nonblocking => {
return (
owned_state,
Some(Err(Error::with_message(
Errno::EALREADY,
"the socket is connecting",
))),
);
}
State::Connecting(_) => {
return (owned_state, None);
}
State::Connected(ref mut connected_stream) => {
let err = connected_stream.check_new();
return (owned_state, Some(err));
}
State::Listen(_) => {
return (
owned_state,
Some(Err(Error::with_message(
Errno::EISCONN,
"the socket is listening",
))),
);
}
};
let connecting_stream = match init_stream.connect(remote_endpoint) {
Ok(connecting_stream) => connecting_stream,
Err((err, init_stream)) => {
return (State::Init(init_stream), Err(err));
return (State::Init(init_stream), Some(Err(err)));
}
};
connecting_stream.init_pollee(&self.pollee);
(State::Connecting(connecting_stream), Ok(()))
(
State::Connecting(connecting_stream),
if is_nonblocking {
Some(Err(Error::with_message(
Errno::EINPROGRESS,
"the socket is connecting",
)))
} else {
None
},
)
})
}
@ -168,6 +203,21 @@ impl StreamSocket {
})
}
fn check_connect(&self) -> Result<()> {
let mut state = self.state.write();
match state.as_mut() {
State::Connecting(_) => {
return_errno_with_message!(Errno::EAGAIN, "the connection is pending")
}
State::Connected(connected_stream) => connected_stream.check_new(),
State::Init(_) | State::Listen(_) => {
// FIXME: The error code should be retrieved via the `SO_ERROR` socket option
return_errno_with_message!(Errno::ECONNREFUSED, "the connection is refused");
}
}
}
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let state = self.state.read();
@ -244,15 +294,20 @@ impl StreamSocket {
}
}
fn update_io_events(&self) {
#[must_use]
fn update_io_events(&self) -> bool {
let state = self.state.read();
match state.as_ref() {
State::Init(_) => (),
State::Connecting(connecting_stream) => {
connecting_stream.update_io_events(&self.pollee)
State::Init(_) => false,
State::Connecting(connecting_stream) => connecting_stream.update_io_events(),
State::Listen(listen_stream) => {
listen_stream.update_io_events(&self.pollee);
false
}
State::Connected(connected_stream) => {
connected_stream.update_io_events(&self.pollee);
false
}
State::Listen(listen_stream) => listen_stream.update_io_events(&self.pollee),
State::Connected(connected_stream) => connected_stream.update_io_events(&self.pollee),
}
}
}
@ -343,13 +398,16 @@ impl Socket for StreamSocket {
})
}
// TODO: Support nonblocking mode
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
let remote_endpoint = socket_addr.try_into()?;
self.start_connect(&remote_endpoint)?;
if let Some(result) = self.start_connect(&remote_endpoint) {
poll_ifaces();
return result;
}
poll_ifaces();
self.wait_events(IoEvents::OUT, || self.finish_connect())
self.wait_events(IoEvents::OUT, || self.check_connect())
}
fn listen(&self, backlog: usize) -> Result<()> {
@ -583,6 +641,12 @@ impl Socket for StreamSocket {
impl Observer<()> for StreamSocket {
fn on_events(&self, events: &()) {
self.update_io_events();
let conn_ready = self.update_io_events();
if conn_ready {
// FIXME: The error code should be stored as the `SO_ERROR` socket option. Since it
// does not exist yet, we ignore the return value below.
let _ = self.finish_connect();
}
}
}

View File

@ -63,7 +63,7 @@ FN_SETUP(connected)
sk_addr.sin_port = S_PORT;
CHECK_WITH(connect(sk_connected, (struct sockaddr *)&sk_addr,
sizeof(sk_addr)),
_ret == 0 || errno == EINPROGRESS);
_ret < 0 && errno == EINPROGRESS);
}
END_SETUP()
@ -253,3 +253,18 @@ FN_TEST(poll)
(pfd.revents & (POLLIN | POLLOUT)) == POLLOUT);
}
END_TEST()
FN_TEST(connect)
{
struct sockaddr *psaddr = (struct sockaddr *)&sk_addr;
socklen_t addrlen = sizeof(sk_addr);
TEST_ERRNO(connect(sk_listen, psaddr, addrlen), EISCONN);
TEST_ERRNO(connect(sk_connected, psaddr, addrlen), 0);
TEST_ERRNO(connect(sk_connected, psaddr, addrlen), EISCONN);
TEST_ERRNO(connect(sk_accepted, psaddr, addrlen), EISCONN);
}
END_TEST()

View File

@ -189,3 +189,12 @@ FN_TEST(poll)
(pfd.revents & (POLLIN | POLLOUT)) == POLLOUT);
}
END_TEST()
FN_TEST(connect)
{
struct sockaddr *psaddr = (struct sockaddr *)&sk_addr;
socklen_t addrlen = sizeof(sk_addr);
TEST_SUCC(connect(sk_connected, psaddr, addrlen));
}
END_TEST()