diff --git a/src/include/migraphx/op/quant_convolution.hpp b/src/include/migraphx/op/quant_convolution.hpp index fb20eff6b74..5976f9163c2 100644 --- a/src/include/migraphx/op/quant_convolution.hpp +++ b/src/include/migraphx/op/quant_convolution.hpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -87,11 +88,13 @@ struct quant_convolution } // all input type must be int8_type and output is float_type - if(t != shape::int8_type) + std::set supported_types = {shape::int8_type, + shape::fp8e4m3fnuz_type}; + if(not contains(supported_types, t)) { - MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t"); + MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t or " + "fp8e4m3fnuz_type"); } - t = shape::int32_type; std::vector output_lens{input.lens()[0], weights.lens()[0]}; auto padding_size = padding.size(); @@ -107,8 +110,11 @@ struct quant_convolution stride[i] + 1))); } - - return inputs[0].with_lens(t, output_lens); + if(t == shape::int8_type) + { + return inputs[0].with_lens(shape::int32_type, output_lens); + } // else fp8 conv + return inputs[0].with_lens(shape::float_type, output_lens); } size_t kdims() const diff --git a/src/targets/gpu/device_name.cpp b/src/targets/gpu/device_name.cpp index ac38d6e8057..e65b97622f6 100644 --- a/src/targets/gpu/device_name.cpp +++ b/src/targets/gpu/device_name.cpp @@ -49,6 +49,12 @@ std::string get_device_name() return props.gcnArchName; } +bool gfx_has_fp8_intrinsics() +{ + const auto device_name = trim(split_string(get_device_name(), ':').front()); + return (starts_with(device_name, "gfx9") and device_name >= "gfx940"); +} + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 6eab3706c32..f83f41cacc3 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -218,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode) return false; if(ins->name() != "convolution" and ins->name() != "quant_convolution") return false; + auto input_arg_t = ins->inputs().front()->get_shape().type(); value v = ins->get_operator().to_value(); auto group = v.at("group").to(); if(group != 1) @@ -225,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode) // Avoid MLIR assertion: Index < Length && "Invalid index!" if(ins->get_shape().lens().size() != 4) return false; + if(ins->get_shape().type() == shape::fp8e4m3fnuz_type) + return true; + if(ins->get_shape().type() == shape::float_type and input_arg_t == shape::fp8e4m3fnuz_type) + return true; if(ins->get_shape().type() == shape::int8_type) return true; if(mode == mlir_mode::int8) @@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) const auto result_type = i.get_shape().type(); const std::initializer_list allowed_types = {type_t::float_type, type_t::half_type, + type_t::fp8e4m3fnuz_type, type_t::int8_type, type_t::int32_type, type_t::bool_type}; @@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) "softmax", "tanh", }; - bool is_float = contains({type_t::float_type, type_t::half_type}, result_type); + bool is_float = + contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type); if(contains(any_type_ops, name)) return true; if(result_type != type_t::bool_type and contains(no_bool_ops, name)) @@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) // supported. if(is_float and name == "convert") { + if(result_type == shape::fp8e4m3fnuz_type) + { + return false; + } // else return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) { return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type()); }); @@ -404,12 +415,13 @@ struct find_mlir_standalone_op void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto gemm_based_op = r.result; - // - // enable only for fp32/fp16/i8 types + // enable only for fp32/fp16/i8/fp8 types if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) { - return not contains( - {shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type}, - i->get_shape().type()); + return not contains({shape::type_t::float_type, + shape::type_t::half_type, + shape::type_t::int8_type, + shape::type_t::fp8e4m3fnuz_type}, + i->get_shape().type()); })) return; static size_t counter = 0; diff --git a/src/targets/gpu/include/migraphx/gpu/device_name.hpp b/src/targets/gpu/include/migraphx/gpu/device_name.hpp index 44312d1f845..54ea873feea 100644 --- a/src/targets/gpu/include/migraphx/gpu/device_name.hpp +++ b/src/targets/gpu/include/migraphx/gpu/device_name.hpp @@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name(); MIGRAPHX_GPU_EXPORT int get_device_id(); +MIGRAPHX_GPU_EXPORT bool gfx_has_fp8_intrinsics(); + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index cf8967fcee6..c7a296cba95 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -300,6 +300,8 @@ struct mlir_program result = mlirF32TypeGet(ctx.get()); else if(as.type_enum() == shape::half_type) result = mlirF16TypeGet(ctx.get()); + else if(as.type_enum() == shape::fp8e4m3fnuz_type) + result = mlirFloat8E4M3FNUZTypeGet(ctx.get()); else if(as.type_enum() == shape::double_type) result = mlirF64TypeGet(ctx.get()); else if(as.is_integral()) diff --git a/src/targets/gpu/rocblas.cpp b/src/targets/gpu/rocblas.cpp index 59452408801..9697189d921 100644 --- a/src/targets/gpu/rocblas.cpp +++ b/src/targets/gpu/rocblas.cpp @@ -58,8 +58,7 @@ bool rocblas_fp8_available() #ifndef MIGRAPHX_USE_ROCBLAS_FP8_API return false; #else - const auto device_name = trim(split_string(get_device_name(), ':').front()); - return (starts_with(device_name, "gfx9") and device_name >= "gfx940"); + return gfx_has_fp8_intrinsics(); #endif } diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 1db85c4f934..0618d1a83c5 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -105,11 +105,19 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::int32_type); unsupported_types.erase(shape::type_t::tuple_type); + // whiltelist supported Ops for the FP8 std::set unsupported_fp8_ops = {}; if(not gpu::rocblas_fp8_available()) { unsupported_fp8_ops.insert("dot"); } + // MIOpen doesn't have support for fp8 pooling yet. + unsupported_fp8_ops.insert("pooling"); + if(not gpu::gfx_has_fp8_intrinsics()) + { + unsupported_fp8_ops.insert("convolution"); + unsupported_fp8_ops.insert("quant_convolution"); + } // add all device kernels unsupported_fp8_ops.insert("logsoftmax"); unsupported_fp8_ops.insert("nonzero"); diff --git a/test/verify/main.cpp b/test/verify/main.cpp index 9a7f8481226..23404864533 100644 --- a/test/verify/main.cpp +++ b/test/verify/main.cpp @@ -77,6 +77,5 @@ int main(int argc, const char* argv[]) "test_split_single_dyn_dim", "test_instancenorm_large_3d", "test_instancenorm_large_3d"}); - rv.disable_test_for("gpu", {"test_conv_bn_add"}); rv.run(argc, argv); } diff --git a/test/verify/quant_conv.cpp b/test/verify/quant_conv.cpp index 72f32f453f3..616b38b0e04 100644 --- a/test/verify/quant_conv.cpp +++ b/test/verify/quant_conv.cpp @@ -27,17 +27,21 @@ #include #include -struct quant_conv : verify_program +template +struct quant_conv : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + migraphx::shape a_shape{DType, {2, 3, 4, 4}}; auto pa = mm->add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + migraphx::shape c_shape{DType, {2, 3, 3, 3}}; auto pc = mm->add_parameter("c", c_shape); mm->add_instruction(migraphx::make_op("quant_convolution"), pa, pc); return p; } }; + +template struct quant_conv; +template struct quant_conv; diff --git a/test/verify/quant_conv_1.cpp b/test/verify/quant_conv_1.cpp index 928badbd7cb..a13bd5ce3f9 100644 --- a/test/verify/quant_conv_1.cpp +++ b/test/verify/quant_conv_1.cpp @@ -27,17 +27,21 @@ #include #include -struct quant_conv_1 : verify_program +template +struct quant_conv_1 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + migraphx::shape a_shape{DType, {2, 3, 4, 4}}; auto pa = mm->add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + migraphx::shape c_shape{DType, {2, 3, 3, 3}}; auto pc = mm->add_parameter("c", c_shape); mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc); return p; } }; + +template struct quant_conv_1; +template struct quant_conv_1; diff --git a/test/verify/quant_conv_1d.cpp b/test/verify/quant_conv_1d.cpp index 2648134c4e3..069cc4efd22 100644 --- a/test/verify/quant_conv_1d.cpp +++ b/test/verify/quant_conv_1d.cpp @@ -27,15 +27,16 @@ #include #include -struct quant_conv_1d : verify_program +template +struct quant_conv_1d : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4}}; + migraphx::shape a_shape{DType, {2, 3, 4}}; auto pa = mm->add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3}}; + migraphx::shape c_shape{DType, {2, 3, 3}}; auto pc = mm->add_parameter("c", c_shape); mm->add_instruction( migraphx::make_op("quant_convolution", @@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program return p; } }; + +template struct quant_conv_1d; +// MLIR 1D convolution is not supported in MIGraphX yet. Enable this through MIOpen route later. +// template struct quant_conv_1d; diff --git a/test/verify/quant_conv_2.cpp b/test/verify/quant_conv_2.cpp index 9ae561f732b..1873852fee5 100644 --- a/test/verify/quant_conv_2.cpp +++ b/test/verify/quant_conv_2.cpp @@ -27,17 +27,21 @@ #include #include -struct quant_conv_2 : verify_program +template +struct quant_conv_2 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape a_shape{migraphx::shape::int8_type, {16, 16, 4, 4}}; + migraphx::shape a_shape{DType, {16, 16, 4, 4}}; auto pa = mm->add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {16, 16, 3, 3}}; + migraphx::shape c_shape{DType, {16, 16, 3, 3}}; auto pc = mm->add_parameter("c", c_shape); mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc); return p; } }; + +template struct quant_conv_2; +template struct quant_conv_2; diff --git a/test/verify/quant_conv_padding.cpp b/test/verify/quant_conv_padding.cpp index f566c314f4c..29159ef7f81 100644 --- a/test/verify/quant_conv_padding.cpp +++ b/test/verify/quant_conv_padding.cpp @@ -27,15 +27,16 @@ #include #include -struct quant_conv_padding : verify_program +template +struct quant_conv_padding : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + migraphx::shape a_shape{DType, {2, 3, 4, 4}}; auto pa = mm->add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + migraphx::shape c_shape{DType, {2, 3, 3, 3}}; auto pc = mm->add_parameter("c", c_shape); mm->add_instruction( migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}), @@ -44,3 +45,6 @@ struct quant_conv_padding : verify_program return p; } }; + +template struct quant_conv_padding; +template struct quant_conv_padding; diff --git a/test/verify/quant_conv_padding_stride.cpp b/test/verify/quant_conv_padding_stride.cpp index f1c07399fc0..955a3b23352 100644 --- a/test/verify/quant_conv_padding_stride.cpp +++ b/test/verify/quant_conv_padding_stride.cpp @@ -27,15 +27,16 @@ #include #include -struct quant_conv_padding_stride : verify_program +template +struct quant_conv_padding_stride : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + migraphx::shape a_shape{DType, {2, 3, 4, 4}}; auto pa = mm->add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + migraphx::shape c_shape{DType, {2, 3, 3, 3}}; auto pc = mm->add_parameter("c", c_shape); mm->add_instruction( migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}), @@ -45,3 +46,5 @@ struct quant_conv_padding_stride : verify_program return p; } }; +template struct quant_conv_padding_stride; +template struct quant_conv_padding_stride; diff --git a/test/verify/run_verify.cpp b/test/verify/run_verify.cpp index c91a4aba73e..464cfb9d529 100644 --- a/test/verify/run_verify.cpp +++ b/test/verify/run_verify.cpp @@ -142,7 +142,8 @@ std::vector run_verify::run_ref(migraphx::program p, { migraphx::target t = migraphx::make_target("ref"); auto_print pp{p, t.name()}; - compile_check(p, t, c_opts); + auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{}); + compile_check(p, t, c_opts, (trace_target == "ref")); return p.eval(std::move(inputs)); } std::pair> diff --git a/test/verify/test_conv.cpp b/test/verify/test_conv.cpp index 873016bb5a6..9b5d0caef7d 100644 --- a/test/verify/test_conv.cpp +++ b/test/verify/test_conv.cpp @@ -27,17 +27,19 @@ #include #include -struct test_conv : verify_program +template +struct test_conv : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto input = - mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto weights = - mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); mm->add_instruction(migraphx::make_op("convolution"), input, weights); return p; } }; + +template struct test_conv; +template struct test_conv; diff --git a/test/verify/test_conv2.cpp b/test/verify/test_conv2.cpp index e6dea116f20..bbdf9d1a1c2 100644 --- a/test/verify/test_conv2.cpp +++ b/test/verify/test_conv2.cpp @@ -27,16 +27,15 @@ #include #include -struct test_conv2 : verify_program +template +struct test_conv2 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto input = - mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 512, 28, 28}}); - auto weights = - mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {256, 512, 1, 1}}); + auto input = mm->add_parameter("x", migraphx::shape{DType, {1, 512, 28, 28}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {256, 512, 1, 1}}); mm->add_instruction( migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}), @@ -45,3 +44,5 @@ struct test_conv2 : verify_program return p; } }; +template struct test_conv2; +template struct test_conv2; diff --git a/test/verify/test_conv_add.cpp b/test/verify/test_conv_add.cpp index 934a1985709..d97a4c9652f 100644 --- a/test/verify/test_conv_add.cpp +++ b/test/verify/test_conv_add.cpp @@ -27,18 +27,17 @@ #include #include -struct test_conv_add : verify_program +template +struct test_conv_add : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}}); - auto w = mm->add_literal( - migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 1)); - auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); - auto v = mm->add_literal( - migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 2)); + auto x = mm->add_parameter("x", {DType, {1, 8, 4, 4}}); + auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 1)); + auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}}); + auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 2)); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v); auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2); @@ -46,3 +45,6 @@ struct test_conv_add : verify_program return p; } }; + +template struct test_conv_add; +template struct test_conv_add; diff --git a/test/verify/test_conv_add_1x1_diff_strides.cpp b/test/verify/test_conv_add_1x1_diff_strides.cpp index 9e2be95966d..c07467fa99f 100644 --- a/test/verify/test_conv_add_1x1_diff_strides.cpp +++ b/test/verify/test_conv_add_1x1_diff_strides.cpp @@ -27,18 +27,17 @@ #include #include -struct test_conv_add_1x1_diff_strides : verify_program +template +struct test_conv_add_1x1_diff_strides : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 2, 2}}); - auto w = mm->add_literal( - migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 1)); - auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); - auto v = mm->add_literal( - migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 2)); + auto x = mm->add_parameter("x", {DType, {1, 8, 2, 2}}); + auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 1, 1}}, 1)); + auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}}); + auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 1, 1}}, 2)); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); auto conv2 = mm->add_instruction( migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v); @@ -47,3 +46,6 @@ struct test_conv_add_1x1_diff_strides : verify_program; +template struct test_conv_add_1x1_diff_strides; diff --git a/test/verify/test_conv_add_relu.cpp b/test/verify/test_conv_add_relu.cpp index 74533c86a13..2611e2f99d4 100644 --- a/test/verify/test_conv_add_relu.cpp +++ b/test/verify/test_conv_add_relu.cpp @@ -28,18 +28,17 @@ #include #include -struct test_conv_add_relu : verify_program +template +struct test_conv_add_relu : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto input = - mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto weights = - mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto bias_literal = migraphx::literal{migraphx::shape{migraphx::shape::float_type, {4}}, - {2.0f, 2.0f, 2.0f, 2.0f}}; + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); + auto bias_literal = + migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}}; auto bias = mm->add_literal(bias_literal); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto bcast_bias = mm->add_instruction( @@ -50,3 +49,6 @@ struct test_conv_add_relu : verify_program return p; } }; + +template struct test_conv_add_relu; +template struct test_conv_add_relu; diff --git a/test/verify/test_conv_bias_clipped_relu.cpp b/test/verify/test_conv_bias_clipped_relu.cpp index bd9fc3bff07..28844dbfb55 100644 --- a/test/verify/test_conv_bias_clipped_relu.cpp +++ b/test/verify/test_conv_bias_clipped_relu.cpp @@ -29,26 +29,24 @@ #include -struct test_conv_bias_clipped_relu : verify_program +template +struct test_conv_bias_clipped_relu : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto input = - mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto weights = - mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto l0 = migraphx::literal{migraphx::shape{migraphx::shape::float_type, {4}}, - {2.0f, 2.0f, 2.0f, 2.0f}}; + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); + auto l0 = migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}}; auto bias = mm->add_literal(l0); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto bcast_add = mm->add_instruction( migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), bias); auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv, bcast_add); - auto min_val = mm->add_literal(0.0f); - auto max_val = mm->add_literal(6.0f); + auto min_val = mm->add_literal(migraphx::literal(DType, {0.0f})); + auto max_val = mm->add_literal(migraphx::literal(DType, {6.0f})); min_val = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", conv->get_shape().lens()}}), min_val); max_val = mm->add_instruction( @@ -57,3 +55,6 @@ struct test_conv_bias_clipped_relu : verify_program return p; } }; + +template struct test_conv_bias_clipped_relu; +template struct test_conv_bias_clipped_relu; diff --git a/test/verify/test_conv_bn.cpp b/test/verify/test_conv_bn.cpp index cda424de5c1..5f356636efe 100644 --- a/test/verify/test_conv_bn.cpp +++ b/test/verify/test_conv_bn.cpp @@ -29,16 +29,17 @@ #include #include -struct test_conv_bn : verify_program +template +struct test_conv_bn : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}}; - migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}}; - migraphx::shape vars{migraphx::shape::float_type, {64}}; + migraphx::shape xs{DType, {1, 3, 224, 224}}; + migraphx::shape ws{DType, {64, 3, 7, 7}}; + migraphx::shape vars{DType, {64}}; auto x = mm->add_parameter("x", xs); auto w = mm->add_parameter("w", ws); // non-symmetrical tiling @@ -53,8 +54,14 @@ struct test_conv_bn : verify_program auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); - auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); - auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); + auto rt = mm->add_literal(migraphx::literal{DType, {0.5}}); + + auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}}); + if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type) + { + // use 5e-2f for the fp8 + eps = mm->add_literal(migraphx::literal{DType, {5e-2f}}); + } auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); @@ -74,3 +81,6 @@ struct test_conv_bn : verify_program return p; } }; + +template struct test_conv_bn; +template struct test_conv_bn; diff --git a/test/verify/test_conv_bn_add.cpp b/test/verify/test_conv_bn_add.cpp index 3433314ecad..52a8486456a 100644 --- a/test/verify/test_conv_bn_add.cpp +++ b/test/verify/test_conv_bn_add.cpp @@ -29,22 +29,27 @@ #include #include -struct test_conv_bn_add : verify_program +template +struct test_conv_bn_add : verify_program> { static migraphx::instruction_ref add_bn(migraphx::module& m, migraphx::instruction_ref x) { auto bn_lens = x->get_shape().lens(); auto c_len = bn_lens.at(1); - migraphx::shape vars{migraphx::shape::float_type, {c_len}}; + migraphx::shape vars{DType, {c_len}}; auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + c_len))); auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + c_len))); auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + c_len))); auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + c_len))); - auto rt = m.add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); - auto eps = m.add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); - + auto rt = m.add_literal(migraphx::literal{DType, {0.5}}); + auto eps = m.add_literal(migraphx::literal{DType, {1e-5f}}); + if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type) + { + // use 5e-2f for the fp8 + eps = m.add_literal(migraphx::literal{DType, {5e-2f}}); + } auto usq_scale = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); auto usq_bias = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias); @@ -66,12 +71,12 @@ struct test_conv_bn_add : verify_program auto* mm = p.get_main_module(); std::size_t ichannels = 64; std::size_t ochannels = 256; - auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, 56}}); - auto w = mm->add_literal(migraphx::generate_literal( - {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1)); - auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56, 56}}); - auto v = mm->add_literal(migraphx::generate_literal( - {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2)); + auto x = mm->add_parameter("x", {DType, {1, ichannels, 56, 56}}); + auto w = + mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 1)); + auto y = mm->add_parameter("y", {DType, {1, ichannels, 56, 56}}); + auto v = + mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 2)); auto relu1 = mm->add_instruction(migraphx::make_op("relu"), x); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), relu1, w); auto bn1 = add_bn(*mm, conv1); @@ -83,3 +88,6 @@ struct test_conv_bn_add : verify_program return p; } }; + +template struct test_conv_bn_add; +template struct test_conv_bn_add; diff --git a/test/verify/test_conv_bn_relu_pooling.cpp b/test/verify/test_conv_bn_relu_pooling.cpp index 4d4c8abb1f4..4b283779c45 100644 --- a/test/verify/test_conv_bn_relu_pooling.cpp +++ b/test/verify/test_conv_bn_relu_pooling.cpp @@ -30,16 +30,17 @@ #include #include -struct test_conv_bn_relu_pooling : verify_program +template +struct test_conv_bn_relu_pooling : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}}; - migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}}; - migraphx::shape vars{migraphx::shape::float_type, {64}}; + migraphx::shape xs{DType, {1, 3, 224, 224}}; + migraphx::shape ws{DType, {64, 3, 7, 7}}; + migraphx::shape vars{DType, {64}}; auto x = mm->add_parameter("x", xs); auto w = mm->add_parameter("w", ws); auto conv = mm->add_instruction( @@ -52,9 +53,13 @@ struct test_conv_bn_relu_pooling : verify_program auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); - auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); - auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); - + auto rt = mm->add_literal(migraphx::literal{DType, {0.5}}); + auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}}); + if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type) + { + // use 5e-2f for the fp8 + eps = mm->add_literal(migraphx::literal{DType, {5e-2f}}); + } auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); auto usq_bias = @@ -82,3 +87,6 @@ struct test_conv_bn_relu_pooling : verify_program return p; } }; + +template struct test_conv_bn_relu_pooling; +template struct test_conv_bn_relu_pooling; diff --git a/test/verify/test_conv_bn_relu_pooling2.cpp b/test/verify/test_conv_bn_relu_pooling2.cpp index 39abacd7c28..3bf9e907d97 100644 --- a/test/verify/test_conv_bn_relu_pooling2.cpp +++ b/test/verify/test_conv_bn_relu_pooling2.cpp @@ -30,22 +30,27 @@ #include #include -struct test_conv_bn_relu_pooling2 : verify_program +template +struct test_conv_bn_relu_pooling2 : verify_program> { static migraphx::instruction_ref add_bn(migraphx::module& m, migraphx::instruction_ref x) { auto bn_lens = x->get_shape().lens(); auto c_len = bn_lens.at(1); - migraphx::shape vars{migraphx::shape::float_type, {c_len}}; + migraphx::shape vars{DType, {c_len}}; auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + c_len))); auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + c_len))); auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + c_len))); auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + c_len))); - auto rt = m.add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); - auto eps = m.add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); - + auto rt = m.add_literal(migraphx::literal{DType, {0.5}}); + auto eps = m.add_literal(migraphx::literal{DType, {1e-5f}}); + if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type) + { + // use 5e-2f for the fp8 + eps = m.add_literal(migraphx::literal{DType, {5e-2f}}); + } auto usq_scale = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); auto usq_bias = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias); @@ -66,10 +71,10 @@ struct test_conv_bn_relu_pooling2 : verify_program migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape xs1{migraphx::shape::float_type, {1, 512, 7, 7}}; - migraphx::shape xs2{migraphx::shape::float_type, {1, 1024, 14, 14}}; - migraphx::shape ws1{migraphx::shape::float_type, {2048, 512, 1, 1}}; - migraphx::shape ws2{migraphx::shape::float_type, {2048, 1024, 1, 1}}; + migraphx::shape xs1{DType, {1, 512, 7, 7}}; + migraphx::shape xs2{DType, {1, 1024, 14, 14}}; + migraphx::shape ws1{DType, {2048, 512, 1, 1}}; + migraphx::shape ws2{DType, {2048, 1024, 1, 1}}; auto x1 = mm->add_parameter("x1", xs1); auto w1 = mm->add_parameter("w1", ws1); auto conv1 = mm->add_instruction( @@ -98,3 +103,6 @@ struct test_conv_bn_relu_pooling2 : verify_program return p; } }; + +template struct test_conv_bn_relu_pooling2; +template struct test_conv_bn_relu_pooling2; diff --git a/test/verify/test_conv_group_add.cpp b/test/verify/test_conv_group_add.cpp index 0fc8569d211..ff8747b616d 100644 --- a/test/verify/test_conv_group_add.cpp +++ b/test/verify/test_conv_group_add.cpp @@ -27,16 +27,17 @@ #include #include -struct test_conv_group_add : verify_program +template +struct test_conv_group_add : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {1, 68, 28, 28}}; + migraphx::shape s{DType, {1, 68, 28, 28}}; auto x = mm->add_parameter("x", s); - auto w = mm->add_parameter("w", {migraphx::shape::float_type, {68, 17, 1, 1}}); - auto b = mm->add_parameter("b", {migraphx::shape::float_type, {68}}); + auto w = mm->add_parameter("w", {DType, {68, 17, 1, 1}}); + auto b = mm->add_parameter("b", {DType, {68}}); auto conv = mm->add_instruction(migraphx::make_op("convolution", {{"group", 4}}), x, w); auto bb = mm->add_instruction( migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 68, 28, 28}}}), b); @@ -44,3 +45,6 @@ struct test_conv_group_add : verify_program return p; } }; +template struct test_conv_group_add; +// grouped convolutions are not supported with MLIR therefore disable it +// template struct test_conv_group_add; diff --git a/test/verify/test_conv_pooling.cpp b/test/verify/test_conv_pooling.cpp index d4e7b7b66af..4fbe8f17c65 100644 --- a/test/verify/test_conv_pooling.cpp +++ b/test/verify/test_conv_pooling.cpp @@ -28,16 +28,15 @@ #include #include -struct test_conv_pooling : verify_program +template +struct test_conv_pooling : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto input = - mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 32, 32}}); - auto weights = - mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 32, 32}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto pooling = mm->add_instruction( migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv); @@ -45,3 +44,6 @@ struct test_conv_pooling : verify_program return p; } }; + +template struct test_conv_pooling; +template struct test_conv_pooling; diff --git a/test/verify/test_conv_relu.cpp b/test/verify/test_conv_relu.cpp index 312cac4f6a5..aa9af88bf01 100644 --- a/test/verify/test_conv_relu.cpp +++ b/test/verify/test_conv_relu.cpp @@ -27,18 +27,20 @@ #include #include -struct test_conv_relu : verify_program +template +struct test_conv_relu : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto input = - mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto weights = - mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); mm->add_instruction(migraphx::make_op("relu"), conv); return p; } }; +template struct test_conv_relu; +template struct test_conv_relu; +template struct test_conv_relu; diff --git a/test/verify/test_conv_relu_half.cpp b/test/verify/test_conv_relu_half.cpp deleted file mode 100644 index a61865c7347..00000000000 --- a/test/verify/test_conv_relu_half.cpp +++ /dev/null @@ -1,44 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include "verify_program.hpp" -#include -#include -#include - -struct test_conv_relu_half : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - auto input = - mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); - auto weights = - mm->add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); - auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); - mm->add_instruction(migraphx::make_op("relu"), conv); - return p; - } -};