// SPDX-License-Identifier: MPL-2.0 use core::sync::atomic::{AtomicBool, Ordering}; use crate::events::IoEvents; use crate::net::iface::Iface; use crate::net::iface::IpEndpoint; use crate::net::iface::{AnyBoundSocket, AnyUnboundSocket}; use crate::net::socket::ip::always_some::AlwaysSome; use crate::net::socket::ip::common::{bind_socket, get_ephemeral_endpoint}; use crate::prelude::*; use crate::process::signal::Pollee; use crate::process::signal::Poller; use super::connecting::ConnectingStream; use super::listen::ListenStream; pub struct InitStream { inner: RwLock, is_nonblocking: AtomicBool, pollee: Pollee, } enum Inner { Unbound(AlwaysSome>), Bound(AlwaysSome>), } impl Inner { fn new() -> Inner { let unbound_socket = Box::new(AnyUnboundSocket::new_tcp()); Inner::Unbound(AlwaysSome::new(unbound_socket)) } fn is_bound(&self) -> bool { match self { Self::Unbound(_) => false, Self::Bound(_) => true, } } fn bind(&mut self, endpoint: IpEndpoint) -> Result<()> { let unbound_socket = if let Inner::Unbound(unbound_socket) = self { unbound_socket } else { return_errno_with_message!(Errno::EINVAL, "the socket is already bound to an address"); }; let bound_socket = unbound_socket.try_take_with(|raw_socket| bind_socket(raw_socket, endpoint, false))?; *self = Inner::Bound(AlwaysSome::new(bound_socket)); Ok(()) } fn bind_to_ephemeral_endpoint(&mut self, remote_endpoint: &IpEndpoint) -> Result<()> { let endpoint = get_ephemeral_endpoint(remote_endpoint); self.bind(endpoint) } fn bound_socket(&self) -> Option<&Arc> { match self { Inner::Bound(bound_socket) => Some(bound_socket), Inner::Unbound(_) => None, } } fn iface(&self) -> Option> { match self { Inner::Bound(bound_socket) => Some(bound_socket.iface().clone()), Inner::Unbound(_) => None, } } fn local_endpoint(&self) -> Option { self.bound_socket() .and_then(|socket| socket.local_endpoint()) } } impl InitStream { // FIXME: In Linux we have the `POLLOUT` event for a newly created socket, while calling // `write()` on it triggers `SIGPIPE`/`EPIPE`. No documentation found yet, but confirmed by // experimentation and Linux source code. pub fn new(nonblocking: bool) -> Arc { Arc::new(Self { inner: RwLock::new(Inner::new()), is_nonblocking: AtomicBool::new(nonblocking), pollee: Pollee::new(IoEvents::empty()), }) } pub fn new_bound( nonblocking: bool, bound_socket: Arc, pollee: Pollee, ) -> Arc { bound_socket.set_observer(Weak::<()>::new()); let inner = Inner::Bound(AlwaysSome::new(bound_socket)); Arc::new(Self { is_nonblocking: AtomicBool::new(nonblocking), inner: RwLock::new(inner), pollee, }) } pub fn bind(&self, endpoint: IpEndpoint) -> Result<()> { self.inner.write().bind(endpoint) } pub fn connect(&self, remote_endpoint: &IpEndpoint) -> Result> { if !self.inner.read().is_bound() { self.inner .write() .bind_to_ephemeral_endpoint(remote_endpoint)? } ConnectingStream::new( self.is_nonblocking(), self.inner.read().bound_socket().unwrap().clone(), *remote_endpoint, self.pollee.clone(), ) } pub fn listen(&self, backlog: usize) -> Result> { let bound_socket = if let Some(bound_socket) = self.inner.read().bound_socket() { bound_socket.clone() } else { return_errno_with_message!(Errno::EINVAL, "cannot listen without bound") }; ListenStream::new( self.is_nonblocking(), bound_socket, backlog, self.pollee.clone(), ) } pub fn local_endpoint(&self) -> Result { self.inner .read() .local_endpoint() .ok_or_else(|| Error::with_message(Errno::EINVAL, "does not has local endpoint")) } pub fn poll(&self, mask: IoEvents, poller: Option<&Poller>) -> IoEvents { self.pollee.poll(mask, poller) } pub fn is_nonblocking(&self) -> bool { self.is_nonblocking.load(Ordering::Relaxed) } pub fn set_nonblocking(&self, nonblocking: bool) { self.is_nonblocking.store(nonblocking, Ordering::Relaxed); } }