diff --git a/kernel/libs/aster-bigtcp/src/socket/bound.rs b/kernel/libs/aster-bigtcp/src/socket/bound.rs index 375d364f..cecc737b 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound.rs @@ -10,12 +10,13 @@ use ostd::sync::{LocalIrqDisabled, RwLock, SpinLock, SpinLockGuard, WriteIrqDisa use smoltcp::{ iface::Context, socket::{tcp::State, udp::UdpMetadata, PollAt}, - time::Instant, + time::{Duration, Instant}, wire::{IpAddress, IpEndpoint, IpRepr, TcpControl, TcpRepr, UdpRepr}, }; use super::{ event::{SocketEventObserver, SocketEvents}, + option::RawTcpSetOption, RawTcpSocket, RawUdpSocket, TcpStateCheck, }; use crate::{ext::Ext, iface::Iface}; @@ -251,6 +252,7 @@ pub enum ConnectState { pub struct NeedIfacePoll(bool); impl NeedIfacePoll { + pub const TRUE: Self = Self(true); pub const FALSE: Self = Self(false); } @@ -371,6 +373,25 @@ impl BoundTcpSocket { } } +impl RawTcpSetOption for BoundTcpSocket { + fn set_keep_alive(&mut self, interval: Option) -> NeedIfacePoll { + let mut socket = self.0.socket.lock(); + socket.set_keep_alive(interval); + + if interval.is_some() { + self.0.update_next_poll_at_ms(PollAt::Now); + NeedIfacePoll::TRUE + } else { + NeedIfacePoll::FALSE + } + } + + fn set_nagle_enabled(&mut self, enabled: bool) { + let mut socket = self.0.socket.lock(); + socket.set_nagle_enabled(enabled); + } +} + impl BoundUdpSocket { /// Binds to a specified endpoint. /// diff --git a/kernel/libs/aster-bigtcp/src/socket/mod.rs b/kernel/libs/aster-bigtcp/src/socket/mod.rs index 4495afba..29352671 100644 --- a/kernel/libs/aster-bigtcp/src/socket/mod.rs +++ b/kernel/libs/aster-bigtcp/src/socket/mod.rs @@ -2,13 +2,15 @@ mod bound; mod event; +mod option; mod state; mod unbound; pub use bound::{BoundTcpSocket, BoundUdpSocket, ConnectState, NeedIfacePoll}; pub(crate) use bound::{BoundTcpSocketInner, BoundUdpSocketInner, TcpProcessResult}; pub use event::{SocketEventObserver, SocketEvents}; -pub use state::TcpStateCheck; +pub use option::RawTcpSetOption; +pub use state::{TcpState, TcpStateCheck}; pub use unbound::{ UnboundTcpSocket, UnboundUdpSocket, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN, UDP_SEND_PAYLOAD_LEN, diff --git a/kernel/libs/aster-bigtcp/src/socket/option.rs b/kernel/libs/aster-bigtcp/src/socket/option.rs new file mode 100644 index 00000000..c5d79b0e --- /dev/null +++ b/kernel/libs/aster-bigtcp/src/socket/option.rs @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: MPL-2.0 + +use smoltcp::time::Duration; + +use super::NeedIfacePoll; + +/// A trait defines setting socket options on a raw socket. +/// +/// TODO: When `UnboundSocket` is removed, all methods in this trait can accept +/// `&self` instead of `&mut self` as parameter. +pub trait RawTcpSetOption { + /// Sets the keep alive interval. + /// + /// Polling the iface _may_ be required after this method succeeds. + fn set_keep_alive(&mut self, interval: Option) -> NeedIfacePoll; + + /// Enables or disables Nagle’s Algorithm. + /// + /// Polling the iface is not required after this method succeeds. + fn set_nagle_enabled(&mut self, enabled: bool); +} diff --git a/kernel/libs/aster-bigtcp/src/socket/state.rs b/kernel/libs/aster-bigtcp/src/socket/state.rs index df56c16f..f8bc4148 100644 --- a/kernel/libs/aster-bigtcp/src/socket/state.rs +++ b/kernel/libs/aster-bigtcp/src/socket/state.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MPL-2.0 -use smoltcp::socket::tcp::State as TcpState; +pub use smoltcp::socket::tcp::State as TcpState; use super::RawTcpSocket; diff --git a/kernel/libs/aster-bigtcp/src/socket/unbound.rs b/kernel/libs/aster-bigtcp/src/socket/unbound.rs index f68804e0..b46dcb0e 100644 --- a/kernel/libs/aster-bigtcp/src/socket/unbound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/unbound.rs @@ -2,7 +2,7 @@ use alloc::{boxed::Box, vec}; -use super::{RawTcpSocket, RawUdpSocket}; +use super::{option::RawTcpSetOption, NeedIfacePoll, RawTcpSocket, RawUdpSocket}; pub struct UnboundSocket { socket: Box, @@ -30,6 +30,17 @@ impl Default for UnboundTcpSocket { } } +impl RawTcpSetOption for UnboundTcpSocket { + fn set_keep_alive(&mut self, interval: Option) -> NeedIfacePoll { + self.socket.set_keep_alive(interval); + NeedIfacePoll::FALSE + } + + fn set_nagle_enabled(&mut self, enabled: bool) { + self.socket.set_nagle_enabled(enabled); + } +} + impl UnboundUdpSocket { pub fn new() -> Self { let raw_udp_socket = { diff --git a/kernel/libs/aster-bigtcp/src/time.rs b/kernel/libs/aster-bigtcp/src/time.rs index b5ea5295..920eb213 100644 --- a/kernel/libs/aster-bigtcp/src/time.rs +++ b/kernel/libs/aster-bigtcp/src/time.rs @@ -1,3 +1,3 @@ // SPDX-License-Identifier: MPL-2.0 -pub use smoltcp::time::Instant; +pub use smoltcp::time::{Duration, Instant}; diff --git a/kernel/src/net/socket/ip/datagram/mod.rs b/kernel/src/net/socket/ip/datagram/mod.rs index 18b82797..a3b709a2 100644 --- a/kernel/src/net/socket/ip/datagram/mod.rs +++ b/kernel/src/net/socket/ip/datagram/mod.rs @@ -18,7 +18,9 @@ use crate::{ net::socket::{ options::{Error as SocketError, SocketOption}, util::{ - options::SocketOptionSet, send_recv_flags::SendRecvFlags, socket_addr::SocketAddr, + options::{SetSocketLevelOption, SocketOptionSet}, + send_recv_flags::SendRecvFlags, + socket_addr::SocketAddr, MessageHeader, }, Socket, @@ -393,6 +395,30 @@ impl Socket for DatagramSocket { } fn set_option(&self, option: &dyn SocketOption) -> Result<()> { - self.options.write().socket.set_option(option) + let mut options = self.options.write(); + let mut inner = self.inner.write(); + + match options.socket.set_option(option, inner.as_mut()) { + Err(e) => Err(e), + Ok(need_iface_poll) => { + let iface_to_poll = need_iface_poll + .then(|| match inner.as_ref() { + Inner::Unbound(_) => None, + Inner::Bound(bound_datagram) => Some(bound_datagram.iface().clone()), + }) + .flatten(); + + drop(inner); + drop(options); + + if let Some(iface) = iface_to_poll { + iface.poll(); + } + + Ok(()) + } + } } } + +impl SetSocketLevelOption for Inner {} diff --git a/kernel/src/net/socket/ip/stream/connected.rs b/kernel/src/net/socket/ip/stream/connected.rs index 12395b1d..bb673b2a 100644 --- a/kernel/src/net/socket/ip/stream/connected.rs +++ b/kernel/src/net/socket/ip/stream/connected.rs @@ -4,7 +4,7 @@ use core::sync::atomic::{AtomicBool, Ordering}; use aster_bigtcp::{ errors::tcp::{RecvError, SendError}, - socket::{NeedIfacePoll, TcpStateCheck}, + socket::{NeedIfacePoll, RawTcpSetOption, RawTcpSocket, TcpStateCheck}, wire::IpEndpoint, }; @@ -205,4 +205,15 @@ impl ConnectedStream { pub(super) fn set_observer(&self, observer: StreamObserver) { self.bound_socket.set_observer(observer) } + + pub(super) fn set_raw_option( + &mut self, + set_option: impl Fn(&mut dyn RawTcpSetOption) -> R, + ) -> R { + set_option(&mut self.bound_socket) + } + + pub(super) fn raw_with(&self, f: impl FnOnce(&RawTcpSocket) -> R) -> R { + self.bound_socket.raw_with(f) + } } diff --git a/kernel/src/net/socket/ip/stream/connecting.rs b/kernel/src/net/socket/ip/stream/connecting.rs index c6b1eba7..e73d1fd0 100644 --- a/kernel/src/net/socket/ip/stream/connecting.rs +++ b/kernel/src/net/socket/ip/stream/connecting.rs @@ -1,6 +1,9 @@ // SPDX-License-Identifier: MPL-2.0 -use aster_bigtcp::{socket::ConnectState, wire::IpEndpoint}; +use aster_bigtcp::{ + socket::{ConnectState, RawTcpSetOption}, + wire::IpEndpoint, +}; use super::{connected::ConnectedStream, init::InitStream}; use crate::{ @@ -83,4 +86,11 @@ impl ConnectingStream { pub(super) fn check_io_events(&self) -> IoEvents { IoEvents::empty() } + + pub(super) fn set_raw_option( + &mut self, + set_option: impl Fn(&mut dyn RawTcpSetOption) -> R, + ) -> R { + set_option(&mut self.bound_socket) + } } diff --git a/kernel/src/net/socket/ip/stream/init.rs b/kernel/src/net/socket/ip/stream/init.rs index 0597dc98..17ebdeb0 100644 --- a/kernel/src/net/socket/ip/stream/init.rs +++ b/kernel/src/net/socket/ip/stream/init.rs @@ -1,6 +1,9 @@ // SPDX-License-Identifier: MPL-2.0 -use aster_bigtcp::{socket::UnboundTcpSocket, wire::IpEndpoint}; +use aster_bigtcp::{ + socket::{RawTcpSetOption, UnboundTcpSocket}, + wire::IpEndpoint, +}; use super::{connecting::ConnectingStream, listen::ListenStream, StreamObserver}; use crate::{ @@ -108,4 +111,14 @@ impl InitStream { // Linux adds OUT and HUP events for a newly created socket IoEvents::OUT | IoEvents::HUP } + + pub(super) fn set_raw_option( + &mut self, + set_option: impl Fn(&mut dyn RawTcpSetOption) -> R, + ) -> R { + match self { + InitStream::Unbound(unbound_socket) => set_option(unbound_socket.as_mut()), + InitStream::Bound(bound_socket) => set_option(bound_socket), + } + } } diff --git a/kernel/src/net/socket/ip/stream/listen.rs b/kernel/src/net/socket/ip/stream/listen.rs index d83c7a69..514b37fe 100644 --- a/kernel/src/net/socket/ip/stream/listen.rs +++ b/kernel/src/net/socket/ip/stream/listen.rs @@ -1,7 +1,10 @@ // SPDX-License-Identifier: MPL-2.0 use aster_bigtcp::{ - errors::tcp::ListenError, iface::BindPortConfig, socket::UnboundTcpSocket, wire::IpEndpoint, + errors::tcp::ListenError, + iface::BindPortConfig, + socket::{RawTcpSetOption, TcpState, UnboundTcpSocket}, + wire::IpEndpoint, }; use ostd::sync::PreemptDisabled; @@ -104,6 +107,30 @@ impl ListenStream { IoEvents::empty() } } + + /// Calls `f` to set socket option on raw socket. + /// + /// This method will call `f` on the bound socket and each backlog socket that is in `Listen` state . + pub(super) fn set_raw_option( + &mut self, + set_option: impl Fn(&mut dyn RawTcpSetOption) -> R, + ) -> R { + self.backlog_sockets.write().iter_mut().for_each(|socket| { + if socket + .bound_socket + .raw_with(|raw_tcp_socket| raw_tcp_socket.state() != TcpState::Listen) + { + return; + } + + // If the socket receives SYN after above check, + // we will also set keep alive on the socket that is not in `Listen` state. + // But such a race doesn't matter, we just let it happen. + set_option(&mut socket.bound_socket); + }); + + set_option(&mut self.bound_socket) + } } struct BacklogSocket { @@ -119,7 +146,15 @@ impl BacklogSocket { "the socket is not bound", ))?; - let unbound_socket = Box::new(UnboundTcpSocket::new()); + let unbound_socket = { + let mut unbound = UnboundTcpSocket::new(); + unbound.set_keep_alive(bound_socket.raw_with(|socket| socket.keep_alive())); + unbound.set_nagle_enabled(bound_socket.raw_with(|socket| socket.nagle_enabled())); + + // TODO: Inherit other options that can be set via `setsockopt` from bound socket + + Box::new(unbound) + }; let bound_socket = { let iface = bound_socket.iface(); let bind_port_config = BindPortConfig::new(local_endpoint.port, true); diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 7b8995fa..0f5ade33 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -2,7 +2,10 @@ use core::sync::atomic::{AtomicBool, Ordering}; -use aster_bigtcp::wire::IpEndpoint; +use aster_bigtcp::{ + socket::{NeedIfacePoll, RawTcpSetOption}, + wire::IpEndpoint, +}; use connected::ConnectedStream; use connecting::{ConnResult, ConnectingStream}; use init::InitStream; @@ -20,13 +23,19 @@ use crate::{ utils::{InodeMode, Metadata, StatusFlags}, }, match_sock_option_mut, match_sock_option_ref, - net::socket::{ - options::{Error as SocketError, SocketOption}, - util::{ - options::SocketOptionSet, send_recv_flags::SendRecvFlags, - shutdown_cmd::SockShutdownCmd, socket_addr::SocketAddr, MessageHeader, + net::{ + iface::Iface, + socket::{ + options::{Error as SocketError, SocketOption}, + util::{ + options::{SetSocketLevelOption, SocketOptionSet}, + send_recv_flags::SendRecvFlags, + shutdown_cmd::SockShutdownCmd, + socket_addr::SocketAddr, + MessageHeader, + }, + Socket, }, - Socket, }, prelude::*, process::signal::{PollHandle, Pollable, Pollee}, @@ -87,11 +96,28 @@ impl StreamSocket { }) } - fn new_connected(connected_stream: ConnectedStream) -> Arc { + fn new_accepted(connected_stream: ConnectedStream) -> Arc { + let options = connected_stream.raw_with(|raw_tcp_socket| { + let mut options = OptionSet::new(); + + if raw_tcp_socket.keep_alive().is_some() { + options.socket.set_keep_alive(true); + } + + if !raw_tcp_socket.nagle_enabled() { + options.tcp.set_no_delay(true); + } + + // TODO: Update other options for a newly-accepted socket + + options + }); + let pollee = Pollee::new(); connected_stream.set_observer(StreamObserver::new(pollee.clone())); + Arc::new(Self { - options: RwLock::new(OptionSet::new()), + options: RwLock::new(options), state: RwLock::new(Takeable::new(State::Connected(connected_stream))), is_nonblocking: AtomicBool::new(false), pollee, @@ -276,7 +302,7 @@ impl StreamSocket { .try_accept(&self.pollee) .map(|connected_stream| { let remote_endpoint = connected_stream.remote_endpoint(); - let accepted_socket = Self::new_connected(connected_stream); + let accepted_socket = Self::new_accepted(connected_stream); (accepted_socket as _, remote_endpoint.into()) }); let iface_to_poll = listen_stream.iface().clone(); @@ -650,11 +676,23 @@ impl Socket for StreamSocket { } fn set_option(&self, option: &dyn SocketOption) -> Result<()> { - let mut options = self.options.write(); + let (mut options, mut state) = self.update_connecting(); - match options.socket.set_option(option) { + match options.socket.set_option(option, state.as_mut()) { Err(err) if err.error() == Errno::ENOPROTOOPT => (), - res => return res, + Err(err) => return Err(err), + Ok(need_iface_poll) => { + let iface_to_poll = need_iface_poll.then(|| state.iface().cloned()).flatten(); + + drop(state); + drop(options); + + if let Some(iface) = iface_to_poll { + iface.poll(); + } + + return Ok(()); + } } // FIXME: Here we have only set the value of the option, without actually @@ -663,6 +701,7 @@ impl Socket for StreamSocket { tcp_no_delay: NoDelay => { let no_delay = tcp_no_delay.get().unwrap(); options.tcp.set_no_delay(*no_delay); + state.set_raw_option(|raw_socket: &mut dyn RawTcpSetOption| raw_socket.set_nagle_enabled(!no_delay)); }, tcp_congestion: Congestion => { let congestion = tcp_congestion.get().unwrap(); @@ -694,18 +733,62 @@ impl Socket for StreamSocket { } } +impl State { + /// Calls `f` to set raw socket option. + /// + /// Note that for listening socket, `f` is called on all backlog sockets in `Listen` State. + /// That is to say, `f` won't be called on backlog sockets in `SynReceived` or `Established` state. + fn set_raw_option(&mut self, set_option: impl Fn(&mut dyn RawTcpSetOption) -> R) -> R { + match self { + State::Init(init_stream) => init_stream.set_raw_option(set_option), + State::Connecting(connecting_stream) => connecting_stream.set_raw_option(set_option), + State::Connected(connected_stream) => connected_stream.set_raw_option(set_option), + State::Listen(listen_stream) => listen_stream.set_raw_option(set_option), + } + } + + fn iface(&self) -> Option<&Arc> { + match self { + State::Init(_) => None, + State::Connecting(ref connecting_stream) => Some(connecting_stream.iface()), + State::Connected(ref connected_stream) => Some(connected_stream.iface()), + State::Listen(ref listen_stream) => Some(listen_stream.iface()), + } + } +} + +impl SetSocketLevelOption for State { + fn set_keep_alive(&mut self, keep_alive: bool) -> NeedIfacePoll { + /// The keepalive interval. + /// + /// The linux value can be found at `/proc/sys/net/ipv4/tcp_keepalive_intvl`, + /// which is by default 75 seconds for most Linux distributions. + const KEEPALIVE_INTERVAL: aster_bigtcp::time::Duration = + aster_bigtcp::time::Duration::from_secs(75); + + let interval = if keep_alive { + Some(KEEPALIVE_INTERVAL) + } else { + None + }; + + let set_keepalive = + |raw_socket: &mut dyn RawTcpSetOption| raw_socket.set_keep_alive(interval); + + self.set_raw_option(set_keepalive) + } +} + impl Drop for StreamSocket { fn drop(&mut self) { let state = self.state.get_mut().take(); - let iface_to_poll = match state { - State::Init(_) => None, - State::Connecting(ref connecting_stream) => Some(connecting_stream.iface().clone()), - State::Connected(ref connected_stream) => Some(connected_stream.iface().clone()), - State::Listen(ref listen_stream) => Some(listen_stream.iface().clone()), - }; + let iface_to_poll = state.iface().cloned(); + // Dropping the state will drop the sockets. This will trigger the socket close process (if + // needed) and require immediate iface polling afterwards. drop(state); + if let Some(iface) = iface_to_poll { iface.poll(); } diff --git a/kernel/src/net/socket/util/options.rs b/kernel/src/net/socket/util/options.rs index be3a0541..2015c7a6 100644 --- a/kernel/src/net/socket/util/options.rs +++ b/kernel/src/net/socket/util/options.rs @@ -3,13 +3,14 @@ use core::time::Duration; use aster_bigtcp::socket::{ - TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN, UDP_SEND_PAYLOAD_LEN, + NeedIfacePoll, TCP_RECV_BUF_LEN, TCP_SEND_BUF_LEN, UDP_RECV_PAYLOAD_LEN, UDP_SEND_PAYLOAD_LEN, }; use crate::{ match_sock_option_mut, match_sock_option_ref, net::socket::options::{ - Error as SocketError, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, SocketOption, + Error as SocketError, KeepAlive, Linger, RecvBuf, ReuseAddr, ReusePort, SendBuf, + SocketOption, }, prelude::*, }; @@ -24,6 +25,7 @@ pub struct SocketOptionSet { send_buf: u32, recv_buf: u32, linger: LingerOption, + keep_alive: bool, } impl SocketOptionSet { @@ -36,6 +38,7 @@ impl SocketOptionSet { send_buf: TCP_SEND_BUF_LEN as u32, recv_buf: TCP_RECV_BUF_LEN as u32, linger: LingerOption::default(), + keep_alive: false, } } @@ -48,6 +51,7 @@ impl SocketOptionSet { send_buf: UDP_SEND_PAYLOAD_LEN as u32, recv_buf: UDP_RECV_PAYLOAD_LEN as u32, linger: LingerOption::default(), + keep_alive: false, } } @@ -87,13 +91,21 @@ impl SocketOptionSet { let linger = self.linger(); socket_linger.set(linger); }, + socket_keepalive: KeepAlive => { + let keep_alive = self.keep_alive(); + socket_keepalive.set(keep_alive); + }, _ => return_errno_with_message!(Errno::ENOPROTOOPT, "the socket option to get is unknown") }); Ok(()) } /// Sets socket-level options. - pub fn set_option(&mut self, option: &dyn SocketOption) -> Result<()> { + pub fn set_option( + &mut self, + option: &dyn SocketOption, + socket: &mut dyn SetSocketLevelOption, + ) -> Result { match_sock_option_ref!(option, { socket_recv_buf: RecvBuf => { let recv_buf = socket_recv_buf.get().unwrap(); @@ -123,10 +135,15 @@ impl SocketOptionSet { let linger = socket_linger.get().unwrap(); self.set_linger(*linger); }, + socket_keepalive: KeepAlive => { + let keep_alive = socket_keepalive.get().unwrap(); + self.set_keep_alive(*keep_alive); + return Ok(socket.set_keep_alive(*keep_alive)); + }, _ => return_errno_with_message!(Errno::ENOPROTOOPT, "the socket option to be set is unknown") }); - Ok(()) + Ok(NeedIfacePoll::FALSE) } } @@ -152,3 +169,11 @@ impl LingerOption { self.timeout } } + +/// A trait used for setting socket level options on actual sockets. +pub(in crate::net) trait SetSocketLevelOption { + /// Sets whether keepalive messages are enabled. + fn set_keep_alive(&mut self, _keep_alive: bool) -> NeedIfacePoll { + NeedIfacePoll::FALSE + } +}