mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-17 12:47:16 +00:00
Unpack Arc<Endpoint>
in UNIX sockets
This commit is contained in:
parent
7ee728e2c7
commit
31d99c66c2
@ -9,19 +9,19 @@ use crate::{
|
||||
};
|
||||
|
||||
pub(super) struct Connected {
|
||||
local_endpoint: Arc<Endpoint>,
|
||||
local_endpoint: Endpoint,
|
||||
}
|
||||
|
||||
impl Connected {
|
||||
pub(super) fn new(local_endpoint: Arc<Endpoint>) -> Self {
|
||||
pub(super) fn new(local_endpoint: Endpoint) -> Self {
|
||||
Connected { local_endpoint }
|
||||
}
|
||||
|
||||
pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> {
|
||||
pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> {
|
||||
self.local_endpoint.addr()
|
||||
}
|
||||
|
||||
pub(super) fn peer_addr(&self) -> Option<UnixSocketAddrBound> {
|
||||
pub(super) fn peer_addr(&self) -> Option<&UnixSocketAddrBound> {
|
||||
self.local_endpoint.peer_addr()
|
||||
}
|
||||
|
||||
|
@ -8,118 +8,114 @@ use crate::{
|
||||
process::signal::Poller,
|
||||
};
|
||||
|
||||
pub(super) struct Endpoint(Inner);
|
||||
|
||||
struct Inner {
|
||||
addr: RwLock<Option<UnixSocketAddrBound>>,
|
||||
pub(super) struct Endpoint {
|
||||
addr: Option<UnixSocketAddrBound>,
|
||||
peer_addr: Option<UnixSocketAddrBound>,
|
||||
reader: Consumer<u8>,
|
||||
writer: Producer<u8>,
|
||||
peer: Weak<Endpoint>,
|
||||
}
|
||||
|
||||
impl Endpoint {
|
||||
pub(super) fn new_pair(is_nonblocking: bool) -> Result<(Arc<Endpoint>, Arc<Endpoint>)> {
|
||||
pub(super) fn new_pair(
|
||||
addr: Option<UnixSocketAddrBound>,
|
||||
peer_addr: Option<UnixSocketAddrBound>,
|
||||
is_nonblocking: bool,
|
||||
) -> Result<(Endpoint, Endpoint)> {
|
||||
let flags = if is_nonblocking {
|
||||
StatusFlags::O_NONBLOCK
|
||||
} else {
|
||||
StatusFlags::empty()
|
||||
};
|
||||
let (writer_a, reader_b) =
|
||||
|
||||
let (writer_this, reader_peer) =
|
||||
Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split();
|
||||
let (writer_b, reader_a) =
|
||||
let (writer_peer, reader_this) =
|
||||
Channel::with_capacity_and_flags(DAFAULT_BUF_SIZE, flags)?.split();
|
||||
let mut endpoint_b = None;
|
||||
let endpoint_a = Arc::new_cyclic(|endpoint_a_ref| {
|
||||
let peer = Arc::new(Endpoint::new(reader_b, writer_b, endpoint_a_ref.clone()));
|
||||
let endpoint_a = Endpoint::new(reader_a, writer_a, Arc::downgrade(&peer));
|
||||
endpoint_b = Some(peer);
|
||||
endpoint_a
|
||||
});
|
||||
Ok((endpoint_a, endpoint_b.unwrap()))
|
||||
|
||||
let this = Endpoint {
|
||||
addr: addr.clone(),
|
||||
peer_addr: peer_addr.clone(),
|
||||
reader: reader_this,
|
||||
writer: writer_this,
|
||||
};
|
||||
let peer = Endpoint {
|
||||
addr: peer_addr,
|
||||
peer_addr: addr,
|
||||
reader: reader_peer,
|
||||
writer: writer_peer,
|
||||
};
|
||||
|
||||
Ok((this, peer))
|
||||
}
|
||||
|
||||
fn new(reader: Consumer<u8>, writer: Producer<u8>, peer: Weak<Endpoint>) -> Self {
|
||||
Self(Inner {
|
||||
addr: RwLock::new(None),
|
||||
reader,
|
||||
writer,
|
||||
peer,
|
||||
})
|
||||
pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> {
|
||||
self.addr.as_ref()
|
||||
}
|
||||
|
||||
pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> {
|
||||
self.0.addr.read().clone()
|
||||
}
|
||||
|
||||
pub(super) fn set_addr(&self, addr: UnixSocketAddrBound) {
|
||||
*self.0.addr.write() = Some(addr);
|
||||
}
|
||||
|
||||
pub(super) fn peer_addr(&self) -> Option<UnixSocketAddrBound> {
|
||||
self.0.peer.upgrade().and_then(|peer| peer.addr())
|
||||
pub(super) fn peer_addr(&self) -> Option<&UnixSocketAddrBound> {
|
||||
self.peer_addr.as_ref()
|
||||
}
|
||||
|
||||
pub(super) fn is_nonblocking(&self) -> bool {
|
||||
let reader_status = self.0.reader.is_nonblocking();
|
||||
let writer_status = self.0.writer.is_nonblocking();
|
||||
let reader_status = self.reader.is_nonblocking();
|
||||
let writer_status = self.writer.is_nonblocking();
|
||||
|
||||
debug_assert!(reader_status == writer_status);
|
||||
|
||||
reader_status
|
||||
}
|
||||
|
||||
pub(super) fn set_nonblocking(&self, is_nonblocking: bool) -> Result<()> {
|
||||
let mut reader_flags = self.0.reader.status_flags();
|
||||
let mut reader_flags = self.reader.status_flags();
|
||||
reader_flags.set(StatusFlags::O_NONBLOCK, is_nonblocking);
|
||||
self.0.reader.set_status_flags(reader_flags)?;
|
||||
self.reader.set_status_flags(reader_flags)?;
|
||||
|
||||
let mut writer_flags = self.0.writer.status_flags();
|
||||
let mut writer_flags = self.writer.status_flags();
|
||||
writer_flags.set(StatusFlags::O_NONBLOCK, is_nonblocking);
|
||||
self.0.writer.set_status_flags(writer_flags)?;
|
||||
self.writer.set_status_flags(writer_flags)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn read(&self, buf: &mut [u8]) -> Result<usize> {
|
||||
self.0.reader.read(buf)
|
||||
self.reader.read(buf)
|
||||
}
|
||||
|
||||
pub(super) fn write(&self, buf: &[u8]) -> Result<usize> {
|
||||
self.0.writer.write(buf)
|
||||
self.writer.write(buf)
|
||||
}
|
||||
|
||||
pub(super) fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> {
|
||||
if !self.is_connected() {
|
||||
return_errno_with_message!(Errno::ENOTCONN, "The socket is not connected.");
|
||||
}
|
||||
// FIXME: If the socket has already been shut down, should we return an error code?
|
||||
|
||||
if cmd.shut_read() {
|
||||
self.0.reader.shutdown();
|
||||
self.reader.shutdown();
|
||||
}
|
||||
|
||||
if cmd.shut_write() {
|
||||
self.0.writer.shutdown();
|
||||
self.writer.shutdown();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn is_connected(&self) -> bool {
|
||||
self.0.peer.upgrade().is_some()
|
||||
}
|
||||
|
||||
pub(super) fn poll(&self, mask: IoEvents, mut poller: Option<&mut Poller>) -> IoEvents {
|
||||
let mut events = IoEvents::empty();
|
||||
// FIXME: should reader and writer use the same mask?
|
||||
let reader_events = self.0.reader.poll(mask, poller.as_deref_mut());
|
||||
let writer_events = self.0.writer.poll(mask, poller);
|
||||
|
||||
if reader_events.contains(IoEvents::HUP) || self.0.reader.is_shutdown() {
|
||||
// FIXME: should reader and writer use the same mask?
|
||||
let reader_events = self.reader.poll(mask, poller.as_deref_mut());
|
||||
let writer_events = self.writer.poll(mask, poller);
|
||||
|
||||
// FIXME: Check this logic later.
|
||||
if reader_events.contains(IoEvents::HUP) || self.reader.is_shutdown() {
|
||||
events |= IoEvents::RDHUP | IoEvents::IN;
|
||||
if writer_events.contains(IoEvents::ERR) || self.0.writer.is_shutdown() {
|
||||
if writer_events.contains(IoEvents::ERR) || self.writer.is_shutdown() {
|
||||
events |= IoEvents::HUP | IoEvents::OUT;
|
||||
}
|
||||
}
|
||||
|
||||
events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT);
|
||||
|
||||
events
|
||||
}
|
||||
}
|
||||
|
@ -57,11 +57,8 @@ impl Init {
|
||||
}
|
||||
}
|
||||
|
||||
let (this_end, remote_end) = Endpoint::new_pair(self.is_nonblocking())?;
|
||||
remote_end.set_addr(remote_addr.clone());
|
||||
if let Some(addr) = addr {
|
||||
this_end.set_addr(addr.clone());
|
||||
};
|
||||
let (this_end, remote_end) =
|
||||
Endpoint::new_pair(addr, Some(remote_addr.clone()), self.is_nonblocking())?;
|
||||
|
||||
push_incoming(remote_addr, remote_end)?;
|
||||
Ok(Connected::new(this_end))
|
||||
|
@ -118,7 +118,7 @@ impl BacklogTable {
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "the socket is not listened"))
|
||||
}
|
||||
|
||||
fn pop_incoming(&self, nonblocking: bool, addr: &UnixSocketAddrBound) -> Result<Arc<Endpoint>> {
|
||||
fn pop_incoming(&self, nonblocking: bool, addr: &UnixSocketAddrBound) -> Result<Endpoint> {
|
||||
let mut poller = Poller::new();
|
||||
loop {
|
||||
let backlog = self.get_backlog(addr)?;
|
||||
@ -147,7 +147,7 @@ impl BacklogTable {
|
||||
}
|
||||
}
|
||||
|
||||
fn push_incoming(&self, addr: &UnixSocketAddrBound, endpoint: Arc<Endpoint>) -> Result<()> {
|
||||
fn push_incoming(&self, addr: &UnixSocketAddrBound, endpoint: Endpoint) -> Result<()> {
|
||||
let backlog = self.get_backlog(addr).map_err(|_| {
|
||||
Error::with_message(
|
||||
Errno::ECONNREFUSED,
|
||||
@ -171,7 +171,7 @@ impl BacklogTable {
|
||||
struct Backlog {
|
||||
pollee: Pollee,
|
||||
backlog: usize,
|
||||
incoming_endpoints: Mutex<VecDeque<Arc<Endpoint>>>,
|
||||
incoming_endpoints: Mutex<VecDeque<Endpoint>>,
|
||||
}
|
||||
|
||||
impl Backlog {
|
||||
@ -183,7 +183,7 @@ impl Backlog {
|
||||
}
|
||||
}
|
||||
|
||||
fn push_incoming(&self, endpoint: Arc<Endpoint>) -> Result<()> {
|
||||
fn push_incoming(&self, endpoint: Endpoint) -> Result<()> {
|
||||
let mut endpoints = self.incoming_endpoints.lock();
|
||||
if endpoints.len() >= self.backlog {
|
||||
return_errno_with_message!(Errno::ECONNREFUSED, "incoming_endpoints is full");
|
||||
@ -193,7 +193,7 @@ impl Backlog {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn pop_incoming(&self) -> Option<Arc<Endpoint>> {
|
||||
fn pop_incoming(&self) -> Option<Endpoint> {
|
||||
let mut incoming_endpoints = self.incoming_endpoints.lock();
|
||||
let endpoint = incoming_endpoints.pop_front();
|
||||
if incoming_endpoints.is_empty() {
|
||||
@ -218,9 +218,6 @@ pub(super) fn unregister_backlog(addr: &UnixSocketAddrBound) {
|
||||
BACKLOG_TABLE.remove_backlog(addr);
|
||||
}
|
||||
|
||||
pub(super) fn push_incoming(
|
||||
remote_addr: &UnixSocketAddrBound,
|
||||
remote_end: Arc<Endpoint>,
|
||||
) -> Result<()> {
|
||||
pub(super) fn push_incoming(remote_addr: &UnixSocketAddrBound, remote_end: Endpoint) -> Result<()> {
|
||||
BACKLOG_TABLE.push_incoming(remote_addr, remote_end)
|
||||
}
|
||||
|
@ -52,7 +52,8 @@ impl UnixStreamSocket {
|
||||
}
|
||||
|
||||
pub fn new_pair(nonblocking: bool) -> Result<(Arc<Self>, Arc<Self>)> {
|
||||
let (end_a, end_b) = Endpoint::new_pair(nonblocking)?;
|
||||
let (end_a, end_b) = Endpoint::new_pair(None, None, nonblocking)?;
|
||||
|
||||
let connected_a = {
|
||||
let connected = Connected::new(end_a);
|
||||
Self::new_connected(connected)
|
||||
@ -61,6 +62,7 @@ impl UnixStreamSocket {
|
||||
let connected = Connected::new(end_b);
|
||||
Self::new_connected(connected)
|
||||
};
|
||||
|
||||
Ok((Arc::new(connected_a), Arc::new(connected_b)))
|
||||
}
|
||||
|
||||
@ -69,7 +71,7 @@ impl UnixStreamSocket {
|
||||
match &*status {
|
||||
State::Init(init) => init.addr(),
|
||||
State::Listen(listen) => Some(listen.addr().clone()),
|
||||
State::Connected(connected) => connected.addr(),
|
||||
State::Connected(connected) => connected.addr().cloned(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -250,7 +252,7 @@ impl Socket for UnixStreamSocket {
|
||||
let addr = match &*self.0.read() {
|
||||
State::Init(init) => init.addr(),
|
||||
State::Listen(listen) => Some(listen.addr().clone()),
|
||||
State::Connected(connected) => connected.addr(),
|
||||
State::Connected(connected) => connected.addr().cloned(),
|
||||
};
|
||||
|
||||
addr.map(Into::<SocketAddr>::into)
|
||||
|
Loading…
x
Reference in New Issue
Block a user