diff --git a/kernel/aster-nix/src/net/socket/ip/stream/connected.rs b/kernel/aster-nix/src/net/socket/ip/stream/connected.rs index 20e204cc7..6a889b69b 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/connected.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/connected.rs @@ -17,13 +17,29 @@ use crate::{ pub struct ConnectedStream { bound_socket: Arc, remote_endpoint: IpEndpoint, + /// Indicates whether this connection is "new" in a `connect()` system call. + /// + /// If the connection is not new, `connect()` will fail with the error code `EISCONN`, + /// otherwise it will succeed. This means that `connect()` will succeed _exactly_ once, + /// regardless of whether the connection is established synchronously or asynchronously. + /// + /// If the connection is established synchronously, the synchronous `connect()` will succeed + /// and any subsequent `connect()` will fail; otherwise, the first `connect()` after the + /// connection is established asynchronously will succeed and any subsequent `connect()` will + /// fail. + is_new_connection: bool, } impl ConnectedStream { - pub fn new(bound_socket: Arc, remote_endpoint: IpEndpoint) -> Self { + pub fn new( + bound_socket: Arc, + remote_endpoint: IpEndpoint, + is_new_connection: bool, + ) -> Self { Self { bound_socket, remote_endpoint, + is_new_connection, } } @@ -73,6 +89,15 @@ impl ConnectedStream { self.remote_endpoint } + pub fn check_new(&mut self) -> Result<()> { + if !self.is_new_connection { + return_errno_with_message!(Errno::EISCONN, "the socket is already connected"); + } + + self.is_new_connection = false; + Ok(()) + } + pub(super) fn init_pollee(&self, pollee: &Pollee) { pollee.reset_events(); self.update_io_events(pollee); diff --git a/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs b/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs index d1ac85383..3503a9afe 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/connecting.rs @@ -2,7 +2,6 @@ use super::{connected::ConnectedStream, init::InitStream}; use crate::{ - events::IoEvents, net::iface::{AnyBoundSocket, IpEndpoint, RawTcpSocket}, prelude::*, process::signal::Pollee, @@ -46,6 +45,7 @@ impl ConnectingStream { Some(ConnResult::Connected) => Ok(ConnectedStream::new( self.bound_socket, self.remote_endpoint, + true, )), Some(ConnResult::Refused) => Err(( Error::with_message(Errno::ECONNREFUSED, "the connection is refused"), @@ -68,15 +68,20 @@ impl ConnectingStream { pub(super) fn init_pollee(&self, pollee: &Pollee) { pollee.reset_events(); - self.update_io_events(pollee); } - pub(super) fn update_io_events(&self, pollee: &Pollee) { + /// Returns `true` when `conn_result` becomes ready, which indicates that the caller should + /// invoke the `into_result()` method as soon as possible. + /// + /// Since `into_result()` needs to be called only once, this method will return `true` + /// _exactly_ once. The caller is responsible for not missing this event. + #[must_use] + pub(super) fn update_io_events(&self) -> bool { if self.conn_result.read().is_some() { - return; + return false; } - let became_writable = self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { + self.bound_socket.raw_with(|socket: &mut RawTcpSocket| { let mut result = self.conn_result.write(); if result.is_some() { return false; @@ -94,17 +99,6 @@ impl ConnectingStream { // Refused *result = Some(ConnResult::Refused); true - }); - - // Either when the connection is established, or when the connection fails, the socket - // shall indicate that it is writable. - // - // TODO: Find a way to turn `ConnectingStream` into `ConnectedStream` or `InitStream` - // here, so non-blocking `connect()` can work correctly. Meanwhile, the latter should - // be responsible to initialize all the I/O events including `IoEvents::OUT`, so the - // following hard-coded event addition can be removed. - if became_writable { - pollee.add_events(IoEvents::OUT); - } + }) } } diff --git a/kernel/aster-nix/src/net/socket/ip/stream/listen.rs b/kernel/aster-nix/src/net/socket/ip/stream/listen.rs index dceed257c..8551b668b 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/listen.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/listen.rs @@ -72,6 +72,7 @@ impl ListenStream { Ok(ConnectedStream::new( active_backlog_socket.into_bound_socket(), remote_endpoint, + false, )) } diff --git a/kernel/aster-nix/src/net/socket/ip/stream/mod.rs b/kernel/aster-nix/src/net/socket/ip/stream/mod.rs index c8b126c98..158fd0704 100644 --- a/kernel/aster-nix/src/net/socket/ip/stream/mod.rs +++ b/kernel/aster-nix/src/net/socket/ip/stream/mod.rs @@ -114,26 +114,61 @@ impl StreamSocket { self.is_nonblocking.store(nonblocking, Ordering::Relaxed); } - fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Result<()> { + // Returns `None` to block the task and wait for the connection to be established, and returns + // `Some(_)` if blocking is not necessary or not allowed. + fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Option> { + let is_nonblocking = self.is_nonblocking(); let mut state = self.state.write(); - state.borrow_result(|owned_state| { - let State::Init(init_stream) = owned_state else { - return ( - owned_state, - Err(Error::with_message(Errno::EINVAL, "cannot connect")), - ); + state.borrow_result(|mut owned_state| { + let init_stream = match owned_state { + State::Init(init_stream) => init_stream, + State::Connecting(_) if is_nonblocking => { + return ( + owned_state, + Some(Err(Error::with_message( + Errno::EALREADY, + "the socket is connecting", + ))), + ); + } + State::Connecting(_) => { + return (owned_state, None); + } + State::Connected(ref mut connected_stream) => { + let err = connected_stream.check_new(); + return (owned_state, Some(err)); + } + State::Listen(_) => { + return ( + owned_state, + Some(Err(Error::with_message( + Errno::EISCONN, + "the socket is listening", + ))), + ); + } }; let connecting_stream = match init_stream.connect(remote_endpoint) { Ok(connecting_stream) => connecting_stream, Err((err, init_stream)) => { - return (State::Init(init_stream), Err(err)); + return (State::Init(init_stream), Some(Err(err))); } }; connecting_stream.init_pollee(&self.pollee); - (State::Connecting(connecting_stream), Ok(())) + ( + State::Connecting(connecting_stream), + if is_nonblocking { + Some(Err(Error::with_message( + Errno::EINPROGRESS, + "the socket is connecting", + ))) + } else { + None + }, + ) }) } @@ -168,6 +203,21 @@ impl StreamSocket { }) } + fn check_connect(&self) -> Result<()> { + let mut state = self.state.write(); + + match state.as_mut() { + State::Connecting(_) => { + return_errno_with_message!(Errno::EAGAIN, "the connection is pending") + } + State::Connected(connected_stream) => connected_stream.check_new(), + State::Init(_) | State::Listen(_) => { + // FIXME: The error code should be retrieved via the `SO_ERROR` socket option + return_errno_with_message!(Errno::ECONNREFUSED, "the connection is refused"); + } + } + } + fn try_accept(&self) -> Result<(Arc, SocketAddr)> { let state = self.state.read(); @@ -244,15 +294,20 @@ impl StreamSocket { } } - fn update_io_events(&self) { + #[must_use] + fn update_io_events(&self) -> bool { let state = self.state.read(); match state.as_ref() { - State::Init(_) => (), - State::Connecting(connecting_stream) => { - connecting_stream.update_io_events(&self.pollee) + State::Init(_) => false, + State::Connecting(connecting_stream) => connecting_stream.update_io_events(), + State::Listen(listen_stream) => { + listen_stream.update_io_events(&self.pollee); + false + } + State::Connected(connected_stream) => { + connected_stream.update_io_events(&self.pollee); + false } - State::Listen(listen_stream) => listen_stream.update_io_events(&self.pollee), - State::Connected(connected_stream) => connected_stream.update_io_events(&self.pollee), } } } @@ -343,13 +398,16 @@ impl Socket for StreamSocket { }) } - // TODO: Support nonblocking mode fn connect(&self, socket_addr: SocketAddr) -> Result<()> { let remote_endpoint = socket_addr.try_into()?; - self.start_connect(&remote_endpoint)?; + if let Some(result) = self.start_connect(&remote_endpoint) { + poll_ifaces(); + return result; + } poll_ifaces(); - self.wait_events(IoEvents::OUT, || self.finish_connect()) + + self.wait_events(IoEvents::OUT, || self.check_connect()) } fn listen(&self, backlog: usize) -> Result<()> { @@ -583,6 +641,12 @@ impl Socket for StreamSocket { impl Observer<()> for StreamSocket { fn on_events(&self, events: &()) { - self.update_io_events(); + let conn_ready = self.update_io_events(); + + if conn_ready { + // FIXME: The error code should be stored as the `SO_ERROR` socket option. Since it + // does not exist yet, we ignore the return value below. + let _ = self.finish_connect(); + } } } diff --git a/regression/apps/network/tcp_err.c b/regression/apps/network/tcp_err.c index 19f28d1c8..0c30e928f 100644 --- a/regression/apps/network/tcp_err.c +++ b/regression/apps/network/tcp_err.c @@ -63,7 +63,7 @@ FN_SETUP(connected) sk_addr.sin_port = S_PORT; CHECK_WITH(connect(sk_connected, (struct sockaddr *)&sk_addr, sizeof(sk_addr)), - _ret == 0 || errno == EINPROGRESS); + _ret < 0 && errno == EINPROGRESS); } END_SETUP() @@ -253,3 +253,18 @@ FN_TEST(poll) (pfd.revents & (POLLIN | POLLOUT)) == POLLOUT); } END_TEST() + +FN_TEST(connect) +{ + struct sockaddr *psaddr = (struct sockaddr *)&sk_addr; + socklen_t addrlen = sizeof(sk_addr); + + TEST_ERRNO(connect(sk_listen, psaddr, addrlen), EISCONN); + + TEST_ERRNO(connect(sk_connected, psaddr, addrlen), 0); + + TEST_ERRNO(connect(sk_connected, psaddr, addrlen), EISCONN); + + TEST_ERRNO(connect(sk_accepted, psaddr, addrlen), EISCONN); +} +END_TEST() diff --git a/regression/apps/network/udp_err.c b/regression/apps/network/udp_err.c index 75ec36701..9769d30ab 100644 --- a/regression/apps/network/udp_err.c +++ b/regression/apps/network/udp_err.c @@ -189,3 +189,12 @@ FN_TEST(poll) (pfd.revents & (POLLIN | POLLOUT)) == POLLOUT); } END_TEST() + +FN_TEST(connect) +{ + struct sockaddr *psaddr = (struct sockaddr *)&sk_addr; + socklen_t addrlen = sizeof(sk_addr); + + TEST_SUCC(connect(sk_connected, psaddr, addrlen)); +} +END_TEST()