Skip to content

Commit

Permalink
impl Index traits and add size checks for copying
Browse files Browse the repository at this point in the history
  • Loading branch information
beepster4096 committed Feb 22, 2022
1 parent f2811f3 commit f1242df
Showing 1 changed file with 59 additions and 105 deletions.
164 changes: 59 additions & 105 deletions crates/cust/src/memory/device/device_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ use crate::sys as cuda;
use bytemuck::{Pod, Zeroable};
use std::fmt::{self, Debug, Formatter};
use std::marker::PhantomData;
use std::mem::{self, size_of};
use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
use std::ops::{
Index, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
};
use std::os::raw::c_void;
use std::ptr::{slice_from_raw_parts, slice_from_raw_parts_mut};

Expand Down Expand Up @@ -236,16 +237,13 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
/// In total it will set `sizeof<T> * len` values of `value` contiguously.
#[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))]
pub fn set_8(&mut self, value: u8) -> CudaResult<()> {
if self.size_in_bytes() == 0 {
return Ok(());
}

// SAFETY: We know T can hold any value because it is `Pod`, and
// sub-byte alignment isn't a thing so we know the alignment is right.
unsafe {
cuda::cuMemsetD8_v2(
self.as_device_ptr().as_raw(),
value,
size_of::<T>() * self.len(),
)
.to_result()
}
unsafe { cuda::cuMemsetD8_v2(self.as_raw_ptr(), value, self.size_in_bytes()).to_result() }
}

