-
Notifications
You must be signed in to change notification settings - Fork 162
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
308 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,301 @@ | ||
cuda_default_header = """ | ||
#include <cuda_runtime.h> | ||
#include <math.h> | ||
#include <mma.h> | ||
""" | ||
|
||
cuda_fp16_header = """ | ||
#include <cuda_fp16.h> | ||
__device__ half max(half a, half b) | ||
{ | ||
return __hgt(__half(a), __half(b)) ? a : b; | ||
} | ||
__device__ half min(half a, half b) | ||
{ | ||
return __hlt(__half(a), __half(b)) ? a : b; | ||
} | ||
#define __int8_t_defined | ||
#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \ | ||
inline __device__ half HALF_MATH_NAME(half x, half y) { \ | ||
float tmp_x = __half2float(x); \ | ||
float tmp_y = __half2float(y); \ | ||
float result = FP32_MATH_NAME(tmp_x, tmp_y); \ | ||
return __float2half(result); \ | ||
} | ||
#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \ | ||
inline __device__ half HALF_MATH_NAME(half x) { \ | ||
float tmp_x = __half2float(x); \ | ||
float result = FP32_MATH_NAME(tmp_x); \ | ||
return __float2half(result); \ | ||
} | ||
CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf) | ||
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf) | ||
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf) | ||
CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf) | ||
CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf) | ||
#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY | ||
#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY | ||
// Pack two half values. | ||
inline __device__ __host__ unsigned | ||
__pack_half2(const half x, const half y) { | ||
unsigned v0 = *((unsigned short *)&x); | ||
unsigned v1 = *((unsigned short *)&y); | ||
return (v1 << 16) | v0; | ||
} | ||
// There is no make_int8 in cuda, but TVM codegen seem to use it | ||
inline __device__ longlong4 make_int8(int x0, int x1, int x2, int x3, int x4, int x5, int x6, int x7) { | ||
int2 i0 = make_int2(x0, x1); | ||
int2 i1 = make_int2(x2, x3); | ||
int2 i2 = make_int2(x4, x5); | ||
int2 i3 = make_int2(x6, x7); | ||
long long l0 = *(long long*)&i0; | ||
long long l1 = *(long long*)&i1; | ||
long long l2 = *(long long*)&i2; | ||
long long l3 = *(long long*)&i3; | ||
return make_longlong4(l0, l1, l2, l3); | ||
} | ||
""" | ||
|
||
rocm_default_header = """ | ||
#include <hip/hip_runtime.h> | ||
#include <math.h> | ||
""" | ||
|
||
rocm_fp16_header = """ | ||
#define half _Float16 | ||
#define __float2half_rn(x) half(x) | ||
#define htanh tanhf | ||
#define htan tanf | ||
#define hatan atanf | ||
#define herf erff | ||
#include <hip/hcc_detail/hip_fp16_math_fwd.h> | ||
#define hpow __ocml_pown_f16 | ||
#define hsqrt __ocml_sqrt_f16 | ||
#define hexp __ocml_exp_f16 | ||
// Pack two half values. | ||
inline __device__ __host__ unsigned | ||
__pack_half2(const half x, const half y) { | ||
unsigned v0 = *((unsigned short *)&x); | ||
unsigned v1 = *((unsigned short *)&y); | ||
return (v1 << 16) | v0; | ||
} | ||
// There is no make_int8 in cuda, but TVM codegen seem to use it | ||
inline __device__ longlong4 make_int8(int x0, int x1, int x2, int x3, int x4, int x5, int x6, int x7) { | ||
int2 i0 = make_int2(x0, x1); | ||
int2 i1 = make_int2(x2, x3); | ||
int2 i2 = make_int2(x4, x5); | ||
int2 i3 = make_int2(x6, x7); | ||
long long l0 = *(long long*)&i0; | ||
long long l1 = *(long long*)&i1; | ||
long long l2 = *(long long*)&i2; | ||
long long l3 = *(long long*)&i3; | ||
return make_longlong4(l0, l1, l2, l3); | ||
} | ||
""" | ||
|
||
cutlass_header = """ | ||
#include "cutlass/cutlass.h" | ||
#include "cutlass/gemm/warp/mma_tensor_op.h" | ||
#include "cutlass/gemm/warp/mma_tensor_op_sm70.h" | ||
namespace cutlass { | ||
namespace gemm { | ||
namespace warp { | ||
template < | ||
typename Shape, | ||
typename SMemLayoutA, | ||
typename SMemLayoutB | ||
> | ||
class GemmTensorOp { | ||
public: | ||
using InstructionShape = GemmShape<16, 8, 16>; | ||
using WarpShape = GemmShape<Shape::kM, Shape::kN, InstructionShape::kK>; | ||
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< | ||
cutlass::arch::Mma< | ||
InstructionShape, | ||
32, | ||
cutlass::half_t, | ||
cutlass::layout::RowMajor, | ||
cutlass::half_t, | ||
cutlass::layout::ColumnMajor, | ||
cutlass::half_t, | ||
cutlass::layout::RowMajor, | ||
cutlass::arch::OpMultiplyAdd | ||
>, | ||
cutlass::MatrixShape<1, 1> | ||
>; | ||
using MmaWarp = typename cutlass::gemm::warp::MmaTensorOp< | ||
WarpShape, | ||
cutlass::half_t, | ||
SMemLayoutA, | ||
cutlass::half_t, | ||
SMemLayoutB, | ||
cutlass::half_t, | ||
cutlass::layout::RowMajor, | ||
Policy | ||
>; | ||
static_assert(Shape::kK % InstructionShape::kK == 0); | ||
static int const kKgroups = Shape::kK / InstructionShape::kK; | ||
using TensorRefA = typename MmaWarp::IteratorA::TensorRef; | ||
using TensorRefB = typename MmaWarp::IteratorB::TensorRef; | ||
typename MmaWarp::FragmentA frag_A[2]; | ||
typename MmaWarp::FragmentB frag_B[2]; | ||
typename MmaWarp::FragmentC accum; | ||
MmaWarp mma_op; | ||
typename MmaWarp::IteratorA iter_A; | ||
typename MmaWarp::IteratorB iter_B; | ||
const int warp_idx_m_, warp_idx_n_, lane_id_; | ||
public: | ||
CUTLASS_DEVICE | ||
GemmTensorOp(int warp_idx_m, int warp_idx_n, int lane_id) | ||
: warp_idx_m_(warp_idx_m), warp_idx_n_(warp_idx_n), lane_id_(lane_id), iter_A({nullptr, 0}, 0), iter_B({nullptr, 0}, 0) { | ||
accum.clear(); | ||
} | ||
CUTLASS_DEVICE | ||
void prologue(const TensorRefA &ref_A, const TensorRefB &ref_B) { | ||
iter_A = typename MmaWarp::IteratorA(ref_A, lane_id_); | ||
iter_B = typename MmaWarp::IteratorB(ref_B, lane_id_); | ||
iter_A.add_tile_offset({warp_idx_m_, 0}); | ||
iter_B.add_tile_offset({0, warp_idx_n_}); | ||
iter_A.load(frag_A[0]); | ||
iter_B.load(frag_B[0]); | ||
++iter_A; | ||
++iter_B; | ||
} | ||
CUTLASS_DEVICE | ||
void body() { | ||
CUTLASS_PRAGMA_UNROLL | ||
for (int k = 0; k < kKgroups - 1; ++k) { | ||
iter_A.load(frag_A[(k + 1) % 2]); | ||
iter_B.load(frag_B[(k + 1) % 2]); | ||
++iter_A; | ||
++iter_B; | ||
mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum); | ||
} | ||
__syncthreads(); | ||
} | ||
CUTLASS_DEVICE | ||
void epilogue() { | ||
mma_op(accum, frag_A[(kKgroups - 1) % 2], frag_B[(kKgroups - 1) % 2], accum); | ||
} | ||
CUTLASS_DEVICE | ||
half& operator[](const size_t i) const { | ||
return ((half*)accum.data())[i]; | ||
} | ||
CUTLASS_DEVICE | ||
half* operator+(const size_t i) const { | ||
return (half*)accum.data() + i; | ||
} | ||
}; | ||
template < | ||
typename Shape, | ||
typename SMemLayoutA, | ||
typename LayoutA, | ||
typename SMemLayoutB, | ||
typename LayoutB, | ||
typename LayoutC | ||
> | ||
class VoltaGemmTensorOp { | ||
public: | ||
using InstructionShape = GemmShape<16, 16, 4>; | ||
using WarpShape = GemmShape<Shape::kM, Shape::kN, InstructionShape::kK>; | ||
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< | ||
cutlass::arch::Mma< | ||
InstructionShape, | ||
32, | ||
cutlass::half_t, | ||
LayoutA, | ||
cutlass::half_t, | ||
LayoutB, | ||
cutlass::half_t, | ||
LayoutC, | ||
cutlass::arch::OpMultiplyAdd | ||
>, | ||
cutlass::MatrixShape<1, 1> | ||
>; | ||
using MmaWarp = typename cutlass::gemm::warp::MmaVoltaTensorOp< | ||
WarpShape, | ||
cutlass::half_t, | ||
SMemLayoutA, | ||
cutlass::half_t, | ||
SMemLayoutB, | ||
cutlass::half_t, | ||
LayoutC, | ||
Policy | ||
>; | ||
static_assert(Shape::kK % InstructionShape::kK == 0); | ||
static int const kKgroups = Shape::kK / InstructionShape::kK; | ||
using TensorRefA = typename MmaWarp::IteratorA::TensorRef; | ||
using TensorRefB = typename MmaWarp::IteratorB::TensorRef; | ||
typename MmaWarp::FragmentA frag_A[2]; | ||
typename MmaWarp::FragmentB frag_B[2]; | ||
typename MmaWarp::FragmentC accum; | ||
typename MmaWarp::IteratorA iter_A; | ||
typename MmaWarp::IteratorB iter_B; | ||
const int warp_idx_m_, warp_idx_n_, lane_id_; | ||
MmaWarp mma_op; | ||
public: | ||
CUTLASS_DEVICE | ||
VoltaGemmTensorOp(int warp_idx_m, int warp_idx_n, int lane_id) | ||
: warp_idx_m_(warp_idx_m), warp_idx_n_(warp_idx_n), lane_id_(lane_id), iter_A({nullptr, 0}, 0), iter_B({nullptr, 0}, 0) { | ||
accum.clear(); | ||
} | ||
CUTLASS_DEVICE | ||
void prologue(TensorRefA ref_A, TensorRefB ref_B) { | ||
iter_A = typename MmaWarp::IteratorA(ref_A, lane_id_); | ||
iter_B = typename MmaWarp::IteratorB(ref_B, lane_id_); | ||
iter_A.add_tile_offset({warp_idx_m_, 0}); | ||
iter_B.add_tile_offset({0, warp_idx_n_}); | ||
iter_A.load(frag_A[0]); | ||
iter_B.load(frag_B[0]); | ||
++iter_A; | ||
++iter_B; | ||
} | ||
CUTLASS_DEVICE | ||
void body() { | ||
CUTLASS_PRAGMA_UNROLL | ||
for (int k = 0; k < kKgroups - 1; ++k) { | ||
iter_A.load(frag_A[(k + 1) % 2]); | ||
iter_B.load(frag_B[(k + 1) % 2]); | ||
++iter_A; | ||
++iter_B; | ||
mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum); | ||
} | ||
__syncthreads(); | ||
} | ||
CUTLASS_DEVICE | ||
void epilogue() { | ||
mma_op(accum, frag_A[(kKgroups - 1) % 2], frag_B[(kKgroups - 1) % 2], accum); | ||
} | ||
CUTLASS_DEVICE | ||
half& operator[](size_t i) const { | ||
return ((half*)accum.data())[i]; | ||
} | ||
CUTLASS_DEVICE | ||
half* operator+(size_t i) const { | ||
return (half*)accum.data() + i; | ||
} | ||
}; | ||
}}} | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters