Skip to content

Commit

Permalink
make device-local primitives device-safe by default (for allocation a…
Browse files Browse the repository at this point in the history
…nd deallocation only: caller remains responsible for correct device placement during usage)
  • Loading branch information
gerwin3 committed Aug 14, 2023
1 parent d27c95d commit 4ff6b46
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub async fn num_devices() -> Result<usize> {
}

/// CUDA device ID.
pub type DeviceId = usize;
pub type DeviceId = i32;

/// CUDA device.
pub struct Device;
Expand Down
57 changes: 51 additions & 6 deletions src/ffi/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,33 @@ impl Device {
let mut id: i32 = 0;
let id_ptr = std::ptr::addr_of_mut!(id);
let ret = cpp!(unsafe [
id_ptr as "std::int32_t*"
] -> i32 as "std::int32_t" {
id_ptr as "int"
] -> i32 as "int" {
return cudaGetDevice(id_ptr);
});
result!(ret, id as DeviceId)
result!(ret, id)
}

pub fn set(id: DeviceId) -> Result<()> {
let id = id as i32;
let ret = cpp!(unsafe [
id as "std::int32_t"
] -> i32 as "std::int32_t" {
id as "int"
] -> i32 as "int" {
return cudaSetDevice(id);
});
result!(ret)
}

#[inline]
pub fn bind(id: DeviceId) -> Result<DeviceGuard> {
DeviceGuard::activate(id)
}

#[inline]
pub fn bind_or_panic(id: DeviceId) -> DeviceGuard {
DeviceGuard::activate(id)
.unwrap_or_else(|err| panic!("failed to bind to device {}: {}", id, err))
}

