Skip to content

Commit

Permalink
refactor: change tvm wrapper symbol names (#141)
Browse files Browse the repository at this point in the history
This PR unifies the naming convention of tvm wrapper symbols
@MasterJH5574 .
  • Loading branch information
yzh119 authored Feb 27, 2024
1 parent 3d55c71 commit ae68085
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/tvm_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -729,23 +729,22 @@ void _FlashInferBatchQKApplyRotaryInPlace(DLTensor* q, DLTensor* k, DLTensor* in
})});
}

// TODO(Zihao): Unify the symbol names
TVM_REGISTER_GLOBAL("paged_kv_cache.attention_kernel_prefill")
TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_prefill_with_paged_kv_cache")
.set_body_typed(_FlashInferAttentionPrefillWithPagedKVCache);

TVM_REGISTER_GLOBAL("paged_kv_cache.attention_kernel_prefill_begin_forward")
TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_prefill_with_paged_kv_cache_begin_forward")
.set_body_typed(_FlashInferAttentionPrefillWithPagedKVCacheBeginForward);

TVM_REGISTER_GLOBAL("paged_kv_cache.attention_kernel_prefill_end_forward")
TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_prefill_with_paged_kv_cache_end_forward")
.set_body_typed(_FlashInferAttentionPrefillWithPagedKVCacheEndForward);

TVM_REGISTER_GLOBAL("paged_kv_cache.attention_kernel_decode")
TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_decode_with_paged_kv_cache")
.set_body_typed(_FlashInferAttentionDecodeWithPagedKVCache);

TVM_REGISTER_GLOBAL("paged_kv_cache.attention_kernel_decode_begin_forward")
TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_decode_with_paged_kv_cache_begin_forward")
.set_body_typed(_FlashInferAttentionDecodeWithPagedKVCacheBeginForward);

TVM_REGISTER_GLOBAL("paged_kv_cache.attention_kernel_decode_end_forward")
TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward")
.set_body_typed(_FlashInferAttentionDecodeWithPagedKVCacheEndForward);

TVM_REGISTER_GLOBAL("flashinfer.attention_kernel_prefill_with_ragged_kv_cache")
Expand Down

0 comments on commit ae68085

Please sign in to comment.