diff --git a/.cirrus.yml b/.cirrus.yml index d1e8a7a..f34e6ab 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -28,11 +28,6 @@ freebsd_task: - sudo sysctl net.inet.tcp.blackhole=0 - . $HOME/.cargo/env - cargo test --target $TARGET - # Test async-io - - git clone https://github.com/smol-rs/async-io.git - - echo '[patch.crates-io]' >> async-io/Cargo.toml - - echo 'polling = { path = ".." }' >> async-io/Cargo.toml - - cargo test --target $TARGET --manifest-path=async-io/Cargo.toml netbsd_task: name: test ($TARGET) @@ -49,11 +44,6 @@ netbsd_task: test_script: - . $HOME/.cargo/env - cargo test - # Test async-io - - git clone https://github.com/smol-rs/async-io.git - - echo '[patch.crates-io]' >> async-io/Cargo.toml - - echo 'polling = { path = ".." }' >> async-io/Cargo.toml - - cargo test --manifest-path=async-io/Cargo.toml openbsd_task: name: test ($TARGET) @@ -69,8 +59,3 @@ openbsd_task: - pkg_add git rust test_script: - cargo test - # Test async-io - - git clone https://github.com/smol-rs/async-io.git - - echo '[patch.crates-io]' >> async-io/Cargo.toml - - echo 'polling = { path = ".." }' >> async-io/Cargo.toml - - cargo test --manifest-path=async-io/Cargo.toml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ff52e57..0500ba8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,14 +56,6 @@ jobs: RUSTFLAGS: ${{ env.RUSTFLAGS }} --cfg polling_test_poll_backend if: startsWith(matrix.os, 'ubuntu') - run: cargo hack build --feature-powerset --no-dev-deps - - name: Clone async-io - run: git clone https://github.com/smol-rs/async-io.git - - name: Add patch section - run: echo '[patch.crates-io]' >> async-io/Cargo.toml - - name: Patch polling - run: echo 'polling = { path = ".." }' >> async-io/Cargo.toml - - name: Test async-io - run: cargo test --manifest-path=async-io/Cargo.toml cross: runs-on: ${{ matrix.os }} diff --git a/Cargo.toml b/Cargo.toml index b4885d2..4ff2de5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ tracing = { version = "0.1.37", default-features = false } [target.'cfg(any(unix, target_os = "fuchsia", target_os = "vxworks"))'.dependencies] libc = "0.2.77" -rustix = { version = "0.37.11", features = ["process", "time", "fs", "std"], default-features = false } +rustix = { version = "0.38", features = ["event", "fs", "pipe", "process", "std", "time"], default-features = false } [target.'cfg(windows)'.dependencies] concurrent-queue = "2.2.0" diff --git a/examples/two-listeners.rs b/examples/two-listeners.rs index 3dcbc4b..02b2339 100644 --- a/examples/two-listeners.rs +++ b/examples/two-listeners.rs @@ -10,8 +10,10 @@ fn main() -> io::Result<()> { l2.set_nonblocking(true)?; let poller = Poller::new()?; - poller.add(&l1, Event::readable(1))?; - poller.add(&l2, Event::readable(2))?; + unsafe { + poller.add(&l1, Event::readable(1))?; + poller.add(&l2, Event::readable(2))?; + } println!("You can connect to the server using `nc`:"); println!(" $ nc 127.0.0.1 8001"); diff --git a/src/epoll.rs b/src/epoll.rs index 91931e2..4d21bbe 100644 --- a/src/epoll.rs +++ b/src/epoll.rs @@ -5,8 +5,9 @@ use std::io; use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd}; use std::time::Duration; +use rustix::event::{epoll, eventfd, EventfdFlags}; use rustix::fd::OwnedFd; -use rustix::io::{epoll, eventfd, read, write, EventfdFlags}; +use rustix::io::{read, write}; use rustix::time::{ timerfd_create, timerfd_settime, Itimerspec, TimerfdClockId, TimerfdFlags, TimerfdTimerFlags, Timespec, @@ -31,7 +32,7 @@ impl Poller { // Create an epoll instance. // // Use `epoll_create1` with `EPOLL_CLOEXEC`. - let epoll_fd = epoll::epoll_create(epoll::CreateFlags::CLOEXEC)?; + let epoll_fd = epoll::create(epoll::CreateFlags::CLOEXEC)?; // Set up eventfd and timerfd. let event_fd = eventfd(0, EventfdFlags::CLOEXEC | EventfdFlags::NONBLOCK)?; @@ -47,24 +48,26 @@ impl Poller { timer_fd, }; - if let Some(ref timer_fd) = poller.timer_fd { + unsafe { + if let Some(ref timer_fd) = poller.timer_fd { + poller.add( + timer_fd.as_raw_fd(), + Event::none(crate::NOTIFY_KEY), + PollMode::Oneshot, + )?; + } + poller.add( - timer_fd.as_raw_fd(), - Event::none(crate::NOTIFY_KEY), + poller.event_fd.as_raw_fd(), + Event { + key: crate::NOTIFY_KEY, + readable: true, + writable: false, + }, PollMode::Oneshot, )?; } - poller.add( - poller.event_fd.as_raw_fd(), - Event { - key: crate::NOTIFY_KEY, - readable: true, - writable: false, - }, - PollMode::Oneshot, - )?; - tracing::trace!( epoll_fd = ?poller.epoll_fd.as_raw_fd(), event_fd = ?poller.event_fd.as_raw_fd(), @@ -85,7 +88,12 @@ impl Poller { } /// Adds a new file descriptor. - pub fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { + /// + /// # Safety + /// + /// The `fd` must be a valid file descriptor. The usual condition of remaining registered in + /// the `Poller` doesn't apply to `epoll`. + pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { let span = tracing::trace_span!( "add", epoll_fd = ?self.epoll_fd.as_raw_fd(), @@ -94,10 +102,10 @@ impl Poller { ); let _enter = span.enter(); - epoll::epoll_add( + epoll::add( &self.epoll_fd, unsafe { rustix::fd::BorrowedFd::borrow_raw(fd) }, - ev.key as u64, + epoll::EventData::new_u64(ev.key as u64), epoll_flags(&ev, mode), )?; @@ -105,7 +113,7 @@ impl Poller { } /// Modifies an existing file descriptor. - pub fn modify(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { + pub fn modify(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> { let span = tracing::trace_span!( "modify", epoll_fd = ?self.epoll_fd.as_raw_fd(), @@ -114,10 +122,10 @@ impl Poller { ); let _enter = span.enter(); - epoll::epoll_mod( + epoll::modify( &self.epoll_fd, - unsafe { rustix::fd::BorrowedFd::borrow_raw(fd) }, - ev.key as u64, + fd, + epoll::EventData::new_u64(ev.key as u64), epoll_flags(&ev, mode), )?; @@ -125,7 +133,7 @@ impl Poller { } /// Deletes a file descriptor. - pub fn delete(&self, fd: RawFd) -> io::Result<()> { + pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> { let span = tracing::trace_span!( "delete", epoll_fd = ?self.epoll_fd.as_raw_fd(), @@ -133,9 +141,7 @@ impl Poller { ); let _enter = span.enter(); - epoll::epoll_del(&self.epoll_fd, unsafe { - rustix::fd::BorrowedFd::borrow_raw(fd) - })?; + epoll::delete(&self.epoll_fd, fd)?; Ok(()) } @@ -170,7 +176,7 @@ impl Poller { // Set interest in timerfd. self.modify( - timer_fd.as_raw_fd(), + timer_fd.as_fd(), Event { key: crate::NOTIFY_KEY, readable: true, @@ -195,7 +201,7 @@ impl Poller { }; // Wait for I/O events. - epoll::epoll_wait(&self.epoll_fd, &mut events.list, timeout_ms)?; + epoll::wait(&self.epoll_fd, &mut events.list, timeout_ms)?; tracing::trace!( epoll_fd = ?self.epoll_fd.as_raw_fd(), res = ?events.list.len(), @@ -206,7 +212,7 @@ impl Poller { let mut buf = [0u8; 8]; let _ = read(&self.event_fd, &mut buf); self.modify( - self.event_fd.as_raw_fd(), + self.event_fd.as_fd(), Event { key: crate::NOTIFY_KEY, readable: true, @@ -255,9 +261,9 @@ impl Drop for Poller { let _enter = span.enter(); if let Some(timer_fd) = self.timer_fd.take() { - let _ = self.delete(timer_fd.as_raw_fd()); + let _ = self.delete(timer_fd.as_fd()); } - let _ = self.delete(self.event_fd.as_raw_fd()); + let _ = self.delete(self.event_fd.as_fd()); } } @@ -310,10 +316,13 @@ impl Events { /// Iterates over I/O events. pub fn iter(&self) -> impl Iterator + '_ { - self.list.iter().map(|(flags, data)| Event { - key: data as usize, - readable: flags.intersects(read_flags()), - writable: flags.intersects(write_flags()), + self.list.iter().map(|ev| { + let flags = ev.flags; + Event { + key: ev.data.u64() as usize, + readable: flags.intersects(read_flags()), + writable: flags.intersects(write_flags()), + } }) } } diff --git a/src/iocp/mod.rs b/src/iocp/mod.rs index d230e5c..d695090 100644 --- a/src/iocp/mod.rs +++ b/src/iocp/mod.rs @@ -42,7 +42,9 @@ use std::collections::hash_map::{Entry, HashMap}; use std::fmt; use std::io; use std::marker::PhantomPinned; -use std::os::windows::io::{AsHandle, AsRawHandle, BorrowedHandle, RawHandle, RawSocket}; +use std::os::windows::io::{ + AsHandle, AsRawHandle, AsRawSocket, BorrowedHandle, BorrowedSocket, RawHandle, RawSocket, +}; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak}; @@ -134,7 +136,16 @@ impl Poller { } /// Add a new source to the poller. - pub(super) fn add(&self, socket: RawSocket, interest: Event, mode: PollMode) -> io::Result<()> { + /// + /// # Safety + /// + /// The socket must be a valid socket and must last until it is deleted. + pub(super) unsafe fn add( + &self, + socket: RawSocket, + interest: Event, + mode: PollMode, + ) -> io::Result<()> { let span = tracing::trace_span!( "add", handle = ?self.port, @@ -192,7 +203,7 @@ impl Poller { /// Update a source in the poller. pub(super) fn modify( &self, - socket: RawSocket, + socket: BorrowedSocket<'_>, interest: Event, mode: PollMode, ) -> io::Result<()> { @@ -217,7 +228,7 @@ impl Poller { let sources = lock!(self.sources.read()); sources - .get(&socket) + .get(&socket.as_raw_socket()) .cloned() .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))? }; @@ -231,7 +242,7 @@ impl Poller { } /// Delete a source from the poller. - pub(super) fn delete(&self, socket: RawSocket) -> io::Result<()> { + pub(super) fn delete(&self, socket: BorrowedSocket<'_>) -> io::Result<()> { let span = tracing::trace_span!( "remove", handle = ?self.port, @@ -243,7 +254,7 @@ impl Poller { let source = { let mut sources = lock!(self.sources.write()); - match sources.remove(&socket) { + match sources.remove(&socket.as_raw_socket()) { Some(s) => s, None => { // If the source has already been removed, then we can just return. diff --git a/src/kqueue.rs b/src/kqueue.rs index 85ec04f..62b7ea3 100644 --- a/src/kqueue.rs +++ b/src/kqueue.rs @@ -1,11 +1,11 @@ //! Bindings to kqueue (macOS, iOS, tvOS, watchOS, FreeBSD, NetBSD, OpenBSD, DragonFly BSD). use std::io; -use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd}; +use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd}; use std::time::Duration; -use rustix::fd::OwnedFd; -use rustix::io::{fcntl_setfd, kqueue, Errno, FdFlags}; +use rustix::event::kqueue; +use rustix::io::{fcntl_setfd, Errno, FdFlags}; use crate::{Event, PollMode}; @@ -55,13 +55,17 @@ impl Poller { } /// Adds a new file descriptor. - pub fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { + /// + /// # Safety + /// + /// The file descriptor must be valid and it must last until it is deleted. + pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { // File descriptors don't need to be added explicitly, so just modify the interest. - self.modify(fd, ev, mode) + self.modify(BorrowedFd::borrow_raw(fd), ev, mode) } /// Modifies an existing file descriptor. - pub fn modify(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { + pub fn modify(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> { let span = if !self.notify.has_fd(fd) { let span = tracing::trace_span!( "add", @@ -91,12 +95,12 @@ impl Poller { // A list of changes for kqueue. let changelist = [ kqueue::Event::new( - kqueue::EventFilter::Read(fd), + kqueue::EventFilter::Read(fd.as_raw_fd()), read_flags | kqueue::EventFlags::RECEIPT, ev.key as _, ), kqueue::Event::new( - kqueue::EventFilter::Write(fd), + kqueue::EventFilter::Write(fd.as_raw_fd()), write_flags | kqueue::EventFlags::RECEIPT, ev.key as _, ), @@ -141,7 +145,7 @@ impl Poller { } /// Deletes a file descriptor. - pub fn delete(&self, fd: RawFd) -> io::Result<()> { + pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> { // Simply delete interest in the file descriptor. self.modify(fd, Event::none(0), PollMode::Oneshot) } @@ -268,9 +272,9 @@ pub(crate) fn mode_to_flags(mode: PollMode) -> kqueue::EventFlags { ))] mod notify { use super::Poller; - use rustix::io::kqueue; + use rustix::event::kqueue; use std::io; - use std::os::unix::io::RawFd; + use std::os::unix::io::BorrowedFd; /// A notification pipe. /// @@ -335,7 +339,7 @@ mod notify { } /// Whether this raw file descriptor is associated with this pipe. - pub(super) fn has_fd(&self, _fd: RawFd) -> bool { + pub(super) fn has_fd(&self, _fd: BorrowedFd<'_>) -> bool { false } } @@ -354,7 +358,7 @@ mod notify { use crate::{Event, PollMode, NOTIFY_KEY}; use std::io::{self, prelude::*}; use std::os::unix::{ - io::{AsRawFd, RawFd}, + io::{AsFd, AsRawFd, BorrowedFd}, net::UnixStream, }; @@ -386,11 +390,13 @@ mod notify { /// Registers this notification pipe in the `Poller`. pub(super) fn register(&self, poller: &Poller) -> io::Result<()> { // Register the read end of this pipe. - poller.add( - self.read_stream.as_raw_fd(), - Event::readable(NOTIFY_KEY), - PollMode::Oneshot, - ) + unsafe { + poller.add( + self.read_stream.as_raw_fd(), + Event::readable(NOTIFY_KEY), + PollMode::Oneshot, + ) + } } /// Reregister this notification pipe in the `Poller`. @@ -400,7 +406,7 @@ mod notify { // Reregister the read end of this pipe. poller.modify( - self.read_stream.as_raw_fd(), + self.read_stream.as_fd(), Event::readable(NOTIFY_KEY), PollMode::Oneshot, ) @@ -418,12 +424,12 @@ mod notify { /// Deregisters this notification pipe from the `Poller`. pub(super) fn deregister(&self, poller: &Poller) -> io::Result<()> { // Deregister the read end of the pipe. - poller.delete(self.read_stream.as_raw_fd()) + poller.delete(self.read_stream.as_fd()) } /// Whether this raw file descriptor is associated with this pipe. - pub(super) fn has_fd(&self, fd: RawFd) -> bool { - self.read_stream.as_raw_fd() == fd + pub(super) fn has_fd(&self, fd: BorrowedFd<'_>) -> bool { + self.read_stream.as_raw_fd() == fd.as_raw_fd() } } } diff --git a/src/lib.rs b/src/lib.rs index 0c3a301..188ab87 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,9 @@ //! //! // Create a poller and register interest in readability on the socket. //! let poller = Poller::new()?; -//! poller.add(&socket, Event::readable(key))?; +//! unsafe { +//! poller.add(&socket, Event::readable(key))?; +//! } //! //! // The event loop. //! let mut events = Vec::new(); @@ -46,13 +48,15 @@ //! } //! } //! } +//! +//! poller.delete(&socket)?; //! # std::io::Result::Ok(()) //! ``` #![cfg(feature = "std")] #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)] -#![allow(clippy::useless_conversion, clippy::unnecessary_cast)] +#![allow(clippy::useless_conversion, clippy::unnecessary_cast, unused_unsafe)] #![cfg_attr(docsrs, feature(doc_cfg))] #![doc( html_favicon_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png" @@ -273,8 +277,11 @@ impl Poller { /// [`modify()`][`Poller::modify()`] again after an event is delivered if we're interested in /// the next event of the same kind. /// - /// Don't forget to [`delete()`][`Poller::delete()`] the file descriptor or socket when it is - /// no longer used! + /// # Safety + /// + /// The source must be [`delete()`]d from this `Poller` before it is dropped. + /// + /// [`delete()`]: Poller::delete /// /// # Errors /// @@ -295,10 +302,13 @@ impl Poller { /// let key = 7; /// /// let poller = Poller::new()?; - /// poller.add(&source, Event::all(key))?; + /// unsafe { + /// poller.add(&source, Event::all(key))?; + /// } + /// poller.delete(&source)?; /// # std::io::Result::Ok(()) /// ``` - pub fn add(&self, source: impl Source, interest: Event) -> io::Result<()> { + pub unsafe fn add(&self, source: impl AsRawSource, interest: Event) -> io::Result<()> { self.add_with_mode(source, interest, PollMode::Oneshot) } @@ -307,13 +317,19 @@ impl Poller { /// This is identical to the `add()` function, but allows specifying the /// polling mode to use for this socket. /// + /// # Safety + /// + /// The source must be [`delete()`]d from this `Poller` before it is dropped. + /// + /// [`delete()`]: Poller::delete + /// /// # Errors /// /// If the operating system does not support the specified mode, this function /// will return an error. - pub fn add_with_mode( + pub unsafe fn add_with_mode( &self, - source: impl Source, + source: impl AsRawSource, interest: Event, mode: PollMode, ) -> io::Result<()> { @@ -354,7 +370,7 @@ impl Poller { /// # let source = std::net::TcpListener::bind("127.0.0.1:0")?; /// # let key = 7; /// # let poller = Poller::new()?; - /// # poller.add(&source, Event::none(key))?; + /// # unsafe { poller.add(&source, Event::none(key))?; } /// poller.modify(&source, Event::all(key))?; /// # std::io::Result::Ok(()) /// ``` @@ -366,8 +382,9 @@ impl Poller { /// # let source = std::net::TcpListener::bind("127.0.0.1:0")?; /// # let key = 7; /// # let poller = Poller::new()?; - /// # poller.add(&source, Event::none(key))?; + /// # unsafe { poller.add(&source, Event::none(key))?; } /// poller.modify(&source, Event::readable(key))?; + /// # poller.delete(&source)?; /// # std::io::Result::Ok(()) /// ``` /// @@ -378,8 +395,9 @@ impl Poller { /// # let poller = Poller::new()?; /// # let key = 7; /// # let source = std::net::TcpListener::bind("127.0.0.1:0")?; - /// # poller.add(&source, Event::none(key))?; + /// # unsafe { poller.add(&source, Event::none(key))? }; /// poller.modify(&source, Event::writable(key))?; + /// # poller.delete(&source)?; /// # std::io::Result::Ok(()) /// ``` /// @@ -390,11 +408,12 @@ impl Poller { /// # let source = std::net::TcpListener::bind("127.0.0.1:0")?; /// # let key = 7; /// # let poller = Poller::new()?; - /// # poller.add(&source, Event::none(key))?; + /// # unsafe { poller.add(&source, Event::none(key))?; } /// poller.modify(&source, Event::none(key))?; + /// # poller.delete(&source)?; /// # std::io::Result::Ok(()) /// ``` - pub fn modify(&self, source: impl Source, interest: Event) -> io::Result<()> { + pub fn modify(&self, source: impl AsSource, interest: Event) -> io::Result<()> { self.modify_with_mode(source, interest, PollMode::Oneshot) } @@ -415,7 +434,7 @@ impl Poller { /// an error. pub fn modify_with_mode( &self, - source: impl Source, + source: impl AsSource, interest: Event, mode: PollMode, ) -> io::Result<()> { @@ -425,7 +444,7 @@ impl Poller { "the key is not allowed to be `usize::MAX`", )); } - self.poller.modify(source.raw(), interest, mode) + self.poller.modify(source.source(), interest, mode) } /// Removes a file descriptor or socket from the poller. @@ -444,12 +463,12 @@ impl Poller { /// let key = 7; /// /// let poller = Poller::new()?; - /// poller.add(&socket, Event::all(key))?; + /// unsafe { poller.add(&socket, Event::all(key))?; } /// poller.delete(&socket)?; /// # std::io::Result::Ok(()) /// ``` - pub fn delete(&self, source: impl Source) -> io::Result<()> { - self.poller.delete(source.raw()) + pub fn delete(&self, source: impl AsSource) -> io::Result<()> { + self.poller.delete(source.source()) } /// Waits for at least one I/O event and returns the number of new events. @@ -482,10 +501,13 @@ impl Poller { /// let key = 7; /// /// let poller = Poller::new()?; - /// poller.add(&socket, Event::all(key))?; + /// unsafe { + /// poller.add(&socket, Event::all(key))?; + /// } /// /// let mut events = Vec::new(); /// let n = poller.wait(&mut events, Some(Duration::from_secs(1)))?; + /// poller.delete(&socket)?; /// # std::io::Result::Ok(()) /// ``` pub fn wait(&self, events: &mut Vec, timeout: Option) -> io::Result { @@ -624,45 +646,65 @@ impl fmt::Debug for Poller { cfg_if! { if #[cfg(unix)] { - use std::os::unix::io::{AsRawFd, RawFd}; + use std::os::unix::io::{AsRawFd, RawFd, AsFd, BorrowedFd}; - /// A [`RawFd`] or a reference to a type implementing [`AsRawFd`]. - pub trait Source { - /// Returns the [`RawFd`] for this I/O object. + /// A resource with a raw file descriptor. + pub trait AsRawSource { + /// Returns the raw file descriptor. fn raw(&self) -> RawFd; } - impl Source for RawFd { + impl AsRawSource for &T { fn raw(&self) -> RawFd { - *self + self.as_raw_fd() } } - impl Source for &T { + impl AsRawSource for RawFd { fn raw(&self) -> RawFd { - self.as_raw_fd() + *self } } + + /// A resource with a borrowed file descriptor. + pub trait AsSource: AsFd { + /// Returns the borrowed file descriptor. + fn source(&self) -> BorrowedFd<'_> { + self.as_fd() + } + } + + impl AsSource for T {} } else if #[cfg(windows)] { - use std::os::windows::io::{AsRawSocket, RawSocket}; + use std::os::windows::io::{AsRawSocket, RawSocket, AsSocket, BorrowedSocket}; - /// A [`RawSocket`] or a reference to a type implementing [`AsRawSocket`]. - pub trait Source { - /// Returns the [`RawSocket`] for this I/O object. + /// A resource with a raw socket. + pub trait AsRawSource { + /// Returns the raw socket. fn raw(&self) -> RawSocket; } - impl Source for RawSocket { + impl AsRawSource for &T { fn raw(&self) -> RawSocket { - *self + self.as_raw_socket() } } - impl Source for &T { + impl AsRawSource for RawSocket { fn raw(&self) -> RawSocket { - self.as_raw_socket() + *self + } + } + + /// A resource with a borrowed socket. + pub trait AsSource: AsSocket { + /// Returns the borrowed socket. + fn source(&self) -> BorrowedSocket<'_> { + self.as_socket() } } + + impl AsSource for T {} } } diff --git a/src/os/kqueue.rs b/src/os/kqueue.rs index 8ea8729..684bd3e 100644 --- a/src/os/kqueue.rs +++ b/src/os/kqueue.rs @@ -7,7 +7,7 @@ use std::io; use std::process::Child; use std::time::Duration; -use rustix::io::kqueue; +use rustix::event::kqueue; use super::__private::PollerSealed; use __private::FilterSealed; @@ -238,7 +238,7 @@ unsafe impl FilterSealed for Timer { impl Filter for Timer {} mod __private { - use rustix::io::kqueue; + use rustix::event::kqueue; #[doc(hidden)] pub unsafe trait FilterSealed { diff --git a/src/poll.rs b/src/poll.rs index 846de80..1ecf278 100644 --- a/src/poll.rs +++ b/src/poll.rs @@ -7,12 +7,11 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Condvar, Mutex}; use std::time::{Duration, Instant}; +use rustix::event::{poll, PollFd, PollFlags}; use rustix::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd}; use rustix::fs::{fcntl_getfl, fcntl_setfl, OFlags}; -use rustix::io::{ - fcntl_getfd, fcntl_setfd, pipe, pipe_with, poll, read, write, FdFlags, PipeFlags, PollFd, - PollFlags, -}; +use rustix::io::{fcntl_getfd, fcntl_setfd, read, write, FdFlags}; +use rustix::pipe::{pipe, pipe_with, PipeFlags}; // std::os::unix doesn't exist on Fuchsia type RawFd = std::os::raw::c_int; @@ -158,7 +157,7 @@ impl Poller { } /// Modifies an existing file descriptor. - pub fn modify(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { + pub fn modify(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> { let span = tracing::trace_span!( "modify", notify_read = ?self.notify_read, @@ -168,11 +167,19 @@ impl Poller { let _enter = span.enter(); self.modify_fds(|fds| { - let data = fds.fd_data.get_mut(&fd).ok_or(io::ErrorKind::NotFound)?; + let data = fds + .fd_data + .get_mut(&fd.as_raw_fd()) + .ok_or(io::ErrorKind::NotFound)?; data.key = ev.key; let poll_fds_index = data.poll_fds_index; - fds.poll_fds[poll_fds_index] = - PollFd::from_borrowed_fd(unsafe { BorrowedFd::borrow_raw(fd) }, poll_events(ev)); + + // SAFETY: This is essentially transmuting a `PollFd<'a>` to a `PollFd<'static>`, which + // only works if it's removed in time with `delete()`. + fds.poll_fds[poll_fds_index] = PollFd::from_borrowed_fd( + unsafe { BorrowedFd::borrow_raw(fd.as_raw_fd()) }, + poll_events(ev), + ); data.remove = cvt_mode_as_remove(mode)?; Ok(()) @@ -180,7 +187,7 @@ impl Poller { } /// Deletes a file descriptor. - pub fn delete(&self, fd: RawFd) -> io::Result<()> { + pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> { let span = tracing::trace_span!( "delete", notify_read = ?self.notify_read, @@ -189,7 +196,10 @@ impl Poller { let _enter = span.enter(); self.modify_fds(|fds| { - let data = fds.fd_data.remove(&fd).ok_or(io::ErrorKind::NotFound)?; + let data = fds + .fd_data + .remove(&fd.as_raw_fd()) + .ok_or(io::ErrorKind::NotFound)?; fds.poll_fds.swap_remove(data.poll_fds_index); if let Some(swapped_pollfd) = fds.poll_fds.get(data.poll_fds_index) { fds.fd_data diff --git a/src/port.rs b/src/port.rs index 1149bfb..bd55b15 100644 --- a/src/port.rs +++ b/src/port.rs @@ -4,8 +4,9 @@ use std::io; use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd}; use std::time::Duration; +use rustix::event::{port, PollFlags}; use rustix::fd::OwnedFd; -use rustix::io::{fcntl_getfd, fcntl_setfd, port, FdFlags, PollFlags}; +use rustix::io::{fcntl_getfd, fcntl_setfd, FdFlags}; use crate::{Event, PollMode}; @@ -42,13 +43,17 @@ impl Poller { } /// Adds a file descriptor. - pub fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { + /// + /// # Safety + /// + /// The `fd` must be a valid file descriptor and it must last until it is deleted. + pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { // File descriptors don't need to be added explicitly, so just modify the interest. - self.modify(fd, ev, mode) + self.modify(BorrowedFd::borrow_raw(fd), ev, mode) } /// Modifies an existing file descriptor. - pub fn modify(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { + pub fn modify(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> { let span = tracing::trace_span!( "modify", port_fd = ?self.port_fd.as_raw_fd(), @@ -79,7 +84,7 @@ impl Poller { } /// Deletes a file descriptor. - pub fn delete(&self, fd: RawFd) -> io::Result<()> { + pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> { let span = tracing::trace_span!( "delete", port_fd = ?self.port_fd.as_raw_fd(), diff --git a/tests/concurrent_modification.rs b/tests/concurrent_modification.rs index 7f31f05..8cf6691 100644 --- a/tests/concurrent_modification.rs +++ b/tests/concurrent_modification.rs @@ -13,20 +13,25 @@ fn concurrent_add() -> io::Result<()> { let mut events = Vec::new(); - Parallel::new() + let result = Parallel::new() .add(|| { poller.wait(&mut events, None)?; Ok(()) }) .add(|| { thread::sleep(Duration::from_millis(100)); - poller.add(&reader, Event::readable(0))?; + unsafe { + poller.add(&reader, Event::readable(0))?; + } writer.write_all(&[1])?; Ok(()) }) .run() .into_iter() - .collect::>()?; + .collect::>(); + + poller.delete(&reader)?; + result?; assert_eq!(events, [Event::readable(0)]); @@ -37,7 +42,9 @@ fn concurrent_add() -> io::Result<()> { fn concurrent_modify() -> io::Result<()> { let (reader, mut writer) = tcp_pair()?; let poller = Poller::new()?; - poller.add(&reader, Event::none(0))?; + unsafe { + poller.add(&reader, Event::none(0))?; + } let mut events = Vec::new(); diff --git a/tests/io.rs b/tests/io.rs index ab0c8a8..10b3d48 100644 --- a/tests/io.rs +++ b/tests/io.rs @@ -7,7 +7,9 @@ use std::time::Duration; fn basic_io() { let poller = Poller::new().unwrap(); let (read, mut write) = tcp_pair().unwrap(); - poller.add(&read, Event::readable(1)).unwrap(); + unsafe { + poller.add(&read, Event::readable(1)).unwrap(); + } // Nothing should be available at first. let mut events = vec![]; @@ -28,6 +30,8 @@ fn basic_io() { 1 ); assert_eq!(&*events, &[Event::readable(1)]); + + poller.delete(&read).unwrap(); } fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> { diff --git a/tests/many_connections.rs b/tests/many_connections.rs index 06e4301..41d640b 100644 --- a/tests/many_connections.rs +++ b/tests/many_connections.rs @@ -20,7 +20,9 @@ fn many_connections() { let poller = polling::Poller::new().unwrap(); for (i, reader, _) in connections.iter() { - poller.add(reader, polling::Event::readable(*i)).unwrap(); + unsafe { + poller.add(reader, polling::Event::readable(*i)).unwrap(); + } } let mut events = vec![]; diff --git a/tests/other_modes.rs b/tests/other_modes.rs index fbe8779..12ef1bc 100644 --- a/tests/other_modes.rs +++ b/tests/other_modes.rs @@ -16,8 +16,7 @@ fn level_triggered() { // Create our poller and register our streams. let poller = Poller::new().unwrap(); - if poller - .add_with_mode(&reader, Event::readable(reader_token), PollMode::Level) + if unsafe { poller.add_with_mode(&reader, Event::readable(reader_token), PollMode::Level) } .is_err() { // Only panic if we're on a platform that should support level mode. @@ -92,8 +91,7 @@ fn edge_triggered() { // Create our poller and register our streams. let poller = Poller::new().unwrap(); - if poller - .add_with_mode(&reader, Event::readable(reader_token), PollMode::Edge) + if unsafe { poller.add_with_mode(&reader, Event::readable(reader_token), PollMode::Edge) } .is_err() { // Only panic if we're on a platform that should support level mode. @@ -170,13 +168,14 @@ fn edge_oneshot_triggered() { // Create our poller and register our streams. let poller = Poller::new().unwrap(); - if poller - .add_with_mode( + if unsafe { + poller.add_with_mode( &reader, Event::readable(reader_token), PollMode::EdgeOneshot, ) - .is_err() + } + .is_err() { // Only panic if we're on a platform that should support level mode. cfg_if::cfg_if! {