diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index 5efae0f0..d96f33a7 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -729,23 +729,22 @@ void _FlashInferBatchQKApplyRotaryInPlace(DLTensor* q, DLTensor* k, DLTensor* in })}); } -// TODO(Zihao): Unify the symbol names -TVM_REGISTER_GLOBAL("paged_kv_cache.attention_kernel_prefill") +TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_prefill_with_paged_kv_cache") .set_body_typed(_FlashInferAttentionPrefillWithPagedKVCache); -TVM_REGISTER_GLOBAL("paged_kv_cache.attention_kernel_prefill_begin_forward") +TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_prefill_with_paged_kv_cache_begin_forward") .set_body_typed(_FlashInferAttentionPrefillWithPagedKVCacheBeginForward); -TVM_REGISTER_GLOBAL("paged_kv_cache.attention_kernel_prefill_end_forward") +TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_prefill_with_paged_kv_cache_end_forward") .set_body_typed(_FlashInferAttentionPrefillWithPagedKVCacheEndForward); -TVM_REGISTER_GLOBAL("paged_kv_cache.attention_kernel_decode") +TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_decode_with_paged_kv_cache") .set_body_typed(_FlashInferAttentionDecodeWithPagedKVCache); -TVM_REGISTER_GLOBAL("paged_kv_cache.attention_kernel_decode_begin_forward") +TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_decode_with_paged_kv_cache_begin_forward") .set_body_typed(_FlashInferAttentionDecodeWithPagedKVCacheBeginForward); -TVM_REGISTER_GLOBAL("paged_kv_cache.attention_kernel_decode_end_forward") +TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward") .set_body_typed(_FlashInferAttentionDecodeWithPagedKVCacheEndForward); TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_prefill_with_ragged_kv_cache")