mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-20 13:06:33 +00:00
Implement non-blocking connect
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
9211061181
commit
27c5c27fd0
@ -17,13 +17,29 @@ use crate::{
|
||||
pub struct ConnectedStream {
|
||||
bound_socket: Arc<AnyBoundSocket>,
|
||||
remote_endpoint: IpEndpoint,
|
||||
/// Indicates whether this connection is "new" in a `connect()` system call.
|
||||
///
|
||||
/// If the connection is not new, `connect()` will fail with the error code `EISCONN`,
|
||||
/// otherwise it will succeed. This means that `connect()` will succeed _exactly_ once,
|
||||
/// regardless of whether the connection is established synchronously or asynchronously.
|
||||
///
|
||||
/// If the connection is established synchronously, the synchronous `connect()` will succeed
|
||||
/// and any subsequent `connect()` will fail; otherwise, the first `connect()` after the
|
||||
/// connection is established asynchronously will succeed and any subsequent `connect()` will
|
||||
/// fail.
|
||||
is_new_connection: bool,
|
||||
}
|
||||
|
||||
impl ConnectedStream {
|
||||
pub fn new(bound_socket: Arc<AnyBoundSocket>, remote_endpoint: IpEndpoint) -> Self {
|
||||
pub fn new(
|
||||
bound_socket: Arc<AnyBoundSocket>,
|
||||
remote_endpoint: IpEndpoint,
|
||||
is_new_connection: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
bound_socket,
|
||||
remote_endpoint,
|
||||
is_new_connection,
|
||||
}
|
||||
}
|
||||
|
||||
@ -73,6 +89,15 @@ impl ConnectedStream {
|
||||
self.remote_endpoint
|
||||
}
|
||||
|
||||
pub fn check_new(&mut self) -> Result<()> {
|
||||
if !self.is_new_connection {
|
||||
return_errno_with_message!(Errno::EISCONN, "the socket is already connected");
|
||||
}
|
||||
|
||||
self.is_new_connection = false;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn init_pollee(&self, pollee: &Pollee) {
|
||||
pollee.reset_events();
|
||||
self.update_io_events(pollee);
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
use super::{connected::ConnectedStream, init::InitStream};
|
||||
use crate::{
|
||||
events::IoEvents,
|
||||
net::iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket},
|
||||
prelude::*,
|
||||
process::signal::Pollee,
|
||||
@ -46,6 +45,7 @@ impl ConnectingStream {
|
||||
Some(ConnResult::Connected) => Ok(ConnectedStream::new(
|
||||
self.bound_socket,
|
||||
self.remote_endpoint,
|
||||
true,
|
||||
)),
|
||||
Some(ConnResult::Refused) => Err((
|
||||
Error::with_message(Errno::ECONNREFUSED, "the connection is refused"),
|
||||
@ -68,15 +68,20 @@ impl ConnectingStream {
|
||||
|
||||
pub(super) fn init_pollee(&self, pollee: &Pollee) {
|
||||
pollee.reset_events();
|
||||
self.update_io_events(pollee);
|
||||
}
|
||||
|
||||
pub(super) fn update_io_events(&self, pollee: &Pollee) {
|
||||
/// Returns `true` when `conn_result` becomes ready, which indicates that the caller should
|
||||
/// invoke the `into_result()` method as soon as possible.
|
||||
///
|
||||
/// Since `into_result()` needs to be called only once, this method will return `true`
|
||||
/// _exactly_ once. The caller is responsible for not missing this event.
|
||||
#[must_use]
|
||||
pub(super) fn update_io_events(&self) -> bool {
|
||||
if self.conn_result.read().is_some() {
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
let became_writable = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
|
||||
self.bound_socket.raw_with(|socket: &mut RawTcpSocket| {
|
||||
let mut result = self.conn_result.write();
|
||||
if result.is_some() {
|
||||
return false;
|
||||
@ -94,17 +99,6 @@ impl ConnectingStream {
|
||||
// Refused
|
||||
*result = Some(ConnResult::Refused);
|
||||
true
|
||||
});
|
||||
|
||||
// Either when the connection is established, or when the connection fails, the socket
|
||||
// shall indicate that it is writable.
|
||||
//
|
||||
// TODO: Find a way to turn `ConnectingStream` into `ConnectedStream` or `InitStream`
|
||||
// here, so non-blocking `connect()` can work correctly. Meanwhile, the latter should
|
||||
// be responsible to initialize all the I/O events including `IoEvents::OUT`, so the
|
||||
// following hard-coded event addition can be removed.
|
||||
if became_writable {
|
||||
pollee.add_events(IoEvents::OUT);
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -72,6 +72,7 @@ impl ListenStream {
|
||||
Ok(ConnectedStream::new(
|
||||
active_backlog_socket.into_bound_socket(),
|
||||
remote_endpoint,
|
||||
false,
|
||||
))
|
||||
}
|
||||
|
||||
|
@ -114,26 +114,61 @@ impl StreamSocket {
|
||||
self.is_nonblocking.store(nonblocking, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> {
|
||||
// Returns `None` to block the task and wait for the connection to be established, and returns
|
||||
// `Some(_)` if blocking is not necessary or not allowed.
|
||||
fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Option<Result<()>> {
|
||||
let is_nonblocking = self.is_nonblocking();
|
||||
let mut state = self.state.write();
|
||||
|
||||
state.borrow_result(|owned_state| {
|
||||
let State::Init(init_stream) = owned_state else {
|
||||
return (
|
||||
owned_state,
|
||||
Err(Error::with_message(Errno::EINVAL, "cannot connect")),
|
||||
);
|
||||
state.borrow_result(|mut owned_state| {
|
||||
let init_stream = match owned_state {
|
||||
State::Init(init_stream) => init_stream,
|
||||
State::Connecting(_) if is_nonblocking => {
|
||||
return (
|
||||
owned_state,
|
||||
Some(Err(Error::with_message(
|
||||
Errno::EALREADY,
|
||||
"the socket is connecting",
|
||||
))),
|
||||
);
|
||||
}
|
||||
State::Connecting(_) => {
|
||||
return (owned_state, None);
|
||||
}
|
||||
State::Connected(ref mut connected_stream) => {
|
||||
let err = connected_stream.check_new();
|
||||
return (owned_state, Some(err));
|
||||
}
|
||||
State::Listen(_) => {
|
||||
return (
|
||||
owned_state,
|
||||
Some(Err(Error::with_message(
|
||||
Errno::EISCONN,
|
||||
"the socket is listening",
|
||||
))),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let connecting_stream = match init_stream.connect(remote_endpoint) {
|
||||
Ok(connecting_stream) => connecting_stream,
|
||||
Err((err, init_stream)) => {
|
||||
return (State::Init(init_stream), Err(err));
|
||||
return (State::Init(init_stream), Some(Err(err)));
|
||||
}
|
||||
};
|
||||
connecting_stream.init_pollee(&self.pollee);
|
||||
|
||||
(State::Connecting(connecting_stream), Ok(()))
|
||||
(
|
||||
State::Connecting(connecting_stream),
|
||||
if is_nonblocking {
|
||||
Some(Err(Error::with_message(
|
||||
Errno::EINPROGRESS,
|
||||
"the socket is connecting",
|
||||
)))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@ -168,6 +203,21 @@ impl StreamSocket {
|
||||
})
|
||||
}
|
||||
|
||||
fn check_connect(&self) -> Result<()> {
|
||||
let mut state = self.state.write();
|
||||
|
||||
match state.as_mut() {
|
||||
State::Connecting(_) => {
|
||||
return_errno_with_message!(Errno::EAGAIN, "the connection is pending")
|
||||
}
|
||||
State::Connected(connected_stream) => connected_stream.check_new(),
|
||||
State::Init(_) | State::Listen(_) => {
|
||||
// FIXME: The error code should be retrieved via the `SO_ERROR` socket option
|
||||
return_errno_with_message!(Errno::ECONNREFUSED, "the connection is refused");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
|
||||
let state = self.state.read();
|
||||
|
||||
@ -244,15 +294,20 @@ impl StreamSocket {
|
||||
}
|
||||
}
|
||||
|
||||
fn update_io_events(&self) {
|
||||
#[must_use]
|
||||
fn update_io_events(&self) -> bool {
|
||||
let state = self.state.read();
|
||||
match state.as_ref() {
|
||||
State::Init(_) => (),
|
||||
State::Connecting(connecting_stream) => {
|
||||
connecting_stream.update_io_events(&self.pollee)
|
||||
State::Init(_) => false,
|
||||
State::Connecting(connecting_stream) => connecting_stream.update_io_events(),
|
||||
State::Listen(listen_stream) => {
|
||||
listen_stream.update_io_events(&self.pollee);
|
||||
false
|
||||
}
|
||||
State::Connected(connected_stream) => {
|
||||
connected_stream.update_io_events(&self.pollee);
|
||||
false
|
||||
}
|
||||
State::Listen(listen_stream) => listen_stream.update_io_events(&self.pollee),
|
||||
State::Connected(connected_stream) => connected_stream.update_io_events(&self.pollee),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -343,13 +398,16 @@ impl Socket for StreamSocket {
|
||||
})
|
||||
}
|
||||
|
||||
// TODO: Support nonblocking mode
|
||||
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
|
||||
let remote_endpoint = socket_addr.try_into()?;
|
||||
self.start_connect(&remote_endpoint)?;
|
||||
|
||||
if let Some(result) = self.start_connect(&remote_endpoint) {
|
||||
poll_ifaces();
|
||||
return result;
|
||||
}
|
||||
poll_ifaces();
|
||||
self.wait_events(IoEvents::OUT, || self.finish_connect())
|
||||
|
||||
self.wait_events(IoEvents::OUT, || self.check_connect())
|
||||
}
|
||||
|
||||
fn listen(&self, backlog: usize) -> Result<()> {
|
||||
@ -583,6 +641,12 @@ impl Socket for StreamSocket {
|
||||
|
||||
impl Observer<()> for StreamSocket {
|
||||
fn on_events(&self, events: &()) {
|
||||
self.update_io_events();
|
||||
let conn_ready = self.update_io_events();
|
||||
|
||||
if conn_ready {
|
||||
// FIXME: The error code should be stored as the `SO_ERROR` socket option. Since it
|
||||
// does not exist yet, we ignore the return value below.
|
||||
let _ = self.finish_connect();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ FN_SETUP(connected)
|
||||
sk_addr.sin_port = S_PORT;
|
||||
CHECK_WITH(connect(sk_connected, (struct sockaddr *)&sk_addr,
|
||||
sizeof(sk_addr)),
|
||||
_ret == 0 || errno == EINPROGRESS);
|
||||
_ret < 0 && errno == EINPROGRESS);
|
||||
}
|
||||
END_SETUP()
|
||||
|
||||
@ -253,3 +253,18 @@ FN_TEST(poll)
|
||||
(pfd.revents & (POLLIN | POLLOUT)) == POLLOUT);
|
||||
}
|
||||
END_TEST()
|
||||
|
||||
FN_TEST(connect)
|
||||
{
|
||||
struct sockaddr *psaddr = (struct sockaddr *)&sk_addr;
|
||||
socklen_t addrlen = sizeof(sk_addr);
|
||||
|
||||
TEST_ERRNO(connect(sk_listen, psaddr, addrlen), EISCONN);
|
||||
|
||||
TEST_ERRNO(connect(sk_connected, psaddr, addrlen), 0);
|
||||
|
||||
TEST_ERRNO(connect(sk_connected, psaddr, addrlen), EISCONN);
|
||||
|
||||
TEST_ERRNO(connect(sk_accepted, psaddr, addrlen), EISCONN);
|
||||
}
|
||||
END_TEST()
|
||||
|
@ -189,3 +189,12 @@ FN_TEST(poll)
|
||||
(pfd.revents & (POLLIN | POLLOUT)) == POLLOUT);
|
||||
}
|
||||
END_TEST()
|
||||
|
||||
FN_TEST(connect)
|
||||
{
|
||||
struct sockaddr *psaddr = (struct sockaddr *)&sk_addr;
|
||||
socklen_t addrlen = sizeof(sk_addr);
|
||||
|
||||
TEST_SUCC(connect(sk_connected, psaddr, addrlen));
|
||||
}
|
||||
END_TEST()
|
||||
|
Reference in New Issue
Block a user