pub fn synchronize() -> Result<()> {
let ret = cpp!(unsafe [] -> i32 as "std::int32_t" {
return cudaDeviceSynchronize();
Expand All @@ -71,6 +81,36 @@ impl Device {
}
}

/// Guard to keep specified active for the duration of the surrounding scope.
pub struct DeviceGuard {
pub active: DeviceId,
pub previous: DeviceId,
}

impl DeviceGuard {
/// Create [`DeviceGuard`] and activate it.
///
/// # Arguments
///
/// * `device` - Device to activate.
fn activate(device: DeviceId) -> Result<DeviceGuard> {
let previous = Device::get()?;
Device::set(device)?;
Ok(Self {
active: device,
previous,
})
}
}

impl Drop for DeviceGuard {
fn drop(&mut self) {
Device::set(self.previous).unwrap_or_else(|err| {
panic!("failed to set device ordinal: {} ({})", self.previous, err)
});
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -91,6 +131,11 @@ mod tests {
assert!(matches!(Device::get(), Ok(0)));
}

#[test]
fn test_bind_device() {
assert!(Device::bind(0).is_ok());
}

#[test]
fn test_synchronize() {
assert!(Device::synchronize().is_ok());
Expand Down
14 changes: 13 additions & 1 deletion src/ffi/memory/device.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use cpp::cpp;

use crate::device::DeviceId;
use crate::ffi::device::Device;
use crate::ffi::memory::host::HostBuffer;
use crate::ffi::ptr::DevicePtr;
use crate::ffi::result;
Expand All @@ -13,6 +15,7 @@ type Result<T> = std::result::Result<T, crate::error::Error>;
pub struct DeviceBuffer<T: Copy> {
pub num_elements: usize,
internal: DevicePtr,
device: DeviceId,
_phantom: std::marker::PhantomData<T>,
}

Expand All @@ -32,6 +35,8 @@ unsafe impl<T: Copy> Sync for DeviceBuffer<T> {}

impl<T: Copy> DeviceBuffer<T> {
pub fn new(num_elements: usize, stream: &Stream) -> Self {
let device =
Device::get().unwrap_or_else(|err| panic!("could not determine current device: {err}"));
let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
let ptr_ptr = std::ptr::addr_of_mut!(ptr);
let size = num_elements * std::mem::size_of::<T>();
Expand All @@ -43,9 +48,10 @@ impl<T: Copy> DeviceBuffer<T> {
] -> i32 as "std::int32_t" {
return cudaMallocAsync(ptr_ptr, size, (cudaStream_t) stream_ptr);
});
match result!(ret, ptr.into()) {
match result!(ret, DevicePtr::from_addr(ptr)) {
Ok(internal) => Self {
internal,
device,
num_elements,
_phantom: Default::default(),
},
Expand Down Expand Up @@ -175,6 +181,10 @@ impl<T: Copy> DeviceBuffer<T> {

/// Release the buffer memory.
///
/// # Panics
///
/// This function panics if binding to the corresponding device fails.
///
/// # Safety
///
/// The buffer may not be used after this function is called, except for being dropped.
Expand All @@ -183,6 +193,8 @@ impl<T: Copy> DeviceBuffer<T> {
return;
}

let _device_guard = Device::bind_or_panic(self.device);

// SAFETY: Safe because we won't use pointer after this.
let mut internal = unsafe { self.internal.take() };
let ptr = internal.as_mut_ptr();
Expand Down
14 changes: 13 additions & 1 deletion src/ffi/memory/device2d.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use cpp::cpp;

use crate::device::DeviceId;
use crate::ffi::device::Device;
use crate::ffi::memory::host::HostBuffer;
use crate::ffi::ptr::DevicePtr;
use crate::ffi::result;
Expand All @@ -16,6 +18,7 @@ pub struct DeviceBuffer2D<T: Copy> {
pub num_channels: usize,
pub pitch: usize,
internal: DevicePtr,
device: DeviceId,
_phantom: std::marker::PhantomData<T>,
}

Expand All @@ -35,6 +38,8 @@ unsafe impl<T: Copy> Sync for DeviceBuffer2D<T> {}

impl<T: Copy> DeviceBuffer2D<T> {
pub fn new(width: usize, height: usize, num_channels: usize) -> Self {
let device =
Device::get().unwrap_or_else(|err| panic!("could not determine current device: {err}"));
let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
let ptr_ptr = std::ptr::addr_of_mut!(ptr);
let mut pitch = 0_usize;
Expand All @@ -53,13 +58,14 @@ impl<T: Copy> DeviceBuffer2D<T> {
height
);
});
match result!(ret, ptr.into()) {
match result!(ret, DevicePtr::from_addr(ptr)) {
Ok(internal) => Self {
width,
height,
num_channels,
pitch,
internal,
device,
_phantom: Default::default(),
},
Err(err) => {
Expand Down Expand Up @@ -200,6 +206,10 @@ impl<T: Copy> DeviceBuffer2D<T> {

/// Release the buffer memory.
///
/// # Panics
///
/// This function panics if binding to the corresponding device fails.
///
/// # Safety
///
/// The buffer may not be used after this function is called, except for being dropped.
Expand All @@ -208,6 +218,8 @@ impl<T: Copy> DeviceBuffer2D<T> {
return;
}

let _device_guard = Device::bind_or_panic(self.device);

// SAFETY: Safe because we won't use pointer after this.
let mut internal = unsafe { self.internal.take() };
let ptr = internal.as_mut_ptr();
Expand Down
14 changes: 13 additions & 1 deletion src/ffi/memory/host.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use cpp::cpp;

use crate::device::DeviceId;
use crate::ffi::device::Device;
use crate::ffi::memory::device::DeviceBuffer;
use crate::ffi::ptr::DevicePtr;
use crate::ffi::result;
Expand All @@ -13,6 +15,7 @@ type Result<T> = std::result::Result<T, crate::error::Error>;
pub struct HostBuffer<T: Copy> {
pub num_elements: usize,
internal: DevicePtr,
device: DeviceId,
_phantom: std::marker::PhantomData<T>,
}

Expand All @@ -32,6 +35,8 @@ unsafe impl<T: Copy> Sync for HostBuffer<T> {}

impl<T: Copy> HostBuffer<T> {
pub fn new(num_elements: usize) -> Self {
let device =
Device::get().unwrap_or_else(|err| panic!("could not determine current device: {err}"));
let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
let ptr_ptr = std::ptr::addr_of_mut!(ptr);
let size = num_elements * std::mem::size_of::<T>();
Expand All @@ -41,9 +46,10 @@ impl<T: Copy> HostBuffer<T> {
] -> i32 as "std::int32_t" {
return cudaMallocHost(ptr_ptr, size);
});
match result!(ret, ptr.into()) {
match result!(ret, DevicePtr::from_addr(ptr)) {
Ok(internal) => Self {
internal,
device,
num_elements,
_phantom: Default::default(),
},
Expand Down Expand Up @@ -153,6 +159,10 @@ impl<T: Copy> HostBuffer<T> {

/// Release the buffer memory.
///
/// # Panics
///
/// This function panics if binding to the corresponding device fails.
///
/// # Safety
///
/// The buffer may not be used after this function is called, except for being dropped.
Expand All @@ -161,6 +171,8 @@ impl<T: Copy> HostBuffer<T> {
return;
}

let _device_guard = Device::bind_or_panic(self.device);

// SAFETY: Safe because we won't use the pointer after this.
let mut internal = unsafe { self.internal.take() };
let ptr = internal.as_mut_ptr();
Expand Down
Loading

0 comments on commit 4ff6b46

Please sign in to comment.