/// Sets the memory range of this buffer to contiguous `8-bit` values of `value` asynchronously.
Expand All @@ -258,10 +256,14 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
/// Therefore you should not read/write from/to the memory range until the operation is complete.
#[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))]
pub unsafe fn set_8_async(&mut self, value: u8, stream: &Stream) -> CudaResult<()> {
if self.size_in_bytes() == 0 {
return Ok(());
}

cuda::cuMemsetD8Async(
self.as_device_ptr().as_raw(),
self.as_raw_ptr(),
value,
size_of::<T>() * self.len(),
self.size_in_bytes(),
stream.as_inner(),
)
.to_result()
Expand All @@ -279,20 +281,18 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
#[track_caller]
#[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))]
pub fn set_16(&mut self, value: u16) -> CudaResult<()> {
let data_len = size_of::<T>() * self.len();
let data_len = self.size_in_bytes();
assert_eq!(
data_len % 2,
0,
"Buffer length is not a multiple of 2 bytes!"
);
assert_eq!(
self.as_device_ptr().as_raw() % 2,
self.as_raw_ptr() % 2,
0,
"Buffer pointer is not aligned to at least 2 bytes!"
);
unsafe {
cuda::cuMemsetD16_v2(self.as_device_ptr().as_raw(), value, data_len / 2).to_result()
}
unsafe { cuda::cuMemsetD16_v2(self.as_raw_ptr(), value, data_len / 2).to_result() }
}

/// Sets the memory range of this buffer to contiguous `16-bit` values of `value` asynchronously.
Expand All @@ -312,24 +312,19 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
#[track_caller]
#[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))]
pub unsafe fn set_16_async(&mut self, value: u16, stream: &Stream) -> CudaResult<()> {
let data_len = size_of::<T>() * self.len();
let data_len = self.size_in_bytes();
assert_eq!(
data_len % 2,
0,
"Buffer length is not a multiple of 2 bytes!"
);
assert_eq!(
self.as_device_ptr().as_raw() % 2,
self.as_raw_ptr() % 2,
0,
"Buffer pointer is not aligned to at least 2 bytes!"
);
cuda::cuMemsetD16Async(
self.as_device_ptr().as_raw(),
value,
data_len / 2,
stream.as_inner(),
)
.to_result()
cuda::cuMemsetD16Async(self.as_raw_ptr(), value, data_len / 2, stream.as_inner())
.to_result()
}

/// Sets the memory range of this buffer to contiguous `32-bit` values of `value`.
Expand All @@ -344,20 +339,18 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
#[track_caller]
#[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))]
pub fn set_32(&mut self, value: u32) -> CudaResult<()> {
let data_len = size_of::<T>() * self.len();
let data_len = self.size_in_bytes();
assert_eq!(
data_len % 4,
0,
"Buffer length is not a multiple of 4 bytes!"
);
assert_eq!(
self.as_device_ptr().as_raw() % 4,
self.as_raw_ptr() % 4,
0,
"Buffer pointer is not aligned to at least 4 bytes!"
);
unsafe {
cuda::cuMemsetD32_v2(self.as_device_ptr().as_raw(), value, data_len / 4).to_result()
}
unsafe { cuda::cuMemsetD32_v2(self.as_raw_ptr(), value, data_len / 4).to_result() }
}

/// Sets the memory range of this buffer to contiguous `32-bit` values of `value` asynchronously.
Expand All @@ -377,24 +370,19 @@ impl<T: DeviceCopy + Pod> DeviceSlice<T> {
#[track_caller]
#[cfg_attr(docsrs, doc(cfg(feature = "bytemuck")))]
pub unsafe fn set_32_async(&mut self, value: u32, stream: &Stream) -> CudaResult<()> {
let data_len = size_of::<T>() * self.len();
let data_len = self.size_in_bytes();
assert_eq!(
data_len % 4,
0,
"Buffer length is not a multiple of 4 bytes!"
);
assert_eq!(
self.as_device_ptr().as_raw() % 4,
self.as_raw_ptr() % 4,
0,
"Buffer pointer is not aligned to at least 4 bytes!"
);
cuda::cuMemsetD32Async(
self.as_device_ptr().as_raw(),
value,
data_len / 4,
stream.as_inner(),
)
.to_result()
cuda::cuMemsetD32Async(self.as_raw_ptr(), value, data_len / 4, stream.as_inner())
.to_result()
}
}

Expand All @@ -405,10 +393,7 @@ impl<T: DeviceCopy + Zeroable> DeviceSlice<T> {
// SAFETY: this is fine because Zeroable guarantees a zero byte-pattern is safe
// for this type. And a slice of bytes can represent any type.
let erased = unsafe {
DeviceSlice::from_raw_parts_mut(
self.as_device_ptr().cast::<u8>(),
size_of::<T>() * self.len(),
)
DeviceSlice::from_raw_parts_mut(self.as_device_ptr().cast::<u8>(), self.size_in_bytes())
};
erased.set_8(0)
}
Expand All @@ -420,14 +405,11 @@ impl<T: DeviceCopy + Zeroable> DeviceSlice<T> {
/// This operation is async so it does not complete immediately, it uses stream-ordering semantics.
/// Therefore you should not read/write from/to the memory range until the operation is complete.
pub unsafe fn set_zero_async(&mut self, stream: &Stream) -> CudaResult<()> {
if self.as_device_ptr().is_null() {
return Ok(());
}
// SAFETY: this is fine because Zeroable guarantees a zero byte-pattern is safe
// for this type. And a slice of bytes can represent any type.
let erased = DeviceSlice::from_raw_parts_mut(
self.as_device_ptr().cast::<u8>(),
size_of::<T>() * self.len(),
self.size_in_bytes(),
);
erased.set_8_async(0, stream)
}
Expand Down Expand Up @@ -636,13 +618,17 @@ impl<T: DeviceCopy> DeviceSliceIndex<T> for RangeToInclusive<usize> {
}
}

impl<T: DeviceCopy> DeviceSlice<T> {
pub fn index<Idx: DeviceSliceIndex<T>>(&self, idx: Idx) -> &DeviceSlice<T> {
idx.index(self)
impl<T: DeviceCopy, Idx: DeviceSliceIndex<T>> Index<Idx> for DeviceSlice<T> {
type Output = DeviceSlice<T>;

fn index(&self, index: Idx) -> &DeviceSlice<T> {
index.index(self)
}
}

pub fn index_mut<Idx: DeviceSliceIndex<T>>(&mut self, idx: Idx) -> &mut DeviceSlice<T> {
idx.index_mut(self)
impl<T: DeviceCopy, Idx: DeviceSliceIndex<T>> IndexMut<Idx> for DeviceSlice<T> {
fn index_mut(&mut self, index: Idx) -> &mut DeviceSlice<T> {
index.index_mut(self)
}
}

Expand All @@ -654,15 +640,11 @@ impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> CopyDestination<I> for
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
let size = self.size_in_bytes();
if size != 0 {
unsafe {
cuda::cuMemcpyHtoD_v2(
self.as_device_ptr().as_raw(),
val.as_ptr() as *const c_void,
size,
)
.to_result()?
cuda::cuMemcpyHtoD_v2(self.as_raw_ptr(), val.as_ptr() as *const c_void, size)
.to_result()?
}
}
Ok(())
Expand All @@ -674,15 +656,11 @@ impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> CopyDestination<I> for
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
let size = self.size_in_bytes();
if size != 0 {
unsafe {
cuda::cuMemcpyDtoH_v2(
val.as_mut_ptr() as *mut c_void,
self.as_device_ptr().as_raw(),
size,
)
.to_result()?
cuda::cuMemcpyDtoH_v2(val.as_mut_ptr() as *mut c_void, self.as_raw_ptr(), size)
.to_result()?
}
}
Ok(())
Expand All @@ -694,16 +672,9 @@ impl<T: DeviceCopy> CopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
let size = self.size_in_bytes();
if size != 0 {
unsafe {
cuda::cuMemcpyDtoD_v2(
self.as_device_ptr().as_raw(),
val.as_device_ptr().as_raw(),
size,
)
.to_result()?
}
unsafe { cuda::cuMemcpyDtoD_v2(self.as_raw_ptr(), val.as_raw_ptr(), size).to_result()? }
}
Ok(())
}
Expand All @@ -713,16 +684,9 @@ impl<T: DeviceCopy> CopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
let size = self.size_in_bytes();
if size != 0 {
unsafe {
cuda::cuMemcpyDtoD_v2(
val.as_device_ptr().as_raw(),
self.as_device_ptr().as_raw(),
size,
)
.to_result()?
}
unsafe { cuda::cuMemcpyDtoD_v2(val.as_raw_ptr(), self.as_raw_ptr(), size).to_result()? }
}
Ok(())
}
Expand All @@ -745,10 +709,10 @@ impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> AsyncCopyDestination<I>
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
let size = self.size_in_bytes();
if size != 0 {
cuda::cuMemcpyHtoDAsync_v2(
self.as_device_ptr().as_raw(),
self.as_raw_ptr(),
val.as_ptr() as *const c_void,
size,
stream.as_inner(),
Expand All @@ -764,11 +728,11 @@ impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> AsyncCopyDestination<I>
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
let size = self.size_in_bytes();
if size != 0 {
cuda::cuMemcpyDtoHAsync_v2(
val.as_mut_ptr() as *mut c_void,
self.as_device_ptr().as_raw(),
self.as_raw_ptr(),
size,
stream.as_inner(),
)
Expand All @@ -783,15 +747,10 @@ impl<T: DeviceCopy> AsyncCopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
let size = self.size_in_bytes();
if size != 0 {
cuda::cuMemcpyDtoDAsync_v2(
self.as_device_ptr().as_raw(),
val.as_device_ptr().as_raw(),
size,
stream.as_inner(),
)
.to_result()?
cuda::cuMemcpyDtoDAsync_v2(self.as_raw_ptr(), val.as_raw_ptr(), size, stream.as_inner())
.to_result()?
}
Ok(())
}
Expand All @@ -801,15 +760,10 @@ impl<T: DeviceCopy> AsyncCopyDestination<DeviceSlice<T>> for DeviceSlice<T> {
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = mem::size_of::<T>() * self.len();
let size = self.size_in_bytes();
if size != 0 {
cuda::cuMemcpyDtoDAsync_v2(
val.as_device_ptr().as_raw(),
self.as_device_ptr().as_raw(),
size,
stream.as_inner(),
)
.to_result()?
cuda::cuMemcpyDtoDAsync_v2(val.as_raw_ptr(), self.as_raw_ptr(), size, stream.as_inner())
.to_result()?
}
Ok(())
}
Expand Down

0 comments on commit f1242df

Please sign in to comment.