Remove socket errors from SocketOptionSet

This commit is contained in:
Ruihan Li 2025-02-19 15:33:03 +08:00 committed by Tate, Hongliang Tian
parent 76e9694dd0
commit 68cf99993e
7 changed files with 221 additions and 151 deletions

View File

@ -362,7 +362,8 @@ impl Socket for DatagramSocket {
fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> {
match_sock_option_mut!(option, {
socket_errors: SocketError => {
self.options.write().socket.get_and_clear_sock_errors(socket_errors);
// TODO: Support socket errors for UDP sockets
socket_errors.set(None);
return Ok(());
},
_ => ()

View File

@ -160,7 +160,7 @@ impl ConnectedStream {
self.tcp_conn.iface()
}
pub fn check_new(&mut self) -> Result<()> {
pub fn finish_last_connect(&mut self) -> Result<()> {
if !self.is_new_connection {
return_errno_with_message!(Errno::EISCONN, "the socket is already connected");
}

View File

@ -82,7 +82,7 @@ impl ConnectingStream {
self.remote_endpoint,
true,
)),
ConnectState::Refused => ConnResult::Refused(InitStream::new_bound(
ConnectState::Refused => ConnResult::Refused(InitStream::new_refused(
self.tcp_conn.into_bound_port().unwrap(),
)),
}

View File

@ -1,5 +1,7 @@
// SPDX-License-Identifier: MPL-2.0
use core::sync::atomic::{AtomicBool, Ordering};
use aster_bigtcp::{socket::RawTcpOption, wire::IpEndpoint};
use super::{connecting::ConnectingStream, listen::ListenStream, StreamObserver};
@ -12,49 +14,62 @@ use crate::{
prelude::*,
};
pub enum InitStream {
Unbound,
Bound(BoundPort),
pub struct InitStream {
bound_port: Option<BoundPort>,
/// Indicates if the last `connect()` is considered to be done.
///
/// If `connect()` is called but we're still in the `InitStream`, this means that the
/// connection is refused.
///
/// * If the connection is refused synchronously, the error code is returned by the
/// `connect()` system call, and after that we always consider the `connect()` to be already
/// done.
///
/// * If the connection is refused asynchronously (e.g., non-blocking sockets or interrupted
/// `connect()`), the last `connect()` is not considered to have been done until another
/// `connect()`, which checks and resets the boolean value and returns an appropriate error
/// code.
is_connect_done: bool,
/// Indicates whether the socket error is `ECONNREFUSED`.
///
/// This boolean value is set to true when the connection is refused and set to false when the
/// error code is reported via either `getsockopt(SOL_SOCKET, SO_ERROR)` or `connect()`.
is_conn_refused: AtomicBool,
}
impl InitStream {
pub fn new() -> Self {
InitStream::Unbound
Self {
bound_port: None,
is_connect_done: true,
is_conn_refused: AtomicBool::new(false),
}
}
pub fn new_bound(bound_port: BoundPort) -> Self {
InitStream::Bound(bound_port)
Self {
bound_port: Some(bound_port),
is_connect_done: true,
is_conn_refused: AtomicBool::new(false),
}
}
pub fn bind(
self,
endpoint: &IpEndpoint,
can_reuse: bool,
) -> core::result::Result<BoundPort, (Error, Self)> {
match self {
InitStream::Unbound => (),
InitStream::Bound(bound_socket) => {
return Err((
Error::with_message(Errno::EINVAL, "the socket is already bound to an address"),
InitStream::Bound(bound_socket),
));
}
};
let bound_port = match bind_port(endpoint, can_reuse) {
Ok(bound_port) => bound_port,
Err(err) => return Err((err, Self::Unbound)),
};
Ok(bound_port)
pub fn new_refused(bound_port: BoundPort) -> Self {
Self {
bound_port: Some(bound_port),
is_connect_done: false,
is_conn_refused: AtomicBool::new(true),
}
}
fn bind_to_ephemeral_endpoint(
self,
remote_endpoint: &IpEndpoint,
) -> core::result::Result<BoundPort, (Error, Self)> {
let endpoint = get_ephemeral_endpoint(remote_endpoint);
self.bind(&endpoint, false)
pub fn bind(&mut self, endpoint: &IpEndpoint, can_reuse: bool) -> Result<()> {
if self.bound_port.is_some() {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address");
}
self.bound_port = Some(bind_port(endpoint, can_reuse)?);
Ok(())
}
pub fn connect(
@ -63,13 +78,49 @@ impl InitStream {
option: &RawTcpOption,
observer: StreamObserver,
) -> core::result::Result<ConnectingStream, (Error, Self)> {
let bound_port = match self {
InitStream::Bound(bound_port) => bound_port,
InitStream::Unbound => self.bind_to_ephemeral_endpoint(remote_endpoint)?,
debug_assert!(
self.is_connect_done,
"`finish_last_connect()` should be called before calling `connect()`"
);
let bound_port = if let Some(bound_port) = self.bound_port {
bound_port
} else {
let endpoint = get_ephemeral_endpoint(remote_endpoint);
match bind_port(&endpoint, false) {
Ok(bound_port) => bound_port,
Err(err) => return Err((err, self)),
}
};
ConnectingStream::new(bound_port, *remote_endpoint, option, observer)
.map_err(|(err, bound_port)| (err, InitStream::Bound(bound_port)))
ConnectingStream::new(bound_port, *remote_endpoint, option, observer).map_err(
|(err, bound_port)| {
if err.error() == Errno::ECONNREFUSED {
(err, InitStream::new_refused(bound_port))
} else {
(err, InitStream::new_bound(bound_port))
}
},
)
}
pub fn finish_last_connect(&mut self) -> Result<()> {
if self.is_connect_done {
return Ok(());
}
self.is_connect_done = true;
let is_conn_refused = self.is_conn_refused.get_mut();
if *is_conn_refused {
*is_conn_refused = false;
return_errno_with_message!(Errno::ECONNREFUSED, "the connection is refused");
} else {
return_errno_with_message!(
Errno::ECONNABORTED,
"the error code for the connection failure is not available"
);
}
}
pub fn listen(
@ -78,10 +129,22 @@ impl InitStream {
option: &RawTcpOption,
observer: StreamObserver,
) -> core::result::Result<ListenStream, (Error, Self)> {
let InitStream::Bound(bound_port) = self else {
if !self.is_connect_done {
// See the comments of `is_connect_done`.
// `listen()` is also not allowed until the second `connect()`.
return Err((
Error::with_message(
Errno::EINVAL,
"the connection is refused, but the connecting phase is not done",
),
self,
));
}
let Some(bound_port) = self.bound_port else {
// FIXME: The socket should be bound to INADDR_ANY (i.e., 0.0.0.0) with an ephemeral
// port. However, INADDR_ANY is not yet supported, so we need to return an error first.
debug_assert!(false, "listen() without bind() is not implemented");
warn!("listen() without bind() is not implemented");
return Err((
Error::with_message(Errno::EINVAL, "listen() without bind() is not implemented"),
self,
@ -90,19 +153,29 @@ impl InitStream {
match ListenStream::new(bound_port, backlog, option, observer) {
Ok(listen_stream) => Ok(listen_stream),
Err((bound_port, error)) => Err((error, Self::Bound(bound_port))),
Err((bound_port, error)) => Err((error, Self::new_bound(bound_port))),
}
}
pub fn local_endpoint(&self) -> Option<IpEndpoint> {
match self {
InitStream::Unbound => None,
InitStream::Bound(bound_port) => Some(bound_port.endpoint().unwrap()),
}
self.bound_port
.as_ref()
.map(|bound_port| bound_port.endpoint().unwrap())
}
pub(super) fn check_io_events(&self) -> IoEvents {
// Linux adds OUT and HUP events for a newly created socket
IoEvents::OUT | IoEvents::HUP
}
pub(super) fn test_and_clear_error(&self) -> Option<Error> {
if self.is_conn_refused.swap(false, Ordering::Relaxed) {
Some(Error::with_message(
Errno::ECONNREFUSED,
"the connection is refused",
))
} else {
None
}
}
}

View File

@ -134,7 +134,7 @@ impl StreamSocket {
/// Ensures that the socket state is up to date and obtains a read lock on it.
///
/// For a description of what "up-to-date" means, see [`Self::update_connecting`].
/// For a description of what "up-to-date" means, see [`Self::write_updated_state`].
fn read_updated_state(&self) -> RwLockReadGuard<Takeable<State>, PreemptDisabled> {
loop {
let state = self.state.read();
@ -144,39 +144,25 @@ impl StreamSocket {
};
drop(state);
self.update_connecting();
self.write_updated_state();
}
}
/// Ensures that the socket state is up to date and obtains a write lock on it.
///
/// For a description of what "up-to-date" means, see [`Self::update_connecting`].
fn write_updated_state(&self) -> RwLockWriteGuard<Takeable<State>, PreemptDisabled> {
self.update_connecting().1
}
/// Updates the socket state if the socket is an obsolete connecting socket.
///
/// A connecting socket can become obsolete because some network events can set the socket to
/// connected state (if the connection succeeds) or initial state (if the connection is
/// refused) in [`Self::update_io_events`], but the state transition is delayed until the user
/// refused) in [`Self::check_io_events`], but the state transition is delayed until the user
/// operates on the socket to avoid too many locks in the interrupt handler.
///
/// This method performs the delayed state transition to ensure that the state is up to date
/// and returns the guards of the write-locked options and state.
fn update_connecting(
&self,
) -> (
RwLockWriteGuard<OptionSet, PreemptDisabled>,
RwLockWriteGuard<Takeable<State>, PreemptDisabled>,
) {
// Hold the lock in advance to avoid race conditions.
let mut options = self.options.write();
/// and returns the guard of the write-locked state.
fn write_updated_state(&self) -> RwLockWriteGuard<Takeable<State>, PreemptDisabled> {
let mut state = self.state.write();
match state.as_ref() {
State::Connecting(connection_stream) if connection_stream.has_result() => (),
_ => return (options, state),
_ => return state,
}
state.borrow(|owned_state| {
@ -186,33 +172,26 @@ impl StreamSocket {
match connecting_stream.into_result() {
ConnResult::Connecting(connecting_stream) => State::Connecting(connecting_stream),
ConnResult::Connected(connected_stream) => {
options.socket.set_sock_errors(None);
State::Connected(connected_stream)
}
ConnResult::Refused(init_stream) => {
options.socket.set_sock_errors(Some(Error::with_message(
Errno::ECONNREFUSED,
"the connection is refused",
)));
State::Init(init_stream)
}
ConnResult::Connected(connected_stream) => State::Connected(connected_stream),
ConnResult::Refused(init_stream) => State::Init(init_stream),
}
});
(options, state)
state
}
// Returns `None` to block the task and wait for the connection to be established, and returns
// `Some(_)` if blocking is not necessary or not allowed.
fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Option<Result<()>> {
let is_nonblocking = self.is_nonblocking();
let (mut options, mut state) = self.update_connecting();
let options = self.options.read();
let raw_option = options.raw();
let mut state = self.write_updated_state();
let (result_or_block, iface_to_poll) = state.borrow_result(|mut owned_state| {
let init_stream = match owned_state {
let mut init_stream = match owned_state {
State::Init(init_stream) => init_stream,
State::Connecting(_) if is_nonblocking => {
return (
@ -228,7 +207,7 @@ impl StreamSocket {
}
State::Connecting(_) => return (owned_state, (None, None)),
State::Connected(ref mut connected_stream) => {
let err = connected_stream.check_new();
let err = connected_stream.finish_last_connect();
return (owned_state, (Some(err), None));
}
State::Listen(_) => {
@ -245,19 +224,25 @@ impl StreamSocket {
}
};
let connecting_stream = match init_stream.connect(
if let Err(err) = init_stream.finish_last_connect() {
return (State::Init(init_stream), (Some(Err(err)), None));
}
let (target_state, iface_to_poll) = match init_stream.connect(
remote_endpoint,
&raw_option,
StreamObserver::new(self.pollee.clone()),
) {
Ok(connecting_stream) => connecting_stream,
Err((mut err, init_stream)) => {
// If the socket is nonblocking, we should return EINPROGRESS instead.
if is_nonblocking {
options.socket.set_sock_errors(Some(err));
err = Error::new(Errno::EINPROGRESS);
}
Ok(connecting_stream) => {
let iface_to_poll = connecting_stream.iface().clone();
(State::Connecting(connecting_stream), Some(iface_to_poll))
}
Err((err, init_stream)) if err.error() == Errno::ECONNREFUSED => {
// `ECONNREFUSED` should be reported asynchronously, i.e., we need to return
// `EINPROGRESS` first for non-blocking sockets.
(State::Init(init_stream), None)
}
Err((err, init_stream)) => {
return (State::Init(init_stream), (Some(Err(err)), None));
}
};
@ -270,12 +255,8 @@ impl StreamSocket {
} else {
None
};
let iface_to_poll = connecting_stream.iface().clone();
(
State::Connecting(connecting_stream),
(result_or_block, Some(iface_to_poll)),
)
(target_state, (result_or_block, iface_to_poll))
});
drop(state);
@ -288,17 +269,16 @@ impl StreamSocket {
}
fn check_connect(&self) -> Result<()> {
let (mut options, mut state) = self.update_connecting();
let mut state = self.write_updated_state();
match state.as_mut() {
State::Init(init_stream) => init_stream.finish_last_connect(),
State::Connecting(_) => {
return_errno_with_message!(Errno::EAGAIN, "the connection is pending")
}
State::Connected(connected_stream) => connected_stream.check_new(),
State::Init(_) | State::Listen(_) => {
let sock_errors = options.socket.sock_errors();
options.socket.set_sock_errors(None);
sock_errors.map(Err).unwrap_or(Ok(()))
State::Connected(connected_stream) => connected_stream.finish_last_connect(),
State::Listen(_) => {
return_errno_with_message!(Errno::EISCONN, "the socket is listening")
}
}
}
@ -392,6 +372,15 @@ impl StreamSocket {
State::Connected(connected_stream) => connected_stream.check_io_events(),
}
}
fn test_and_clear_error(&self) -> Option<Error> {
let state = self.read_updated_state();
match state.as_ref() {
State::Init(init_stream) => init_stream.test_and_clear_error(),
State::Connecting(_) | State::Listen(_) | State::Connected(_) => None,
}
}
}
impl Pollable for StreamSocket {
@ -418,26 +407,10 @@ impl Socket for StreamSocket {
let can_reuse = self.options.read().socket.reuse_addr();
let mut state = self.write_updated_state();
state.borrow_result(|owned_state| {
let State::Init(init_stream) = owned_state else {
return (
owned_state,
Err(Error::with_message(
Errno::EINVAL,
"the socket is already bound to an address",
)),
);
};
let bound_port = match init_stream.bind(&endpoint, can_reuse) {
Ok(bound_port) => bound_port,
Err((err, init_stream)) => {
return (State::Init(init_stream), Err(err));
}
};
(State::Init(InitStream::new_bound(bound_port)), Ok(()))
})
let State::Init(init_stream) = state.as_mut() else {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address");
};
init_stream.bind(&endpoint, can_reuse)
}
fn connect(&self, socket_addr: SocketAddr) -> Result<()> {
@ -451,10 +424,11 @@ impl Socket for StreamSocket {
}
fn listen(&self, backlog: usize) -> Result<()> {
let (options, mut state) = self.update_connecting();
let options = self.options.read();
let raw_option = options.raw();
let mut state = self.write_updated_state();
state.borrow_result(|owned_state| {
let init_stream = match owned_state {
State::Init(init_stream) => init_stream,
@ -588,8 +562,7 @@ impl Socket for StreamSocket {
fn get_option(&self, option: &mut dyn SocketOption) -> Result<()> {
match_sock_option_mut!(option, {
socket_errors: SocketError => {
let mut options = self.update_connecting().0;
options.socket.get_and_clear_sock_errors(socket_errors);
socket_errors.set(self.test_and_clear_error());
return Ok(());
},
_ => ()
@ -660,7 +633,8 @@ impl Socket for StreamSocket {
}
fn set_option(&self, option: &dyn SocketOption) -> Result<()> {
let (mut options, mut state) = self.update_connecting();
let mut options = self.options.write();
let mut state = self.write_updated_state();
let need_iface_poll = match options.socket.set_option(option, state.as_mut()) {
Err(err) if err.error() == Errno::ENOPROTOOPT => {

View File

@ -9,8 +9,7 @@ use aster_bigtcp::socket::{
use crate::{
match_sock_option_mut, match_sock_option_ref,
net::socket::options::{
Error as SocketError, KeepAlive, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf,
SocketOption,
KeepAlive, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption,
},
prelude::*,
};
@ -19,7 +18,6 @@ use crate::{
#[get_copy = "pub"]
#[set = "pub"]
pub struct SocketOptionSet {
sock_errors: Option<Error>,
reuse_addr: bool,
reuse_port: bool,
send_buf: u32,
@ -32,7 +30,6 @@ impl SocketOptionSet {
/// Return the default socket level options for tcp socket.
pub fn new_tcp() -> Self {
Self {
sock_errors: None,
reuse_addr: false,
reuse_port: false,
send_buf: TCP_SEND_BUF_LEN as u32,
@ -45,7 +42,6 @@ impl SocketOptionSet {
/// Return the default socket level options for udp socket.
pub fn new_udp() -> Self {
Self {
sock_errors: None,
reuse_addr: false,
reuse_port: false,
send_buf: UDP_SEND_PAYLOAD_LEN as u32,
@ -55,15 +51,6 @@ impl SocketOptionSet {
}
}
/// Gets and clears the socket error.
///
/// When processing the `getsockopt` system call, the socket error is automatically cleared
/// after reading. This method should be called to provide this behavior.
pub fn get_and_clear_sock_errors(&mut self, option: &mut SocketError) {
option.set(self.sock_errors());
self.set_sock_errors(None);
}
/// Gets socket-level options.
///
/// Note that the socket error has to be handled separately, because it is automatically

View File

@ -323,31 +323,58 @@ FN_TEST(async_connect)
socklen_t errlen = sizeof(err);
sk_addr.sin_port = 0xbeef;
#define ASYNC_CONNECT \
TEST_ERRNO(connect(sk_bound, (struct sockaddr *)&sk_addr, \
sizeof(sk_addr)), \
EINPROGRESS); \
TEST_RES(poll(&pfd, 1, 60), pfd.revents &POLLOUT);
ASYNC_CONNECT;
// The second `connect` will fail with `ECONNREFUSED`.
TEST_ERRNO(connect(sk_bound, (struct sockaddr *)&sk_addr,
sizeof(sk_addr)),
EINPROGRESS);
ECONNREFUSED);
TEST_RES(poll(&pfd, 1, 60), pfd.revents & POLLOUT);
TEST_RES(getsockopt(sk_bound, SOL_SOCKET, SO_ERROR, &err, &errlen),
errlen == sizeof(err) && err == ECONNREFUSED);
ASYNC_CONNECT;
// Reading the socket error will cause it to be cleared
TEST_RES(getsockopt(sk_bound, SOL_SOCKET, SO_ERROR, &err, &errlen),
errlen == sizeof(err) && err == ECONNREFUSED);
TEST_RES(getsockopt(sk_bound, SOL_SOCKET, SO_ERROR, &err, &errlen),
errlen == sizeof(err) && err == 0);
// `listen` won't succeed until the second `connect`.
TEST_ERRNO(listen(sk_bound, 10), EINVAL);
// The second `connect` will fail with `ECONNABORTED` if the socket
// error is cleared.
TEST_ERRNO(connect(sk_bound, (struct sockaddr *)&sk_addr,
sizeof(sk_addr)),
ECONNABORTED);
#undef ASYNC_CONNECT
}
END_TEST()
void set_blocking(int sockfd)
static void set_blocking(int sockfd, int is_blocking)
{
int flags = CHECK(fcntl(sockfd, F_GETFL, 0));
CHECK(fcntl(sockfd, F_SETFL, flags & (~O_NONBLOCK)));
if (is_blocking) {
flags &= ~O_NONBLOCK;
} else {
flags |= O_NONBLOCK;
}
CHECK(fcntl(sockfd, F_SETFL, flags));
}
FN_SETUP(enter_blocking_mode)
{
set_blocking(sk_connected);
set_blocking(sk_bound);
set_blocking(sk_connected, 1);
set_blocking(sk_bound, 1);
}
END_SETUP()
@ -536,6 +563,7 @@ FN_TEST(bind_and_connect_same_address)
TEST_SUCC(listen(sk_listen, 3));
// For blocking sockets, conflict addresses result in `EADDRNOTAVAIL`.
sk_addr.sin_port = htons(listen_port);
TEST_SUCC(connect(sk_connect1, (struct sockaddr *)&sk_addr,
sizeof(sk_addr)));
@ -543,8 +571,15 @@ FN_TEST(bind_and_connect_same_address)
sizeof(sk_addr)),
EADDRNOTAVAIL);
// For non-blocking sockets, conflict addresses also result in `EADDRNOTAVAIL`.
// (`EINPROGRESS` should _not_ be returned in this case.)
set_blocking(sk_connect2, 0);
TEST_ERRNO(connect(sk_connect2, (struct sockaddr *)&sk_addr,
sizeof(sk_addr)),
EADDRNOTAVAIL);
TEST_SUCC(close(sk_listen));
TEST_SUCC(close(sk_connect1));
TEST_SUCC(close(sk_connect2));
}
END_TEST()
END_TEST()