diff --git a/kernel/src/syscall/poll.rs b/kernel/src/syscall/poll.rs index bdbd715c..d7cfc668 100644 --- a/kernel/src/syscall/poll.rs +++ b/kernel/src/syscall/poll.rs @@ -3,29 +3,39 @@ use core::{cell::Cell, time::Duration}; use super::SyscallReturn; -use crate::{events::IoEvents, fs::file_table::FileDesc, prelude::*, process::signal::Poller}; +use crate::{ + events::IoEvents, + fs::{file_handle::FileLike, file_table::FileDesc}, + prelude::*, + process::signal::Poller, +}; pub fn sys_poll(fds: Vaddr, nfds: u64, timeout: i32, ctx: &Context) -> Result { let user_space = ctx.get_user_space(); + let poll_fds = { let mut read_addr = fds; let mut poll_fds = Vec::with_capacity(nfds as _); + for _ in 0..nfds { let c_poll_fd = user_space.read_val::(read_addr)?; + read_addr += core::mem::size_of::(); + let poll_fd = PollFd::from(c_poll_fd); // Always clear the revents fields first poll_fd.revents().set(IoEvents::empty()); poll_fds.push(poll_fd); - // FIXME: do we need to respect align of c_pollfd here? - read_addr += core::mem::size_of::(); } + poll_fds }; + let timeout = if timeout >= 0 { Some(Duration::from_millis(timeout as _)) } else { None }; + debug!( "poll_fds = {:?}, nfds = {}, timeout = {:?}", poll_fds, nfds, timeout @@ -37,8 +47,8 @@ pub fn sys_poll(fds: Vaddr, nfds: u64, timeout: i32, ctx: &Context) -> Result(); } @@ -46,44 +56,18 @@ pub fn sys_poll(fds: Vaddr, nfds: u64, timeout: i32, ctx: &Context) -> Result, ctx: &Context) -> Result { - // The main loop of polling - let mut poller = Poller::new(); + let files = hold_files(poll_fds, ctx)?; + + let poller = match register_poller(poll_fds, files.as_ref()) { + PollerResult::AllRegistered(poller) => poller, + PollerResult::EventFoundAt(index) => { + let next = index + 1; + let remaining_events = count_all_events(&poll_fds[next..], &files[next..]); + return Ok(1 + remaining_events); + } + }; + loop { - let mut num_revents = 0; - - let file_table = ctx.process.file_table().lock(); - for poll_fd in poll_fds { - // Skip poll_fd if it is not given a fd - let fd = match poll_fd.fd() { - Some(fd) => fd, - None => continue, - }; - - // Poll the file - let file = file_table.get_file(fd)?; - let need_poller = if num_revents == 0 { - Some(&mut poller) - } else { - None - }; - let revents = file.poll(poll_fd.events(), need_poller); - if !revents.is_empty() { - poll_fd.revents().set(revents); - num_revents += 1; - } - } - - drop(file_table); - - if num_revents > 0 { - return Ok(num_revents); - } - - // Return immediately if specifying a timeout of zero - if timeout.is_some() && timeout.as_ref().unwrap().is_zero() { - return Ok(0); - } - if let Some(timeout) = timeout.as_ref() { match poller.wait_timeout(timeout) { Ok(_) => {} @@ -97,9 +81,79 @@ pub fn do_poll(poll_fds: &[PollFd], timeout: Option, ctx: &Context) -> } else { poller.wait()?; } + + let num_events = count_all_events(poll_fds, &files); + if num_events > 0 { + return Ok(num_events); + } } } +/// Holds all the files we're going to poll. +fn hold_files(poll_fds: &[PollFd], ctx: &Context) -> Result>>> { + let file_table = ctx.process.file_table().lock(); + + let mut files = Vec::with_capacity(poll_fds.len()); + + for poll_fd in poll_fds.iter() { + let Some(fd) = poll_fd.fd() else { + files.push(None); + continue; + }; + + files.push(Some(file_table.get_file(fd)?.clone())); + } + + Ok(files) +} + +enum PollerResult { + AllRegistered(Poller), + EventFoundAt(usize), +} + +/// Registers the files with a poller, or exits early if some events are detected. +fn register_poller(poll_fds: &[PollFd], files: &[Option>]) -> PollerResult { + let mut poller = Poller::new(); + + for (i, (poll_fd, file)) in poll_fds.iter().zip(files.iter()).enumerate() { + let Some(file) = file else { + continue; + }; + + let events = file.poll(poll_fd.events(), Some(&mut poller)); + if events.is_empty() { + continue; + } + + poll_fd.revents().set(events); + return PollerResult::EventFoundAt(i); + } + + PollerResult::AllRegistered(poller) +} + +/// Counts the number of the ready files. +fn count_all_events(poll_fds: &[PollFd], files: &[Option>]) -> usize { + let mut counter = 0; + + for (poll_fd, file) in poll_fds.iter().zip(files.iter()) { + let Some(file) = file else { + continue; + }; + + let events = file.poll(poll_fd.events(), None); + if events.is_empty() { + continue; + } + + poll_fd.revents().set(events); + counter += 1; + } + + counter +} + // https://github.com/torvalds/linux/blob/master/include/uapi/asm-generic/poll.h #[derive(Debug, Clone, Copy, Pod)] #[repr(C)]