diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 5b6fbe7b2546..7aefe78155a3 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -75,6 +75,12 @@ # Contrib initializers from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel +# Relay and Relax contain modules that are only available in compiler package +# Do not import them if TVM is built with runtime only +if not _RUNTIME_ONLY: + from . import relay + from . import relax + if not _RUNTIME_ONLY and support.libinfo().get("USE_MICRO", "OFF") == "ON": from . import micro diff --git a/python/tvm/contrib/debugger/debug_executor.py b/python/tvm/contrib/debugger/debug_executor.py index e4efcc51b311..75932c0d5e34 100644 --- a/python/tvm/contrib/debugger/debug_executor.py +++ b/python/tvm/contrib/debugger/debug_executor.py @@ -76,12 +76,12 @@ def create(graph_json_str, libmod, device, dump_root=None): # Automatically set params if they can be extracted from the libmod try: params = libmod["get_graph_params"]() + if isinstance(params, tvm.ir.container.Map): + gmod.set_input(**params) except (AttributeError, tvm.error.RPCError): # Params can not be extracted from the libmod and must be set somewhere else manually # Do not set params during RPC communication pass - else: - gmod.set_input(**params) return gmod diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 5bf96aef775e..7b186d7098b9 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -128,11 +128,15 @@ def __init__(self, *args, **kwargs): self._inst._outer = weakref.ref(self) def __getattr__(self, name): - # fall back to instance attribute if there is not any - # return self._inst.__getattribute__(name) import inspect # pylint: disable=import-outside-toplevel - result = self._inst.__getattribute__(name) + try: + # fall back to instance attribute if there is not any + # return self._inst.__getattribute__(name) + result = self._inst.__getattribute__(name) + except AttributeError: + result = super(TVMDerivedObject, self).__getattr__(name) + if inspect.ismethod(result): def method(*args, **kwargs): diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 366784c04fc0..f90df9941766 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -2143,6 +2143,15 @@ Array IterMapSimplify(const Array& indices, const Map rewrite = res->indices; + if (rewrite.empty() && !is_one(input_pred) && check_level != IterMapLevel::Bijective) { + // The input predicate may cause detect iter map to fail + // but we can still detect the iter map without the input predicate + // in which case the resulting iter map is valid and can be used for simplification. + rewrite = DetectIterMap(indices, input_iters, const_true(), check_level, ana, + /*simplify_trivial_iterators=*/simplify_trivial_iterators) + ->indices; + } + if (rewrite.empty()) { return indices; } diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 50c8b8f5254e..e0996cd72fc2 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -318,7 +318,17 @@ void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, ObjectPath path) { for (size_t i = 0; i < op->match_buffers.size(); i++) { auto buf = op->match_buffers[i]->buffer; auto buffer_path = match_path->ArrayIndex(i)->Attr("buffer"); + auto buffer_strides_path = buffer_path->Attr("strides"); context.push_back(WithDef(buf->data, buffer_path->Attr("data"))); + // Define buffer strides and elem_offset if they are vars + if (const auto* v = buf->elem_offset.as()) { + context.push_back(WithDef(GetRef(v), buffer_path->Attr("elem_offset"))); + } + for (size_t i = 0; i < buf->strides.size(); ++i) { + if (const auto* v = buf->strides[i].as()) { + context.push_back(WithDef(GetRef(v), buffer_strides_path->ArrayIndex(i))); + } + } context.push_back(WithDef(buf, buffer_path)); } } diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 4c0f422c0d3a..6da2f873b728 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -320,13 +320,16 @@ class BuiltinLower : public StmtExprMutator { PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_call_packed())) { return MakeCallPackedGeneric(op, 0, builtin::tvm_call_packed_lowered(), - /* use_string_lookup */ true); + /* use_string_lookup */ true, + /* use_last_value_as_traced_value*/ false); } else if (op->op.same_as(builtin::tvm_call_cpacked())) { return MakeCallPackedGeneric(op, 0, builtin::tvm_call_cpacked_lowered(), - /* use_string_lookup */ false); + /* use_string_lookup */ false, + /* use_last_value_as_traced_value*/ false); } else if (op->op.same_as(builtin::tvm_call_trace_packed())) { return MakeCallPackedGeneric(op, 0, builtin::tvm_call_trace_packed_lowered(), - /* use_string_lookup */ true); + /* use_string_lookup */ true, + /* use_last_value_as_traced_value*/ true); } else if (op->op.same_as(builtin::anylist_setitem_call_packed())) { return MakeAnyListSetItemCallPacked(op, builtin::tvm_call_packed_lowered(), true); } else if (op->op.same_as(builtin::anylist_setitem_call_cpacked())) { @@ -510,7 +513,7 @@ class BuiltinLower : public StmtExprMutator { PrimExpr list_handle = op->args[0]; PrimExpr list_index = op->args[1]; - Call call = MakeCallPackedGeneric(op, 2, lowered_op, use_string_lookup); + Call call = MakeCallPackedGeneric(op, 2, lowered_op, use_string_lookup, false); PrimExpr value_stack = call->args[1]; PrimExpr tcode_stack = call->args[2]; // The stack offset of return value stack_end @@ -528,9 +531,10 @@ class BuiltinLower : public StmtExprMutator { * \param name_offset The beginning of function name and call packed section. * \param lowered_packed_op The target lowered op. * \param use_string_lookup Whether to lookup function by string. + * \param pass_last_arg_as_traced_value Whether to pass last argument as traced value */ Call MakeCallPackedGeneric(const CallNode* op, size_t name_offset, const Op& lowered_packed_op, - bool use_string_lookup) { + bool use_string_lookup, bool pass_last_arg_as_traced_value) { auto& scope = alloca_scope_.back(); auto& prep_seq = prep_seq_stack_.back(); @@ -571,6 +575,7 @@ class BuiltinLower : public StmtExprMutator { ConstInt32(arg_stack_begin + num_args)}; // cpacked call resource_handle if (!use_string_lookup) { + ICHECK(!pass_last_arg_as_traced_value); PrimExpr last_arg = op->args[args_end]; const VarNode* var_node = last_arg.as(); if (var_node != nullptr) { @@ -579,6 +584,10 @@ class BuiltinLower : public StmtExprMutator { } else { packed_args.push_back(last_arg); } + } else if (pass_last_arg_as_traced_value) { + // pass in last element as traced value + // used by call_packed_traced + packed_args.push_back(op->args[op->args.size() - 1]); } return Call(op->dtype, lowered_packed_op, packed_args); } diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 3b5cc0001975..5fb7526b217b 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -1069,6 +1069,8 @@ def check_cuda(n, lanes): check_cuda(64, 2) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda def test_cuda_thread_sync_inside_condition(): @T.prim_func def func1(A: T.Buffer((4, 4), "float32")) -> None: diff --git a/tests/python/meta_schedule/test_meta_schedule_mma_m16n8k8_auto_tensorization.py b/tests/python/meta_schedule/test_meta_schedule_mma_m16n8k8_auto_tensorization.py index 441aa39c7278..68f26bd3ee6c 100644 --- a/tests/python/meta_schedule/test_meta_schedule_mma_m16n8k8_auto_tensorization.py +++ b/tests/python/meta_schedule/test_meta_schedule_mma_m16n8k8_auto_tensorization.py @@ -696,10 +696,6 @@ class TVM_ALIGNED(2) half { return (v1 << 16) | v0; } -// Some fp16 math functions are not supported in cuda_fp16.h, -// so we define them here to make sure the generated CUDA code -// is valid. -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) #define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \ static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) { \ float tmp_x = __half2float(x); \ @@ -715,16 +711,31 @@ class TVM_ALIGNED(2) half { return __float2half(result); \ } +// Some fp16 math functions are not supported in cuda_fp16.h, +// so we define them here to make sure the generated CUDA code +// is valid. +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 530) CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf) +#else +CUDA_UNSUPPORTED_HALF_MATH_UNARY(hexp, exp) +#endif +#endif #undef CUDA_UNSUPPORTED_HALF_MATH_BINARY #undef CUDA_UNSUPPORTED_HALF_MATH_UNARY - -#endif +__forceinline__ __device__ unsigned int +cast_smem_ptr_to_int(const void* const smem_ptr) +{ + unsigned int smem_int; + asm volatile ("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; cvt.u32.u64 %0, smem_int; }" + : "=r"(smem_int) : "l"(smem_ptr)); + return smem_int; +} #if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ (__CUDACC_VER_MAJOR__ > 11)) @@ -746,6 +757,7 @@ class TVM_ALIGNED(2) half { #define int64_t long long #define uint64_t unsigned long long #endif +extern "C" __global__ void __launch_bounds__(128) main_kernel(half* __restrict__ C, half* __restrict__ X, half* __restrict__ Y); extern "C" __global__ void __launch_bounds__(128) main_kernel(half* __restrict__ C, half* __restrict__ X, half* __restrict__ Y) { extern __shared__ uchar buf_dyn_shmem[]; uint1 C_reindex_m16n8k8_matrixC[64]; @@ -761,12 +773,7 @@ class TVM_ALIGNED(2) half { for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)(buf_dyn_shmem + (((((ax0_ax1_fused_0 * 2048) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 16)) + 24576))) - ); + unsigned int addr = cast_smem_ptr_to_int(buf_dyn_shmem + ((((ax0_ax1_fused_0 * 2048) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 16))); __asm__ __volatile__( #if TVM_ENABLE_L2_PREFETCH "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" @@ -780,12 +787,7 @@ class TVM_ALIGNED(2) half { for (int ax0_ax1_fused_0_1 = 0; ax0_ax1_fused_0_1 < 4; ++ax0_ax1_fused_0_1) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)(buf_dyn_shmem + ((((ax0_ax1_fused_0_1 * 2048) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 3) * 128)) + (((((int)threadIdx.x) & 7) ^ ((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4))) * 16)))) - ); + unsigned int addr = cast_smem_ptr_to_int(buf_dyn_shmem + (((((ax0_ax1_fused_0_1 * 2048) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 4) * 256)) + (((((int)threadIdx.x) & 15) ^ ((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4))) * 16)) + 24576)); __asm__ __volatile__( #if TVM_ENABLE_L2_PREFETCH "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" @@ -801,12 +803,7 @@ class TVM_ALIGNED(2) half { for (int ax0_ax1_fused_0_2 = 0; ax0_ax1_fused_0_2 < 4; ++ax0_ax1_fused_0_2) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)(buf_dyn_shmem + (((((ax0_ax1_fused_0_2 * 2048) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 16)) + 32768))) - ); + unsigned int addr = cast_smem_ptr_to_int(buf_dyn_shmem + (((((ax0_ax1_fused_0_2 * 2048) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 16)) + 8192)); __asm__ __volatile__( #if TVM_ENABLE_L2_PREFETCH "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" @@ -820,12 +817,7 @@ class TVM_ALIGNED(2) half { for (int ax0_ax1_fused_0_3 = 0; ax0_ax1_fused_0_3 < 4; ++ax0_ax1_fused_0_3) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)(buf_dyn_shmem + (((((ax0_ax1_fused_0_3 * 2048) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 3) * 128)) + (((((int)threadIdx.x) & 7) ^ ((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4))) * 16)) + 8192))) - ); + unsigned int addr = cast_smem_ptr_to_int(buf_dyn_shmem + (((((ax0_ax1_fused_0_3 * 2048) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 4) * 256)) + (((((int)threadIdx.x) & 15) ^ ((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4))) * 16)) + 32768)); __asm__ __volatile__( #if TVM_ENABLE_L2_PREFETCH "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" @@ -844,12 +836,7 @@ class TVM_ALIGNED(2) half { for (int ax0_0 = 0; ax0_0 < 2; ++ax0_0) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(((half*)buf_dyn_shmem)[12288])) + (((((((int)threadIdx.y) >> 1) * 2048) + (ax0_0 * 1024)) + (((int)threadIdx.x) * 32)) + ((0 ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)))) - ); + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[(((((((int)threadIdx.y) >> 1) * 2048) + (ax0_0 * 1024)) + (((int)threadIdx.x) * 32)) + ((0 ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0); __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.shared.b16" "{%0, %1, %2, %3}, [%4];\n" @@ -861,12 +848,7 @@ class TVM_ALIGNED(2) half { for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(((half*)buf_dyn_shmem)[0])) + ((((((int)threadIdx.x) & 7) * 128) + ((((int)threadIdx.y) & 1) * 64)) + ((((ax1_0 * 4) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)))) - ); + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((((int)threadIdx.x) & 7) * 128) + ((((((((int)threadIdx.y) & 1) * 8) + (ax1_0 * 4)) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 12288)])) + 0); __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "{%0, %1, %2, %3}, [%4];\n" @@ -880,12 +862,7 @@ class TVM_ALIGNED(2) half { for (int ax0_ax1_fused_0_4 = 0; ax0_ax1_fused_0_4 < 4; ++ax0_ax1_fused_0_4) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)(buf_dyn_shmem + ((((((((ax2_0_0 + 2) % 3) * 8192) + (ax0_ax1_fused_0_4 * 2048)) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 16)) + 24576))) - ); + unsigned int addr = cast_smem_ptr_to_int(buf_dyn_shmem + (((((((ax2_0_0 + 2) % 3) * 8192) + (ax0_ax1_fused_0_4 * 2048)) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 16))); __asm__ __volatile__( #if TVM_ENABLE_L2_PREFETCH "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" @@ -899,12 +876,7 @@ class TVM_ALIGNED(2) half { for (int ax0_ax1_fused_0_5 = 0; ax0_ax1_fused_0_5 < 4; ++ax0_ax1_fused_0_5) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)(buf_dyn_shmem + (((((((ax2_0_0 + 2) % 3) * 8192) + (ax0_ax1_fused_0_5 * 2048)) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 3) * 128)) + (((((int)threadIdx.x) & 7) ^ ((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4))) * 16)))) - ); + unsigned int addr = cast_smem_ptr_to_int(buf_dyn_shmem + ((((((((ax2_0_0 + 2) % 3) * 8192) + (ax0_ax1_fused_0_5 * 2048)) + (((int)threadIdx.y) * 512)) + ((((int)threadIdx.x) >> 4) * 256)) + (((((int)threadIdx.x) & 15) ^ ((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 4))) * 16)) + 24576)); __asm__ __volatile__( #if TVM_ENABLE_L2_PREFETCH "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" @@ -924,12 +896,7 @@ class TVM_ALIGNED(2) half { for (int ax0_0_1 = 0; ax0_0_1 < 2; ++ax0_0_1) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(((half*)buf_dyn_shmem)[(((ax2_0_0 % 3) * 4096) + 12288)])) + (((((((int)threadIdx.y) >> 1) * 2048) + (ax0_0_1 * 1024)) + (((int)threadIdx.x) * 32)) + (((ax2_0_1 + 1) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)))) - ); + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((((ax2_0_0 % 3) * 4096) + ((((int)threadIdx.y) >> 1) * 2048)) + (ax0_0_1 * 1024)) + (((int)threadIdx.x) * 32)) + (((ax2_0_1 + 1) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0); __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.shared.b16" "{%0, %1, %2, %3}, [%4];\n" @@ -941,12 +908,7 @@ class TVM_ALIGNED(2) half { for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(((half*)buf_dyn_shmem)[((ax2_0_0 % 3) * 4096)])) + (((((ax2_0_1 * 1024) + ((((int)threadIdx.x) & 7) * 128)) + ((((int)threadIdx.y) & 1) * 64)) + ((((ax1_0_1 * 4) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 1024))) - ); + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((((ax2_0_0 % 3) * 4096) + (ax2_0_1 * 1024)) + ((((int)threadIdx.x) & 7) * 128)) + ((((((((int)threadIdx.y) & 1) * 8) + (ax1_0_1 * 4)) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 13312)])) + 0); __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "{%0, %1, %2, %3}, [%4];\n" @@ -971,12 +933,7 @@ class TVM_ALIGNED(2) half { for (int ax0_0_2 = 0; ax0_0_2 < 2; ++ax0_0_2) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(((half*)buf_dyn_shmem)[((((ax2_0_0 + 1) % 3) * 4096) + 12288)])) + (((((((int)threadIdx.y) >> 1) * 2048) + (ax0_0_2 * 1024)) + (((int)threadIdx.x) * 32)) + ((0 ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)))) - ); + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[(((((((ax2_0_0 + 1) % 3) * 4096) + ((((int)threadIdx.y) >> 1) * 2048)) + (ax0_0_2 * 1024)) + (((int)threadIdx.x) * 32)) + ((0 ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0); __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.shared.b16" "{%0, %1, %2, %3}, [%4];\n" @@ -988,12 +945,7 @@ class TVM_ALIGNED(2) half { for (int ax1_0_2 = 0; ax1_0_2 < 2; ++ax1_0_2) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(((half*)buf_dyn_shmem)[(((ax2_0_0 + 1) % 3) * 4096)])) + ((((((int)threadIdx.x) & 7) * 128) + ((((int)threadIdx.y) & 1) * 64)) + ((((ax1_0_2 * 4) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)))) - ); + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((((ax2_0_0 + 1) % 3) * 4096) + ((((int)threadIdx.x) & 7) * 128)) + ((((((((int)threadIdx.y) & 1) * 8) + (ax1_0_2 * 4)) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 12288)])) + 0); __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "{%0, %1, %2, %3}, [%4];\n" @@ -1022,12 +974,7 @@ class TVM_ALIGNED(2) half { for (int ax0_0_3 = 0; ax0_0_3 < 2; ++ax0_0_3) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(((half*)buf_dyn_shmem)[12288])) + (((((((int)threadIdx.y) >> 1) * 2048) + (ax0_0_3 * 1024)) + (((int)threadIdx.x) * 32)) + (((ax2_0_1_1 + 1) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)))) - ); + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[(((((((int)threadIdx.y) >> 1) * 2048) + (ax0_0_3 * 1024)) + (((int)threadIdx.x) * 32)) + (((ax2_0_1_1 + 1) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0); __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.shared.b16" "{%0, %1, %2, %3}, [%4];\n" @@ -1039,12 +986,7 @@ class TVM_ALIGNED(2) half { for (int ax1_0_3 = 0; ax1_0_3 < 2; ++ax1_0_3) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(((half*)buf_dyn_shmem)[0])) + (((((ax2_0_1_1 * 1024) + ((((int)threadIdx.x) & 7) * 128)) + ((((int)threadIdx.y) & 1) * 64)) + ((((ax1_0_3 * 4) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 1024))) - ); + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((ax2_0_1_1 * 1024) + ((((int)threadIdx.x) & 7) * 128)) + ((((((((int)threadIdx.y) & 1) * 8) + (ax1_0_3 * 4)) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 13312)])) + 0); __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "{%0, %1, %2, %3}, [%4];\n" @@ -1069,12 +1011,7 @@ class TVM_ALIGNED(2) half { for (int ax0_0_5 = 0; ax0_0_5 < 2; ++ax0_0_5) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(((half*)buf_dyn_shmem)[16384])) + (((((((int)threadIdx.y) >> 1) * 2048) + (ax0_0_5 * 1024)) + (((int)threadIdx.x) * 32)) + ((0 ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)))) - ); + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((((((int)threadIdx.y) >> 1) * 2048) + (ax0_0_5 * 1024)) + (((int)threadIdx.x) * 32)) + ((0 ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)) + 4096)])) + 0); __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.shared.b16" "{%0, %1, %2, %3}, [%4];\n" @@ -1086,12 +1023,7 @@ class TVM_ALIGNED(2) half { for (int ax1_0_5 = 0; ax1_0_5 < 2; ++ax1_0_5) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(((half*)buf_dyn_shmem)[4096])) + ((((((int)threadIdx.x) & 7) * 128) + ((((int)threadIdx.y) & 1) * 64)) + ((((ax1_0_5 * 4) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)))) - ); + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((((int)threadIdx.x) & 7) * 128) + ((((((((int)threadIdx.y) & 1) * 8) + (ax1_0_5 * 4)) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 16384)])) + 0); __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "{%0, %1, %2, %3}, [%4];\n" @@ -1116,12 +1048,7 @@ class TVM_ALIGNED(2) half { for (int ax0_0_6 = 0; ax0_0_6 < 2; ++ax0_0_6) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(((half*)buf_dyn_shmem)[16384])) + (((((((int)threadIdx.y) >> 1) * 2048) + (ax0_0_6 * 1024)) + (((int)threadIdx.x) * 32)) + (((ax2_0_1_2 + 1) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)))) - ); + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((((((int)threadIdx.y) >> 1) * 2048) + (ax0_0_6 * 1024)) + (((int)threadIdx.x) * 32)) + (((ax2_0_1_2 + 1) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)) + 4096)])) + 0); __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.shared.b16" "{%0, %1, %2, %3}, [%4];\n" @@ -1133,12 +1060,7 @@ class TVM_ALIGNED(2) half { for (int ax1_0_6 = 0; ax1_0_6 < 2; ++ax1_0_6) { { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(((half*)buf_dyn_shmem)[4096])) + (((((ax2_0_1_2 * 1024) + ((((int)threadIdx.x) & 7) * 128)) + ((((int)threadIdx.y) & 1) * 64)) + ((((ax1_0_6 * 4) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 1024))) - ); + unsigned int addr = cast_smem_ptr_to_int((&(((half*)buf_dyn_shmem)[((((ax2_0_1_2 * 1024) + ((((int)threadIdx.x) & 7) * 128)) + ((((((((int)threadIdx.y) & 1) * 8) + (ax1_0_6 * 4)) + (((int)threadIdx.x) >> 3)) ^ (((int)threadIdx.x) & 7)) * 8)) + 17408)])) + 0); __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "{%0, %1, %2, %3}, [%4];\n" @@ -1175,11 +1097,11 @@ class TVM_ALIGNED(2) half { for (int ax0_0_7 = 0; ax0_0_7 < 8; ++ax0_0_7) { __syncthreads(); for (int ax1_0_7 = 0; ax1_0_7 < 8; ++ax1_0_7) { - *(uint1*)(((half*)buf_dyn_shmem) + ((((((int)threadIdx.y) * 512) + (ax1_0_7 * 64)) + (((int)threadIdx.x) * 2)) + 12288)) = C_reindex_m16n8k8_matrixC[((((ax0_0_7 >> 1) * 16) + (ax1_0_7 * 2)) + (ax0_0_7 & 1))]; + *(uint1*)(((half*)buf_dyn_shmem) + ((((((int)threadIdx.x) * 2050) + (((int)threadIdx.y) * 512)) + (ax1_0_7 * 64)) + 12288)) = C_reindex_m16n8k8_matrixC[((((ax0_0_7 >> 1) * 16) + (ax1_0_7 * 2)) + (ax0_0_7 & 1))]; } __syncthreads(); - for (int ax0_0_2_ax1_0_2_fused_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 = 0; ax0_0_2_ax1_0_2_fused_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 < 16; ++ax0_0_2_ax1_0_2_fused_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0) { - C[(((((((((((((int)blockIdx.y) >> 3) * 524288) + ((ax0_0_2_ax1_0_2_fused_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 >> 3) * 262144)) + (ax0_0_7 * 32768)) + ((((int)threadIdx.y) & 1) * 16384)) + ((((int)threadIdx.x) >> 3) * 4096)) + (((int)blockIdx.x) * 1024)) + ((((int)blockIdx.y) & 7) * 128)) + ((ax0_0_2_ax1_0_2_fused_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 & 7) * 16)) + ((((int)threadIdx.y) >> 1) * 8)) + (((int)threadIdx.x) & 7))] = ((half*)buf_dyn_shmem)[((((ax0_0_2_ax1_0_2_fused_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 * 128) + (((int)threadIdx.y) * 32)) + ((int)threadIdx.x)) + 12288)]; + for (int threadIdx_x_cache_threadIdx_y_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 = 0; threadIdx_x_cache_threadIdx_y_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 < 512; ++threadIdx_x_cache_threadIdx_y_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0) { + C[(((((((((((((int)blockIdx.y) >> 3) * 524288) + (((threadIdx_x_cache_threadIdx_y_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 & 15) >> 3) * 262144)) + (ax0_0_7 * 32768)) + ((((int)threadIdx.y) & 1) * 16384)) + ((((int)threadIdx.x) >> 3) * 4096)) + (((int)blockIdx.x) * 1024)) + ((((int)blockIdx.y) & 7) * 128)) + ((threadIdx_x_cache_threadIdx_y_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 & 7) * 16)) + ((((int)threadIdx.y) >> 1) * 8)) + (((int)threadIdx.x) & 7))] = ((half*)buf_dyn_shmem)[((((threadIdx_x_cache_threadIdx_y_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 * 128) + (((int)threadIdx.y) * 32)) + ((int)threadIdx.x)) + 12288)]; } } } @@ -1229,7 +1151,7 @@ def test_mma_script_after_build(): with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): rt_mod = tvm.build(sch.mod, target="cuda") - + print(rt_mod.imported_modules[0].get_source()) assert rt_mod.imported_modules[0].get_source() == expected_cuda_script diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py index a8ce704bd0ce..8cc1c7c7aa44 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py @@ -296,7 +296,7 @@ def main( W_shared[j, k_o * 4 : k_o * 4 + 4], ) T.writes(compute_local[i, j]) - T.block_attr({"meta_schedule.auto_tensorize": "dp4a"}) + T.block_attr({"meta_schedule.auto_tensorize": "dp4a_s8s8s32"}) with T.init(): with T.block("compute_init"): T.reads() diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py index d1f4b6bdce7c..034bddd97132 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py @@ -900,128 +900,130 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 def test_conv_1x1(): # fmt: off @T.prim_func - def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64), "float32")) -> None: + def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64), "float32")): T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) - conv2d_nhwc_reindex_shared = T.alloc_buffer((16, 4, 1, 1, 16, 16), scope="shared") - conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((16, 4, 1, 1, 16, 16), scope="wmma.accumulator") + # with T.block("root"): + conv2d_nhwc_reindex_shared = T.alloc_buffer((2, 2, 8, 2, 16, 16), scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((2, 2, 8, 2, 16, 16), scope="wmma.accumulator") PadInput_reindex_shared = T.alloc_buffer((256, 64), "float16", scope="shared") weight_reindex_shared = T.alloc_buffer((1, 1, 64, 64), "float16", scope="shared") PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 64), "float16", scope="wmma.matrix_a") weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 64, 64), "float16", scope="wmma.matrix_b") - for ax2_0_0_ax3_0_0_fused in T.thread_binding(16, thread="blockIdx.y"): - for ax2_0_1_ax3_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): - for ax2_0_2_ax3_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 1): - for ax0_ax1_fused in range(1024): + for ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused in T.thread_binding(4, thread="blockIdx.y"): + for ax0_1_ax1_1_ax2_0_1_ax3_0_1_fused in T.thread_binding(1, thread="blockIdx.x"): + for ax0_2_ax1_2_ax2_0_2_ax3_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax4_0_0 in range(1): + for ax0_ax1_fused in range(8192): with T.block("PadInput_reindex_shared"): - v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 64) + v0 = T.axis.spatial(256, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 128 + ax0_ax1_fused // 64) v1 = T.axis.spatial(64, ax0_ax1_fused % 64) T.reads(inputs[0, v0 // 16, v0 % 16, v1]) T.writes(PadInput_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) PadInput_reindex_shared[v0, v1] = inputs[0, v0 // 16, v0 % 16, v1] for ax0_ax1_ax2_ax3_fused in range(2048): with T.block("weight_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused // 32) - v3 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) + v3 = T.axis.spatial(64, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_reindex_shared[v0, v1, v2, v3]) - T.block_attr({"buffer_dim_align": [[0, 2, 32, 8]], "meta_schedule.cooperative_fetch": 4}) + T.block_attr({"buffer_dim_align": [[0, 2, 32, 8]], "meta_schedule.cooperative_fetch": 8}) weight_reindex_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] - for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1): - for ax0_0_1, ax1_0_1 in T.grid(1, 4): + for ax4_0_1 in range(1): + for ax0_0, ax1_0 in T.grid(8, 4): with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): - v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0_1) - v1_o = T.axis.spatial(4, ax1_0_1) + v0_o = T.axis.spatial(16, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 8 + ax0_0) + v1_o = T.axis.spatial(4, ax1_0) T.reads(PadInput_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) - for ax0_1_1, ax1_1_1 in T.grid(16, 16): + for ax0_1, ax1_1 in T.grid(16, 16): with T.block("PadInput_reindex_shared_wmma.matrix_a"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 4, 1): + for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 4, 2): with T.block("weight_reindex_shared_wmma.matrix_b_o"): - v0, v1, v2_o = T.axis.remap("SSS", [ax0, ax1, ax2_0]) - v3_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0) - T.reads(weight_reindex_shared[v0, v1, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + v0_o, v1_o, v2_o = T.axis.remap("SSS", [ax0, ax1, ax2_0]) + v3_o = T.axis.spatial(4, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 2 + ax3_0) + T.reads(weight_reindex_shared[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax2_1, ax3_1 in T.grid(16, 16): with T.block("weight_reindex_shared_wmma.matrix_b"): v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) - T.reads(weight_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) - T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) - weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = weight_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] - for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 4, 1, 1): + T.reads(weight_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = weight_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i] + for ax0_3, ax1_3, ax2_0_3, ax3_0_3, ax4_0_2, ax0_4, ax1_4, ax2_0_4, ax3_0_4 in T.grid(1, 1, 8, 2, 4, 1, 1, 1, 1): with T.block("conv2d_nhwc_o"): - v0 = T.axis.reduce(1, ax0_0 + ax0_1 + ax0_2) - v1 = T.axis.reduce(1, ax1_0 + ax1_1 + ax1_2) - v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax2_0_3 + ax2_0_4) - v3_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0_3 + ax3_0_4) + v0_o = T.axis.spatial(1, ax0_3 + ax0_4) + v1_o = T.axis.spatial(1, ax1_3 + ax1_4) + v2_o = T.axis.spatial(16, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 8 + ax2_0_3 + ax2_0_4) + v3_o = T.axis.spatial(4, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 2 + ax3_0_3 + ax3_0_4) v4_o = T.axis.reduce(4, ax4_0_0 * 4 + ax4_0_1 * 4 + ax4_0_2) - T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16:v4_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, 0:16, 0:16]) + T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16:v4_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, 0:16, 0:16]) T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax2_1, ax3_1 in T.grid(16, 16): with T.block("conv2d_nhwc_init"): v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) T.reads() - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i_init, v3_i_init]) - conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i_init, v3_i_init] = T.float32(0) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i_init, v3_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i_init, v3_i_init] = T.float32(0) for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): with T.block("conv2d_nhwc"): v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i]) T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) - conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) - for ax2 in range(1): - for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax2_1, ax3 in T.grid(1, 1): + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) + for ax2 in range(8): + for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 2): with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): - v0 = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused) - v1 = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax0_ax1_fused) - v2, v3 = T.axis.remap("SS", [ax2_1, ax3]) + v0_o = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2) + v1_o = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2) + v2_o = T.axis.spatial(8, ax2 + ax2_1) + v3_o = T.axis.spatial(2, ax3) v4_o = T.axis.spatial(1, 0) v5_o = T.axis.spatial(1, 0) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) - T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) for ax4, ax5 in T.grid(16, 16): with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) - T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) - conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] for ax0_ax1_ax3_ax4_ax5_fused in range(512): with T.block("conv2d_nhwc_reindex_shared"): - v0 = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused) - v1 = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax0_ax1_ax3_ax4_ax5_fused // 256) - v2 = T.axis.spatial(1, ax2) - v3 = T.axis.spatial(1, 0) + v0 = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2) + v1 = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2) + v2 = T.axis.spatial(8, ax2) + v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 256) v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) - T.writes(conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16]) - T.block_attr({"meta_schedule.cooperative_fetch": 2}) - conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] + T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16 + v1 * 32]) + T.block_attr({"meta_schedule.cooperative_fetch": 1}) + conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16 + v1 * 32] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] # fmt: on decision_0 = [ - ("SamplePerfectTile", [1, 1, 1]), - ("SamplePerfectTile", [1, 1, 1]), - ("SamplePerfectTile", [8, 2, 1, 1, 1]), - ("SamplePerfectTile", [2, 1, 2, 1, 1]), + ("SamplePerfectTile", [1, 1, 1, 1, 1]), + ("SamplePerfectTile", [1, 1, 1, 1, 1]), + ("SamplePerfectTile", [2, 1, 1, 8, 1]), + ("SamplePerfectTile", [2, 1, 1, 2, 1]), ("SamplePerfectTile", [1, 1, 4]), - ("SampleCategorical", 1), ("SampleCategorical", 0), - ("SampleCategorical", 2), + ("SampleCategorical", 1), + ("SampleCategorical", 3), ] mod = te.create_prim_func( diff --git a/tests/python/meta_schedule/test_meta_schedule_trace_apply.py b/tests/python/meta_schedule/test_meta_schedule_trace_apply.py index 8562205753d3..2bc8717b38c4 100644 --- a/tests/python/meta_schedule/test_meta_schedule_trace_apply.py +++ b/tests/python/meta_schedule/test_meta_schedule_trace_apply.py @@ -632,51 +632,31 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " @tvm.script.ir_module class Conv2dInt8_tensorcore_scheduled: @T.prim_func - def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer((1, 1, 1, 256), "int64"), p5: T.Buffer((1, 1, 1, 256), "int64"), p6: T.Buffer((1, 1, 1, 256), "int64"), p7: T.Buffer((), "int32"), p8: T.Buffer(1, "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "uint8")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_s0 = T.int32() - A_s0_1 = T.int32() - A_s0_2 = T.int32() - A_s0_3 = T.int32() - A_s1 = T.int32() - A_s1_1 = T.int32() - A_s1_2 = T.int32() - A_s1_3 = T.int32() - B_s0 = T.int32() - B_s1 = T.int32() - C_s0 = T.int32() - C_s0_1 = T.int32() - C_s0_2 = T.int32() - C_s0_3 = T.int32() - C_s0_4 = T.int32() - C_s1 = T.int32() - C_s1_1 = T.int32() - C_s1_2 = T.int32() - C_s1_3 = T.int32() - C_s1_4 = T.int32() - # body - # with T.block("root") - conv2d_nhwc_reindex_shared = T.alloc_buffer([50176, 256], dtype="int32", scope="shared") - conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([50176, 256], dtype="int32", scope="wmma.accumulator") - pad_temp_reindex_shared = T.alloc_buffer([50176, 64], dtype="int8", scope="shared") - p1_reindex_shared = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="shared") - pad_temp_reindex_shared_wmma_matrix_a = T.alloc_buffer([50176, 64], dtype="int8", scope="wmma.matrix_a") - p1_reindex_shared_wmma_matrix_b = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="wmma.matrix_b") - for ax2_0_0_ax3_0_0_fused in T.thread_binding(3136, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step":512, "pragma_unroll_explicit":1}): + def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer((1, 1, 1, 256), "int64"), p5: T.Buffer((1, 1, 1, 256), "int64"), p6: T.Buffer((1, 1, 1, 256), "int64"), p7: T.Buffer((), "int32"), p8: T.Buffer((1,), "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "uint8")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + conv2d_nhwc_reindex_shared = T.alloc_buffer((50176, 256), "int32", scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((50176, 256), "int32", scope="wmma.accumulator") + pad_temp_reindex_shared = T.alloc_buffer((50176, 64), "int8", scope="shared") + p1_reindex_shared = T.alloc_buffer((1, 1, 256, 64), "int8", scope="shared") + pad_temp_reindex_shared_wmma_matrix_a = T.alloc_buffer((50176, 64), "int8", scope="wmma.matrix_a") + p1_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 256, 64), "int8", scope="wmma.matrix_b") + for ax2_0_0_ax3_0_0_fused in T.thread_binding(3136, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}): for ax2_0_1_ax3_0_1_fused in T.thread_binding(1, thread="vthread.x"): for ax2_0_2_ax3_0_2_fused in T.thread_binding(16, thread="threadIdx.x"): - for ax2_0_3_init, ax3_0_3_init, ax2_0_4_init, ax3_0_4_init in T.grid(1, 1, 1, 1): + for ax0_0_init, ax1_0_init, ax0_1_init, ax1_1_init, ax2_0_3_init, ax3_0_3_init, ax0_2_init, ax1_2_init, ax2_0_4_init, ax3_0_4_init in T.grid(1, 1, 1, 1, 1, 1, 1, 1, 1, 1): with T.block("conv2d_nhwc_o_init"): + v0_o = T.axis.spatial(1, ax0_0_init + ax0_1_init + ax0_2_init) + v1_o = T.axis.spatial(1, ax1_0_init + ax1_1_init + ax1_2_init) v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3_init + ax2_0_4_init) v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3_init + ax3_0_4_init) T.reads() - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) - C = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int32", strides=[C_s0, C_s1], scope="wmma.accumulator", offset_factor=16) - T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, T.float32(0), dtype="handle") + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "warp_execution": 1}) + C = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) + T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0)) for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 2): - for ax0_ax1_fused_0 in T.serial(16): + for ax0_ax1_fused_0 in range(16): for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(16): with T.block("pad_temp_reindex_shared"): @@ -684,9 +664,9 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " v1 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2) % 32) T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) T.writes(pad_temp_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 16]]}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 16]]}) pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] - for ax0_ax1_ax2_ax3_fused_0 in T.serial(8): + for ax0_ax1_ax2_ax3_fused_0 in range(8): for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(16, thread="threadIdx.x"): for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(8): with T.block("p1_reindex_shared"): @@ -696,51 +676,51 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " v3 = T.axis.spatial(64, ax4_0_0 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 8 + ax0_ax1_ax2_ax3_fused_2) % 32) T.reads(p1[v2, v0, v1, v3]) T.writes(p1_reindex_shared[v0, v1, v2, v3]) - T.block_attr({"buffer_dim_align":[[0, 2, 32, 16]]}) + T.block_attr({"buffer_dim_align": [[0, 2, 32, 16]]}) p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1): for ax0_0_1, ax1_0_1 in T.grid(1, 2): with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0_1) v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax1_0_1) - T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - A = T.match_buffer(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int8", strides=[A_s0, A_s1], scope="shared", offset_factor=16) - C_1 = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int8", strides=[C_s0_1, C_s1_1], scope="wmma.matrix_a", offset_factor=16) - T.tvm_load_matrix_sync(C_1.data, 16, 16, 16, C_1.elem_offset // C_s0_1 // 16 * (C_s0_1 // 16) + C_1.elem_offset % C_s0_1 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A.data, A.elem_offset, A_s0 * 16, 1, dtype="handle"), A_s0, "row_major", dtype="handle") + T.reads(pad_temp_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + A = T.match_buffer(pad_temp_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="shared", offset_factor=16) + C = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) + T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 1, 2): with T.block("p1_reindex_shared_wmma.matrix_b_o"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v0_o, v1_o = T.axis.remap("SS", [ax0, ax1]) v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax2_0) v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax3_0) - T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - A_1 = T.match_buffer(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int8", strides=[A_s0_1, A_s1_1], scope="shared", offset_factor=16) - C_2 = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int8", strides=[C_s0_2, C_s1_2], scope="wmma.matrix_b", offset_factor=16) - T.tvm_load_matrix_sync(C_2.data, 16, 16, 16, C_2.elem_offset // C_s0_2 // 16 * (C_s0_2 // 16) + C_2.elem_offset % C_s0_2 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int8"), A_1.data, A_1.elem_offset, A_s0_1 * 16, 1, dtype="handle"), A_s0_1, "col_major", dtype="handle") + T.reads(p1_reindex_shared[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.writes(p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + A = T.match_buffer(p1_reindex_shared[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="shared", offset_factor=16) + C = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) + T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major") for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 2, 1, 1): with T.block("conv2d_nhwc_o_update"): - v0 = T.axis.reduce(1, ax0_0 + ax0_1 + ax0_2) - v1 = T.axis.reduce(1, ax1_0 + ax1_1 + ax1_2) + v0_o = T.axis.spatial(1, ax0_0 + ax0_1 + ax0_2) + v1_o = T.axis.spatial(1, ax1_0 + ax1_1 + ax1_2) v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax3_0_3 + ax3_0_4) v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) - A_2 = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], [16, 16], dtype="int8", strides=[A_s0_2, A_s1_2], scope="wmma.matrix_a", offset_factor=16) - B = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], [16, 16], dtype="int8", strides=[B_s0, B_s1], scope="wmma.matrix_b", offset_factor=16) - C_3 = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], [16, 16], dtype="int32", strides=[C_s0_3, C_s1_3], scope="wmma.accumulator", offset_factor=16) - T.tvm_mma_sync(C_3.data, C_3.elem_offset // C_s0_3 // 16 * (C_s0_3 // 16) + C_3.elem_offset % C_s0_3 // 16, A_2.data, A_2.elem_offset // A_s0_2 // 16 * (A_s0_2 // 16) + A_2.elem_offset % A_s0_2 // 16, B.data, B.elem_offset // B_s0 // 16 * (B_s0 // 16) + B.elem_offset % B_s0 // 16, C_3.data, C_3.elem_offset // C_s0_3 // 16 * (C_s0_3 // 16) + C_3.elem_offset % C_s0_3 // 16, dtype="handle") + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v3_o * 16:v3_o * 16 + 16, v4_o * 16:v4_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "warp_execution": 1}) + A = T.match_buffer(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) + B = T.match_buffer(p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v3_o * 16:v3_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], (16, 16), "int8", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) + C = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) + T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) for ax0_0, ax1_0 in T.grid(1, 1): with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 8 * 8 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0) v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 8 * 2 + ax2_0_2_ax3_0_2_fused % 2 + ax1_0) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - A_3 = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int32", strides=[A_s0_3, A_s1_3], scope="wmma.accumulator", offset_factor=16) - C_4 = T.match_buffer(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="int32", strides=[C_s0_4, C_s1_4], scope="shared", offset_factor=16) - T.tvm_store_matrix_sync(A_3.data, 16, 16, 16, A_3.elem_offset // A_s0_3 // 16 * (A_s0_3 // 16) + A_3.elem_offset % A_s0_3 // 16, T.tvm_access_ptr(T.type_annotation(dtype="int32"), C_4.data, C_4.elem_offset, C_s0_4 * 16, 2, dtype="handle"), C_s0_4, "row_major", dtype="handle") + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + A = T.match_buffer(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "int32", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16) + C = T.match_buffer(conv2d_nhwc_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="shared", offset_factor=16) + T.tvm_store_matrix_sync(A.data, 16, 16, 16, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int32"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") for ax0, ax1_0 in T.grid(128, 2): for ax1_1 in T.thread_binding(16, thread="threadIdx.x"): with T.block("conv2d_nhwc_reindex_shared"): @@ -748,8 +728,7 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " v1 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 8 * 32 + ax1_0 * 16 + ax1_1) T.reads(p7[()], conv2d_nhwc_reindex_shared[v0, v1], p2[0, 0, 0, v1], p3[0, 0, 0, v1], p4[0, 0, 0, v1], p5[0, 0, 0, v1], p6[0, 0, 0, v1], p8[0], p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) T.writes(compute[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) - compute[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] = T.max(T.min(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(p7[()] + T.cast(T.shift_right(T.cast(conv2d_nhwc_reindex_shared[v0, v1] - p2[0, 0, 0, v1] + p3[0, 0, 0, v1], "int64") * p4[0, 0, 0, v1] + p5[0, 0, 0, v1], p6[0, 0, 0, v1], dtype="int64"), "int32"), 255), 0), "uint8"), "int32") - p8[0], 1098990753, 31, 1, dtype="int32") + p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1], 255), 0), "uint8"), T.uint8(255)), T.uint8(0)) - + compute[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] = T.max(T.min(T.Cast("uint8", T.max(T.min(T.q_multiply_shift(T.Cast("int32", T.Cast("uint8", T.max(T.min(p7[()] + T.Cast("int32", T.shift_right(T.Cast("int64", conv2d_nhwc_reindex_shared[v0, v1] - p2[0, 0, 0, v1] + p3[0, 0, 0, v1]) * p4[0, 0, 0, v1] + p5[0, 0, 0, v1], p6[0, 0, 0, v1])), 255), 0))) - p8[0], 1098990753, 31, 1) + p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1], 255), 0)), T.uint8(255)), T.uint8(0)) @tvm.script.ir_module class Conv2dInt8_NCHWc: @@ -1698,33 +1677,31 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " @tvm.script.ir_module class Conv2dInt8_with_predicate_scheduled: @T.prim_func - def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer(256, "int32"), p5: T.Buffer(256, "int32"), p6: T.Buffer(256, "int32"), p7: T.Buffer((), "int32"), p8: T.Buffer(1, "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "int32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body + def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer((256,), "int32"), p5: T.Buffer((256,), "int32"), p6: T.Buffer((256,), "int32"), p7: T.Buffer((), "int32"), p8: T.Buffer((1,), "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "int32")): + T.func_attr({"tir.noalias": T.bool(True)}) with T.block("root"): T.reads() T.writes() - T.block_attr({"meta_schedule.unroll_explicit":1024}) - conv2d_nhwc_reindex_shared = T.alloc_buffer([50176, 256], dtype="int32", scope="shared") - conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([50176, 256], dtype="int32", scope="wmma.accumulator") - pad_temp_reindex_shared = T.alloc_buffer([50176, 64], dtype="int8", scope="shared") - p1_reindex_shared = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="shared") - pad_temp_reindex_shared_wmma_matrix_a = T.alloc_buffer([50176, 64], dtype="int8", scope="wmma.matrix_a") - p1_reindex_shared_wmma_matrix_b = T.alloc_buffer([1, 1, 256, 64], dtype="int8", scope="wmma.matrix_b") + T.block_attr({"meta_schedule.unroll_explicit": 1024}) + conv2d_nhwc_reindex_shared = T.alloc_buffer((50176, 256), "int32", scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((50176, 256), "int32", scope="wmma.accumulator") + pad_temp_reindex_shared = T.alloc_buffer((50176, 64), "int8", scope="shared") + p1_reindex_shared = T.alloc_buffer((1, 1, 256, 64), "int8", scope="shared") + pad_temp_reindex_shared_wmma_matrix_a = T.alloc_buffer((50176, 64), "int8", scope="wmma.matrix_a") + p1_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 256, 64), "int8", scope="wmma.matrix_b") for ax2_0_0_ax3_0_0_fused in T.thread_binding(32, thread="blockIdx.y"): for ax2_0_1_ax3_0_1_fused in T.thread_binding(196, thread="blockIdx.x"): for ax2_0_2_ax3_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 2): - for ax0_ax1_fused in T.serial(1024): + for ax0_ax1_fused in range(1024): with T.block("pad_temp_reindex_shared"): v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0_ax1_fused // 32) v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32) T.reads(p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) T.writes(pad_temp_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 16]], "meta_schedule.cooperative_fetch":4}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 16]], "meta_schedule.cooperative_fetch": 4}) pad_temp_reindex_shared[v0, v1] = p0[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] - for ax0_ax1_ax2_ax3_fused in T.serial(2048): + for ax0_ax1_ax2_ax3_fused in range(2048): with T.block("p1_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) @@ -1732,16 +1709,16 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " v3 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(p1[v2, v0, v1, v3]) T.writes(p1_reindex_shared[v0, v1, v2, v3]) - T.block_attr({"buffer_dim_align":[[0, 2, 32, 16]], "meta_schedule.cooperative_fetch":3}) + T.block_attr({"buffer_dim_align": [[0, 2, 32, 16]], "meta_schedule.cooperative_fetch": 3}) p1_reindex_shared[v0, v1, v2, v3] = p1[v2, v0, v1, v3] for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 2): for ax0_0_1, ax1_0_1 in T.grid(1, 1): with T.block("pad_temp_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0_1) v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1 + ax1_0_1) - T.reads(pad_temp_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a_shared"}) + T.reads(pad_temp_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_s8_a_shared"}) for ax0_1_1, ax1_1_1 in T.grid(16, 16): with T.block("pad_temp_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) @@ -1750,28 +1727,28 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = pad_temp_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 1): with T.block("p1_reindex_shared_wmma.matrix_b_o"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v0_o, v1_o = T.axis.remap("SS", [ax0, ax1]) v2_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax2_0) v3_o = T.axis.spatial(4, ax4_0_0 * 2 + ax4_0_1 + ax3_0) - T.reads(p1_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans_shared"}) + T.reads(p1_reindex_shared[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.writes(p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_s8_b_trans_shared"}) for ax2_1, ax3_1 in T.grid(16, 16): with T.block("p1_reindex_shared_wmma.matrix_b"): v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) - T.reads(p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) - T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) - p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = p1_reindex_shared[v0, v1, v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.reads(p1_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + T.writes(p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]) + p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = p1_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i] for ax2_0_3, ax3_0_3, ax0_2, ax1_2, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(1, 1, 1, 1, 1, 1, 2): with T.block("conv2d_nhwc_o"): - v0 = T.axis.reduce(1, ax0_0 + ax0_1 + ax0_2) - v1 = T.axis.reduce(1, ax1_0 + ax1_1 + ax1_2) + v0_o = T.axis.spatial(1, ax0_0 + ax0_1 + ax0_2) + v1_o = T.axis.spatial(1, ax1_0 + ax1_1 + ax1_2) v2_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax2_0_3 + ax2_0_4) v3_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax3_0_3 * 2 + ax3_0_4) v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 + ax4_0_2) - T.reads(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 : v3_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_s8s8s32_trans", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_s32", "meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "warp_execution":1}) + T.reads(pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v3_o * 16:v3_o * 16 + 16, v4_o * 16:v4_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_s8s8s32_trans", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_s32", "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 32, "warp_execution": 1}) with T.init(): for ax2_1, ax3_1 in T.grid(16, 16): with T.block("conv2d_nhwc_init"): @@ -1782,17 +1759,17 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): with T.block("conv2d_nhwc"): v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.Cast("int32", pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("int32", p1_reindex_shared_wmma_matrix_b[v0, v1, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.Cast("int32", pad_temp_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("int32", p1_reindex_shared_wmma_matrix_b[v0_o, v1_o, v3_o * 16 + v3_i, v4_o * 16 + v4_i]) for ax0_0, ax1_0 in T.grid(1, 2): with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): v0_o = T.axis.spatial(3136, ax2_0_0_ax3_0_0_fused // 4 * 392 + ax2_0_1_ax3_0_1_fused * 2 + ax2_0_2_ax3_0_2_fused // 2 + ax0_0) v1_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused % 4 * 4 + ax2_0_2_ax3_0_2_fused % 2 * 2 + ax1_0) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_s32_shared"}) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_s32_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -1801,13 +1778,12 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] for ax0, ax1_0, ax1_1, ax1_2, ax1_3 in T.grid(32, 1, 4, 32, 2): with T.block("conv2d_nhwc_reindex_shared"): - T.where(((ax1_0 * 4 + ax1_1) * 32 + ax1_2) * 2 + ax1_3 < 64) v0 = T.axis.spatial(50176, ax2_0_0_ax3_0_0_fused // 4 * 6272 + ax2_0_1_ax3_0_1_fused * 32 + ax0) v1 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused % 4 * 64 + (ax1_0 * 256 + ax1_1 * 64 + ax1_2 * 2 + ax1_3)) + T.where(((ax1_0 * 4 + ax1_1) * 32 + ax1_2) * 2 + ax1_3 < 64) T.reads(p7[()], conv2d_nhwc_reindex_shared[v0, v1], p2[0, 0, 0, v1], p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], p8[0], p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) T.writes(compute[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1]) - compute[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] = T.max(T.min(T.q_multiply_shift(T.max(T.min(p7[()] + T.q_multiply_shift_per_axis(conv2d_nhwc_reindex_shared[v0, v1] - p2[0, 0, 0, v1] + p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], 31, False, True, dtype="int32"), 255), 0) - p8[0], 1457846997, 31, 0, dtype="int32") + T.q_multiply_shift(p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1], 2101000910, 31, 0, dtype="int32"), 255), 0) - + compute[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1] = T.max(T.min(T.q_multiply_shift(T.max(T.min(p7[()] + T.q_multiply_shift_per_axis(conv2d_nhwc_reindex_shared[v0, v1] - p2[0, 0, 0, v1] + p3[0, 0, 0, v1], p4[v1], p5[v1], p6[v1], 31, T.bool(False), T.bool(True)), 255), 0) - p8[0], 1457846997, 31, 0) + T.q_multiply_shift(p9[v0 // 3136, v0 % 3136 // 56, v0 % 56, v1], 2101000910, 31, 0), 255), 0) # fmt: on def verify(anchor_mod, anchor_trace_fun, target_mod, target, ref): diff --git a/tests/python/runtime/test_runtime_trace.py b/tests/python/runtime/test_runtime_trace.py index b1a7a5338fb5..08f56b56c8c7 100644 --- a/tests/python/runtime/test_runtime_trace.py +++ b/tests/python/runtime/test_runtime_trace.py @@ -212,10 +212,4 @@ def check_assign(dtype): if __name__ == "__main__": - test_trace_expr_assign() - test_trace_expr_sum_generated() - test_trace_expr_sum_custom() - test_trace_expr_sum_args() - test_trace_default_action() - test_trace_can_change_traced_value_int() - test_trace_can_change_traced_value_float() + tvm.testing.main() diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index 326ef2b8ce56..ade414f4234f 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -244,8 +244,8 @@ def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, (128, 128), elem_offset=off3) # body with T.block("C"): - T.reads([A[0:128, 0:128], B[0:128, 0:128]]) - T.writes([C[0:128, 0:128]]) + T.reads() + T.writes() T.evaluate( T.tvm_call_packed( "tvm.contrib.cblas.matmul", @@ -778,8 +778,8 @@ def tir_extern(var_A: T.handle, var_B: T.handle, var_P: T.handle, var_C: T.handl P = T.match_buffer(var_P, [1], dtype="float32", offset_factor=1) C = T.match_buffer(var_C, [128, 128], dtype="float32", offset_factor=1) with T.block("C"): - T.reads(A[0:128, 0:128], B[0:128, 0:128], P[0]) - T.writes(C[0:128, 0:128]) + T.reads() + T.writes() T.call_extern("myfunc", A.data, B.data, C.data, P[0], dtype="") _check_workload(te_extern, tir_extern) diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py index e839f44b3306..8c153afc9de9 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py @@ -162,8 +162,7 @@ def func(A: T.Buffer([256], "float32")): with T.launch_thread("threadIdx.x", 256) as threadIdx_x: A[threadIdx_x] = A[threadIdx_x] + 2.0 - with pytest.raises(ValueError): - tvm.tir.analysis.verify_well_formed(func) + tvm.tir.analysis.verify_well_formed(func) def test_reuse_of_env_thread_across_functions_is_ill_formed(): diff --git a/tests/python/tir-base/test_debug_info.py b/tests/python/tir-base/test_debug_info.py index 16a589a2ddcc..8bd22f1bb6bd 100644 --- a/tests/python/tir-base/test_debug_info.py +++ b/tests/python/tir-base/test_debug_info.py @@ -155,16 +155,14 @@ def test_llvm_ir_debug_accuracy(): locations = find_di_locations(source) # Find the 'assert' from MyModule - debug_dir_match = re.search( - r"tail call void %0\(i8\* getelementptr inbounds .* !dbg !(\d+)\n", source - ) + debug_dir_match = re.search(r"tail call void %0\(.* !dbg !(\d+)\n", source) # Extract out the debug directive line directive_idx = debug_dir_match.groups()[0] # Check that it matches the expected line number (in main.tir) debug_line_no = int(locations[directive_idx]) - assert debug_line_no == 43 + assert debug_line_no == 56 if __name__ == "__main__": diff --git a/tests/python/tir-schedule/test_tir_schedule_rfactor.py b/tests/python/tir-schedule/test_tir_schedule_rfactor.py index 9856c082045e..37e68fa21a0e 100644 --- a/tests/python/tir-schedule/test_tir_schedule_rfactor.py +++ b/tests/python/tir-schedule/test_tir_schedule_rfactor.py @@ -1582,6 +1582,7 @@ def test_reduction_rfactor_argmax_body_bufferstore_value_not_var(): s.rfactor(ki, 1) +@pytest.mark.xfail(reason="The input IR is not well-formed") def test_reduction_rfactor_argmax_body_bufferstore_value_unbound_var(): s = tir.Schedule(argmax_split_body_bufferstore_value_unbound_var, debug_mask="all") argmax = s.get_block("argmax") diff --git a/tests/scripts/task_python_unittest.sh b/tests/scripts/task_python_unittest.sh index de1a1a6d384b..5b07b5256ea5 100755 --- a/tests/scripts/task_python_unittest.sh +++ b/tests/scripts/task_python_unittest.sh @@ -54,7 +54,7 @@ TEST_FILES=( "usmp" ) -for TEST_FILE in ${TEST_FILES}; do +for TEST_FILE in ${TEST_FILES[@]}; do run_pytest ctypes ${TEST_FILE}-0, tests/python/${TEST_FILE} run_pytest cython ${TEST_FILE}-1, tests/python/${TEST_FILE} done