Unpack Arc<Endpoint> in UNIX sockets

This commit is contained in:
Ruihan Li 2024-06-30 21:58:38 +08:00 committed by Tate, Hongliang Tian
parent 7ee728e2c7
commit 31d99c66c2
5 changed files with 68 additions and 76 deletions

View File

@ -9,19 +9,19 @@ use crate::{
};
pub(super) struct Connected {
local_endpoint: Arc<Endpoint>,
local_endpoint: Endpoint,
}
impl Connected {
pub(super) fn new(local_endpoint: Arc<Endpoint>) -> Self {
pub(super) fn new(local_endpoint: Endpoint) -> Self {
Connected { local_endpoint }
}
pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> {
pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> {
self.local_endpoint.addr()
}
pub(super) fn peer_addr(&self) -> Option<UnixSocketAddrBound> {
pub(super) fn peer_addr(&self) -> Option<&UnixSocketAddrBound> {
self.local_endpoint.peer_addr()
}

View File

@ -8,118 +8,114 @@ use crate::{
process::signal::Poller,
};
pub(super) struct Endpoint(Inner);
struct Inner {
addr: RwLock<Option<UnixSocketAddrBound>>,
pub(super) struct Endpoint {
addr: Option<UnixSocketAddrBound>,
peer_addr: Option<UnixSocketAddrBound>,
reader: Consumer<u8>,
writer: Producer<u8>,
peer: Weak<Endpoint>,
}
impl Endpoint {
pub(super) fn new_pair(is_nonblocking: bool) -> Result<(Arc<Endpoint>, Arc<Endpoint>)> {
pub(super) fn new_pair(
addr: Option<UnixSocketAddrBound>,
peer_addr: Option<UnixSocketAddrBound>,
is_nonblocking: bool,
) -> Result<(Endpoint, Endpoint)> {
let flags = if is_nonblocking {
StatusFlags::O_NONBLOCK
} else {
StatusFlags::empty()
};
let (writer_a, reader_b) =
let (writer_this, reader_peer) =
Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split();
let (writer_b, reader_a) =
let (writer_peer, reader_this) =
Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split();
let mut endpoint_b = None;
let endpoint_a = Arc::new_cyclic(|endpoint_a_ref| {
let peer = Arc::new(Endpoint::new(reader_b, writer_b, endpoint_a_ref.clone()));
let endpoint_a = Endpoint::new(reader_a, writer_a, Arc::downgrade(&peer));
endpoint_b = Some(peer);
endpoint_a
});
Ok((endpoint_a, endpoint_b.unwrap()))
let this = Endpoint {
addr: addr.clone(),
peer_addr: peer_addr.clone(),
reader: reader_this,
writer: writer_this,
};
let peer = Endpoint {
addr: peer_addr,
peer_addr: addr,
reader: reader_peer,
writer: writer_peer,
};
Ok((this, peer))
}
fn new(reader: Consumer<u8>, writer: Producer<u8>, peer: Weak<Endpoint>) -> Self {
Self(Inner {
addr: RwLock::new(None),
reader,
writer,
peer,
})
pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> {
self.addr.as_ref()
}
pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> {
self.0.addr.read().clone()
}
pub(super) fn set_addr(&self, addr: UnixSocketAddrBound) {
*self.0.addr.write() = Some(addr);
}
pub(super) fn peer_addr(&self) -> Option<UnixSocketAddrBound> {
self.0.peer.upgrade().and_then(|peer| peer.addr())
pub(super) fn peer_addr(&self) -> Option<&UnixSocketAddrBound> {
self.peer_addr.as_ref()
}
pub(super) fn is_nonblocking(&self) -> bool {
let reader_status = self.0.reader.is_nonblocking();
let writer_status = self.0.writer.is_nonblocking();
let reader_status = self.reader.is_nonblocking();
let writer_status = self.writer.is_nonblocking();
debug_assert!(reader_status == writer_status);
reader_status
}
pub(super) fn set_nonblocking(&self, is_nonblocking: bool) -> Result<()> {
let mut reader_flags = self.0.reader.status_flags();
let mut reader_flags = self.reader.status_flags();
reader_flags.set(StatusFlags::O_NONBLOCK, is_nonblocking);
self.0.reader.set_status_flags(reader_flags)?;
self.reader.set_status_flags(reader_flags)?;
let mut writer_flags = self.0.writer.status_flags();
let mut writer_flags = self.writer.status_flags();
writer_flags.set(StatusFlags::O_NONBLOCK, is_nonblocking);
self.0.writer.set_status_flags(writer_flags)?;
self.writer.set_status_flags(writer_flags)?;
Ok(())
}
pub(super) fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.0.reader.read(buf)
self.reader.read(buf)
}
pub(super) fn write(&self, buf: &[u8]) -> Result<usize> {
self.0.writer.write(buf)
self.writer.write(buf)
}
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
if !self.is_connected() {
return_errno_with_message!(Errno::ENOTCONN, "The socket is not connected.");
}
// FIXME: If the socket has already been shut down, should we return an error code?
if cmd.shut_read() {
self.0.reader.shutdown();
self.reader.shutdown();
}
if cmd.shut_write() {
self.0.writer.shutdown();
self.writer.shutdown();
}
Ok(())
}
pub(super) fn is_connected(&self) -> bool {
self.0.peer.upgrade().is_some()
}
pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents {
let mut events = IoEvents::empty();
// FIXME: should reader and writer use the same mask?
let reader_events = self.0.reader.poll(mask, poller.as_deref_mut());
let writer_events = self.0.writer.poll(mask, poller);
if reader_events.contains(IoEvents::HUP) || self.0.reader.is_shutdown() {
// FIXME: should reader and writer use the same mask?
let reader_events = self.reader.poll(mask, poller.as_deref_mut());
let writer_events = self.writer.poll(mask, poller);
// FIXME: Check this logic later.
if reader_events.contains(IoEvents::HUP) || self.reader.is_shutdown() {
events |= IoEvents::RDHUP | IoEvents::IN;
if writer_events.contains(IoEvents::ERR) || self.0.writer.is_shutdown() {
if writer_events.contains(IoEvents::ERR) || self.writer.is_shutdown() {
events |= IoEvents::HUP | IoEvents::OUT;
}
}
events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT);
events
}
}

View File

@ -57,11 +57,8 @@ impl Init {
}
}
let (this_end, remote_end) = Endpoint::new_pair(self.is_nonblocking())?;
remote_end.set_addr(remote_addr.clone());
if let Some(addr) = addr {
this_end.set_addr(addr.clone());
};
let (this_end, remote_end) =
Endpoint::new_pair(addr, Some(remote_addr.clone()), self.is_nonblocking())?;
push_incoming(remote_addr, remote_end)?;
Ok(Connected::new(this_end))

View File

@ -118,7 +118,7 @@ impl BacklogTable {
.ok_or_else(|| Error::with_message(Errno::EINVAL, "the socket is not listened"))
}
fn pop_incoming(&self, nonblocking: bool, addr: &UnixSocketAddrBound) -> Result<Arc<Endpoint>> {
fn pop_incoming(&self, nonblocking: bool, addr: &UnixSocketAddrBound) -> Result<Endpoint> {
let mut poller = Poller::new();
loop {
let backlog = self.get_backlog(addr)?;
@ -147,7 +147,7 @@ impl BacklogTable {
}
}
fn push_incoming(&self, addr: &UnixSocketAddrBound, endpoint: Arc<Endpoint>) -> Result<()> {
fn push_incoming(&self, addr: &UnixSocketAddrBound, endpoint: Endpoint) -> Result<()> {
let backlog = self.get_backlog(addr).map_err(|_| {
Error::with_message(
Errno::ECONNREFUSED,
@ -171,7 +171,7 @@ impl BacklogTable {
struct Backlog {
pollee: Pollee,
backlog: usize,
incoming_endpoints: Mutex<VecDeque<Arc<Endpoint>>>,
incoming_endpoints: Mutex<VecDeque<Endpoint>>,
}
impl Backlog {
@ -183,7 +183,7 @@ impl Backlog {
}
}
fn push_incoming(&self, endpoint: Arc<Endpoint>) -> Result<()> {
fn push_incoming(&self, endpoint: Endpoint) -> Result<()> {
let mut endpoints = self.incoming_endpoints.lock();
if endpoints.len() >= self.backlog {
return_errno_with_message!(Errno::ECONNREFUSED, "incoming_endpoints is full");
@ -193,7 +193,7 @@ impl Backlog {
Ok(())
}
fn pop_incoming(&self) -> Option<Arc<Endpoint>> {
fn pop_incoming(&self) -> Option<Endpoint> {
let mut incoming_endpoints = self.incoming_endpoints.lock();
let endpoint = incoming_endpoints.pop_front();
if incoming_endpoints.is_empty() {
@ -218,9 +218,6 @@ pub(super) fn unregister_backlog(addr: &UnixSocketAddrBound) {
BACKLOG_TABLE.remove_backlog(addr);
}
pub(super) fn push_incoming(
remote_addr: &UnixSocketAddrBound,
remote_end: Arc<Endpoint>,
) -> Result<()> {
pub(super) fn push_incoming(remote_addr: &UnixSocketAddrBound, remote_end: Endpoint) -> Result<()> {
BACKLOG_TABLE.push_incoming(remote_addr, remote_end)
}

View File

@ -52,7 +52,8 @@ impl UnixStreamSocket {
}
pub fn new_pair(nonblocking: bool) -> Result<(Arc<Self>, Arc<Self>)> {
let (end_a, end_b) = Endpoint::new_pair(nonblocking)?;
let (end_a, end_b) = Endpoint::new_pair(None, None, nonblocking)?;
let connected_a = {
let connected = Connected::new(end_a);
Self::new_connected(connected)
@ -61,6 +62,7 @@ impl UnixStreamSocket {
let connected = Connected::new(end_b);
Self::new_connected(connected)
};
Ok((Arc::new(connected_a), Arc::new(connected_b)))
}
@ -69,7 +71,7 @@ impl UnixStreamSocket {
match &*status {
State::Init(init) => init.addr(),
State::Listen(listen) => Some(listen.addr().clone()),
State::Connected(connected) => connected.addr(),
State::Connected(connected) => connected.addr().cloned(),
}
}
@ -250,7 +252,7 @@ impl Socket for UnixStreamSocket {
let addr = match &*self.0.read() {
State::Init(init) => init.addr(),
State::Listen(listen) => Some(listen.addr().clone()),
State::Connected(connected) => connected.addr(),
State::Connected(connected) => connected.addr().cloned(),
};
addr.map(Into::<SocketAddr>::into)