diff --git a/crates/cust/src/memory/device/device_slice.rs b/crates/cust/src/memory/device/device_slice.rs index e460c8c..ac1e4b2 100644 --- a/crates/cust/src/memory/device/device_slice.rs +++ b/crates/cust/src/memory/device/device_slice.rs @@ -10,6 +10,7 @@ use bytemuck::{Pod, Zeroable}; use std::mem::{self, size_of}; use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; use std::os::raw::c_void; +use std::slice; /// Fixed-size device-side slice. #[derive(Debug, Copy, Clone)] @@ -22,14 +23,6 @@ pub struct DeviceSlice { unsafe impl Send for DeviceSlice {} unsafe impl Sync for DeviceSlice {} -impl DeviceSlice { - pub fn as_host_vec(&self) -> CudaResult> { - let mut vec = vec![T::default(); self.len()]; - self.copy_to(&mut vec)?; - Ok(vec) - } -} - // This works by faking a regular slice out of the device raw-pointer and the length and transmuting // I have no idea if this is safe or not. Probably not, though I can't imagine how the compiler // could possibly know that the pointer is not de-referenceable. I'm banking that we get proper @@ -81,6 +74,17 @@ impl DeviceSlice { self.ptr } + pub fn as_host_vec(&self) -> CudaResult> { + let mut vec = Vec::with_capacity(self.len()); + // SAFETY: The slice points to uninitialized memory, but we only write to it. Once it is + // written, all values are valid, so we can (and must) change the length of the vector. + unsafe { + self.copy_to(slice::from_raw_parts_mut(vec.as_mut_ptr(), self.len()))?; + vec.set_len(self.len()) + } + Ok(vec) + } + /* TODO (AL): keep these? /// Divides one DeviceSlice into two at a given index. ///