diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index eda6ea626e4..613b21e662e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -57,6 +57,7 @@ add_library(migraphx file_buffer.cpp fileutils.cpp fp_to_double.cpp + fp8_ocp_to_fnuz.cpp fuse_concat.cpp fuse_pointwise.cpp fuse_pointwise_reduce.cpp diff --git a/src/cpp_generator.cpp b/src/cpp_generator.cpp index 433ccaadb5b..292d42e3e09 100644 --- a/src/cpp_generator.cpp +++ b/src/cpp_generator.cpp @@ -220,8 +220,8 @@ cpp_generator::function cpp_generator::generate_module(const module& m, if(x < 0) string_literal = "-__builtin_huge_val()"; } - else if(std::isnan(static_cast(x))) - string_literal = "__builtin_nan()"; + else if(std::isnan(x)) + string_literal = "__builtin_nan(\"0\")"; else string_literal = ins->get_literal().to_string(); }); diff --git a/src/fp8_ocp_to_fnuz.cpp b/src/fp8_ocp_to_fnuz.cpp new file mode 100644 index 00000000000..305ca6058f1 --- /dev/null +++ b/src/fp8_ocp_to_fnuz.cpp @@ -0,0 +1,178 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace { + +using fp8::fp8e4m3fnuz; + +std::unordered_set get_quantizable_op_names() +{ + static std::unordered_set s = {"convolution", "dot"}; + return s; +} + +struct match_fp8ocp_convert_to_fp8fnuz +{ + auto matcher() const + { + auto dq1 = match::arg(0)( + skip_post_dq_ops(match::dequantizelinear_op("scale1", "zp1").bind("dq1"))); + auto dq2 = match::arg(1)( + skip_post_dq_ops(match::dequantizelinear_op("scale2", "zp2").bind("dq2"))); + return match::name(get_quantizable_op_names())(dq1, dq2); + } + + static auto bit_cast_and_handle_specials(module& m, + const instruction_ref dq, + const instruction_ref x, + const instruction_ref bits_0x80_lit, + const instruction_ref bits_0x7f_lit, + const instruction_ref bits_0xff_lit, + const instruction_ref bits_0x00_lit) + { + auto x_lens = x->get_shape().lens(); + auto cast_input = m.insert_instruction( + dq, make_op("bit_cast", {{"target_type", shape::fp8e4m3fnuz_type}}), x); + auto mb_bits_0x80_lit = m.insert_instruction( + dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x80_lit); + auto mb_bits_0x7f_lit = m.insert_instruction( + dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x7f_lit); + auto mb_bits_0xff_lit = m.insert_instruction( + dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0xff_lit); + auto mb_zero_lit = m.insert_instruction( + dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x00_lit); + // negative zero in fp8e4m3fn to zero in fp8e4m3fnuz + // a == 0x80 ? 0x0 : a + auto is_neg_zero = m.insert_instruction(dq, make_op("equal"), cast_input, mb_bits_0x80_lit); + auto ret = m.insert_instruction(dq, make_op("where"), is_neg_zero, mb_zero_lit, cast_input); + + // positive and negative NaN in fp8e4m3fn to NaN in fp8e4m3fnuz + // (a == 0x7f or a == 0xff) ? 0x80 : a + auto eq_0x7f = m.insert_instruction(dq, make_op("equal"), ret, mb_bits_0x7f_lit); + + auto eq_0xff = m.insert_instruction(dq, make_op("equal"), ret, mb_bits_0xff_lit); + + auto cond = m.insert_instruction(dq, make_op("logical_or"), eq_0x7f, eq_0xff); + ret = m.insert_instruction(dq, make_op("where"), cond, mb_bits_0x80_lit, ret); + return ret; + } + + // Add the same broadcast instructions after adjusted scales or + // adjusted zero points from after the originals. Similar to + // propagate_quantized_ins in simplify_qdq. + static auto propagate_broadcasts(module& m, + const instruction_ref adj, + const instruction_ref ori, + const instruction_ref start, + const instruction_ref insert_pt) + { + auto prev_ins = start; + std::vector ins_between; + // matcher skips continguous, multi/broadcasts and transposes, collect all those + // instructions + while(prev_ins != ori) + { + ins_between.push_back(prev_ins); + prev_ins = prev_ins->inputs().front(); + } + auto ret = adj; + for(auto ins : reverse_iterator_for(ins_between)) + { + ret = m.insert_instruction(insert_pt, (*ins)->get_operator(), {ret}); + } + return ret; + } + + static auto cast_to_fnuz(module& m, + const instruction_ref dq, + const instruction_ref input, + const instruction_ref dq_scale, + const instruction_ref dq_zp) + { + auto x = input; + std::vector bits_0x80 = {fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits())}; + auto bits_0x80_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x80); + + std::vector bits_0x7f = {fp8e4m3fnuz(0x7f, fp8e4m3fnuz::from_bits())}; + auto bits_0x7f_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x7f); + + std::vector bits_0xff = {fp8e4m3fnuz(0xff, fp8e4m3fnuz::from_bits())}; + auto bits_0xff_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0xff); + + std::vector bits_0x00 = {fp8e4m3fnuz(0x00, fp8e4m3fnuz::from_bits())}; + auto bits_0x00_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x00); + + x = bit_cast_and_handle_specials( + m, dq, x, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit); + auto adj_dq_zp = bit_cast_and_handle_specials( + m, dq, dq_zp, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit); + + // adj_scale = 2 * scale + auto two_lit = m.add_literal(literal{shape{dq_scale->get_shape().type()}, {2}}); + two_lit = m.insert_instruction( + dq, make_op("multibroadcast", {{"out_lens", dq_scale->get_shape().lens()}}), two_lit); + auto adj_dq_scale = m.insert_instruction(dq, make_op("mul"), dq_scale, two_lit); + + adj_dq_scale = propagate_broadcasts(m, adj_dq_scale, dq_scale, dq->inputs().at(1), dq); + adj_dq_zp = propagate_broadcasts(m, adj_dq_zp, dq_zp, dq->inputs().at(2), dq); + m.replace_instruction(dq, make_op("dequantizelinear"), x, adj_dq_scale, adj_dq_zp); + } + + auto apply(module& m, const match::matcher_result& r) const + { + auto dq1 = r.instructions["dq1"]; + auto dq2 = r.instructions["dq2"]; + auto scale1 = r.instructions["scale1"]; + auto scale2 = r.instructions["scale2"]; + auto zp1 = r.instructions["zp1"]; + auto zp2 = r.instructions["zp2"]; + + std::set supported_types = {migraphx::shape::fp8e4m3fn_type}; + if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or + not contains(supported_types, dq2->inputs().front()->get_shape().type())) + return; + + cast_to_fnuz(m, dq1, dq1->inputs().front(), scale1, zp1); + cast_to_fnuz(m, dq2, dq2->inputs().front(), scale2, zp2); + } +}; + +} // namespace + +void fp8_ocp_to_fnuz::apply(module_pass_manager& mpm) const +{ + module_ref mm = &mpm.get_module(); + match::find_matches(*mm, match_fp8ocp_convert_to_fp8fnuz{}); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/include/migraphx/fp8_ocp_to_fnuz.hpp b/src/include/migraphx/fp8_ocp_to_fnuz.hpp new file mode 100644 index 00000000000..19e4a1cda02 --- /dev/null +++ b/src/include/migraphx/fp8_ocp_to_fnuz.hpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_RTGLIB_FP8_OCP_TO_FNUZ_HPP +#define MIGRAPHX_GUARD_RTGLIB_FP8_OCP_TO_FNUZ_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +/** + * Convert fp8e4m3fn to fp8e4m3fnuz for hardware that only supports fp8e4m3fnuz data types + * intrinsically. Conversion uses the same bit representation and adjusts scaling factors at the + * dequantization. Using the same bit representation from fp8e4m3fn to fp8e4m3fnuz halves the + * floating point representation. This pass should run before simplify_qdq so that the scales and + * zero points calculated by simplify_qdq have the correct adjusted scaling factors + */ +struct MIGRAPHX_EXPORT fp8_ocp_to_fnuz +{ + std::string name() const { return "fp8_ocp_to_fnuz"; } + void apply(module_pass_manager& mpm) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/match/dq_helpers.hpp b/src/include/migraphx/match/dq_helpers.hpp new file mode 100644 index 00000000000..cdb40ae977e --- /dev/null +++ b/src/include/migraphx/match/dq_helpers.hpp @@ -0,0 +1,62 @@ + +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_MATCH_DQ_HELPERS_HPP +#define MIGRAPHX_GUARD_MATCH_DQ_HELPERS_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace match { + +/** + * Find dequantizelinear (DQ) instruction with constant scale and zero point input + * while skipping broadcast instructions between DQ and scale/zero point. Used + * in simplify_qdq and fp8_ocp_to_fnuz. + */ +inline auto dequantizelinear_op(const std::string& scale, const std::string& zp) +{ + return match::name("dequantizelinear")( + match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))), + match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp)))); +} + +/** + * Skip certain operators after DQ instruction. + * Used in simplify_qdq and fp8_ocp_to_fnuz. + */ +template +auto skip_post_dq_ops(Ms... ms) +{ + return match::skip(match::name( + "broadcast", "multibroadcast", "contiguous", "transpose", "reshape", "convert"))(ms...); +} + +} // namespace match +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/bit_cast.hpp b/src/include/migraphx/op/bit_cast.hpp index eb233ad8b36..0112342a14b 100644 --- a/src/include/migraphx/op/bit_cast.hpp +++ b/src/include/migraphx/op/bit_cast.hpp @@ -80,6 +80,7 @@ struct bit_cast : unary args[0].visit([&](auto input) { using itype = typename decltype(input)::value_type; if constexpr(sizeof(otype) == sizeof(itype)) + { par_transform(input.begin(), input.end(), output.begin(), [&](auto x) { return migraphx::bit_cast(x); diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 86c2100a995..bd21564b618 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -36,18 +36,12 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace { -template -auto skip_post_dq_ops(Ms... ms) -{ - return match::skip(match::name( - "broadcast", "multibroadcast", "contiguous", "transpose", "reshape", "convert"))(ms...); -} - std::unordered_set get_quantizable_op_names() { static std::unordered_set s = {"convolution", "dot"}; @@ -117,20 +111,12 @@ struct match_find_quantizable_ops return qinp; } - static auto dequantizelinear_op(const std::string& scale, const std::string& zp) - { - return match::name("dequantizelinear")( - match::arg(0)(match::skip(match::name("quantizelinear"))(match::any())), - match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))), - match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp)))); - } - auto matcher() const { - auto dq1 = - match::arg(0)(skip_post_dq_ops(dequantizelinear_op("scale1", "zp1").bind("dq1"))); - auto dq2 = - match::arg(1)(skip_post_dq_ops(dequantizelinear_op("scale2", "zp2").bind("dq2"))); + auto dq1 = match::arg(0)( + skip_post_dq_ops(match::dequantizelinear_op("scale1", "zp1").bind("dq1"))); + auto dq2 = match::arg(1)( + skip_post_dq_ops(match::dequantizelinear_op("scale2", "zp2").bind("dq2"))); return match::name(get_quantizable_op_names())(dq1, dq2); } @@ -231,7 +217,9 @@ struct match_find_quantizable_ops is_valid_qparam(zp1, out_lens, out_lens.size() - 2) and is_valid_qparam(scale2, out_lens, out_lens.size() - 1) and is_valid_qparam(zp2, out_lens, out_lens.size() - 1))) + { return; + } // This implementation supports both arguments being per-axis affine quantized // In practice, inputs are per-tensor affine and weights are per-axis symmetric diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 9320ed86f9f..9854db47c4c 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -178,6 +179,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, eliminate_identity{}, dead_code_elimination{}, + enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), fp8_ocp_to_fnuz{}), + enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), dead_code_elimination{}), simplify_qdq{}, enable_pass(not mlir_enabled(), rewrite_quantization{}), dead_code_elimination{}, diff --git a/test/fp8_ocp_to_fnuz_test.cpp b/test/fp8_ocp_to_fnuz_test.cpp new file mode 100644 index 00000000000..58abb18bddc --- /dev/null +++ b/test/fp8_ocp_to_fnuz_test.cpp @@ -0,0 +1,229 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using migraphx::make_op; +using migraphx::shape; +using migraphx::fp8::fp8e4m3fnuz; + +void run_fp8_ocp_to_fnuz(migraphx::module& m) +{ + migraphx::run_passes(m, {migraphx::fp8_ocp_to_fnuz{}, migraphx::dead_code_elimination{}}); +} + +void run_simplify_qdq(migraphx::module& m) +{ + run_passes(m, {migraphx::simplify_qdq{}, migraphx::dead_code_elimination{}}); +} + +void run_cse(migraphx::module& m) +{ + run_passes(m, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}}); +} + +void run_propagate_constant(migraphx::module& m, + const std::unordered_set& skip_ops = {}) +{ + migraphx::run_passes( + m, {migraphx::propagate_constant{skip_ops}, migraphx::dead_code_elimination{}}); +} + +auto bit_cast_and_handle_specials(migraphx::module& m, + const migraphx::instruction_ref x, + const migraphx::instruction_ref bits_0x80_lit, + const migraphx::instruction_ref bits_0x7f_lit, + const migraphx::instruction_ref bits_0xff_lit, + const migraphx::instruction_ref bits_0x00_lit) +{ + auto x_lens = x->get_shape().lens(); + auto cast_input = + m.add_instruction(make_op("bit_cast", {{"target_type", shape::fp8e4m3fnuz_type}}), x); + auto mb_bits_0x80_lit = + m.add_instruction(make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x80_lit); + auto mb_bits_0x7f_lit = + m.add_instruction(make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x7f_lit); + auto mb_bits_0xff_lit = + m.add_instruction(make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0xff_lit); + auto mb_zero_lit = + m.add_instruction(make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x00_lit); + // negative zero in fp8e4m3fn to zero in fp8e4m3fnuz + // a == 0x80 ? 0x0 : a + auto is_neg_zero = m.add_instruction(make_op("equal"), cast_input, mb_bits_0x80_lit); + auto ret = m.add_instruction(make_op("where"), is_neg_zero, mb_zero_lit, cast_input); + + // positive and negative NaN in fp8e4m3fn to NaN in fp8e4m3fnuz + // (a == 0x7f or a == 0xff) ? 0x80 : a + auto eq_0x7f = m.add_instruction(make_op("equal"), ret, mb_bits_0x7f_lit); + + auto eq_0xff = m.add_instruction(make_op("equal"), ret, mb_bits_0xff_lit); + + auto cond = m.add_instruction(make_op("logical_or"), eq_0x7f, eq_0xff); + ret = m.add_instruction(make_op("where"), cond, mb_bits_0x80_lit, ret); + return ret; +} + +auto cast_fp8_helper(migraphx::module& m, + const migraphx::instruction_ref dq_input, + const migraphx::instruction_ref dq_scale, + const migraphx::instruction_ref dq_zp) +{ + auto dq_input_lens = dq_input->get_shape().lens(); + std::vector bits_0x80 = {fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits())}; + std::vector bits_0x7f = {fp8e4m3fnuz(0x7f, fp8e4m3fnuz::from_bits())}; + std::vector bits_0xff = {fp8e4m3fnuz(0xff, fp8e4m3fnuz::from_bits())}; + std::vector bits_0x00 = {fp8e4m3fnuz(0x00, fp8e4m3fnuz::from_bits())}; + auto bits_0x80_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x80); + auto bits_0x7f_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x7f); + auto bits_0xff_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0xff); + auto bits_0x00_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x00); + + auto cast_input = bit_cast_and_handle_specials( + m, dq_input, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit); + auto adj_zp = bit_cast_and_handle_specials( + m, dq_zp, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit); + + auto two_lit = m.add_literal(migraphx::literal{shape{dq_scale->get_shape().type()}, {2}}); + two_lit = m.add_instruction( + make_op("multibroadcast", {{"out_lens", dq_scale->get_shape().lens()}}), two_lit); + auto adj_dq_scale = m.add_instruction(make_op("mul"), dq_scale, two_lit); + + return std::vector{cast_input, adj_dq_scale, adj_zp}; +} + +TEST_CASE(fp8_gemm_conversion) +{ + using migraphx::fp8::fp8e4m3fn; + using migraphx::fp8::fp8e4m3fnuz; + std::vector data_lens = {2, 3, 8, 8}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", {migraphx::shape::float_type, data_lens}); + auto b = m1.add_parameter("b", {migraphx::shape::float_type, data_lens}); + auto scale = m1.add_literal(0.5f); + std::vector data; + data.push_back(fp8e4m3fn{0.f}); + auto zero = + m1.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); + + auto qa = add_quantize_op(m1, "quantizelinear", a, scale, zero); + auto qb = add_quantize_op(m1, "quantizelinear", b, scale, zero); + auto da = + add_quantize_op(m1, "dequantizelinear", qa, qa->inputs().at(1), qa->inputs().at(2)); + auto db = + add_quantize_op(m1, "dequantizelinear", qb, qb->inputs().at(1), qb->inputs().at(2)); + auto dot = m1.add_instruction(migraphx::make_op("dot"), da, db); + m1.add_return({dot}); + } + run_fp8_ocp_to_fnuz(m1); + + // expected after fp8_ocp_to_fnuz + migraphx::module m2; + { + auto a = m2.add_parameter("a", {migraphx::shape::float_type, data_lens}); + auto b = m2.add_parameter("b", {migraphx::shape::float_type, data_lens}); + auto scale = m2.add_literal(0.5f); + std::vector data; + data.push_back(fp8e4m3fn{0.f}); + auto zero = + m2.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); + + auto qa = add_quantize_op(m2, "quantizelinear", a, scale, zero); + auto qb = add_quantize_op(m2, "quantizelinear", b, scale, zero); + + auto outs_a = cast_fp8_helper(m2, qa, scale, zero); + auto adj_a = outs_a.at(0); + auto mb_scales_a = + m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_a.at(1)); + auto mb_zp_a = + m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_a.at(2)); + auto da = m2.add_instruction(make_op("dequantizelinear"), adj_a, mb_scales_a, mb_zp_a); + + auto outs_b = cast_fp8_helper(m2, qb, scale, zero); + auto adj_b = outs_b.at(0); + auto mb_scales_b = + m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_b.at(1)); + auto mb_zp_b = + m2.add_instruction(make_op("multibroadcast", {{"out_lens", data_lens}}), outs_b.at(2)); + auto db = m2.add_instruction(make_op("dequantizelinear"), adj_b, mb_scales_b, mb_zp_b); + + auto dot = m2.add_instruction(migraphx::make_op("dot"), da, db); + m2.add_return({dot}); + } + + EXPECT(m1 == m2); + + // expected after simplify_qdq + migraphx::module m3; + { + auto a = m3.add_parameter("a", {migraphx::shape::float_type, {2, 3, 8, 8}}); + auto b = m3.add_parameter("b", {migraphx::shape::float_type, {2, 3, 8, 8}}); + auto scale = m3.add_literal(0.5f); + std::vector data; + data.push_back(fp8e4m3fn{0.f}); + auto zero = + m3.add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); + + auto qa = add_quantize_op(m3, "quantizelinear", a, scale, zero); + auto qb = add_quantize_op(m3, "quantizelinear", b, scale, zero); + + auto outs_a = cast_fp8_helper(m3, qa, qa->inputs().at(1), qa->inputs().at(2)); + auto outs_b = cast_fp8_helper(m3, qb, qb->inputs().at(1), qb->inputs().at(2)); + auto adj_qa = outs_a.at(0); + auto adj_scale_a = outs_a.at(1); + auto adj_qb = outs_b.at(0); + auto adj_scale_b = outs_b.at(1); + + auto dot = m3.add_instruction(migraphx::make_op("quant_dot"), adj_qa, adj_qb); + + auto out_scale = add_scale_mul(m3, adj_scale_a, adj_scale_b, 1, 1, dot->get_shape().lens()); + auto dq_out = add_quantize_op(m3, "dequantizelinear", dot, out_scale); + m3.add_return({dq_out}); + } + + run_simplify_qdq(m1); + // running propagate constant to simplify adjustments to literals + // could pass the test without, but a tedious amount of instructions to rearrange + run_propagate_constant(m1); + run_propagate_constant(m3); + run_cse(m1); + run_cse(m3); + EXPECT(m1 == m3); + m1.debug_print(); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/include/quantize_helpers.hpp b/test/include/quantize_helpers.hpp new file mode 100644 index 00000000000..43bde67199e --- /dev/null +++ b/test/include/quantize_helpers.hpp @@ -0,0 +1,73 @@ +#include +#include +#include +#include + +#ifndef MIGRAPHX_GUARD_TEST_INCLUDE_QUANTIZE_HELPERS_HPP +#define MIGRAPHX_GUARD_TEST_INCLUDE_QUANTIZE_HELPERS_HPP + +inline migraphx::instruction_ref broadcast_scale(migraphx::module& m, + migraphx::instruction_ref scale, + const std::vector& out_lens, + std::size_t axis) +{ + if(scale->get_shape().lens() == out_lens) + return scale; + + migraphx::instruction_ref scale_mb; + auto scale_lens = scale->get_shape().lens(); + if(scale_lens.front() == 1 and scale_lens.size() == 1) + scale_mb = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), scale); + else + scale_mb = m.add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", out_lens}}), scale); + return scale_mb; +} + +inline migraphx::instruction_ref broadcast_shift(migraphx::module& m, + migraphx::instruction_ref shift, + const std::vector& out_lens) +{ + if(shift->get_shape().lens() == out_lens) + return shift; + return m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), shift); +} + +inline migraphx::instruction_ref add_scale_mul(migraphx::module& m, + migraphx::instruction_ref scale1, + migraphx::instruction_ref scale2, + std::size_t axis1, + std::size_t axis2, + const std::vector& out_lens) +{ + auto scale1_mb = broadcast_scale(m, scale1, out_lens, axis1); + auto scale2_mb = broadcast_scale(m, scale2, out_lens, axis2); + return m.add_instruction(migraphx::make_op("mul"), scale1_mb, scale2_mb); +} + +inline migraphx::instruction_ref add_quantize_op(migraphx::module& m, + const std::string& name, + migraphx::instruction_ref x, + migraphx::instruction_ref scale, + migraphx::instruction_ref shift, + std::size_t q_axis = 1) +{ + auto lens = x->get_shape().lens(); + auto scale_mb = broadcast_scale(m, scale, lens, q_axis); + auto shift_mb = broadcast_shift(m, shift, lens); + return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb); +} + +inline migraphx::instruction_ref add_quantize_op(migraphx::module& m, + const std::string& name, + migraphx::instruction_ref x, + migraphx::instruction_ref scale, + std::size_t q_axis = 1) +{ + auto lens = x->get_shape().lens(); + auto scale_mb = broadcast_scale(m, scale, lens, q_axis); + return m.add_instruction(migraphx::make_op(name), x, scale_mb); +} + +#endif // MIGRAPHX_GUARD_TEST_INCLUDE_QUANTIZE_HELPERS_HPP diff --git a/test/ref/fp8_ocp_to_fnuz.cpp b/test/ref/fp8_ocp_to_fnuz.cpp new file mode 100644 index 00000000000..d0a00df9565 --- /dev/null +++ b/test/ref/fp8_ocp_to_fnuz.cpp @@ -0,0 +1,226 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +/** + * test that before and after the fp8_ocp_to_fnuz pass + * have equivalent results + */ + +void run_fp8_ocp_to_fnuz(migraphx::module& m) +{ + migraphx::run_passes(m, {migraphx::fp8_ocp_to_fnuz{}, migraphx::dead_code_elimination{}}); +} + +TEST_CASE(fp8_ocp_to_fnuz_gemm) +{ + using migraphx::fp8::fp8e4m3fn; + using migraphx::fp8::fp8e4m3fnuz; + std::vector data_lens = {2, 2}; + migraphx::shape data_shape{migraphx::shape::float_type, data_lens}; + + migraphx::program p1; + auto* m1 = p1.get_main_module(); + { + auto a = m1->add_parameter("a", data_shape); + auto b = m1->add_parameter("b", data_shape); + auto scale = m1->add_literal(0.5f); + std::vector data; + data.push_back(fp8e4m3fn{0.f}); + auto zero = + m1->add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); + + auto qa = add_quantize_op(*m1, "quantizelinear", a, scale, zero); + auto qb = add_quantize_op(*m1, "quantizelinear", b, scale, zero); + auto da = + add_quantize_op(*m1, "dequantizelinear", qa, qa->inputs().at(1), qa->inputs().at(2)); + auto db = + add_quantize_op(*m1, "dequantizelinear", qb, qb->inputs().at(1), qb->inputs().at(2)); + auto dot = m1->add_instruction(migraphx::make_op("dot"), da, db); + m1->add_return({dot}); + } + + migraphx::program p2 = p1; + migraphx::module* m2 = p2.get_main_module(); + run_fp8_ocp_to_fnuz(*m2); + + p1.compile(migraphx::make_target("ref")); + p2.compile(migraphx::make_target("ref")); + + migraphx::parameter_map params; + std::vector a_data = {20, -100, 100, 0.25}; + std::vector b_data = {28, 0.125, 2.5, 0.25}; + params["a"] = migraphx::argument(data_shape, a_data.data()); + params["b"] = migraphx::argument(data_shape, b_data.data()); + + auto result_1 = p1.eval({params}).back(); + auto result_2 = p2.eval({params}).back(); + std::vector results_vector_1(4); + std::vector results_vector_2(4); + result_1.visit([&](auto output) { results_vector_1.assign(output.begin(), output.end()); }); + result_2.visit([&](auto output) { results_vector_2.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(results_vector_1, results_vector_2)); +} + +TEST_CASE(fp8_ocp_to_fnuz_gemm_multi_scale) +{ + using migraphx::fp8::fp8e4m3fn; + using migraphx::fp8::fp8e4m3fnuz; + std::vector data_lens = {3, 3}; + migraphx::shape data_shape{migraphx::shape::float_type, data_lens}; + migraphx::shape scales_shape{migraphx::shape::float_type, {3}}; + + migraphx::program p1; + auto* m1 = p1.get_main_module(); + { + auto a = m1->add_parameter("a", data_shape); + auto b = m1->add_parameter("b", data_shape); + auto scale1 = m1->add_literal(migraphx::generate_literal(scales_shape, 0)); + auto scale2 = m1->add_literal(0.4f); + std::vector data; + data.push_back(fp8e4m3fn{0.f}); + auto zero = + m1->add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); + + auto qa = add_quantize_op(*m1, "quantizelinear", a, scale1, zero); + auto qb = add_quantize_op(*m1, "quantizelinear", b, scale2, zero); + auto da = + add_quantize_op(*m1, "dequantizelinear", qa, qa->inputs().at(1), qa->inputs().at(2)); + auto db = + add_quantize_op(*m1, "dequantizelinear", qb, qb->inputs().at(1), qb->inputs().at(2)); + auto dot = m1->add_instruction(migraphx::make_op("dot"), da, db); + m1->add_return({dot}); + } + + migraphx::program p2 = p1; + migraphx::module* m2 = p2.get_main_module(); + run_fp8_ocp_to_fnuz(*m2); + + p1.compile(migraphx::make_target("ref")); + p2.compile(migraphx::make_target("ref")); + + migraphx::parameter_map params; + std::vector a_data = {20, -100, 100, 0.25, 0.3, 3.3, 5.0, -8.0, 63.0}; + std::vector b_data = {28, 0.125, 2.5, 0.25, 0.0582, -187, 0.716, 8.12, 1.87}; + params["a"] = migraphx::argument(data_shape, a_data.data()); + params["b"] = migraphx::argument(data_shape, b_data.data()); + + auto result_1 = p1.eval({params}).back(); + auto result_2 = p2.eval({params}).back(); + std::vector results_vector_1(9); + std::vector results_vector_2(9); + result_1.visit([&](auto output) { results_vector_1.assign(output.begin(), output.end()); }); + result_2.visit([&](auto output) { results_vector_2.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(results_vector_1, results_vector_2)); +} + +TEST_CASE(fp8_ocp_to_fnuz_conv) +{ + using migraphx::fp8::fp8e4m3fn; + using migraphx::fp8::fp8e4m3fnuz; + std::vector data_lens = {2, 2}; + migraphx::shape data_shape{migraphx::shape::float_type, data_lens}; + + migraphx::program p1; + auto* m1 = p1.get_main_module(); + { + std::vector a_data = { + 2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, + 0.80927712, -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, + 0.67726439, -0.65290606, 0.02345525, -0.33579525, 0.38901961, 1.05473483, + -1.31188095, 1.8963089, -0.07265259, 0.947339, 0.41949373, -0.70814759, + 0.25892952, 1.07311416, 1.2571274, -0.62318051, -0.19951548, -0.94232577, + -0.29393643, 0.42292568, -0.80230367, 1.40909171, 0.63617158, 0.13900366, + 1.09253144, -0.15265895, 1.54781747, 0.72780299, 1.09189606, -0.38068101, + 0.97057933, -0.58958799, 1.56188643, 0.21474874, 0.58725154, -1.27097559, + -0.03024297, 1.09437096, -0.4897908, 0.34838957, -1.31042492, -1.69069934, + 0.86956722, -0.40457946, 0.46691212, 1.29273605, 0.26464137, 0.22073045, + -1.02178168, 0.22163901, -1.84387338, 0.75522131, -0.45775682, -0.42241111, + -1.50944722, 1.07256448, -1.95876884, -0.28106022, 0.3341668, 2.13129425, + -1.14728117, -1.06555498, -0.298444, -0.88322699, -0.65866792, -2.06007552, + 0.01374334, 0.45612028, 0.52715492, 1.01914406, -1.72659791, 0.80650896, + 0.16860051, 2.24112225, -0.78620857, 0.36566174, -0.07020134, -0.47976932, + -0.68230027, -0.94711417, -0.54506505, 1.66504931, -0.71860826, 0.61132306}; + + std::vector b_data = { + 2.82721668e-02, 6.44195229e-02, 1.53499246e-02, 1.72468081e-01, -6.33238107e-02, + 9.49496776e-02, 1.40258059e-01, -7.92879611e-02, -1.29301161e-01, 3.11307609e-03, + -1.90624535e-01, 1.13238767e-01, -2.80647576e-02, 3.12882811e-02, -3.52091640e-02, + 3.33581865e-02, 6.43158704e-02, 7.40238279e-02, -1.00106120e-01, -9.56912562e-02, + 1.44342467e-01, 9.40258950e-02, 6.36333972e-02, 1.66158378e-03, -8.91554281e-02, + 2.58734226e-02, 1.70919895e-02, 1.78214177e-01, 8.84564668e-02, 8.98126513e-02, + -1.63809001e-01, 1.37802169e-01, 1.66439757e-01, -1.45631135e-02, 1.88469887e-04, + 4.76950556e-02, -1.91969007e-01, -1.76233292e-01, -7.70473927e-02, 1.14828631e-01, + 1.76608220e-01, -1.50728196e-01, 1.99946314e-02, -5.88052124e-02, 1.31612435e-01, + 1.61106288e-02, -1.35080189e-01, 1.49512306e-01, 3.86456847e-02, 1.29330024e-01, + -3.22975963e-02, -5.60784787e-02, -5.41997552e-02, 4.78562862e-02}; + + migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 4, 4}}; + auto a = m1->add_literal(migraphx::literal{a_shape, a_data}); + + migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; + auto b = m1->add_literal(migraphx::literal{b_shape, b_data}); + auto scale = m1->add_literal(0.5f); + std::vector data; + data.push_back(fp8e4m3fn{0.f}); + auto zero = + m1->add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); + + auto qa = add_quantize_op(*m1, "quantizelinear", a, scale, zero); + auto qb = add_quantize_op(*m1, "quantizelinear", b, scale, zero); + auto da = + add_quantize_op(*m1, "dequantizelinear", qa, qa->inputs().at(1), qa->inputs().at(2)); + auto db = + add_quantize_op(*m1, "dequantizelinear", qb, qb->inputs().at(1), qb->inputs().at(2)); + auto conv_ins = m1->add_instruction(migraphx::make_op("convolution"), da, db); + m1->add_return({conv_ins}); + } + + migraphx::program p2 = p1; + migraphx::module* m2 = p2.get_main_module(); + run_fp8_ocp_to_fnuz(*m2); + + p1.compile(migraphx::make_target("ref")); + p2.compile(migraphx::make_target("ref")); + + auto result_1 = p1.eval({}).back(); + auto result_2 = p2.eval({}).back(); + std::vector results_vector_1(16); + std::vector results_vector_2(16); + result_1.visit([&](auto output) { results_vector_1.assign(output.begin(), output.end()); }); + result_2.visit([&](auto output) { results_vector_2.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_rms_range(results_vector_1, results_vector_2)); +} diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index c3c50cb4172..cef500fbfd3 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -45,75 +46,12 @@ void run_pass(migraphx::module& m) { run_passes(m, {migraphx::simplify_qdq{}, migraphx::dead_code_elimination{}}); } + void run_cse(migraphx::module& m) { run_passes(m, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}}); } -migraphx::instruction_ref broadcast_scale(migraphx::module& m, - migraphx::instruction_ref scale, - const std::vector& out_lens, - std::size_t axis) -{ - if(scale->get_shape().lens() == out_lens) - return scale; - - migraphx::instruction_ref scale_mb; - auto scale_lens = scale->get_shape().lens(); - if(scale_lens.front() == 1 and scale_lens.size() == 1) - scale_mb = - m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), scale); - else - scale_mb = m.add_instruction( - migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", out_lens}}), scale); - return scale_mb; -} - -migraphx::instruction_ref broadcast_shift(migraphx::module& m, - migraphx::instruction_ref shift, - const std::vector& out_lens) -{ - if(shift->get_shape().lens() == out_lens) - return shift; - return m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), shift); -} - -migraphx::instruction_ref add_quantize_op(migraphx::module& m, - const std::string& name, - migraphx::instruction_ref x, - migraphx::instruction_ref scale, - migraphx::instruction_ref shift, - std::size_t q_axis = 1) -{ - auto lens = x->get_shape().lens(); - auto scale_mb = broadcast_scale(m, scale, lens, q_axis); - auto shift_mb = broadcast_shift(m, shift, lens); - return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb); -} - -migraphx::instruction_ref add_quantize_op(migraphx::module& m, - const std::string& name, - migraphx::instruction_ref x, - migraphx::instruction_ref scale, - std::size_t q_axis = 1) -{ - auto lens = x->get_shape().lens(); - auto scale_mb = broadcast_scale(m, scale, lens, q_axis); - return m.add_instruction(migraphx::make_op(name), x, scale_mb); -} - -migraphx::instruction_ref add_scale_mul(migraphx::module& m, - migraphx::instruction_ref scale1, - migraphx::instruction_ref scale2, - std::size_t axis1, - std::size_t axis2, - const std::vector& out_lens) -{ - auto scale1_mb = broadcast_scale(m, scale1, out_lens, axis1); - auto scale2_mb = broadcast_scale(m, scale2, out_lens, axis2); - return m.add_instruction(migraphx::make_op("mul"), scale1_mb, scale2_mb); -} - migraphx::instruction_ref init_zero_point(migraphx::module& m, migraphx::instruction_ref q_ins) { auto zp = m.add_literal(migraphx::literal{migraphx::shape{q_ins->get_shape().type()}, {0}}); diff --git a/test/verify/test_fp8_ocp_to_fnuz_gemm.cpp b/test/verify/test_fp8_ocp_to_fnuz_gemm.cpp new file mode 100644 index 00000000000..88fc9828034 --- /dev/null +++ b/test/verify/test_fp8_ocp_to_fnuz_gemm.cpp @@ -0,0 +1,60 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include + +struct test_fp8_ocp_to_fnuz_gemm : verify_program +{ + using fp8e4m3fn = migraphx::fp8::fp8e4m3fn; + using fp8e4m3fnuz = migraphx::fp8::fp8e4m3fnuz; + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data_lens = {2, 2}; + migraphx::shape data_shape{migraphx::shape::float_type, data_lens}; + auto a = mm->add_parameter("a", data_shape); + auto b = mm->add_parameter("b", data_shape); + auto scale = mm->add_literal(0.5f); + std::vector data; + data.push_back(fp8e4m3fn{0.f}); + auto zero = + mm->add_literal(migraphx::shape{migraphx::shape::fp8e4m3fn_type, {1}, {0}}, data); + + auto qa = add_quantize_op(*mm, "quantizelinear", a, scale, zero); + auto qb = add_quantize_op(*mm, "quantizelinear", b, scale, zero); + auto da = + add_quantize_op(*mm, "dequantizelinear", qa, qa->inputs().at(1), qa->inputs().at(2)); + auto db = + add_quantize_op(*mm, "dequantizelinear", qb, qb->inputs().at(1), qb->inputs().at(2)); + auto dot = mm->add_instruction(migraphx::make_op("dot"), da, db); + mm->add_return({dot}); + return p; + } + std::string section() const { return "gemm"; } +};