Skip to content

Commit

Permalink
Add back mostly unchanged exchange wrapper + buffer with RustToCudaAs…
Browse files Browse the repository at this point in the history
…ync impls
  • Loading branch information
juntyr committed Jan 3, 2024
1 parent 9dc2ae7 commit 91f9246
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 410 deletions.
6 changes: 0 additions & 6 deletions src/host/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,7 @@ impl<'stream, 'a, T: PortableBitSemantics + TypeGraphLayout>
pub unsafe fn new(
device_box: &'a mut DeviceBox<DeviceCopyWithPortableBitSemantics<T>>,
host_ref: &'a mut T,
stream: &'stream Stream,
) -> Self {
let _ = stream;

Self {
device_box,
host_ref,
Expand Down Expand Up @@ -448,10 +445,7 @@ impl<'stream, 'a, T: PortableBitSemantics + TypeGraphLayout>
pub const unsafe fn new(
device_box: &'a DeviceBox<DeviceCopyWithPortableBitSemantics<T>>,
host_ref: &'a T,
stream: &'stream Stream,
) -> Self {
let _ = stream;

Self {
device_box,
host_ref,
Expand Down
26 changes: 26 additions & 0 deletions src/utils/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,32 @@ impl<'a, 'stream, T: BorrowMut<C::Completed>, C: Completion<T>> Async<'a, 'strea
} => Ok((self.value, Some(completion))),
}
}

/// # Safety
///
/// The returned reference to the inner value of type `T` may not yet have
/// completed its asynchronous work and may thus be in an inconsistent
/// state.
///
/// This method must only be used to construct a larger asynchronous
/// computation out of smaller ones that have all been submitted to the
/// same [`Stream`].
pub const unsafe fn unwrap_ref_unchecked(&self) -> &T {
&self.value
}

/// # Safety
///
/// The returned reference to the inner value of type `T` may not yet have
/// completed its asynchronous work and may thus be in an inconsistent
/// state.
///
/// This method must only be used to construct a larger asynchronous
/// computation out of smaller ones that have all been submitted to the
/// same [`Stream`].
pub unsafe fn unwrap_mut_unchecked(&mut self) -> &mut T {
&mut self.value
}
}

#[cfg(feature = "host")]
Expand Down
51 changes: 34 additions & 17 deletions src/utils/exchange/buffer/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::{
utils::{
adapter::DeviceCopyWithPortableBitSemantics,
ffi::{DeviceAccessible, DeviceMutPointer},
r#async::{Async, CompletionFnMut, NoCompletion},
},
};

Expand Down Expand Up @@ -174,12 +175,12 @@ impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout, const M2D: bool, con
CudaExchangeBufferHost<T, M2D, M2H>
{
#[allow(clippy::type_complexity)]
pub unsafe fn borrow_async<A: CudaAlloc>(
pub unsafe fn borrow_async<'stream, A: CudaAlloc>(
&self,
alloc: A,
stream: &rustacuda::stream::Stream,
stream: &'stream rustacuda::stream::Stream,
) -> rustacuda::error::CudaResult<(
DeviceAccessible<CudaExchangeBufferCudaRepresentation<T, M2D, M2H>>,
Async<'_, 'stream, DeviceAccessible<CudaExchangeBufferCudaRepresentation<T, M2D, M2H>>>,
CombinedCudaAlloc<NoCudaAlloc, A>,
)> {
// Safety: device_buffer is inside an UnsafeCell
Expand All @@ -196,33 +197,49 @@ impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout, const M2D: bool, con
)?;
}

Ok((
DeviceAccessible::from(CudaExchangeBufferCudaRepresentation(
DeviceMutPointer(device_buffer.as_mut_ptr().cast()),
device_buffer.len(),
)),
CombinedCudaAlloc::new(NoCudaAlloc, alloc),
))
let cuda_repr = DeviceAccessible::from(CudaExchangeBufferCudaRepresentation(
DeviceMutPointer(device_buffer.as_mut_ptr().cast()),
device_buffer.len(),
));

let r#async = if M2D {
Async::pending(cuda_repr, stream, NoCompletion)?
} else {
Async::ready(cuda_repr, stream)
};

