Skip to content

Commit

Permalink
breaking: Rework the API for I/O safety
Browse files Browse the repository at this point in the history
* Rework the API for I/O safety

* Bump to rustix v0.38
  • Loading branch information
notgull authored Aug 4, 2023
1 parent c86c389 commit 6eb7679
Show file tree
Hide file tree
Showing 15 changed files with 228 additions and 154 deletions.
15 changes: 0 additions & 15 deletions .cirrus.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ freebsd_task:
- sudo sysctl net.inet.tcp.blackhole=0
- . $HOME/.cargo/env
- cargo test --target $TARGET
# Test async-io
- git clone https://github.com/smol-rs/async-io.git
- echo '[patch.crates-io]' >> async-io/Cargo.toml
- echo 'polling = { path = ".." }' >> async-io/Cargo.toml
- cargo test --target $TARGET --manifest-path=async-io/Cargo.toml

netbsd_task:
name: test ($TARGET)
Expand All @@ -49,11 +44,6 @@ netbsd_task:
test_script:
- . $HOME/.cargo/env
- cargo test
# Test async-io
- git clone https://github.com/smol-rs/async-io.git
- echo '[patch.crates-io]' >> async-io/Cargo.toml
- echo 'polling = { path = ".." }' >> async-io/Cargo.toml
- cargo test --manifest-path=async-io/Cargo.toml

openbsd_task:
name: test ($TARGET)
Expand All @@ -69,8 +59,3 @@ openbsd_task:
- pkg_add git rust
test_script:
- cargo test
# Test async-io
- git clone https://github.com/smol-rs/async-io.git
- echo '[patch.crates-io]' >> async-io/Cargo.toml
- echo 'polling = { path = ".." }' >> async-io/Cargo.toml
- cargo test --manifest-path=async-io/Cargo.toml
8 changes: 0 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,6 @@ jobs:
RUSTFLAGS: ${{ env.RUSTFLAGS }} --cfg polling_test_poll_backend
if: startsWith(matrix.os, 'ubuntu')
- run: cargo hack build --feature-powerset --no-dev-deps
- name: Clone async-io
run: git clone https://github.com/smol-rs/async-io.git
- name: Add patch section
run: echo '[patch.crates-io]' >> async-io/Cargo.toml
- name: Patch polling
run: echo 'polling = { path = ".." }' >> async-io/Cargo.toml
- name: Test async-io
run: cargo test --manifest-path=async-io/Cargo.toml

cross:
runs-on: ${{ matrix.os }}
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ tracing = { version = "0.1.37", default-features = false }

[target.'cfg(any(unix, target_os = "fuchsia", target_os = "vxworks"))'.dependencies]
libc = "0.2.77"
rustix = { version = "0.37.11", features = ["process", "time", "fs", "std"], default-features = false }
rustix = { version = "0.38", features = ["event", "fs", "pipe", "process", "std", "time"], default-features = false }

[target.'cfg(windows)'.dependencies]
concurrent-queue = "2.2.0"
Expand Down
6 changes: 4 additions & 2 deletions examples/two-listeners.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ fn main() -> io::Result<()> {
l2.set_nonblocking(true)?;

let poller = Poller::new()?;
poller.add(&l1, Event::readable(1))?;
poller.add(&l2, Event::readable(2))?;
unsafe {
poller.add(&l1, Event::readable(1))?;
poller.add(&l2, Event::readable(2))?;
}

println!("You can connect to the server using `nc`:");
println!(" $ nc 127.0.0.1 8001");
Expand Down
79 changes: 44 additions & 35 deletions src/epoll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use std::io;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd};
use std::time::Duration;

use rustix::event::{epoll, eventfd, EventfdFlags};
use rustix::fd::OwnedFd;
use rustix::io::{epoll, eventfd, read, write, EventfdFlags};
use rustix::io::{read, write};
use rustix::time::{
timerfd_create, timerfd_settime, Itimerspec, TimerfdClockId, TimerfdFlags, TimerfdTimerFlags,
Timespec,
Expand All @@ -31,7 +32,7 @@ impl Poller {
// Create an epoll instance.
//
// Use `epoll_create1` with `EPOLL_CLOEXEC`.
let epoll_fd = epoll::epoll_create(epoll::CreateFlags::CLOEXEC)?;
let epoll_fd = epoll::create(epoll::CreateFlags::CLOEXEC)?;

// Set up eventfd and timerfd.
let event_fd = eventfd(0, EventfdFlags::CLOEXEC | EventfdFlags::NONBLOCK)?;
Expand All @@ -47,24 +48,26 @@ impl Poller {
timer_fd,
};

if let Some(ref timer_fd) = poller.timer_fd {
unsafe {
if let Some(ref timer_fd) = poller.timer_fd {
poller.add(
timer_fd.as_raw_fd(),
Event::none(crate::NOTIFY_KEY),
PollMode::Oneshot,
)?;
}

poller.add(
timer_fd.as_raw_fd(),
Event::none(crate::NOTIFY_KEY),
poller.event_fd.as_raw_fd(),
Event {
key: crate::NOTIFY_KEY,
readable: true,
writable: false,
},
PollMode::Oneshot,
)?;
}

poller.add(
poller.event_fd.as_raw_fd(),
Event {
key: crate::NOTIFY_KEY,
readable: true,
writable: false,
},
PollMode::Oneshot,
)?;

tracing::trace!(
epoll_fd = ?poller.epoll_fd.as_raw_fd(),
event_fd = ?poller.event_fd.as_raw_fd(),
Expand All @@ -85,7 +88,12 @@ impl Poller {
}

/// Adds a new file descriptor.
pub fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
///
/// # Safety
///
/// The `fd` must be a valid file descriptor. The usual condition of remaining registered in
/// the `Poller` doesn't apply to `epoll`.
pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
let span = tracing::trace_span!(
"add",
epoll_fd = ?self.epoll_fd.as_raw_fd(),
Expand All @@ -94,18 +102,18 @@ impl Poller {
);
let _enter = span.enter();

epoll::epoll_add(
epoll::add(
&self.epoll_fd,
unsafe { rustix::fd::BorrowedFd::borrow_raw(fd) },
ev.key as u64,
epoll::EventData::new_u64(ev.key as u64),
epoll_flags(&ev, mode),
)?;

Ok(())
}

/// Modifies an existing file descriptor.
pub fn modify(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
pub fn modify(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> {
let span = tracing::trace_span!(
"modify",
epoll_fd = ?self.epoll_fd.as_raw_fd(),
Expand All @@ -114,28 +122,26 @@ impl Poller {
);
let _enter = span.enter();

epoll::epoll_mod(
epoll::modify(
&self.epoll_fd,
unsafe { rustix::fd::BorrowedFd::borrow_raw(fd) },
ev.key as u64,
fd,
epoll::EventData::new_u64(ev.key as u64),
epoll_flags(&ev, mode),
)?;

Ok(())
}

/// Deletes a file descriptor.
pub fn delete(&self, fd: RawFd) -> io::Result<()> {
pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> {
let span = tracing::trace_span!(
"delete",
epoll_fd = ?self.epoll_fd.as_raw_fd(),
?fd,
);
let _enter = span.enter();

epoll::epoll_del(&self.epoll_fd, unsafe {
rustix::fd::BorrowedFd::borrow_raw(fd)
})?;
epoll::delete(&self.epoll_fd, fd)?;

Ok(())
}
Expand Down Expand Up @@ -170,7 +176,7 @@ impl Poller {

// Set interest in timerfd.
self.modify(
timer_fd.as_raw_fd(),
timer_fd.as_fd(),
Event {
key: crate::NOTIFY_KEY,
readable: true,
Expand All @@ -195,7 +201,7 @@ impl Poller {
};

// Wait for I/O events.
epoll::epoll_wait(&self.epoll_fd, &mut events.list, timeout_ms)?;
epoll::wait(&self.epoll_fd, &mut events.list, timeout_ms)?;
tracing::trace!(
epoll_fd = ?self.epoll_fd.as_raw_fd(),
res = ?events.list.len(),
Expand All @@ -206,7 +212,7 @@ impl Poller {
let mut buf = [0u8; 8];
let _ = read(&self.event_fd, &mut buf);
self.modify(
self.event_fd.as_raw_fd(),
self.event_fd.as_fd(),
Event {
key: crate::NOTIFY_KEY,
readable: true,
Expand Down Expand Up @@ -255,9 +261,9 @@ impl Drop for Poller {
let _enter = span.enter();

if let Some(timer_fd) = self.timer_fd.take() {
let _ = self.delete(timer_fd.as_raw_fd());
let _ = self.delete(timer_fd.as_fd());
}
let _ = self.delete(self.event_fd.as_raw_fd());
let _ = self.delete(self.event_fd.as_fd());
}
}

Expand Down Expand Up @@ -310,10 +316,13 @@ impl Events {

/// Iterates over I/O events.
pub fn iter(&self) -> impl Iterator<Item = Event> + '_ {
self.list.iter().map(|(flags, data)| Event {
key: data as usize,
readable: flags.intersects(read_flags()),
writable: flags.intersects(write_flags()),
self.list.iter().map(|ev| {
let flags = ev.flags;
Event {
key: ev.data.u64() as usize,
readable: flags.intersects(read_flags()),
writable: flags.intersects(write_flags()),
}
})
}
}
23 changes: 17 additions & 6 deletions src/iocp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ use std::collections::hash_map::{Entry, HashMap};
use std::fmt;
use std::io;
use std::marker::PhantomPinned;
use std::os::windows::io::{AsHandle, AsRawHandle, BorrowedHandle, RawHandle, RawSocket};
use std::os::windows::io::{
AsHandle, AsRawHandle, AsRawSocket, BorrowedHandle, BorrowedSocket, RawHandle, RawSocket,
};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak};
Expand Down Expand Up @@ -134,7 +136,16 @@ impl Poller {
}

/// Add a new source to the poller.
pub(super) fn add(&self, socket: RawSocket, interest: Event, mode: PollMode) -> io::Result<()> {
///
/// # Safety
///
/// The socket must be a valid socket and must last until it is deleted.
pub(super) unsafe fn add(
&self,
socket: RawSocket,
interest: Event,
mode: PollMode,
) -> io::Result<()> {
let span = tracing::trace_span!(
"add",
handle = ?self.port,
Expand Down Expand Up @@ -192,7 +203,7 @@ impl Poller {
/// Update a source in the poller.
pub(super) fn modify(
&self,
socket: RawSocket,
socket: BorrowedSocket<'_>,
interest: Event,
mode: PollMode,
) -> io::Result<()> {
Expand All @@ -217,7 +228,7 @@ impl Poller {
let sources = lock!(self.sources.read());

sources
.get(&socket)
.get(&socket.as_raw_socket())
.cloned()
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?
};
Expand All @@ -231,7 +242,7 @@ impl Poller {
}

/// Delete a source from the poller.
pub(super) fn delete(&self, socket: RawSocket) -> io::Result<()> {
pub(super) fn delete(&self, socket: BorrowedSocket<'_>) -> io::Result<()> {
let span = tracing::trace_span!(
"remove",
handle = ?self.port,
Expand All @@ -243,7 +254,7 @@ impl Poller {
let source = {
let mut sources = lock!(self.sources.write());

match sources.remove(&socket) {
match sources.remove(&socket.as_raw_socket()) {
Some(s) => s,
None => {
// If the source has already been removed, then we can just return.
Expand Down
Loading

0 comments on commit 6eb7679

Please sign in to comment.