Skip to content

Commit

Permalink
feat(driver): remove slab and use ptr as user data
Browse files Browse the repository at this point in the history
  • Loading branch information
Berrysoft committed May 14, 2024
1 parent 88495f1 commit f6e7c4a
Show file tree
Hide file tree
Showing 15 changed files with 287 additions and 353 deletions.
1 change: 0 additions & 1 deletion compio-driver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ compio-log = { workspace = true }
cfg-if = { workspace = true }
crossbeam-channel = { workspace = true }
futures-util = { workspace = true }
slab = { workspace = true }
socket2 = { workspace = true }

# Windows specific dependencies
Expand Down
22 changes: 10 additions & 12 deletions compio-driver/src/fusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ pub use driver_type::DriverType;
pub(crate) use iour::{sockaddr_storage, socklen_t};
pub use iour::{OpCode as IourOpCode, OpEntry};
pub use poll::{Decision, OpCode as PollOpCode};
use slab::Slab;

pub(crate) use crate::RawOp;
use crate::{OutEntries, ProactorBuilder};
use crate::{Key, OutEntries, ProactorBuilder};

mod driver_type {
use std::sync::atomic::{AtomicU8, Ordering};
Expand Down Expand Up @@ -136,10 +134,10 @@ impl Driver {
}
}

pub fn create_op<T: OpCode + 'static>(&self, user_data: usize, op: T) -> RawOp {
pub fn create_op<T: OpCode + 'static>(&self, op: T) -> Key<T> {
match &self.fuse {
FuseDriver::Poll(driver) => driver.create_op(user_data, op),
FuseDriver::IoUring(driver) => driver.create_op(user_data, op),
FuseDriver::Poll(driver) => driver.create_op(op),
FuseDriver::IoUring(driver) => driver.create_op(op),
}
}

Expand All @@ -150,17 +148,17 @@ impl Driver {
}
}

