mirror of
https://github.com/asterinas/asterinas.git
synced 2025-06-11 22:36:48 +00:00
Fix endless loops when send buffer is full
This commit is contained in:
parent
c5d04c41a2
commit
f6c230f756
276
regression/apps/network/send_buf_full.c
Normal file
276
regression/apps/network/send_buf_full.c
Normal file
@ -0,0 +1,276 @@
|
||||
#include <arpa/inet.h>
|
||||
#include <netinet/in.h>
|
||||
#include <stdio.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/wait.h>
|
||||
#include <sched.h>
|
||||
#include <fcntl.h>
|
||||
#include <unistd.h>
|
||||
#include <errno.h>
|
||||
|
||||
static int new_bound_socket(struct sockaddr_in *addr)
|
||||
{
|
||||
int sockfd;
|
||||
|
||||
sockfd = socket(PF_INET, SOCK_STREAM, 0);
|
||||
if (sockfd < 0) {
|
||||
perror("new_bound_socket: socket");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (bind(sockfd, (struct sockaddr *)addr, sizeof(*addr)) < 0) {
|
||||
perror("new_bound_socket: bind");
|
||||
close(sockfd);
|
||||
return -1;
|
||||
}
|
||||
|
||||
return sockfd;
|
||||
}
|
||||
|
||||
static int new_connected_socket(struct sockaddr_in *addr)
|
||||
{
|
||||
int sockfd;
|
||||
|
||||
sockfd = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sockfd < 0) {
|
||||
perror("new_connected_socket: socket");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (connect(sockfd, (struct sockaddr *)addr, sizeof(*addr)) < 0) {
|
||||
perror("new_connected_socket: connect");
|
||||
close(sockfd);
|
||||
return -1;
|
||||
}
|
||||
|
||||
return sockfd;
|
||||
}
|
||||
|
||||
static int accept_without_addr(int sockfd)
|
||||
{
|
||||
struct sockaddr addr;
|
||||
socklen_t addrlen = sizeof(addr);
|
||||
int acceptfd;
|
||||
|
||||
acceptfd = accept(sockfd, &addr, &addrlen);
|
||||
if (acceptfd < 0) {
|
||||
perror("accept_without_addr: accept");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return acceptfd;
|
||||
}
|
||||
|
||||
static int mark_filde_nonblock(int flide)
|
||||
{
|
||||
int flags;
|
||||
|
||||
flags = fcntl(flide, F_GETFL, 0);
|
||||
if (flags < 0) {
|
||||
perror("mark_filde_nonblock: fcntl(F_GETFL)");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (fcntl(flide, F_SETFL, flags | O_NONBLOCK) < 0) {
|
||||
perror("mark_filde_nonblock: fcntl(F_SETFL)");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int mark_filde_mayblock(int flide)
|
||||
{
|
||||
int flags;
|
||||
|
||||
flags = fcntl(flide, F_GETFL, 0);
|
||||
if (flags < 0) {
|
||||
perror("mark_filde_mayblock: fcntl(F_GETFL)");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (fcntl(flide, F_SETFL, flags & ~O_NONBLOCK) < 0) {
|
||||
perror("mark_filde_mayblock: fcntl(F_SETFL)");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static char buffer[4096] = "Hello, world";
|
||||
|
||||
static ssize_t receive_all(int sockfd)
|
||||
{
|
||||
size_t recv_len = 0;
|
||||
ssize_t ret;
|
||||
|
||||
if (mark_filde_nonblock(sockfd) < 0) {
|
||||
perror("receive_all: mark_filde_nonblock");
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (;;) {
|
||||
ret = recv(sockfd, buffer, sizeof(buffer), 0);
|
||||
if (ret < 0 && errno == EAGAIN)
|
||||
break;
|
||||
|
||||
if (ret < 0) {
|
||||
perror("receive_all: recv");
|
||||
return -1;
|
||||
}
|
||||
|
||||
recv_len += ret;
|
||||
}
|
||||
|
||||
return recv_len;
|
||||
}
|
||||
|
||||
int test_full_send_buffer(struct sockaddr_in *addr)
|
||||
{
|
||||
int listenfd, sendfd, recvfd;
|
||||
int ret = -1, wstatus;
|
||||
size_t sent_len = 0;
|
||||
ssize_t sent;
|
||||
int pid;
|
||||
|
||||
listenfd = new_bound_socket(addr);
|
||||
if (listenfd < 0) {
|
||||
fprintf(stderr,
|
||||
"Test failed: Error occurs in new_bound_socket\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (listen(listenfd, 2) < 0) {
|
||||
perror("listen");
|
||||
fprintf(stderr, "Test failed: Error occurs in listen\n");
|
||||
goto out_listen;
|
||||
}
|
||||
|
||||
sendfd = new_connected_socket(addr);
|
||||
if (sendfd < 0) {
|
||||
fprintf(stderr,
|
||||
"Test failed: Error occurs in new_connected_socket\n");
|
||||
goto out_listen;
|
||||
}
|
||||
|
||||
recvfd = accept_without_addr(listenfd);
|
||||
if (recvfd < 0) {
|
||||
fprintf(stderr,
|
||||
"Test failed: Error occurs in accept_without_addr\n");
|
||||
goto out_send;
|
||||
}
|
||||
|
||||
if (mark_filde_nonblock(sendfd) < 0) {
|
||||
fprintf(stderr,
|
||||
"Test failed: Error occurs in mark_filde_nonblock\n");
|
||||
goto out;
|
||||
}
|
||||
|
||||
for (;;) {
|
||||
sent = send(sendfd, buffer, sizeof(buffer), 0);
|
||||
if (sent < 0 && errno == EAGAIN)
|
||||
break;
|
||||
|
||||
if (sent < 0) {
|
||||
perror("send");
|
||||
fprintf(stderr, "Test failed: Error occurs in send\n");
|
||||
goto out;
|
||||
}
|
||||
|
||||
sent_len += sent;
|
||||
}
|
||||
|
||||
if (mark_filde_mayblock(sendfd) < 0) {
|
||||
fprintf(stderr,
|
||||
"Test failed: Error occurs in mark_filde_mayblock\n");
|
||||
goto out;
|
||||
}
|
||||
|
||||
pid = fork();
|
||||
if (pid < 0) {
|
||||
perror("fork");
|
||||
fprintf(stderr, "Test failed: Error occurs in fork\n");
|
||||
goto out;
|
||||
}
|
||||
|
||||
if (pid == 0) {
|
||||
int i;
|
||||
ssize_t recv_len;
|
||||
|
||||
// Ensure that the parent executes send() first, then the child
|
||||
// executes recv().
|
||||
for (i = 0; i < 10; ++i)
|
||||
sched_yield();
|
||||
|
||||
fprintf(stderr, "Start receiving...\n");
|
||||
recv_len = receive_all(recvfd);
|
||||
if (recv_len < 0) {
|
||||
fprintf(stderr,
|
||||
"Test failed: Error occurs in receive_all\n");
|
||||
goto out;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Received bytes: %lu\n", recv_len);
|
||||
if (recv_len != sent_len + 1) {
|
||||
fprintf(stderr,
|
||||
"Test failed: Mismatched sent bytes and received bytes\n");
|
||||
goto out;
|
||||
}
|
||||
|
||||
ret = 0;
|
||||
goto out;
|
||||
}
|
||||
|
||||
sent = send(sendfd, buffer, 1, 0);
|
||||
if (sent < 0) {
|
||||
perror("send");
|
||||
fprintf(stderr, "Test failed: Error occurs in send\n");
|
||||
goto wait;
|
||||
}
|
||||
|
||||
sent_len += 1;
|
||||
fprintf(stderr, "Sent bytes: %lu\n", sent_len);
|
||||
|
||||
ret = 0;
|
||||
|
||||
wait:
|
||||
if (wait(&wstatus) < 0) {
|
||||
perror("wait");
|
||||
fprintf(stderr, "Test failed: Error occurs in wait\n");
|
||||
ret = -1;
|
||||
} else if (WEXITSTATUS(wstatus) != 0) {
|
||||
fprintf(stderr, "Test failed: Error occurs in child process\n");
|
||||
ret = -1;
|
||||
}
|
||||
|
||||
if (ret == 0)
|
||||
fprintf(stderr,
|
||||
"Test passed: Equal sent bytes and received bytes\n");
|
||||
|
||||
out:
|
||||
close(recvfd);
|
||||
|
||||
out_send:
|
||||
close(sendfd);
|
||||
|
||||
out_listen:
|
||||
close(listenfd);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
int main(void)
|
||||
{
|
||||
struct sockaddr_in addr;
|
||||
int backlog;
|
||||
int err = 0;
|
||||
|
||||
addr.sin_family = AF_INET;
|
||||
addr.sin_port = htons(8080);
|
||||
if (inet_aton("127.0.0.1", &addr.sin_addr) < 0) {
|
||||
fprintf(stderr, "inet_aton cannot parse 127.0.0.1\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return test_full_send_buffer(&addr);
|
||||
}
|
@ -77,21 +77,43 @@ impl ConnectedStream {
|
||||
|
||||
pub fn sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
|
||||
debug_assert!(flags.is_all_supported());
|
||||
let mut sent_len = 0;
|
||||
let buf_len = buf.len();
|
||||
|
||||
let poller = Poller::new();
|
||||
loop {
|
||||
let len = self
|
||||
.bound_socket
|
||||
.raw_with(|socket: &mut RawTcpSocket| socket.send_slice(&buf[sent_len..]))
|
||||
.map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet"))?;
|
||||
poll_ifaces();
|
||||
sent_len += len;
|
||||
if sent_len == buf_len {
|
||||
let sent_len = self.try_sendto(buf, flags)?;
|
||||
if sent_len > 0 {
|
||||
return Ok(sent_len);
|
||||
}
|
||||
let events = self.bound_socket.poll(IoEvents::OUT, Some(&poller));
|
||||
if events.contains(IoEvents::HUP) || events.contains(IoEvents::ERR) {
|
||||
return_errno_with_message!(Errno::ENOBUFS, "fail to send packets");
|
||||
}
|
||||
if !events.contains(IoEvents::OUT) {
|
||||
if self.is_nonblocking() {
|
||||
return_errno_with_message!(Errno::EAGAIN, "try to send again");
|
||||
}
|
||||
// FIXME: deal with send timeout
|
||||
poller.wait()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_sendto(&self, buf: &[u8], flags: SendRecvFlags) -> Result<usize> {
|
||||
let res = self
|
||||
.bound_socket
|
||||
.raw_with(|socket: &mut RawTcpSocket| socket.send_slice(buf))
|
||||
.map_err(|_| Error::with_message(Errno::ENOBUFS, "cannot send packet"));
|
||||
match res {
|
||||
// We have to explicitly invoke `update_socket_state` when the send buffer becomes
|
||||
// full. Note that smoltcp does not think it is an interface event, so calling
|
||||
// `poll_ifaces` alone is not enough.
|
||||
Ok(0) => self.bound_socket.update_socket_state(),
|
||||
Ok(_) => poll_ifaces(),
|
||||
_ => (),
|
||||
};
|
||||
res
|
||||
}
|
||||
|
||||
pub fn local_endpoint(&self) -> Result<IpEndpoint> {
|
||||
self.bound_socket
|
||||
.local_endpoint()
|
||||
|
@ -228,10 +228,11 @@ impl Socket for StreamSocket {
|
||||
if remote.is_some() {
|
||||
return_errno_with_message!(Errno::EINVAL, "tcp socked should not provide remote addr");
|
||||
}
|
||||
let state = self.state.read();
|
||||
match &*state {
|
||||
State::Connected(connected_stream) => connected_stream.sendto(buf, flags),
|
||||
_ => return_errno_with_message!(Errno::EINVAL, "cannot send"),
|
||||
}
|
||||
|
||||
let connected_stream = match &*self.state.read() {
|
||||
State::Connected(connected_stream) => connected_stream.clone(),
|
||||
_ => return_errno_with_message!(Errno::EINVAL, "the socket is not connected"),
|
||||
};
|
||||
connected_stream.sendto(buf, flags)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user