Skip to content

Commit

Permalink
[TIR] Add is_vector Method to DataType class and update usages acro…
Browse files Browse the repository at this point in the history
…ss Codebase (#17443)

* Refactor data_type.h and c_runtime_api.h

This commit refactors the `data_type.h` and `c_runtime_api.h` files. It introduces a new function `is_vector()` in the `DataType` class to check if a type is a vector type. Additionally, it adds a new constant `kTVMGridConstant` in the `TVMTypeCode` enum in `c_runtime_api.h`. These changes improve the code organization and provide better support for vector types.

* revert kTVMGridConstant

* lint fix
  • Loading branch information
LeiWang1999 authored Oct 6, 2024
1 parent accd582 commit ff0b07b
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 10 deletions.
2 changes: 2 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class DataType {
bool is_fixed_length_vector() const { return static_cast<int16_t>(data_.lanes) > 1; }
/*! \return Whether the type is a scalable vector. */
bool is_scalable_vector() const { return static_cast<int16_t>(data_.lanes) < -1; }
/*! \return whether type is a vector type. */
bool is_vector() const { return lanes() > 1; }
/*! \return whether type is a bool vector type. */
bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; }
/*! \return whether type is a Void type. */
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast",
if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) {
if (expr.dtype().lanes() == type.lanes()) {
return expr;
} else if (expr.dtype().lanes() == 1 && type.lanes() > 1) {
} else if (expr.dtype().lanes() == 1 && type.is_vector()) {
return tvm::tir::Broadcast(expr, type.lanes());
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1737,7 +1737,7 @@ void CodeGenLLVM::BufferAccessHelper(
if (const RampNode* ramp = last_index.as<RampNode>()) {
PrimExpr offset = ramp->base + (ramp->stride * i);
last_index_value = MakeValue(offset);
} else if (last_index.dtype().lanes() > 1) {
} else if (last_index.dtype().is_vector()) {
if (i == 0) {
cached_vector_index = MakeValue(last_index);
}
Expand Down
8 changes: 4 additions & 4 deletions src/target/llvm/intrin_rule_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) {

// Enable QHL library for FP16 data type
const PrimExpr& x = call->args[0];
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
return TVMExternCall(call, tvm_wrapper);
}
#endif
Expand Down Expand Up @@ -116,7 +116,7 @@ TVM_REGISTER_OP("tir.tanh")
}

// Enable QHL library for FP16 data type
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf");
return TVMExternCall(call, tvm_wrapper);
}
Expand Down Expand Up @@ -152,7 +152,7 @@ TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>(
}

// Enable QHL library for FP16 data type
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf");
return TVMExternCall(call, tvm_wrapper);
}
Expand Down Expand Up @@ -191,7 +191,7 @@ TVM_REGISTER_OP("tir.sigmoid")
const tir::Call new_call = tir::Call(call->dtype, call->op, new_args);

// Enable QHL library for FP16 data type
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf");
return TVMExternCall(new_call.get(), tvm_wrapper);
}
Expand Down
8 changes: 4 additions & 4 deletions src/tir/analysis/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
size_t size = static_cast<size_t>(op->ConstantAllocationSize());
shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
}
if (op->dtype.lanes() > 1) {
if (op->dtype.is_vector()) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
Expand Down Expand Up @@ -202,7 +202,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}

void VisitExpr_(const CastNode* op) {
if (op->dtype.lanes() > 1) {
if (op->dtype.is_vector()) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
Expand All @@ -215,7 +215,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}

void VisitExpr_(const BufferLoadNode* op) {
if (op->dtype.lanes() > 1) {
if (op->dtype.is_vector()) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
Expand All @@ -229,7 +229,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}

void VisitStmt_(const BufferStoreNode* op) {
if (op->value->dtype.lanes() > 1) {
if (op->value->dtype.is_vector()) {
if (static_cast<size_t>(op->value->dtype.lanes() * op->value->dtype.bytes()) >
max_vector_bytes_) {
std::stringstream s;
Expand Down

0 comments on commit ff0b07b

Please sign in to comment.