Fix endless loops when send buffer is full

This commit is contained in:
Ruihan Li 2023-11-02 14:00:50 +08:00 committed by Tate, Hongliang Tian
parent c5d04c41a2
commit f6c230f756
3 changed files with 313 additions and 14 deletions

View File

@ -0,0 +1,276 @@
#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdio.h>
#include <sys/socket.h>
#include <sys/wait.h>
#include <sched.h>
#include <fcntl.h>
#include <unistd.h>
#include <errno.h>
static int new_bound_socket(struct sockaddr_in *addr)
{
int sockfd;
sockfd = socket(PF_INET, SOCK_STREAM, 0);
if (sockfd < 0) {
perror("new_bound_socket: socket");
return -1;
}
if (bind(sockfd, (struct sockaddr *)addr, sizeof(*addr)) < 0) {
perror("new_bound_socket: bind");
close(sockfd);
return -1;
}
return sockfd;
}
static int new_connected_socket(struct sockaddr_in *addr)
{
int sockfd;
sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd < 0) {
perror("new_connected_socket: socket");
return -1;
}
if (connect(sockfd, (struct sockaddr *)addr, sizeof(*addr)) < 0) {
perror("new_connected_socket: connect");
close(sockfd);
return -1;
}
return sockfd;
}
static int accept_without_addr(int sockfd)
{
struct sockaddr addr;
socklen_t addrlen = sizeof(addr);
int acceptfd;
acceptfd = accept(sockfd, &addr, &addrlen);
if (acceptfd < 0) {
perror("accept_without_addr: accept");
return -1;
}
return acceptfd;
}
static int mark_filde_nonblock(int flide)
{
int flags;
flags = fcntl(flide, F_GETFL, 0);
if (flags < 0) {
perror("mark_filde_nonblock: fcntl(F_GETFL)");
return -1;
}
if (fcntl(flide, F_SETFL, flags | O_NONBLOCK) < 0) {
perror("mark_filde_nonblock: fcntl(F_SETFL)");
return -1;
}
return 0;
}
static int mark_filde_mayblock(int flide)
{
int flags;
flags = fcntl(flide, F_GETFL, 0);
if (flags < 0) {
perror("mark_filde_mayblock: fcntl(F_GETFL)");
return -1;
}
if (fcntl(flide, F_SETFL, flags & ~O_NONBLOCK) < 0) {
perror("mark_filde_mayblock: fcntl(F_SETFL)");
return -1;
}
return 0;
}
static char buffer[4096] = "Hello, world";
static ssize_t receive_all(int sockfd)
{
size_t recv_len = 0;
ssize_t ret;
if (mark_filde_nonblock(sockfd) < 0) {
perror("receive_all: mark_filde_nonblock");
return -1;
}
for (;;) {
ret = recv(sockfd, buffer, sizeof(buffer), 0);
if (ret < 0 && errno == EAGAIN)
break;
if (ret < 0) {
perror("receive_all: recv");
return -1;
}
recv_len += ret;
}
return recv_len;
}
int test_full_send_buffer(struct sockaddr_in *addr)
{
int listenfd, sendfd, recvfd;
int ret = -1, wstatus;
size_t sent_len = 0;
ssize_t sent;
int pid;
listenfd = new_bound_socket(addr);
if (listenfd < 0) {
fprintf(stderr,
"Test failed: Error occurs in new_bound_socket\n");
return -1;
}
if (listen(listenfd, 2) < 0) {
perror("listen");
fprintf(stderr, "Test failed: Error occurs in listen\n");
goto out_listen;
}
sendfd = new_connected_socket(addr);
if (sendfd < 0) {
fprintf(stderr,
"Test failed: Error occurs in new_connected_socket\n");
goto out_listen;
}
recvfd = accept_without_addr(listenfd);
if (recvfd < 0) {
fprintf(stderr,
"Test failed: Error occurs in accept_without_addr\n");
goto out_send;
}
if (mark_filde_nonblock(sendfd) < 0) {
fprintf(stderr,
"Test failed: Error occurs in mark_filde_nonblock\n");
goto out;
}
for (;;) {
sent = send(sendfd, buffer, sizeof(buffer), 0);
if (sent < 0 && errno == EAGAIN)
break;
if (sent < 0) {
perror("send");
fprintf(stderr, "Test failed: Error occurs in send\n");
goto out;
}
sent_len += sent;
}
if (mark_filde_mayblock(sendfd) < 0) {
fprintf(stderr,
"Test failed: Error occurs in mark_filde_mayblock\n");
goto out;
}
pid = fork();
if (pid < 0) {
perror("fork");
fprintf(stderr, "Test failed: Error occurs in fork\n");
goto out;
}
if (pid == 0) {
int i;
ssize_t recv_len;
// Ensure that the parent executes send() first, then the child
// executes recv().
for (i = 0; i < 10; ++i)
sched_yield();
fprintf(stderr, "Start receiving...\n");
recv_len = receive_all(recvfd);
if (recv_len < 0) {
fprintf(stderr,
"Test failed: Error occurs in receive_all\n");
goto out;
}
fprintf(stderr, "Received bytes: %lu\n", recv_len);
if (recv_len != sent_len + 1) {
fprintf(stderr,
"Test failed: Mismatched sent bytes and received bytes\n");
goto out;
}
ret = 0;
goto out;
}
sent = send(sendfd, buffer, 1, 0);
if (sent < 0) {
perror("send");
fprintf(stderr, "Test failed: Error occurs in send\n");
goto wait;
}
sent_len += 1;
fprintf(stderr, "Sent bytes: %lu\n", sent_len);
ret = 0;
wait:
if (wait(&wstatus) < 0) {
perror("wait");
fprintf(stderr, "Test failed: Error occurs in wait\n");
ret = -1;
} else if (WEXITSTATUS(wstatus) != 0) {
fprintf(stderr, "Test failed: Error occurs in child process\n");
ret = -1;
}
if (ret == 0)
fprintf(stderr,
"Test passed: Equal sent bytes and received bytes\n");
out:
close(recvfd);
out_send:
close(sendfd);
out_listen:
close(listenfd);
return ret;
}
int main(void)
{
struct sockaddr_in addr;
int backlog;
int err = 0;
addr.sin_family = AF_INET;
addr.sin_port = htons(8080);
if (inet_aton("127.0.0.1", &addr.sin_addr) < 0) {
fprintf(stderr, "inet_aton cannot parse 127.0.0.1\n");
return -1;
}
return test_full_send_buffer(&addr);
}

