Report POLLNVAL in poll for invalid FDs

This commit is contained in:
Ruihan Li 2024-10-25 18:03:51 +08:00 committed by Tate, Hongliang Tian
parent 390aa411bd
commit b5610f3034
4 changed files with 84 additions and 11 deletions

View File

@ -56,7 +56,13 @@ pub fn sys_poll(fds: Vaddr, nfds: u64, timeout: i32, ctx: &Context) -> Result<Sy
}
pub fn do_poll(poll_fds: &[PollFd], timeout: Option<Duration>, ctx: &Context) -> Result<usize> {
let files = hold_files(poll_fds, ctx)?;
let (result, files) = hold_files(poll_fds, ctx);
match result {
FileResult::AllValid => (),
FileResult::SomeInvalid => {
return Ok(count_all_events(poll_fds, &files));
}
}
let poller = match register_poller(poll_fds, files.as_ref()) {
PollerResult::AllRegistered(poller) => poller,
@ -89,11 +95,17 @@ pub fn do_poll(poll_fds: &[PollFd], timeout: Option<Duration>, ctx: &Context) ->
}
}
enum FileResult {
AllValid,
SomeInvalid,
}
/// Holds all the files we're going to poll.
fn hold_files(poll_fds: &[PollFd], ctx: &Context) -> Result<Vec<Option<Arc<dyn FileLike>>>> {
fn hold_files(poll_fds: &[PollFd], ctx: &Context) -> (FileResult, Vec<Option<Arc<dyn FileLike>>>) {
let file_table = ctx.process.file_table().lock();
let mut files = Vec::with_capacity(poll_fds.len());
let mut result = FileResult::AllValid;
for poll_fd in poll_fds.iter() {
let Some(fd) = poll_fd.fd() else {
@ -101,10 +113,18 @@ fn hold_files(poll_fds: &[PollFd], ctx: &Context) -> Result<Vec<Option<Arc<dyn F
continue;
};
files.push(Some(file_table.get_file(fd)?.clone()));
let Ok(file) = file_table.get_file(fd) else {
poll_fd.revents.set(IoEvents::NVAL);
result = FileResult::SomeInvalid;
files.push(None);
continue;
};
files.push(Some(file.clone()));
}
Ok(files)
(result, files)
}
enum PollerResult {
@ -139,6 +159,10 @@ fn count_all_events(poll_fds: &[PollFd], files: &[Option<Arc<dyn FileLike>>]) ->
for (poll_fd, file) in poll_fds.iter().zip(files.iter()) {
let Some(file) = file else {
if !poll_fd.revents.get().is_empty() {
// This is only possible for POLLNVAL.
counter += 1;
}
continue;
};

View File

@ -146,7 +146,7 @@ fn do_select(
for poll_fd in &poll_fds {
let fd = poll_fd.fd().unwrap();
let revents = poll_fd.revents().get();
let (readable, writable, except) = convert_events_to_rwe(&revents);
let (readable, writable, except) = convert_events_to_rwe(revents)?;
if let Some(ref mut fds) = readfds
&& readable
{
@ -169,8 +169,8 @@ fn do_select(
Ok(total_revents)
}
// Convert select's rwe input to poll's IoEvents input according to Linux's
// behavior.
/// Converts `select` RWE input to `poll` I/O event input
/// according to Linux's behavior.
fn convert_rwe_to_events(readable: bool, writable: bool, except: bool) -> IoEvents {
let mut events = IoEvents::empty();
if readable {
@ -185,13 +185,17 @@ fn convert_rwe_to_events(readable: bool, writable: bool, except: bool) -> IoEven
events
}
// Convert poll's IoEvents results to select's rwe results according to Linux's
// behavior.
fn convert_events_to_rwe(events: &IoEvents) -> (bool, bool, bool) {
/// Converts `poll` I/O event results to `select` RWE results
/// according to Linux's behavior.
fn convert_events_to_rwe(events: IoEvents) -> Result<(bool, bool, bool)> {
if events.contains(IoEvents::NVAL) {
return_errno_with_message!(Errno::EBADF, "the file descriptor is invalid");
}
let readable = events.intersects(IoEvents::IN | IoEvents::HUP | IoEvents::ERR);
let writable = events.intersects(IoEvents::OUT | IoEvents::ERR);
let except = events.contains(IoEvents::PRI);
(readable, writable, except)
Ok((readable, writable, except))
}
const FD_SETSIZE: usize = 1024;

View File

@ -0,0 +1,44 @@
// SPDX-License-Identifier: MPL-2.0
#include "../network/test.h"
#include <unistd.h>
#include <sys/poll.h>
FN_TEST(poll_nval)
{
int fildes[2];
int rfd, wfd;
struct pollfd fds[3];
TEST_SUCC(pipe(fildes));
rfd = fildes[0];
wfd = fildes[1];
TEST_SUCC(write(wfd, "", 1));
fds[0].fd = rfd;
fds[1].fd = 1000;
fds[2].fd = wfd;
fds[0].events = POLLIN | POLLOUT;
fds[1].events = POLLIN | POLLOUT;
fds[2].events = POLLIN | POLLOUT;
TEST_RES(poll(fds, 3, 0), _ret == 3 && fds[0].revents == POLLIN &&
fds[1].revents == POLLNVAL &&
fds[2].revents == POLLOUT);
TEST_SUCC(close(rfd));
TEST_SUCC(close(wfd));
}
END_TEST()
FN_TEST(select_bafd)
{
fd_set rfds;
FD_ZERO(&rfds);
FD_SET(100, &rfds);
TEST_ERRNO(select(200, &rfds, NULL, NULL, NULL), EBADF);
}
END_TEST()

View File

@ -64,3 +64,4 @@ echo "All fdatasync test passed."
pipe/pipe_err
pipe/short_rw
epoll/epoll_err
epoll/poll_err