From ac19a7e0e7c55551df1410be0a19422b1c36a72b Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Thu, 15 Aug 2024 13:17:36 +0800 Subject: [PATCH] Fix race conditions in UNIX socket `listen()` --- .../src/net/socket/unix/stream/listener.rs | 15 ++++++++------- .../src/net/socket/unix/stream/socket.rs | 9 ++++++--- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/kernel/aster-nix/src/net/socket/unix/stream/listener.rs b/kernel/aster-nix/src/net/socket/unix/stream/listener.rs index c479de8ac..3d746c870 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/listener.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/listener.rs @@ -16,9 +16,9 @@ pub(super) struct Listener { } impl Listener { - pub(super) fn new(addr: UnixSocketAddrBound, backlog: usize) -> Result { - let backlog = BACKLOG_TABLE.add_backlog(addr, backlog)?; - Ok(Self { backlog }) + pub(super) fn new(addr: UnixSocketAddrBound, backlog: usize) -> Self { + let backlog = BACKLOG_TABLE.add_backlog(addr, backlog).unwrap(); + Self { backlog } } pub(super) fn addr(&self) -> &UnixSocketAddrBound { @@ -78,21 +78,22 @@ impl BacklogTable { } } - fn add_backlog(&self, addr: UnixSocketAddrBound, backlog: usize) -> Result> { + fn add_backlog(&self, addr: UnixSocketAddrBound, backlog: usize) -> Option> { let inode = { let UnixSocketAddrBound::Path(_, ref dentry) = addr else { todo!() }; create_keyable_inode(dentry) }; + let new_backlog = Arc::new(Backlog::new(addr, backlog)); let mut backlog_sockets = self.backlog_sockets.write(); if backlog_sockets.contains_key(&inode) { - return_errno_with_message!(Errno::EADDRINUSE, "the addr is already used"); + return None; } - let new_backlog = Arc::new(Backlog::new(addr, backlog)); backlog_sockets.insert(inode, new_backlog.clone()); - Ok(new_backlog) + + Some(new_backlog) } fn get_backlog(&self, addr: &UnixSocketAddrBound) -> Result> { diff --git a/kernel/aster-nix/src/net/socket/unix/stream/socket.rs b/kernel/aster-nix/src/net/socket/unix/stream/socket.rs index eb9337296..53c51ff26 100644 --- a/kernel/aster-nix/src/net/socket/unix/stream/socket.rs +++ b/kernel/aster-nix/src/net/socket/unix/stream/socket.rs @@ -232,7 +232,9 @@ impl Socket for UnixStreamSocket { } fn listen(&self, backlog: usize) -> Result<()> { - let addr = match &*self.state.read() { + let mut state = self.state.write(); + + let addr = match &*state { State::Init(init) => init .addr() .ok_or(Error::with_message( @@ -248,8 +250,9 @@ impl Socket for UnixStreamSocket { } }; - let listener = Listener::new(addr, backlog)?; - *self.state.write() = State::Listen(listener); + let listener = Listener::new(addr, backlog); + *state = State::Listen(listener); + Ok(()) }