From 68cf99993e053e3d9c32f7ecd59b170930b7747d Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Wed, 19 Feb 2025 15:33:03 +0800 Subject: [PATCH] Remove socket errors from `SocketOptionSet` --- kernel/src/net/socket/ip/datagram/mod.rs | 3 +- kernel/src/net/socket/ip/stream/connected.rs | 2 +- kernel/src/net/socket/ip/stream/connecting.rs | 2 +- kernel/src/net/socket/ip/stream/init.rs | 161 +++++++++++++----- kernel/src/net/socket/ip/stream/mod.rs | 134 ++++++--------- kernel/src/net/socket/util/options.rs | 15 +- test/apps/network/tcp_err.c | 55 ++++-- 7 files changed, 221 insertions(+), 151 deletions(-) diff --git a/kernel/src/net/socket/ip/datagram/mod.rs b/kernel/src/net/socket/ip/datagram/mod.rs index 9ec95683..301c3cd6 100644 --- a/kernel/src/net/socket/ip/datagram/mod.rs +++ b/kernel/src/net/socket/ip/datagram/mod.rs @@ -362,7 +362,8 @@ impl Socket for DatagramSocket { fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> { match_sock_option_mut!(option, { socket_errors: SocketError => { - self.options.write().socket.get_and_clear_sock_errors(socket_errors); + // TODO: Support socket errors for UDP sockets + socket_errors.set(None); return Ok(()); }, _ => () diff --git a/kernel/src/net/socket/ip/stream/connected.rs b/kernel/src/net/socket/ip/stream/connected.rs index 2b9932bc..fa98c7d3 100644 --- a/kernel/src/net/socket/ip/stream/connected.rs +++ b/kernel/src/net/socket/ip/stream/connected.rs @@ -160,7 +160,7 @@ impl ConnectedStream { self.tcp_conn.iface() } - pub fn check_new(&mut self) -> Result<()> { + pub fn finish_last_connect(&mut self) -> Result<()> { if !self.is_new_connection { return_errno_with_message!(Errno::EISCONN, "the socket is already connected"); } diff --git a/kernel/src/net/socket/ip/stream/connecting.rs b/kernel/src/net/socket/ip/stream/connecting.rs index 1efc9eae..cf5477a3 100644 --- a/kernel/src/net/socket/ip/stream/connecting.rs +++ b/kernel/src/net/socket/ip/stream/connecting.rs @@ -82,7 +82,7 @@ impl ConnectingStream { self.remote_endpoint, true, )), - ConnectState::Refused => ConnResult::Refused(InitStream::new_bound( + ConnectState::Refused => ConnResult::Refused(InitStream::new_refused( self.tcp_conn.into_bound_port().unwrap(), )), } diff --git a/kernel/src/net/socket/ip/stream/init.rs b/kernel/src/net/socket/ip/stream/init.rs index 6c37b3d2..cb1018d2 100644 --- a/kernel/src/net/socket/ip/stream/init.rs +++ b/kernel/src/net/socket/ip/stream/init.rs @@ -1,5 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 +use core::sync::atomic::{AtomicBool, Ordering}; + use aster_bigtcp::{socket::RawTcpOption, wire::IpEndpoint}; use super::{connecting::ConnectingStream, listen::ListenStream, StreamObserver}; @@ -12,49 +14,62 @@ use crate::{ prelude::*, }; -pub enum InitStream { - Unbound, - Bound(BoundPort), +pub struct InitStream { + bound_port: Option, + /// Indicates if the last `connect()` is considered to be done. + /// + /// If `connect()` is called but we're still in the `InitStream`, this means that the + /// connection is refused. + /// + /// * If the connection is refused synchronously, the error code is returned by the + /// `connect()` system call, and after that we always consider the `connect()` to be already + /// done. + /// + /// * If the connection is refused asynchronously (e.g., non-blocking sockets or interrupted + /// `connect()`), the last `connect()` is not considered to have been done until another + /// `connect()`, which checks and resets the boolean value and returns an appropriate error + /// code. + is_connect_done: bool, + /// Indicates whether the socket error is `ECONNREFUSED`. + /// + /// This boolean value is set to true when the connection is refused and set to false when the + /// error code is reported via either `getsockopt(SOL_SOCKET, SO_ERROR)` or `connect()`. + is_conn_refused: AtomicBool, } impl InitStream { pub fn new() -> Self { - InitStream::Unbound + Self { + bound_port: None, + is_connect_done: true, + is_conn_refused: AtomicBool::new(false), + } } pub fn new_bound(bound_port: BoundPort) -> Self { - InitStream::Bound(bound_port) + Self { + bound_port: Some(bound_port), + is_connect_done: true, + is_conn_refused: AtomicBool::new(false), + } } - pub fn bind( - self, - endpoint: &IpEndpoint, - can_reuse: bool, - ) -> core::result::Result { - match self { - InitStream::Unbound => (), - InitStream::Bound(bound_socket) => { - return Err(( - Error::with_message(Errno::EINVAL, "the socket is already bound to an address"), - InitStream::Bound(bound_socket), - )); - } - }; - - let bound_port = match bind_port(endpoint, can_reuse) { - Ok(bound_port) => bound_port, - Err(err) => return Err((err, Self::Unbound)), - }; - - Ok(bound_port) + pub fn new_refused(bound_port: BoundPort) -> Self { + Self { + bound_port: Some(bound_port), + is_connect_done: false, + is_conn_refused: AtomicBool::new(true), + } } - fn bind_to_ephemeral_endpoint( - self, - remote_endpoint: &IpEndpoint, - ) -> core::result::Result { - let endpoint = get_ephemeral_endpoint(remote_endpoint); - self.bind(&endpoint, false) + pub fn bind(&mut self, endpoint: &IpEndpoint, can_reuse: bool) -> Result<()> { + if self.bound_port.is_some() { + return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address"); + } + + self.bound_port = Some(bind_port(endpoint, can_reuse)?); + + Ok(()) } pub fn connect( @@ -63,13 +78,49 @@ impl InitStream { option: &RawTcpOption, observer: StreamObserver, ) -> core::result::Result { - let bound_port = match self { - InitStream::Bound(bound_port) => bound_port, - InitStream::Unbound => self.bind_to_ephemeral_endpoint(remote_endpoint)?, + debug_assert!( + self.is_connect_done, + "`finish_last_connect()` should be called before calling `connect()`" + ); + + let bound_port = if let Some(bound_port) = self.bound_port { + bound_port + } else { + let endpoint = get_ephemeral_endpoint(remote_endpoint); + match bind_port(&endpoint, false) { + Ok(bound_port) => bound_port, + Err(err) => return Err((err, self)), + } }; - ConnectingStream::new(bound_port, *remote_endpoint, option, observer) - .map_err(|(err, bound_port)| (err, InitStream::Bound(bound_port))) + ConnectingStream::new(bound_port, *remote_endpoint, option, observer).map_err( + |(err, bound_port)| { + if err.error() == Errno::ECONNREFUSED { + (err, InitStream::new_refused(bound_port)) + } else { + (err, InitStream::new_bound(bound_port)) + } + }, + ) + } + + pub fn finish_last_connect(&mut self) -> Result<()> { + if self.is_connect_done { + return Ok(()); + } + + self.is_connect_done = true; + + let is_conn_refused = self.is_conn_refused.get_mut(); + if *is_conn_refused { + *is_conn_refused = false; + return_errno_with_message!(Errno::ECONNREFUSED, "the connection is refused"); + } else { + return_errno_with_message!( + Errno::ECONNABORTED, + "the error code for the connection failure is not available" + ); + } } pub fn listen( @@ -78,10 +129,22 @@ impl InitStream { option: &RawTcpOption, observer: StreamObserver, ) -> core::result::Result { - let InitStream::Bound(bound_port) = self else { + if !self.is_connect_done { + // See the comments of `is_connect_done`. + // `listen()` is also not allowed until the second `connect()`. + return Err(( + Error::with_message( + Errno::EINVAL, + "the connection is refused, but the connecting phase is not done", + ), + self, + )); + } + + let Some(bound_port) = self.bound_port else { // FIXME: The socket should be bound to INADDR_ANY (i.e., 0.0.0.0) with an ephemeral // port. However, INADDR_ANY is not yet supported, so we need to return an error first. - debug_assert!(false, "listen() without bind() is not implemented"); + warn!("listen() without bind() is not implemented"); return Err(( Error::with_message(Errno::EINVAL, "listen() without bind() is not implemented"), self, @@ -90,19 +153,29 @@ impl InitStream { match ListenStream::new(bound_port, backlog, option, observer) { Ok(listen_stream) => Ok(listen_stream), - Err((bound_port, error)) => Err((error, Self::Bound(bound_port))), + Err((bound_port, error)) => Err((error, Self::new_bound(bound_port))), } } pub fn local_endpoint(&self) -> Option { - match self { - InitStream::Unbound => None, - InitStream::Bound(bound_port) => Some(bound_port.endpoint().unwrap()), - } + self.bound_port + .as_ref() + .map(|bound_port| bound_port.endpoint().unwrap()) } pub(super) fn check_io_events(&self) -> IoEvents { // Linux adds OUT and HUP events for a newly created socket IoEvents::OUT | IoEvents::HUP } + + pub(super) fn test_and_clear_error(&self) -> Option { + if self.is_conn_refused.swap(false, Ordering::Relaxed) { + Some(Error::with_message( + Errno::ECONNREFUSED, + "the connection is refused", + )) + } else { + None + } + } } diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index abfd3e2a..0879a135 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -134,7 +134,7 @@ impl StreamSocket { /// Ensures that the socket state is up to date and obtains a read lock on it. /// - /// For a description of what "up-to-date" means, see [`Self::update_connecting`]. + /// For a description of what "up-to-date" means, see [`Self::write_updated_state`]. fn read_updated_state(&self) -> RwLockReadGuard, PreemptDisabled> { loop { let state = self.state.read(); @@ -144,39 +144,25 @@ impl StreamSocket { }; drop(state); - self.update_connecting(); + self.write_updated_state(); } } /// Ensures that the socket state is up to date and obtains a write lock on it. /// - /// For a description of what "up-to-date" means, see [`Self::update_connecting`]. - fn write_updated_state(&self) -> RwLockWriteGuard, PreemptDisabled> { - self.update_connecting().1 - } - - /// Updates the socket state if the socket is an obsolete connecting socket. - /// /// A connecting socket can become obsolete because some network events can set the socket to /// connected state (if the connection succeeds) or initial state (if the connection is - /// refused) in [`Self::update_io_events`], but the state transition is delayed until the user + /// refused) in [`Self::check_io_events`], but the state transition is delayed until the user /// operates on the socket to avoid too many locks in the interrupt handler. /// /// This method performs the delayed state transition to ensure that the state is up to date - /// and returns the guards of the write-locked options and state. - fn update_connecting( - &self, - ) -> ( - RwLockWriteGuard, - RwLockWriteGuard, PreemptDisabled>, - ) { - // Hold the lock in advance to avoid race conditions. - let mut options = self.options.write(); + /// and returns the guard of the write-locked state. + fn write_updated_state(&self) -> RwLockWriteGuard, PreemptDisabled> { let mut state = self.state.write(); match state.as_ref() { State::Connecting(connection_stream) if connection_stream.has_result() => (), - _ => return (options, state), + _ => return state, } state.borrow(|owned_state| { @@ -186,33 +172,26 @@ impl StreamSocket { match connecting_stream.into_result() { ConnResult::Connecting(connecting_stream) => State::Connecting(connecting_stream), - ConnResult::Connected(connected_stream) => { - options.socket.set_sock_errors(None); - State::Connected(connected_stream) - } - ConnResult::Refused(init_stream) => { - options.socket.set_sock_errors(Some(Error::with_message( - Errno::ECONNREFUSED, - "the connection is refused", - ))); - State::Init(init_stream) - } + ConnResult::Connected(connected_stream) => State::Connected(connected_stream), + ConnResult::Refused(init_stream) => State::Init(init_stream), } }); - (options, state) + state } // Returns `None` to block the task and wait for the connection to be established, and returns // `Some(_)` if blocking is not necessary or not allowed. fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Option> { let is_nonblocking = self.is_nonblocking(); - let (mut options, mut state) = self.update_connecting(); + let options = self.options.read(); let raw_option = options.raw(); + let mut state = self.write_updated_state(); + let (result_or_block, iface_to_poll) = state.borrow_result(|mut owned_state| { - let init_stream = match owned_state { + let mut init_stream = match owned_state { State::Init(init_stream) => init_stream, State::Connecting(_) if is_nonblocking => { return ( @@ -228,7 +207,7 @@ impl StreamSocket { } State::Connecting(_) => return (owned_state, (None, None)), State::Connected(ref mut connected_stream) => { - let err = connected_stream.check_new(); + let err = connected_stream.finish_last_connect(); return (owned_state, (Some(err), None)); } State::Listen(_) => { @@ -245,19 +224,25 @@ impl StreamSocket { } }; - let connecting_stream = match init_stream.connect( + if let Err(err) = init_stream.finish_last_connect() { + return (State::Init(init_stream), (Some(Err(err)), None)); + } + + let (target_state, iface_to_poll) = match init_stream.connect( remote_endpoint, &raw_option, StreamObserver::new(self.pollee.clone()), ) { - Ok(connecting_stream) => connecting_stream, - Err((mut err, init_stream)) => { - // If the socket is nonblocking, we should return EINPROGRESS instead. - if is_nonblocking { - options.socket.set_sock_errors(Some(err)); - err = Error::new(Errno::EINPROGRESS); - } - + Ok(connecting_stream) => { + let iface_to_poll = connecting_stream.iface().clone(); + (State::Connecting(connecting_stream), Some(iface_to_poll)) + } + Err((err, init_stream)) if err.error() == Errno::ECONNREFUSED => { + // `ECONNREFUSED` should be reported asynchronously, i.e., we need to return + // `EINPROGRESS` first for non-blocking sockets. + (State::Init(init_stream), None) + } + Err((err, init_stream)) => { return (State::Init(init_stream), (Some(Err(err)), None)); } }; @@ -270,12 +255,8 @@ impl StreamSocket { } else { None }; - let iface_to_poll = connecting_stream.iface().clone(); - ( - State::Connecting(connecting_stream), - (result_or_block, Some(iface_to_poll)), - ) + (target_state, (result_or_block, iface_to_poll)) }); drop(state); @@ -288,17 +269,16 @@ impl StreamSocket { } fn check_connect(&self) -> Result<()> { - let (mut options, mut state) = self.update_connecting(); + let mut state = self.write_updated_state(); match state.as_mut() { + State::Init(init_stream) => init_stream.finish_last_connect(), State::Connecting(_) => { return_errno_with_message!(Errno::EAGAIN, "the connection is pending") } - State::Connected(connected_stream) => connected_stream.check_new(), - State::Init(_) | State::Listen(_) => { - let sock_errors = options.socket.sock_errors(); - options.socket.set_sock_errors(None); - sock_errors.map(Err).unwrap_or(Ok(())) + State::Connected(connected_stream) => connected_stream.finish_last_connect(), + State::Listen(_) => { + return_errno_with_message!(Errno::EISCONN, "the socket is listening") } } } @@ -392,6 +372,15 @@ impl StreamSocket { State::Connected(connected_stream) => connected_stream.check_io_events(), } } + + fn test_and_clear_error(&self) -> Option { + let state = self.read_updated_state(); + + match state.as_ref() { + State::Init(init_stream) => init_stream.test_and_clear_error(), + State::Connecting(_) | State::Listen(_) | State::Connected(_) => None, + } + } } impl Pollable for StreamSocket { @@ -418,26 +407,10 @@ impl Socket for StreamSocket { let can_reuse = self.options.read().socket.reuse_addr(); let mut state = self.write_updated_state(); - state.borrow_result(|owned_state| { - let State::Init(init_stream) = owned_state else { - return ( - owned_state, - Err(Error::with_message( - Errno::EINVAL, - "the socket is already bound to an address", - )), - ); - }; - - let bound_port = match init_stream.bind(&endpoint, can_reuse) { - Ok(bound_port) => bound_port, - Err((err, init_stream)) => { - return (State::Init(init_stream), Err(err)); - } - }; - - (State::Init(InitStream::new_bound(bound_port)), Ok(())) - }) + let State::Init(init_stream) = state.as_mut() else { + return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address"); + }; + init_stream.bind(&endpoint, can_reuse) } fn connect(&self, socket_addr: SocketAddr) -> Result<()> { @@ -451,10 +424,11 @@ impl Socket for StreamSocket { } fn listen(&self, backlog: usize) -> Result<()> { - let (options, mut state) = self.update_connecting(); - + let options = self.options.read(); let raw_option = options.raw(); + let mut state = self.write_updated_state(); + state.borrow_result(|owned_state| { let init_stream = match owned_state { State::Init(init_stream) => init_stream, @@ -588,8 +562,7 @@ impl Socket for StreamSocket { fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> { match_sock_option_mut!(option, { socket_errors: SocketError => { - let mut options = self.update_connecting().0; - options.socket.get_and_clear_sock_errors(socket_errors); + socket_errors.set(self.test_and_clear_error()); return Ok(()); }, _ => () @@ -660,7 +633,8 @@ impl Socket for StreamSocket { } fn set_option(&self, option: &dyn SocketOption) -> Result<()> { - let (mut options, mut state) = self.update_connecting(); + let mut options = self.options.write(); + let mut state = self.write_updated_state(); let need_iface_poll = match options.socket.set_option(option, state.as_mut()) { Err(err) if err.error() == Errno::ENOPROTOOPT => { diff --git a/kernel/src/net/socket/util/options.rs b/kernel/src/net/socket/util/options.rs index a2d720d0..6bd71a92 100644 --- a/kernel/src/net/socket/util/options.rs +++ b/kernel/src/net/socket/util/options.rs @@ -9,8 +9,7 @@ use aster_bigtcp::socket::{ use crate::{ match_sock_option_mut, match_sock_option_ref, net::socket::options::{ - Error as SocketError, KeepAlive, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, - SocketOption, + KeepAlive, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption, }, prelude::*, }; @@ -19,7 +18,6 @@ use crate::{ #[get_copy = "pub"] #[set = "pub"] pub struct SocketOptionSet { - sock_errors: Option, reuse_addr: bool, reuse_port: bool, send_buf: u32, @@ -32,7 +30,6 @@ impl SocketOptionSet { /// Return the default socket level options for tcp socket. pub fn new_tcp() -> Self { Self { - sock_errors: None, reuse_addr: false, reuse_port: false, send_buf: TCP_SEND_BUF_LEN as u32, @@ -45,7 +42,6 @@ impl SocketOptionSet { /// Return the default socket level options for udp socket. pub fn new_udp() -> Self { Self { - sock_errors: None, reuse_addr: false, reuse_port: false, send_buf: UDP_SEND_PAYLOAD_LEN as u32, @@ -55,15 +51,6 @@ impl SocketOptionSet { } } - /// Gets and clears the socket error. - /// - /// When processing the `getsockopt` system call, the socket error is automatically cleared - /// after reading. This method should be called to provide this behavior. - pub fn get_and_clear_sock_errors(&mut self, option: &mut SocketError) { - option.set(self.sock_errors()); - self.set_sock_errors(None); - } - /// Gets socket-level options. /// /// Note that the socket error has to be handled separately, because it is automatically diff --git a/test/apps/network/tcp_err.c b/test/apps/network/tcp_err.c index 8c42c0e0..4ce9b3e5 100644 --- a/test/apps/network/tcp_err.c +++ b/test/apps/network/tcp_err.c @@ -323,31 +323,58 @@ FN_TEST(async_connect) socklen_t errlen = sizeof(err); sk_addr.sin_port = 0xbeef; + +#define ASYNC_CONNECT \ + TEST_ERRNO(connect(sk_bound, (struct sockaddr *)&sk_addr, \ + sizeof(sk_addr)), \ + EINPROGRESS); \ + TEST_RES(poll(&pfd, 1, 60), pfd.revents &POLLOUT); + + ASYNC_CONNECT; + + // The second `connect` will fail with `ECONNREFUSED`. TEST_ERRNO(connect(sk_bound, (struct sockaddr *)&sk_addr, sizeof(sk_addr)), - EINPROGRESS); + ECONNREFUSED); - TEST_RES(poll(&pfd, 1, 60), pfd.revents & POLLOUT); - - TEST_RES(getsockopt(sk_bound, SOL_SOCKET, SO_ERROR, &err, &errlen), - errlen == sizeof(err) && err == ECONNREFUSED); + ASYNC_CONNECT; // Reading the socket error will cause it to be cleared + TEST_RES(getsockopt(sk_bound, SOL_SOCKET, SO_ERROR, &err, &errlen), + errlen == sizeof(err) && err == ECONNREFUSED); TEST_RES(getsockopt(sk_bound, SOL_SOCKET, SO_ERROR, &err, &errlen), errlen == sizeof(err) && err == 0); + + // `listen` won't succeed until the second `connect`. + TEST_ERRNO(listen(sk_bound, 10), EINVAL); + + // The second `connect` will fail with `ECONNABORTED` if the socket + // error is cleared. + TEST_ERRNO(connect(sk_bound, (struct sockaddr *)&sk_addr, + sizeof(sk_addr)), + ECONNABORTED); + +#undef ASYNC_CONNECT } END_TEST() -void set_blocking(int sockfd) +static void set_blocking(int sockfd, int is_blocking) { int flags = CHECK(fcntl(sockfd, F_GETFL, 0)); - CHECK(fcntl(sockfd, F_SETFL, flags & (~O_NONBLOCK))); + + if (is_blocking) { + flags &= ~O_NONBLOCK; + } else { + flags |= O_NONBLOCK; + } + + CHECK(fcntl(sockfd, F_SETFL, flags)); } FN_SETUP(enter_blocking_mode) { - set_blocking(sk_connected); - set_blocking(sk_bound); + set_blocking(sk_connected, 1); + set_blocking(sk_bound, 1); } END_SETUP() @@ -536,6 +563,7 @@ FN_TEST(bind_and_connect_same_address) TEST_SUCC(listen(sk_listen, 3)); + // For blocking sockets, conflict addresses result in `EADDRNOTAVAIL`. sk_addr.sin_port = htons(listen_port); TEST_SUCC(connect(sk_connect1, (struct sockaddr *)&sk_addr, sizeof(sk_addr))); @@ -543,8 +571,15 @@ FN_TEST(bind_and_connect_same_address) sizeof(sk_addr)), EADDRNOTAVAIL); + // For non-blocking sockets, conflict addresses also result in `EADDRNOTAVAIL`. + // (`EINPROGRESS` should _not_ be returned in this case.) + set_blocking(sk_connect2, 0); + TEST_ERRNO(connect(sk_connect2, (struct sockaddr *)&sk_addr, + sizeof(sk_addr)), + EADDRNOTAVAIL); + TEST_SUCC(close(sk_listen)); TEST_SUCC(close(sk_connect1)); TEST_SUCC(close(sk_connect2)); } -END_TEST() \ No newline at end of file +END_TEST()