diff --git a/.typos.toml b/.typos.toml index cebc0ca3b..a605e90f5 100644 --- a/.typos.toml +++ b/.typos.toml @@ -16,6 +16,7 @@ Fo = "Fo" Inh = "Inh" DELET = "DELET" wrk = "wrk" +rto = "rto" # Files with svg suffix are ignored to check. [type.svg] diff --git a/kernel/libs/aster-bigtcp/src/socket/bound.rs b/kernel/libs/aster-bigtcp/src/socket/bound.rs index 2351c29c3..db1092e79 100644 --- a/kernel/libs/aster-bigtcp/src/socket/bound.rs +++ b/kernel/libs/aster-bigtcp/src/socket/bound.rs @@ -453,13 +453,22 @@ impl TcpConnection { /// Closes the connection. /// + /// This method returns `false` if the socket is closed _before_ calling this method. + /// /// Polling the iface is _always_ required after this method succeeds. - pub fn close(&self) { + pub fn close(&self) -> bool { let mut socket = self.0.inner.lock(); socket.listener = None; + + if socket.is_closed() { + return false; + } + socket.close(); self.0.update_next_poll_at_ms(PollAt::Now); + + true } /// Calls `f` with an immutable reference to the associated [`RawTcpSocket`]. diff --git a/kernel/src/net/socket/ip/stream/connected.rs b/kernel/src/net/socket/ip/stream/connected.rs index 71f705882..2b9932bc6 100644 --- a/kernel/src/net/socket/ip/stream/connected.rs +++ b/kernel/src/net/socket/ip/stream/connected.rs @@ -70,7 +70,9 @@ impl ConnectedStream { if cmd.shut_write() { self.is_sending_closed.store(true, Ordering::Relaxed); - self.tcp_conn.close(); + if !self.tcp_conn.close() { + return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"); + } events |= IoEvents::OUT | IoEvents::HUP; } @@ -116,6 +118,10 @@ impl ConnectedStream { reader: &mut dyn MultiRead, _flags: SendRecvFlags, ) -> Result<(usize, NeedIfacePoll)> { + if reader.is_empty() { + return Ok((0, NeedIfacePoll::FALSE)); + } + let result = self.tcp_conn.send(|socket_buffer| { match reader.read(&mut VmWriter::from(socket_buffer)) { Ok(len) => (len, Ok(len)), diff --git a/kernel/src/net/socket/ip/stream/mod.rs b/kernel/src/net/socket/ip/stream/mod.rs index 956dd03e5..abfd3e2a4 100644 --- a/kernel/src/net/socket/ip/stream/mod.rs +++ b/kernel/src/net/socket/ip/stream/mod.rs @@ -10,10 +10,13 @@ use connected::ConnectedStream; use connecting::{ConnResult, ConnectingStream}; use init::InitStream; use listen::ListenStream; -use options::{Congestion, KeepIdle, MaxSegment, NoDelay, WindowClamp, KEEPALIVE_INTERVAL}; +use options::{ + Congestion, DeferAccept, Inq, KeepIdle, MaxSegment, NoDelay, SynCnt, UserTimeout, WindowClamp, + KEEPALIVE_INTERVAL, +}; use ostd::sync::{PreemptDisabled, RwLockReadGuard, RwLockWriteGuard}; use takeable::Takeable; -use util::TcpOptionSet; +use util::{Retrans, TcpOptionSet}; use super::UNSPECIFIED_LOCAL_ENDPOINT; use crate::{ @@ -204,7 +207,7 @@ impl StreamSocket { // `Some(_)` if blocking is not necessary or not allowed. fn start_connect(&self, remote_endpoint: &IpEndpoint) -> Option> { let is_nonblocking = self.is_nonblocking(); - let (options, mut state) = self.update_connecting(); + let (mut options, mut state) = self.update_connecting(); let raw_option = options.raw(); @@ -248,7 +251,13 @@ impl StreamSocket { StreamObserver::new(self.pollee.clone()), ) { Ok(connecting_stream) => connecting_stream, - Err((err, init_stream)) => { + Err((mut err, init_stream)) => { + // If the socket is nonblocking, we should return EINPROGRESS instead. + if is_nonblocking { + options.socket.set_sock_errors(Some(err)); + err = Error::new(Errno::EINPROGRESS); + } + return (State::Init(init_stream), (Some(Err(err)), None)); } }; @@ -601,13 +610,33 @@ impl Socket for StreamSocket { tcp_no_delay.set(no_delay); }, tcp_maxseg: MaxSegment => { + const DEFAULT_MAX_SEGMEMT: u32 = 536; + // For an unconnected socket, + // older Linux versions (e.g., v6.0) return + // the default MSS value defined above. + // However, newer Linux versions (e.g., v6.11) + // return the user-set MSS value if it is set. + // Here, we adopt the behavior of the latest Linux versions. let maxseg = options.tcp.maxseg(); - tcp_maxseg.set(maxseg); + if maxseg == 0 { + tcp_maxseg.set(DEFAULT_MAX_SEGMEMT); + } else { + tcp_maxseg.set(maxseg); + } }, tcp_keep_idle: KeepIdle => { let keep_idle = options.tcp.keep_idle(); tcp_keep_idle.set(keep_idle); }, + tcp_syn_cnt: SynCnt => { + let syn_cnt = options.tcp.syn_cnt(); + tcp_syn_cnt.set(syn_cnt); + }, + tcp_defer_accept: DeferAccept => { + let defer_accept = options.tcp.defer_accept(); + let seconds = defer_accept.to_secs(); + tcp_defer_accept.set(seconds); + }, tcp_window_clamp: WindowClamp => { let window_clamp = options.tcp.window_clamp(); tcp_window_clamp.set(window_clamp); @@ -616,6 +645,14 @@ impl Socket for StreamSocket { let congestion = options.tcp.congestion(); tcp_congestion.set(congestion); }, + tcp_user_timeout: UserTimeout => { + let user_timeout = options.tcp.user_timeout(); + tcp_user_timeout.set(user_timeout); + }, + tcp_inq: Inq => { + let inq = options.tcp.receive_inq(); + tcp_inq.set(inq); + }, _ => return_errno_with_message!(Errno::ENOPROTOOPT, "the socket option to get is unknown") }); @@ -679,6 +716,23 @@ fn do_tcp_setsockopt( // TODO: Track when the socket becomes idle to actually support keep idle. }, + tcp_syn_cnt: SynCnt => { + const MAX_TCP_SYN_CNT: u8 = 127; + + let syncnt = tcp_syn_cnt.get().unwrap(); + if *syncnt < 1 || *syncnt > MAX_TCP_SYN_CNT { + return_errno_with_message!(Errno::EINVAL, "the SYN count is out of bounds"); + } + options.tcp.set_syn_cnt(*syncnt); + }, + tcp_defer_accept: DeferAccept => { + let mut seconds = *(tcp_defer_accept.get().unwrap()); + if (seconds as i32) < 0 { + seconds = 0; + } + let retrans = Retrans::from_secs(seconds); + options.tcp.set_defer_accept(retrans); + }, tcp_window_clamp: WindowClamp => { let window_clamp = tcp_window_clamp.get().unwrap(); let half_recv_buf = options.socket.recv_buf() / 2; @@ -692,6 +746,17 @@ fn do_tcp_setsockopt( let congestion = tcp_congestion.get().unwrap(); options.tcp.set_congestion(*congestion); }, + tcp_user_timeout: UserTimeout => { + let user_timeout = tcp_user_timeout.get().unwrap(); + if (*user_timeout as i32) < 0 { + return_errno_with_message!(Errno::EINVAL, "the user timeout cannot be negative"); + } + options.tcp.set_user_timeout(*user_timeout); + }, + tcp_inq: Inq => { + let inq = tcp_inq.get().unwrap(); + options.tcp.set_receive_inq(*inq); + }, _ => return_errno_with_message!(Errno::ENOPROTOOPT, "the socket option to be set is unknown") }); diff --git a/kernel/src/net/socket/ip/stream/options.rs b/kernel/src/net/socket/ip/stream/options.rs index 630df9fe4..f3f3a3f4e 100644 --- a/kernel/src/net/socket/ip/stream/options.rs +++ b/kernel/src/net/socket/ip/stream/options.rs @@ -7,8 +7,12 @@ impl_socket_options!( pub struct NoDelay(bool); pub struct MaxSegment(u32); pub struct KeepIdle(u32); + pub struct SynCnt(u8); + pub struct DeferAccept(u32); pub struct WindowClamp(u32); pub struct Congestion(CongestionControl); + pub struct UserTimeout(u32); + pub struct Inq(bool); ); /// The keepalive interval. diff --git a/kernel/src/net/socket/ip/stream/util.rs b/kernel/src/net/socket/ip/stream/util.rs index 06acc6aae..47ebc21f3 100644 --- a/kernel/src/net/socket/ip/stream/util.rs +++ b/kernel/src/net/socket/ip/stream/util.rs @@ -1,5 +1,7 @@ // SPDX-License-Identifier: MPL-2.0 +use aster_bigtcp::time::Duration; + use crate::prelude::*; #[derive(Debug, Clone, Copy, CopyGetters, Setters)] @@ -9,12 +11,17 @@ pub struct TcpOptionSet { no_delay: bool, maxseg: u32, keep_idle: u32, + syn_cnt: u8, + defer_accept: Retrans, window_clamp: u32, congestion: CongestionControl, + user_timeout: u32, + receive_inq: bool, } pub const DEFAULT_MAXSEG: u32 = 536; pub const DEFAULT_KEEP_IDLE: u32 = 7200; +pub const DEFAULT_SYN_CNT: u8 = 6; pub const DEFAULT_WINDOW_CLAMP: u32 = 0x8000_0000; impl TcpOptionSet { @@ -23,8 +30,12 @@ impl TcpOptionSet { no_delay: false, maxseg: DEFAULT_MAXSEG, keep_idle: DEFAULT_KEEP_IDLE, + syn_cnt: DEFAULT_SYN_CNT, + defer_accept: Retrans(0), window_clamp: DEFAULT_WINDOW_CLAMP, congestion: CongestionControl::Reno, + user_timeout: 0, + receive_inq: false, } } } @@ -35,6 +46,63 @@ impl Default for TcpOptionSet { } } +/// Initial RTO value +const TCP_TIMEOUT_INIT: Duration = Duration::from_secs(1); +const TCP_RTO_MAX: Duration = Duration::from_secs(120); + +/// The number of retransmits. +#[derive(Debug, Clone, Copy)] +pub struct Retrans(u8); + +impl Retrans { + /// Converts seconds to retransmits. + pub const fn from_secs(seconds: u32) -> Self { + if seconds == 0 { + return Self(0); + } + + let mut timeout = TCP_TIMEOUT_INIT.secs() as u32; + let rto_max = TCP_RTO_MAX.secs() as u32; + let mut period = timeout; + let mut res = 1; + + while seconds > period && res < 255 { + res += 1; + timeout <<= 1; + if timeout > rto_max { + timeout = rto_max; + } + period += timeout; + } + + Self(res) + } + + /// Converts retransmits to seconds. + pub const fn to_secs(self) -> u32 { + let mut retrans = self.0; + + if retrans == 0 { + return 0; + } + + let mut timeout = TCP_TIMEOUT_INIT.secs() as u32; + let rto_max = TCP_RTO_MAX.secs() as u32; + let mut period = timeout; + + while retrans > 1 { + retrans -= 1; + timeout <<= 1; + if timeout > rto_max { + timeout = rto_max; + } + period += timeout; + } + + period + } +} + #[derive(Debug, Clone, Copy)] pub enum CongestionControl { Reno, @@ -49,7 +117,7 @@ impl CongestionControl { let congestion = match name { Self::RENO => Self::Reno, Self::CUBIC => Self::Cubic, - _ => return_errno_with_message!(Errno::EINVAL, "unsupported congestion name"), + _ => return_errno_with_message!(Errno::ENOENT, "unsupported congestion name"), }; Ok(congestion) diff --git a/kernel/src/util/net/options/tcp.rs b/kernel/src/util/net/options/tcp.rs index 9c0b6491c..75a24c143 100644 --- a/kernel/src/util/net/options/tcp.rs +++ b/kernel/src/util/net/options/tcp.rs @@ -3,7 +3,10 @@ use super::RawSocketOption; use crate::{ impl_raw_socket_option, - net::socket::ip::stream::options::{Congestion, KeepIdle, MaxSegment, NoDelay, WindowClamp}, + net::socket::ip::stream::options::{ + Congestion, DeferAccept, Inq, KeepIdle, MaxSegment, NoDelay, SynCnt, UserTimeout, + WindowClamp, + }, prelude::*, util::net::options::SocketOption, }; @@ -16,13 +19,28 @@ use crate::{ #[expect(non_camel_case_types)] #[expect(clippy::upper_case_acronyms)] pub enum CTcpOptionName { - NODELAY = 1, /* Turn off Nagle's algorithm. */ - MAXSEG = 2, /* Limit MSS */ - CORK = 3, /* Never send partially complete segments */ - KEEPIDLE = 4, /* Start keeplives after this period */ - KEEPALIVE = 5, /* Interval between keepalives */ - WINDOW_CLAMP = 10, /* Bound advertised window */ - CONGESTION = 13, /* Congestion control algorithm */ + /// Turn off Nagle's algorithm + NODELAY = 1, + /// Limit MSS + MAXSEG = 2, + /// Never send partially complete segments + CORK = 3, + /// Start keeplives after this period + KEEPIDLE = 4, + /// Interval between keepalives + KEEPALIVE = 5, + /// Number of SYN retransmits + SYNCNT = 7, + /// Wake up listener only when data arriv + DEFER_ACCEPT = 9, + /// Bound advertised window + WINDOW_CLAMP = 10, + /// Congestion control algorithm + CONGESTION = 13, + /// How long for loss retry before timeout + USER_TIMEOUT = 18, + /// Notify bytes available to read as a cmsg on read + INQ = 36, } pub fn new_tcp_option(name: i32) -> Result> { @@ -31,8 +49,12 @@ pub fn new_tcp_option(name: i32) -> Result> { CTcpOptionName::NODELAY => Ok(Box::new(NoDelay::new())), CTcpOptionName::MAXSEG => Ok(Box::new(MaxSegment::new())), CTcpOptionName::KEEPIDLE => Ok(Box::new(KeepIdle::new())), + CTcpOptionName::SYNCNT => Ok(Box::new(SynCnt::new())), + CTcpOptionName::DEFER_ACCEPT => Ok(Box::new(DeferAccept::new())), CTcpOptionName::WINDOW_CLAMP => Ok(Box::new(WindowClamp::new())), CTcpOptionName::CONGESTION => Ok(Box::new(Congestion::new())), + CTcpOptionName::USER_TIMEOUT => Ok(Box::new(UserTimeout::new())), + CTcpOptionName::INQ => Ok(Box::new(Inq::new())), _ => return_errno_with_message!(Errno::ENOPROTOOPT, "unsupported tcp-level option"), } } @@ -40,5 +62,9 @@ pub fn new_tcp_option(name: i32) -> Result> { impl_raw_socket_option!(NoDelay); impl_raw_socket_option!(MaxSegment); impl_raw_socket_option!(KeepIdle); +impl_raw_socket_option!(SynCnt); +impl_raw_socket_option!(DeferAccept); impl_raw_socket_option!(WindowClamp); impl_raw_socket_option!(Congestion); +impl_raw_socket_option!(UserTimeout); +impl_raw_socket_option!(Inq); diff --git a/kernel/src/util/net/options/utils.rs b/kernel/src/util/net/options/utils.rs index 7a6a5e9d4..ab6231417 100644 --- a/kernel/src/util/net/options/utils.rs +++ b/kernel/src/util/net/options/utils.rs @@ -32,15 +32,8 @@ pub trait WriteToUser { fn write_to_user(&self, addr: Vaddr, max_len: u32) -> Result; } -/// This macro is used to implement `ReadFromUser` and `WriteToUser` for types that -/// implement the `Pod` trait. -/// FIXME: The macro is somewhat ugly. Ideally, we would prefer to use -/// ```rust -/// impl ReadFromUser for T -/// ``` -/// instead of this macro. However, using the `impl` statement will result in a compilation -/// error, as it is possible for an upstream crate to implement `Pod` for other types like `bool`, -macro_rules! impl_read_write_for_pod_type { +/// This macro is used to implement `ReadFromUser` and `WriteToUser` for u32 and i32. +macro_rules! impl_read_write_for_32bit_type { ($pod_ty: ty) => { impl ReadFromUser for $pod_ty { fn read_from_user(addr: Vaddr, max_len: u32) -> Result { @@ -66,31 +59,38 @@ macro_rules! impl_read_write_for_pod_type { }; } -impl_read_write_for_pod_type!(u32); +impl_read_write_for_32bit_type!(i32); +impl_read_write_for_32bit_type!(u32); impl ReadFromUser for bool { fn read_from_user(addr: Vaddr, max_len: u32) -> Result { - if (max_len as usize) < core::mem::size_of::() { - return_errno_with_message!(Errno::EINVAL, "max_len is too short"); - } - - let val = current_userspace!().read_val::(addr)?; - + let val = i32::read_from_user(addr, max_len)?; Ok(val != 0) } } impl WriteToUser for bool { fn write_to_user(&self, addr: Vaddr, max_len: u32) -> Result { - let write_len = core::mem::size_of::(); + let val = if *self { 1i32 } else { 0i32 }; + val.write_to_user(addr, max_len) + } +} - if (max_len as usize) < write_len { - return_errno_with_message!(Errno::EINVAL, "max_len is too short"); +impl ReadFromUser for u8 { + fn read_from_user(addr: Vaddr, max_len: u32) -> Result { + let val = i32::read_from_user(addr, max_len)?; + + if val < 0 || val > u8::MAX as i32 { + return_errno_with_message!(Errno::EINVAL, "invalid u8 value"); } - let val = if *self { 1i32 } else { 0i32 }; - current_userspace!().write_val(addr, &val)?; - Ok(write_len) + Ok(val as u8) + } +} + +impl WriteToUser for u8 { + fn write_to_user(&self, addr: Vaddr, max_len: u32) -> Result { + (*self as i32).write_to_user(addr, max_len) } } @@ -138,25 +138,40 @@ impl WriteToUser for LingerOption { } } +const TCP_CONGESTION_NAME_MAX: u32 = 16; + impl ReadFromUser for CongestionControl { fn read_from_user(addr: Vaddr, max_len: u32) -> Result { - let mut bytes = vec![0; max_len as usize]; - current_userspace!().read_bytes(addr, &mut VmWriter::from(bytes.as_mut_slice()))?; - let name = String::from_utf8(bytes).unwrap(); - CongestionControl::new(&name) + let mut bytes = [0; TCP_CONGESTION_NAME_MAX as usize]; + + let dst = { + let read_len = (TCP_CONGESTION_NAME_MAX - 1).min(max_len) as usize; + &mut bytes[..read_len] + }; + + // Clippy warns that `dst.as_mut` is redundant. However, using `dst` directly + // instead of `dst.as_mut` would take the ownership of `dst`. Consequently, + // the subsequent code that constructs `name` from `dst` would fail to compile. + #[expect(clippy::useless_asref)] + current_userspace!().read_bytes(addr, &mut VmWriter::from(dst.as_mut()))?; + + let name = core::str::from_utf8(dst) + .map_err(|_| Error::with_message(Errno::ENOENT, "non-UTF8 congestion name"))?; + CongestionControl::new(name) } } impl WriteToUser for CongestionControl { fn write_to_user(&self, addr: Vaddr, max_len: u32) -> Result { - let name = self.name().as_bytes(); + let mut bytes = [0u8; TCP_CONGESTION_NAME_MAX as usize]; - let write_len = name.len(); - if write_len > max_len as usize { - return_errno_with_message!(Errno::EINVAL, "max_len is too short"); - } + let name_bytes = self.name().as_bytes(); + let name_len = name_bytes.len(); + bytes[..name_len].copy_from_slice(name_bytes); - current_userspace!().write_bytes(addr, &mut VmReader::from(name))?; + let write_len = TCP_CONGESTION_NAME_MAX.min(max_len) as usize; + + current_userspace!().write_bytes(addr, &mut VmReader::from(&bytes[..write_len]))?; Ok(write_len) } diff --git a/test/apps/network/sockoption.c b/test/apps/network/sockoption.c index 73b8a349e..1c355120c 100644 --- a/test/apps/network/sockoption.c +++ b/test/apps/network/sockoption.c @@ -273,4 +273,4 @@ FN_TEST(keepidle) &keepidle_len), keepidle == 200); } -END_TEST() \ No newline at end of file +END_TEST() diff --git a/test/syscall_test/Makefile b/test/syscall_test/Makefile index d32672671..4378c54a9 100644 --- a/test/syscall_test/Makefile +++ b/test/syscall_test/Makefile @@ -47,6 +47,7 @@ TESTS ?= \ symlink_test \ sync_test \ sysinfo_test \ + tcp_socket_test \ timers_test \ truncate_test \ uidgid_test \ diff --git a/test/syscall_test/blocklists/tcp_socket_test b/test/syscall_test/blocklists/tcp_socket_test new file mode 100644 index 000000000..370bf58e6 --- /dev/null +++ b/test/syscall_test/blocklists/tcp_socket_test @@ -0,0 +1,16 @@ +AllInetTests/*/1 +AllInetTests/TcpSocketTest.ZeroWriteAllowed/0 +AllInetTests/TcpSocketTest.FullBuffer/0 +AllInetTests/TcpSocketTest.TcpSCMPriority/0 +AllInetTests/TcpSocketTest.Tiocinq/0 +AllInetTests/TcpSocketTest.TcpInq/0 +AllInetTests/TcpSocketTest.MsgTruncPeek/0 +AllInetTests/TcpSocketTest.MsgTruncLargeSize/0 +AllInetTests/TcpSocketTest.MsgTruncWithCtrunc/0 +AllInetTests/TcpSocketTest.MsgTrunc/0 + +AllInetTests/SimpleTcpSocketTest.SetMaxSeg/0 +AllInetTests/SimpleTcpSocketTest.CleanupOnConnectionRefused/0 +AllInetTests/SimpleTcpSocketTest.SetTCPWindowClampZeroClosedSocket/0 +AllInetTests/SimpleTcpSocketTest.SetSocketAttachDetachFilter/0 +AllInetTests/SimpleTcpSocketTest.SetSocketDetachFilterNoInstalledFilter/0 \ No newline at end of file