Skip to content

Commit

Permalink
Add FP8 rocblas gemm support (#2473)
Browse files Browse the repository at this point in the history
  • Loading branch information
umangyadav authored Dec 6, 2023
1 parent e3e0054 commit 6d0b6bc
Show file tree
Hide file tree
Showing 48 changed files with 490 additions and 312 deletions.
92 changes: 72 additions & 20 deletions src/eliminate_data_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,72 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void insert_convert_to_supported_type(module& m,
instruction_ref ins,
migraphx::shape::type_t target_type,
std::set<migraphx::shape::type_t> unsupported_types)
{
migraphx::shape::type_t orig_type = ins->get_shape().type();
std::vector<instruction_ref> inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](const auto& i) {
if(contains(unsupported_types, i->get_shape().type()))
{
return m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}),
i);
}
else
{
return i;
}
});
// if no change
if(inputs == ins->inputs())
return;
auto op = ins->get_operator();
auto attributes = op.attributes();
if(attributes.contains("general_data_type"))
{
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
}
auto new_ins = m.insert_instruction(ins, op, inputs);
if(orig_type == shape::tuple_type)
{
auto orig_outs = ins->outputs();
if(not std::all_of(orig_outs.begin(), orig_outs.end(), [&](const auto out_ins) {
return out_ins->name() == "get_tuple_elem";
}))
MIGRAPHX_THROW(
"eliminate_data_type: Instruction with tuple output doesn't have all its "
"usages as get_tuple_elem instruction");

std::transform(
orig_outs.begin(), orig_outs.end(), orig_outs.begin(), [&](const auto out_ins) {
auto gte_ins = m.insert_instruction(ins, out_ins->get_operator(), new_ins);
auto orig_out_type = out_ins->get_shape().type();
if(contains(unsupported_types, orig_out_type))
{
auto gte_convert = m.insert_instruction(
ins, make_op("convert", {{"target_type", orig_out_type}}), gte_ins);
return m.replace_instruction(out_ins, gte_convert);
}
else
{
return m.replace_instruction(out_ins, gte_ins);
}
});
}
else
{
auto convert_back_ins = m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}),
new_ins);
m.replace_instruction(ins, convert_back_ins);
}
}

void eliminate_data_type::apply(module& m) const
{
static const std::vector<std::string> skip_op_names = {"convert",
Expand All @@ -42,31 +108,17 @@ void eliminate_data_type::apply(module& m) const
"scatternd_add",
"scatternd_mul",
"scatternd_none"};
if(unsupported_types.empty())
return;

for(auto ins : iterator_for(m))
{
if(ins->name()[0] == '@')
continue;
if(contains(skip_op_names, ins->name()))
continue;
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
if(types.count(i->get_shape().type()) == 0)
return i;
return m.insert_instruction(ins, make_op("convert", {{"target_type", target_type}}), i);
});
if(inputs == ins->inputs())
if(contains(skip_op_names, ins->name()) and not contains(unsupported_ops, ins->name()))
continue;
auto op = ins->get_operator();
auto attributes = op.attributes();
if(attributes.contains("general_data_type"))
{
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
}
auto old_type = ins->get_shape().type();
auto out = m.insert_instruction(ins, op, inputs);
auto convert =
m.insert_instruction(ins, make_op("convert", {{"target_type", old_type}}), out);
m.replace_instruction(ins, convert);
if(contains(unsupported_ops, "all") or contains(unsupported_ops, ins->name()))
insert_convert_to_supported_type(m, ins, target_type, unsupported_types);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/include/migraphx/eliminate_data_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ struct module;
*/
struct MIGRAPHX_EXPORT eliminate_data_type
{
std::set<shape::type_t> types;
std::set<shape::type_t> unsupported_types;
shape::type_t target_type;
std::set<std::string> unsupported_ops = {"all"};
std::string name() const { return "eliminate_data_type"; }
void apply(module& m) const;
};
Expand Down
9 changes: 9 additions & 0 deletions src/targets/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
# rocblas FP8 API
check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API)

set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")

Expand Down Expand Up @@ -288,6 +290,13 @@ else()
message(STATUS "rocBLAS does not have User Tuning Beta API")
endif()

