Skip to content

Commit

Permalink
Speed up wgpu passes/allocations (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurBrussee authored Aug 12, 2024
1 parent 2d4d5a2 commit d94a07a
Show file tree
Hide file tree
Showing 17 changed files with 480 additions and 836 deletions.
21 changes: 9 additions & 12 deletions crates/cubecl-cuda/src/compute/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ unsafe impl<MM: MemoryManagement<CudaStorage>> Send for CudaServer<MM> {}
impl<MM: MemoryManagement<CudaStorage>> CudaServer<MM> {
fn read_sync(&mut self, binding: server::Binding<Self>) -> Vec<u8> {
let ctx = self.get_context();
let resource = ctx.memory_management.get(binding.memory);
let resource = ctx.memory_management.get_resource(binding.memory);

// TODO: Check if it is possible to make this faster
let mut data = vec![0; resource.size() as usize];
Expand All @@ -89,13 +89,12 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
}

fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
let handle = self.empty(data.len());
let ctx = self.get_context();
let handle = ctx.memory_management.reserve(data.len(), || unsafe {
cudarc::driver::result::stream::synchronize(ctx.stream).unwrap();
});
let handle = server::Handle::new(handle);
let binding = handle.clone().binding().memory;
let resource = ctx.memory_management.get(binding);

let resource = ctx
.memory_management
.get_resource(handle.clone().binding().memory);

unsafe {
cudarc::driver::result::memcpy_htod_async(resource.ptr, data, ctx.stream).unwrap();
Expand All @@ -106,9 +105,7 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {

fn empty(&mut self, size: usize) -> server::Handle<Self> {
let ctx = self.get_context();
let handle = ctx.memory_management.reserve(size, || unsafe {
cudarc::driver::result::stream::synchronize(ctx.stream).unwrap();
});
let handle = ctx.memory_management.reserve(size, &[]);
server::Handle::new(handle)
}

Expand Down Expand Up @@ -148,7 +145,7 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {

let resources = bindings
.into_iter()
.map(|binding| ctx.memory_management.get(binding.memory))
.map(|binding| ctx.memory_management.get_resource(binding.memory))
.collect::<Vec<_>>();

ctx.execute_task(kernel_id, count, resources);
Expand All @@ -171,7 +168,7 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
binding: server::Binding<Self>,
) -> <Self::Storage as cubecl_runtime::storage::ComputeStorage>::Resource {
let ctx = self.get_context();
ctx.memory_management.get(binding.memory)
ctx.memory_management.get_resource(binding.memory)
}
}

Expand Down
16 changes: 1 addition & 15 deletions crates/cubecl-cuda/src/compute/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,25 +151,11 @@ impl ComputeStorage for CudaStorage {
fn alloc(&mut self, size: usize) -> StorageHandle {
let id = StorageId::new();
let ptr = unsafe { cudarc::driver::result::malloc_async(self.stream, size).unwrap() };
self.memory.insert(id.clone(), ptr);
self.memory.insert(id, ptr);
StorageHandle::new(id, StorageUtilization::Full(size))
}

fn dealloc(&mut self, id: StorageId) {
self.deallocations.push(id);
}

fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) {
let num_bytes = from.size();

unsafe {
cudarc::driver::result::memcpy_dtod_async(
self.get(to).ptr,
self.get(from).ptr,
num_bytes,
self.stream,
)
.unwrap();
}
}
}
2 changes: 1 addition & 1 deletion crates/cubecl-runtime/benches/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn main() {
if handles.len() >= 4000 {
handles.pop_front();
}
let handle = mm.reserve(MB, || {});
let handle = mm.reserve(MB, &[]);
handles.push_back(handle);
}
println!("{:?}", start.elapsed());
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-runtime/src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use alloc::sync::Arc;
macro_rules! storage_id_type {
($name:ident) => {
/// Storage ID.
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
#[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)]
pub struct $name {
value: usize,
}
Expand Down
14 changes: 10 additions & 4 deletions crates/cubecl-runtime/src/memory_management/base.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::storage::ComputeStorage;
use crate::storage::{ComputeStorage, StorageHandle, StorageId};

/// The managed tensor buffer handle that points to some memory segment.
/// It should not contain actual data.
Expand All @@ -23,18 +23,24 @@ pub trait MemoryManagement<Storage: ComputeStorage>: Send + core::fmt::Debug {
/// The associated type that must implement [MemoryBinding]
type Binding: MemoryBinding;

/// Returns the storage from the specified binding
fn get(&mut self, binding: Self::Binding) -> StorageHandle;

/// Returns the resource from the storage at the specified handle
fn get(&mut self, binding: Self::Binding) -> Storage::Resource;
fn get_resource(&mut self, binding: Self::Binding) -> Storage::Resource {
let handle = self.get(binding);
self.storage().get(&handle)
}

/// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it
fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle;
fn reserve(&mut self, size: usize, exclude: &[StorageId]) -> Self::Handle;

/// Bypass the memory allocation algorithm to allocate data directly.
///
/// # Notes
///
/// Can be useful for servers that want specific control over memory.
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle;
fn alloc(&mut self, size: usize) -> Self::Handle;

/// Bypass the memory allocation algorithm to deallocate data directly.
///
Expand Down
28 changes: 14 additions & 14 deletions crates/cubecl-runtime/src/memory_management/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::memory_pool::{
MemoryExtensionStrategy, MemoryPool, MemoryPoolBinding, MemoryPoolHandle, RoundingStrategy,
SmallMemoryPool,
};
use crate::storage::ComputeStorage;
use crate::storage::{ComputeStorage, StorageHandle, StorageId};
use alloc::vec::Vec;

use super::MemoryManagement;
Expand Down Expand Up @@ -92,7 +92,7 @@ impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
);

for _ in 0..option.chunk_num_prealloc {
pool.alloc(&mut storage, option.chunk_size, || {});
pool.alloc(&mut storage, option.chunk_size);
}

pool
Expand Down Expand Up @@ -125,46 +125,46 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagem
type Handle = MemoryPoolHandle;
type Binding = MemoryPoolBinding;

fn get(&mut self, binding: Self::Binding) -> Storage::Resource {
if let Some(handle) = self.small_memory_pool.get(&mut self.storage, &binding) {
return handle;
fn get(&mut self, binding: Self::Binding) -> StorageHandle {
if let Some(handle) = self.small_memory_pool.get(&binding) {
return handle.clone();
}

for pool in &mut self.pools {
if let Some(handle) = pool.get(&mut self.storage, &binding) {
return handle;
for pool in &self.pools {
if let Some(handle) = pool.get(&binding) {
return handle.clone();
}
}

panic!("No handle found in memory pools");
}

fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
fn reserve(&mut self, size: usize, exclude: &[StorageId]) -> Self::Handle {
if size <= self.min_chunk_alignment_offset {
return self
.small_memory_pool
.reserve(&mut self.storage, size, sync);
.reserve(&mut self.storage, size, exclude);
}

for (index, option) in self.options.iter().enumerate() {
if size <= option.slice_max_size {
let pool = &mut self.pools[index];
return pool.reserve(&mut self.storage, size, sync);
return pool.reserve(&mut self.storage, size, exclude);
}
}

panic!("No memory pool big enough to reserve {size} bytes.");
}

fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
fn alloc(&mut self, size: usize) -> Self::Handle {
if size <= self.min_chunk_alignment_offset {
return self.small_memory_pool.alloc(&mut self.storage, size, sync);
return self.small_memory_pool.alloc(&mut self.storage, size);
}

for (index, option) in self.options.iter().enumerate() {
if size <= option.slice_max_size {
let pool = &mut self.pools[index];
return pool.alloc(&mut self.storage, size, sync);
return pool.alloc(&mut self.storage, size);
}
}

Expand Down
Loading

0 comments on commit d94a07a

Please sign in to comment.