diff --git a/services/libs/jinux-std/src/net/socket/ip/datagram.rs b/services/libs/jinux-std/src/net/socket/ip/datagram.rs index 939743ab8..718e29795 100644 --- a/services/libs/jinux-std/src/net/socket/ip/datagram.rs +++ b/services/libs/jinux-std/src/net/socket/ip/datagram.rs @@ -28,10 +28,78 @@ pub struct DatagramSocket { enum Inner { Unbound(AlwaysSome), - Bound { - bound_socket: Arc, - remote_endpoint: Option, - }, + Bound(Arc), +} + +struct BoundDatagram { + bound_socket: Arc, + remote_endpoint: RwLock>, +} + +impl BoundDatagram { + fn new(bound_socket: Arc) -> Arc { + Arc::new(Self { + bound_socket, + remote_endpoint: RwLock::new(None), + }) + } + + fn remote_endpoint(&self) -> Result { + if let Some(endpoint) = *self.remote_endpoint.read() { + Ok(endpoint) + } else { + return_errno_with_message!(Errno::EINVAL, "remote endpoint is not specified") + } + } + + fn set_remote_endpoint(&self, endpoint: IpEndpoint) { + *self.remote_endpoint.write() = Some(endpoint); + } + + fn local_endpoint(&self) -> Result { + if let Some(endpoint) = self.bound_socket.local_endpoint() { + Ok(endpoint) + } else { + return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint") + } + } + + fn try_recvfrom(&self, buf: &mut [u8], flags: &SendRecvFlags) -> Result<(usize, IpEndpoint)> { + poll_ifaces(); + let recv_slice = |socket: &mut RawUdpSocket| match socket.recv_slice(buf) { + Err(smoltcp::socket::udp::RecvError::Exhausted) => { + return_errno_with_message!(Errno::EAGAIN, "recv buf is empty") + } + Ok((len, remote_endpoint)) => Ok((len, remote_endpoint)), + }; + self.bound_socket.raw_with(recv_slice) + } + + fn try_sendto( + &self, + buf: &[u8], + remote: Option, + flags: SendRecvFlags, + ) -> Result { + let remote_endpoint = remote + .or_else(|| self.remote_endpoint().ok()) + .ok_or_else(|| Error::with_message(Errno::EINVAL, "udp should provide remote addr"))?; + let send_slice = |socket: &mut RawUdpSocket| match socket.send_slice(buf, remote_endpoint) { + Err(_) => return_errno_with_message!(Errno::ENOBUFS, "send udp packet fails"), + Ok(()) => Ok(buf.len()), + }; + let len = self.bound_socket.raw_with(send_slice)?; + poll_ifaces(); + Ok(len) + } + + fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + self.bound_socket.poll(mask, poller) + } + + fn update_socket_state(&self) { + self.bound_socket.update_socket_state(); + } } impl Inner { @@ -39,7 +107,7 @@ impl Inner { matches!(self, Inner::Bound { .. }) } - fn bind(&mut self, endpoint: IpEndpoint) -> Result<()> { + fn bind(&mut self, endpoint: IpEndpoint) -> Result> { if self.is_bound() { return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address"); } @@ -55,69 +123,25 @@ impl Inner { .bind(bound_endpoint) .map_err(|_| Error::with_message(Errno::EINVAL, "cannot bind socket")) })?; - *self = Inner::Bound { - bound_socket, - remote_endpoint: None, - }; + let bound = BoundDatagram::new(bound_socket); + *self = Inner::Bound(bound.clone()); // Once the socket is bound, we should update the socket state at once. - self.update_socket_state(); - Ok(()) + bound.update_socket_state(); + Ok(bound) } - fn bind_to_ephemeral_endpoint(&mut self, remote_endpoint: &IpEndpoint) -> Result<()> { + fn bind_to_ephemeral_endpoint( + &mut self, + remote_endpoint: &IpEndpoint, + ) -> Result> { let endpoint = get_ephemeral_endpoint(remote_endpoint); self.bind(endpoint) } - fn set_remote_endpoint(&mut self, endpoint: IpEndpoint) -> Result<()> { - if let Inner::Bound { - remote_endpoint, .. - } = self - { - *remote_endpoint = Some(endpoint); - Ok(()) - } else { - return_errno_with_message!(Errno::EINVAL, "the socket is not bound"); - } - } - - fn remote_endpoint(&self) -> Option { - if let Inner::Bound { - remote_endpoint, .. - } = self - { - *remote_endpoint - } else { - None - } - } - - fn local_endpoint(&self) -> Option { - if let Inner::Bound { bound_socket, .. } = self { - bound_socket.local_endpoint() - } else { - None - } - } - - fn bound_socket(&self) -> Option> { - if let Inner::Bound { bound_socket, .. } = self { - Some(bound_socket.clone()) - } else { - None - } - } - fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { match self { Inner::Unbound(unbound_socket) => unbound_socket.poll(mask, poller), - Inner::Bound { bound_socket, .. } => bound_socket.poll(mask, poller), - } - } - - fn update_socket_state(&self) { - if let Inner::Bound { bound_socket, .. } = self { - bound_socket.update_socket_state(); + Inner::Bound(bound) => bound.poll(mask, poller), } } } @@ -135,28 +159,6 @@ impl DatagramSocket { self.inner.read().is_bound() } - fn try_recvfrom(&self, buf: &mut [u8], flags: &SendRecvFlags) -> Result<(usize, IpEndpoint)> { - poll_ifaces(); - let bound_socket = self.inner.read().bound_socket().unwrap(); - let recv_slice = |socket: &mut RawUdpSocket| match socket.recv_slice(buf) { - Err(smoltcp::socket::udp::RecvError::Exhausted) => { - return_errno_with_message!(Errno::EAGAIN, "recv buf is empty") - } - Ok((len, remote_endpoint)) => Ok((len, remote_endpoint)), - }; - bound_socket.raw_with(recv_slice) - } - - fn remote_endpoint(&self) -> Result { - self.inner - .read() - .remote_endpoint() - .ok_or(Error::with_message( - Errno::EINVAL, - "udp should provide remote addr", - )) - } - pub fn is_nonblocking(&self) -> bool { self.nonblocking.load(Ordering::SeqCst) } @@ -164,6 +166,27 @@ impl DatagramSocket { pub fn set_nonblocking(&self, nonblocking: bool) { self.nonblocking.store(nonblocking, Ordering::SeqCst); } + + fn bound(&self) -> Result> { + if let Inner::Bound(bound) = &*self.inner.read() { + Ok(bound.clone()) + } else { + return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint") + } + } + + fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result> { + if let Inner::Bound(bound) = &*self.inner.read() { + return Ok(bound.clone()); + } + + let mut inner = self.inner.write(); + if let Inner::Bound(bound) = &*inner { + Ok(bound.clone()) + } else { + inner.bind_to_ephemeral_endpoint(remote_endpoint) + } + } } impl FileLike for DatagramSocket { @@ -209,50 +232,36 @@ impl FileLike for DatagramSocket { impl Socket for DatagramSocket { fn bind(&self, sockaddr: SocketAddr) -> Result<()> { let endpoint = sockaddr.try_into()?; - self.inner.write().bind(endpoint) + self.inner.write().bind(endpoint)?; + Ok(()) } fn connect(&self, sockaddr: SocketAddr) -> Result<()> { - let remote_endpoint: IpEndpoint = sockaddr.try_into()?; - let mut inner = self.inner.write(); - if !self.is_bound() { - inner.bind_to_ephemeral_endpoint(&remote_endpoint)? - } - inner.set_remote_endpoint(remote_endpoint)?; - inner.update_socket_state(); + let endpoint = sockaddr.try_into()?; + let bound = self.try_bind_empheral(&endpoint)?; + bound.set_remote_endpoint(endpoint); Ok(()) } fn addr(&self) -> Result { - if let Some(local_endpoint) = self.inner.read().local_endpoint() { - local_endpoint.try_into() - } else { - return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint"); - } + self.bound()?.local_endpoint()?.try_into() } fn peer_addr(&self) -> Result { - if let Some(remote_endpoint) = self.inner.read().remote_endpoint() { - remote_endpoint.try_into() - } else { - return_errno_with_message!(Errno::EINVAL, "remote endpoint is not specified"); - } + self.bound()?.remote_endpoint()?.try_into() } // FIXME: respect RecvFromFlags fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { debug_assert!(flags.is_all_supported()); - if !self.is_bound() { - return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint"); - } + let bound = self.bound()?; let poller = Poller::new(); - let bound_socket = self.inner.read().bound_socket().unwrap(); loop { - if let Ok((recv_len, remote_endpoint)) = self.try_recvfrom(buf, &flags) { + if let Ok((recv_len, remote_endpoint)) = bound.try_recvfrom(buf, &flags) { let remote_addr = remote_endpoint.try_into()?; return Ok((recv_len, remote_addr)); } - let events = self.inner.read().poll(IoEvents::IN, Some(&poller)); + let events = bound.poll(IoEvents::IN, Some(&poller)); if !events.contains(IoEvents::IN) { if self.is_nonblocking() { return_errno_with_message!(Errno::EAGAIN, "try to receive again"); @@ -269,23 +278,14 @@ impl Socket for DatagramSocket { remote: Option, flags: SendRecvFlags, ) -> Result { - let remote_endpoint: IpEndpoint = if let Some(remote_addr) = remote { - remote_addr.try_into()? + debug_assert!(flags.is_all_supported()); + let (bound, remote_endpoint) = if let Some(addr) = remote { + let endpoint = addr.try_into()?; + (self.try_bind_empheral(&endpoint)?, Some(endpoint)) } else { - self.remote_endpoint()? + let bound = self.bound()?; + (bound, None) }; - if !self.is_bound() { - self.inner - .write() - .bind_to_ephemeral_endpoint(&remote_endpoint)?; - } - let bound_socket = self.inner.read().bound_socket().unwrap(); - let send_slice = |socket: &mut RawUdpSocket| match socket.send_slice(buf, remote_endpoint) { - Err(_) => return_errno_with_message!(Errno::ENOBUFS, "send udp packet fails"), - Ok(()) => Ok(buf.len()), - }; - let len = bound_socket.raw_with(send_slice)?; - poll_ifaces(); - Ok(len) + bound.try_sendto(buf, remote_endpoint, flags) } }