Unpack Arc<Endpoint> in UNIX sockets

This commit is contained in:
Ruihan Li
2024-06-30 21:58:38 +08:00
committed by Tate, Hongliang Tian
parent 7ee728e2c7
commit 31d99c66c2
5 changed files with 68 additions and 76 deletions

View File

@ -9,19 +9,19 @@ use crate::{
}; };
pub(super) struct Connected { pub(super) struct Connected {
local_endpoint: Arc<Endpoint>, local_endpoint: Endpoint,
} }
impl Connected { impl Connected {
pub(super) fn new(local_endpoint: Arc<Endpoint>) -> Self { pub(super) fn new(local_endpoint: Endpoint) -> Self {
Connected { local_endpoint } Connected { local_endpoint }
} }
pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> { pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> {
self.local_endpoint.addr() self.local_endpoint.addr()
} }
pub(super) fn peer_addr(&self) -> Option<UnixSocketAddrBound> { pub(super) fn peer_addr(&self) -> Option<&UnixSocketAddrBound> {
self.local_endpoint.peer_addr() self.local_endpoint.peer_addr()
} }

View File

@ -8,118 +8,114 @@ use crate::{
process::signal::Poller, process::signal::Poller,
}; };
pub(super) struct Endpoint(Inner); pub(super) struct Endpoint {
addr: Option<UnixSocketAddrBound>,
struct Inner { peer_addr: Option<UnixSocketAddrBound>,
addr: RwLock<Option<UnixSocketAddrBound>>,
reader: Consumer<u8>, reader: Consumer<u8>,
writer: Producer<u8>, writer: Producer<u8>,
peer: Weak<Endpoint>,
} }
impl Endpoint { impl Endpoint {
pub(super) fn new_pair(is_nonblocking: bool) -> Result<(Arc<Endpoint>, Arc<Endpoint>)> { pub(super) fn new_pair(
addr: Option<UnixSocketAddrBound>,
peer_addr: Option<UnixSocketAddrBound>,
is_nonblocking: bool,
) -> Result<(Endpoint, Endpoint)> {
let flags = if is_nonblocking { let flags = if is_nonblocking {
StatusFlags::O_NONBLOCK StatusFlags::O_NONBLOCK
} else { } else {
StatusFlags::empty() StatusFlags::empty()
}; };
let (writer_a, reader_b) =
let (writer_this, reader_peer) =
Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split(); Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split();
let (writer_b, reader_a) = let (writer_peer, reader_this) =
Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split(); Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split();
let mut endpoint_b = None;
let endpoint_a = Arc::new_cyclic(|endpoint_a_ref| { let this = Endpoint {
let peer = Arc::new(Endpoint::new(reader_b, writer_b, endpoint_a_ref.clone())); addr: addr.clone(),
let endpoint_a = Endpoint::new(reader_a, writer_a, Arc::downgrade(&peer)); peer_addr: peer_addr.clone(),
endpoint_b = Some(peer); reader: reader_this,
endpoint_a writer: writer_this,
}); };
Ok((endpoint_a, endpoint_b.unwrap())) let peer = Endpoint {
addr: peer_addr,
peer_addr: addr,
reader: reader_peer,
writer: writer_peer,
};
Ok((this, peer))
} }
fn new(reader: Consumer<u8>, writer: Producer<u8>, peer: Weak<Endpoint>) -> Self { pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> {
Self(Inner { self.addr.as_ref()
addr: RwLock::new(None),
reader,
writer,
peer,
})
} }
pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> { pub(super) fn peer_addr(&self) -> Option<&UnixSocketAddrBound> {
self.0.addr.read().clone() self.peer_addr.as_ref()
}
pub(super) fn set_addr(&self, addr: UnixSocketAddrBound) {
*self.0.addr.write() = Some(addr);
}
pub(super) fn peer_addr(&self) -> Option<UnixSocketAddrBound> {
self.0.peer.upgrade().and_then(|peer| peer.addr())
} }
pub(super) fn is_nonblocking(&self) -> bool { pub(super) fn is_nonblocking(&self) -> bool {
let reader_status = self.0.reader.is_nonblocking(); let reader_status = self.reader.is_nonblocking();
let writer_status = self.0.writer.is_nonblocking(); let writer_status = self.writer.is_nonblocking();
debug_assert!(reader_status == writer_status); debug_assert!(reader_status == writer_status);
reader_status reader_status
} }
pub(super) fn set_nonblocking(&self, is_nonblocking: bool) -> Result<()> { pub(super) fn set_nonblocking(&self, is_nonblocking: bool) -> Result<()> {
let mut reader_flags = self.0.reader.status_flags(); let mut reader_flags = self.reader.status_flags();
reader_flags.set(StatusFlags::O_NONBLOCK, is_nonblocking); reader_flags.set(StatusFlags::O_NONBLOCK, is_nonblocking);
self.0.reader.set_status_flags(reader_flags)?; self.reader.set_status_flags(reader_flags)?;
let mut writer_flags = self.0.writer.status_flags(); let mut writer_flags = self.writer.status_flags();
writer_flags.set(StatusFlags::O_NONBLOCK, is_nonblocking); writer_flags.set(StatusFlags::O_NONBLOCK, is_nonblocking);
self.0.writer.set_status_flags(writer_flags)?; self.writer.set_status_flags(writer_flags)?;
Ok(()) Ok(())
} }
pub(super) fn read(&self, buf: &mut [u8]) -> Result<usize> { pub(super) fn read(&self, buf: &mut [u8]) -> Result<usize> {
self.0.reader.read(buf) self.reader.read(buf)
} }
pub(super) fn write(&self, buf: &[u8]) -> Result<usize> { pub(super) fn write(&self, buf: &[u8]) -> Result<usize> {
self.0.writer.write(buf) self.writer.write(buf)
} }
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
if !self.is_connected() { // FIXME: If the socket has already been shut down, should we return an error code?
return_errno_with_message!(Errno::ENOTCONN, "The socket is not connected.");
}
if cmd.shut_read() { if cmd.shut_read() {
self.0.reader.shutdown(); self.reader.shutdown();
} }
if cmd.shut_write() { if cmd.shut_write() {
self.0.writer.shutdown(); self.writer.shutdown();
} }
Ok(()) Ok(())
} }
pub(super) fn is_connected(&self) -> bool {
self.0.peer.upgrade().is_some()
}
pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents { pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents {
let mut events = IoEvents::empty(); let mut events = IoEvents::empty();
// FIXME: should reader and writer use the same mask?
let reader_events = self.0.reader.poll(mask, poller.as_deref_mut());
let writer_events = self.0.writer.poll(mask, poller);
if reader_events.contains(IoEvents::HUP) || self.0.reader.is_shutdown() { // FIXME: should reader and writer use the same mask?
let reader_events = self.reader.poll(mask, poller.as_deref_mut());
let writer_events = self.writer.poll(mask, poller);
// FIXME: Check this logic later.
if reader_events.contains(IoEvents::HUP) || self.reader.is_shutdown() {
events |= IoEvents::RDHUP | IoEvents::IN; events |= IoEvents::RDHUP | IoEvents::IN;
if writer_events.contains(IoEvents::ERR) || self.0.writer.is_shutdown() { if writer_events.contains(IoEvents::ERR) || self.writer.is_shutdown() {
events |= IoEvents::HUP | IoEvents::OUT; events |= IoEvents::HUP | IoEvents::OUT;
} }
} }
events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT); events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT);
events events
} }
} }

