diff --git a/docs/reference/py.rst b/docs/reference/py.rst index 4300d9e5a69..c68a2df0e54 100755 --- a/docs/reference/py.rst +++ b/docs/reference/py.rst @@ -314,6 +314,13 @@ program :type ins_names: list[str] +.. py:function:: autocast_fp8(prog) + + Auto-convert FP8 parameters and return values to Float for an MIGraphX program. + + :param program prog: Program to auto-convert parameters/return values. + + op -- .. py::class:: op(name, kwargs) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3e1e0e60d27..c20bc0a0a6c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -256,6 +256,7 @@ register_migraphx_ops( undefined unique unknown + unpack_int4 unsqueeze where ) diff --git a/src/include/migraphx/op/unpack_int4.hpp b/src/include/migraphx/op/unpack_int4.hpp new file mode 100644 index 00000000000..df7938ffea2 --- /dev/null +++ b/src/include/migraphx/op/unpack_int4.hpp @@ -0,0 +1,97 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_OPERATORS_UNPACK_INT4_HPP +#define MIGRAPHX_GUARD_OPERATORS_UNPACK_INT4_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { +struct unpack_int4 +{ + int64_t axis = -1; + + std::string name() const { return "unpack_int4"; } + + value attributes() const + { + value normalize = value::object{}; + normalize["axis"] = value::array{normalize_attribute::include_min}; + return {{"normalize_axes", normalize}}; + } + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.axis, "axis")); + } + + migraphx::shape normalize_compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this}.same_dims().has(1); + auto in_shape = inputs.front(); + if(in_shape.type() != migraphx::shape::uint8_type) + { + MIGRAPHX_THROW("UNPACK_INT4: Only Unsigned Int8 type is supported for unpacking"); + } + auto new_lens = in_shape.lens(); + new_lens[axis] *= 2; + return {migraphx::shape::uint8_type, new_lens}; + } + + argument compute(const shape& output_shape, std::vector args) const + { + argument result{output_shape}; + auto in_shape = args.front().get_shape(); + auto input = args.at(0).get(); + auto output = result.get(); + par_for(in_shape.elements(), [&](auto i) { + auto data_idx = in_shape.multi(i); + auto out_data_multi_idx = data_idx; + out_data_multi_idx[axis] *= 2; + auto input_val = input[data_idx]; + // mask first 4 bits, packing is assumed to be little endian + output[out_data_multi_idx] = uint8_t(0x0F) & input_val; + out_data_multi_idx[axis] += 1; + output[out_data_multi_idx] = input_val >> 4; // NOLINT(hicpp-signed-bitwise) + }); + return result; + } +}; +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index e58fdf9a973..d01b9d15b1d 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -41,6 +42,7 @@ #include #include #include +#include #ifdef HAVE_GPU #include #endif @@ -581,6 +583,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) py::arg("t"), py::arg("calibration") = std::vector{}, py::arg("ins_names") = std::unordered_set{"dot", "convolution"}); + m.def( + "autocast_fp8", + [](migraphx::program& prog) { + migraphx::run_passes(*prog.get_main_module(), {migraphx::autocast_fp8_pass{}}); + }, + "Auto-convert FP8 parameters and return values to Float for MIGraphX Program", + py::arg("prog")); #ifdef HAVE_GPU m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); diff --git a/src/targets/gpu/rocblas.cpp b/src/targets/gpu/rocblas.cpp index 9d5266e6176..1b37f08e1ed 100644 --- a/src/targets/gpu/rocblas.cpp +++ b/src/targets/gpu/rocblas.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -35,6 +35,8 @@ namespace gpu { rocblas_handle_ptr create_rocblas_handle_ptr() { + // add a call to rocblas_initialize() to workaround a rocblas bug SWDEV-438929 + rocblas_initialize(); rocblas_handle handle; rocblas_create_handle(&handle); return rocblas_handle_ptr{handle}; diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index 7eaae50f3a2..f5ae4259fc2 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -2325,6 +2325,53 @@ TEST_CASE(pack_int4_odd_lengths) throws_shape(migraphx::make_op("pack_int4", {{"axis", 0}}), input); } +TEST_CASE(unpack_int4) +{ + migraphx::shape input{migraphx::shape::uint8_type, {1, 4, 16, 8}}; + migraphx::shape output{migraphx::shape::uint8_type, {1, 4, 16, 16}}; + expect_shape(output, migraphx::make_op("unpack_int4"), input); +} + +TEST_CASE(unpack_int4_axis1) +{ + migraphx::shape input{migraphx::shape::uint8_type, {1, 2, 16, 16}}; + migraphx::shape output{migraphx::shape::uint8_type, {1, 4, 16, 16}}; + expect_shape(output, migraphx::make_op("unpack_int4", {{"axis", 1}}), input); +} + +TEST_CASE(unpack_int4_axis2) +{ + migraphx::shape input{migraphx::shape::uint8_type, {1, 2, 16, 16}}; + migraphx::shape output{migraphx::shape::uint8_type, {1, 4, 16, 16}}; + expect_shape(output, migraphx::make_op("unpack_int4", {{"axis", -3}}), input); +} + +TEST_CASE(unpack_int4_invalid_axis) +{ + migraphx::shape input{migraphx::shape::uint8_type, {1, 4, 16, 16}}; + throws_shape(migraphx::make_op("unpack_int4", {{"axis", 4}}), input); +} + +TEST_CASE(unpack_int4_nonstandard) +{ + migraphx::shape input{migraphx::shape::uint8_type, {1, 16, 16, 4}, {1024, 16, 1, 256}}; + migraphx::shape output{migraphx::shape::uint8_type, {1, 32, 16, 4}}; + expect_shape(output, migraphx::make_op("unpack_int4", {{"axis", 1}}), input); +} + +TEST_CASE(unpack_int4_invalid_dtype) +{ + migraphx::shape input{migraphx::shape::float_type, {1, 4, 16, 16}}; + throws_shape(migraphx::make_op("unpack_int4", {{"axis", 0}}), input); +} + +TEST_CASE(unpack_int4_odd_lengths) +{ + migraphx::shape input{migraphx::shape::uint8_type, {3, 4, 16, 16}}; + migraphx::shape output{migraphx::shape::uint8_type, {6, 4, 16, 16}}; + expect_shape(output, migraphx::make_op("unpack_int4", {{"axis", 0}}), input); +} + TEST_CASE(pad_shape0) { migraphx::shape input{migraphx::shape::float_type, {2, 3, 3, 3}}; diff --git a/test/py/CMakeLists.txt b/test/py/CMakeLists.txt index 5df5f9d82b2..f2bfd2e8555 100755 --- a/test/py/CMakeLists.txt +++ b/test/py/CMakeLists.txt @@ -103,6 +103,7 @@ add_py_test(op test_op.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(shape test_shape.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(module_construct test_module_construct.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(literal test_literal.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) +add_py_test(autocast_fp8 test_autocast_fp8.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) if(MIGRAPHX_ENABLE_GPU) add_py_test(gpu_offload test_gpu_offload.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(gpu test_gpu.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR}) diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index 89c95f0b58e..a6956c6199e 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -66,10 +66,6 @@ def assert_similar_outputs(cls, ref_outputs, outputs, rtol, atol): def disabled_tests_onnx_1_7_0(backend_test): # fails # from OnnxBackendNodeModelTest - backend_test.exclude(r'test_logsoftmax_axis_0_cpu') - backend_test.exclude(r'test_logsoftmax_axis_1_cpu') - backend_test.exclude(r'test_logsoftmax_default_axis_cpu') - backend_test.exclude(r'test_maxpool_2d_dilations_cpu') backend_test.exclude(r'test_maxpool_with_argmax_2d_precomputed_pads_cpu') backend_test.exclude( r'test_maxpool_with_argmax_2d_precomputed_strides_cpu') @@ -83,9 +79,6 @@ def disabled_tests_onnx_1_7_0(backend_test): backend_test.exclude(r'test_nonmaxsuppression_two_batches_cpu') backend_test.exclude(r'test_nonmaxsuppression_two_classes_cpu') backend_test.exclude(r'test_nonzero_example_cpu') - backend_test.exclude(r'test_softmax_axis_0_cpu') - backend_test.exclude(r'test_softmax_axis_1_cpu') - backend_test.exclude(r'test_softmax_default_axis_cpu') # from OnnxBackendPyTorchConvertedModelTest backend_test.exclude(r'test_ConvTranspose2d_cpu') @@ -129,56 +122,6 @@ def disabled_tests_onnx_1_7_0(backend_test): backend_test.exclude(r'test_maxunpool_export_without_output_shape_cpu') backend_test.exclude(r'test_mod_mixed_sign_int32_cpu') backend_test.exclude(r'test_mod_mixed_sign_int8_cpu') - backend_test.exclude( - r'test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index_cpu' - ) - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NC_cpu') - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1_cpu') - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1_ignore_index_cpu' - ) - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1_mean_weight_negative_ignore_index_cpu' - ) - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1_weight_cpu') - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1d2_cpu') - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1d2_no_weight_reduction_mean_ignore_index_cpu' - ) - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_mean_cpu' - ) - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_sum_cpu' - ) - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_cpu' - ) - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_mean_cpu' - ) - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_cpu' - ) - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_ignore_index_cpu' - ) - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_cpu' - ) - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_cpu' - ) - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight_cpu' - ) - backend_test.exclude( - r'test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight_cpu' - ) backend_test.exclude(r'test_qlinearmatmul_2D_cpu') backend_test.exclude(r'test_qlinearmatmul_3D_cpu') backend_test.exclude(r'test_range_float_type_positive_delta_expanded_cpu') @@ -197,8 +140,6 @@ def disabled_tests_onnx_1_7_0(backend_test): backend_test.exclude( r'test_resize_downsample_sizes_linear_pytorch_half_pixel_cpu') backend_test.exclude(r'test_resize_downsample_sizes_nearest_cpu') - backend_test.exclude( - r'test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn_cpu') backend_test.exclude(r'test_resize_tf_crop_and_resize_cpu') backend_test.exclude( r'test_resize_upsample_scales_cubic_A_n0p5_exclude_outside_cpu') @@ -230,159 +171,6 @@ def disabled_tests_onnx_1_7_0(backend_test): backend_test.exclude(r'test_slice_neg_steps_cpu') backend_test.exclude(r'test_slice_negative_axes_cpu') backend_test.exclude(r'test_slice_start_out_of_bounds_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_log_prob_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_log_prob_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob_expanded_cpu' - ) - backend_test.exclude(r'test_softmax_cross_entropy_mean_3d_cpu') - backend_test.exclude(r'test_softmax_cross_entropy_mean_3d_expanded_cpu') - backend_test.exclude(r'test_softmax_cross_entropy_mean_3d_log_prob_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_3d_log_prob_expanded_cpu') - backend_test.exclude(r'test_softmax_cross_entropy_mean_cpu') - backend_test.exclude(r'test_softmax_cross_entropy_mean_expanded_cpu') - backend_test.exclude(r'test_softmax_cross_entropy_mean_log_prob_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_log_prob_expanded_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_log_prob_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_log_prob_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_log_prob_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_log_prob_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_mean_no_weight_ignore_index_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_no_weight_ignore_index_expanded_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_no_weight_ignore_index_log_prob_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_no_weight_ignore_index_log_prob_expanded_cpu' - ) - backend_test.exclude(r'test_softmax_cross_entropy_mean_weight_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_expanded_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_ignore_index_3d_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_ignore_index_3d_expanded_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_ignore_index_3d_log_prob_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_ignore_index_3d_log_prob_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_ignore_index_4d_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_ignore_index_4d_expanded_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_ignore_index_4d_log_prob_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_ignore_index_4d_log_prob_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_ignore_index_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_ignore_index_expanded_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_ignore_index_log_prob_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_ignore_index_log_prob_expanded_cpu' - ) - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_log_prob_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_mean_weight_log_prob_expanded_cpu') - backend_test.exclude(r'test_softmax_cross_entropy_none_cpu') - backend_test.exclude(r'test_softmax_cross_entropy_none_expanded_cpu') - backend_test.exclude(r'test_softmax_cross_entropy_none_log_prob_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_none_log_prob_expanded_cpu') - backend_test.exclude(r'test_softmax_cross_entropy_none_weights_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_none_weights_expanded_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_none_weights_log_prob_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_none_weights_log_prob_expanded_cpu') - backend_test.exclude(r'test_softmax_cross_entropy_sum_cpu') - backend_test.exclude(r'test_softmax_cross_entropy_sum_expanded_cpu') - backend_test.exclude(r'test_softmax_cross_entropy_sum_log_prob_cpu') - backend_test.exclude( - r'test_softmax_cross_entropy_sum_log_prob_expanded_cpu') - backend_test.exclude(r'test_split_zero_size_splits_cpu') backend_test.exclude( r'test_strnormalizer_export_monday_casesensintive_lower_cpu') backend_test.exclude( @@ -473,7 +261,6 @@ def disabled_tests_onnx_1_8_0(backend_test): backend_test.exclude(r'test_nllloss_NCd1d2d3_sum_weight_high_ii_cpu') backend_test.exclude(r'test_nllloss_NCd1d2d3d4d5_mean_weight_cpu') backend_test.exclude(r'test_nllloss_NCd1d2d3d4d5_none_no_weight_cpu') - backend_test.exclude(r'test_reduce_sum_empty_axes_input_noop_random_cpu') backend_test.exclude(r'test_sce_NCd1_mean_weight_negative_ii_cpu') backend_test.exclude(r'test_sce_NCd1_mean_weight_negative_ii_expanded_cpu') backend_test.exclude(r'test_sce_NCd1_mean_weight_negative_ii_log_prob_cpu') @@ -553,9 +340,6 @@ def disabled_tests_onnx_1_8_0(backend_test): backend_test.exclude(r'test_sce_sum_log_prob_expanded_cpu') backend_test.exclude(r'test_sequence_insert_at_back_cpu') backend_test.exclude(r'test_sequence_insert_at_front_cpu') - backend_test.exclude(r'test_split_variable_parts_1d_cpu') - backend_test.exclude(r'test_split_variable_parts_2d_cpu') - backend_test.exclude(r'test_split_variable_parts_default_axis_cpu') def disabled_tests_onnx_1_9_0(backend_test): @@ -606,18 +390,10 @@ def disabled_tests_onnx_1_10_0(backend_test): backend_test.exclude(r'test_castlike_FLOAT_to_STRING_expanded_cpu') backend_test.exclude(r'test_castlike_STRING_to_FLOAT_cpu') backend_test.exclude(r'test_castlike_STRING_to_FLOAT_expanded_cpu') - backend_test.exclude(r'test_optional_get_element_cpu') backend_test.exclude(r'test_optional_get_element_sequence_cpu') - backend_test.exclude(r'test_optional_has_element_cpu') - backend_test.exclude(r'test_optional_has_element_empty_cpu') def disabled_tests_onnx_1_11_0(backend_test): - # fails - # from OnnxBackendNodeModelTest - backend_test.exclude(r'test_roialign_aligned_false_cpu') - backend_test.exclude(r'test_roialign_aligned_true_cpu') - # errors # from OnnxBackendNodeModelTest backend_test.exclude(r'test_gridsample_aligncorners_true_cpu') @@ -651,28 +427,6 @@ def disabled_tests_onnx_1_12_0(backend_test): backend_test.exclude(r'test_hannwindow_expanded_cpu') backend_test.exclude(r'test_hannwindow_symmetric_cpu') backend_test.exclude(r'test_hannwindow_symmetric_expanded_cpu') - backend_test.exclude(r'test_layer_normalization_2d_axis0_cpu') - backend_test.exclude(r'test_layer_normalization_2d_axis1_cpu') - backend_test.exclude(r'test_layer_normalization_2d_axis_negative_1_cpu') - backend_test.exclude(r'test_layer_normalization_2d_axis_negative_2_cpu') - backend_test.exclude(r'test_layer_normalization_3d_axis0_epsilon_cpu') - backend_test.exclude(r'test_layer_normalization_3d_axis1_epsilon_cpu') - backend_test.exclude(r'test_layer_normalization_3d_axis2_epsilon_cpu') - backend_test.exclude( - r'test_layer_normalization_3d_axis_negative_1_epsilon_cpu') - backend_test.exclude( - r'test_layer_normalization_3d_axis_negative_2_epsilon_cpu') - backend_test.exclude( - r'test_layer_normalization_3d_axis_negative_3_epsilon_cpu') - backend_test.exclude(r'test_layer_normalization_4d_axis0_cpu') - backend_test.exclude(r'test_layer_normalization_4d_axis1_cpu') - backend_test.exclude(r'test_layer_normalization_4d_axis2_cpu') - backend_test.exclude(r'test_layer_normalization_4d_axis3_cpu') - backend_test.exclude(r'test_layer_normalization_4d_axis_negative_1_cpu') - backend_test.exclude(r'test_layer_normalization_4d_axis_negative_2_cpu') - backend_test.exclude(r'test_layer_normalization_4d_axis_negative_3_cpu') - backend_test.exclude(r'test_layer_normalization_4d_axis_negative_4_cpu') - backend_test.exclude(r'test_layer_normalization_default_axis_cpu') backend_test.exclude(r'test_melweightmatrix_cpu') backend_test.exclude(r'test_sequence_map_add_1_sequence_1_tensor_cpu') backend_test.exclude( @@ -759,10 +513,6 @@ def disabled_tests_onnx_1_13_0(backend_test): backend_test.exclude(r'test_col2im_pads_cpu') backend_test.exclude(r'test_col2im_strides_cpu') backend_test.exclude(r'test_constant_pad_axes_cpu') - backend_test.exclude(r'test_group_normalization_epsilon_cpu') - backend_test.exclude(r'test_group_normalization_epsilon_expanded_cpu') - backend_test.exclude(r'test_group_normalization_example_cpu') - backend_test.exclude(r'test_group_normalization_example_expanded_cpu') backend_test.exclude(r'test_mish_cpu') backend_test.exclude(r'test_optional_get_element_optional_sequence_cpu') backend_test.exclude(r'test_optional_get_element_optional_tensor_cpu') @@ -778,8 +528,6 @@ def disabled_tests_onnx_1_13_0(backend_test): backend_test.exclude(r'test_optional_has_element_empty_optional_input_cpu') backend_test.exclude(r'test_optional_has_element_optional_input_cpu') backend_test.exclude(r'test_optional_has_element_tensor_input_cpu') - backend_test.exclude(r'test_prelu_broadcast_expanded_cpu') - backend_test.exclude(r'test_prelu_example_expanded_cpu') backend_test.exclude(r'test_reduce_l1_default_axes_keepdims_example_cpu') backend_test.exclude(r'test_reduce_l1_default_axes_keepdims_random_cpu') backend_test.exclude(r'test_reduce_l2_default_axes_keepdims_example_cpu') @@ -810,11 +558,6 @@ def disabled_tests_onnx_1_13_0(backend_test): def disabled_tests_onnx_1_14_0(backend_test): - # fails - # from OnnxBackendNodeModelTest - backend_test.exclude(r'test_averagepool_2d_dilations_cpu') - backend_test.exclude(r'test_roialign_mode_max_cpu') - # errors # from OnnxBackendNodeModelTest backend_test.exclude(r'test_basic_deform_conv_with_padding_cpu') @@ -839,10 +582,6 @@ def disabled_tests_onnx_1_14_0(backend_test): r'test_resize_downsample_scales_linear_half_pixel_symmetric_cpu') backend_test.exclude( r'test_resize_upsample_scales_linear_half_pixel_symmetric_cpu') - - # The following tests fail due to the CastLike operator being unsupported - backend_test.exclude(r'test_softplus_example_expanded_ver18_cpu') - backend_test.exclude(r'test_softplus_expanded_ver18_cpu') backend_test.exclude(r'test_split_to_sequence_1_cpu') backend_test.exclude(r'test_split_to_sequence_2_cpu') backend_test.exclude(r'test_split_to_sequence_nokeepdims_cpu') @@ -1270,6 +1009,7 @@ def create_backend_test(testname=None, target_device=None): backend_test.include(r'.*test_sequence_map.*') backend_test.include(r'.*test_shrink.*') backend_test.include(r'.*test_[sS]oftmax.*') + backend_test.include(r'.*test_[sS]oftmin.*') backend_test.include(r'.*test_[sS]oftplus.*') backend_test.include(r'.*test_[sS]oftsign.*') backend_test.include(r'.*test_sce.*') @@ -1334,14 +1074,11 @@ def create_backend_test(testname=None, target_device=None): # backend_test.include(r'.*test_momentum.*') # backend_test.include(r'.*test_nesterov_momentum.*') # backend_test.include(r'.*test_training_dropout.*') - # backend_test.include(r'.*test_Softmin.*') # Exclude failing tests # from OnnxBackendRealModelTest backend_test.exclude(r'test_inception_v1_cpu') - backend_test.exclude(r'test_resnet50_cpu') - backend_test.exclude(r'test_squeezenet_cpu') # PRelu OnnxBackendPyTorchConvertedModelTest has wrong dim for broadcasting backend_test.exclude(r'[a-z,_]*PReLU_[0-9]d_multiparam[a-z,_]*') diff --git a/test/py/test_autocast_fp8.py b/test/py/test_autocast_fp8.py new file mode 100644 index 00000000000..e084ac6bdf2 --- /dev/null +++ b/test/py/test_autocast_fp8.py @@ -0,0 +1,77 @@ +##################################################################################### +# The MIT License (MIT) +# +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +##################################################################################### +import migraphx + + +def test_autocast_fp8_1(): + p1 = migraphx.program() + m1 = p1.get_main_module() + x = m1.add_parameter("x", shape=migraphx.shape(type='fp8e4m3fnuz_type')) + y = m1.add_parameter("y", shape=migraphx.shape(type='fp8e4m3fnuz_type')) + sum_op = m1.add_instruction(migraphx.op("add"), [x, y]) + m1.add_return([sum_op]) + + m1 = migraphx.autocast_fp8(p1) + + p2 = migraphx.program() + m2 = p2.get_main_module() + y_fp32 = m2.add_parameter("y", shape=migraphx.shape(type='float_type')) + x_fp32 = m2.add_parameter("x", shape=migraphx.shape(type='float_type')) + + y_fp8 = m2.add_instruction( + migraphx.op("convert", + target_type=int(migraphx.shape.type_t.fp8e4m3fnuz_type)), + [y_fp32]) + x_fp8 = m2.add_instruction( + migraphx.op("convert", + target_type=int(migraphx.shape.type_t.fp8e4m3fnuz_type)), + [x_fp32]) + + sum_fp8 = m2.add_instruction(migraphx.op("add"), [x_fp8, y_fp8]) + sum_fp32 = m2.add_instruction( + migraphx.op("convert", + target_type=int(migraphx.shape.type_t.float_type)), + [sum_fp8]) + + m2.add_return([sum_fp32]) + assert p1 == p2 + + +def test_autocast_fp8_2(): + p1 = migraphx.program() + m1 = p1.get_main_module() + x = m1.add_parameter("x", shape=migraphx.shape(type='float_type')) + y = m1.add_parameter("y", shape=migraphx.shape(type='float_type')) + sum = m1.add_instruction(migraphx.op("add"), [x, y]) + m1.add_return([sum]) + + m1 = migraphx.autocast_fp8(p1) + + p2 = p1 + assert p1 == p2 + + +if __name__ == "__main__": + test_autocast_fp8_1() + test_autocast_fp8_2() diff --git a/test/ref/pack_unpack_int4.cpp b/test/ref/pack_unpack_int4.cpp new file mode 100644 index 00000000000..e79555855ef --- /dev/null +++ b/test/ref/pack_unpack_int4.cpp @@ -0,0 +1,135 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include + +#include + +TEST_CASE(pack_unpack_int4) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::uint8_type, {2, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s, {0x0A, 0x0B, 0x0C, 0x0D}}); + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_int4"), l0); + mm->add_instruction(migraphx::make_op("unpack_int4"), pack_ins); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0x0A, 0x0B, 0x0C, 0x0D}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + +TEST_CASE(pack_unpack_int4_transposed) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::uint8_type, {2, 2}, {1, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s, {0x0A, 0x0B, 0x0C, 0x0D}}); + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_int4"), l0); + mm->add_instruction(migraphx::make_op("unpack_int4"), pack_ins); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0x0A, 0x0B, 0x0C, 0x0D}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + +TEST_CASE(pack_multibroadcast_unpack_int4) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::uint8_type, {4}, {1}}; + auto l0 = mm->add_literal(migraphx::literal{s, {0x0A, 0x0B, 0x0C, 0x0D}}); + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_int4"), l0); + auto mb_pack = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 2}}}), pack_ins); + mm->add_instruction(migraphx::make_op("unpack_int4"), mb_pack); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(16); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0x0A, + 0x0B, + 0x0C, + 0x0D, + 0x0A, + 0x0B, + 0x0C, + 0x0D, + 0x0A, + 0x0B, + 0x0C, + 0x0D, + 0x0A, + 0x0B, + 0x0C, + 0x0D}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + +TEST_CASE(pack_unpack_int4_axis_0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::uint8_type, {2, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s, {0x0A, 0x0B, 0x0C, 0x0D}}); + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_int4", {{"axis", 0}}), l0); + mm->add_instruction(migraphx::make_op("unpack_int4", {{"axis", 0}}), pack_ins); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0x0A, 0x0B, 0x0C, 0x0D}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + +TEST_CASE(pack_unpack_int4_nchw) +{ + // test with literal values such as 0x18 in which first 4 bits will be dropped, ideally + // quantizer should produce values that fit into 4 bits. + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::uint8_type, {1, 2, 4, 4}}; + auto l0 = mm->add_literal( + migraphx::literal{s, {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, + 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, + 0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F}}); + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_int4", {{"axis", -1}}), l0); + mm->add_instruction(migraphx::make_op("unpack_int4", {{"axis", -1}}), pack_ins); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(32); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, + 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, + 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} diff --git a/test/ref/unpack_int4.cpp b/test/ref/unpack_int4.cpp new file mode 100644 index 00000000000..c32e61ff2ea --- /dev/null +++ b/test/ref/unpack_int4.cpp @@ -0,0 +1,127 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include + +#include + +TEST_CASE(unpack_int4) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::uint8_type, {2, 1}}; + auto l0 = mm->add_literal(migraphx::literal{s, {0xBA, 0xDC}}); + mm->add_instruction(migraphx::make_op("unpack_int4"), l0); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0x0A, 0x0B, 0x0C, 0x0D}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + +TEST_CASE(unpack_int4_transposed) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::uint8_type, {2, 2}, {1, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s, {0x1A, 0x2B, 0x3C, 0x4D}}); + mm->add_instruction(migraphx::make_op("unpack_int4"), l0); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(8); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0x0A, 0x01, 0x0B, 0x02, 0x0C, 0x03, 0x0D, 0x04}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + +TEST_CASE(unpack_int4_broadcasted) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::uint8_type, {4}, {1}}; + auto l0 = mm->add_literal(migraphx::literal{s, {0x1A, 0x2B, 0x3C, 0x4D}}); + auto l0b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 4}}}), l0); + mm->add_instruction(migraphx::make_op("unpack_int4"), l0b); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(32); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0x0A, 0x01, 0x0B, 0x02, 0x0C, 0x03, 0x0D, 0x04, 0x0A, 0x01, 0x0B, + 0x02, 0x0C, 0x03, 0x0D, 0x04, 0x0A, 0x01, 0x0B, 0x02, 0x0C, 0x03, + 0x0D, 0x04, 0x0A, 0x01, 0x0B, 0x02, 0x0C, 0x03, 0x0D, 0x04}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + +TEST_CASE(unpack_int4_axis_0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::uint8_type, {1, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s, {0xCA, 0xDB}}); + mm->add_instruction(migraphx::make_op("unpack_int4", {{"axis", 0}}), l0); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0x0A, 0x0B, 0x0C, 0x0D}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + +TEST_CASE(unpack_int4_nchw) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::uint8_type, {1, 2, 4, 2}}; + auto l0 = mm->add_literal(migraphx::literal{s, + {0x10, + 0x32, + 0x54, + 0x76, + 0x98, + 0xBA, + 0xDC, + 0xFE, + 0x10, + 0x32, + 0x54, + 0x76, + 0x98, + 0xBA, + 0xDC, + 0xFE}}); + mm->add_instruction(migraphx::make_op("unpack_int4", {{"axis", -1}}), l0); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(32); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, + 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, + 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +}