diff --git a/crates/cust/src/memory/device/device_buffer.rs b/crates/cust/src/memory/device/device_buffer.rs index 3e0ce92..76ff4d8 100644 --- a/crates/cust/src/memory/device/device_buffer.rs +++ b/crates/cust/src/memory/device/device_buffer.rs @@ -400,13 +400,13 @@ impl Deref for DeviceBuffer { type Target = DeviceSlice; fn deref(&self) -> &DeviceSlice { - unsafe { &*(self as *const _ as *const DeviceSlice) } + unsafe { DeviceSlice::from_raw_parts(self.buf, self.len) } } } impl DerefMut for DeviceBuffer { fn deref_mut(&mut self) -> &mut DeviceSlice { - unsafe { &mut *(self as *mut _ as *mut DeviceSlice) } + unsafe { DeviceSlice::from_raw_parts_mut(self.buf, self.len) } } } diff --git a/crates/cust/src/memory/device/device_slice.rs b/crates/cust/src/memory/device/device_slice.rs index e460c8c..6098ed5 100644 --- a/crates/cust/src/memory/device/device_slice.rs +++ b/crates/cust/src/memory/device/device_slice.rs @@ -1,27 +1,39 @@ use crate::error::{CudaResult, ToResult}; use crate::memory::device::AsyncCopyDestination; use crate::memory::device::{CopyDestination, DeviceBuffer}; -use crate::memory::DeviceCopy; use crate::memory::DevicePointer; +use crate::memory::{DeviceCopy, DeviceMemory}; use crate::stream::Stream; use crate::sys as cuda; #[cfg(feature = "bytemuck")] use bytemuck::{Pod, Zeroable}; +use std::fmt::{self, Debug, Formatter}; +use std::marker::PhantomData; use std::mem::{self, size_of}; use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; use std::os::raw::c_void; +use std::ptr::{slice_from_raw_parts, slice_from_raw_parts_mut}; /// Fixed-size device-side slice. -#[derive(Debug, Copy, Clone)] -#[repr(C)] +#[repr(transparent)] pub struct DeviceSlice { - ptr: DevicePointer, - len: usize, + _phantom: PhantomData, + slice: [()], } unsafe impl Send for DeviceSlice {} unsafe impl Sync for DeviceSlice {} +impl Debug for DeviceSlice { + fn fmt(&self, formatter: &mut Formatter) -> fmt::Result { + formatter + .debug_struct("DeviceSlice") + .field("ptr", &self.as_device_ptr().as_ptr()) + .field("len", &self.len()) + .finish() + } +} + impl DeviceSlice { pub fn as_host_vec(&self) -> CudaResult> { let mut vec = vec![T::default(); self.len()]; @@ -46,7 +58,7 @@ impl DeviceSlice { /// assert_eq!(a.len(), 3); /// ``` pub fn len(&self) -> usize { - self.len + self.slice.len() } /// Returns `true` if the slice has a length of 0. @@ -60,7 +72,7 @@ impl DeviceSlice { /// assert!(a.is_empty()); /// ``` pub fn is_empty(&self) -> bool { - self.len == 0 + self.len() == 0 } /// Return a raw device-pointer to the slice's buffer. @@ -78,7 +90,7 @@ impl DeviceSlice { /// println!("{:p}", a.as_ptr()); /// ``` pub fn as_device_ptr(&self) -> DevicePointer { - self.ptr + DevicePointer::from_raw(self as *const _ as *const () as usize as u64) } /* TODO (AL): keep these? @@ -184,8 +196,8 @@ impl DeviceSlice { /// assert_eq!([1u64, 2], host_buf); /// ``` #[allow(clippy::needless_pass_by_value)] - pub unsafe fn from_raw_parts(ptr: DevicePointer, len: usize) -> DeviceSlice { - DeviceSlice { ptr, len } + pub unsafe fn from_raw_parts<'a>(ptr: DevicePointer, len: usize) -> &'a DeviceSlice { + &*(slice_from_raw_parts(ptr.as_ptr(), len) as *const DeviceSlice) } /// Performs the same functionality as `from_raw_parts`, except that a @@ -203,8 +215,11 @@ impl DeviceSlice { /// slices as with `from_raw_parts`. /// /// See the documentation of `from_raw_parts` for more details. - pub unsafe fn from_raw_parts_mut(ptr: DevicePointer, len: usize) -> DeviceSlice { - DeviceSlice { ptr, len } + pub unsafe fn from_raw_parts_mut<'a>( + ptr: DevicePointer, + len: usize, + ) -> &'a mut DeviceSlice { + &mut *(slice_from_raw_parts_mut(ptr.as_mut_ptr(), len) as *mut DeviceSlice) } } @@ -221,14 +236,15 @@ impl DeviceSlice { /// In total it will set `sizeof * len` values of `value` contiguously. #[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))] pub fn set_8(&mut self, value: u8) -> CudaResult<()> { - if self.ptr.is_null() { - return Ok(()); - } - // SAFETY: We know T can hold any value because it is `Pod`, and // sub-byte alignment isn't a thing so we know the alignment is right. unsafe { - cuda::cuMemsetD8_v2(self.ptr.as_raw(), value, size_of::() * self.len).to_result() + cuda::cuMemsetD8_v2( + self.as_device_ptr().as_raw(), + value, + size_of::() * self.len(), + ) + .to_result() } } @@ -242,14 +258,10 @@ impl DeviceSlice { /// Therefore you should not read/write from/to the memory range until the operation is complete. #[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))] pub unsafe fn set_8_async(&mut self, value: u8, stream: &Stream) -> CudaResult<()> { - if self.ptr.is_null() { - return Ok(()); - } - cuda::cuMemsetD8Async( - self.ptr.as_raw(), + self.as_device_ptr().as_raw(), value, - size_of::() * self.len, + size_of::() * self.len(), stream.as_inner(), ) .to_result() @@ -267,21 +279,20 @@ impl DeviceSlice { #[track_caller] #[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))] pub fn set_16(&mut self, value: u16) -> CudaResult<()> { - if self.ptr.is_null() { - return Ok(()); - } - let data_len = size_of::() * self.len; + let data_len = size_of::() * self.len(); assert_eq!( data_len % 2, 0, "Buffer length is not a multiple of 2 bytes!" ); assert_eq!( - self.ptr.as_raw() % 2, + self.as_device_ptr().as_raw() % 2, 0, "Buffer pointer is not aligned to at least 2 bytes!" ); - unsafe { cuda::cuMemsetD16_v2(self.ptr.as_raw(), value, data_len / 2).to_result() } + unsafe { + cuda::cuMemsetD16_v2(self.as_device_ptr().as_raw(), value, data_len / 2).to_result() + } } /// Sets the memory range of this buffer to contiguous `16-bit` values of `value` asynchronously. @@ -301,22 +312,24 @@ impl DeviceSlice { #[track_caller] #[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))] pub unsafe fn set_16_async(&mut self, value: u16, stream: &Stream) -> CudaResult<()> { - if self.ptr.is_null() { - return Ok(()); - } - let data_len = size_of::() * self.len; + let data_len = size_of::() * self.len(); assert_eq!( data_len % 2, 0, "Buffer length is not a multiple of 2 bytes!" ); assert_eq!( - self.ptr.as_raw() % 2, + self.as_device_ptr().as_raw() % 2, 0, "Buffer pointer is not aligned to at least 2 bytes!" ); - cuda::cuMemsetD16Async(self.ptr.as_raw(), value, data_len / 2, stream.as_inner()) - .to_result() + cuda::cuMemsetD16Async( + self.as_device_ptr().as_raw(), + value, + data_len / 2, + stream.as_inner(), + ) + .to_result() } /// Sets the memory range of this buffer to contiguous `32-bit` values of `value`. @@ -331,21 +344,20 @@ impl DeviceSlice { #[track_caller] #[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))] pub fn set_32(&mut self, value: u32) -> CudaResult<()> { - if self.ptr.is_null() { - return Ok(()); - } - let data_len = size_of::() * self.len; + let data_len = size_of::() * self.len(); assert_eq!( data_len % 4, 0, "Buffer length is not a multiple of 4 bytes!" ); assert_eq!( - self.ptr.as_raw() % 4, + self.as_device_ptr().as_raw() % 4, 0, "Buffer pointer is not aligned to at least 4 bytes!" ); - unsafe { cuda::cuMemsetD32_v2(self.ptr.as_raw(), value, data_len / 4).to_result() } + unsafe { + cuda::cuMemsetD32_v2(self.as_device_ptr().as_raw(), value, data_len / 4).to_result() + } } /// Sets the memory range of this buffer to contiguous `32-bit` values of `value` asynchronously. @@ -365,22 +377,24 @@ impl DeviceSlice { #[track_caller] #[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))] pub unsafe fn set_32_async(&mut self, value: u32, stream: &Stream) -> CudaResult<()> { - if self.ptr.is_null() { - return Ok(()); - } - let data_len = size_of::() * self.len; + let data_len = size_of::() * self.len(); assert_eq!( data_len % 4, 0, "Buffer length is not a multiple of 4 bytes!" ); assert_eq!( - self.ptr.as_raw() % 4, + self.as_device_ptr().as_raw() % 4, 0, "Buffer pointer is not aligned to at least 4 bytes!" ); - cuda::cuMemsetD32Async(self.ptr.as_raw(), value, data_len / 4, stream.as_inner()) - .to_result() + cuda::cuMemsetD32Async( + self.as_device_ptr().as_raw(), + value, + data_len / 4, + stream.as_inner(), + ) + .to_result() } } @@ -388,14 +402,13 @@ impl DeviceSlice { impl DeviceSlice { /// Sets this slice's data to zero. pub fn set_zero(&mut self) -> CudaResult<()> { - if self.ptr.is_null() { - return Ok(()); - } // SAFETY: this is fine because Zeroable guarantees a zero byte-pattern is safe // for this type. And a slice of bytes can represent any type. - let mut erased = DeviceSlice { - ptr: self.ptr.cast::(), - len: size_of::() * self.len, + let erased = unsafe { + DeviceSlice::from_raw_parts_mut( + self.as_device_ptr().cast::(), + size_of::() * self.len(), + ) }; erased.set_8(0) } @@ -407,15 +420,15 @@ impl DeviceSlice { /// This operation is async so it does not complete immediately, it uses stream-ordering semantics. /// Therefore you should not read/write from/to the memory range until the operation is complete. pub unsafe fn set_zero_async(&mut self, stream: &Stream) -> CudaResult<()> { - if self.ptr.is_null() { + if self.as_device_ptr().is_null() { return Ok(()); } // SAFETY: this is fine because Zeroable guarantees a zero byte-pattern is safe // for this type. And a slice of bytes can represent any type. - let mut erased = DeviceSlice { - ptr: self.ptr.cast::(), - len: size_of::() * self.len, - }; + let erased = DeviceSlice::from_raw_parts_mut( + self.as_device_ptr().cast::(), + size_of::() * self.len(), + ); erased.set_8_async(0, stream) } } @@ -426,8 +439,16 @@ pub trait DeviceSliceIndex { /// # Safety /// /// The range must be in-bounds of the slice. - unsafe fn get_unchecked(self, slice: &DeviceSlice) -> DeviceSlice; - fn index(self, slice: &DeviceSlice) -> DeviceSlice; + unsafe fn get_unchecked(self, slice: &DeviceSlice) -> &DeviceSlice; + fn index(self, slice: &DeviceSlice) -> &DeviceSlice; + + /// Indexes into this slice without checking if it is in-bounds. + /// + /// # Safety + /// + /// The range must be in-bounds of the slice. + unsafe fn get_unchecked_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice; + fn index_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice; } #[inline(never)] @@ -465,19 +486,26 @@ fn slice_end_index_overflow_fail() -> ! { } impl DeviceSliceIndex for usize { - unsafe fn get_unchecked(self, slice: &DeviceSlice) -> DeviceSlice { + unsafe fn get_unchecked(self, slice: &DeviceSlice) -> &DeviceSlice { (self..self + 1).get_unchecked(slice) } - fn index(self, slice: &DeviceSlice) -> DeviceSlice { + fn index(self, slice: &DeviceSlice) -> &DeviceSlice { slice.index(self..self + 1) } + + unsafe fn get_unchecked_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice { + (self..self + 1).get_unchecked_mut(slice) + } + fn index_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice { + slice.index_mut(self..self + 1) + } } impl DeviceSliceIndex for Range { - unsafe fn get_unchecked(self, slice: &DeviceSlice) -> DeviceSlice { + unsafe fn get_unchecked(self, slice: &DeviceSlice) -> &DeviceSlice { DeviceSlice::from_raw_parts(slice.as_device_ptr().add(self.start), self.end - self.start) } - fn index(self, slice: &DeviceSlice) -> DeviceSlice { + fn index(self, slice: &DeviceSlice) -> &DeviceSlice { if self.start > self.end { slice_index_order_fail(self.start, self.end); } else if self.end > slice.len() { @@ -486,36 +514,77 @@ impl DeviceSliceIndex for Range { // SAFETY: `self` is checked to be valid and in bounds above. unsafe { self.get_unchecked(slice) } } + + unsafe fn get_unchecked_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice { + DeviceSlice::from_raw_parts_mut( + slice.as_device_ptr().add(self.start), + self.end - self.start, + ) + } + fn index_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice { + if self.start > self.end { + slice_index_order_fail(self.start, self.end); + } else if self.end > slice.len() { + slice_end_index_len_fail(self.end, slice.len()); + } + // SAFETY: `self` is checked to be valid and in bounds above. + unsafe { self.get_unchecked_mut(slice) } + } } impl DeviceSliceIndex for RangeTo { - unsafe fn get_unchecked(self, slice: &DeviceSlice) -> DeviceSlice { + unsafe fn get_unchecked(self, slice: &DeviceSlice) -> &DeviceSlice { (0..self.end).get_unchecked(slice) } - fn index(self, slice: &DeviceSlice) -> DeviceSlice { + fn index(self, slice: &DeviceSlice) -> &DeviceSlice { (0..self.end).index(slice) } + + unsafe fn get_unchecked_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice { + (0..self.end).get_unchecked_mut(slice) + } + fn index_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice { + (0..self.end).index_mut(slice) + } } impl DeviceSliceIndex for RangeFrom { - unsafe fn get_unchecked(self, slice: &DeviceSlice) -> DeviceSlice { + unsafe fn get_unchecked(self, slice: &DeviceSlice) -> &DeviceSlice { (self.start..slice.len()).get_unchecked(slice) } - fn index(self, slice: &DeviceSlice) -> DeviceSlice { + fn index(self, slice: &DeviceSlice) -> &DeviceSlice { if self.start > slice.len() { slice_start_index_len_fail(self.start, slice.len()); } // SAFETY: `self` is checked to be valid and in bounds above. unsafe { self.get_unchecked(slice) } } + + unsafe fn get_unchecked_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice { + (self.start..slice.len()).get_unchecked_mut(slice) + } + fn index_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice { + if self.start > slice.len() { + slice_start_index_len_fail(self.start, slice.len()); + } + // SAFETY: `self` is checked to be valid and in bounds above. + unsafe { self.get_unchecked_mut(slice) } + } } impl DeviceSliceIndex for RangeFull { - unsafe fn get_unchecked(self, slice: &DeviceSlice) -> DeviceSlice { - *slice + unsafe fn get_unchecked(self, slice: &DeviceSlice) -> &DeviceSlice { + slice + } + fn index(self, slice: &DeviceSlice) -> &DeviceSlice { + slice + } + + unsafe fn get_unchecked_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice { + slice } - fn index(self, slice: &DeviceSlice) -> DeviceSlice { - *slice + fn index_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice { + slice } } @@ -530,30 +599,51 @@ fn into_slice_range(range: RangeInclusive) -> Range { } impl DeviceSliceIndex for RangeInclusive { - unsafe fn get_unchecked(self, slice: &DeviceSlice) -> DeviceSlice { + unsafe fn get_unchecked(self, slice: &DeviceSlice) -> &DeviceSlice { into_slice_range(self).get_unchecked(slice) } - fn index(self, slice: &DeviceSlice) -> DeviceSlice { + fn index(self, slice: &DeviceSlice) -> &DeviceSlice { if *self.end() == usize::MAX { slice_end_index_overflow_fail(); } into_slice_range(self).index(slice) } + + unsafe fn get_unchecked_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice { + into_slice_range(self).get_unchecked_mut(slice) + } + fn index_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice { + if *self.end() == usize::MAX { + slice_end_index_overflow_fail(); + } + into_slice_range(self).index_mut(slice) + } } impl DeviceSliceIndex for RangeToInclusive { - unsafe fn get_unchecked(self, slice: &DeviceSlice) -> DeviceSlice { + unsafe fn get_unchecked(self, slice: &DeviceSlice) -> &DeviceSlice { (0..=self.end).get_unchecked(slice) } - fn index(self, slice: &DeviceSlice) -> DeviceSlice { + fn index(self, slice: &DeviceSlice) -> &DeviceSlice { (0..=self.end).index(slice) } + + unsafe fn get_unchecked_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice { + (0..=self.end).get_unchecked_mut(slice) + } + fn index_mut(self, slice: &mut DeviceSlice) -> &mut DeviceSlice { + (0..=self.end).index_mut(slice) + } } impl DeviceSlice { - pub fn index>(&self, idx: Idx) -> DeviceSlice { + pub fn index>(&self, idx: Idx) -> &DeviceSlice { idx.index(self) } + + pub fn index_mut>(&mut self, idx: Idx) -> &mut DeviceSlice { + idx.index_mut(self) + } } impl crate::private::Sealed for DeviceSlice {} @@ -567,8 +657,12 @@ impl + AsMut<[T]> + ?Sized> CopyDestination for let size = mem::size_of::() * self.len(); if size != 0 { unsafe { - cuda::cuMemcpyHtoD_v2(self.ptr.as_raw(), val.as_ptr() as *const c_void, size) - .to_result()? + cuda::cuMemcpyHtoD_v2( + self.as_device_ptr().as_raw(), + val.as_ptr() as *const c_void, + size, + ) + .to_result()? } } Ok(()) @@ -603,8 +697,12 @@ impl CopyDestination> for DeviceSlice { let size = mem::size_of::() * self.len(); if size != 0 { unsafe { - cuda::cuMemcpyDtoD_v2(self.ptr.as_raw(), val.as_device_ptr().as_raw(), size) - .to_result()? + cuda::cuMemcpyDtoD_v2( + self.as_device_ptr().as_raw(), + val.as_device_ptr().as_raw(), + size, + ) + .to_result()? } } Ok(()) @@ -650,7 +748,7 @@ impl + AsMut<[T]> + ?Sized> AsyncCopyDestination let size = mem::size_of::() * self.len(); if size != 0 { cuda::cuMemcpyHtoDAsync_v2( - self.ptr.as_raw(), + self.as_device_ptr().as_raw(), val.as_ptr() as *const c_void, size, stream.as_inner(),