diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs b/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs index 0ad89de04..b250a961a 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs @@ -112,24 +112,25 @@ impl Socket for StreamSocket { fn connect(&self, sockaddr: SocketAddr) -> Result<()> { let remote_endpoint = sockaddr.try_into()?; - let state = self.state.read(); - match &*state { - State::Init(init_stream) => { - let init_stream = init_stream.clone(); - drop(state); - init_stream.connect(&remote_endpoint)?; - let nonblocking = init_stream.is_nonblocking(); - let bound_socket = init_stream.bound_socket().unwrap(); - let connected_stream = Arc::new(ConnectedStream::new( - nonblocking, - bound_socket, - remote_endpoint, - )); - *self.state.write() = State::Connected(connected_stream); - Ok(()) - } + + let init_stream = match &*self.state.read() { + State::Init(init_stream) => init_stream.clone(), _ => return_errno_with_message!(Errno::EINVAL, "cannot connect"), - } + }; + + init_stream.connect(&remote_endpoint)?; + + let connected_stream = { + let nonblocking = init_stream.is_nonblocking(); + let bound_socket = init_stream.bound_socket().unwrap(); + Arc::new(ConnectedStream::new( + nonblocking, + bound_socket, + remote_endpoint, + )) + }; + *self.state.write() = State::Connected(connected_stream); + Ok(()) } fn listen(&self, backlog: usize) -> Result<()> { @@ -153,17 +154,23 @@ impl Socket for StreamSocket { } fn accept(&self) -> Result<(Arc, SocketAddr)> { - let state = self.state.read(); - match &*state { - State::Listen(listen_stream) => { - let (connected_stream, remote_endpoint) = listen_stream.accept()?; - let state = RwLock::new(State::Connected(Arc::new(connected_stream))); - let accepted_socket = Arc::new(StreamSocket { state }); - let socket_addr = remote_endpoint.try_into()?; - Ok((accepted_socket, socket_addr)) - } + let listen_stream = match &*self.state.read() { + State::Listen(listen_stream) => listen_stream.clone(), _ => return_errno_with_message!(Errno::EINVAL, "the socket is not listening"), - } + }; + + let (connected_stream, remote_endpoint) = { + let listen_stream = listen_stream.clone(); + listen_stream.accept()? + }; + + let accepted_socket = { + let state = RwLock::new(State::Connected(Arc::new(connected_stream))); + Arc::new(StreamSocket { state }) + }; + + let socket_addr = remote_endpoint.try_into()?; + Ok((accepted_socket, socket_addr)) } fn shutdown(&self, cmd: SockShutdownCmd) -> Result<()> { @@ -202,11 +209,12 @@ impl Socket for StreamSocket { } fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { - let state = self.state.read(); - let (recv_size, remote_endpoint) = match &*state { - State::Connected(connected_stream) => connected_stream.recvfrom(buf, flags), - _ => return_errno_with_message!(Errno::EINVAL, "cannot recv"), - }?; + let connected_stream = match &*self.state.read() { + State::Connected(connected_stream) => connected_stream.clone(), + _ => return_errno_with_message!(Errno::EINVAL, "the socket is not connected"), + }; + + let (recv_size, remote_endpoint) = connected_stream.recvfrom(buf, flags)?; let socket_addr = remote_endpoint.try_into()?; Ok((recv_size, socket_addr)) }