Skip to content

Commit

Permalink
[Fix] PagedKVCache fetching compute stream when copy stream is needed (
Browse files Browse the repository at this point in the history
…#16714)

This PR fixes an issue in PagedKVCache, where a compute stream will
always be fetched. For backends like WebGPU, the `GetCurrentStream`
function is not implemented, which leads to an error when fetching
the compute stream.
  • Loading branch information
MasterJH5574 authored Mar 13, 2024
1 parent 8023a98 commit 981009d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
free_page_ids_.push_back(page_id);
}

// The compute stream is the default stream.
// If the device is CUDA/ROCm, we create a standalone copy stream, in
// purpose to hide the latency of auxiliary stream copy.
compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device);
if (device.device_type == DLDeviceType::kDLCUDA ||
device.device_type == DLDeviceType::kDLROCM) {
// The compute stream is the default stream.
compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device);
copy_stream_ = DeviceAPI::Get(device)->CreateStream(device);
}
}
Expand Down

0 comments on commit 981009d

Please sign in to comment.