View File

@ -57,11 +57,8 @@ impl Init {
} }
} }
let (this_end, remote_end) = Endpoint::new_pair(self.is_nonblocking())?; let (this_end, remote_end) =
remote_end.set_addr(remote_addr.clone()); Endpoint::new_pair(addr, Some(remote_addr.clone()), self.is_nonblocking())?;
if let Some(addr) = addr {
this_end.set_addr(addr.clone());
};
push_incoming(remote_addr, remote_end)?; push_incoming(remote_addr, remote_end)?;
Ok(Connected::new(this_end)) Ok(Connected::new(this_end))

View File

@ -118,7 +118,7 @@ impl BacklogTable {
.ok_or_else(|| Error::with_message(Errno::EINVAL, "the socket is not listened")) .ok_or_else(|| Error::with_message(Errno::EINVAL, "the socket is not listened"))
} }
fn pop_incoming(&self, nonblocking: bool, addr: &UnixSocketAddrBound) -> Result<Arc<Endpoint>> { fn pop_incoming(&self, nonblocking: bool, addr: &UnixSocketAddrBound) -> Result<Endpoint> {
let mut poller = Poller::new(); let mut poller = Poller::new();
loop { loop {
let backlog = self.get_backlog(addr)?; let backlog = self.get_backlog(addr)?;
@ -147,7 +147,7 @@ impl BacklogTable {
} }
} }
fn push_incoming(&self, addr: &UnixSocketAddrBound, endpoint: Arc<Endpoint>) -> Result<()> { fn push_incoming(&self, addr: &UnixSocketAddrBound, endpoint: Endpoint) -> Result<()> {
let backlog = self.get_backlog(addr).map_err(|_| { let backlog = self.get_backlog(addr).map_err(|_| {
Error::with_message( Error::with_message(
Errno::ECONNREFUSED, Errno::ECONNREFUSED,
@ -171,7 +171,7 @@ impl BacklogTable {
struct Backlog { struct Backlog {
pollee: Pollee, pollee: Pollee,
backlog: usize, backlog: usize,
incoming_endpoints: Mutex<VecDeque<Arc<Endpoint>>>, incoming_endpoints: Mutex<VecDeque<Endpoint>>,
} }
impl Backlog { impl Backlog {
@ -183,7 +183,7 @@ impl Backlog {
} }
} }
fn push_incoming(&self, endpoint: Arc<Endpoint>) -> Result<()> { fn push_incoming(&self, endpoint: Endpoint) -> Result<()> {
let mut endpoints = self.incoming_endpoints.lock(); let mut endpoints = self.incoming_endpoints.lock();
if endpoints.len() >= self.backlog { if endpoints.len() >= self.backlog {
return_errno_with_message!(Errno::ECONNREFUSED, "incoming_endpoints is full"); return_errno_with_message!(Errno::ECONNREFUSED, "incoming_endpoints is full");
@ -193,7 +193,7 @@ impl Backlog {
Ok(()) Ok(())
} }
fn pop_incoming(&self) -> Option<Arc<Endpoint>> { fn pop_incoming(&self) -> Option<Endpoint> {
let mut incoming_endpoints = self.incoming_endpoints.lock(); let mut incoming_endpoints = self.incoming_endpoints.lock();
let endpoint = incoming_endpoints.pop_front(); let endpoint = incoming_endpoints.pop_front();
if incoming_endpoints.is_empty() { if incoming_endpoints.is_empty() {
@ -218,9 +218,6 @@ pub(super) fn unregister_backlog(addr: &UnixSocketAddrBound) {
BACKLOG_TABLE.remove_backlog(addr); BACKLOG_TABLE.remove_backlog(addr);
} }
pub(super) fn push_incoming( pub(super) fn push_incoming(remote_addr: &UnixSocketAddrBound, remote_end: Endpoint) -> Result<()> {
remote_addr: &UnixSocketAddrBound,
remote_end: Arc<Endpoint>,
) -> Result<()> {
BACKLOG_TABLE.push_incoming(remote_addr, remote_end) BACKLOG_TABLE.push_incoming(remote_addr, remote_end)
} }

View File

@ -52,7 +52,8 @@ impl UnixStreamSocket {
} }
pub fn new_pair(nonblocking: bool) -> Result<(Arc<Self>, Arc<Self>)> { pub fn new_pair(nonblocking: bool) -> Result<(Arc<Self>, Arc<Self>)> {
let (end_a, end_b) = Endpoint::new_pair(nonblocking)?; let (end_a, end_b) = Endpoint::new_pair(None, None, nonblocking)?;
let connected_a = { let connected_a = {
let connected = Connected::new(end_a); let connected = Connected::new(end_a);
Self::new_connected(connected) Self::new_connected(connected)
@ -61,6 +62,7 @@ impl UnixStreamSocket {
let connected = Connected::new(end_b); let connected = Connected::new(end_b);
Self::new_connected(connected) Self::new_connected(connected)
}; };
Ok((Arc::new(connected_a), Arc::new(connected_b))) Ok((Arc::new(connected_a), Arc::new(connected_b)))
} }
@ -69,7 +71,7 @@ impl UnixStreamSocket {
match &*status { match &*status {
State::Init(init) => init.addr(), State::Init(init) => init.addr(),
State::Listen(listen) => Some(listen.addr().clone()), State::Listen(listen) => Some(listen.addr().clone()),
State::Connected(connected) => connected.addr(), State::Connected(connected) => connected.addr().cloned(),
} }
} }
@ -250,7 +252,7 @@ impl Socket for UnixStreamSocket {
let addr = match &*self.0.read() { let addr = match &*self.0.read() {
State::Init(init) => init.addr(), State::Init(init) => init.addr(),
State::Listen(listen) => Some(listen.addr().clone()), State::Listen(listen) => Some(listen.addr().clone()),
State::Connected(connected) => connected.addr(), State::Connected(connected) => connected.addr().cloned(),
}; };
addr.map(Into::<SocketAddr>::into) addr.map(Into::<SocketAddr>::into)