Ok((r#async, CombinedCudaAlloc::new(NoCudaAlloc, alloc)))
}

#[allow(clippy::type_complexity)]
pub unsafe fn restore_async<A: CudaAlloc>(
&mut self,
pub unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>(
mut this: owning_ref::BoxRefMut<'a, O, Self>,
alloc: CombinedCudaAlloc<NoCudaAlloc, A>,
stream: &rustacuda::stream::Stream,
) -> rustacuda::error::CudaResult<A> {
stream: &'stream rustacuda::stream::Stream,
) -> rustacuda::error::CudaResult<(
Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>,
A,
)> {
let (_alloc_front, alloc_tail) = alloc.split();

if M2H {
// Only move the buffer contents back to the host if needed

let this: &mut Self = &mut this;

rustacuda::memory::AsyncCopyDestination::async_copy_to(
&***self.device_buffer.get_mut(),
self.host_buffer.as_mut_slice(),
&***this.device_buffer.get_mut(),
this.host_buffer.as_mut_slice(),
stream,
)?;
}

Ok(alloc_tail)
let r#async = if M2H {
Async::<_, CompletionFnMut<'a, Self>>::pending(this, stream, Box::new(|_this| Ok(())))?
} else {
Async::ready(this, stream)
};

Ok((r#async, alloc_tail))
}
}
47 changes: 37 additions & 10 deletions src/utils/exchange/buffer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::{
use crate::{
alloc::{CombinedCudaAlloc, CudaAlloc},
utils::ffi::DeviceAccessible,
utils::r#async::{Async, CompletionFnMut},
};

#[cfg(any(feature = "host", feature = "device"))]
Expand Down Expand Up @@ -133,25 +134,51 @@ unsafe impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout, const M2D: bo

#[cfg(feature = "host")]
#[allow(clippy::type_complexity)]
unsafe fn borrow_async<A: CudaAlloc>(
unsafe fn borrow_async<'stream, A: CudaAlloc>(
&self,
alloc: A,
stream: &rustacuda::stream::Stream,
stream: &'stream rustacuda::stream::Stream,
) -> rustacuda::error::CudaResult<(
DeviceAccessible<Self::CudaRepresentation>,
CombinedCudaAlloc<Self::CudaAllocation, A>,
Async<'_, 'stream, DeviceAccessible<Self::CudaRepresentation>>,
CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
)> {
self.inner.borrow_async(alloc, stream)
}

#[cfg(feature = "host")]
#[allow(clippy::type_complexity)]
unsafe fn restore_async<A: CudaAlloc>(
&mut self,
alloc: CombinedCudaAlloc<Self::CudaAllocation, A>,
stream: &rustacuda::stream::Stream,
) -> rustacuda::error::CudaResult<A> {
self.inner.restore_async(alloc, stream)
unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>(
this: owning_ref::BoxRefMut<'a, O, Self>,
alloc: CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
stream: &'stream rustacuda::stream::Stream,
) -> rustacuda::error::CudaResult<(
Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>,
A,
)> {
let this_backup = unsafe { std::mem::ManuallyDrop::new(std::ptr::read(&this)) };

let (r#async, alloc_tail) = host::CudaExchangeBufferHost::restore_async(
this.map_mut(|this| &mut this.inner),
alloc,
stream,
)?;

let (inner, on_completion) = unsafe { r#async.unwrap_unchecked()? };

std::mem::forget(inner);
let this = std::mem::ManuallyDrop::into_inner(this_backup);

if let Some(on_completion) = on_completion {
let r#async = Async::<_, CompletionFnMut<'a, Self>>::pending(
this,
stream,
Box::new(|this: &mut Self| on_completion(&mut this.inner)),
)?;
Ok((r#async, alloc_tail))
} else {
let r#async = Async::ready(this, stream);
Ok((r#async, alloc_tail))
}
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/utils/exchange/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// pub mod buffer;
pub mod buffer;

// #[cfg(feature = "host")]
// pub mod wrapper;
#[cfg(feature = "host")]
pub mod wrapper;
Loading

0 comments on commit 91f9246

Please sign in to comment.