mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-26 10:53:25 +00:00
Use Arc
to pack bound datagram structure
For TCP streams we have packed their different states with `Arc`, e.g. `InitStream`, `ConnectedStream`, and `ListenStream`. Later we will implement an trait that observes iface events for some of these states. For UDP datagrams, they can be in the unbound state and the bound state. If we want to implement the observer trait for the bound state, we need to wrap it with `Arc` so that it can be registered with `AnyBoundSocket`. Alternatively, we can implement the observer trait directly on the UDP datagram structure (i.e. `DatagramSocket`). However, there are literally no events to handle if the socket is not bound at all (i.e. it is in the unbound state). So a more efficient way is to implement the observer trait only for the bound state, which motivates changes in this commit.
This commit is contained in:
committed by
Tate, Hongliang Tian
parent
246d8521f2
commit
02e10705af
@ -28,10 +28,78 @@ pub struct DatagramSocket {
|
||||
|
||||
enum Inner {
|
||||
Unbound(AlwaysSome<AnyUnboundSocket>),
|
||||
Bound {
|
||||
Bound(Arc<BoundDatagram>),
|
||||
}
|
||||
|
||||
struct BoundDatagram {
|
||||
bound_socket: Arc<AnyBoundSocket>,
|
||||
remote_endpoint: Option<IpEndpoint>,
|
||||
},
|
||||
remote_endpoint: RwLock<Option<IpEndpoint>>,
|
||||
}
|
||||
|
||||
impl BoundDatagram {
|
||||
fn new(bound_socket: Arc<AnyBoundSocket>) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
bound_socket,
|
||||
remote_endpoint: RwLock::new(None),
|
||||
})
|
||||
}
|
||||
|
||||
fn remote_endpoint(&self) -> Result<IpEndpoint> {
|
||||
if let Some(endpoint) = *self.remote_endpoint.read() {
|
||||
Ok(endpoint)
|
||||
} else {
|
||||
return_errno_with_message!(Errno::EINVAL, "remote endpoint is not specified")
|
||||
}
|
||||
}
|
||||
|
||||
fn set_remote_endpoint(&self, endpoint: IpEndpoint) {
|
||||
*self.remote_endpoint.write() = Some(endpoint);
|
||||
}
|
||||
|
||||
fn local_endpoint(&self) -> Result<IpEndpoint> {
|
||||
if let Some(endpoint) = self.bound_socket.local_endpoint() {
|
||||
Ok(endpoint)
|
||||
} else {
|
||||
return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint")
|
||||
}
|
||||
}
|
||||
|
||||
fn try_recvfrom(&self, buf: &mut [u8], flags: &SendRecvFlags) -> Result<(usize, IpEndpoint)> {
|
||||
poll_ifaces();
|
||||
let recv_slice = |socket: &mut RawUdpSocket| match socket.recv_slice(buf) {
|
||||
Err(smoltcp::socket::udp::RecvError::Exhausted) => {
|
||||
return_errno_with_message!(Errno::EAGAIN, "recv buf is empty")
|
||||
}
|
||||
Ok((len, remote_endpoint)) => Ok((len, remote_endpoint)),
|
||||
};
|
||||
self.bound_socket.raw_with(recv_slice)
|
||||
}
|
||||
|
||||
fn try_sendto(
|
||||
&self,
|
||||
buf: &[u8],
|
||||
remote: Option<IpEndpoint>,
|
||||
flags: SendRecvFlags,
|
||||
) -> Result<usize> {
|
||||
let remote_endpoint = remote
|
||||
.or_else(|| self.remote_endpoint().ok())
|
||||
.ok_or_else(|| Error::with_message(Errno::EINVAL, "udp should provide remote addr"))?;
|
||||
let send_slice = |socket: &mut RawUdpSocket| match socket.send_slice(buf, remote_endpoint) {
|
||||
Err(_) => return_errno_with_message!(Errno::ENOBUFS, "send udp packet fails"),
|
||||
Ok(()) => Ok(buf.len()),
|
||||
};
|
||||
let len = self.bound_socket.raw_with(send_slice)?;
|
||||
poll_ifaces();
|
||||
Ok(len)
|
||||
}
|
||||
|
||||
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
self.bound_socket.poll(mask, poller)
|
||||
}
|
||||
|
||||
fn update_socket_state(&self) {
|
||||
self.bound_socket.update_socket_state();
|
||||
}
|
||||
}
|
||||
|
||||
impl Inner {
|
||||
@ -39,7 +107,7 @@ impl Inner {
|
||||
matches!(self, Inner::Bound { .. })
|
||||
}
|
||||
|
||||
fn bind(&mut self, endpoint: IpEndpoint) -> Result<()> {
|
||||
fn bind(&mut self, endpoint: IpEndpoint) -> Result<Arc<BoundDatagram>> {
|
||||
if self.is_bound() {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address");
|
||||
}
|
||||
@ -55,69 +123,25 @@ impl Inner {
|
||||
.bind(bound_endpoint)
|
||||
.map_err(|_| Error::with_message(Errno::EINVAL, "cannot bind socket"))
|
||||
})?;
|
||||
*self = Inner::Bound {
|
||||
bound_socket,
|
||||
remote_endpoint: None,
|
||||
};
|
||||
let bound = BoundDatagram::new(bound_socket);
|
||||
*self = Inner::Bound(bound.clone());
|
||||
// Once the socket is bound, we should update the socket state at once.
|
||||
self.update_socket_state();
|
||||
Ok(())
|
||||
bound.update_socket_state();
|
||||
Ok(bound)
|
||||
}
|
||||
|
||||
fn bind_to_ephemeral_endpoint(&mut self, remote_endpoint: &IpEndpoint) -> Result<()> {
|
||||
fn bind_to_ephemeral_endpoint(
|
||||
&mut self,
|
||||
remote_endpoint: &IpEndpoint,
|
||||
) -> Result<Arc<BoundDatagram>> {
|
||||
let endpoint = get_ephemeral_endpoint(remote_endpoint);
|
||||
self.bind(endpoint)
|
||||
}
|
||||
|
||||
fn set_remote_endpoint(&mut self, endpoint: IpEndpoint) -> Result<()> {
|
||||
if let Inner::Bound {
|
||||
remote_endpoint, ..
|
||||
} = self
|
||||
{
|
||||
*remote_endpoint = Some(endpoint);
|
||||
Ok(())
|
||||
} else {
|
||||
return_errno_with_message!(Errno::EINVAL, "the socket is not bound");
|
||||
}
|
||||
}
|
||||
|
||||
fn remote_endpoint(&self) -> Option<IpEndpoint> {
|
||||
if let Inner::Bound {
|
||||
remote_endpoint, ..
|
||||
} = self
|
||||
{
|
||||
*remote_endpoint
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn local_endpoint(&self) -> Option<IpEndpoint> {
|
||||
if let Inner::Bound { bound_socket, .. } = self {
|
||||
bound_socket.local_endpoint()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn bound_socket(&self) -> Option<Arc<AnyBoundSocket>> {
|
||||
if let Inner::Bound { bound_socket, .. } = self {
|
||||
Some(bound_socket.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents {
|
||||
match self {
|
||||
Inner::Unbound(unbound_socket) => unbound_socket.poll(mask, poller),
|
||||
Inner::Bound { bound_socket, .. } => bound_socket.poll(mask, poller),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_socket_state(&self) {
|
||||
if let Inner::Bound { bound_socket, .. } = self {
|
||||
bound_socket.update_socket_state();
|
||||
Inner::Bound(bound) => bound.poll(mask, poller),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -135,28 +159,6 @@ impl DatagramSocket {
|
||||
self.inner.read().is_bound()
|
||||
}
|
||||
|
||||
fn try_recvfrom(&self, buf: &mut [u8], flags: &SendRecvFlags) -> Result<(usize, IpEndpoint)> {
|
||||
poll_ifaces();
|
||||
let bound_socket = self.inner.read().bound_socket().unwrap();
|
||||
let recv_slice = |socket: &mut RawUdpSocket| match socket.recv_slice(buf) {
|
||||
Err(smoltcp::socket::udp::RecvError::Exhausted) => {
|
||||
return_errno_with_message!(Errno::EAGAIN, "recv buf is empty")
|
||||
}
|
||||
Ok((len, remote_endpoint)) => Ok((len, remote_endpoint)),
|
||||
};
|
||||
bound_socket.raw_with(recv_slice)
|
||||
}
|
||||
|
||||
fn remote_endpoint(&self) -> Result<IpEndpoint> {
|
||||
self.inner
|
||||
.read()
|
||||
.remote_endpoint()
|
||||
.ok_or(Error::with_message(
|
||||
Errno::EINVAL,
|
||||
"udp should provide remote addr",
|
||||
))
|
||||
}
|
||||
|
||||
pub fn is_nonblocking(&self) -> bool {
|
||||
self.nonblocking.load(Ordering::SeqCst)
|
||||
}
|
||||
@ -164,6 +166,27 @@ impl DatagramSocket {
|
||||
pub fn set_nonblocking(&self, nonblocking: bool) {
|
||||
self.nonblocking.store(nonblocking, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
fn bound(&self) -> Result<Arc<BoundDatagram>> {
|
||||
if let Inner::Bound(bound) = &*self.inner.read() {
|
||||
Ok(bound.clone())
|
||||
} else {
|
||||
return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint")
|
||||
}
|
||||
}
|
||||
|
||||
fn try_bind_empheral(&self, remote_endpoint: &IpEndpoint) -> Result<Arc<BoundDatagram>> {
|
||||
if let Inner::Bound(bound) = &*self.inner.read() {
|
||||
return Ok(bound.clone());
|
||||
}
|
||||
|
||||
let mut inner = self.inner.write();
|
||||
if let Inner::Bound(bound) = &*inner {
|
||||
Ok(bound.clone())
|
||||
} else {
|
||||
inner.bind_to_ephemeral_endpoint(remote_endpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FileLike for DatagramSocket {
|
||||
@ -209,50 +232,36 @@ impl FileLike for DatagramSocket {
|
||||
impl Socket for DatagramSocket {
|
||||
fn bind(&self, sockaddr: SocketAddr) -> Result<()> {
|
||||
let endpoint = sockaddr.try_into()?;
|
||||
self.inner.write().bind(endpoint)
|
||||
self.inner.write().bind(endpoint)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn connect(&self, sockaddr: SocketAddr) -> Result<()> {
|
||||
let remote_endpoint: IpEndpoint = sockaddr.try_into()?;
|
||||
let mut inner = self.inner.write();
|
||||
if !self.is_bound() {
|
||||
inner.bind_to_ephemeral_endpoint(&remote_endpoint)?
|
||||
}
|
||||
inner.set_remote_endpoint(remote_endpoint)?;
|
||||
inner.update_socket_state();
|
||||
let endpoint = sockaddr.try_into()?;
|
||||
let bound = self.try_bind_empheral(&endpoint)?;
|
||||
bound.set_remote_endpoint(endpoint);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn addr(&self) -> Result<SocketAddr> {
|
||||
if let Some(local_endpoint) = self.inner.read().local_endpoint() {
|
||||
local_endpoint.try_into()
|
||||
} else {
|
||||
return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint");
|
||||
}
|
||||
self.bound()?.local_endpoint()?.try_into()
|
||||
}
|
||||
|
||||
fn peer_addr(&self) -> Result<SocketAddr> {
|
||||
if let Some(remote_endpoint) = self.inner.read().remote_endpoint() {
|
||||
remote_endpoint.try_into()
|
||||
} else {
|
||||
return_errno_with_message!(Errno::EINVAL, "remote endpoint is not specified");
|
||||
}
|
||||
self.bound()?.remote_endpoint()?.try_into()
|
||||
}
|
||||
|
||||
// FIXME: respect RecvFromFlags
|
||||
fn recvfrom(&self, buf: &mut [u8], flags: SendRecvFlags) -> Result<(usize, SocketAddr)> {
|
||||
debug_assert!(flags.is_all_supported());
|
||||
if !self.is_bound() {
|
||||
return_errno_with_message!(Errno::EINVAL, "socket does not bind to local endpoint");
|
||||
}
|
||||
let bound = self.bound()?;
|
||||
let poller = Poller::new();
|
||||
let bound_socket = self.inner.read().bound_socket().unwrap();
|
||||
loop {
|
||||
if let Ok((recv_len, remote_endpoint)) = self.try_recvfrom(buf, &flags) {
|
||||
if let Ok((recv_len, remote_endpoint)) = bound.try_recvfrom(buf, &flags) {
|
||||
let remote_addr = remote_endpoint.try_into()?;
|
||||
return Ok((recv_len, remote_addr));
|
||||
}
|
||||
let events = self.inner.read().poll(IoEvents::IN, Some(&poller));
|
||||
let events = bound.poll(IoEvents::IN, Some(&poller));
|
||||
if !events.contains(IoEvents::IN) {
|
||||
if self.is_nonblocking() {
|
||||
return_errno_with_message!(Errno::EAGAIN, "try to receive again");
|
||||
@ -269,23 +278,14 @@ impl Socket for DatagramSocket {
|
||||
remote: Option<SocketAddr>,
|
||||
flags: SendRecvFlags,
|
||||
) -> Result<usize> {
|
||||
let remote_endpoint: IpEndpoint = if let Some(remote_addr) = remote {
|
||||
remote_addr.try_into()?
|
||||
debug_assert!(flags.is_all_supported());
|
||||
let (bound, remote_endpoint) = if let Some(addr) = remote {
|
||||
let endpoint = addr.try_into()?;
|
||||
(self.try_bind_empheral(&endpoint)?, Some(endpoint))
|
||||
} else {
|
||||
self.remote_endpoint()?
|
||||
let bound = self.bound()?;
|
||||
(bound, None)
|
||||
};
|
||||
if !self.is_bound() {
|
||||
self.inner
|
||||
.write()
|
||||
.bind_to_ephemeral_endpoint(&remote_endpoint)?;
|
||||
}
|
||||
let bound_socket = self.inner.read().bound_socket().unwrap();
|
||||
let send_slice = |socket: &mut RawUdpSocket| match socket.send_slice(buf, remote_endpoint) {
|
||||
Err(_) => return_errno_with_message!(Errno::ENOBUFS, "send udp packet fails"),
|
||||
Ok(()) => Ok(buf.len()),
|
||||
};
|
||||
let len = bound_socket.raw_with(send_slice)?;
|
||||
poll_ifaces();
|
||||
Ok(len)
|
||||
bound.try_sendto(buf, remote_endpoint, flags)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user