diff --git a/crates/cust/src/memory/device/device_slice.rs b/crates/cust/src/memory/device/device_slice.rs index 6098ed5..4c3b93f 100644 --- a/crates/cust/src/memory/device/device_slice.rs +++ b/crates/cust/src/memory/device/device_slice.rs @@ -9,8 +9,9 @@ use crate::sys as cuda; use bytemuck::{Pod, Zeroable}; use std::fmt::{self, Debug, Formatter}; use std::marker::PhantomData; -use std::mem::{self, size_of}; -use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; +use std::ops::{ + Index, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, +}; use std::os::raw::c_void; use std::ptr::{slice_from_raw_parts, slice_from_raw_parts_mut}; @@ -236,16 +237,13 @@ impl DeviceSlice { /// In total it will set `sizeof * len` values of `value` contiguously. #[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))] pub fn set_8(&mut self, value: u8) -> CudaResult<()> { + if self.size_in_bytes() == 0 { + return Ok(()); + } + // SAFETY: We know T can hold any value because it is `Pod`, and // sub-byte alignment isn't a thing so we know the alignment is right. - unsafe { - cuda::cuMemsetD8_v2( - self.as_device_ptr().as_raw(), - value, - size_of::() * self.len(), - ) - .to_result() - } + unsafe { cuda::cuMemsetD8_v2(self.as_raw_ptr(), value, self.size_in_bytes()).to_result() } } /// Sets the memory range of this buffer to contiguous `8-bit` values of `value` asynchronously. @@ -258,10 +256,14 @@ impl DeviceSlice { /// Therefore you should not read/write from/to the memory range until the operation is complete. #[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))] pub unsafe fn set_8_async(&mut self, value: u8, stream: &Stream) -> CudaResult<()> { + if self.size_in_bytes() == 0 { + return Ok(()); + } + cuda::cuMemsetD8Async( - self.as_device_ptr().as_raw(), + self.as_raw_ptr(), value, - size_of::() * self.len(), + self.size_in_bytes(), stream.as_inner(), ) .to_result() @@ -279,20 +281,18 @@ impl DeviceSlice { #[track_caller] #[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))] pub fn set_16(&mut self, value: u16) -> CudaResult<()> { - let data_len = size_of::() * self.len(); + let data_len = self.size_in_bytes(); assert_eq!( data_len % 2, 0, "Buffer length is not a multiple of 2 bytes!" ); assert_eq!( - self.as_device_ptr().as_raw() % 2, + self.as_raw_ptr() % 2, 0, "Buffer pointer is not aligned to at least 2 bytes!" ); - unsafe { - cuda::cuMemsetD16_v2(self.as_device_ptr().as_raw(), value, data_len / 2).to_result() - } + unsafe { cuda::cuMemsetD16_v2(self.as_raw_ptr(), value, data_len / 2).to_result() } } /// Sets the memory range of this buffer to contiguous `16-bit` values of `value` asynchronously. @@ -312,24 +312,19 @@ impl DeviceSlice { #[track_caller] #[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))] pub unsafe fn set_16_async(&mut self, value: u16, stream: &Stream) -> CudaResult<()> { - let data_len = size_of::() * self.len(); + let data_len = self.size_in_bytes(); assert_eq!( data_len % 2, 0, "Buffer length is not a multiple of 2 bytes!" ); assert_eq!( - self.as_device_ptr().as_raw() % 2, + self.as_raw_ptr() % 2, 0, "Buffer pointer is not aligned to at least 2 bytes!" ); - cuda::cuMemsetD16Async( - self.as_device_ptr().as_raw(), - value, - data_len / 2, - stream.as_inner(), - ) - .to_result() + cuda::cuMemsetD16Async(self.as_raw_ptr(), value, data_len / 2, stream.as_inner()) + .to_result() } /// Sets the memory range of this buffer to contiguous `32-bit` values of `value`. @@ -344,20 +339,18 @@ impl DeviceSlice { #[track_caller] #[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))] pub fn set_32(&mut self, value: u32) -> CudaResult<()> { - let data_len = size_of::() * self.len(); + let data_len = self.size_in_bytes(); assert_eq!( data_len % 4, 0, "Buffer length is not a multiple of 4 bytes!" ); assert_eq!( - self.as_device_ptr().as_raw() % 4, + self.as_raw_ptr() % 4, 0, "Buffer pointer is not aligned to at least 4 bytes!" ); - unsafe { - cuda::cuMemsetD32_v2(self.as_device_ptr().as_raw(), value, data_len / 4).to_result() - } + unsafe { cuda::cuMemsetD32_v2(self.as_raw_ptr(), value, data_len / 4).to_result() } } /// Sets the memory range of this buffer to contiguous `32-bit` values of `value` asynchronously. @@ -377,24 +370,19 @@ impl DeviceSlice { #[track_caller] #[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))] pub unsafe fn set_32_async(&mut self, value: u32, stream: &Stream) -> CudaResult<()> { - let data_len = size_of::() * self.len(); + let data_len = self.size_in_bytes(); assert_eq!( data_len % 4, 0, "Buffer length is not a multiple of 4 bytes!" ); assert_eq!( - self.as_device_ptr().as_raw() % 4, + self.as_raw_ptr() % 4, 0, "Buffer pointer is not aligned to at least 4 bytes!" ); - cuda::cuMemsetD32Async( - self.as_device_ptr().as_raw(), - value, - data_len / 4, - stream.as_inner(), - ) - .to_result() + cuda::cuMemsetD32Async(self.as_raw_ptr(), value, data_len / 4, stream.as_inner()) + .to_result() } } @@ -405,10 +393,7 @@ impl DeviceSlice { // SAFETY: this is fine because Zeroable guarantees a zero byte-pattern is safe // for this type. And a slice of bytes can represent any type. let erased = unsafe { - DeviceSlice::from_raw_parts_mut( - self.as_device_ptr().cast::(), - size_of::() * self.len(), - ) + DeviceSlice::from_raw_parts_mut(self.as_device_ptr().cast::(), self.size_in_bytes()) }; erased.set_8(0) } @@ -420,14 +405,11 @@ impl DeviceSlice { /// This operation is async so it does not complete immediately, it uses stream-ordering semantics. /// Therefore you should not read/write from/to the memory range until the operation is complete. pub unsafe fn set_zero_async(&mut self, stream: &Stream) -> CudaResult<()> { - if self.as_device_ptr().is_null() { - return Ok(()); - } // SAFETY: this is fine because Zeroable guarantees a zero byte-pattern is safe // for this type. And a slice of bytes can represent any type. let erased = DeviceSlice::from_raw_parts_mut( self.as_device_ptr().cast::(), - size_of::() * self.len(), + self.size_in_bytes(), ); erased.set_8_async(0, stream) } @@ -636,13 +618,17 @@ impl DeviceSliceIndex for RangeToInclusive { } } -impl DeviceSlice { - pub fn index>(&self, idx: Idx) -> &DeviceSlice { - idx.index(self) +impl> Index for DeviceSlice { + type Output = DeviceSlice; + + fn index(&self, index: Idx) -> &DeviceSlice { + index.index(self) } +} - pub fn index_mut>(&mut self, idx: Idx) -> &mut DeviceSlice { - idx.index_mut(self) +impl> IndexMut for DeviceSlice { + fn index_mut(&mut self, index: Idx) -> &mut DeviceSlice { + index.index_mut(self) } } @@ -654,15 +640,11 @@ impl + AsMut<[T]> + ?Sized> CopyDestination for self.len() == val.len(), "destination and source slices have different lengths" ); - let size = mem::size_of::() * self.len(); + let size = self.size_in_bytes(); if size != 0 { unsafe { - cuda::cuMemcpyHtoD_v2( - self.as_device_ptr().as_raw(), - val.as_ptr() as *const c_void, - size, - ) - .to_result()? + cuda::cuMemcpyHtoD_v2(self.as_raw_ptr(), val.as_ptr() as *const c_void, size) + .to_result()? } } Ok(()) @@ -674,15 +656,11 @@ impl + AsMut<[T]> + ?Sized> CopyDestination for self.len() == val.len(), "destination and source slices have different lengths" ); - let size = mem::size_of::() * self.len(); + let size = self.size_in_bytes(); if size != 0 { unsafe { - cuda::cuMemcpyDtoH_v2( - val.as_mut_ptr() as *mut c_void, - self.as_device_ptr().as_raw(), - size, - ) - .to_result()? + cuda::cuMemcpyDtoH_v2(val.as_mut_ptr() as *mut c_void, self.as_raw_ptr(), size) + .to_result()? } } Ok(()) @@ -694,16 +672,9 @@ impl CopyDestination> for DeviceSlice { self.len() == val.len(), "destination and source slices have different lengths" ); - let size = mem::size_of::() * self.len(); + let size = self.size_in_bytes(); if size != 0 { - unsafe { - cuda::cuMemcpyDtoD_v2( - self.as_device_ptr().as_raw(), - val.as_device_ptr().as_raw(), - size, - ) - .to_result()? - } + unsafe { cuda::cuMemcpyDtoD_v2(self.as_raw_ptr(), val.as_raw_ptr(), size).to_result()? } } Ok(()) } @@ -713,16 +684,9 @@ impl CopyDestination> for DeviceSlice { self.len() == val.len(), "destination and source slices have different lengths" ); - let size = mem::size_of::() * self.len(); + let size = self.size_in_bytes(); if size != 0 { - unsafe { - cuda::cuMemcpyDtoD_v2( - val.as_device_ptr().as_raw(), - self.as_device_ptr().as_raw(), - size, - ) - .to_result()? - } + unsafe { cuda::cuMemcpyDtoD_v2(val.as_raw_ptr(), self.as_raw_ptr(), size).to_result()? } } Ok(()) } @@ -745,10 +709,10 @@ impl + AsMut<[T]> + ?Sized> AsyncCopyDestination self.len() == val.len(), "destination and source slices have different lengths" ); - let size = mem::size_of::() * self.len(); + let size = self.size_in_bytes(); if size != 0 { cuda::cuMemcpyHtoDAsync_v2( - self.as_device_ptr().as_raw(), + self.as_raw_ptr(), val.as_ptr() as *const c_void, size, stream.as_inner(), @@ -764,11 +728,11 @@ impl + AsMut<[T]> + ?Sized> AsyncCopyDestination self.len() == val.len(), "destination and source slices have different lengths" ); - let size = mem::size_of::() * self.len(); + let size = self.size_in_bytes(); if size != 0 { cuda::cuMemcpyDtoHAsync_v2( val.as_mut_ptr() as *mut c_void, - self.as_device_ptr().as_raw(), + self.as_raw_ptr(), size, stream.as_inner(), ) @@ -783,15 +747,10 @@ impl AsyncCopyDestination> for DeviceSlice { self.len() == val.len(), "destination and source slices have different lengths" ); - let size = mem::size_of::() * self.len(); + let size = self.size_in_bytes(); if size != 0 { - cuda::cuMemcpyDtoDAsync_v2( - self.as_device_ptr().as_raw(), - val.as_device_ptr().as_raw(), - size, - stream.as_inner(), - ) - .to_result()? + cuda::cuMemcpyDtoDAsync_v2(self.as_raw_ptr(), val.as_raw_ptr(), size, stream.as_inner()) + .to_result()? } Ok(()) } @@ -801,15 +760,10 @@ impl AsyncCopyDestination> for DeviceSlice { self.len() == val.len(), "destination and source slices have different lengths" ); - let size = mem::size_of::() * self.len(); + let size = self.size_in_bytes(); if size != 0 { - cuda::cuMemcpyDtoDAsync_v2( - val.as_device_ptr().as_raw(), - self.as_device_ptr().as_raw(), - size, - stream.as_inner(), - ) - .to_result()? + cuda::cuMemcpyDtoDAsync_v2(val.as_raw_ptr(), self.as_raw_ptr(), size, stream.as_inner()) + .to_result()? } Ok(()) }