Skip to content

Commit

Permalink
fp16 support
Browse files Browse the repository at this point in the history
  • Loading branch information
jlxue committed Nov 22, 2023
1 parent 8f50973 commit fcf5f3f
Show file tree
Hide file tree
Showing 2 changed files with 308 additions and 2 deletions.
301 changes: 301 additions & 0 deletions test/python/dsl_frontend/common_header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
cuda_default_header = """
#include <cuda_runtime.h>
#include <math.h>
#include <mma.h>
"""

cuda_fp16_header = """
#include <cuda_fp16.h>
__device__ half max(half a, half b)
{
return __hgt(__half(a), __half(b)) ? a : b;
}
__device__ half min(half a, half b)
{
return __hlt(__half(a), __half(b)) ? a : b;
}
#define __int8_t_defined
#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \
inline __device__ half HALF_MATH_NAME(half x, half y) { \
float tmp_x = __half2float(x); \
float tmp_y = __half2float(y); \
float result = FP32_MATH_NAME(tmp_x, tmp_y); \
return __float2half(result); \
}
#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \
inline __device__ half HALF_MATH_NAME(half x) { \
float tmp_x = __half2float(x); \
float result = FP32_MATH_NAME(tmp_x); \
return __float2half(result); \
}
CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)
#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY
#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY
// Pack two half values.
inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
// There is no make_int8 in cuda, but TVM codegen seem to use it
inline __device__ longlong4 make_int8(int x0, int x1, int x2, int x3, int x4, int x5, int x6, int x7) {
int2 i0 = make_int2(x0, x1);
int2 i1 = make_int2(x2, x3);
int2 i2 = make_int2(x4, x5);
int2 i3 = make_int2(x6, x7);
long long l0 = *(long long*)&i0;
long long l1 = *(long long*)&i1;
long long l2 = *(long long*)&i2;
long long l3 = *(long long*)&i3;
return make_longlong4(l0, l1, l2, l3);
}
"""

rocm_default_header = """
#include <hip/hip_runtime.h>
#include <math.h>
"""

rocm_fp16_header = """
#define half _Float16
#define __float2half_rn(x) half(x)
#define htanh tanhf
#define htan tanf
#define hatan atanf
#define herf erff
#include <hip/hcc_detail/hip_fp16_math_fwd.h>
#define hpow __ocml_pown_f16
#define hsqrt __ocml_sqrt_f16
#define hexp __ocml_exp_f16
// Pack two half values.
inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
// There is no make_int8 in cuda, but TVM codegen seem to use it
inline __device__ longlong4 make_int8(int x0, int x1, int x2, int x3, int x4, int x5, int x6, int x7) {
int2 i0 = make_int2(x0, x1);
int2 i1 = make_int2(x2, x3);
int2 i2 = make_int2(x4, x5);
int2 i3 = make_int2(x6, x7);
long long l0 = *(long long*)&i0;
long long l1 = *(long long*)&i1;
long long l2 = *(long long*)&i2;
long long l3 = *(long long*)&i3;
return make_longlong4(l0, l1, l2, l3);
}
"""