pub fn cancel(&mut self, user_data: usize, registry: &mut Slab<RawOp>) {
pub fn cancel<T>(&mut self, op: Key<T>) {
match &mut self.fuse {
FuseDriver::Poll(driver) => driver.cancel(user_data, registry),
FuseDriver::IoUring(driver) => driver.cancel(user_data, registry),
FuseDriver::Poll(driver) => driver.cancel(op),
FuseDriver::IoUring(driver) => driver.cancel(op),
}
}

pub fn push(&mut self, user_data: usize, op: &mut RawOp) -> Poll<io::Result<usize>> {
pub fn push<T: OpCode + 'static>(&mut self, op: &mut Key<T>) -> Poll<io::Result<usize>> {
match &mut self.fuse {
FuseDriver::Poll(driver) => driver.push(user_data, op),
FuseDriver::IoUring(driver) => driver.push(user_data, op),
FuseDriver::Poll(driver) => driver.push(op),
FuseDriver::IoUring(driver) => driver.push(op),
}
}

Expand Down
18 changes: 5 additions & 13 deletions compio-driver/src/iocp/cp/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,11 @@ impl GlobalPort {
self.port.attach(fd)
}

pub fn post<T: ?Sized>(
&self,
res: io::Result<usize>,
optr: *mut Overlapped<T>,
) -> io::Result<()> {
pub fn post(&self, res: io::Result<usize>, optr: *mut Overlapped) -> io::Result<()> {
self.port.post(res, optr)
}

pub fn post_raw<T: ?Sized>(&self, optr: *const Overlapped<T>) -> io::Result<()> {
pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> {
self.port.post_raw(optr)
}
}
Expand All @@ -62,7 +58,7 @@ fn iocp_start() -> io::Result<()> {
loop {
for entry in port.port.poll_raw(None)? {
// Any thin pointer is OK because we don't use the type of opcode.
let overlapped_ptr: *mut Overlapped<()> = entry.lpOverlapped.cast();
let overlapped_ptr: *mut Overlapped = entry.lpOverlapped.cast();
let overlapped = unsafe { &*overlapped_ptr };
if let Err(_e) = syscall!(
BOOL,
Expand Down Expand Up @@ -135,15 +131,11 @@ impl PortHandle {
Self { port }
}

pub fn post<T: ?Sized>(
&self,
res: io::Result<usize>,
optr: *mut Overlapped<T>,
) -> io::Result<()> {
pub fn post(&self, res: io::Result<usize>, optr: *mut Overlapped) -> io::Result<()> {
self.port.post(res, optr)
}

pub fn post_raw<T: ?Sized>(&self, optr: *const Overlapped<T>) -> io::Result<()> {
pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> {
self.port.post_raw(optr)
}
}
12 changes: 4 additions & 8 deletions compio-driver/src/iocp/cp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,7 @@ impl CompletionPort {
Ok(())
}

pub fn post<T: ?Sized>(
&self,
res: io::Result<usize>,
optr: *mut Overlapped<T>,
) -> io::Result<()> {
pub fn post(&self, res: io::Result<usize>, optr: *mut Overlapped) -> io::Result<()> {
if let Some(overlapped) = unsafe { optr.as_mut() } {
match &res {
Ok(transferred) => {
Expand All @@ -97,7 +93,7 @@ impl CompletionPort {
self.post_raw(optr)
}

pub fn post_raw<T: ?Sized>(&self, optr: *const Overlapped<T>) -> io::Result<()> {
pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> {
syscall!(
BOOL,
PostQueuedCompletionStatus(self.port.as_raw_handle() as _, 0, 0, optr.cast())
Expand Down Expand Up @@ -143,7 +139,7 @@ impl CompletionPort {
) -> io::Result<impl Iterator<Item = Entry>> {
Ok(self.poll_raw(timeout)?.filter_map(move |entry| {
// Any thin pointer is OK because we don't use the type of opcode.
let overlapped_ptr: *mut Overlapped<()> = entry.lpOverlapped.cast();
let overlapped_ptr: *mut Overlapped = entry.lpOverlapped.cast();
let overlapped = unsafe { &*overlapped_ptr };
if let Some(current_driver) = current_driver {
if overlapped.driver != current_driver {
Expand Down Expand Up @@ -181,7 +177,7 @@ impl CompletionPort {
_ => Err(io::Error::from_raw_os_error(error as _)),
}
};
Some(Entry::new(overlapped.user_data, res))
Some(Entry::new(overlapped_ptr as usize, res))
}))
}
}
Expand Down
8 changes: 2 additions & 6 deletions compio-driver/src/iocp/cp/multi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,11 @@ impl PortHandle {
Self { port }
}

pub fn post<T: ?Sized>(
&self,
res: io::Result<usize>,
optr: *mut Overlapped<T>,
) -> io::Result<()> {
pub fn post(&self, res: io::Result<usize>, optr: *mut Overlapped) -> io::Result<()> {
self.port.post(res, optr)
}

pub fn post_raw<T: ?Sized>(&self, optr: *const Overlapped<T>) -> io::Result<()> {
pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> {
self.port.post_raw(optr)
}
}
98 changes: 42 additions & 56 deletions compio-driver/src/iocp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@ use std::{
},
},
pin::Pin,
ptr::{null, NonNull},
ptr::null,
sync::Arc,
task::Poll,
time::Duration,
};

use compio_log::{instrument, trace};
use slab::Slab;
use windows_sys::Win32::{
Foundation::{ERROR_BUSY, ERROR_OPERATION_ABORTED, ERROR_TIMEOUT, WAIT_OBJECT_0, WAIT_TIMEOUT},
Networking::WinSock::{WSACleanup, WSAStartup, WSADATA},
Expand All @@ -29,7 +28,7 @@ use windows_sys::Win32::{
},
};

use crate::{syscall, AsyncifyPool, Entry, OutEntries, ProactorBuilder, RawOp};
use crate::{syscall, AsyncifyPool, Entry, Key, OutEntries, ProactorBuilder};

pub(crate) mod op;

Expand Down Expand Up @@ -173,12 +172,10 @@ pub(crate) struct Driver {
waits: HashMap<usize, WinThreadpollWait>,
cancelled: HashSet<usize>,
pool: AsyncifyPool,
notify_overlapped: Arc<Overlapped<()>>,
notify_overlapped: Arc<Overlapped>,
}

impl Driver {
const NOTIFY: usize = usize::MAX;

pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
instrument!(compio_log::Level::TRACE, "new", ?builder);
let mut data: WSADATA = unsafe { std::mem::zeroed() };
Expand All @@ -191,33 +188,33 @@ impl Driver {
waits: HashMap::default(),
cancelled: HashSet::default(),
pool: builder.create_or_get_thread_pool(),
notify_overlapped: Arc::new(Overlapped::new(driver, Self::NOTIFY, ())),
notify_overlapped: Arc::new(Overlapped::new(driver)),
})
}

pub fn create_op<T: OpCode + 'static>(&self, user_data: usize, op: T) -> RawOp {
RawOp::new(self.port.as_raw_handle() as _, user_data, op)
pub fn create_op<T: OpCode + 'static>(&self, op: T) -> Key<T> {
Key::new(self.port.as_raw_handle() as _, op)
}

pub fn attach(&mut self, fd: RawFd) -> io::Result<()> {
self.port.attach(fd)
}

pub fn cancel(&mut self, user_data: usize, registry: &mut Slab<RawOp>) {
instrument!(compio_log::Level::TRACE, "cancel", user_data);
pub fn cancel<T: OpCode>(&mut self, mut op: Key<T>) {
instrument!(compio_log::Level::TRACE, "cancel", ?op);
trace!("cancel RawOp");
let user_data = op.user_data();
self.cancelled.insert(user_data);
if let Some(op) = registry.get_mut(user_data) {
let overlapped_ptr = op.as_mut_ptr();
let op = op.as_op_pin();
// It's OK to fail to cancel.
trace!("call OpCode::cancel");
unsafe { op.cancel(overlapped_ptr.cast()) }.ok();
}
let overlapped_ptr = op.as_mut_ptr();
let op = op.as_op_pin();
// It's OK to fail to cancel.
trace!("call OpCode::cancel");
unsafe { op.cancel(overlapped_ptr.cast()) }.ok();
}

pub fn push(&mut self, user_data: usize, op: &mut RawOp) -> Poll<io::Result<usize>> {
instrument!(compio_log::Level::TRACE, "push", user_data);
pub fn push<T: OpCode + 'static>(&mut self, op: &mut Key<T>) -> Poll<io::Result<usize>> {
instrument!(compio_log::Level::TRACE, "push", ?op);
let user_data = op.user_data();
if self.cancelled.remove(&user_data) {
trace!("pushed RawOp already cancelled");
Poll::Ready(Err(io::Error::from_raw_os_error(
Expand All @@ -230,7 +227,7 @@ impl Driver {
match op_pin.op_type() {
OpType::Overlapped => unsafe { op_pin.operate(optr.cast()) },
OpType::Blocking => {
if self.push_blocking(op)? {
if self.push_blocking(user_data)? {
Poll::Pending
} else {
Poll::Ready(Err(io::Error::from_raw_os_error(ERROR_BUSY as _)))
Expand All @@ -247,20 +244,12 @@ impl Driver {
}
}

fn push_blocking(&mut self, op: &mut RawOp) -> io::Result<bool> {
// Safety: the RawOp is not released before the operation returns.
struct SendWrapper<T>(T);
unsafe impl<T> Send for SendWrapper<T> {}

let optr = SendWrapper(NonNull::from(op));
fn push_blocking(&mut self, user_data: usize) -> io::Result<bool> {
let port = self.port.handle();
Ok(self
.pool
.dispatch(move || {
#[allow(clippy::redundant_locals)]
let mut optr = optr;
// Safety: the pointer is created from a reference.
let op = unsafe { optr.0.as_mut() };
let op = unsafe { Key::upcast(user_data) };
let optr = op.as_mut_ptr();
let res = op.operate_blocking();
port.post(res, optr).ok();
Expand All @@ -269,12 +258,13 @@ impl Driver {
}

fn create_entry(
notify_user_data: usize,
cancelled: &mut HashSet<usize>,
waits: &mut HashMap<usize, WinThreadpollWait>,
entry: Entry,
) -> Option<Entry> {
let user_data = entry.user_data();
if user_data != Self::NOTIFY {
if user_data != notify_user_data {
waits.remove(&user_data);
let result = if cancelled.remove(&user_data) {
Err(io::Error::from_raw_os_error(ERROR_OPERATION_ABORTED as _))
Expand All @@ -294,11 +284,11 @@ impl Driver {
) -> io::Result<()> {
instrument!(compio_log::Level::TRACE, "poll", ?timeout);

entries.extend(
self.port
.poll(timeout)?
.filter_map(|e| Self::create_entry(&mut self.cancelled, &mut self.waits, e)),
);
let notify_user_data = self.notify_overlapped.as_ref() as *const Overlapped as usize;

entries.extend(self.port.poll(timeout)?.filter_map(|e| {
Self::create_entry(notify_user_data, &mut self.cancelled, &mut self.waits, e)
}));

Ok(())
}
Expand Down Expand Up @@ -326,11 +316,11 @@ impl Drop for Driver {
/// A notify handle to the inner driver.
pub struct NotifyHandle {
port: cp::PortHandle,
overlapped: Arc<Overlapped<()>>,
overlapped: Arc<Overlapped>,
}

impl NotifyHandle {
fn new(port: cp::PortHandle, overlapped: Arc<Overlapped<()>>) -> Self {
fn new(port: cp::PortHandle, overlapped: Arc<Overlapped>) -> Self {
Self { port, overlapped }
}

Expand All @@ -348,8 +338,11 @@ struct WinThreadpollWait {
}

impl WinThreadpollWait {
pub fn new(port: cp::PortHandle, event: RawFd, op: &mut RawOp) -> io::Result<Self> {
let mut context = Box::new(WinThreadpollWaitContext { port, op });
pub fn new<T>(port: cp::PortHandle, event: RawFd, op: &mut Key<T>) -> io::Result<Self> {
let mut context = Box::new(WinThreadpollWaitContext {
port,
user_data: op.user_data(),
});
let wait = syscall!(
BOOL,
CreateThreadpoolWait(
Expand All @@ -376,13 +369,13 @@ impl WinThreadpollWait {
WAIT_TIMEOUT => Err(io::Error::from_raw_os_error(ERROR_TIMEOUT as _)),
_ => Err(io::Error::from_raw_os_error(result as _)),
};
let op = unsafe { Key::upcast(context.user_data) };
let res = if res.is_err() {
res
} else {
let op = unsafe { &mut *context.op };
op.operate_blocking()
};
context.port.post(res, (*context.op).as_mut_ptr()).ok();
context.port.post(res, op.as_mut_ptr()).ok();
}
}

Expand All @@ -398,34 +391,27 @@ impl Drop for WinThreadpollWait {

struct WinThreadpollWaitContext {
port: cp::PortHandle,
op: *mut RawOp,
user_data: usize,
}

/// The overlapped struct we actually used for IOCP.
#[repr(C)]
pub struct Overlapped<T: ?Sized> {
pub struct Overlapped {
/// The base [`OVERLAPPED`].
pub base: OVERLAPPED,
/// The unique ID of created driver.
pub driver: RawFd,
/// The registered user defined data.
pub user_data: usize,
/// The opcode.
/// The user should guarantee the type is correct.
pub op: T,
}

impl<T> Overlapped<T> {
pub(crate) fn new(driver: RawFd, user_data: usize, op: T) -> Self {
impl Overlapped {
pub(crate) fn new(driver: RawFd) -> Self {
Self {
base: unsafe { std::mem::zeroed() },
driver,
user_data,
op,
}
}
}

// SAFETY: neither field of `OVERLAPPED` is used
unsafe impl Send for Overlapped<()> {}
unsafe impl Sync for Overlapped<()> {}
unsafe impl Send for Overlapped {}
unsafe impl Sync for Overlapped {}
Loading

0 comments on commit f6e7c4a

Please sign in to comment.