Skip to content

Commit

Permalink
Merge pull request #247 from Berrysoft/refactor/shared-fd-generic
Browse files Browse the repository at this point in the history
refactor(driver): generic SharedFd
  • Loading branch information
Berrysoft authored May 2, 2024
2 parents c0164a9 + 73824eb commit 62a3a9e
Show file tree
Hide file tree
Showing 18 changed files with 374 additions and 344 deletions.
95 changes: 42 additions & 53 deletions compio-driver/src/fd.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#[cfg(unix)]
use std::os::fd::FromRawFd;
#[cfg(windows)]
use std::os::windows::io::{
FromRawHandle, FromRawSocket, OwnedHandle, OwnedSocket, RawHandle, RawSocket,
};
use std::os::windows::io::{FromRawHandle, FromRawSocket, RawHandle, RawSocket};
use std::{
future::{poll_fn, Future},
mem::ManuallyDrop,
ops::Deref,
panic::RefUnwindSafe,
sync::{
atomic::{AtomicBool, Ordering},
Expand All @@ -17,35 +16,35 @@ use std::{

use futures_util::task::AtomicWaker;

use crate::{AsRawFd, OwnedFd, RawFd};
use crate::{AsRawFd, RawFd};

#[derive(Debug)]
struct Inner {
fd: OwnedFd,
struct Inner<T> {
fd: T,
// whether there is a future waiting
waits: AtomicBool,
waker: AtomicWaker,
}

impl RefUnwindSafe for Inner {}
impl<T> RefUnwindSafe for Inner<T> {}

/// A shared fd. It is passed to the operations to make sure the fd won't be
/// closed before the operations complete.
#[derive(Debug, Clone)]
pub struct SharedFd(Arc<Inner>);
#[derive(Debug)]
pub struct SharedFd<T>(Arc<Inner<T>>);

impl SharedFd {
impl<T> SharedFd<T> {
/// Create the shared fd from an owned fd.
pub fn new(fd: impl Into<OwnedFd>) -> Self {
pub fn new(fd: T) -> Self {
Self(Arc::new(Inner {
fd: fd.into(),
fd,
waits: AtomicBool::new(false),
waker: AtomicWaker::new(),
}))
}

/// Try to take the inner owned fd.
pub fn try_unwrap(self) -> Result<OwnedFd, Self> {
pub fn try_unwrap(self) -> Result<T, Self> {
let this = ManuallyDrop::new(self);
if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
Ok(fd)
Expand All @@ -55,7 +54,7 @@ impl SharedFd {
}

// SAFETY: if `Some` is returned, the method should not be called again.
unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<OwnedFd> {
unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<T> {
let ptr = ManuallyDrop::new(std::ptr::read(&this.0));
// The ptr is duplicated without increasing the strong count, should forget.
match Arc::try_unwrap(ManuallyDrop::into_inner(ptr)) {
Expand All @@ -68,7 +67,7 @@ impl SharedFd {
}

/// Wait and take the inner owned fd.
pub fn take(self) -> impl Future<Output = Option<OwnedFd>> {
pub fn take(self) -> impl Future<Output = Option<T>> {
let this = ManuallyDrop::new(self);
async move {
if !this.0.waits.swap(true, Ordering::AcqRel) {
Expand All @@ -93,7 +92,7 @@ impl SharedFd {
}
}

impl Drop for SharedFd {
impl<T> Drop for SharedFd<T> {
fn drop(&mut self) {
// It's OK to wake multiple times.
if Arc::strong_count(&self.0) == 2 {
Expand All @@ -102,71 +101,61 @@ impl Drop for SharedFd {
}
}

#[cfg(windows)]
#[doc(hidden)]
impl SharedFd {
pub unsafe fn to_file(&self) -> ManuallyDrop<std::fs::File> {
ManuallyDrop::new(std::fs::File::from_raw_handle(self.as_raw_fd() as _))
}

pub unsafe fn to_socket(&self) -> ManuallyDrop<socket2::Socket> {
ManuallyDrop::new(socket2::Socket::from_raw_socket(self.as_raw_fd() as _))
}
}

#[cfg(unix)]
#[doc(hidden)]
impl SharedFd {
pub unsafe fn to_file(&self) -> ManuallyDrop<std::fs::File> {
ManuallyDrop::new(std::fs::File::from_raw_fd(self.as_raw_fd() as _))
}

pub unsafe fn to_socket(&self) -> ManuallyDrop<socket2::Socket> {
ManuallyDrop::new(socket2::Socket::from_raw_fd(self.as_raw_fd() as _))
}
}

impl AsRawFd for SharedFd {
impl<T: AsRawFd> AsRawFd for SharedFd<T> {
fn as_raw_fd(&self) -> RawFd {
self.0.fd.as_raw_fd()
}
}

#[cfg(windows)]
impl FromRawHandle for SharedFd {
impl<T: FromRawHandle> FromRawHandle for SharedFd<T> {
unsafe fn from_raw_handle(handle: RawHandle) -> Self {
Self::new(OwnedFd::File(OwnedHandle::from_raw_handle(handle)))
Self::new(T::from_raw_handle(handle))
}
}

#[cfg(windows)]
impl FromRawSocket for SharedFd {
impl<T: FromRawSocket> FromRawSocket for SharedFd<T> {
unsafe fn from_raw_socket(sock: RawSocket) -> Self {
Self::new(OwnedFd::Socket(OwnedSocket::from_raw_socket(sock)))
Self::new(T::from_raw_socket(sock))
}
}

#[cfg(unix)]
impl FromRawFd for SharedFd {
impl<T: FromRawFd> FromRawFd for SharedFd<T> {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Self::new(OwnedFd::from_raw_fd(fd))
Self::new(T::from_raw_fd(fd))
}
}

impl From<OwnedFd> for SharedFd {
fn from(value: OwnedFd) -> Self {
impl<T> From<T> for SharedFd<T> {
fn from(value: T) -> Self {
Self::new(value)
}
}

impl<T> Clone for SharedFd<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl<T> Deref for SharedFd<T> {
type Target = T;

fn deref(&self) -> &Self::Target {
&self.0.fd
}
}

/// Get a clone of [`SharedFd`].
pub trait ToSharedFd {
pub trait ToSharedFd<T> {
/// Return a cloned [`SharedFd`].
fn to_shared_fd(&self) -> SharedFd;
fn to_shared_fd(&self) -> SharedFd<T>;
}

impl ToSharedFd for SharedFd {
fn to_shared_fd(&self) -> SharedFd {
impl<T> ToSharedFd<T> for SharedFd<T> {
fn to_shared_fd(&self) -> SharedFd<T> {
self.clone()
}
}
12 changes: 6 additions & 6 deletions compio-driver/src/fusion/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub use crate::unix::op::*;
use crate::SharedFd;

macro_rules! op {
(<$($ty:ident: $trait:ident),* $(,)?> $name:ident( $($arg:ident: $arg_t:ident),* $(,)? )) => {
(<$($ty:ident: $trait:ident),* $(,)?> $name:ident( $($arg:ident: $arg_t:ty),* $(,)? )) => {
::paste::paste!{
enum [< $name Inner >] <$($ty: $trait),*> {
Poll(poll::$name<$($ty),*>),
Expand Down Expand Up @@ -92,9 +92,9 @@ mod iour { pub use crate::sys::iour::{op::*, OpCode}; }
#[rustfmt::skip]
mod poll { pub use crate::sys::poll::{op::*, OpCode}; }

op!(<T: IoBufMut> RecvFrom(fd: SharedFd, buffer: T));
op!(<T: IoBuf> SendTo(fd: SharedFd, buffer: T, addr: SockAddr));
op!(<T: IoVectoredBufMut> RecvFromVectored(fd: SharedFd, buffer: T));
op!(<T: IoVectoredBuf> SendToVectored(fd: SharedFd, buffer: T, addr: SockAddr));
op!(<> FileStat(fd: SharedFd));
op!(<T: IoBufMut, S: AsRawFd> RecvFrom(fd: SharedFd<S>, buffer: T));
op!(<T: IoBuf, S: AsRawFd> SendTo(fd: SharedFd<S>, buffer: T, addr: SockAddr));
op!(<T: IoVectoredBufMut, S: AsRawFd> RecvFromVectored(fd: SharedFd<S>, buffer: T));
op!(<T: IoVectoredBuf, S: AsRawFd> SendToVectored(fd: SharedFd<S>, buffer: T, addr: SockAddr));
op!(<S: AsRawFd> FileStat(fd: SharedFd<S>));
op!(<> PathStat(path: CString, follow_symlink: bool));
30 changes: 30 additions & 0 deletions compio-driver/src/iocp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,36 @@ impl AsRawFd for OwnedFd {
}
}

impl AsRawFd for RawFd {
fn as_raw_fd(&self) -> RawFd {
*self
}
}

impl AsRawFd for std::fs::File {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_handle() as _
}
}

impl AsRawFd for OwnedHandle {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_handle() as _
}
}

impl AsRawFd for socket2::Socket {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_socket() as _
}
}

impl AsRawFd for OwnedSocket {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_socket() as _
}
}

impl From<OwnedHandle> for OwnedFd {
fn from(value: OwnedHandle) -> Self {
Self::File(value)
Expand Down
Loading

0 comments on commit 62a3a9e

Please sign in to comment.