Skip to content

Commit

Permalink
FP8 2D forward convolution using rocMLIR (#2507)
Browse files Browse the repository at this point in the history
  • Loading branch information
umangyadav authored Dec 7, 2023
1 parent a09dc50 commit 6a72e8f
Show file tree
Hide file tree
Showing 29 changed files with 231 additions and 164 deletions.
16 changes: 11 additions & 5 deletions src/include/migraphx/op/quant_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <migraphx/op/common.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/value.hpp>
Expand Down Expand Up @@ -87,11 +88,13 @@ struct quant_convolution
}

// all input type must be int8_type and output is float_type
if(t != shape::int8_type)
std::set<migraphx::shape::type_t> supported_types = {shape::int8_type,
shape::fp8e4m3fnuz_type};
if(not contains(supported_types, t))
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t");
MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t or "
"fp8e4m3fnuz_type");
}
t = shape::int32_type;

std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto padding_size = padding.size();
Expand All @@ -107,8 +110,11 @@ struct quant_convolution
stride[i] +
1)));
}

return inputs[0].with_lens(t, output_lens);
if(t == shape::int8_type)
{
return inputs[0].with_lens(shape::int32_type, output_lens);
} // else fp8 conv
return inputs[0].with_lens(shape::float_type, output_lens);
}

size_t kdims() const
Expand Down
6 changes: 6 additions & 0 deletions src/targets/gpu/device_name.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ std::string get_device_name()
return props.gcnArchName;
}

bool gfx_has_fp8_intrinsics()
{
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
24 changes: 18 additions & 6 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,18 @@ auto is_mlir_conv(mlir_mode mode)
return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
auto input_arg_t = ins->inputs().front()->get_shape().type();
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if(group != 1)
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4)
return false;
if(ins->get_shape().type() == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::float_type and input_arg_t == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::int8_type)
return true;
if(mode == mlir_mode::int8)
Expand Down Expand Up @@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::fp8e4m3fnuz_type,
type_t::int8_type,
type_t::int32_type,
type_t::bool_type};
Expand Down Expand Up @@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
"softmax",
"tanh",
};
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
bool is_float =
contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name))
Expand All @@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
// supported.
if(is_float and name == "convert")
{
if(result_type == shape::fp8e4m3fnuz_type)
{
return false;
} // else
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
});
Expand Down Expand Up @@ -404,12 +415,13 @@ struct find_mlir_standalone_op
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto gemm_based_op = r.result;
//
// enable only for fp32/fp16/i8 types
// enable only for fp32/fp16/i8/fp8 types
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type());
return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::int8_type,
shape::type_t::fp8e4m3fnuz_type},
i->get_shape().type());
}))
return;
static size_t counter = 0;
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/device_name.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();

MIGRAPHX_GPU_EXPORT int get_device_id();

MIGRAPHX_GPU_EXPORT bool gfx_has_fp8_intrinsics();

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ struct mlir_program
result = mlirF32TypeGet(ctx.get());
else if(as.type_enum() == shape::half_type)
result = mlirF16TypeGet(ctx.get());
else if(as.type_enum() == shape::fp8e4m3fnuz_type)
result = mlirFloat8E4M3FNUZTypeGet(ctx.get());
else if(as.type_enum() == shape::double_type)
result = mlirF64TypeGet(ctx.get());
else if(as.is_integral())
Expand Down
3 changes: 1 addition & 2 deletions src/targets/gpu/rocblas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ bool rocblas_fp8_available()
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return false;
#else
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
return gfx_has_fp8_intrinsics();
#endif
}

Expand Down
8 changes: 8 additions & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,19 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type);
// whiltelist supported Ops for the FP8
std::set<std::string> unsupported_fp8_ops = {};
if(not gpu::rocblas_fp8_available())
{
unsupported_fp8_ops.insert("dot");
}
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops.insert("pooling");
if(not gpu::gfx_has_fp8_intrinsics())
{
unsupported_fp8_ops.insert("convolution");
unsupported_fp8_ops.insert("quant_convolution");
}
// add all device kernels
unsupported_fp8_ops.insert("logsoftmax");
unsupported_fp8_ops.insert("nonzero");
Expand Down
1 change: 0 additions & 1 deletion test/verify/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,5 @@ int main(int argc, const char* argv[])
"test_split_single_dyn_dim",
"test_instancenorm_large_3d<migraphx::shape::float_type>",
"test_instancenorm_large_3d<migraphx::shape::half_type>"});
rv.disable_test_for("gpu", {"test_conv_bn_add"});
rv.run(argc, argv);
}
10 changes: 7 additions & 3 deletions test/verify/quant_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct quant_conv : verify_program<quant_conv>
template <migraphx::shape::type_t DType>
struct quant_conv : verify_program<quant_conv<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::make_op("quant_convolution"), pa, pc);
return p;
}
};

