Avoid nested locks

This commit is contained in:
Ruihan Li
2024-07-26 22:42:49 +08:00
committed by Tate, Hongliang Tian
parent 5445a26ec5
commit 8bcb5a2a5f
2 changed files with 12 additions and 13 deletions

View File

@ -14,21 +14,20 @@ use crate::{
}; };
pub(super) struct Init { pub(super) struct Init {
addr: Mutex<Option<UnixSocketAddrBound>>, addr: Option<UnixSocketAddrBound>,
pollee: Pollee, pollee: Pollee,
} }
impl Init { impl Init {
pub(super) fn new() -> Self { pub(super) fn new() -> Self {
Self { Self {
addr: Mutex::new(None), addr: None,
pollee: Pollee::new(IoEvents::empty()), pollee: Pollee::new(IoEvents::empty()),
} }
} }
pub(super) fn bind(&self, addr_to_bind: &UnixSocketAddr) -> Result<()> { pub(super) fn bind(&mut self, addr_to_bind: &UnixSocketAddr) -> Result<()> {
let mut addr = self.addr.lock(); if self.addr.is_some() {
if addr.is_some() {
return_errno_with_message!(Errno::EINVAL, "the socket is already bound"); return_errno_with_message!(Errno::EINVAL, "the socket is already bound");
} }
@ -40,28 +39,28 @@ impl Init {
UnixSocketAddrBound::Path(dentry) UnixSocketAddrBound::Path(dentry)
} }
}; };
self.addr = Some(bound_addr);
*addr = Some(bound_addr);
Ok(()) Ok(())
} }
pub(super) fn connect(&self, remote_addr: &UnixSocketAddrBound) -> Result<Connected> { pub(super) fn connect(&self, remote_addr: &UnixSocketAddrBound) -> Result<Connected> {
let addr = self.addr(); let addr = self.addr();
if let Some(ref addr) = addr { if let Some(addr) = addr {
if *addr == *remote_addr { if *addr == *remote_addr {
return_errno_with_message!(Errno::EINVAL, "try to connect to self is invalid"); return_errno_with_message!(Errno::EINVAL, "try to connect to self is invalid");
} }
} }
let (this_end, remote_end) = Endpoint::new_pair(addr, Some(remote_addr.clone())); let (this_end, remote_end) = Endpoint::new_pair(addr.cloned(), Some(remote_addr.clone()));
push_incoming(remote_addr, remote_end)?; push_incoming(remote_addr, remote_end)?;
Ok(Connected::new(this_end)) Ok(Connected::new(this_end))
} }
pub(super) fn addr(&self) -> Option<UnixSocketAddrBound> { pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> {
self.addr.lock().clone() self.addr.as_ref()
} }
pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents {

View File

@ -74,7 +74,7 @@ impl UnixStreamSocket {
fn bound_addr(&self) -> Option<UnixSocketAddrBound> { fn bound_addr(&self) -> Option<UnixSocketAddrBound> {
let state = self.state.read(); let state = self.state.read();
match &*state { match &*state {
State::Init(init) => init.addr(), State::Init(init) => init.addr().cloned(),
State::Listen(listen) => Some(listen.addr().clone()), State::Listen(listen) => Some(listen.addr().clone()),
State::Connected(connected) => connected.addr().cloned(), State::Connected(connected) => connected.addr().cloned(),
} }
@ -197,7 +197,7 @@ impl Socket for UnixStreamSocket {
fn bind(&self, socket_addr: SocketAddr) -> Result<()> { fn bind(&self, socket_addr: SocketAddr) -> Result<()> {
let addr = UnixSocketAddr::try_from(socket_addr)?; let addr = UnixSocketAddr::try_from(socket_addr)?;
match &*self.state.read() { match &mut *self.state.write() {
State::Init(init) => init.bind(&addr), State::Init(init) => init.bind(&addr),
_ => return_errno_with_message!( _ => return_errno_with_message!(
Errno::EINVAL, Errno::EINVAL,
@ -273,7 +273,7 @@ impl Socket for UnixStreamSocket {
fn addr(&self) -> Result<SocketAddr> { fn addr(&self) -> Result<SocketAddr> {
let addr = match &*self.state.read() { let addr = match &*self.state.read() {
State::Init(init) => init.addr(), State::Init(init) => init.addr().cloned(),
State::Listen(listen) => Some(listen.addr().clone()), State::Listen(listen) => Some(listen.addr().clone()),
State::Connected(connected) => connected.addr().cloned(), State::Connected(connected) => connected.addr().cloned(),
}; };