From d94a07af312db5dcfff1c0374431097ad3e551c0 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 12 Aug 2024 15:17:58 +0100 Subject: [PATCH] Speed up wgpu passes/allocations (#56) --- crates/cubecl-cuda/src/compute/server.rs | 21 +- crates/cubecl-cuda/src/compute/storage.rs | 16 +- crates/cubecl-runtime/benches/dynamic.rs | 2 +- crates/cubecl-runtime/src/id.rs | 2 +- .../src/memory_management/base.rs | 14 +- .../src/memory_management/dynamic.rs | 28 +- .../src/memory_management/memory_pool/base.rs | 246 +++------- .../memory_management/memory_pool/handle.rs | 2 - .../src/memory_management/memory_pool/ring.rs | 439 +++++++----------- .../memory_management/memory_pool/small.rs | 87 ++-- .../src/memory_management/simple.rs | 56 +-- crates/cubecl-runtime/src/storage/base.rs | 3 - .../cubecl-runtime/src/storage/bytes_cpu.rs | 23 +- crates/cubecl-runtime/tests/dummy/server.rs | 19 +- crates/cubecl-wgpu/src/compute/server.rs | 331 ++++++------- crates/cubecl-wgpu/src/compute/storage.rs | 25 +- crates/cubecl-wgpu/src/runtime.rs | 2 +- 17 files changed, 480 insertions(+), 836 deletions(-) diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs index 729b47c7..6ea4d209 100644 --- a/crates/cubecl-cuda/src/compute/server.rs +++ b/crates/cubecl-cuda/src/compute/server.rs @@ -65,7 +65,7 @@ unsafe impl> Send for CudaServer {} impl> CudaServer { fn read_sync(&mut self, binding: server::Binding) -> Vec { let ctx = self.get_context(); - let resource = ctx.memory_management.get(binding.memory); + let resource = ctx.memory_management.get_resource(binding.memory); // TODO: Check if it is possible to make this faster let mut data = vec![0; resource.size() as usize]; @@ -89,13 +89,12 @@ impl> ComputeServer for CudaServer { } fn create(&mut self, data: &[u8]) -> server::Handle { + let handle = self.empty(data.len()); let ctx = self.get_context(); - let handle = ctx.memory_management.reserve(data.len(), || unsafe { - cudarc::driver::result::stream::synchronize(ctx.stream).unwrap(); - }); - let handle = server::Handle::new(handle); - let binding = handle.clone().binding().memory; - let resource = ctx.memory_management.get(binding); + + let resource = ctx + .memory_management + .get_resource(handle.clone().binding().memory); unsafe { cudarc::driver::result::memcpy_htod_async(resource.ptr, data, ctx.stream).unwrap(); @@ -106,9 +105,7 @@ impl> ComputeServer for CudaServer { fn empty(&mut self, size: usize) -> server::Handle { let ctx = self.get_context(); - let handle = ctx.memory_management.reserve(size, || unsafe { - cudarc::driver::result::stream::synchronize(ctx.stream).unwrap(); - }); + let handle = ctx.memory_management.reserve(size, &[]); server::Handle::new(handle) } @@ -148,7 +145,7 @@ impl> ComputeServer for CudaServer { let resources = bindings .into_iter() - .map(|binding| ctx.memory_management.get(binding.memory)) + .map(|binding| ctx.memory_management.get_resource(binding.memory)) .collect::>(); ctx.execute_task(kernel_id, count, resources); @@ -171,7 +168,7 @@ impl> ComputeServer for CudaServer { binding: server::Binding, ) -> ::Resource { let ctx = self.get_context(); - ctx.memory_management.get(binding.memory) + ctx.memory_management.get_resource(binding.memory) } } diff --git a/crates/cubecl-cuda/src/compute/storage.rs b/crates/cubecl-cuda/src/compute/storage.rs index 4981b6b7..de7d1521 100644 --- a/crates/cubecl-cuda/src/compute/storage.rs +++ b/crates/cubecl-cuda/src/compute/storage.rs @@ -151,25 +151,11 @@ impl ComputeStorage for CudaStorage { fn alloc(&mut self, size: usize) -> StorageHandle { let id = StorageId::new(); let ptr = unsafe { cudarc::driver::result::malloc_async(self.stream, size).unwrap() }; - self.memory.insert(id.clone(), ptr); + self.memory.insert(id, ptr); StorageHandle::new(id, StorageUtilization::Full(size)) } fn dealloc(&mut self, id: StorageId) { self.deallocations.push(id); } - - fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) { - let num_bytes = from.size(); - - unsafe { - cudarc::driver::result::memcpy_dtod_async( - self.get(to).ptr, - self.get(from).ptr, - num_bytes, - self.stream, - ) - .unwrap(); - } - } } diff --git a/crates/cubecl-runtime/benches/dynamic.rs b/crates/cubecl-runtime/benches/dynamic.rs index 73353a8d..f1d21f49 100644 --- a/crates/cubecl-runtime/benches/dynamic.rs +++ b/crates/cubecl-runtime/benches/dynamic.rs @@ -22,7 +22,7 @@ fn main() { if handles.len() >= 4000 { handles.pop_front(); } - let handle = mm.reserve(MB, || {}); + let handle = mm.reserve(MB, &[]); handles.push_back(handle); } println!("{:?}", start.elapsed()); diff --git a/crates/cubecl-runtime/src/id.rs b/crates/cubecl-runtime/src/id.rs index dacc3cd6..e11eefff 100644 --- a/crates/cubecl-runtime/src/id.rs +++ b/crates/cubecl-runtime/src/id.rs @@ -5,7 +5,7 @@ use alloc::sync::Arc; macro_rules! storage_id_type { ($name:ident) => { /// Storage ID. - #[derive(Clone, Hash, PartialEq, Eq, Debug)] + #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)] pub struct $name { value: usize, } diff --git a/crates/cubecl-runtime/src/memory_management/base.rs b/crates/cubecl-runtime/src/memory_management/base.rs index 76858dda..deeb4b23 100644 --- a/crates/cubecl-runtime/src/memory_management/base.rs +++ b/crates/cubecl-runtime/src/memory_management/base.rs @@ -1,4 +1,4 @@ -use crate::storage::ComputeStorage; +use crate::storage::{ComputeStorage, StorageHandle, StorageId}; /// The managed tensor buffer handle that points to some memory segment. /// It should not contain actual data. @@ -23,18 +23,24 @@ pub trait MemoryManagement: Send + core::fmt::Debug { /// The associated type that must implement [MemoryBinding] type Binding: MemoryBinding; + /// Returns the storage from the specified binding + fn get(&mut self, binding: Self::Binding) -> StorageHandle; + /// Returns the resource from the storage at the specified handle - fn get(&mut self, binding: Self::Binding) -> Storage::Resource; + fn get_resource(&mut self, binding: Self::Binding) -> Storage::Resource { + let handle = self.get(binding); + self.storage().get(&handle) + } /// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it - fn reserve(&mut self, size: usize, sync: Sync) -> Self::Handle; + fn reserve(&mut self, size: usize, exclude: &[StorageId]) -> Self::Handle; /// Bypass the memory allocation algorithm to allocate data directly. /// /// # Notes /// /// Can be useful for servers that want specific control over memory. - fn alloc(&mut self, size: usize, sync: Sync) -> Self::Handle; + fn alloc(&mut self, size: usize) -> Self::Handle; /// Bypass the memory allocation algorithm to deallocate data directly. /// diff --git a/crates/cubecl-runtime/src/memory_management/dynamic.rs b/crates/cubecl-runtime/src/memory_management/dynamic.rs index 4df0aa2e..71a51eb7 100644 --- a/crates/cubecl-runtime/src/memory_management/dynamic.rs +++ b/crates/cubecl-runtime/src/memory_management/dynamic.rs @@ -2,7 +2,7 @@ use super::memory_pool::{ MemoryExtensionStrategy, MemoryPool, MemoryPoolBinding, MemoryPoolHandle, RoundingStrategy, SmallMemoryPool, }; -use crate::storage::ComputeStorage; +use crate::storage::{ComputeStorage, StorageHandle, StorageId}; use alloc::vec::Vec; use super::MemoryManagement; @@ -92,7 +92,7 @@ impl DynamicMemoryManagement { ); for _ in 0..option.chunk_num_prealloc { - pool.alloc(&mut storage, option.chunk_size, || {}); + pool.alloc(&mut storage, option.chunk_size); } pool @@ -125,46 +125,46 @@ impl MemoryManagement for DynamicMemoryManagem type Handle = MemoryPoolHandle; type Binding = MemoryPoolBinding; - fn get(&mut self, binding: Self::Binding) -> Storage::Resource { - if let Some(handle) = self.small_memory_pool.get(&mut self.storage, &binding) { - return handle; + fn get(&mut self, binding: Self::Binding) -> StorageHandle { + if let Some(handle) = self.small_memory_pool.get(&binding) { + return handle.clone(); } - for pool in &mut self.pools { - if let Some(handle) = pool.get(&mut self.storage, &binding) { - return handle; + for pool in &self.pools { + if let Some(handle) = pool.get(&binding) { + return handle.clone(); } } panic!("No handle found in memory pools"); } - fn reserve(&mut self, size: usize, sync: Sync) -> Self::Handle { + fn reserve(&mut self, size: usize, exclude: &[StorageId]) -> Self::Handle { if size <= self.min_chunk_alignment_offset { return self .small_memory_pool - .reserve(&mut self.storage, size, sync); + .reserve(&mut self.storage, size, exclude); } for (index, option) in self.options.iter().enumerate() { if size <= option.slice_max_size { let pool = &mut self.pools[index]; - return pool.reserve(&mut self.storage, size, sync); + return pool.reserve(&mut self.storage, size, exclude); } } panic!("No memory pool big enough to reserve {size} bytes."); } - fn alloc(&mut self, size: usize, sync: Sync) -> Self::Handle { + fn alloc(&mut self, size: usize) -> Self::Handle { if size <= self.min_chunk_alignment_offset { - return self.small_memory_pool.alloc(&mut self.storage, size, sync); + return self.small_memory_pool.alloc(&mut self.storage, size); } for (index, option) in self.options.iter().enumerate() { if size <= option.slice_max_size { let pool = &mut self.pools[index]; - return pool.alloc(&mut self.storage, size, sync); + return pool.alloc(&mut self.storage, size); } } diff --git a/crates/cubecl-runtime/src/memory_management/memory_pool/base.rs b/crates/cubecl-runtime/src/memory_management/memory_pool/base.rs index 5c1f7848..dfffd755 100644 --- a/crates/cubecl-runtime/src/memory_management/memory_pool/base.rs +++ b/crates/cubecl-runtime/src/memory_management/memory_pool/base.rs @@ -1,34 +1,25 @@ use super::index::SearchIndex; -use super::{ - ChunkHandle, ChunkId, MemoryChunk, MemoryPoolBinding, MemoryPoolHandle, MemorySlice, - RingBuffer, SliceHandle, SliceId, -}; -use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization}; +use super::{MemoryPoolBinding, MemoryPoolHandle, RingBuffer, SliceHandle, SliceId}; +use crate::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; use alloc::vec::Vec; -use hashbrown::{HashMap, HashSet}; +use hashbrown::HashMap; pub struct MemoryPool { - chunks: HashMap, + chunks: HashMap, slices: HashMap, #[allow(unused)] // will be used when we rewrite memory extension memory_extension_strategy: MemoryExtensionStrategy, rounding: RoundingStrategy, - chunk_index: SearchIndex, - ring: RingBuffer, - recently_added_chunks: Vec, + storage_index: SearchIndex, + ring: RingBuffer, + recently_added_chunks: Vec, recently_allocated_size: usize, buffer_alignment: usize, } -struct SliceUpdate { - slice_id: SliceId, - size: usize, -} - #[derive(new, Debug)] pub struct Chunk { - pub storage: StorageHandle, - pub handle: ChunkHandle, + pub alloc_size: usize, pub slices: MemoryPage, } @@ -88,20 +79,12 @@ impl MemoryPage { fn insert_slice(&mut self, address: usize, slice: SliceId) { self.slices.insert(address, slice); } - - fn slices_sorted_by_address(&self) -> Vec { - let mut entries: Vec<(usize, SliceId)> = self.slices.clone().into_iter().collect(); - entries.sort_by_key(|&(key, _)| key); - let sorted_slices: Vec = entries.into_iter().map(|(_, values)| values).collect(); - sorted_slices - } } #[derive(new, Debug)] pub struct Slice { pub storage: StorageHandle, pub handle: SliceHandle, - pub chunk: ChunkHandle, pub padding: usize, } @@ -176,7 +159,7 @@ impl MemoryPool { slices: HashMap::new(), memory_extension_strategy: merging_strategy, rounding: alloc_strategy, - chunk_index: SearchIndex::new(), + storage_index: SearchIndex::new(), ring: RingBuffer::new(buffer_alignment), recently_added_chunks: Vec::new(), recently_allocated_size: 0, @@ -185,42 +168,34 @@ impl MemoryPool { } /// Returns the resource from the storage, for the specified handle. - pub fn get( - &mut self, - storage: &mut Storage, - binding: &MemoryPoolBinding, - ) -> Option { - self.slices - .get(binding.slice.id()) - .map(|s| &s.storage) - .map(|h| storage.get(h)) + pub fn get(&self, binding: &MemoryPoolBinding) -> Option<&StorageHandle> { + self.slices.get(binding.slice.id()).map(|s| &s.storage) } /// Reserves memory of specified size using the reserve algorithm, and return /// a handle to the reserved memory. /// /// Also clean ups, merging free slices together if permitted by the merging strategy - pub fn reserve( + pub fn reserve( &mut self, storage: &mut Storage, size: usize, - sync: Sync, + exclude: &[StorageId], ) -> MemoryPoolHandle { - let slice = self.get_free_slice(size); + let slice = self.get_free_slice(size, exclude); match slice { Some(slice) => MemoryPoolHandle { slice: slice.clone(), }, - None => self.alloc(storage, size, sync), + None => self.alloc(storage, size), } } - pub fn alloc( + pub fn alloc( &mut self, storage: &mut Storage, size: usize, - #[allow(unused)] sync: Sync, ) -> MemoryPoolHandle { let alloc_size = self.rounding.alloc_size(size); self.alloc_slice(storage, alloc_size, size) @@ -233,17 +208,15 @@ impl MemoryPool { slice_size: usize, ) -> MemoryPoolHandle { let chunk_size = self.rounding.alloc_size(alloc_size); - let handle_chunk = self.create_chunk(storage, chunk_size); - let chunk_size = self.chunks.get(handle_chunk.id()).unwrap().storage.size(); - self.recently_added_chunks.push(*handle_chunk.id()); + let storage_id = self.create_chunk(storage, chunk_size); + let chunk_size = self.chunks.get(&storage_id).unwrap().alloc_size; + self.recently_added_chunks.push(storage_id); self.recently_allocated_size += chunk_size; - let chunk_id = *handle_chunk.id(); - let (slice, extra_slice) = - self.allocate_slices(handle_chunk.clone(), chunk_size, slice_size); + let (slice, extra_slice) = self.allocate_slices(storage_id, chunk_size, slice_size); let handle_slice = slice.handle.clone(); - self.update_chunk_metadata(chunk_id, slice, extra_slice); + self.update_chunk_metadata(slice, extra_slice); MemoryPoolHandle { slice: handle_slice, @@ -252,16 +225,16 @@ impl MemoryPool { fn allocate_slices( &self, - handle_chunk: ChunkHandle, + storage_id: StorageId, alloc_size: usize, slice_size: usize, ) -> (Slice, Option) { - let slice = self.create_slice(0, slice_size, handle_chunk.clone()); + let slice = self.create_slice(0, slice_size, storage_id); let effective_size = slice.effective_size(); let extra_slice = if effective_size < alloc_size { - Some(self.create_slice(effective_size, alloc_size - effective_size, handle_chunk)) + Some(self.create_slice(effective_size, alloc_size - effective_size, storage_id)) } else { None }; @@ -269,30 +242,20 @@ impl MemoryPool { (slice, extra_slice) } - fn update_chunk_metadata( - &mut self, - chunk_id: ChunkId, - slice: Slice, - extra_slice: Option, - ) { + fn update_chunk_metadata(&mut self, slice: Slice, extra_slice: Option) { + let storage_id = slice.storage.id; let slice_id = *slice.handle.id(); let slice_offset = slice.storage.offset(); self.slices.insert(slice_id, slice); - self.chunks - .get_mut(&chunk_id) - .unwrap() - .slices - .slices - .insert(slice_offset, slice_id); + let chunk = self.chunks.get_mut(&storage_id).unwrap(); + chunk.slices.slices.insert(slice_offset, slice_id); if let Some(extra_slice) = extra_slice { let extra_slice_id = *extra_slice.handle.id(); let extra_slice_offset = extra_slice.storage.offset(); self.slices.insert(extra_slice_id, extra_slice); - self.chunks - .get_mut(&chunk_id) - .unwrap() + chunk .slices .slices .insert(extra_slice_offset, extra_slice_id); @@ -304,7 +267,7 @@ impl MemoryPool { let total_memory_usage: f64 = self .chunks .values() - .map(|chunk| chunk.storage.size() as f64) + .map(|chunk| chunk.alloc_size as f64) .sum(); let effective_memory_usage: f64 = self .slices @@ -318,7 +281,7 @@ impl MemoryPool { /// Finds a free slice that can contain the given size /// Returns the chunk's id and size. - fn get_free_slice(&mut self, size: usize) -> Option { + fn get_free_slice(&mut self, size: usize, exclude: &[StorageId]) -> Option { if size < MIN_SIZE_NEEDED_TO_OFFSET { return None; } @@ -326,9 +289,12 @@ impl MemoryPool { let padding = calculate_padding(size, self.buffer_alignment); let effective_size = size + padding; - let slice_id = - self.ring - .find_free_slice(effective_size, &mut self.chunks, &mut self.slices)?; + let slice_id = self.ring.find_free_slice( + effective_size, + &mut self.chunks, + &mut self.slices, + exclude, + )?; let slice = self.slices.get_mut(&slice_id).unwrap(); let old_slice_size = slice.effective_size(); @@ -350,7 +316,7 @@ impl MemoryPool { } /// Creates a slice of size `size` upon the given chunk with the given offset. - fn create_slice(&self, offset: usize, size: usize, handle_chunk: ChunkHandle) -> Slice { + fn create_slice(&self, offset: usize, size: usize, storage_id: StorageId) -> Slice { assert_eq!( offset % self.buffer_alignment, 0, @@ -360,17 +326,16 @@ impl MemoryPool { if offset > 0 && size < MIN_SIZE_NEEDED_TO_OFFSET { panic!("tried to create slice of size {size} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support"); } - let chunk = self.chunks.get(handle_chunk.id()).unwrap(); let handle = SliceHandle::new(); let storage = StorageHandle { - id: chunk.storage.id.clone(), + id: storage_id, utilization: StorageUtilization::Slice { offset, size }, }; let padding = calculate_padding(size, self.buffer_alignment); - Slice::new(storage, handle, chunk.handle.clone(), padding) + Slice::new(storage, handle, padding) } /// Creates a chunk of given size by allocating on the storage. @@ -378,119 +343,22 @@ impl MemoryPool { &mut self, storage: &mut Storage, size: usize, - ) -> ChunkHandle { + ) -> StorageId { let padding = calculate_padding(size, self.buffer_alignment); let effective_size = size + padding; let storage = storage.alloc(effective_size); - let handle = ChunkHandle::new(); - let id = *handle.id(); + let id = storage.id; self.ring.push_chunk(id); self.chunks.insert( id, - Chunk::new(storage, handle.clone(), MemoryPage::new(HashMap::new())), + Chunk::new(effective_size, MemoryPage::new(HashMap::new())), ); - self.chunk_index.insert(id, size); + self.storage_index.insert(id, size); - handle - } - - #[allow(unused)] - fn extend_max_memory(&mut self, storage: &mut Storage) { - let mut slices = Vec::::new(); - - let mut deallocations = HashSet::::new(); - - let mut chunks_total_size: usize = 0; - - for chunk_id in &self.recently_added_chunks { - let chunk = self.chunks.get(chunk_id).unwrap(); - let chunk_id = *chunk.handle.id(); - let sorted_slice = chunk.slices.slices_sorted_by_address(); - for slice_id in sorted_slice { - let slice = self.slices.get(&slice_id).unwrap(); - let size = slice.storage.size(); - - slices.push(SliceUpdate { slice_id, size }); - } - chunks_total_size += chunk.storage.size(); - deallocations.insert(chunk_id); - } - - if !slices.is_empty() { - self.move_to_new_chunk(chunks_total_size, storage, &mut slices, &mut deallocations); - } else { - self.deallocate(storage, &mut deallocations); - } - } - - fn deallocate( - &mut self, - storage: &mut Storage, - deallocations: &mut HashSet, - ) { - for id in deallocations.drain() { - let mut chunk = self.chunks.remove(&id).unwrap(); - self.ring.remove_chunk(id); - - for (_address, slice_id) in chunk.slices.slices.drain() { - let slice = self.slices.get(&slice_id).unwrap(); - let chunk_id = *slice.chunk.id(); - - assert_ne!(chunk_id, id, "Chunk id should be updated"); - } - - self.chunk_index.remove(&id); - storage.dealloc(chunk.storage.id); - } - } - - fn move_to_new_chunk( - &mut self, - alloc_size: usize, - storage: &mut Storage, - slices: &mut Vec, - deallocations: &mut HashSet, - ) { - let chunk = self.create_chunk(storage, alloc_size); - let storage_id = self.chunks.get(chunk.id()).unwrap().storage.id.clone(); - let mut offset = 0; - let mut slices_ids: Vec<(usize, SliceId)> = Vec::new(); - - for update in slices.drain(..) { - let slice_id = update.slice_id; - - let slice = self.slices.get_mut(&slice_id).unwrap(); - let old_storage = slice.storage.clone(); - - slice.chunk = chunk.clone(); - slice.storage = StorageHandle { - id: storage_id.clone(), - utilization: StorageUtilization::Slice { - offset, - size: update.size, - }, - }; - storage.copy(&old_storage, &slice.storage); - slices_ids.push((offset, slice_id)); - offset += slice.effective_size(); - } - - let chunk = self.chunks.get_mut(chunk.id()).unwrap(); - let chunk_handle = chunk.handle.clone(); - for (address, slice_id) in slices_ids.drain(..) { - chunk.slices.insert_slice(address, slice_id); - } - let chunk_size = chunk.storage.size(); - let last_slice_size = chunk_size - offset; - assert_eq!(last_slice_size % self.buffer_alignment, 0); - if last_slice_size != 0 { - self.create_slice(offset, last_slice_size, chunk_handle); - } - - self.deallocate(storage, deallocations); + id } } @@ -503,22 +371,22 @@ fn calculate_padding(size: usize, buffer_alignment: usize) -> usize { } } -impl MemorySlice for Slice { - fn is_free(&self) -> bool { +impl Slice { + pub(crate) fn is_free(&self) -> bool { self.handle.is_free() } - fn size(&self) -> usize { + pub(crate) fn size(&self) -> usize { self.effective_size() } - fn split(&mut self, offset_slice: usize, buffer_alignment: usize) -> Option { + pub(crate) fn split(&mut self, offset_slice: usize, buffer_alignment: usize) -> Option { let size_new = self.effective_size() - offset_slice; let offset_new = self.storage.offset() + offset_slice; let old_size = self.effective_size(); let storage_new = StorageHandle { - id: self.storage.id.clone(), + id: self.storage.id, utilization: StorageUtilization::Slice { offset: offset_new, size: size_new, @@ -549,20 +417,20 @@ impl MemorySlice for Slice { ); self.padding = 0; let padding = calculate_padding(size_new - buffer_alignment, buffer_alignment); - Some(Slice::new(storage_new, handle, self.chunk.clone(), padding)) + Some(Slice::new(storage_new, handle, padding)) } - fn id(&self) -> SliceId { + pub(crate) fn id(&self) -> SliceId { *self.handle.id() } - fn next_slice_position(&self) -> usize { + pub(crate) fn next_slice_position(&self) -> usize { self.storage.offset() + self.effective_size() } } -impl MemoryChunk for Chunk { - fn merge_next_slice( +impl Chunk { + pub(crate) fn merge_next_slice( &mut self, from_slice_index: usize, slices: &mut HashMap, @@ -570,11 +438,11 @@ impl MemoryChunk for Chunk { self.slices.merge_with_next_slice(from_slice_index, slices) } - fn slice(&self, index: usize) -> Option { + pub(crate) fn slice(&self, index: usize) -> Option { self.slices.find_slice(index) } - fn insert_slice( + pub(crate) fn insert_slice( &mut self, position: usize, slice: Slice, diff --git a/crates/cubecl-runtime/src/memory_management/memory_pool/handle.rs b/crates/cubecl-runtime/src/memory_management/memory_pool/handle.rs index 3bb04c2a..575fa14f 100644 --- a/crates/cubecl-runtime/src/memory_management/memory_pool/handle.rs +++ b/crates/cubecl-runtime/src/memory_management/memory_pool/handle.rs @@ -1,8 +1,6 @@ use crate::memory_id_type; use crate::memory_management::{MemoryBinding, MemoryHandle}; -// The ChunkId allows to keep track of how many references there are to a specific chunk. -memory_id_type!(ChunkId, ChunkHandle); // The SliceId allows to keep track of how many references there are to a specific slice. memory_id_type!(SliceId, SliceHandle, SliceBinding); diff --git a/crates/cubecl-runtime/src/memory_management/memory_pool/ring.rs b/crates/cubecl-runtime/src/memory_management/memory_pool/ring.rs index 00f80ff9..4b2cdc94 100644 --- a/crates/cubecl-runtime/src/memory_management/memory_pool/ring.rs +++ b/crates/cubecl-runtime/src/memory_management/memory_pool/ring.rs @@ -1,75 +1,46 @@ use alloc::vec::Vec; -use core::marker::PhantomData; use hashbrown::HashMap; -use super::{ChunkId, SliceId}; +use crate::storage::StorageId; + +use super::{Chunk, Slice, SliceId}; #[derive(Debug)] -pub struct RingBuffer, S: MemorySlice> { - queue: Vec, - chunk_positions: HashMap, +pub struct RingBuffer { + queue: Vec, + chunk_positions: HashMap, cursor_slice: usize, cursor_chunk: usize, - _s: PhantomData, - _c: PhantomData, buffer_alignment: usize, } -pub trait MemoryChunk { - fn merge_next_slice(&mut self, slice_position: usize, slices: &mut HashMap) - -> bool; - fn slice(&self, index: usize) -> Option; - fn insert_slice(&mut self, position: usize, slice: S, slices: &mut HashMap); -} - -pub trait MemorySlice: Sized { - fn is_free(&self) -> bool; - fn size(&self) -> usize; - fn split(&mut self, offset: usize, buffer_alignment: usize) -> Option; - fn id(&self) -> SliceId; - fn next_slice_position(&self) -> usize; -} - -impl, S: MemorySlice> RingBuffer { +impl RingBuffer { pub fn new(buffer_alignment: usize) -> Self { Self { queue: Vec::new(), chunk_positions: HashMap::new(), cursor_slice: 0, cursor_chunk: 0, - _s: PhantomData, - _c: PhantomData, buffer_alignment, } } - pub fn push_chunk(&mut self, chunk_id: ChunkId) { - self.queue.push(chunk_id); - self.chunk_positions.insert(chunk_id, self.queue.len() - 1); - } - - pub fn remove_chunk(&mut self, chunk_id: ChunkId) { - if let Some(position) = self.chunk_positions.remove(&chunk_id) { - self.queue.remove(position); - } - - self.chunk_positions.clear(); - - for (pos, id) in self.queue.iter().enumerate() { - self.chunk_positions.insert(*id, pos); - } - self.cursor_chunk = 0; - self.cursor_slice = 0; + pub fn push_chunk(&mut self, storage_id: StorageId) { + self.queue.push(storage_id); + self.chunk_positions + .insert(storage_id, self.queue.len() - 1); } pub fn find_free_slice( &mut self, size: usize, - chunks: &mut HashMap, - slices: &mut HashMap, + chunks: &mut HashMap, + slices: &mut HashMap, + exclude: &[StorageId], ) -> Option { let max_second = self.cursor_chunk; - let result = self.find_free_slice_in_all_chunks(size, chunks, slices, self.queue.len()); + let result = + self.find_free_slice_in_all_chunks(size, chunks, slices, self.queue.len(), exclude); if result.is_some() { return result; @@ -77,14 +48,14 @@ impl, S: MemorySlice> RingBuffer { self.cursor_chunk = 0; self.cursor_slice = 0; - self.find_free_slice_in_all_chunks(size, chunks, slices, max_second) + self.find_free_slice_in_all_chunks(size, chunks, slices, max_second, exclude) } fn find_free_slice_in_chunk( &mut self, size: usize, - chunk: &mut C, - slices: &mut HashMap, + chunk: &mut Chunk, + slices: &mut HashMap, mut slice_index: usize, ) -> Option<(usize, SliceId)> { while let Some(slice_id) = chunk.slice(slice_index) { @@ -127,9 +98,10 @@ impl, S: MemorySlice> RingBuffer { fn find_free_slice_in_all_chunks( &mut self, size: usize, - chunks: &mut HashMap, - slices: &mut HashMap, + chunks: &mut HashMap, + slices: &mut HashMap, max_cursor_position: usize, + exclude: &[StorageId], ) -> Option { let start = self.cursor_chunk; let end = usize::min(self.queue.len(), max_cursor_position); @@ -141,6 +113,10 @@ impl, S: MemorySlice> RingBuffer { } if let Some(id) = self.queue.get(chunk_index) { + if exclude.contains(id) { + continue; + } + let chunk = chunks.get_mut(id).unwrap(); let result = self.find_free_slice_in_chunk(size, chunk, slices, slice_index); @@ -161,294 +137,221 @@ impl, S: MemorySlice> RingBuffer { #[cfg(test)] mod tests { - use super::stub::*; + use crate::{ + memory_management::memory_pool::{MemoryPage, SliceHandle}, + storage::StorageHandle, + }; + use super::*; - use alloc::vec; #[test] fn simple_1() { - let mut ring = RingBuffer::::new(0); - - let slice_1 = new_slice(0, 100, 0); - let slice_2 = new_slice(1, 200, 1); - let chunk_1 = new_chunk(0, vec![0, 1]); + let mut ring = RingBuffer::new(1); - let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]); - let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]); + let (storage_id, slice_ids, mut slices, chunk) = new_chunk(&[100, 200]); - ring.push_chunk(ChunkId { value: 0 }); + ring.push_chunk(storage_id); + let mut chunks = HashMap::from([(storage_id, chunk)]); - let slice = ring.find_free_slice(50, &mut chunks, &mut slices).unwrap(); + let slice = ring + .find_free_slice(50, &mut chunks, &mut slices, &[]) + .unwrap(); - assert_eq!(slice, SliceId { value: 0 }); - assert_eq!(slices.get(&slice).unwrap().size, 50); + assert_eq!(slice, slice_ids[0]); + assert_eq!(slices.get(&slice).unwrap().size(), 50); assert_eq!(slices.len(), 3); - assert_eq!(chunks.values().last().unwrap().slices.len(), 3); + assert_eq!(chunks.values().last().unwrap().slices.slices.len(), 3); } #[test] fn simple_2() { - let mut ring = RingBuffer::::new(0); + let mut ring = RingBuffer::new(1); - let slice_1 = new_slice(0, 100, 0); - let slice_2 = new_slice(1, 200, 1); - let chunk_1 = new_chunk(0, vec![0, 1]); + let (storage_id, slice_ids, mut slices, chunk) = new_chunk(&[100, 200]); - let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]); - let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]); + ring.push_chunk(storage_id); + let mut chunks = HashMap::from([(storage_id, chunk)]); - ring.push_chunk(ChunkId { value: 0 }); + let slice = ring + .find_free_slice(150, &mut chunks, &mut slices, &[]) + .unwrap(); - let slice = ring.find_free_slice(150, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slice, SliceId { value: 0 }); - assert_eq!(slices.get(&slice).unwrap().size, 150); + assert_eq!(slice, slice_ids[0]); + assert_eq!(slices.get(&slice).unwrap().size(), 150); assert_eq!(slices.len(), 2); - assert_eq!(chunks.values().last().unwrap().slices.len(), 2); + assert_eq!(chunks.values().last().unwrap().slices.slices.len(), 2); } #[test] fn multiple_chunks() { - let mut ring = RingBuffer::::new(0); + let mut ring = RingBuffer::new(1); + + let (storage_id_1, mut slice_ids, mut slices, chunk_1) = new_chunk(&[100, 200]); + let (storage_id_2, slice_ids_2, slices_2, chunk_2) = new_chunk(&[200, 200]); - let slice_1 = new_slice(0, 100, 0); - let slice_2 = new_slice(1, 200, 1); - let slice_3 = new_slice(2, 200, 0); - let slice_4 = new_slice(3, 200, 1); - let chunk_1 = new_chunk(0, vec![0, 1]); - let chunk_2 = new_chunk(1, vec![2, 3]); + ring.push_chunk(storage_id_1); + ring.push_chunk(storage_id_2); - let mut slices = HashMap::from([ - (slice_1.id, slice_1), - (slice_2.id, slice_2), - (slice_3.id, slice_3), - (slice_4.id, slice_4), - ]); - let mut chunks = HashMap::from([(chunk_1.id, chunk_1), (chunk_2.id, chunk_2)]); + let mut chunks = HashMap::from([(storage_id_1, chunk_1), (storage_id_2, chunk_2)]); - ring.push_chunk(ChunkId { value: 0 }); - ring.push_chunk(ChunkId { value: 1 }); + slice_ids.extend(slice_ids_2); + slices.extend(slices_2); - slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true; - slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = false; - slices.get_mut(&SliceId { value: 3 }).unwrap().is_free = false; + // Clone references to control what slice is free: + let _slice_1 = slices.get(&slice_ids[1]).unwrap().handle.clone(); + let _slice_3 = slices.get(&slice_ids[3]).unwrap().handle.clone(); - let slice = ring.find_free_slice(200, &mut chunks, &mut slices).unwrap(); + let slice = ring + .find_free_slice(200, &mut chunks, &mut slices, &[]) + .unwrap(); - assert_eq!(slice, SliceId { value: 2 }); + assert_eq!(slice, slice_ids[2]); - let slice = ring.find_free_slice(100, &mut chunks, &mut slices).unwrap(); + let slice = ring + .find_free_slice(100, &mut chunks, &mut slices, &[]) + .unwrap(); - assert_eq!(slice, SliceId { value: 0 }); + assert_eq!(slice, slice_ids[0]); } #[test] fn find_free_slice_with_exact_fit() { - let mut ring = RingBuffer::::new(0); - - let slice_1 = new_slice(0, 100, 0); - let slice_2 = new_slice(1, 200, 1); - let chunk_1 = new_chunk(0, vec![0, 1]); + let mut ring = RingBuffer::new(1); - let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]); - let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]); + let (storage_id, slice_ids, mut slices, chunk) = new_chunk(&[100, 200]); - ring.push_chunk(ChunkId { value: 0 }); + ring.push_chunk(storage_id); + let mut chunks = HashMap::from([(storage_id, chunk)]); - slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = false; - slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true; + // Clone reference to control what slice is free: + let _slice_1 = slices.get(&slice_ids[0]).unwrap().handle.clone(); - let slice = ring.find_free_slice(200, &mut chunks, &mut slices).unwrap(); + let slice = ring + .find_free_slice(200, &mut chunks, &mut slices, &[]) + .unwrap(); - assert_eq!(slice, SliceId { value: 1 }); - assert_eq!(slices.get(&slice).unwrap().size, 200); + assert_eq!(slice, slice_ids[1]); + assert_eq!(slices.get(&slice).unwrap().size(), 200); assert_eq!(slices.len(), 2); - assert_eq!(chunks.values().last().unwrap().slices.len(), 2); + assert_eq!(chunks.values().last().unwrap().slices.slices.len(), 2); } #[test] fn find_free_slice_with_merging() { - let mut ring = RingBuffer::::new(0); + let mut ring = RingBuffer::new(1); - let slice_1 = new_slice(0, 100, 0); - let slice_2 = new_slice(1, 50, 1); - let slice_3 = new_slice(2, 100, 2); - let chunk_1 = new_chunk(0, vec![0, 1, 2]); + let (storage_id, slice_ids, mut slices, chunk) = new_chunk(&[100, 50, 100]); - let mut slices = HashMap::from([ - (slice_1.id, slice_1), - (slice_2.id, slice_2), - (slice_3.id, slice_3), - ]); - let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]); + ring.push_chunk(storage_id); + let mut chunks = HashMap::from([(storage_id, chunk)]); - ring.push_chunk(ChunkId { value: 0 }); + let slice = ring + .find_free_slice(250, &mut chunks, &mut slices, &[]) + .unwrap(); - slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true; - slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true; - slices.get_mut(&SliceId { value: 2 }).unwrap().is_free = true; - - let slice = ring.find_free_slice(250, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slice, SliceId { value: 0 }); - assert_eq!(slices.get(&slice).unwrap().size, 250); + assert_eq!(slice, slice_ids[0]); + assert_eq!(slices.get(&slice).unwrap().size(), 250); assert_eq!(slices.len(), 1); - assert_eq!(chunks.values().last().unwrap().slices.len(), 1); + assert_eq!(chunks.values().last().unwrap().slices.slices.len(), 1); } #[test] fn find_free_slice_with_multiple_chunks_and_merging() { - let mut ring = RingBuffer::::new(0); - - let slice_1 = new_slice(0, 50, 0); - let slice_2 = new_slice(1, 50, 1); - let chunk_1 = new_chunk(0, vec![0, 1]); - - let slice_3 = new_slice(2, 100, 0); - let slice_4 = new_slice(3, 50, 1); - let chunk_2 = new_chunk(1, vec![2, 3]); + let mut ring = RingBuffer::new(1); - let mut slices = HashMap::from([ - (slice_1.id, slice_1), - (slice_2.id, slice_2), - (slice_3.id, slice_3), - (slice_4.id, slice_4), - ]); - let mut chunks = HashMap::from([(chunk_1.id, chunk_1), (chunk_2.id, chunk_2)]); + let (storage_id_1, mut slice_ids, mut slices, chunk_1) = new_chunk(&[50, 50]); + let (storage_id_2, slice_ids_2, slices_2, chunk_2) = new_chunk(&[100, 50]); + slice_ids.extend(slice_ids_2); + slices.extend(slices_2); - ring.push_chunk(ChunkId { value: 0 }); - ring.push_chunk(ChunkId { value: 1 }); + ring.push_chunk(storage_id_1); + ring.push_chunk(storage_id_2); - slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true; - slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true; - slices.get_mut(&SliceId { value: 2 }).unwrap().is_free = true; - slices.get_mut(&SliceId { value: 3 }).unwrap().is_free = true; + let mut chunks = HashMap::from([(storage_id_1, chunk_1), (storage_id_2, chunk_2)]); - let slice = ring.find_free_slice(150, &mut chunks, &mut slices).unwrap(); + let slice = ring + .find_free_slice(150, &mut chunks, &mut slices, &[]) + .unwrap(); - assert_eq!(slices.get(&slice).unwrap().size, 150); + assert_eq!(slices.get(&slice).unwrap().size(), 150); assert_eq!(slices.len(), 2); - assert_eq!(chunks.values().last().unwrap().slices.len(), 1); - } - - fn new_slice(id: usize, size: usize, position: usize) -> TestSlice { - TestSlice { - id: SliceId { value: id }, - is_free: true, - size, - position, - } + assert_eq!(chunks.values().last().unwrap().slices.slices.len(), 1); } - fn new_chunk(id: usize, slices: Vec) -> TestChunk { - TestChunk { - id: ChunkId { value: id }, - slices: slices.into_iter().map(|i| SliceId { value: i }).collect(), - } - } -} - -#[cfg(test)] -mod stub { - use super::*; - use cubecl_common::*; - - #[derive(Debug)] - pub struct TestChunk { - pub id: ChunkId, - pub slices: Vec, - } + #[test] + fn excludes_excluded_storage() { + let mut ring = RingBuffer::new(1); - #[derive(Debug)] - pub struct TestSlice { - pub id: SliceId, - pub is_free: bool, - pub size: usize, - pub position: usize, - } + let (storage_id_1, mut slice_ids, mut slices, chunk_1) = new_chunk(&[100, 100]); + let (storage_id_2, slice_ids_2, slices_2, chunk_2) = new_chunk(&[100, 100]); - impl MemorySlice for TestSlice { - fn is_free(&self) -> bool { - self.is_free - } + ring.push_chunk(storage_id_1); + ring.push_chunk(storage_id_2); - fn size(&self) -> usize { - self.size - } - - fn split(&mut self, offset: usize, _buffer_alignment: usize) -> Option { - let size_remained = self.size - offset; - self.size = offset; + let mut chunks = HashMap::from([(storage_id_1, chunk_1), (storage_id_2, chunk_2)]); - Some(Self { - id: SliceId { - value: rand::gen_random(), - }, - is_free: true, - size: size_remained, - position: self.position + 1, - }) - } + slice_ids.extend(slice_ids_2); + slices.extend(slices_2); - fn id(&self) -> SliceId { - self.id - } + let slice = ring + .find_free_slice(100, &mut chunks, &mut slices, &[]) + .unwrap(); + assert_eq!(slice, slice_ids[0]); - fn next_slice_position(&self) -> usize { - self.position + 1 - } + let slice = ring + .find_free_slice(100, &mut chunks, &mut slices, &[storage_id_1]) + .unwrap(); + assert_eq!(slice, slice_ids[2]); } - impl MemoryChunk for TestChunk { - fn merge_next_slice( - &mut self, - from_slice_index: usize, - slices: &mut HashMap, - ) -> bool { - let slice_id_current = self.slices.get(from_slice_index).unwrap(); - let slice_id_next = self.slices.get(from_slice_index + 1); - let slice_id_next = match slice_id_next { - Some(val) => val, - None => return false, - }; - - let slice_next = slices.get(slice_id_next).unwrap(); - let is_free = slice_next.is_free; - let size = slice_next.size; - - let slice_current = slices.get_mut(slice_id_current).unwrap(); - - if is_free { - slice_current.size += size; - slices.remove(slice_id_next); - self.slices.remove(from_slice_index + 1); - - for (index, temp_slice_id) in self.slices.iter_mut().enumerate() { - let slice = slices.get_mut(temp_slice_id).unwrap(); - slice.position = index; - } - return true; - } - - false - } + fn new_chunk( + slice_sizes: &[usize], + ) -> (StorageId, Vec, HashMap, Chunk) { + let offsets: Vec<_> = slice_sizes + .iter() + .scan(0, |state, size| { + let offset = *state; + *state += *size; + Some(offset) + }) + .collect(); - fn slice(&self, index: usize) -> Option { - self.slices.get(index).copied() - } + let storage_id = StorageId::new(); - fn insert_slice( - &mut self, - position: usize, - slice: TestSlice, - slices: &mut HashMap, - ) { - self.slices.insert(position, slice.id()); - slices.insert(slice.id(), slice); - for (index, temp_slice_id) in self.slices.iter_mut().enumerate() { - let temp_slice = slices.get_mut(temp_slice_id).unwrap(); - temp_slice.position = index; - } - } + let slices: Vec<_> = offsets + .iter() + .zip(slice_sizes) + .map(|(&offset, &size)| Slice { + storage: StorageHandle { + id: storage_id, + utilization: crate::storage::StorageUtilization::Slice { offset, size }, + }, + handle: SliceHandle::new(), + padding: 0, + }) + .collect(); + + let mem_page = MemoryPage { + slices: slices + .iter() + .zip(offsets) + .map(|(slice, offset)| (offset, slice.id())) + .collect(), + }; + + let chunk = Chunk { + alloc_size: 1024 * 1024, // Arbitrary, just pretend we have a big enough allocation. + slices: mem_page, + }; + + ( + storage_id, + slices.iter().map(|slice| slice.id()).collect(), + slices + .into_iter() + .map(|slice| (slice.id(), slice)) + .collect(), + chunk, + ) } } diff --git a/crates/cubecl-runtime/src/memory_management/memory_pool/small.rs b/crates/cubecl-runtime/src/memory_management/memory_pool/small.rs index 7eb9cf5c..aad7c3c8 100644 --- a/crates/cubecl-runtime/src/memory_management/memory_pool/small.rs +++ b/crates/cubecl-runtime/src/memory_management/memory_pool/small.rs @@ -1,5 +1,5 @@ -use super::{ChunkHandle, ChunkId, MemoryPoolBinding, MemoryPoolHandle, SliceHandle, SliceId}; -use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization}; +use super::{MemoryPoolBinding, MemoryPoolHandle, SliceHandle, SliceId}; +use crate::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; use alloc::vec::Vec; use hashbrown::HashMap; @@ -14,18 +14,15 @@ use hashbrown::HashMap; /// - `ring_buffer`: A vector used as a ring buffer to manage chunk reuse. /// - `index`: The current position in the ring buffer. pub struct SmallMemoryPool { - chunks: HashMap, + chunks: HashMap, slices: HashMap, - ring_buffer: Vec, + ring_buffer: Vec, index: usize, buffer_storage_alignment_offset: usize, } #[derive(new, Debug)] pub struct SmallChunk { - pub storage: StorageHandle, - #[allow(dead_code)] - pub handle: ChunkHandle, pub slice: Option, } @@ -33,8 +30,6 @@ pub struct SmallChunk { pub struct SmallSlice { pub storage: StorageHandle, pub handle: SliceHandle, - #[allow(dead_code)] - pub chunk: ChunkHandle, pub padding: usize, } @@ -56,43 +51,35 @@ impl SmallMemoryPool { } /// Returns the resource from the storage, for the specified handle. - pub fn get( - &mut self, - storage: &mut Storage, - binding: &MemoryPoolBinding, - ) -> Option { - self.slices - .get(binding.slice.id()) - .map(|s| &s.storage) - .map(|h| storage.get(h)) + pub fn get(&self, binding: &MemoryPoolBinding) -> Option<&StorageHandle> { + self.slices.get(binding.slice.id()).map(|s| &s.storage) } /// Reserves memory of specified size using the reserve algorithm, and return /// a handle to the reserved memory. /// /// Also clean ups, merging free slices together if permitted by the merging strategy - pub fn reserve( + pub fn reserve( &mut self, storage: &mut Storage, size: usize, - sync: Sync, + exclude: &[StorageId], ) -> MemoryPoolHandle { assert!(size <= self.buffer_storage_alignment_offset); - let slice = self.get_free_slice(size); + let slice = self.get_free_slice(size, exclude); match slice { Some(slice) => MemoryPoolHandle { slice: slice.clone(), }, - None => self.alloc(storage, size, sync), + None => self.alloc(storage, size), } } - pub fn alloc( + pub fn alloc( &mut self, storage: &mut Storage, size: usize, - _sync: Sync, ) -> MemoryPoolHandle { assert!(size <= self.buffer_storage_alignment_offset); @@ -104,20 +91,19 @@ impl SmallMemoryPool { storage: &mut Storage, slice_size: usize, ) -> MemoryPoolHandle { - let handle_chunk = self.create_chunk(storage, self.buffer_storage_alignment_offset); - let chunk_id = *handle_chunk.id(); - let slice = self.allocate_slice(handle_chunk.clone(), slice_size); + let storage_id = self.create_chunk(storage, self.buffer_storage_alignment_offset); + let slice = self.allocate_slice(storage_id, slice_size); let handle_slice = slice.handle.clone(); - self.update_chunk_metadata(chunk_id, slice); + self.update_chunk_metadata(slice); MemoryPoolHandle { slice: handle_slice, } } - fn allocate_slice(&self, handle_chunk: ChunkHandle, slice_size: usize) -> SmallSlice { - let slice = self.create_slice(0, slice_size, handle_chunk.clone()); + fn allocate_slice(&self, storage_id: StorageId, slice_size: usize) -> SmallSlice { + let slice = self.create_slice(0, slice_size, storage_id); let effective_size = slice.effective_size(); assert_eq!(effective_size, self.buffer_storage_alignment_offset); @@ -125,20 +111,21 @@ impl SmallMemoryPool { slice } - fn update_chunk_metadata(&mut self, chunk_id: ChunkId, slice: SmallSlice) { + fn update_chunk_metadata(&mut self, slice: SmallSlice) { let slice_id = *slice.handle.id(); + self.chunks.get_mut(&slice.storage.id).unwrap().slice = Some(slice_id); self.slices.insert(slice_id, slice); - self.chunks.get_mut(&chunk_id).unwrap().slice = Some(slice_id); } - fn find_free_slice(&mut self) -> Option { - if self.ring_buffer.is_empty() { - return None; - } + fn find_free_slice(&mut self, exclude: &[StorageId]) -> Option { for _ in 0..self.ring_buffer.len() { - let chunk_id = self.ring_buffer.get(self.index).unwrap(); - let chunk = self.chunks.get(chunk_id).unwrap(); + let storage_id = self.ring_buffer.get(self.index).unwrap(); + if exclude.contains(storage_id) { + continue; + } + + let chunk = self.chunks.get(storage_id).unwrap(); let slice = self.slices.get(&chunk.slice.unwrap()).unwrap(); self.index = (self.index + 1) % self.ring_buffer.len(); if slice.handle.is_free() { @@ -150,8 +137,8 @@ impl SmallMemoryPool { /// Finds a free slice that can contain the given size /// Returns the chunk's id and size. - fn get_free_slice(&mut self, size: usize) -> Option { - let slice_id = self.find_free_slice()?; + fn get_free_slice(&mut self, size: usize, exclude: &[StorageId]) -> Option { + let slice_id = self.find_free_slice(exclude)?; let slice = self.slices.get_mut(&slice_id).unwrap(); let old_slice_size = slice.effective_size(); @@ -174,19 +161,18 @@ impl SmallMemoryPool { } /// Creates a slice of size `size` upon the given chunk with the given offset. - fn create_slice(&self, offset: usize, size: usize, handle_chunk: ChunkHandle) -> SmallSlice { + fn create_slice(&self, offset: usize, size: usize, storage_id: StorageId) -> SmallSlice { assert_eq!(offset, 0); - let chunk = self.chunks.get(handle_chunk.id()).unwrap(); let handle = SliceHandle::new(); let storage = StorageHandle { - id: chunk.storage.id.clone(), + id: storage_id, utilization: StorageUtilization::Slice { offset, size }, }; let padding = calculate_padding(size, self.buffer_storage_alignment_offset); - SmallSlice::new(storage, handle, chunk.handle.clone(), padding) + SmallSlice::new(storage, handle, padding) } /// Creates a chunk of given size by allocating on the storage. @@ -194,20 +180,15 @@ impl SmallMemoryPool { &mut self, storage: &mut Storage, size: usize, - ) -> ChunkHandle { + ) -> StorageId { let padding = calculate_padding(size, self.buffer_storage_alignment_offset); let effective_size = size + padding; let storage = storage.alloc(effective_size); - let handle = ChunkHandle::new(); - let id = *handle.id(); - + let id = storage.id; self.ring_buffer.push(id); - - self.chunks - .insert(id, SmallChunk::new(storage, handle.clone(), None)); - - handle + self.chunks.insert(id, SmallChunk::new(None)); + id } #[allow(unused)] diff --git a/crates/cubecl-runtime/src/memory_management/simple.rs b/crates/cubecl-runtime/src/memory_management/simple.rs index ad243860..f2ffcbf6 100644 --- a/crates/cubecl-runtime/src/memory_management/simple.rs +++ b/crates/cubecl-runtime/src/memory_management/simple.rs @@ -1,6 +1,6 @@ use crate::{ memory_id_type, - storage::{ComputeStorage, StorageHandle, StorageUtilization}, + storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}, }; use alloc::vec::Vec; use hashbrown::HashMap; @@ -175,35 +175,31 @@ impl MemoryManagement for SimpleMemoryManageme type Binding = SimpleBinding; /// Returns the resource from the storage, for the specified handle. - fn get(&mut self, binding: Self::Binding) -> Storage::Resource { - let storage = match binding { - SimpleBinding::Chunk(chunk) => { - &self - .chunks - .get(chunk.id()) - .expect("Storage found for the given execution buffer handle") - .storage - } - SimpleBinding::Slice(slice) => { - &self - .slices - .get(slice.id()) - .expect("Storage found for the given execution buffer handle") - .storage - } - }; - - self.storage.get(storage) + fn get(&mut self, binding: Self::Binding) -> StorageHandle { + match binding { + SimpleBinding::Chunk(chunk) => self + .chunks + .get(chunk.id()) + .expect("Storage found for the given execution buffer handle") + .storage + .clone(), + SimpleBinding::Slice(slice) => self + .slices + .get(slice.id()) + .expect("Storage found for the given execution buffer handle") + .storage + .clone(), + } } /// Reserves memory of specified size using the reserve algorithm, and return /// a handle to the reserved memory. /// /// Also clean ups, removing unused slices, and chunks if permitted by deallocation strategy. - fn reserve(&mut self, size: usize, _sync: Sync) -> Self::Handle { + fn reserve(&mut self, size: usize, exclude: &[StorageId]) -> Self::Handle { self.cleanup_slices(); - let handle = self.reserve_algorithm(size); + let handle = self.reserve_algorithm(size, exclude); if self.dealloc_strategy.should_dealloc() { self.cleanup_chunks(); @@ -212,7 +208,7 @@ impl MemoryManagement for SimpleMemoryManageme handle } - fn alloc(&mut self, size: usize, _sync: Sync) -> Self::Handle { + fn alloc(&mut self, size: usize) -> Self::Handle { self.create_chunk(size) } @@ -248,9 +244,9 @@ impl SimpleMemoryManagement { } } - fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle { + fn reserve_algorithm(&mut self, size: usize, exclude: &[StorageId]) -> SimpleHandle { // Looks for a large enough, existing but unused chunk of memory. - let chunk = self.find_free_chunk(size); + let chunk = self.find_free_chunk(size, exclude); match chunk { Some(chunk) => { @@ -269,7 +265,7 @@ impl SimpleMemoryManagement { /// Finds the smallest of the free and large enough chunks to fit `size` /// Returns the chunk's id and size. - fn find_free_chunk(&self, size: usize) -> Option<&Chunk> { + fn find_free_chunk(&self, size: usize, exclude: &[StorageId]) -> Option<&Chunk> { let mut size_diff_current = usize::MAX; let mut current = None; @@ -279,6 +275,10 @@ impl SimpleMemoryManagement { continue; } + if exclude.contains(&chunk.storage.id) { + continue; + } + let storage_size = chunk.storage.size(); // If we find a chunk of exactly the right size, we stop searching altogether @@ -310,7 +310,7 @@ impl SimpleMemoryManagement { let handle_slice = SliceHandle::new(); let storage = StorageHandle { - id: chunk.storage.id.clone(), + id: chunk.storage.id, utilization: StorageUtilization::Slice { offset: 0, size }, }; @@ -389,7 +389,7 @@ mod tests { impl SimpleMemoryManagement { fn reserve_no_sync(&mut self, size: usize) -> SimpleHandle { - self.reserve(size, || {}) + self.reserve(size, &[]) } } diff --git a/crates/cubecl-runtime/src/storage/base.rs b/crates/cubecl-runtime/src/storage/base.rs index db7e95be..968aadbc 100644 --- a/crates/cubecl-runtime/src/storage/base.rs +++ b/crates/cubecl-runtime/src/storage/base.rs @@ -58,7 +58,4 @@ pub trait ComputeStorage: Send { /// Deallocates the memory pointed by the given storage id. fn dealloc(&mut self, id: StorageId); - - /// Copy - fn copy(&mut self, from: &StorageHandle, to: &StorageHandle); } diff --git a/crates/cubecl-runtime/src/storage/bytes_cpu.rs b/crates/cubecl-runtime/src/storage/bytes_cpu.rs index 3b6c493c..18bd9bce 100644 --- a/crates/cubecl-runtime/src/storage/bytes_cpu.rs +++ b/crates/cubecl-runtime/src/storage/bytes_cpu.rs @@ -57,7 +57,7 @@ impl ComputeStorage for BytesStorage { type Resource = BytesResource; fn get(&mut self, handle: &StorageHandle) -> Self::Resource { - let allocated_bytes = self.memory.get_mut(&handle.id).unwrap(); + let allocated_bytes = self.memory.get(&handle.id).unwrap(); BytesResource { ptr: allocated_bytes.ptr, @@ -68,7 +68,7 @@ impl ComputeStorage for BytesStorage { fn alloc(&mut self, size: usize) -> StorageHandle { let id = StorageId::new(); let handle = StorageHandle { - id: id.clone(), + id, utilization: StorageUtilization::Full(size), }; @@ -90,23 +90,6 @@ impl ComputeStorage for BytesStorage { } } } - - fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) { - assert_eq!(from.size(), to.size()); - - let input = self.get(from); - let output = self.get(to); - - for i in 0..from.size() { - let offset = i + from.offset(); - let ptr_out = output.ptr.wrapping_add(offset); - - let offset = i + to.offset(); - let ptr_in = input.ptr.wrapping_add(offset); - - unsafe { *ptr_in = *ptr_out } - } - } } #[cfg(test)] @@ -127,7 +110,7 @@ mod tests { let mut storage = BytesStorage::default(); let handle_1 = storage.alloc(64); let handle_2 = StorageHandle::new( - handle_1.id.clone(), + handle_1.id, StorageUtilization::Slice { offset: 24, size: 8, diff --git a/crates/cubecl-runtime/tests/dummy/server.rs b/crates/cubecl-runtime/tests/dummy/server.rs index 2f5ade31..8c6ee178 100644 --- a/crates/cubecl-runtime/tests/dummy/server.rs +++ b/crates/cubecl-runtime/tests/dummy/server.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use cubecl_common::{reader::reader_from_concrete, sync_type::SyncType}; +use cubecl_runtime::storage::ComputeStorage; use cubecl_runtime::{ - memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement}, + memory_management::{simple::SimpleMemoryManagement, MemoryManagement}, server::{Binding, ComputeServer, Handle}, storage::{BytesResource, BytesStorage}, ExecutionMode, @@ -29,17 +30,19 @@ where type FeatureSet = (); fn read(&mut self, binding: Binding) -> cubecl_common::reader::Reader { - let bytes = self.memory_management.get(binding.memory); + let bytes_handle = self.memory_management.get(binding.memory); + let bytes = self.memory_management.storage().get(&bytes_handle); reader_from_concrete(bytes.read().to_vec()) } fn get_resource(&mut self, binding: Binding) -> BytesResource { - self.memory_management.get(binding.memory) + let handle = self.memory_management.get(binding.memory); + self.memory_management.storage().get(&handle) } fn create(&mut self, data: &[u8]) -> Handle { - let handle = self.memory_management.reserve(data.len(), || {}); - let resource = self.memory_management.get(handle.clone().binding()); + let handle = self.empty(data.len()); + let resource = self.get_resource(handle.clone().binding()); let bytes = resource.write(); @@ -47,11 +50,11 @@ where bytes[i] = *val; } - Handle::new(handle) + handle } fn empty(&mut self, size: usize) -> Handle { - Handle::new(self.memory_management.reserve(size, || {})) + Handle::new(self.memory_management.reserve(size, &[])) } unsafe fn execute( @@ -63,7 +66,7 @@ where ) { let mut resources = bindings .into_iter() - .map(|binding| self.memory_management.get(binding.memory)) + .map(|binding| self.get_resource(binding)) .collect::>(); kernel.compute(&mut resources); diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index 4d908d99..28146ed6 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -1,25 +1,18 @@ -use std::num::NonZeroU64; +use std::num::NonZero; use super::WgpuStorage; use alloc::{borrow::Cow, sync::Arc}; use cubecl_common::{reader::Reader, sync_type::SyncType}; -use cubecl_core::{compute::DebugInformation, prelude::*, FeatureSet, KernelId}; +use cubecl_core::{compute::DebugInformation, prelude::*, server::Handle, FeatureSet, KernelId}; use cubecl_runtime::{ debug::DebugLogger, memory_management::MemoryManagement, server::{self, ComputeServer}, + storage::{ComputeStorage, StorageId}, ExecutionMode, }; use hashbrown::HashMap; -use wgpu::{ - util::{BufferInitDescriptor, DeviceExt, StagingBelt}, - BindGroup, CommandEncoder, ComputePipeline, ShaderModuleDescriptor, -}; - -// Allocations with existing data smaller than this can use a staging belt -// which speeds up the allocation. A higher number here will catch more -// allocations, but can also increase memory usage. -const SMALL_ALLOC_SIZE: usize = 512; +use wgpu::{CommandEncoder, ComputePass, ComputePipeline, ShaderModuleDescriptor}; /// Wgpu compute server. #[derive(Debug)] @@ -28,13 +21,20 @@ pub struct WgpuServer> { device: Arc, queue: Arc, encoder: CommandEncoder, - staging_belt: StagingBelt, + current_pass: Option>, + tasks_count: usize, + compute_storage_used: Vec, pipelines: HashMap>, tasks_max: usize, - tasks_count: usize, logger: DebugLogger, } +fn create_encoder(device: &wgpu::Device) -> CommandEncoder { + device.create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("CubeCL Command Encoder"), + }) +} + impl WgpuServer where MM: MemoryManagement, @@ -46,59 +46,20 @@ where queue: Arc, tasks_max: usize, ) -> Self { - let encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { - label: Some("Command Encoder"), - }); - Self { memory_management, - device, - queue, - encoder, - staging_belt: StagingBelt::new(SMALL_ALLOC_SIZE as u64), + device: device.clone(), + queue: queue.clone(), + encoder: create_encoder(&device), + current_pass: None, + tasks_count: 0, + compute_storage_used: Vec::new(), pipelines: HashMap::new(), tasks_max, - tasks_count: 0, logger: DebugLogger::new(), } } - fn register_compute( - &mut self, - pipeline: Arc, - bind_group: BindGroup, - count: CubeCount, - ) { - // First resolve the dispatch buffer if needed. The weird ordering is because the lifetime of this - // needs to be longer than the compute pass, so we can't do this just before dispatching. - let dispatch_resource = match count.clone() { - CubeCount::Dynamic(binding) => Some(self.memory_management.get(binding.memory)), - _ => None, - }; - - let mut compute = self - .encoder - .begin_compute_pass(&wgpu::ComputePassDescriptor { - label: None, - timestamp_writes: None, - }); - - compute.set_pipeline(&pipeline); - compute.set_bind_group(0, &bind_group, &[]); - - match count { - CubeCount::Static(x, y, z) => { - compute.dispatch_workgroups(x, y, z); - } - CubeCount::Dynamic(_) => { - let resource = dispatch_resource.as_ref().unwrap(); - compute.dispatch_workgroups_indirect(&resource.buffer, resource.offset()); - } - } - - self.tasks_count += 1; - } - fn pipeline( &mut self, kernel: ::Kernel, @@ -152,72 +113,74 @@ where ) } - fn create_read_buffer(&mut self, handle: server::Binding) -> wgpu::Buffer { - let resource = self.memory_management.get(handle.memory); + fn clear_compute_pass(&mut self) { + self.current_pass = None; + } +} + +impl ComputeServer for WgpuServer +where + MM: MemoryManagement, +{ + type Kernel = Box; + type DispatchOptions = CubeCount; + type Storage = WgpuStorage; + type MemoryManagement = MM; + type FeatureSet = FeatureSet; + + fn read(&mut self, binding: server::Binding) -> Reader { + let resource = self.get_resource(binding); let size = resource.size(); - let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor { + let read_buffer = self.device.create_buffer(&wgpu::BufferDescriptor { label: None, size, usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, }); + self.clear_compute_pass(); + self.encoder.copy_buffer_to_buffer( &resource.buffer, resource.offset(), - &buffer_dest, + &read_buffer, 0, size, ); - self.tasks_count += 1; + // Flush all commands to the queue, so GPU gets started on copying to the staging buffer. self.sync(SyncType::Flush); - buffer_dest - } -} -impl ComputeServer for WgpuServer -where - MM: MemoryManagement, -{ - type Kernel = Box; - type DispatchOptions = CubeCount; - type Storage = WgpuStorage; - type MemoryManagement = MM; - type FeatureSet = FeatureSet; + let (sender, receiver) = async_channel::bounded(1); + let slice = read_buffer.slice(..); + slice.map_async(wgpu::MapMode::Read, move |v| { + sender + .try_send(v) + .expect("Unable to send buffer slice result to async channel."); + }); - fn read(&mut self, binding: server::Binding) -> Reader { let device = self.device.clone(); - let buffer = self.create_read_buffer(binding); Box::pin(async move { - let buffer_slice = buffer.slice(..); - let (sender, receiver) = async_channel::bounded(1); - - buffer_slice.map_async(wgpu::MapMode::Read, move |v| { - sender - .try_send(v) - .expect("Unable to send buffer slice result to async channel.") - }); - + // Now wait for the GPU to finish. device.poll(wgpu::Maintain::Wait); - let result = receiver + let slice = read_buffer.slice(..); + + receiver .recv() .await - .expect("Unable to receive buffer slice result."); + .expect("Unable to receive buffer slice result.") + .expect("Failed to map buffer"); - if let Ok(()) = result { - let data = buffer_slice.get_mapped_range(); - let result = bytemuck::cast_slice(&data).to_vec(); + let data = slice.get_mapped_range(); + let result = bytemuck::cast_slice(&data).to_vec(); - drop(data); - buffer.unmap(); - result - } else { - panic!("Unable to read buffer {:?}", result) - } + drop(data); + read_buffer.unmap(); + + result }) } @@ -225,7 +188,8 @@ where &mut self, binding: server::Binding, ) -> ::Resource { - self.memory_management.get(binding.memory) + let handle = self.memory_management.get(binding.memory); + self.memory_management.storage().get(&handle) } /// When we create a new handle from existing data, we use custom allocations so that we don't @@ -234,68 +198,28 @@ where /// This is important, otherwise the compute passes are going to be too small and we won't be able to /// fully utilize the GPU. fn create(&mut self, data: &[u8]) -> server::Handle { - let handle = server::Handle::new(self.memory_management.reserve(data.len(), || { - flush_tasks( - &mut self.encoder, - &self.queue, - &self.device, - &mut self.tasks_count, - &mut self.staging_belt, - ); - self.device.poll(wgpu::Maintain::Wait); - })); - - let non_zero_len = NonZeroU64::new(data.len() as u64); - - // If there's nothing to copy, don't need to do any work here. - if let Some(len) = non_zero_len { - let binding = handle.clone().binding(); - let resource = self.memory_management.get(binding.memory); - - if data.len() < SMALL_ALLOC_SIZE { - // Use a staging belt if the allocation is small enough. This is faster than allocating a new buffer. - // Ideally, we could use queue.write_buffer_with(), which seems to be the recommended method for performance, - // but that doesn't seem to work, as we might re-use a buffer multiple times, and need to schedule this - // precisely in the encoder. - let mut write_buf = self.staging_belt.write_buffer( - &mut self.encoder, - &resource.buffer, - resource.offset(), - len, - &self.device, - ); - write_buf.copy_from_slice(data); - } else { - let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor { - label: Some("Buffer Src"), - contents: data, - usage: wgpu::BufferUsages::COPY_SRC, - })); - self.encoder.copy_buffer_to_buffer( - &buffer_src, - 0, - &resource.buffer, - resource.offset(), - buffer_src.size(), - ); - } - self.tasks_count += 1; + // Reserve memory on some storage we haven't yet used this command queue. + let memory = self + .memory_management + .reserve(data.len(), &self.compute_storage_used); + + let handle = Handle::new(memory); + + if let Some(len) = NonZero::new(data.len() as u64) { + let resource = self.get_resource(handle.clone().binding()); + + // Write to the staging buffer. Next queue submission this will copy the data to the GPU. + self.queue + .write_buffer_with(&resource.buffer, resource.offset(), len) + .expect("Failed to write to staging buffer.") + .copy_from_slice(data); } handle } fn empty(&mut self, size: usize) -> server::Handle { - server::Handle::new(self.memory_management.reserve(size, || { - flush_tasks( - &mut self.encoder, - &self.queue, - &self.device, - &mut self.tasks_count, - &mut self.staging_belt, - ); - self.device.poll(wgpu::Maintain::Wait); - })) + server::Handle::new(self.memory_management.reserve(size, &[])) } unsafe fn execute( @@ -308,27 +232,67 @@ where let pipeline = self.pipeline(kernel, mode); let group_layout = pipeline.get_bind_group_layout(0); - let memory_handles = bindings - .into_iter() - .map(|binding| self.memory_management.get(binding.memory)) - .collect::>(); + // Store all the resources we'll be using. This could be eliminated if + // there was a way to tie the lifetime of the resource to the memory handle. + let resources: Vec<_> = bindings + .iter() + .map(|binding| { + let resource_handle = self.memory_management.get(binding.memory.clone()); + // Keep track of the storage we've used so far. + self.compute_storage_used.push(resource_handle.id); + + self.memory_management.storage().get(&resource_handle) + }) + .collect(); - let entries = memory_handles + let entries = &resources .iter() .enumerate() - .map(|(i, buffer)| wgpu::BindGroupEntry { + .map(|(i, r)| wgpu::BindGroupEntry { binding: i as u32, - resource: buffer.as_binding(), + resource: r.as_binding(), }) .collect::>(); let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { label: None, layout: &group_layout, - entries: &entries, + entries, }); - self.register_compute(pipeline, bind_group, count); + // First resolve the dispatch buffer if needed. The weird ordering is because the lifetime of this + // needs to be longer than the compute pass, so we can't do this just before dispatching. + let dispatch_resource = match count.clone() { + CubeCount::Dynamic(binding) => Some(self.get_resource(binding)), + _ => None, + }; + + self.tasks_count += 1; + + // Start a new compute pass if needed. The forget_lifetime allows + // to store this with a 'static lifetime, but the compute pass must + // be dropped before the encoder. This isn't unsafe - it's still checked at runtime. + let pass = self.current_pass.get_or_insert_with(|| { + self.encoder + .begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }) + .forget_lifetime() + }); + + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, &bind_group, &[]); + + match count { + CubeCount::Static(x, y, z) => { + pass.dispatch_workgroups(x, y, z); + } + CubeCount::Dynamic(_) => { + let resource = dispatch_resource.as_ref().unwrap(); + pass.dispatch_workgroups_indirect(&resource.buffer, resource.offset()); + } + } if self.tasks_count >= self.tasks_max { self.sync(SyncType::Flush); @@ -336,41 +300,20 @@ where } fn sync(&mut self, sync_type: SyncType) { - flush_tasks( - &mut self.encoder, - &self.queue, - &self.device, - &mut self.tasks_count, - &mut self.staging_belt, - ); + // End the current compute pass. + self.clear_compute_pass(); + let new_encoder = create_encoder(&self.device); + let encoder = std::mem::replace(&mut self.encoder, new_encoder); + self.queue.submit([encoder.finish()]); - // Cleanup allocations and deallocations. - self.memory_management.storage().perform_deallocations(); + self.tasks_count = 0; + self.compute_storage_used.clear(); if sync_type == SyncType::Wait { self.device.poll(wgpu::Maintain::Wait); } - } -} -/// Flush tasks using the [command encoder](CommandEncoder). -/// -/// This implementation is decoupled from both the [server](WgpuServer) and [memory management](MemoryManagement). -/// This decoupling allows for safe usage within sync callbacks when allocating memory buffers. -fn flush_tasks( - encoder: &mut CommandEncoder, - queue: &wgpu::Queue, - device: &wgpu::Device, - tasks_count: &mut usize, - staging_belt: &mut StagingBelt, -) { - staging_belt.finish(); - - let mut new_encoder = - device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); - core::mem::swap(&mut new_encoder, encoder); - - queue.submit(Some(new_encoder.finish())); - *tasks_count = 0; - staging_belt.recall(); + // Cleanup allocations and deallocations. + self.memory_management.storage().perform_deallocations(); + } } diff --git a/crates/cubecl-wgpu/src/compute/storage.rs b/crates/cubecl-wgpu/src/compute/storage.rs index d3974efc..ac73bd1e 100644 --- a/crates/cubecl-wgpu/src/compute/storage.rs +++ b/crates/cubecl-wgpu/src/compute/storage.rs @@ -7,7 +7,6 @@ pub struct WgpuStorage { memory: HashMap>, deallocations: Vec, device: Arc, - queue: Arc, } impl core::fmt::Debug for WgpuStorage { @@ -68,12 +67,11 @@ pub enum WgpuResourceKind { /// Keeps actual wgpu buffer references in a hashmap with ids as key. impl WgpuStorage { /// Create a new storage on the given [device](wgpu::Device). - pub fn new(device: Arc, queue: Arc) -> Self { + pub fn new(device: Arc) -> Self { Self { memory: HashMap::new(), deallocations: Vec::new(), device, - queue, } } @@ -116,7 +114,7 @@ impl ComputeStorage for WgpuStorage { mapped_at_creation: false, })); - self.memory.insert(id.clone(), buffer); + self.memory.insert(id, buffer); StorageHandle::new(id, StorageUtilization::Full(size)) } @@ -124,23 +122,4 @@ impl ComputeStorage for WgpuStorage { fn dealloc(&mut self, id: StorageId) { self.deallocations.push(id); } - - fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) { - let mut encoder = self - .device - .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); - - let from = self.get(from); - let to = self.get(to); - - encoder.copy_buffer_to_buffer( - &from.buffer, - from.offset(), - &to.buffer, - to.offset(), - to.size(), - ); - - self.queue.submit(Some(encoder.finish())); - } } diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index 2b2c7db2..df935910 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -117,7 +117,7 @@ fn create_client( MutexComputeChannel>>, > { let limits = device_wgpu.limits(); - let storage = WgpuStorage::new(device_wgpu.clone(), queue.clone()); + let storage = WgpuStorage::new(device_wgpu.clone()); let memory_management = DynamicMemoryManagement::new( storage, DynamicMemoryManagementOptions::preset(