template struct quant_conv<migraphx::shape::int8_type>;
template struct quant_conv<migraphx::shape::fp8e4m3fnuz_type>;
10 changes: 7 additions & 3 deletions test/verify/quant_conv_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>

struct quant_conv_1 : verify_program<quant_conv_1>
template <migraphx::shape::type_t DType>
struct quant_conv_1 : verify_program<quant_conv_1<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc);
return p;
}
};

template struct quant_conv_1<migraphx::shape::int8_type>;
template struct quant_conv_1<migraphx::shape::fp8e4m3fnuz_type>;
11 changes: 8 additions & 3 deletions test/verify/quant_conv_1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct quant_conv_1d : verify_program<quant_conv_1d>
template <migraphx::shape::type_t DType>
struct quant_conv_1d : verify_program<quant_conv_1d<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4}};
migraphx::shape a_shape{DType, {2, 3, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3}};
migraphx::shape c_shape{DType, {2, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(
migraphx::make_op("quant_convolution",
Expand All @@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program<quant_conv_1d>
return p;
}
};

template struct quant_conv_1d<migraphx::shape::int8_type>;
// MLIR 1D convolution is not supported in MIGraphX yet. Enable this through MIOpen route later.
// template struct quant_conv_1d<migraphx::shape::fp8e4m3fnuz_type>;
10 changes: 7 additions & 3 deletions test/verify/quant_conv_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>

struct quant_conv_2 : verify_program<quant_conv_2>
template <migraphx::shape::type_t DType>
struct quant_conv_2 : verify_program<quant_conv_2<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {16, 16, 4, 4}};
migraphx::shape a_shape{DType, {16, 16, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {16, 16, 3, 3}};
migraphx::shape c_shape{DType, {16, 16, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc);
return p;
}
};

template struct quant_conv_2<migraphx::shape::int8_type>;
template struct quant_conv_2<migraphx::shape::fp8e4m3fnuz_type>;
10 changes: 7 additions & 3 deletions test/verify/quant_conv_padding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct quant_conv_padding : verify_program<quant_conv_padding>
template <migraphx::shape::type_t DType>
struct quant_conv_padding : verify_program<quant_conv_padding<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}),
Expand All @@ -44,3 +45,6 @@ struct quant_conv_padding : verify_program<quant_conv_padding>
return p;
}
};

template struct quant_conv_padding<migraphx::shape::int8_type>;
template struct quant_conv_padding<migraphx::shape::fp8e4m3fnuz_type>;
9 changes: 6 additions & 3 deletions test/verify/quant_conv_padding_stride.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
template <migraphx::shape::type_t DType>
struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}),
Expand All @@ -45,3 +46,5 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
return p;
}
};
template struct quant_conv_padding_stride<migraphx::shape::int8_type>;
template struct quant_conv_padding_stride<migraphx::shape::fp8e4m3fnuz_type>;
3 changes: 2 additions & 1 deletion test/verify/run_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p,
{
migraphx::target t = migraphx::make_target("ref");
auto_print pp{p, t.name()};
compile_check(p, t, c_opts);
auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
compile_check(p, t, c_opts, (trace_target == "ref"));
return p.eval(std::move(inputs));
}
std::pair<migraphx::program, std::vector<migraphx::argument>>
Expand Down
12 changes: 7 additions & 5 deletions test/verify/test_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,19 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct test_conv : verify_program<test_conv>
template <migraphx::shape::type_t DType>
struct test_conv : verify_program<test_conv<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("convolution"), input, weights);
return p;
}
};

template struct test_conv<migraphx::shape::float_type>;
template struct test_conv<migraphx::shape::fp8e4m3fnuz_type>;
11 changes: 6 additions & 5 deletions test/verify/test_conv2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct test_conv2 : verify_program<test_conv2>
template <migraphx::shape::type_t DType>
struct test_conv2 : verify_program<test_conv2<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 512, 28, 28}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {256, 512, 1, 1}});
auto input = mm->add_parameter("x", migraphx::shape{DType, {1, 512, 28, 28}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {256, 512, 1, 1}});
mm->add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
Expand All @@ -45,3 +44,5 @@ struct test_conv2 : verify_program<test_conv2>
return p;
}
};
template struct test_conv2<migraphx::shape::float_type>;
template struct test_conv2<migraphx::shape::fp8e4m3fnuz_type>;
Loading

0 comments on commit 6a72e8f

Please sign in to comment.