View File

@ -77,21 +77,43 @@ impl ConnectedStream {
pub fn sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
debug_assert!(flags.is_all_supported());
let mut sent_len = 0;
let buf_len = buf.len();
let poller = Poller::new();
loop {
let len = self
.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.send_slice(&buf[sent_len..]))
.map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet"))?;
poll_ifaces();
sent_len += len;
if sent_len == buf_len {
let sent_len = self.try_sendto(buf, flags)?;
if sent_len > 0 {
return Ok(sent_len);
}
let events = self.bound_socket.poll(IoEvents::OUT, Some(&poller));
if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) {
return_errno_with_message!(Errno::ENOBUFS, "fail to send packets");
}
if !events.contains(IoEvents::OUT) {
if self.is_nonblocking() {
return_errno_with_message!(Errno::EAGAIN, "try to send again");
}
// FIXME: deal with send timeout
poller.wait()?;
}
}
}
fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
let res = self
.bound_socket
.raw_with(|socket: &mut RawTcpSocket| socket.send_slice(buf))
.map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet"));
match res {
// We have to explicitly invoke `update_socket_state` when the send buffer becomes
// full. Note that smoltcp does not think it is an interface event, so calling
// `poll_ifaces` alone is not enough.
Ok(0) => self.bound_socket.update_socket_state(),
Ok(_) => poll_ifaces(),
_ => (),
};
res
}
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
self.bound_socket
.local_endpoint()

View File

@ -228,10 +228,11 @@ impl Socket for StreamSocket {
if remote.is_some() {
return_errno_with_message!(Errno::EINVAL, "tcp socked should not provide remote addr");
}
let state = self.state.read();
match &*state {
State::Connected(connected_stream) => connected_stream.sendto(buf, flags),
_ => return_errno_with_message!(Errno::EINVAL, "cannot send"),
}
let connected_stream = match &*self.state.read() {
State::Connected(connected_stream) => connected_stream.clone(),
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not connected"),
};
connected_stream.sendto(buf, flags)
}
}