diff --git a/kernel/aster-nix/src/net/socket/unix/stream/init.rs b/kernel/aster-nix/src/net/socket/unix/stream/init.rs index e5378f966..ac9d83c4a 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/init.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/init.rs @@ -14,21 +14,20 @@ use crate::{ }; pub(super) struct Init { - addr: Mutex>, + addr: Option, pollee: Pollee, } impl Init { pub(super) fn new() -> Self { Self { - addr: Mutex::new(None), + addr: None, pollee: Pollee::new(IoEvents::empty()), } } - pub(super) fn bind(&self, addr_to_bind: &UnixSocketAddr) -> Result<()> { - let mut addr = self.addr.lock(); - if addr.is_some() { + pub(super) fn bind(&mut self, addr_to_bind: &UnixSocketAddr) -> Result<()> { + if self.addr.is_some() { return_errno_with_message!(Errno::EINVAL, "the socket is already bound"); } @@ -40,28 +39,28 @@ impl Init { UnixSocketAddrBound::Path(dentry) } }; + self.addr = Some(bound_addr); - *addr = Some(bound_addr); Ok(()) } pub(super) fn connect(&self, remote_addr: &UnixSocketAddrBound) -> Result { let addr = self.addr(); - if let Some(ref addr) = addr { + if let Some(addr) = addr { if *addr == *remote_addr { return_errno_with_message!(Errno::EINVAL, "try to connect to self is invalid"); } } - let (this_end, remote_end) = Endpoint::new_pair(addr, Some(remote_addr.clone())); + let (this_end, remote_end) = Endpoint::new_pair(addr.cloned(), Some(remote_addr.clone())); push_incoming(remote_addr, remote_end)?; Ok(Connected::new(this_end)) } - pub(super) fn addr(&self) -> Option { - self.addr.lock().clone() + pub(super) fn addr(&self) -> Option<&UnixSocketAddrBound> { + self.addr.as_ref() } pub(super) fn poll(&self, mask: IoEvents, poller: Option<&mut Poller>) -> IoEvents { diff --git a/kernel/aster-nix/src/net/socket/unix/stream/socket.rs b/kernel/aster-nix/src/net/socket/unix/stream/socket.rs index e1e85b7ae..930933bf3 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/socket.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/socket.rs @@ -74,7 +74,7 @@ impl UnixStreamSocket { fn bound_addr(&self) -> Option { let state = self.state.read(); match &*state { - State::Init(init) => init.addr(), + State::Init(init) => init.addr().cloned(), State::Listen(listen) => Some(listen.addr().clone()), State::Connected(connected) => connected.addr().cloned(), } @@ -197,7 +197,7 @@ impl Socket for UnixStreamSocket { fn bind(&self, socket_addr: SocketAddr) -> Result<()> { let addr = UnixSocketAddr::try_from(socket_addr)?; - match &*self.state.read() { + match &mut *self.state.write() { State::Init(init) => init.bind(&addr), _ => return_errno_with_message!( Errno::EINVAL, @@ -273,7 +273,7 @@ impl Socket for UnixStreamSocket { fn addr(&self) -> Result { let addr = match &*self.state.read() { - State::Init(init) => init.addr(), + State::Init(init) => init.addr().cloned(), State::Listen(listen) => Some(listen.addr().clone()), State::Connected(connected) => connected.addr().cloned(), };