Fix error codes in bind

This commit is contained in:
Ruihan Li 2024-09-03 07:14:39 +08:00 committed by Tate, Hongliang Tian
parent a345e11b96
commit 6f915133b5
6 changed files with 140 additions and 21 deletions

View File

@ -30,6 +30,14 @@ impl UnixSocketAddr {
Ok(bound) Ok(bound)
} }
pub(super) fn bind_unnamed(&self) -> Result<()> {
if matches!(self, UnixSocketAddr::Unnamed) {
Ok(())
} else {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound");
}
}
pub(super) fn connect(&self) -> Result<UnixSocketAddrKey> { pub(super) fn connect(&self) -> Result<UnixSocketAddrKey> {
let bound = match self { let bound = match self {
Self::Unnamed => return_errno_with_message!( Self::Unnamed => return_errno_with_message!(

View File

@ -1,16 +1,22 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use core::ops::Deref;
use ostd::sync::PreemptDisabled;
use crate::{ use crate::{
events::{IoEvents, Observer}, events::{IoEvents, Observer},
fs::utils::{Channel, Consumer, Producer}, fs::utils::{Channel, Consumer, Producer},
net::socket::{unix::addr::UnixSocketAddrBound, SockShutdownCmd}, net::socket::{
unix::{addr::UnixSocketAddrBound, UnixSocketAddr},
SockShutdownCmd,
},
prelude::*, prelude::*,
process::signal::{Pollee, Poller}, process::signal::{Pollee, Poller},
}; };
pub(super) struct Connected { pub(super) struct Connected {
addr: Option<UnixSocketAddrBound>, addr: AddrView,
peer_addr: Option<UnixSocketAddrBound>,
reader: Consumer<u8>, reader: Consumer<u8>,
writer: Producer<u8>, writer: Producer<u8>,
} }
@ -27,15 +33,15 @@ impl Connected {
let (writer_this, reader_peer) = let (writer_this, reader_peer) =
Channel::with_capacity_and_pollees(DEFAULT_BUF_SIZE, writer_pollee, None).split(); Channel::with_capacity_and_pollees(DEFAULT_BUF_SIZE, writer_pollee, None).split();
let (addr_this, addr_peer) = AddrView::new_pair(addr, peer_addr);
let this = Connected { let this = Connected {
addr: addr.clone(), addr: addr_this,
peer_addr: peer_addr.clone(),
reader: reader_this, reader: reader_this,
writer: writer_this, writer: writer_this,
}; };
let peer = Connected { let peer = Connected {
addr: peer_addr, addr: addr_peer,
peer_addr: addr,
reader: reader_peer, reader: reader_peer,
writer: writer_peer, writer: writer_peer,
}; };
@ -43,12 +49,25 @@ impl Connected {
(this, peer) (this, peer)
} }
pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> { pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> {
self.addr.as_ref() self.addr.addr().deref().as_ref().cloned()
} }
pub(super) fn peer_addr(&self) -> Option<&UnixSocketAddrBound> { pub(super) fn peer_addr(&self) -> Option<UnixSocketAddrBound> {
self.peer_addr.as_ref() self.addr.peer_addr()
}
pub(super) fn bind(&self, addr_to_bind: UnixSocketAddr) -> Result<()> {
let mut addr = self.addr.addr();
if addr.is_some() {
return addr_to_bind.bind_unnamed();
}
let bound_addr = addr_to_bind.bind()?;
*addr = Some(bound_addr);
Ok(())
} }
pub(super) fn try_read(&self, buf: &mut [u8]) -> Result<usize> { pub(super) fn try_read(&self, buf: &mut [u8]) -> Result<usize> {
@ -123,4 +142,38 @@ pub(super) fn combine_io_events(
events & (mask | IoEvents::ALWAYS_POLL) events & (mask | IoEvents::ALWAYS_POLL)
} }
struct AddrView {
addr: Arc<SpinLock<Option<UnixSocketAddrBound>>>,
peer: Arc<SpinLock<Option<UnixSocketAddrBound>>>,
}
impl AddrView {
fn new_pair(
first: Option<UnixSocketAddrBound>,
second: Option<UnixSocketAddrBound>,
) -> (AddrView, AddrView) {
let first = Arc::new(SpinLock::new(first));
let second = Arc::new(SpinLock::new(second));
let view1 = AddrView {
addr: first.clone(),
peer: second.clone(),
};
let view2 = AddrView {
addr: second,
peer: first,
};
(view1, view2)
}
fn addr(&self) -> SpinLockGuard<Option<UnixSocketAddrBound>, PreemptDisabled> {
self.addr.lock()
}
fn peer_addr(&self) -> Option<UnixSocketAddrBound> {
self.peer.lock().as_ref().cloned()
}
}
const DEFAULT_BUF_SIZE: usize = 65536; const DEFAULT_BUF_SIZE: usize = 65536;

View File

@ -37,7 +37,7 @@ impl Init {
pub(super) fn bind(&mut self, addr_to_bind: UnixSocketAddr) -> Result<()> { pub(super) fn bind(&mut self, addr_to_bind: UnixSocketAddr) -> Result<()> {
if self.addr.is_some() { if self.addr.is_some() {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound"); return addr_to_bind.bind_unnamed();
} }
let bound_addr = addr_to_bind.bind()?; let bound_addr = addr_to_bind.bind()?;

View File

@ -50,7 +50,7 @@ impl Listener {
pub(super) fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> { pub(super) fn try_accept(&self) -> Result<(Arc<dyn FileLike>, SocketAddr)> {
let connected = self.backlog.pop_incoming()?; let connected = self.backlog.pop_incoming()?;
let peer_addr = connected.peer_addr().cloned().into(); let peer_addr = connected.peer_addr().into();
let socket = UnixStreamSocket::new_connected(connected, false); let socket = UnixStreamSocket::new_connected(connected, false);
Ok((socket, peer_addr)) Ok((socket, peer_addr))

View File

@ -229,11 +229,11 @@ impl Socket for UnixStreamSocket {
match self.state.write().as_mut() { match self.state.write().as_mut() {
State::Init(init) => init.bind(addr), State::Init(init) => init.bind(addr),
_ => return_errno_with_message!( State::Connected(connected) => connected.bind(addr),
Errno::EINVAL, State::Listen(_) => {
"cannot bind a listening or connected socket" // Listening sockets are always already bound.
), addr.bind_unnamed()
// FIXME: Maybe binding a connected socket should also be allowed? }
} }
} }
@ -315,7 +315,7 @@ impl Socket for UnixStreamSocket {
let addr = match self.state.read().as_ref() { let addr = match self.state.read().as_ref() {
State::Init(init) => init.addr().cloned(), State::Init(init) => init.addr().cloned(),
State::Listen(listen) => Some(listen.addr().clone()), State::Listen(listen) => Some(listen.addr().clone()),
State::Connected(connected) => connected.addr().cloned(), State::Connected(connected) => connected.addr(),
}; };
Ok(addr.into()) Ok(addr.into())
@ -323,8 +323,10 @@ impl Socket for UnixStreamSocket {
fn peer_addr(&self) -> Result<SocketAddr> { fn peer_addr(&self) -> Result<SocketAddr> {
let peer_addr = match self.state.read().as_ref() { let peer_addr = match self.state.read().as_ref() {
State::Connected(connected) => connected.peer_addr().cloned(), State::Connected(connected) => connected.peer_addr(),
_ => return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected"), State::Init(_) | State::Listen(_) => {
return_errno_with_message!(Errno::ENOTCONN, "the socket is not connected")
}
}; };
Ok(peer_addr.into()) Ok(peer_addr.into())

View File

@ -211,6 +211,62 @@ FN_TEST(getpeername)
} }
END_TEST() END_TEST()
FN_TEST(bind)
{
TEST_ERRNO(bind(sk_bound, (struct sockaddr *)&UNIX_ADDR("\0Z"),
PATH_OFFSET + 1),
EINVAL);
TEST_ERRNO(bind(sk_listen, (struct sockaddr *)&UNIX_ADDR("\0Z"),
PATH_OFFSET + 1),
EINVAL);
TEST_SUCC(bind(sk_bound, (struct sockaddr *)&UNNAMED_ADDR,
UNNAMED_ADDRLEN));
TEST_SUCC(bind(sk_listen, (struct sockaddr *)&UNNAMED_ADDR,
UNNAMED_ADDRLEN));
}
END_TEST()
FN_TEST(bind_connected)
{
int fildes[2];
struct sockaddr_un addr;
socklen_t addrlen;
TEST_SUCC(socketpair(PF_UNIX, SOCK_STREAM, 0, fildes));
TEST_SUCC(bind(fildes[0], (struct sockaddr *)&UNIX_ADDR("\0X"),
PATH_OFFSET + 2));
addrlen = sizeof(addr);
TEST_RES(getpeername(fildes[1], (struct sockaddr *)&addr, &addrlen),
addrlen == PATH_OFFSET + 2 && memcmp(&addr, &UNIX_ADDR("\0X"),
PATH_OFFSET + 2) == 0);
TEST_SUCC(bind(fildes[1], (struct sockaddr *)&UNIX_ADDR("\0Y"),
PATH_OFFSET + 2));
addrlen = sizeof(addr);
TEST_RES(getpeername(fildes[0], (struct sockaddr *)&addr, &addrlen),
addrlen == PATH_OFFSET + 2 && memcmp(&addr, &UNIX_ADDR("\0Y"),
PATH_OFFSET + 2) == 0);
TEST_ERRNO(bind(fildes[0], (struct sockaddr *)&UNIX_ADDR("\0Z"),
PATH_OFFSET + 2),
EINVAL);
TEST_ERRNO(bind(fildes[1], (struct sockaddr *)&UNIX_ADDR("\0Z"),
PATH_OFFSET + 2),
EINVAL);
TEST_SUCC(bind(fildes[0], (struct sockaddr *)&UNNAMED_ADDR,
UNNAMED_ADDRLEN));
TEST_SUCC(bind(fildes[1], (struct sockaddr *)&UNNAMED_ADDR,
UNNAMED_ADDRLEN));
TEST_SUCC(close(fildes[0]));
TEST_SUCC(close(fildes[1]));
}
END_TEST()
FN_TEST(connect) FN_TEST(connect)
{ {
TEST_ERRNO(connect(sk_unbound, (struct sockaddr *)&BOUND_ADDR, TEST_ERRNO(connect(sk_unbound, (struct sockaddr *)&BOUND_ADDR,