if(HAS_ROCBLAS_FP8_BETA_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphX is using Beta API of rocBLAS for FP8 computations")
else()
message(STATUS "rocBLAS does not have Fp8 Beta API")
endif()

target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
Expand Down
85 changes: 68 additions & 17 deletions src/targets/gpu/gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,35 @@
* THE SOFTWARE.
*/

#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
#include <type_traits>

using microseconds = std::chrono::duration<double, std::micro>;

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct rb_compute_type
{
int type = 0;
rb_compute_type(rocblas_datatype t) : type(static_cast<int>(t)) {}
rb_compute_type(rocblas_computetype t) : type(static_cast<int>(t)) {}
operator rocblas_datatype() const { return static_cast<rocblas_datatype>(type); }
operator rocblas_computetype() const { return static_cast<rocblas_computetype>(type); }
};

// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type)
{
Expand All @@ -46,7 +63,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::fp8e4m3fnuz_type:
case shape::fp8e4m3fnuz_type: return rocblas_datatype_f8_r;
case shape::tuple_type:
case shape::bool_type:
case shape::uint16_type:
Expand Down Expand Up @@ -183,12 +200,17 @@ struct gemm_impl
{
output_type = rocblas_datatype_i32_r;
}
compute_type = output_type;
compute_type = rb_compute_type{output_type};
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
if(arg_type == rocblas_datatype_f8_r)
{
assert(get_type(input_shapes[1].type()) == rocblas_datatype_f8_r);
compute_type = rocblas_compute_type_f32;
}

auto a_lens = input_shapes[0].lens();
auto b_lens = input_shapes[1].lens();
Expand Down Expand Up @@ -217,23 +239,52 @@ struct gemm_impl

void run(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx = 0) const
{
if(strided_batched)
#ifdef MIGRAPHX_USE_ROCBLAS_FP8_API
if(rocblas_fp8_available() and
std::any_of(input_args.begin(), input_args.end(), [](const auto i) {
return i.get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex3,
common_args,
rocblas_gemm_algo_standard,
solution_idx,
gemm_flags);
}
}
else
#endif
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
}
}
}

Expand Down Expand Up @@ -331,7 +382,6 @@ struct gemm_impl
num_matrices,
compute_type);
}

/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls.
Expand Down Expand Up @@ -366,6 +416,7 @@ struct gemm_impl
ldd,
compute_type);
}

#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
Expand Down Expand Up @@ -481,8 +532,8 @@ struct gemm_impl
rocblas_int b_stride = 0;
rocblas_int c_stride = 0;
rocblas_int d_stride = 0;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_datatype arg_type = rocblas_datatype_f32_r;
rb_compute_type compute_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true;
bool is_3inputs = true;
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/rocblas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ struct context;

MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag();

MIGRAPHX_GPU_EXPORT bool rocblas_fp8_available();

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Expand Down
8 changes: 2 additions & 6 deletions src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,7 @@ class numeric_limits<fp8e5m2fnuz>
{
return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits());
}
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
static constexpr __device__ fp8e5m2fnuz min()
{
return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits());
Expand All @@ -528,9 +526,7 @@ class numeric_limits<fp8e5m2>
}

static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); }

static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); }
Expand Down
10 changes: 10 additions & 0 deletions src/targets/gpu/rocblas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ bool get_compute_fp32_flag()
return (starts_with(device_name, "gfx9") and device_name >= "gfx908");
}

bool rocblas_fp8_available()
{
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return false;
#else
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
#endif
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
7 changes: 7 additions & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type);
std::set<std::string> unsupported_fp8_ops = {};
if(not gpu::rocblas_fp8_available())
{
unsupported_fp8_ops.insert("dot");
}
// clang-format off
return
{
Expand Down Expand Up @@ -136,6 +141,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{},
eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type}, shape::float_type, unsupported_fp8_ops},
dead_code_elimination{},
optimize_module{},
fuse_pointwise{},
dead_code_elimination{},
Expand Down
11 changes: 8 additions & 3 deletions test/verify/gemm_2args_bmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
template <migraphx::shape::type_t DType>
struct gemm_2args_bmv : verify_program<gemm_2args_bmv<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
migraphx::shape m1_shape{DType, {2, 3, 3, 5}};
migraphx::shape m2_shape{DType, {5}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2);
Expand All @@ -46,3 +47,7 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
return p;
}
};

template struct gemm_2args_bmv<migraphx::shape::float_type>;
template struct gemm_2args_bmv<migraphx::shape::half_type>;
template struct gemm_2args_bmv<migraphx::shape::fp8e4m3fnuz_type>;
Loading

0 comments on commit 6d0b6bc

Please sign in to comment.