From 6f915133b5b2ae7dcf657976aa3b1b5ac2ea709c Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Tue, 3 Sep 2024 07:14:39 +0800 Subject: [PATCH] Fix error codes in `bind` --- kernel/src/net/socket/unix/addr.rs | 8 ++ .../src/net/socket/unix/stream/connected.rs | 75 ++++++++++++++++--- kernel/src/net/socket/unix/stream/init.rs | 2 +- kernel/src/net/socket/unix/stream/listener.rs | 2 +- kernel/src/net/socket/unix/stream/socket.rs | 18 +++-- test/apps/network/unix_err.c | 56 ++++++++++++++ 6 files changed, 140 insertions(+), 21 deletions(-) diff --git a/kernel/src/net/socket/unix/addr.rs b/kernel/src/net/socket/unix/addr.rs index 0e8ef226d..4a173ca6d 100644 --- a/kernel/src/net/socket/unix/addr.rs +++ b/kernel/src/net/socket/unix/addr.rs @@ -30,6 +30,14 @@ impl UnixSocketAddr { Ok(bound) } + pub(super) fn bind_unnamed(&self) -> Result<()> { + if matches!(self, UnixSocketAddr::Unnamed) { + Ok(()) + } else { + return_errno_with_message!(Errno::EINVAL, "the socket is already bound"); + } + } + pub(super) fn connect(&self) -> Result { let bound = match self { Self::Unnamed => return_errno_with_message!( diff --git a/kernel/src/net/socket/unix/stream/connected.rs b/kernel/src/net/socket/unix/stream/connected.rs index 30a7e5319..90dbf3a92 100644 --- a/kernel/src/net/socket/unix/stream/connected.rs +++ b/kernel/src/net/socket/unix/stream/connected.rs @@ -1,16 +1,22 @@ // SPDX-License-Identifier: MPL-2.0 +use core::ops::Deref; + +use ostd::sync::PreemptDisabled; + use crate::{ events::{IoEvents, Observer}, fs::utils::{Channel, Consumer, Producer}, - net::socket::{unix::addr::UnixSocketAddrBound, SockShutdownCmd}, + net::socket::{ + unix::{addr::UnixSocketAddrBound, UnixSocketAddr}, + SockShutdownCmd, + }, prelude::*, process::signal::{Pollee, Poller}, }; pub(super) struct Connected { - addr: Option, - peer_addr: Option, + addr: AddrView, reader: Consumer, writer: Producer, } @@ -27,15 +33,15 @@ impl Connected { let (writer_this, reader_peer) = Channel::with_capacity_and_pollees(DEFAULT_BUF_SIZE, writer_pollee, None).split(); + let (addr_this, addr_peer) = AddrView::new_pair(addr, peer_addr); + let this = Connected { - addr: addr.clone(), - peer_addr: peer_addr.clone(), + addr: addr_this, reader: reader_this, writer: writer_this, }; let peer = Connected { - addr: peer_addr, - peer_addr: addr, + addr: addr_peer, reader: reader_peer, writer: writer_peer, }; @@ -43,12 +49,25 @@ impl Connected { (this, peer) } - pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> { - self.addr.as_ref() + pub(super) fn addr(&self) -> Option { + self.addr.addr().deref().as_ref().cloned() } - pub(super) fn peer_addr(&self) -> Option<&UnixSocketAddrBound> { - self.peer_addr.as_ref() + pub(super) fn peer_addr(&self) -> Option { + self.addr.peer_addr() + } + + pub(super) fn bind(&self, addr_to_bind: UnixSocketAddr) -> Result<()> { + let mut addr = self.addr.addr(); + + if addr.is_some() { + return addr_to_bind.bind_unnamed(); + } + + let bound_addr = addr_to_bind.bind()?; + *addr = Some(bound_addr); + + Ok(()) } pub(super) fn try_read(&self, buf: &mut [u8]) -> Result { @@ -123,4 +142,38 @@ pub(super) fn combine_io_events( events & (mask | IoEvents::ALWAYS_POLL) } +struct AddrView { + addr: Arc>>, + peer: Arc>>, +} + +impl AddrView { + fn new_pair( + first: Option, + second: Option, + ) -> (AddrView, AddrView) { + let first = Arc::new(SpinLock::new(first)); + let second = Arc::new(SpinLock::new(second)); + + let view1 = AddrView { + addr: first.clone(), + peer: second.clone(), + }; + let view2 = AddrView { + addr: second, + peer: first, + }; + + (view1, view2) + } + + fn addr(&self) -> SpinLockGuard, PreemptDisabled> { + self.addr.lock() + } + + fn peer_addr(&self) -> Option { + self.peer.lock().as_ref().cloned() + } +} + const DEFAULT_BUF_SIZE: usize = 65536; diff --git a/kernel/src/net/socket/unix/stream/init.rs b/kernel/src/net/socket/unix/stream/init.rs index 3493c3f85..b67a047fc 100644 --- a/kernel/src/net/socket/unix/stream/init.rs +++ b/kernel/src/net/socket/unix/stream/init.rs @@ -37,7 +37,7 @@ impl Init { pub(super) fn bind(&mut self, addr_to_bind: UnixSocketAddr) -> Result<()> { if self.addr.is_some() { - return_errno_with_message!(Errno::EINVAL, "the socket is already bound"); + return addr_to_bind.bind_unnamed(); } let bound_addr = addr_to_bind.bind()?; diff --git a/kernel/src/net/socket/unix/stream/listener.rs b/kernel/src/net/socket/unix/stream/listener.rs index 3d8e3108b..05a425547 100644 --- a/kernel/src/net/socket/unix/stream/listener.rs +++ b/kernel/src/net/socket/unix/stream/listener.rs @@ -50,7 +50,7 @@ impl Listener { pub(super) fn try_accept(&self) -> Result<(Arc, SocketAddr)> { let connected = self.backlog.pop_incoming()?; - let peer_addr = connected.peer_addr().cloned().into(); + let peer_addr = connected.peer_addr().into(); let socket = UnixStreamSocket::new_connected(connected, false); Ok((socket, peer_addr)) diff --git a/kernel/src/net/socket/unix/stream/socket.rs b/kernel/src/net/socket/unix/stream/socket.rs index 016f54ae8..34aeddf13 100644 --- a/kernel/src/net/socket/unix/stream/socket.rs +++ b/kernel/src/net/socket/unix/stream/socket.rs @@ -229,11 +229,11 @@ impl Socket for UnixStreamSocket { match self.state.write().as_mut() { State::Init(init) => init.bind(addr), - _ => return_errno_with_message!( - Errno::EINVAL, - "cannot bind a listening or connected socket" - ), - // FIXME: Maybe binding a connected socket should also be allowed? + State::Connected(connected) => connected.bind(addr), + State::Listen(_) => { + // Listening sockets are always already bound. + addr.bind_unnamed() + } } } @@ -315,7 +315,7 @@ impl Socket for UnixStreamSocket { let addr = match self.state.read().as_ref() { State::Init(init) => init.addr().cloned(), State::Listen(listen) => Some(listen.addr().clone()), - State::Connected(connected) => connected.addr().cloned(), + State::Connected(connected) => connected.addr(), }; Ok(addr.into()) @@ -323,8 +323,10 @@ impl Socket for UnixStreamSocket { fn peer_addr(&self) -> Result { let peer_addr = match self.state.read().as_ref() { - State::Connected(connected) => connected.peer_addr().cloned(), - _ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"), + State::Connected(connected) => connected.peer_addr(), + State::Init(_) | State::Listen(_) => { + return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected") + } }; Ok(peer_addr.into()) diff --git a/test/apps/network/unix_err.c b/test/apps/network/unix_err.c index 79df46d4e..c25b1c978 100644 --- a/test/apps/network/unix_err.c +++ b/test/apps/network/unix_err.c @@ -211,6 +211,62 @@ FN_TEST(getpeername) } END_TEST() +FN_TEST(bind) +{ + TEST_ERRNO(bind(sk_bound, (struct sockaddr *)&UNIX_ADDR("\0Z"), + PATH_OFFSET + 1), + EINVAL); + + TEST_ERRNO(bind(sk_listen, (struct sockaddr *)&UNIX_ADDR("\0Z"), + PATH_OFFSET + 1), + EINVAL); + + TEST_SUCC(bind(sk_bound, (struct sockaddr *)&UNNAMED_ADDR, + UNNAMED_ADDRLEN)); + + TEST_SUCC(bind(sk_listen, (struct sockaddr *)&UNNAMED_ADDR, + UNNAMED_ADDRLEN)); +} +END_TEST() + +FN_TEST(bind_connected) +{ + int fildes[2]; + struct sockaddr_un addr; + socklen_t addrlen; + + TEST_SUCC(socketpair(PF_UNIX, SOCK_STREAM, 0, fildes)); + + TEST_SUCC(bind(fildes[0], (struct sockaddr *)&UNIX_ADDR("\0X"), + PATH_OFFSET + 2)); + addrlen = sizeof(addr); + TEST_RES(getpeername(fildes[1], (struct sockaddr *)&addr, &addrlen), + addrlen == PATH_OFFSET + 2 && memcmp(&addr, &UNIX_ADDR("\0X"), + PATH_OFFSET + 2) == 0); + + TEST_SUCC(bind(fildes[1], (struct sockaddr *)&UNIX_ADDR("\0Y"), + PATH_OFFSET + 2)); + addrlen = sizeof(addr); + TEST_RES(getpeername(fildes[0], (struct sockaddr *)&addr, &addrlen), + addrlen == PATH_OFFSET + 2 && memcmp(&addr, &UNIX_ADDR("\0Y"), + PATH_OFFSET + 2) == 0); + + TEST_ERRNO(bind(fildes[0], (struct sockaddr *)&UNIX_ADDR("\0Z"), + PATH_OFFSET + 2), + EINVAL); + TEST_ERRNO(bind(fildes[1], (struct sockaddr *)&UNIX_ADDR("\0Z"), + PATH_OFFSET + 2), + EINVAL); + TEST_SUCC(bind(fildes[0], (struct sockaddr *)&UNNAMED_ADDR, + UNNAMED_ADDRLEN)); + TEST_SUCC(bind(fildes[1], (struct sockaddr *)&UNNAMED_ADDR, + UNNAMED_ADDRLEN)); + + TEST_SUCC(close(fildes[0])); + TEST_SUCC(close(fildes[1])); +} +END_TEST() + FN_TEST(connect) { TEST_ERRNO(connect(sk_unbound, (struct sockaddr *)&BOUND_ADDR,