cutlass_header = """
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/mma_tensor_op.h"
#include "cutlass/gemm/warp/mma_tensor_op_sm70.h"
namespace cutlass {
namespace gemm {
namespace warp {
template <
typename Shape,
typename SMemLayoutA,
typename SMemLayoutB
>
class GemmTensorOp {
public:
using InstructionShape = GemmShape<16, 8, 16>;
using WarpShape = GemmShape<Shape::kM, Shape::kN, InstructionShape::kK>;
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<
InstructionShape,
32,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::ColumnMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::arch::OpMultiplyAdd
>,
cutlass::MatrixShape<1, 1>
>;
using MmaWarp = typename cutlass::gemm::warp::MmaTensorOp<
WarpShape,
cutlass::half_t,
SMemLayoutA,
cutlass::half_t,
SMemLayoutB,
cutlass::half_t,
cutlass::layout::RowMajor,
Policy
>;
static_assert(Shape::kK % InstructionShape::kK == 0);
static int const kKgroups = Shape::kK / InstructionShape::kK;
using TensorRefA = typename MmaWarp::IteratorA::TensorRef;
using TensorRefB = typename MmaWarp::IteratorB::TensorRef;
typename MmaWarp::FragmentA frag_A[2];
typename MmaWarp::FragmentB frag_B[2];
typename MmaWarp::FragmentC accum;
MmaWarp mma_op;
typename MmaWarp::IteratorA iter_A;
typename MmaWarp::IteratorB iter_B;
const int warp_idx_m_, warp_idx_n_, lane_id_;
public:
CUTLASS_DEVICE
GemmTensorOp(int warp_idx_m, int warp_idx_n, int lane_id)
: warp_idx_m_(warp_idx_m), warp_idx_n_(warp_idx_n), lane_id_(lane_id), iter_A({nullptr, 0}, 0), iter_B({nullptr, 0}, 0) {
accum.clear();
}
CUTLASS_DEVICE
void prologue(const TensorRefA &ref_A, const TensorRefB &ref_B) {
iter_A = typename MmaWarp::IteratorA(ref_A, lane_id_);
iter_B = typename MmaWarp::IteratorB(ref_B, lane_id_);
iter_A.add_tile_offset({warp_idx_m_, 0});
iter_B.add_tile_offset({0, warp_idx_n_});
iter_A.load(frag_A[0]);
iter_B.load(frag_B[0]);
++iter_A;
++iter_B;
}
CUTLASS_DEVICE
void body() {
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < kKgroups - 1; ++k) {
iter_A.load(frag_A[(k + 1) % 2]);
iter_B.load(frag_B[(k + 1) % 2]);
++iter_A;
++iter_B;
mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum);
}
__syncthreads();
}
CUTLASS_DEVICE
void epilogue() {
mma_op(accum, frag_A[(kKgroups - 1) % 2], frag_B[(kKgroups - 1) % 2], accum);
}
CUTLASS_DEVICE
half& operator[](const size_t i) const {
return ((half*)accum.data())[i];
}
CUTLASS_DEVICE
half* operator+(const size_t i) const {
return (half*)accum.data() + i;
}
};
template <
typename Shape,
typename SMemLayoutA,
typename LayoutA,
typename SMemLayoutB,
typename LayoutB,
typename LayoutC
>
class VoltaGemmTensorOp {
public:
using InstructionShape = GemmShape<16, 16, 4>;
using WarpShape = GemmShape<Shape::kM, Shape::kN, InstructionShape::kK>;
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<
InstructionShape,
32,
cutlass::half_t,
LayoutA,
cutlass::half_t,
LayoutB,
cutlass::half_t,
LayoutC,
cutlass::arch::OpMultiplyAdd
>,
cutlass::MatrixShape<1, 1>
>;
using MmaWarp = typename cutlass::gemm::warp::MmaVoltaTensorOp<
WarpShape,
cutlass::half_t,
SMemLayoutA,
cutlass::half_t,
SMemLayoutB,
cutlass::half_t,
LayoutC,
Policy
>;
static_assert(Shape::kK % InstructionShape::kK == 0);
static int const kKgroups = Shape::kK / InstructionShape::kK;
using TensorRefA = typename MmaWarp::IteratorA::TensorRef;
using TensorRefB = typename MmaWarp::IteratorB::TensorRef;
typename MmaWarp::FragmentA frag_A[2];
typename MmaWarp::FragmentB frag_B[2];
typename MmaWarp::FragmentC accum;
typename MmaWarp::IteratorA iter_A;
typename MmaWarp::IteratorB iter_B;
const int warp_idx_m_, warp_idx_n_, lane_id_;
MmaWarp mma_op;
public:
CUTLASS_DEVICE
VoltaGemmTensorOp(int warp_idx_m, int warp_idx_n, int lane_id)
: warp_idx_m_(warp_idx_m), warp_idx_n_(warp_idx_n), lane_id_(lane_id), iter_A({nullptr, 0}, 0), iter_B({nullptr, 0}, 0) {
accum.clear();
}
CUTLASS_DEVICE
void prologue(TensorRefA ref_A, TensorRefB ref_B) {
iter_A = typename MmaWarp::IteratorA(ref_A, lane_id_);
iter_B = typename MmaWarp::IteratorB(ref_B, lane_id_);
iter_A.add_tile_offset({warp_idx_m_, 0});
iter_B.add_tile_offset({0, warp_idx_n_});
iter_A.load(frag_A[0]);
iter_B.load(frag_B[0]);
++iter_A;
++iter_B;
}
CUTLASS_DEVICE
void body() {
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < kKgroups - 1; ++k) {
iter_A.load(frag_A[(k + 1) % 2]);
iter_B.load(frag_B[(k + 1) % 2]);
++iter_A;
++iter_B;
mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum);
}
__syncthreads();
}
CUTLASS_DEVICE
void epilogue() {
mma_op(accum, frag_A[(kKgroups - 1) % 2], frag_B[(kKgroups - 1) % 2], accum);
}
CUTLASS_DEVICE
half& operator[](size_t i) const {
return ((half*)accum.data())[i];
}
CUTLASS_DEVICE
half* operator+(size_t i) const {
return (half*)accum.data() + i;
}
};
}}}
"""
9 changes: 7 additions & 2 deletions test/python/dsl_frontend/kernel_packer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import re
import json

from common_header import *

code_header = '''
#include <cuda_runtime.h>
Expand Down Expand Up @@ -85,7 +85,12 @@ def kernel_slice_to_code(kernel_slice, kernel_name, in_args, out_args, thread_ex
idx2 = kernel_slice.find(') {\n', idx1) + 4
kernel_slice = kernel_slice[:idx1] + 'extern "C" ' + kernel_slice[idx1 : idx2] + '\n'.join(grid_size + block_size) +'\n' + kernel_slice[idx2:]
# add code header
kernel = code_header + '\n' + kernel_slice
header = cuda_default_header
if re.search('cutlass', kernel_slice):
header += cutlass_header
if re.search('half', kernel_slice):
header += cuda_fp16_header
kernel = header + '\n' + kernel_slice
display_inputs = ', '.join([tensor_display(name, prop) for (name, prop) in in_args])
display_outputs = ', '.join([tensor_display(name, prop) for (name, prop) in out_args])
code = f'// LOCAL: {kernel_name} -- {display_inputs} -> {display_outputs}\n\n{kernel}\n'
Expand Down

0 comments on commit fcf5f3f

Please sign in to comment.