From f6c230f7567cbc515d65d1ba66805b6cd6088b04 Mon Sep 17 00:00:00 2001 From: Ruihan Li Date: Thu, 2 Nov 2023 14:00:50 +0800 Subject: [PATCH] Fix endless loops when send buffer is full --- regression/apps/network/send_buf_full.c | 276 ++++++++++++++++++ .../src/net/socket/ip/stream/connected.rs | 40 ++- .../jinux-std/src/net/socket/ip/stream/mod.rs | 11 +- 3 files changed, 313 insertions(+), 14 deletions(-) create mode 100644 regression/apps/network/send_buf_full.c diff --git a/regression/apps/network/send_buf_full.c b/regression/apps/network/send_buf_full.c new file mode 100644 index 000000000..f1f431f5a --- /dev/null +++ b/regression/apps/network/send_buf_full.c @@ -0,0 +1,276 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int new_bound_socket(struct sockaddr_in *addr) +{ + int sockfd; + + sockfd = socket(PF_INET, SOCK_STREAM, 0); + if (sockfd < 0) { + perror("new_bound_socket: socket"); + return -1; + } + + if (bind(sockfd, (struct sockaddr *)addr, sizeof(*addr)) < 0) { + perror("new_bound_socket: bind"); + close(sockfd); + return -1; + } + + return sockfd; +} + +static int new_connected_socket(struct sockaddr_in *addr) +{ + int sockfd; + + sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (sockfd < 0) { + perror("new_connected_socket: socket"); + return -1; + } + + if (connect(sockfd, (struct sockaddr *)addr, sizeof(*addr)) < 0) { + perror("new_connected_socket: connect"); + close(sockfd); + return -1; + } + + return sockfd; +} + +static int accept_without_addr(int sockfd) +{ + struct sockaddr addr; + socklen_t addrlen = sizeof(addr); + int acceptfd; + + acceptfd = accept(sockfd, &addr, &addrlen); + if (acceptfd < 0) { + perror("accept_without_addr: accept"); + return -1; + } + + return acceptfd; +} + +static int mark_filde_nonblock(int flide) +{ + int flags; + + flags = fcntl(flide, F_GETFL, 0); + if (flags < 0) { + perror("mark_filde_nonblock: fcntl(F_GETFL)"); + return -1; + } + + if (fcntl(flide, F_SETFL, flags | O_NONBLOCK) < 0) { + perror("mark_filde_nonblock: fcntl(F_SETFL)"); + return -1; + } + + return 0; +} + +static int mark_filde_mayblock(int flide) +{ + int flags; + + flags = fcntl(flide, F_GETFL, 0); + if (flags < 0) { + perror("mark_filde_mayblock: fcntl(F_GETFL)"); + return -1; + } + + if (fcntl(flide, F_SETFL, flags & ~O_NONBLOCK) < 0) { + perror("mark_filde_mayblock: fcntl(F_SETFL)"); + return -1; + } + + return 0; +} + +static char buffer[4096] = "Hello, world"; + +static ssize_t receive_all(int sockfd) +{ + size_t recv_len = 0; + ssize_t ret; + + if (mark_filde_nonblock(sockfd) < 0) { + perror("receive_all: mark_filde_nonblock"); + return -1; + } + + for (;;) { + ret = recv(sockfd, buffer, sizeof(buffer), 0); + if (ret < 0 && errno == EAGAIN) + break; + + if (ret < 0) { + perror("receive_all: recv"); + return -1; + } + + recv_len += ret; + } + + return recv_len; +} + +int test_full_send_buffer(struct sockaddr_in *addr) +{ + int listenfd, sendfd, recvfd; + int ret = -1, wstatus; + size_t sent_len = 0; + ssize_t sent; + int pid; + + listenfd = new_bound_socket(addr); + if (listenfd < 0) { + fprintf(stderr, + "Test failed: Error occurs in new_bound_socket\n"); + return -1; + } + + if (listen(listenfd, 2) < 0) { + perror("listen"); + fprintf(stderr, "Test failed: Error occurs in listen\n"); + goto out_listen; + } + + sendfd = new_connected_socket(addr); + if (sendfd < 0) { + fprintf(stderr, + "Test failed: Error occurs in new_connected_socket\n"); + goto out_listen; + } + + recvfd = accept_without_addr(listenfd); + if (recvfd < 0) { + fprintf(stderr, + "Test failed: Error occurs in accept_without_addr\n"); + goto out_send; + } + + if (mark_filde_nonblock(sendfd) < 0) { + fprintf(stderr, + "Test failed: Error occurs in mark_filde_nonblock\n"); + goto out; + } + + for (;;) { + sent = send(sendfd, buffer, sizeof(buffer), 0); + if (sent < 0 && errno == EAGAIN) + break; + + if (sent < 0) { + perror("send"); + fprintf(stderr, "Test failed: Error occurs in send\n"); + goto out; + } + + sent_len += sent; + } + + if (mark_filde_mayblock(sendfd) < 0) { + fprintf(stderr, + "Test failed: Error occurs in mark_filde_mayblock\n"); + goto out; + } + + pid = fork(); + if (pid < 0) { + perror("fork"); + fprintf(stderr, "Test failed: Error occurs in fork\n"); + goto out; + } + + if (pid == 0) { + int i; + ssize_t recv_len; + + // Ensure that the parent executes send() first, then the child + // executes recv(). + for (i = 0; i < 10; ++i) + sched_yield(); + + fprintf(stderr, "Start receiving...\n"); + recv_len = receive_all(recvfd); + if (recv_len < 0) { + fprintf(stderr, + "Test failed: Error occurs in receive_all\n"); + goto out; + } + + fprintf(stderr, "Received bytes: %lu\n", recv_len); + if (recv_len != sent_len + 1) { + fprintf(stderr, + "Test failed: Mismatched sent bytes and received bytes\n"); + goto out; + } + + ret = 0; + goto out; + } + + sent = send(sendfd, buffer, 1, 0); + if (sent < 0) { + perror("send"); + fprintf(stderr, "Test failed: Error occurs in send\n"); + goto wait; + } + + sent_len += 1; + fprintf(stderr, "Sent bytes: %lu\n", sent_len); + + ret = 0; + +wait: + if (wait(&wstatus) < 0) { + perror("wait"); + fprintf(stderr, "Test failed: Error occurs in wait\n"); + ret = -1; + } else if (WEXITSTATUS(wstatus) != 0) { + fprintf(stderr, "Test failed: Error occurs in child process\n"); + ret = -1; + } + + if (ret == 0) + fprintf(stderr, + "Test passed: Equal sent bytes and received bytes\n"); + +out: + close(recvfd); + +out_send: + close(sendfd); + +out_listen: + close(listenfd); + + return ret; +} + +int main(void) +{ + struct sockaddr_in addr; + int backlog; + int err = 0; + + addr.sin_family = AF_INET; + addr.sin_port = htons(8080); + if (inet_aton("127.0.0.1", &addr.sin_addr) < 0) { + fprintf(stderr, "inet_aton cannot parse 127.0.0.1\n"); + return -1; + } + + return test_full_send_buffer(&addr); +} diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs b/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs index 14d21d762..55b8aceeb 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/connected.rs @@ -77,21 +77,43 @@ impl ConnectedStream { pub fn sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result { debug_assert!(flags.is_all_supported()); - let mut sent_len = 0; - let buf_len = buf.len(); + + let poller = Poller::new(); loop { - let len = self - .bound_socket - .raw_with(|socket: &mut RawTcpSocket| socket.send_slice(&buf[sent_len..])) - .map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet"))?; - poll_ifaces(); - sent_len += len; - if sent_len == buf_len { + let sent_len = self.try_sendto(buf, flags)?; + if sent_len > 0 { return Ok(sent_len); } + let events = self.bound_socket.poll(IoEvents::OUT, Some(&poller)); + if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) { + return_errno_with_message!(Errno::ENOBUFS, "fail to send packets"); + } + if !events.contains(IoEvents::OUT) { + if self.is_nonblocking() { + return_errno_with_message!(Errno::EAGAIN, "try to send again"); + } + // FIXME: deal with send timeout + poller.wait()?; + } } } + fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result { + let res = self + .bound_socket + .raw_with(|socket: &mut RawTcpSocket| socket.send_slice(buf)) + .map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet")); + match res { + // We have to explicitly invoke `update_socket_state` when the send buffer becomes + // full. Note that smoltcp does not think it is an interface event, so calling + // `poll_ifaces` alone is not enough. + Ok(0) => self.bound_socket.update_socket_state(), + Ok(_) => poll_ifaces(), + _ => (), + }; + res + } + pub fn local_endpoint(&self) -> Result { self.bound_socket .local_endpoint() diff --git a/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs b/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs index 2abb76fa2..b34af7488 100644 --- a/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs +++ b/services/libs/jinux-std/src/net/socket/ip/stream/mod.rs @@ -228,10 +228,11 @@ impl Socket for StreamSocket { if remote.is_some() { return_errno_with_message!(Errno::EINVAL, "tcp socked should not provide remote addr"); } - let state = self.state.read(); - match &*state { - State::Connected(connected_stream) => connected_stream.sendto(buf, flags), - _ => return_errno_with_message!(Errno::EINVAL, "cannot send"), - } + + let connected_stream = match &*self.state.read() { + State::Connected(connected_stream) => connected_stream.clone(), + _ => return_errno_with_message!(Errno::EINVAL, "the socket is not connected"), + }; + connected_stream.sendto(buf, flags) } }