From 02e10705afae1edfb41e925bf6a9725018daccb7 Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Sun, 19 Nov 2023 23:54:49 +0800 Subject: [PATCH] Use `Arc` to pack bound datagram structure For TCP streams we have packed their different states with `Arc`, e.g. `InitStream`, `ConnectedStream`, and `ListenStream`. Later we will implement an trait that observes iface events for some of these states. For UDP datagrams, they can be in the unbound state and the bound state. If we want to implement the observer trait for the bound state, we need to wrap it with `Arc` so that it can be registered with `AnyBoundSocket`. Alternatively, we can implement the observer trait directly on the UDP datagram structure (i.e. `DatagramSocket`). However, there are literally no events to handle if the socket is not bound at all (i.e. it is in the unbound state). So a more efficient way is to implement the observer trait only for the bound state, which motivates changes in this commit. --- .../jinux-std/src/net/socket/ip/datagram.rs | 240 +++++++++--------- 1 file changed, 120 insertions(+), 120 deletions(-) diff --git a/services/libs/jinux-std/src/net/socket/ip/datagram.rs b/services/libs/jinux-std/src/net/socket/ip/datagram.rs index 939743ab8..718e29795 100644 --- a/services/libs/jinux-std/src/net/socket/ip/datagram.rs +++ b/services/libs/jinux-std/src/net/socket/ip/datagram.rs @@ -28,10 +28,78 @@ pub struct DatagramSocket { enum Inner { Unbound(AlwaysSome), - Bound { - bound_socket: Arc, - remote_endpoint: Option, - }, + Bound(Arc), +} + +struct BoundDatagram { + bound_socket: Arc, + remote_endpoint: RwLock>, +} + +impl BoundDatagram { + fn new(bound_socket: Arc) -> Arc { + Arc::new(Self { + bound_socket, + remote_endpoint: RwLock::new(None), + }) + } + + fn remote_endpoint(&self) -> Result { + if let Some(endpoint) = *self.remote_endpoint.read() { + Ok(endpoint) + } else { + return_errno_with_message!(Errno::EINVAL, "remote endpoint is not specified") + } + } + + fn set_remote_endpoint(&self, endpoint: IpEndpoint) { + *self.remote_endpoint.write() = Some(endpoint); + } + + fn local_endpoint(&self) -> Result { + if let Some(endpoint) = self.bound_socket.local_endpoint() { + Ok(endpoint) + } else { + return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint") + } + } + + fn try_recvfrom(&self, buf: &mut [u8], flags: &SendRecvFlags) -> Result<(usize, IpEndpoint)> { + poll_ifaces(); + let recv_slice = |socket: &mut RawUdpSocket| match socket.recv_slice(buf) { + Err(smoltcp::socket::udp::RecvError::Exhausted) => { + return_errno_with_message!(Errno::EAGAIN, "recv buf is empty") + } + Ok((len, remote_endpoint)) => Ok((len, remote_endpoint)), + }; + self.bound_socket.raw_with(recv_slice) + } + + fn try_sendto( + &self, + buf: &[u8], + remote: Option, + flags: SendRecvFlags, + ) -> Result { + let remote_endpoint = remote + .or_else(|| self.remote_endpoint().ok()) + .ok_or_else(|| Error::with_message(Errno::EINVAL, "udp should provide remote addr"))?; + let send_slice = |socket: &mut RawUdpSocket| match socket.send_slice(buf, remote_endpoint) { + Err(_) => return_errno_with_message!(Errno::ENOBUFS, "send udp packet fails"), + Ok(()) => Ok(buf.len()), + }; + let len = self.bound_socket.raw_with(send_slice)?; + poll_ifaces(); + Ok(len) + } + + fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { + self.bound_socket.poll(mask, poller) + } + + fn update_socket_state(&self) { + self.bound_socket.update_socket_state(); + } } impl Inner { @@ -39,7 +107,7 @@ impl Inner { matches!(self, Inner::Bound { .. }) } - fn bind(&mut self, endpoint: IpEndpoint) -> Result<()> { + fn bind(&mut self, endpoint: IpEndpoint) -> Result> { if self.is_bound() { return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address"); } @@ -55,69 +123,25 @@ impl Inner { .bind(bound_endpoint) .map_err(|_| Error::with_message(Errno::EINVAL, "cannot bind socket")) })?; - *self = Inner::Bound { - bound_socket, - remote_endpoint: None, - }; + let bound = BoundDatagram::new(bound_socket); + *self = Inner::Bound(bound.clone()); // Once the socket is bound, we should update the socket state at once. - self.update_socket_state(); - Ok(()) + bound.update_socket_state(); + Ok(bound) } - fn bind_to_ephemeral_endpoint(&mut self, remote_endpoint: &IpEndpoint) -> Result<()> { + fn bind_to_ephemeral_endpoint( + &mut self, + remote_endpoint: &IpEndpoint, + ) -> Result> { let endpoint = get_ephemeral_endpoint(remote_endpoint); self.bind(endpoint) } - fn set_remote_endpoint(&mut self, endpoint: IpEndpoint) -> Result<()> { - if let Inner::Bound { - remote_endpoint, .. - } = self - { - *remote_endpoint = Some(endpoint); - Ok(()) - } else { - return_errno_with_message!(Errno::EINVAL, "the socket is not bound"); - } - } - - fn remote_endpoint(&self) -> Option { - if let Inner::Bound { - remote_endpoint, .. - } = self - { - *remote_endpoint - } else { - None - } - } - - fn local_endpoint(&self) -> Option { - if let Inner::Bound { bound_socket, .. } = self { - bound_socket.local_endpoint() - } else { - None - } - } - - fn bound_socket(&self) -> Option> { - if let Inner::Bound { bound_socket, .. } = self { - Some(bound_socket.clone()) - } else { - None - } - } - fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { match self { Inner::Unbound(unbound_socket) => unbound_socket.poll(mask, poller), - Inner::Bound { bound_socket, .. } => bound_socket.poll(mask, poller), - } - } - - fn update_socket_state(&self) { - if let Inner::Bound { bound_socket, .. } = self { - bound_socket.update_socket_state(); + Inner::Bound(bound) => bound.poll(mask, poller), } } } @@ -135,28 +159,6 @@ impl DatagramSocket { self.inner.read().is_bound() } - fn try_recvfrom(&self, buf: &mut [u8], flags: &SendRecvFlags) -> Result<(usize, IpEndpoint)> { - poll_ifaces(); - let bound_socket = self.inner.read().bound_socket().unwrap(); - let recv_slice = |socket: &mut RawUdpSocket| match socket.recv_slice(buf) { - Err(smoltcp::socket::udp::RecvError::Exhausted) => { - return_errno_with_message!(Errno::EAGAIN, "recv buf is empty") - } - Ok((len, remote_endpoint)) => Ok((len, remote_endpoint)), - }; - bound_socket.raw_with(recv_slice) - } - - fn remote_endpoint(&self) -> Result { - self.inner - .read() - .remote_endpoint() - .ok_or(Error::with_message( - Errno::EINVAL, - "udp should provide remote addr", - )) - } - pub fn is_nonblocking(&self) -> bool { self.nonblocking.load(Ordering::SeqCst) } @@ -164,6 +166,27 @@ impl DatagramSocket { pub fn set_nonblocking(&self, nonblocking: bool) { self.nonblocking.store(nonblocking, Ordering::SeqCst); } + + fn bound(&self) -> Result> { + if let Inner::Bound(bound) = &*self.inner.read() { + Ok(bound.clone()) + } else { + return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint") + } + } + + fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result> { + if let Inner::Bound(bound) = &*self.inner.read() { + return Ok(bound.clone()); + } + + let mut inner = self.inner.write(); + if let Inner::Bound(bound) = &*inner { + Ok(bound.clone()) + } else { + inner.bind_to_ephemeral_endpoint(remote_endpoint) + } + } } impl FileLike for DatagramSocket { @@ -209,50 +232,36 @@ impl FileLike for DatagramSocket { impl Socket for DatagramSocket { fn bind(&self, sockaddr: SocketAddr) -> Result<()> { let endpoint = sockaddr.try_into()?; - self.inner.write().bind(endpoint) + self.inner.write().bind(endpoint)?; + Ok(()) } fn connect(&self, sockaddr: SocketAddr) -> Result<()> { - let remote_endpoint: IpEndpoint = sockaddr.try_into()?; - let mut inner = self.inner.write(); - if !self.is_bound() { - inner.bind_to_ephemeral_endpoint(&remote_endpoint)? - } - inner.set_remote_endpoint(remote_endpoint)?; - inner.update_socket_state(); + let endpoint = sockaddr.try_into()?; + let bound = self.try_bind_empheral(&endpoint)?; + bound.set_remote_endpoint(endpoint); Ok(()) } fn addr(&self) -> Result { - if let Some(local_endpoint) = self.inner.read().local_endpoint() { - local_endpoint.try_into() - } else { - return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint"); - } + self.bound()?.local_endpoint()?.try_into() } fn peer_addr(&self) -> Result { - if let Some(remote_endpoint) = self.inner.read().remote_endpoint() { - remote_endpoint.try_into() - } else { - return_errno_with_message!(Errno::EINVAL, "remote endpoint is not specified"); - } + self.bound()?.remote_endpoint()?.try_into() } // FIXME: respect RecvFromFlags fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> { debug_assert!(flags.is_all_supported()); - if !self.is_bound() { - return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint"); - } + let bound = self.bound()?; let poller = Poller::new(); - let bound_socket = self.inner.read().bound_socket().unwrap(); loop { - if let Ok((recv_len, remote_endpoint)) = self.try_recvfrom(buf, &flags) { + if let Ok((recv_len, remote_endpoint)) = bound.try_recvfrom(buf, &flags) { let remote_addr = remote_endpoint.try_into()?; return Ok((recv_len, remote_addr)); } - let events = self.inner.read().poll(IoEvents::IN, Some(&poller)); + let events = bound.poll(IoEvents::IN, Some(&poller)); if !events.contains(IoEvents::IN) { if self.is_nonblocking() { return_errno_with_message!(Errno::EAGAIN, "try to receive again"); @@ -269,23 +278,14 @@ impl Socket for DatagramSocket { remote: Option, flags: SendRecvFlags, ) -> Result { - let remote_endpoint: IpEndpoint = if let Some(remote_addr) = remote { - remote_addr.try_into()? + debug_assert!(flags.is_all_supported()); + let (bound, remote_endpoint) = if let Some(addr) = remote { + let endpoint = addr.try_into()?; + (self.try_bind_empheral(&endpoint)?, Some(endpoint)) } else { - self.remote_endpoint()? + let bound = self.bound()?; + (bound, None) }; - if !self.is_bound() { - self.inner - .write() - .bind_to_ephemeral_endpoint(&remote_endpoint)?; - } - let bound_socket = self.inner.read().bound_socket().unwrap(); - let send_slice = |socket: &mut RawUdpSocket| match socket.send_slice(buf, remote_endpoint) { - Err(_) => return_errno_with_message!(Errno::ENOBUFS, "send udp packet fails"), - Ok(()) => Ok(buf.len()), - }; - let len = bound_socket.raw_with(send_slice)?; - poll_ifaces(); - Ok(len) + bound.try_sendto(buf, remote_endpoint, flags) } }