From f6e7c4a5c30b11b296ee6df12112a603b2fb4a5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=AE=87=E9=80=B8?= Date: Wed, 15 May 2024 03:44:25 +0800 Subject: [PATCH] feat(driver): remove slab and use ptr as user data --- compio-driver/Cargo.toml | 1 - compio-driver/src/fusion/mod.rs | 22 ++-- compio-driver/src/iocp/cp/global.rs | 18 +-- compio-driver/src/iocp/cp/mod.rs | 12 +- compio-driver/src/iocp/cp/multi.rs | 8 +- compio-driver/src/iocp/mod.rs | 98 +++++++--------- compio-driver/src/iour/mod.rs | 34 +++--- compio-driver/src/key.rs | 142 +++++++++++++++++++++-- compio-driver/src/lib.rs | 169 +++++----------------------- compio-driver/src/poll/mod.rs | 33 ++---- compio-driver/src/unix/mod.rs | 11 +- compio-driver/tests/file.rs | 33 ++---- compio-runtime/src/runtime/mod.rs | 20 ++-- compio-runtime/src/runtime/op.rs | 35 +++--- compio/examples/driver.rs | 4 +- 15 files changed, 287 insertions(+), 353 deletions(-) diff --git a/compio-driver/Cargo.toml b/compio-driver/Cargo.toml index 941d0310..2de24dd6 100644 --- a/compio-driver/Cargo.toml +++ b/compio-driver/Cargo.toml @@ -36,7 +36,6 @@ compio-log = { workspace = true } cfg-if = { workspace = true } crossbeam-channel = { workspace = true } futures-util = { workspace = true } -slab = { workspace = true } socket2 = { workspace = true } # Windows specific dependencies diff --git a/compio-driver/src/fusion/mod.rs b/compio-driver/src/fusion/mod.rs index d123cf23..b621e909 100644 --- a/compio-driver/src/fusion/mod.rs +++ b/compio-driver/src/fusion/mod.rs @@ -14,10 +14,8 @@ pub use driver_type::DriverType; pub(crate) use iour::{sockaddr_storage, socklen_t}; pub use iour::{OpCode as IourOpCode, OpEntry}; pub use poll::{Decision, OpCode as PollOpCode}; -use slab::Slab; -pub(crate) use crate::RawOp; -use crate::{OutEntries, ProactorBuilder}; +use crate::{Key, OutEntries, ProactorBuilder}; mod driver_type { use std::sync::atomic::{AtomicU8, Ordering}; @@ -136,10 +134,10 @@ impl Driver { } } - pub fn create_op(&self, user_data: usize, op: T) -> RawOp { + pub fn create_op(&self, op: T) -> Key { match &self.fuse { - FuseDriver::Poll(driver) => driver.create_op(user_data, op), - FuseDriver::IoUring(driver) => driver.create_op(user_data, op), + FuseDriver::Poll(driver) => driver.create_op(op), + FuseDriver::IoUring(driver) => driver.create_op(op), } } @@ -150,17 +148,17 @@ impl Driver { } } - pub fn cancel(&mut self, user_data: usize, registry: &mut Slab) { + pub fn cancel(&mut self, op: Key) { match &mut self.fuse { - FuseDriver::Poll(driver) => driver.cancel(user_data, registry), - FuseDriver::IoUring(driver) => driver.cancel(user_data, registry), + FuseDriver::Poll(driver) => driver.cancel(op), + FuseDriver::IoUring(driver) => driver.cancel(op), } } - pub fn push(&mut self, user_data: usize, op: &mut RawOp) -> Poll> { + pub fn push(&mut self, op: &mut Key) -> Poll> { match &mut self.fuse { - FuseDriver::Poll(driver) => driver.push(user_data, op), - FuseDriver::IoUring(driver) => driver.push(user_data, op), + FuseDriver::Poll(driver) => driver.push(op), + FuseDriver::IoUring(driver) => driver.push(op), } } diff --git a/compio-driver/src/iocp/cp/global.rs b/compio-driver/src/iocp/cp/global.rs index f9a045bc..f16ea270 100644 --- a/compio-driver/src/iocp/cp/global.rs +++ b/compio-driver/src/iocp/cp/global.rs @@ -29,15 +29,11 @@ impl GlobalPort { self.port.attach(fd) } - pub fn post( - &self, - res: io::Result, - optr: *mut Overlapped, - ) -> io::Result<()> { + pub fn post(&self, res: io::Result, optr: *mut Overlapped) -> io::Result<()> { self.port.post(res, optr) } - pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> { + pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> { self.port.post_raw(optr) } } @@ -62,7 +58,7 @@ fn iocp_start() -> io::Result<()> { loop { for entry in port.port.poll_raw(None)? { // Any thin pointer is OK because we don't use the type of opcode. - let overlapped_ptr: *mut Overlapped<()> = entry.lpOverlapped.cast(); + let overlapped_ptr: *mut Overlapped = entry.lpOverlapped.cast(); let overlapped = unsafe { &*overlapped_ptr }; if let Err(_e) = syscall!( BOOL, @@ -135,15 +131,11 @@ impl PortHandle { Self { port } } - pub fn post( - &self, - res: io::Result, - optr: *mut Overlapped, - ) -> io::Result<()> { + pub fn post(&self, res: io::Result, optr: *mut Overlapped) -> io::Result<()> { self.port.post(res, optr) } - pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> { + pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> { self.port.post_raw(optr) } } diff --git a/compio-driver/src/iocp/cp/mod.rs b/compio-driver/src/iocp/cp/mod.rs index c07c6381..f8b413df 100644 --- a/compio-driver/src/iocp/cp/mod.rs +++ b/compio-driver/src/iocp/cp/mod.rs @@ -77,11 +77,7 @@ impl CompletionPort { Ok(()) } - pub fn post( - &self, - res: io::Result, - optr: *mut Overlapped, - ) -> io::Result<()> { + pub fn post(&self, res: io::Result, optr: *mut Overlapped) -> io::Result<()> { if let Some(overlapped) = unsafe { optr.as_mut() } { match &res { Ok(transferred) => { @@ -97,7 +93,7 @@ impl CompletionPort { self.post_raw(optr) } - pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> { + pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> { syscall!( BOOL, PostQueuedCompletionStatus(self.port.as_raw_handle() as _, 0, 0, optr.cast()) @@ -143,7 +139,7 @@ impl CompletionPort { ) -> io::Result> { Ok(self.poll_raw(timeout)?.filter_map(move |entry| { // Any thin pointer is OK because we don't use the type of opcode. - let overlapped_ptr: *mut Overlapped<()> = entry.lpOverlapped.cast(); + let overlapped_ptr: *mut Overlapped = entry.lpOverlapped.cast(); let overlapped = unsafe { &*overlapped_ptr }; if let Some(current_driver) = current_driver { if overlapped.driver != current_driver { @@ -181,7 +177,7 @@ impl CompletionPort { _ => Err(io::Error::from_raw_os_error(error as _)), } }; - Some(Entry::new(overlapped.user_data, res)) + Some(Entry::new(overlapped_ptr as usize, res)) })) } } diff --git a/compio-driver/src/iocp/cp/multi.rs b/compio-driver/src/iocp/cp/multi.rs index df2d56c6..461cf28a 100644 --- a/compio-driver/src/iocp/cp/multi.rs +++ b/compio-driver/src/iocp/cp/multi.rs @@ -48,15 +48,11 @@ impl PortHandle { Self { port } } - pub fn post( - &self, - res: io::Result, - optr: *mut Overlapped, - ) -> io::Result<()> { + pub fn post(&self, res: io::Result, optr: *mut Overlapped) -> io::Result<()> { self.port.post(res, optr) } - pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> { + pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> { self.port.post_raw(optr) } } diff --git a/compio-driver/src/iocp/mod.rs b/compio-driver/src/iocp/mod.rs index 39fa01ea..f2a5d837 100644 --- a/compio-driver/src/iocp/mod.rs +++ b/compio-driver/src/iocp/mod.rs @@ -9,14 +9,13 @@ use std::{ }, }, pin::Pin, - ptr::{null, NonNull}, + ptr::null, sync::Arc, task::Poll, time::Duration, }; use compio_log::{instrument, trace}; -use slab::Slab; use windows_sys::Win32::{ Foundation::{ERROR_BUSY, ERROR_OPERATION_ABORTED, ERROR_TIMEOUT, WAIT_OBJECT_0, WAIT_TIMEOUT}, Networking::WinSock::{WSACleanup, WSAStartup, WSADATA}, @@ -29,7 +28,7 @@ use windows_sys::Win32::{ }, }; -use crate::{syscall, AsyncifyPool, Entry, OutEntries, ProactorBuilder, RawOp}; +use crate::{syscall, AsyncifyPool, Entry, Key, OutEntries, ProactorBuilder}; pub(crate) mod op; @@ -173,12 +172,10 @@ pub(crate) struct Driver { waits: HashMap, cancelled: HashSet, pool: AsyncifyPool, - notify_overlapped: Arc>, + notify_overlapped: Arc, } impl Driver { - const NOTIFY: usize = usize::MAX; - pub fn new(builder: &ProactorBuilder) -> io::Result { instrument!(compio_log::Level::TRACE, "new", ?builder); let mut data: WSADATA = unsafe { std::mem::zeroed() }; @@ -191,33 +188,33 @@ impl Driver { waits: HashMap::default(), cancelled: HashSet::default(), pool: builder.create_or_get_thread_pool(), - notify_overlapped: Arc::new(Overlapped::new(driver, Self::NOTIFY, ())), + notify_overlapped: Arc::new(Overlapped::new(driver)), }) } - pub fn create_op(&self, user_data: usize, op: T) -> RawOp { - RawOp::new(self.port.as_raw_handle() as _, user_data, op) + pub fn create_op(&self, op: T) -> Key { + Key::new(self.port.as_raw_handle() as _, op) } pub fn attach(&mut self, fd: RawFd) -> io::Result<()> { self.port.attach(fd) } - pub fn cancel(&mut self, user_data: usize, registry: &mut Slab) { - instrument!(compio_log::Level::TRACE, "cancel", user_data); + pub fn cancel(&mut self, mut op: Key) { + instrument!(compio_log::Level::TRACE, "cancel", ?op); trace!("cancel RawOp"); + let user_data = op.user_data(); self.cancelled.insert(user_data); - if let Some(op) = registry.get_mut(user_data) { - let overlapped_ptr = op.as_mut_ptr(); - let op = op.as_op_pin(); - // It's OK to fail to cancel. - trace!("call OpCode::cancel"); - unsafe { op.cancel(overlapped_ptr.cast()) }.ok(); - } + let overlapped_ptr = op.as_mut_ptr(); + let op = op.as_op_pin(); + // It's OK to fail to cancel. + trace!("call OpCode::cancel"); + unsafe { op.cancel(overlapped_ptr.cast()) }.ok(); } - pub fn push(&mut self, user_data: usize, op: &mut RawOp) -> Poll> { - instrument!(compio_log::Level::TRACE, "push", user_data); + pub fn push(&mut self, op: &mut Key) -> Poll> { + instrument!(compio_log::Level::TRACE, "push", ?op); + let user_data = op.user_data(); if self.cancelled.remove(&user_data) { trace!("pushed RawOp already cancelled"); Poll::Ready(Err(io::Error::from_raw_os_error( @@ -230,7 +227,7 @@ impl Driver { match op_pin.op_type() { OpType::Overlapped => unsafe { op_pin.operate(optr.cast()) }, OpType::Blocking => { - if self.push_blocking(op)? { + if self.push_blocking(user_data)? { Poll::Pending } else { Poll::Ready(Err(io::Error::from_raw_os_error(ERROR_BUSY as _))) @@ -247,20 +244,12 @@ impl Driver { } } - fn push_blocking(&mut self, op: &mut RawOp) -> io::Result { - // Safety: the RawOp is not released before the operation returns. - struct SendWrapper(T); - unsafe impl Send for SendWrapper {} - - let optr = SendWrapper(NonNull::from(op)); + fn push_blocking(&mut self, user_data: usize) -> io::Result { let port = self.port.handle(); Ok(self .pool .dispatch(move || { - #[allow(clippy::redundant_locals)] - let mut optr = optr; - // Safety: the pointer is created from a reference. - let op = unsafe { optr.0.as_mut() }; + let op = unsafe { Key::upcast(user_data) }; let optr = op.as_mut_ptr(); let res = op.operate_blocking(); port.post(res, optr).ok(); @@ -269,12 +258,13 @@ impl Driver { } fn create_entry( + notify_user_data: usize, cancelled: &mut HashSet, waits: &mut HashMap, entry: Entry, ) -> Option { let user_data = entry.user_data(); - if user_data != Self::NOTIFY { + if user_data != notify_user_data { waits.remove(&user_data); let result = if cancelled.remove(&user_data) { Err(io::Error::from_raw_os_error(ERROR_OPERATION_ABORTED as _)) @@ -294,11 +284,11 @@ impl Driver { ) -> io::Result<()> { instrument!(compio_log::Level::TRACE, "poll", ?timeout); - entries.extend( - self.port - .poll(timeout)? - .filter_map(|e| Self::create_entry(&mut self.cancelled, &mut self.waits, e)), - ); + let notify_user_data = self.notify_overlapped.as_ref() as *const Overlapped as usize; + + entries.extend(self.port.poll(timeout)?.filter_map(|e| { + Self::create_entry(notify_user_data, &mut self.cancelled, &mut self.waits, e) + })); Ok(()) } @@ -326,11 +316,11 @@ impl Drop for Driver { /// A notify handle to the inner driver. pub struct NotifyHandle { port: cp::PortHandle, - overlapped: Arc>, + overlapped: Arc, } impl NotifyHandle { - fn new(port: cp::PortHandle, overlapped: Arc>) -> Self { + fn new(port: cp::PortHandle, overlapped: Arc) -> Self { Self { port, overlapped } } @@ -348,8 +338,11 @@ struct WinThreadpollWait { } impl WinThreadpollWait { - pub fn new(port: cp::PortHandle, event: RawFd, op: &mut RawOp) -> io::Result { - let mut context = Box::new(WinThreadpollWaitContext { port, op }); + pub fn new(port: cp::PortHandle, event: RawFd, op: &mut Key) -> io::Result { + let mut context = Box::new(WinThreadpollWaitContext { + port, + user_data: op.user_data(), + }); let wait = syscall!( BOOL, CreateThreadpoolWait( @@ -376,13 +369,13 @@ impl WinThreadpollWait { WAIT_TIMEOUT => Err(io::Error::from_raw_os_error(ERROR_TIMEOUT as _)), _ => Err(io::Error::from_raw_os_error(result as _)), }; + let op = unsafe { Key::upcast(context.user_data) }; let res = if res.is_err() { res } else { - let op = unsafe { &mut *context.op }; op.operate_blocking() }; - context.port.post(res, (*context.op).as_mut_ptr()).ok(); + context.port.post(res, op.as_mut_ptr()).ok(); } } @@ -398,34 +391,27 @@ impl Drop for WinThreadpollWait { struct WinThreadpollWaitContext { port: cp::PortHandle, - op: *mut RawOp, + user_data: usize, } /// The overlapped struct we actually used for IOCP. #[repr(C)] -pub struct Overlapped { +pub struct Overlapped { /// The base [`OVERLAPPED`]. pub base: OVERLAPPED, /// The unique ID of created driver. pub driver: RawFd, - /// The registered user defined data. - pub user_data: usize, - /// The opcode. - /// The user should guarantee the type is correct. - pub op: T, } -impl Overlapped { - pub(crate) fn new(driver: RawFd, user_data: usize, op: T) -> Self { +impl Overlapped { + pub(crate) fn new(driver: RawFd) -> Self { Self { base: unsafe { std::mem::zeroed() }, driver, - user_data, - op, } } } // SAFETY: neither field of `OVERLAPPED` is used -unsafe impl Send for Overlapped<()> {} -unsafe impl Sync for Overlapped<()> {} +unsafe impl Send for Overlapped {} +unsafe impl Sync for Overlapped {} diff --git a/compio-driver/src/iour/mod.rs b/compio-driver/src/iour/mod.rs index 6fc30235..38e16849 100644 --- a/compio-driver/src/iour/mod.rs +++ b/compio-driver/src/iour/mod.rs @@ -1,7 +1,7 @@ #[cfg_attr(all(doc, docsrs), doc(cfg(all())))] #[allow(unused_imports)] pub use std::os::fd::{AsRawFd, OwnedFd, RawFd}; -use std::{io, os::fd::FromRawFd, pin::Pin, ptr::NonNull, sync::Arc, task::Poll, time::Duration}; +use std::{io, os::fd::FromRawFd, pin::Pin, sync::Arc, task::Poll, time::Duration}; use compio_log::{instrument, trace, warn}; use crossbeam_queue::SegQueue; @@ -25,12 +25,10 @@ use io_uring::{ IoUring, }; pub(crate) use libc::{sockaddr_storage, socklen_t}; -use slab::Slab; -use crate::{syscall, AsyncifyPool, Entry, OutEntries, ProactorBuilder}; +use crate::{syscall, AsyncifyPool, Entry, Key, OutEntries, ProactorBuilder}; pub(crate) mod op; -pub(crate) use crate::RawOp; /// The created entry of [`OpCode`]. pub enum OpEntry { @@ -159,16 +157,16 @@ impl Driver { has_entry } - pub fn create_op(&self, user_data: usize, op: T) -> RawOp { - RawOp::new(self.as_raw_fd(), user_data, op) + pub fn create_op(&self, op: T) -> Key { + Key::new(self.as_raw_fd(), op) } pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> { Ok(()) } - pub fn cancel(&mut self, user_data: usize, _registry: &mut Slab) { - instrument!(compio_log::Level::TRACE, "cancel", user_data); + pub fn cancel(&mut self, op: Key) { + instrument!(compio_log::Level::TRACE, "cancel", ?op); trace!("cancel RawOp"); unsafe { #[allow(clippy::useless_conversion)] @@ -176,7 +174,7 @@ impl Driver { .inner .submission() .push( - &AsyncCancel::new(user_data as _) + &AsyncCancel::new(op.user_data() as _) .build() .user_data(Self::CANCEL) .into(), @@ -204,8 +202,9 @@ impl Driver { } } - pub fn push(&mut self, user_data: usize, op: &mut RawOp) -> Poll> { - instrument!(compio_log::Level::TRACE, "push", user_data); + pub fn push(&mut self, op: &mut Key) -> Poll> { + instrument!(compio_log::Level::TRACE, "push", ?op); + let user_data = op.user_data(); let op_pin = op.as_op_pin(); trace!("push RawOp"); match op_pin.create_entry() { @@ -220,7 +219,7 @@ impl Driver { Poll::Pending } OpEntry::Blocking => { - if self.push_blocking(user_data, op)? { + if self.push_blocking(user_data)? { Poll::Pending } else { Poll::Ready(Err(io::Error::from_raw_os_error(libc::EBUSY))) @@ -229,20 +228,13 @@ impl Driver { } } - fn push_blocking(&mut self, user_data: usize, op: &mut RawOp) -> io::Result { - // Safety: the RawOp is not released before the operation returns. - struct SendWrapper(T); - unsafe impl Send for SendWrapper {} - - let op = SendWrapper(NonNull::from(op)); + fn push_blocking(&mut self, user_data: usize) -> io::Result { let handle = self.handle()?; let completed = self.pool_completed.clone(); let is_ok = self .pool .dispatch(move || { - #[allow(clippy::redundant_locals)] - let mut op = op; - let op = unsafe { op.0.as_mut() }; + let op = unsafe { Key::upcast(user_data) }; let op_pin = op.as_op_pin(); let res = op_pin.call_blocking(); completed.push(Entry::new(user_data, res)); diff --git a/compio-driver/src/key.rs b/compio-driver/src/key.rs index 1afacaf7..0472b497 100644 --- a/compio-driver/src/key.rs +++ b/compio-driver/src/key.rs @@ -1,20 +1,111 @@ +use std::{ + io, + marker::PhantomData, + ops::{Deref, DerefMut}, + pin::Pin, + task::Waker, +}; + +use compio_buf::BufResult; + +use crate::{OpCode, Overlapped, PushEntry, RawFd}; + +#[repr(C)] +pub struct RawOp { + header: Overlapped, + // The two flags here are manual reference counting. The driver holds the strong ref until it + // completes; the runtime holds the strong ref until the future is dropped. + cancelled: bool, + upcast_fn: unsafe fn(usize) -> *mut RawOp, + result: PushEntry, io::Result>, + op: T, +} + +impl RawOp { + pub fn as_op_pin(&mut self) -> Pin<&mut T> { + unsafe { Pin::new_unchecked(&mut self.op) } + } + + #[cfg(windows)] + pub fn as_mut_ptr(&mut self) -> *mut Overlapped { + &mut self.header + } + + pub fn set_cancelled(&mut self) -> bool { + self.cancelled = true; + self.has_result() + } + + pub fn set_result(&mut self, res: io::Result) -> bool { + if let PushEntry::Pending(Some(w)) = + std::mem::replace(&mut self.result, PushEntry::Ready(res)) + { + w.wake(); + } + self.cancelled + } + + pub fn has_result(&self) -> bool { + self.result.is_ready() + } + + pub fn set_waker(&mut self, waker: Waker) { + if let PushEntry::Pending(w) = &mut self.result { + *w = Some(waker) + } + } + + pub fn into_inner(self) -> BufResult + where + T: Sized, + { + BufResult(self.result.take_ready().unwrap(), self.op) + } +} + +#[cfg(windows)] +impl RawOp { + pub fn operate_blocking(&mut self) -> io::Result { + use std::task::Poll; + + let optr = self.as_mut_ptr(); + let op = self.as_op_pin(); + let res = unsafe { op.operate(optr.cast()) }; + match res { + Poll::Pending => unreachable!("this operation is not overlapped"), + Poll::Ready(res) => res, + } + } +} + +unsafe fn upcast(user_data: usize) -> *mut RawOp { + user_data as *mut RawOp as *mut RawOp +} + /// A typed wrapper for key of Ops submitted into driver #[derive(PartialEq, Eq, Hash)] pub struct Key { user_data: usize, - _p: std::marker::PhantomData, + _p: PhantomData>, } impl Unpin for Key {} -impl Clone for Key { - fn clone(&self) -> Self { - *self +impl Key { + /// Create [`RawOp`] and get the [`Key`] to it. + pub fn new(driver: RawFd, op: T) -> Self { + let header = Overlapped::new(driver); + let raw_op = Box::new(RawOp { + header, + cancelled: false, + upcast_fn: upcast::, + result: PushEntry::Pending(None), + op, + }); + unsafe { Self::new_unchecked(Box::into_raw(raw_op) as _) } } } -impl Copy for Key {} - impl Key { /// Create a new `Key` with the given user data. /// @@ -22,19 +113,48 @@ impl Key { /// /// Caller needs to ensure that `T` does correspond to `user_data` in driver /// this `Key` is created with. - pub unsafe fn new(user_data: usize) -> Self { + pub unsafe fn new_unchecked(user_data: usize) -> Self { Self { user_data, - _p: std::marker::PhantomData, + _p: PhantomData, } } + + /// Get the unique user-defined data. + pub const fn user_data(&self) -> usize { + self.user_data + } + + /// Get the inner result if it is completed. + pub fn into_inner(self) -> BufResult { + unsafe { Box::from_raw(self.user_data as *mut RawOp) }.into_inner() + } +} + +impl Key<()> { + pub(crate) unsafe fn drop_in_place(user_data: usize) { + let op = &*(user_data as *const RawOp<()>); + let ptr = (op.upcast_fn)(user_data); + let _ = Box::from_raw(ptr); + } + + pub(crate) unsafe fn upcast<'a>(user_data: usize) -> &'a mut RawOp { + let op = &*(user_data as *const RawOp<()>); + &mut *(op.upcast_fn)(user_data) + } } -impl std::ops::Deref for Key { - type Target = usize; +impl Deref for Key { + type Target = RawOp; fn deref(&self) -> &Self::Target { - &self.user_data + unsafe { &*(self.user_data as *const RawOp) } + } +} + +impl DerefMut for Key { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *(self.user_data as *mut RawOp) } } } diff --git a/compio-driver/src/lib.rs b/compio-driver/src/lib.rs index eefa4245..5b1c6d18 100644 --- a/compio-driver/src/lib.rs +++ b/compio-driver/src/lib.rs @@ -14,16 +14,12 @@ compile_error!("You must choose at least one of these features: [\"io-uring\", \ use std::{ io, - mem::ManuallyDrop, - pin::Pin, - ptr::NonNull, task::{Poll, Waker}, time::Duration, }; use compio_buf::BufResult; -use compio_log::{instrument, trace}; -use slab::Slab; +use compio_log::instrument; mod key; pub use key::Key; @@ -216,7 +212,6 @@ impl PushEntry { /// It owns the operations to keep the driver safe. pub struct Proactor { driver: Driver, - ops: Slab, } impl Proactor { @@ -233,7 +228,6 @@ impl Proactor { fn with_builder(builder: &ProactorBuilder) -> io::Result { Ok(Self { driver: Driver::new(builder)?, - ops: Slab::with_capacity(builder.capacity as _), }) } @@ -258,32 +252,25 @@ impl Proactor { /// contains a cancelled user-defined data, the operation will be ignored. /// However, to make the operation dropped correctly, you should cancel /// after push. - pub fn cancel(&mut self, user_data: usize) { - instrument!(compio_log::Level::DEBUG, "cancel", user_data); - if let Some(op) = self.ops.get_mut(user_data) { - if op.set_cancelled() { - // The op is completed. - trace!("cancel and remove {}", user_data); - self.ops.remove(user_data); - return; - } + pub fn cancel(&mut self, mut op: Key) -> Option> { + instrument!(compio_log::Level::DEBUG, "cancel", ?op); + if op.set_cancelled() { + Some(op.into_inner()) + } else { + self.driver.cancel(op); + None } - self.driver.cancel(user_data, &mut self.ops); } /// Push an operation into the driver, and return the unique key, called /// user-defined data, associated with it. pub fn push(&mut self, op: T) -> PushEntry, BufResult> { - let entry = self.ops.vacant_entry(); - let user_data = entry.key(); - let op = self.driver.create_op(user_data, op); - let op = entry.insert(op); - match self.driver.push(user_data, op) { - Poll::Pending => PushEntry::Pending(unsafe { Key::new(user_data) }), + let mut op = self.driver.create_op(op); + match self.driver.push(&mut op) { + Poll::Pending => PushEntry::Pending(op), Poll::Ready(res) => { - let mut op = self.ops.remove(user_data); op.set_result(res); - PushEntry::Ready(unsafe { op.into_inner::() }) + PushEntry::Ready(op.into_inner()) } } } @@ -296,8 +283,7 @@ impl Proactor { entries: &mut impl Extend, ) -> io::Result<()> { unsafe { - self.driver - .poll(timeout, OutEntries::new(entries, &mut self.ops))?; + self.driver.poll(timeout, OutEntries::new(entries))?; } Ok(()) } @@ -307,24 +293,18 @@ impl Proactor { /// # Panics /// This function will panic if the requested operation has not been /// completed. - pub fn pop(&mut self, user_data: Key) -> Option> { - instrument!(compio_log::Level::DEBUG, "pop", ?user_data); - if self.ops[*user_data].has_result() { - let op = self - .ops - .try_remove(*user_data) - .expect("the entry should be valid"); - trace!("poped {}", *user_data); - // Safety: user cannot create key with safe code, so the type should be correct - Some(unsafe { op.into_inner::() }) + pub fn pop(&mut self, op: Key) -> PushEntry, BufResult> { + instrument!(compio_log::Level::DEBUG, "pop", ?op); + if op.has_result() { + PushEntry::Ready(op.into_inner()) } else { - None + PushEntry::Pending(op) } } /// Update the waker of the specified op. - pub fn update_waker(&mut self, user_data: usize, waker: Waker) { - self.ops[user_data].set_waker(waker); + pub fn update_waker(&mut self, op: &mut Key, waker: Waker) { + op.set_waker(waker); } /// Create a notify handle to interrupt the inner driver. @@ -362,118 +342,25 @@ impl Entry { } } -pub(crate) struct RawOp { - op: NonNull>, - // The two flags here are manual reference counting. The driver holds the strong ref until it - // completes; the runtime holds the strong ref until the future is dropped. - cancelled: bool, - result: PushEntry, io::Result>, -} - -impl RawOp { - pub(crate) fn new(driver: RawFd, user_data: usize, op: impl OpCode + 'static) -> Self { - let op = Overlapped::new(driver, user_data, op); - let op = Box::new(op) as Box>; - Self { - op: unsafe { NonNull::new_unchecked(Box::into_raw(op)) }, - cancelled: false, - result: PushEntry::Pending(None), - } - } - - pub fn as_op_pin(&mut self) -> Pin<&mut dyn OpCode> { - unsafe { Pin::new_unchecked(&mut self.op.as_mut().op) } - } - - #[cfg(windows)] - pub fn as_mut_ptr(&mut self) -> *mut Overlapped { - self.op.as_ptr() - } - - pub fn set_cancelled(&mut self) -> bool { - self.cancelled = true; - self.has_result() - } - - pub fn set_result(&mut self, res: io::Result) -> bool { - if let PushEntry::Pending(Some(w)) = - std::mem::replace(&mut self.result, PushEntry::Ready(res)) - { - w.wake(); - } - self.cancelled - } - - pub fn has_result(&self) -> bool { - self.result.is_ready() - } - - pub fn set_waker(&mut self, waker: Waker) { - if let PushEntry::Pending(w) = &mut self.result { - *w = Some(waker) - } - } - - /// # Safety - /// The caller should ensure the correct type. - /// - /// # Panics - /// This function will panic if the result has not been set. - pub unsafe fn into_inner(self) -> BufResult { - let mut this = ManuallyDrop::new(self); - let overlapped: Box> = Box::from_raw(this.op.cast().as_ptr()); - BufResult( - std::mem::replace(&mut this.result, PushEntry::Pending(None)) - .take_ready() - .unwrap(), - overlapped.op, - ) - } - - #[cfg(windows)] - fn operate_blocking(&mut self) -> io::Result { - let optr = self.as_mut_ptr(); - let op = self.as_op_pin(); - let res = unsafe { op.operate(optr.cast()) }; - match res { - Poll::Pending => unreachable!("this operation is not overlapped"), - Poll::Ready(res) => res, - } - } -} - -impl Drop for RawOp { - fn drop(&mut self) { - if self.has_result() { - let _ = unsafe { Box::from_raw(self.op.as_ptr()) }; - } - } -} - // The output entries need to be marked as `completed`. If an entry has been // marked as `cancelled`, it will be removed from the registry. -struct OutEntries<'a, 'b, E> { +struct OutEntries<'b, E> { entries: &'b mut E, - registry: &'a mut Slab, } -impl<'a, 'b, E> OutEntries<'a, 'b, E> { - pub fn new(entries: &'b mut E, registry: &'a mut Slab) -> Self { - Self { entries, registry } - } - - #[allow(dead_code)] - pub fn registry(&mut self) -> &mut Slab { - self.registry +impl<'b, E> OutEntries<'b, E> { + pub fn new(entries: &'b mut E) -> Self { + Self { entries } } } -impl> Extend for OutEntries<'_, '_, E> { +impl> Extend for OutEntries<'_, E> { fn extend>(&mut self, iter: T) { self.entries.extend(iter.into_iter().filter_map(|e| { let user_data = e.user_data(); - if self.registry[user_data].set_result(e.into_result()) { - self.registry.remove(user_data); + let op = unsafe { Key::upcast(user_data) }; + if op.set_result(e.into_result()) { + unsafe { Key::drop_in_place(user_data) }; None } else { Some(user_data) diff --git a/compio-driver/src/poll/mod.rs b/compio-driver/src/poll/mod.rs index e0144ce0..65afbcdb 100644 --- a/compio-driver/src/poll/mod.rs +++ b/compio-driver/src/poll/mod.rs @@ -7,7 +7,6 @@ use std::{ num::NonZeroUsize, os::fd::BorrowedFd, pin::Pin, - ptr::NonNull, sync::Arc, task::Poll, time::Duration, @@ -17,14 +16,11 @@ use compio_log::{instrument, trace}; use crossbeam_queue::SegQueue; pub(crate) use libc::{sockaddr_storage, socklen_t}; use polling::{Event, Events, PollMode, Poller}; -use slab::Slab; -use crate::{syscall, AsyncifyPool, Entry, OutEntries, ProactorBuilder}; +use crate::{syscall, AsyncifyPool, Entry, Key, OutEntries, ProactorBuilder}; pub(crate) mod op; -pub(crate) use crate::RawOp; - /// Abstraction of operations. pub trait OpCode { /// Perform the operation before submit, and return [`Decision`] to @@ -182,8 +178,8 @@ impl Driver { }) } - pub fn create_op(&self, user_data: usize, op: T) -> RawOp { - RawOp::new(self.as_raw_fd(), user_data, op) + pub fn create_op(&self, op: T) -> Key { + Key::new(self.as_raw_fd(), op) } /// # Safety @@ -207,11 +203,12 @@ impl Driver { Ok(()) } - pub fn cancel(&mut self, user_data: usize, _registry: &mut Slab) { - self.cancelled.insert(user_data); + pub fn cancel(&mut self, op: Key) { + self.cancelled.insert(op.user_data()); } - pub fn push(&mut self, user_data: usize, op: &mut RawOp) -> Poll> { + pub fn push(&mut self, op: &mut Key) -> Poll> { + let user_data = op.user_data(); if self.cancelled.remove(&user_data) { Poll::Ready(Err(io::Error::from_raw_os_error(libc::ETIMEDOUT))) } else { @@ -226,7 +223,7 @@ impl Driver { } Ok(Decision::Completed(res)) => Poll::Ready(Ok(res)), Ok(Decision::Blocking(event)) => { - if self.push_blocking(user_data, op, event) { + if self.push_blocking(user_data, event) { Poll::Pending } else { Poll::Ready(Err(io::Error::from_raw_os_error(libc::EBUSY))) @@ -237,19 +234,12 @@ impl Driver { } } - fn push_blocking(&mut self, user_data: usize, op: &mut RawOp, event: Event) -> bool { - // Safety: the RawOp is not released before the operation returns. - struct SendWrapper(T); - unsafe impl Send for SendWrapper {} - - let op = SendWrapper(NonNull::from(op)); + fn push_blocking(&mut self, user_data: usize, event: Event) -> bool { let poll = self.poll.clone(); let completed = self.pool_completed.clone(); self.pool .dispatch(move || { - #[allow(clippy::redundant_locals)] - let mut op = op; - let op = unsafe { op.0.as_mut() }; + let op = unsafe { Key::upcast(user_data) }; let op_pin = op.as_op_pin(); let res = match op_pin.on_event(&event) { Poll::Pending => unreachable!("this operation is not non-blocking"), @@ -287,7 +277,8 @@ impl Driver { if self.cancelled.remove(&user_data) { entries.extend(Some(entry_cancelled(user_data))); } else { - let op = entries.registry()[user_data].as_op_pin(); + let op = Key::upcast(user_data); + let op = op.as_op_pin(); let res = match op.on_event(&event) { Poll::Pending => { // The operation should go back to the front. diff --git a/compio-driver/src/unix/mod.rs b/compio-driver/src/unix/mod.rs index fc986898..ed300d0d 100644 --- a/compio-driver/src/unix/mod.rs +++ b/compio-driver/src/unix/mod.rs @@ -6,13 +6,10 @@ pub(crate) mod op; use crate::RawFd; /// The overlapped struct for unix needn't contain extra fields. -#[repr(transparent)] -pub(crate) struct Overlapped { - pub op: T, -} +pub(crate) struct Overlapped; -impl Overlapped { - pub(crate) fn new(_driver: RawFd, _user_data: usize, op: T) -> Self { - Self { op } +impl Overlapped { + pub fn new(_driver: RawFd) -> Self { + Self } } diff --git a/compio-driver/tests/file.rs b/compio-driver/tests/file.rs index 371f2c0f..c4207765 100644 --- a/compio-driver/tests/file.rs +++ b/compio-driver/tests/file.rs @@ -52,31 +52,12 @@ fn push_and_wait(driver: &mut Proactor, op: O) -> BufResult while entries.is_empty() { driver.poll(None, &mut entries).unwrap(); } - assert_eq!(entries[0], *user_data); - driver.pop(user_data).unwrap() + assert_eq!(entries[0], user_data.user_data()); + driver.pop(user_data).take_ready().unwrap() } } } -#[test] -fn cancel_before_poll() { - let mut driver = Proactor::new().unwrap(); - - let fd = open_file(&mut driver); - let fd = SharedFd::new(fd); - driver.attach(fd.as_raw_fd()).unwrap(); - - driver.cancel(0); - - let op = ReadAt::new(fd.clone(), 0, Vec::with_capacity(8)); - let BufResult(res, _) = push_and_wait(&mut driver, op); - - assert!(res.is_ok() || res.unwrap_err().kind() == io::ErrorKind::TimedOut); - - let op = CloseFile::new(fd.try_unwrap().unwrap()); - push_and_wait(&mut driver, op).unwrap(); -} - #[test] fn timeout() { let mut driver = Proactor::new().unwrap(); @@ -98,11 +79,11 @@ fn register_multiple() { let fd = SharedFd::new(fd); driver.attach(fd.as_raw_fd()).unwrap(); - let mut need_wait = 0; + let mut keys = vec![]; for _i in 0..TASK_LEN { match driver.push(ReadAt::new(fd.clone(), 0, Vec::with_capacity(1024))) { - PushEntry::Pending(_) => need_wait += 1, + PushEntry::Pending(key) => keys.push(key), PushEntry::Ready(res) => { res.unwrap(); } @@ -110,13 +91,13 @@ fn register_multiple() { } let mut entries = ArrayVec::::new(); - while entries.len() < need_wait { + while entries.len() < keys.len() { driver.poll(None, &mut entries).unwrap(); } // Cancel the entries to drop the ops, and decrease the ref count of fd. - for entry in entries { - driver.cancel(entry); + for key in keys { + driver.cancel(key); } let op = CloseFile::new(fd.try_unwrap().unwrap()); diff --git a/compio-runtime/src/runtime/mod.rs b/compio-runtime/src/runtime/mod.rs index 31872eba..15c554c8 100644 --- a/compio-runtime/src/runtime/mod.rs +++ b/compio-runtime/src/runtime/mod.rs @@ -147,8 +147,8 @@ impl RuntimeInner { } } - pub fn cancel_op(&self, user_data: Key) { - self.driver.borrow_mut().cancel(*user_data); + pub fn cancel_op(&self, op: Key) { + self.driver.borrow_mut().cancel(op); } #[cfg(feature = "time")] @@ -159,16 +159,14 @@ impl RuntimeInner { pub fn poll_task( &self, cx: &mut Context, - user_data: Key, - ) -> Poll> { - instrument!(compio_log::Level::DEBUG, "poll_task", ?user_data,); + op: Key, + ) -> PushEntry, BufResult> { + instrument!(compio_log::Level::DEBUG, "poll_task", ?op); let mut driver = self.driver.borrow_mut(); - if let Some(res) = driver.pop(user_data) { - Poll::Ready(res) - } else { - driver.update_waker(*user_data, cx.waker().clone()); - Poll::Pending - } + driver.pop(op).map_pending(|mut k| { + driver.update_waker(&mut k, cx.waker().clone()); + k + }) } #[cfg(feature = "time")] diff --git a/compio-runtime/src/runtime/op.rs b/compio-runtime/src/runtime/op.rs index bcc2bfec..eba36666 100644 --- a/compio-runtime/src/runtime/op.rs +++ b/compio-runtime/src/runtime/op.rs @@ -5,22 +5,18 @@ use std::{ }; use compio_buf::BufResult; -use compio_driver::{Key, OpCode}; +use compio_driver::{Key, OpCode, PushEntry}; use crate::runtime::Runtime; #[derive(Debug)] -pub struct OpFuture { - user_data: Key, - completed: bool, +pub struct OpFuture { + key: Option>, } -impl OpFuture { - pub fn new(user_data: Key) -> Self { - Self { - user_data, - completed: false, - } +impl OpFuture { + pub fn new(key: Key) -> Self { + Self { key: Some(key) } } } @@ -28,18 +24,23 @@ impl Future for OpFuture { type Output = BufResult; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let res = Runtime::current().inner().poll_task(cx, self.user_data); - if res.is_ready() { - self.completed = true; + let res = Runtime::current() + .inner() + .poll_task(cx, self.key.take().unwrap()); + match res { + PushEntry::Pending(key) => { + self.key = Some(key); + Poll::Pending + } + PushEntry::Ready(res) => Poll::Ready(res), } - res } } -impl Drop for OpFuture { +impl Drop for OpFuture { fn drop(&mut self) { - if !self.completed { - Runtime::current().inner().cancel_op(self.user_data) + if let Some(key) = self.key.take() { + Runtime::current().inner().cancel_op(key) } } } diff --git a/compio/examples/driver.rs b/compio/examples/driver.rs index 4e4b8bb4..f114c70b 100644 --- a/compio/examples/driver.rs +++ b/compio/examples/driver.rs @@ -53,8 +53,8 @@ fn push_and_wait(driver: &mut Proactor, op: O) -> (usize, O while entries.is_empty() { driver.poll(None, &mut entries).unwrap(); } - assert_eq!(entries[0], *user_data); - driver.pop(user_data).unwrap().unwrap() + assert_eq!(entries[0], user_data.user_data()); + driver.pop(user_data).take_ready().unwrap().unwrap() } } }