From 91f9246832390215bc91ebf00b6692839918a13c Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Wed, 3 Jan 2024 04:02:35 +0000 Subject: [PATCH] Add back mostly unchanged exchange wrapper + buffer with RustToCudaAsync impls --- src/host/mod.rs | 6 - src/utils/async.rs | 26 ++ src/utils/exchange/buffer/host.rs | 51 ++- src/utils/exchange/buffer/mod.rs | 47 ++- src/utils/exchange/mod.rs | 6 +- src/utils/exchange/wrapper.rs | 536 +++++++++--------------------- 6 files changed, 262 insertions(+), 410 deletions(-) diff --git a/src/host/mod.rs b/src/host/mod.rs index f77c75792..a705c8504 100644 --- a/src/host/mod.rs +++ b/src/host/mod.rs @@ -362,10 +362,7 @@ impl<'stream, 'a, T: PortableBitSemantics + TypeGraphLayout> pub unsafe fn new( device_box: &'a mut DeviceBox>, host_ref: &'a mut T, - stream: &'stream Stream, ) -> Self { - let _ = stream; - Self { device_box, host_ref, @@ -448,10 +445,7 @@ impl<'stream, 'a, T: PortableBitSemantics + TypeGraphLayout> pub const unsafe fn new( device_box: &'a DeviceBox>, host_ref: &'a T, - stream: &'stream Stream, ) -> Self { - let _ = stream; - Self { device_box, host_ref, diff --git a/src/utils/async.rs b/src/utils/async.rs index 6aab8adca..f408431ae 100644 --- a/src/utils/async.rs +++ b/src/utils/async.rs @@ -223,6 +223,32 @@ impl<'a, 'stream, T: BorrowMut, C: Completion> Async<'a, 'strea } => Ok((self.value, Some(completion))), } } + + /// # Safety + /// + /// The returned reference to the inner value of type `T` may not yet have + /// completed its asynchronous work and may thus be in an inconsistent + /// state. + /// + /// This method must only be used to construct a larger asynchronous + /// computation out of smaller ones that have all been submitted to the + /// same [`Stream`]. + pub const unsafe fn unwrap_ref_unchecked(&self) -> &T { + &self.value + } + + /// # Safety + /// + /// The returned reference to the inner value of type `T` may not yet have + /// completed its asynchronous work and may thus be in an inconsistent + /// state. + /// + /// This method must only be used to construct a larger asynchronous + /// computation out of smaller ones that have all been submitted to the + /// same [`Stream`]. + pub unsafe fn unwrap_mut_unchecked(&mut self) -> &mut T { + &mut self.value + } } #[cfg(feature = "host")] diff --git a/src/utils/exchange/buffer/host.rs b/src/utils/exchange/buffer/host.rs index e62227d8e..ce0cb9d41 100644 --- a/src/utils/exchange/buffer/host.rs +++ b/src/utils/exchange/buffer/host.rs @@ -16,6 +16,7 @@ use crate::{ utils::{ adapter::DeviceCopyWithPortableBitSemantics, ffi::{DeviceAccessible, DeviceMutPointer}, + r#async::{Async, CompletionFnMut, NoCompletion}, }, }; @@ -174,12 +175,12 @@ impl { #[allow(clippy::type_complexity)] - pub unsafe fn borrow_async( + pub unsafe fn borrow_async<'stream, A: CudaAlloc>( &self, alloc: A, - stream: &rustacuda::stream::Stream, + stream: &'stream rustacuda::stream::Stream, ) -> rustacuda::error::CudaResult<( - DeviceAccessible>, + Async<'_, 'stream, DeviceAccessible>>, CombinedCudaAlloc, )> { // Safety: device_buffer is inside an UnsafeCell @@ -196,33 +197,49 @@ impl( - &mut self, + pub unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>( + mut this: owning_ref::BoxRefMut<'a, O, Self>, alloc: CombinedCudaAlloc, - stream: &rustacuda::stream::Stream, - ) -> rustacuda::error::CudaResult { + stream: &'stream rustacuda::stream::Stream, + ) -> rustacuda::error::CudaResult<( + Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>, + A, + )> { let (_alloc_front, alloc_tail) = alloc.split(); if M2H { // Only move the buffer contents back to the host if needed + let this: &mut Self = &mut this; + rustacuda::memory::AsyncCopyDestination::async_copy_to( - &***self.device_buffer.get_mut(), - self.host_buffer.as_mut_slice(), + &***this.device_buffer.get_mut(), + this.host_buffer.as_mut_slice(), stream, )?; } - Ok(alloc_tail) + let r#async = if M2H { + Async::<_, CompletionFnMut<'a, Self>>::pending(this, stream, Box::new(|_this| Ok(())))? + } else { + Async::ready(this, stream) + }; + + Ok((r#async, alloc_tail)) } } diff --git a/src/utils/exchange/buffer/mod.rs b/src/utils/exchange/buffer/mod.rs index 9dfc4414e..c48a715ac 100644 --- a/src/utils/exchange/buffer/mod.rs +++ b/src/utils/exchange/buffer/mod.rs @@ -20,6 +20,7 @@ use crate::{ use crate::{ alloc::{CombinedCudaAlloc, CudaAlloc}, utils::ffi::DeviceAccessible, + utils::r#async::{Async, CompletionFnMut}, }; #[cfg(any(feature = "host", feature = "device"))] @@ -133,25 +134,51 @@ unsafe impl( + unsafe fn borrow_async<'stream, A: CudaAlloc>( &self, alloc: A, - stream: &rustacuda::stream::Stream, + stream: &'stream rustacuda::stream::Stream, ) -> rustacuda::error::CudaResult<( - DeviceAccessible, - CombinedCudaAlloc, + Async<'_, 'stream, DeviceAccessible>, + CombinedCudaAlloc, )> { self.inner.borrow_async(alloc, stream) } #[cfg(feature = "host")] #[allow(clippy::type_complexity)] - unsafe fn restore_async( - &mut self, - alloc: CombinedCudaAlloc, - stream: &rustacuda::stream::Stream, - ) -> rustacuda::error::CudaResult { - self.inner.restore_async(alloc, stream) + unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>( + this: owning_ref::BoxRefMut<'a, O, Self>, + alloc: CombinedCudaAlloc, + stream: &'stream rustacuda::stream::Stream, + ) -> rustacuda::error::CudaResult<( + Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>, + A, + )> { + let this_backup = unsafe { std::mem::ManuallyDrop::new(std::ptr::read(&this)) }; + + let (r#async, alloc_tail) = host::CudaExchangeBufferHost::restore_async( + this.map_mut(|this| &mut this.inner), + alloc, + stream, + )?; + + let (inner, on_completion) = unsafe { r#async.unwrap_unchecked()? }; + + std::mem::forget(inner); + let this = std::mem::ManuallyDrop::into_inner(this_backup); + + if let Some(on_completion) = on_completion { + let r#async = Async::<_, CompletionFnMut<'a, Self>>::pending( + this, + stream, + Box::new(|this: &mut Self| on_completion(&mut this.inner)), + )?; + Ok((r#async, alloc_tail)) + } else { + let r#async = Async::ready(this, stream); + Ok((r#async, alloc_tail)) + } } } diff --git a/src/utils/exchange/mod.rs b/src/utils/exchange/mod.rs index 9c0de5e36..722e02559 100644 --- a/src/utils/exchange/mod.rs +++ b/src/utils/exchange/mod.rs @@ -1,4 +1,4 @@ -// pub mod buffer; +pub mod buffer; -// #[cfg(feature = "host")] -// pub mod wrapper; +#[cfg(feature = "host")] +pub mod wrapper; diff --git a/src/utils/exchange/wrapper.rs b/src/utils/exchange/wrapper.rs index 454ecc8f3..09aef582d 100644 --- a/src/utils/exchange/wrapper.rs +++ b/src/utils/exchange/wrapper.rs @@ -1,16 +1,9 @@ -use std::{ - future::{Future, IntoFuture}, - marker::PhantomData, - ops::{Deref, DerefMut}, - sync::{Arc, Mutex}, - task::{Poll, Waker}, -}; +use std::ops::{Deref, DerefMut}; use rustacuda::{ - error::{CudaError, CudaResult}, - event::{Event, EventFlags, EventStatus}, + error::CudaResult, memory::{AsyncCopyDestination, CopyDestination, DeviceBox, LockedBox}, - stream::{Stream, StreamWaitEventFlags}, + stream::Stream, }; use crate::{ @@ -20,32 +13,16 @@ use crate::{ HostAndDeviceMutRefAsync, }, lend::{RustToCuda, RustToCudaAsync}, - utils::{adapter::DeviceCopyWithPortableBitSemantics, ffi::DeviceAccessible}, + utils::{ + adapter::DeviceCopyWithPortableBitSemantics, + ffi::DeviceAccessible, + r#async::{Async, CompletionFnMut, NoCompletion}, + }, }; #[allow(clippy::module_name_repetitions)] pub struct ExchangeWrapperOnHost> { - value: T, - device_box: CudaDropWrapper< - DeviceBox< - DeviceCopyWithPortableBitSemantics< - DeviceAccessible<::CudaRepresentation>, - >, - >, - >, - locked_cuda_repr: CudaDropWrapper< - LockedBox< - DeviceCopyWithPortableBitSemantics< - DeviceAccessible<::CudaRepresentation>, - >, - >, - >, - move_event: CudaDropWrapper, -} - -#[allow(clippy::module_name_repetitions)] -pub struct ExchangeWrapperOnHostAsync<'stream, T: RustToCuda> { - value: T, + value: Box, device_box: CudaDropWrapper< DeviceBox< DeviceCopyWithPortableBitSemantics< @@ -60,34 +37,11 @@ pub struct ExchangeWrapperOnHostAsync<'stream, T: RustToCuda, >, >, - move_event: CudaDropWrapper, - stream: PhantomData<&'stream Stream>, - waker: Arc>>, } #[allow(clippy::module_name_repetitions)] pub struct ExchangeWrapperOnDevice> { - value: T, - device_box: CudaDropWrapper< - DeviceBox< - DeviceCopyWithPortableBitSemantics< - DeviceAccessible<::CudaRepresentation>, - >, - >, - >, - locked_cuda_repr: CudaDropWrapper< - LockedBox< - DeviceCopyWithPortableBitSemantics< - DeviceAccessible<::CudaRepresentation>, - >, - >, - >, - move_event: CudaDropWrapper, -} - -#[allow(clippy::module_name_repetitions)] -pub struct ExchangeWrapperOnDeviceAsync<'stream, T: RustToCuda> { - value: T, + value: Box, device_box: CudaDropWrapper< DeviceBox< DeviceCopyWithPortableBitSemantics< @@ -102,9 +56,6 @@ pub struct ExchangeWrapperOnDeviceAsync<'stream, T: RustToCuda, >, >, - move_event: CudaDropWrapper, - stream: &'stream Stream, - waker: Arc>>, } impl> ExchangeWrapperOnHost { @@ -130,16 +81,14 @@ impl> ExchangeWrapperOnHost { uninit }; - let move_event = Event::new(EventFlags::DISABLE_TIMING)?.into(); - Ok(Self { - value, + value: Box::new(value), device_box, locked_cuda_repr, - move_event, }) } + // TODO: safety constraint? /// Moves the data synchronously to the CUDA device, where it can then be /// lent out immutably via [`ExchangeWrapperOnDevice::as_ref`], or mutably /// via [`ExchangeWrapperOnDevice::as_mut`]. @@ -164,7 +113,6 @@ impl> ExchangeWrapperOnHost { value: self.value, device_box: self.device_box, locked_cuda_repr: self.locked_cuda_repr, - move_event: self.move_event, }) } } @@ -172,6 +120,8 @@ impl> ExchangeWrapperOnHost { impl> ExchangeWrapperOnHost { + #[allow(clippy::needless_lifetimes)] // keep 'stream explicit + // TODO: safety constraint? /// Moves the data asynchronously to the CUDA device. /// /// To avoid aliasing, each CUDA thread will get access to its own shallow @@ -182,11 +132,14 @@ impl( mut self, - stream: &Stream, - ) -> CudaResult> { - let (cuda_repr, null_alloc) = unsafe { self.value.borrow_async(NoCudaAlloc, stream) }?; + stream: &'stream Stream, + ) -> CudaResult, NoCompletion>> { + let (cuda_repr, _null_alloc) = unsafe { self.value.borrow_async(NoCudaAlloc, stream) }?; + let (cuda_repr, _completion): (_, Option) = + unsafe { cuda_repr.unwrap_unchecked()? }; + **self.locked_cuda_repr = DeviceCopyWithPortableBitSemantics::from(cuda_repr); // Safety: The device value is not safely exposed until either @@ -196,112 +149,16 @@ impl>> = Arc::new(Mutex::new(None)); - - let waker_callback = waker.clone(); - stream.add_callback(Box::new(move |_| { - if let Ok(mut w) = waker_callback.lock() { - if let Some(w) = w.take() { - w.wake(); - } - } - }))?; - - let _: NoCudaAlloc = null_alloc.into(); - - Ok(ExchangeWrapperOnDeviceAsync { - value: self.value, - device_box: self.device_box, - locked_cuda_repr: self.locked_cuda_repr, - move_event: self.move_event, - stream, - waker, - }) - } -} - -impl<'stream, T: RustToCuda> - ExchangeWrapperOnHostAsync<'stream, T> -{ - /// Synchronises the host CPU thread until the data has moved to the CPU. - /// - /// # Errors - /// Returns a [`rustacuda::error::CudaError`] iff an error occurs inside - /// CUDA - pub fn sync_to_host(self) -> CudaResult> { - self.move_event.synchronize()?; - - Ok(ExchangeWrapperOnHost { - value: self.value, - device_box: self.device_box, - locked_cuda_repr: self.locked_cuda_repr, - move_event: self.move_event, - }) - } - - /// Moves the asynchronous data move to a different [`Stream`]. - /// - /// # Errors - /// Returns a [`rustacuda::error::CudaError`] iff an error occurs inside - /// CUDA - pub fn move_to_stream(self, stream: &Stream) -> CudaResult> { - stream.wait_event(&self.move_event, StreamWaitEventFlags::DEFAULT)?; - self.move_event.record(stream)?; - - let waker_callback = self.waker.clone(); - stream.add_callback(Box::new(move |_| { - if let Ok(mut w) = waker_callback.lock() { - if let Some(w) = w.take() { - w.wake(); - } - } - }))?; - - Ok(ExchangeWrapperOnHostAsync { - value: self.value, - device_box: self.device_box, - locked_cuda_repr: self.locked_cuda_repr, - move_event: self.move_event, - stream: PhantomData::<&Stream>, - waker: self.waker, - }) - } -} -impl<'stream, T: RustToCuda> IntoFuture - for ExchangeWrapperOnHostAsync<'stream, T> -{ - type Output = CudaResult>; - - type IntoFuture = impl Future; - - fn into_future(self) -> Self::IntoFuture { - let mut wrapper = Some(self); - - core::future::poll_fn(move |cx| match &wrapper { - Some(inner) => match inner.move_event.query() { - Ok(EventStatus::NotReady) => inner.waker.lock().map_or_else( - |_| Poll::Ready(Err(CudaError::OperatingSystemError)), - |mut w| { - *w = Some(cx.waker().clone()); - Poll::Pending - }, - ), - Ok(EventStatus::Ready) => match wrapper.take() { - Some(inner) => Poll::Ready(Ok(ExchangeWrapperOnHost { - value: inner.value, - device_box: inner.device_box, - locked_cuda_repr: inner.locked_cuda_repr, - move_event: inner.move_event, - })), - None => Poll::Ready(Err(CudaError::AlreadyAcquired)), - }, - Err(err) => Poll::Ready(Err(err)), + Async::pending( + ExchangeWrapperOnDevice { + value: self.value, + device_box: self.device_box, + locked_cuda_repr: self.locked_cuda_repr, }, - None => Poll::Ready(Err(CudaError::AlreadyAcquired)), - }) + stream, + NoCompletion, + ) } } @@ -319,83 +176,60 @@ impl> DerefMut for ExchangeWrapper } } -impl<'stream, T: RustToCuda> - ExchangeWrapperOnDeviceAsync<'stream, T> -{ - /// Synchronises the host CPU thread until the data has moved to the GPU. +impl> ExchangeWrapperOnDevice { + // TODO: safety constraint? + /// Moves the data synchronously back to the host CPU device. + /// + /// To avoid aliasing, each CUDA thread only got access to its own shallow + /// copy of the data. Hence, + /// - any shallow changes to the data will NOT be reflected back to the CPU + /// - any deep changes to the data WILL be reflected back to the CPU /// /// # Errors /// Returns a [`rustacuda::error::CudaError`] iff an error occurs inside /// CUDA - pub fn sync_to_device(self) -> CudaResult> { - self.move_event.synchronize()?; + pub fn move_to_host(mut self) -> CudaResult> { + let null_alloc = NoCudaAlloc.into(); - Ok(ExchangeWrapperOnDevice { - value: self.value, - device_box: self.device_box, - locked_cuda_repr: self.locked_cuda_repr, - move_event: self.move_event, - }) - } + // Reflect deep changes back to the CPU + let _null_alloc: NoCudaAlloc = unsafe { self.value.restore(null_alloc) }?; - /// Moves the asynchronous data move to a different [`Stream`]. - /// - /// # Errors - /// Returns a [`rustacuda::error::CudaError`] iff an error occurs inside - /// CUDA - pub fn move_to_stream( - self, - stream: &Stream, - ) -> CudaResult> { - stream.wait_event(&self.move_event, StreamWaitEventFlags::DEFAULT)?; - self.move_event.record(stream)?; - - let waker_callback = self.waker.clone(); - stream.add_callback(Box::new(move |_| { - if let Ok(mut w) = waker_callback.lock() { - if let Some(w) = w.take() { - w.wake(); - } - } - }))?; - - Ok(ExchangeWrapperOnDeviceAsync { + // Note: Shallow changes are not reflected back to the CPU + + Ok(ExchangeWrapperOnHost { value: self.value, device_box: self.device_box, locked_cuda_repr: self.locked_cuda_repr, - move_event: self.move_event, - stream, - waker: self.waker, }) } - pub fn as_ref_async( + #[must_use] + pub fn as_ref( &self, - ) -> HostAndDeviceConstRefAsync::CudaRepresentation>> { + ) -> HostAndDeviceConstRef::CudaRepresentation>> { // Safety: `device_box` contains exactly the device copy of `locked_cuda_repr` unsafe { - HostAndDeviceConstRefAsync::new( - &*self.device_box, - (**self.locked_cuda_repr).into_ref(), - self.stream, - ) + HostAndDeviceConstRef::new(&self.device_box, (**self.locked_cuda_repr).into_ref()) } } - pub fn as_mut_async( + #[must_use] + pub fn as_mut( &mut self, - ) -> HostAndDeviceMutRefAsync::CudaRepresentation>> { + ) -> HostAndDeviceMutRef::CudaRepresentation>> { // Safety: `device_box` contains exactly the device copy of `locked_cuda_repr` unsafe { - HostAndDeviceMutRefAsync::new( - &mut self.device_box, - (**self.locked_cuda_repr).into_mut(), - self.stream, - ) + HostAndDeviceMutRef::new(&mut self.device_box, (**self.locked_cuda_repr).into_mut()) } } +} - /// Moves the data synchronously back to the host CPU device. +impl> + ExchangeWrapperOnDevice +{ + #[allow(clippy::needless_lifetimes)] // keep 'stream explicit + // TODO: safety constraint? + /// Moves the data asynchronously back to the host CPU device. /// /// To avoid aliasing, each CUDA thread only got access to its own shallow /// copy of the data. Hence, @@ -405,28 +239,60 @@ impl<'stream, T: RustToCuda> /// # Errors /// Returns a [`rustacuda::error::CudaError`] iff an error occurs inside /// CUDA - pub fn move_to_host(mut self) -> CudaResult> { + pub fn move_to_host_async<'stream>( + self, + stream: &'stream Stream, + ) -> CudaResult< + Async< + 'static, + 'stream, + ExchangeWrapperOnHost, + CompletionFnMut<'static, ExchangeWrapperOnHost>, + >, + > { let null_alloc = NoCudaAlloc.into(); + let value = owning_ref::BoxRefMut::new(self.value); + // Reflect deep changes back to the CPU - let _null_alloc: NoCudaAlloc = unsafe { self.value.restore(null_alloc) }?; + let (r#async, _null_alloc): (_, NoCudaAlloc) = + unsafe { RustToCudaAsync::restore_async(value, null_alloc, stream) }?; + let (value, on_complete) = unsafe { r#async.unwrap_unchecked()? }; + + let value = value.into_owner(); // Note: Shallow changes are not reflected back to the CPU - Ok(ExchangeWrapperOnHost { - value: self.value, - device_box: self.device_box, - locked_cuda_repr: self.locked_cuda_repr, - move_event: self.move_event, - }) + if let Some(on_complete) = on_complete { + Async::<_, CompletionFnMut>>::pending( + ExchangeWrapperOnHost { + value, + device_box: self.device_box, + locked_cuda_repr: self.locked_cuda_repr, + }, + stream, + Box::new(|on_host: &mut ExchangeWrapperOnHost| on_complete(&mut on_host.value)), + ) + } else { + Ok(Async::ready( + ExchangeWrapperOnHost { + value, + device_box: self.device_box, + locked_cuda_repr: self.locked_cuda_repr, + }, + stream, + )) + } } } impl< + 'a, 'stream, T: RustToCudaAsync, - > ExchangeWrapperOnDeviceAsync<'stream, T> + > Async<'a, 'stream, ExchangeWrapperOnDevice, NoCompletion> { + // TODO: safety constraint? /// Moves the data asynchronously back to the host CPU device. /// /// To avoid aliasing, each CUDA thread only got access to its own shallow @@ -438,165 +304,87 @@ impl< /// Returns a [`rustacuda::error::CudaError`] iff an error occurs inside /// CUDA pub fn move_to_host_async( - mut self, + self, stream: &'stream Stream, - ) -> CudaResult> { - let null_alloc = NoCudaAlloc.into(); - - // Reflect deep changes back to the CPU - let _null_alloc: NoCudaAlloc = unsafe { self.value.restore_async(null_alloc, stream) }?; - - // Note: Shallow changes are not reflected back to the CPU - - self.move_event.record(stream)?; - - let waker: Arc>> = Arc::new(Mutex::new(None)); - - let waker_callback = waker.clone(); - stream.add_callback(Box::new(move |_| { - if let Ok(mut w) = waker_callback.lock() { - if let Some(w) = w.take() { - w.wake(); - } - } - }))?; - - Ok(ExchangeWrapperOnHostAsync { - value: self.value, - device_box: self.device_box, - locked_cuda_repr: self.locked_cuda_repr, - move_event: self.move_event, - stream: PhantomData::<&'stream Stream>, - waker, - }) - } -} - -impl<'stream, T: RustToCuda> IntoFuture - for ExchangeWrapperOnDeviceAsync<'stream, T> -{ - type Output = CudaResult>; - - type IntoFuture = impl Future; - - fn into_future(self) -> Self::IntoFuture { - let mut wrapper = Some(self); - - core::future::poll_fn(move |cx| match &wrapper { - Some(inner) => match inner.move_event.query() { - Ok(EventStatus::NotReady) => inner.waker.lock().map_or_else( - |_| Poll::Ready(Err(CudaError::OperatingSystemError)), - |mut w| { - *w = Some(cx.waker().clone()); - Poll::Pending - }, - ), - Ok(EventStatus::Ready) => match wrapper.take() { - Some(inner) => Poll::Ready(Ok(ExchangeWrapperOnDevice { - value: inner.value, - device_box: inner.device_box, - locked_cuda_repr: inner.locked_cuda_repr, - move_event: inner.move_event, - })), - None => Poll::Ready(Err(CudaError::AlreadyAcquired)), - }, - Err(err) => Poll::Ready(Err(err)), - }, - None => Poll::Ready(Err(CudaError::AlreadyAcquired)), - }) - } -} + ) -> CudaResult< + Async< + 'static, + 'stream, + ExchangeWrapperOnHost, + CompletionFnMut<'static, ExchangeWrapperOnHost>, + >, + > { + let (this, completion): (_, Option) = unsafe { self.unwrap_unchecked()? }; -impl> ExchangeWrapperOnDevice { - /// Moves the data synchronously back to the host CPU device. - /// - /// To avoid aliasing, each CUDA thread only got access to its own shallow - /// copy of the data. Hence, - /// - any shallow changes to the data will NOT be reflected back to the CPU - /// - any deep changes to the data WILL be reflected back to the CPU - /// - /// # Errors - /// Returns a [`rustacuda::error::CudaError`] iff an error occurs inside - /// CUDA - pub fn move_to_host(mut self) -> CudaResult> { let null_alloc = NoCudaAlloc.into(); + let value = owning_ref::BoxRefMut::new(this.value); + // Reflect deep changes back to the CPU - let _null_alloc: NoCudaAlloc = unsafe { self.value.restore(null_alloc) }?; + let (r#async, _null_alloc): (_, NoCudaAlloc) = + unsafe { RustToCudaAsync::restore_async(value, null_alloc, stream) }?; + let (value, on_complete) = unsafe { r#async.unwrap_unchecked()? }; + + let value = value.into_owner(); // Note: Shallow changes are not reflected back to the CPU - Ok(ExchangeWrapperOnHost { - value: self.value, - device_box: self.device_box, - locked_cuda_repr: self.locked_cuda_repr, - move_event: self.move_event, - }) + let on_host = ExchangeWrapperOnHost { + value, + device_box: this.device_box, + locked_cuda_repr: this.locked_cuda_repr, + }; + + if let Some(on_complete) = on_complete { + Async::<_, CompletionFnMut>>::pending( + on_host, + stream, + Box::new(|on_host: &mut ExchangeWrapperOnHost| on_complete(&mut on_host.value)), + ) + } else if matches!(completion, Some(NoCompletion)) { + Async::<_, CompletionFnMut>>::pending( + on_host, + stream, + Box::new(|_on_host: &mut ExchangeWrapperOnHost| Ok(())), + ) + } else { + Ok(Async::ready(on_host, stream)) + } } - pub fn as_ref( + // TODO: replace by async borrow map + #[must_use] + pub fn as_ref_async( &self, - ) -> HostAndDeviceConstRef::CudaRepresentation>> { + ) -> HostAndDeviceConstRefAsync< + 'stream, + '_, + DeviceAccessible<::CudaRepresentation>, + > { + let this = unsafe { self.unwrap_ref_unchecked() }; + // Safety: `device_box` contains exactly the device copy of `locked_cuda_repr` unsafe { - HostAndDeviceConstRef::new(&self.device_box, (**self.locked_cuda_repr).into_ref()) + HostAndDeviceConstRefAsync::new( + &*(this.device_box), + (**(this.locked_cuda_repr)).into_ref(), + ) } } - pub fn as_mut( + // TODO: replace by async borrow map mut + #[must_use] + pub fn as_mut_async( &mut self, - ) -> HostAndDeviceMutRef::CudaRepresentation>> { + ) -> HostAndDeviceMutRefAsync::CudaRepresentation>> { + let this = unsafe { self.unwrap_mut_unchecked() }; + // Safety: `device_box` contains exactly the device copy of `locked_cuda_repr` unsafe { - HostAndDeviceMutRef::new(&mut self.device_box, (**self.locked_cuda_repr).into_mut()) + HostAndDeviceMutRefAsync::new( + &mut *(this.device_box), + (**(this.locked_cuda_repr)).into_mut(), + ) } } } - -impl> - ExchangeWrapperOnDevice -{ - /// Moves the data asynchronously back to the host CPU device. - /// - /// To avoid aliasing, each CUDA thread only got access to its own shallow - /// copy of the data. Hence, - /// - any shallow changes to the data will NOT be reflected back to the CPU - /// - any deep changes to the data WILL be reflected back to the CPU - /// - /// # Errors - /// Returns a [`rustacuda::error::CudaError`] iff an error occurs inside - /// CUDA - pub fn move_to_host_async( - mut self, - stream: &Stream, - ) -> CudaResult> { - let null_alloc = NoCudaAlloc.into(); - - // Reflect deep changes back to the CPU - let _null_alloc: NoCudaAlloc = unsafe { self.value.restore_async(null_alloc, stream) }?; - - // Note: Shallow changes are not reflected back to the CPU - - self.move_event.record(stream)?; - - let waker: Arc>> = Arc::new(Mutex::new(None)); - - let waker_callback = waker.clone(); - stream.add_callback(Box::new(move |_| { - if let Ok(mut w) = waker_callback.lock() { - if let Some(w) = w.take() { - w.wake(); - } - } - }))?; - - Ok(ExchangeWrapperOnHostAsync { - value: self.value, - device_box: self.device_box, - locked_cuda_repr: self.locked_cuda_repr, - move_event: self.move_event, - stream: PhantomData::<&Stream>, - waker, - }) - } -}