From 6e86734d12618328235748b2d54c9de1e3f203c7 Mon Sep 17 00:00:00 2001 From: kahmed10 <15948690+kahmed10@users.noreply.github.com> Date: Wed, 18 Oct 2023 21:08:47 -0500 Subject: [PATCH 01/12] update script when offload copy is disabled (#2348) --- tools/accuracy/accuracy_checker.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tools/accuracy/accuracy_checker.py b/tools/accuracy/accuracy_checker.py index d368ca2a29e..8752bbe7f78 100644 --- a/tools/accuracy/accuracy_checker.py +++ b/tools/accuracy/accuracy_checker.py @@ -220,10 +220,16 @@ def main(): else: test_input = np.zeros(in_shape).astype(get_np_datatype(in_type)) test_inputs[name] = test_input - params[name] = migraphx.argument(test_input) + migraphx_arg = migraphx.argument(test_input) + if not args.offload_copy: + migraphx_arg = migraphx.to_gpu(migraphx_arg) + params[name] = migraphx_arg if not args.ort_run: - pred_migx = np.array(model.run(params)[-1]) + if not args.offload_copy: + pred_migx = np.array(migraphx.from_gpu(model.run(params)[-1])) + else: + pred_migx = np.array(model.run(params)[-1]) if use_onnx: sess_op = ort.SessionOptions() From 581b1b5f2953feb6d8e9dcbb15d7b4e414097ae4 Mon Sep 17 00:00:00 2001 From: Umang Yadav <29876643+umangyadav@users.noreply.github.com> Date: Thu, 19 Oct 2023 10:25:00 -0400 Subject: [PATCH 02/12] Add flag to accept non-uniform WG sizes (#2167) * Disable -Wunsafe-buffer-usage when compiling gpu code --- src/targets/gpu/compile_hip_code_object.cpp | 26 +++++--- .../include/migraphx/kernels/index.hpp | 63 ++++++++++--------- 2 files changed, 52 insertions(+), 37 deletions(-) diff --git a/src/targets/gpu/compile_hip_code_object.cpp b/src/targets/gpu/compile_hip_code_object.cpp index 74aa9d24138..d2c7dfc8fda 100644 --- a/src/targets/gpu/compile_hip_code_object.cpp +++ b/src/targets/gpu/compile_hip_code_object.cpp @@ -139,6 +139,12 @@ void hip_compile_options::set_launch_params( global = compute_global(local); } +static bool hip_accept_non_uniform_wg() +{ + static bool non_uniform_wg = hip_has_flags({"-fno-offload-uniform-block"}); + return non_uniform_wg; +} + std::function compute_global_for(context& ctx, std::size_t n, std::size_t over) { @@ -146,13 +152,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over) std::size_t max_global = ctx.get_current_device().get_cu_count() * ctx.get_current_device().get_max_workitems_per_cu(); return [n, over, max_global](std::size_t local) { - // hip require global workitems multiple of local workitems. It may degrade performance. - // [TODO]: consider adding "fno-hip-uniform-block" flag when it becomes available. - // https://reviews.llvm.org/D155213 - std::size_t num_elements = ((n + local - 1) / local) * local; - std::size_t groups = (num_elements + local - 1) / local; - std::size_t max_blocks = max_global / local; - std::size_t nglobal = std::min(max_blocks * over, groups) * local; + std::size_t num_elements = n; + if(not hip_accept_non_uniform_wg()) + { + num_elements = (1 + (n - 1) / local) * local; + } + std::size_t groups = 1 + (num_elements - 1) / local; + std::size_t max_blocks = max_global / local; + std::size_t nglobal = std::min(max_blocks * over, groups) * local; return std::min(nglobal, num_elements); }; } @@ -183,6 +190,11 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs); srcs.emplace_back("args.hpp", args_hpp); + if(options.global % options.local != 0 and hip_accept_non_uniform_wg()) + options.params += " -fno-offload-uniform-block"; + else + assert(options.global % options.local == 0); + options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global); options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local); options.params += " " + join_strings(compiler_warnings(), " "); diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp index beaf645c38a..a015e02e964 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp @@ -31,6 +31,14 @@ #include #include +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wreserved-identifier" +extern "C" __device__ size_t __ockl_get_enqueued_local_size(uint); // NOLINT +extern "C" __device__ size_t __ockl_get_local_size(uint); // NOLINT +#pragma clang diagnostic pop +#endif + namespace migraphx { #if defined(MIGRAPHX_NGLOBAL) && defined(MIGRAPHX_NLOCAL) @@ -45,43 +53,37 @@ inline __device__ __attribute__((const)) index_int compute_global_size() // This actualy works even when global is not divisible by local size. // This doesnt actually do a multiplicatiosn. Instead it calls a device // function to get the global size, which is why it works. - return blockDim.x * gridDim.x; // NOLINT + return blockDim.x * gridDim.x; // NOLINT #endif } -// We cant just use blockDim.x to get the local size since its broken on hip -// when global is not divisible by local size. In this case, we calulate the -// size for the last group. +#ifdef MIGRAPHX_NGROUP +// If global is divisible by local then local can be a const +#if(MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1) +#define MIGRAPHX_HAS_CONST_LOCAL 1 +#endif +#endif + inline __device__ __attribute__((const)) index_int compute_local_size() { -#ifdef MIGRAPHX_NLOCAL - const auto nlocal = MIGRAPHX_NLOCAL; -#else - const auto nlocal = blockDim.x; // NOLINT -#endif -#ifdef MIGRAPHX_NGROUP - const auto ngroup = MIGRAPHX_NGROUP; +#ifdef MIGRAPHX_HAS_CONST_LOCAL + return MIGRAPHX_NLOCAL; #else - const auto ngroup = gridDim.x; // NOLINT + // Returns block size. For the non-uniform block it returns the size of the non-uniform block. + return __ockl_get_local_size(0); // NOLINT #endif - const auto group_id = blockIdx.x; // NOLINT - const auto nglobal = compute_global_size(); - if(group_id == ngroup - 1) - { - return 1 + (nglobal - 1) % nlocal; - } - else - { - return nlocal; // NOLINT - } } -#ifdef MIGRAPHX_NGROUP -// If global is divisible by local then local can be a const -#if(MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1) -#define MIGRAPHX_HAS_CONST_LOCAL 1 -#endif +inline __device__ __attribute__((const)) index_int compute_max_local_size() +{ +#ifdef MIGRAPHX_LOCAL + return MIGRAPHX_NLOCAL; +#else + // Returns the block size. When workgrop has non-uniform block, this returns size of the uniform + // block. + return __ockl_get_enqueued_local_size(0); // NOLINT #endif +} struct index { @@ -126,8 +128,8 @@ struct index #else __device__ index_int max_nlocal() const { - MIGRAPHX_ASSERT(blockDim.x > 0); - return blockDim.x; + MIGRAPHX_ASSERT(compute_max_local_size() > 0); + return compute_max_local_size(); } #endif @@ -249,7 +251,8 @@ struct index #endif inline __device__ __attribute__((const)) index make_index() { - return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT + return index{ + blockIdx.x * compute_max_local_size() + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT } } // namespace migraphx From 07848b28dbc135f34a4c1bf57dc9af039d7b31c6 Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Thu, 19 Oct 2023 09:25:56 -0500 Subject: [PATCH 03/12] Make argument constructor explicit (#2346) --- src/include/migraphx/argument.hpp | 2 +- src/include/migraphx/op/allocate.hpp | 4 ++-- src/include/migraphx/op/pooling.hpp | 4 ++-- src/targets/gpu/include/migraphx/gpu/convolution.hpp | 6 +++--- test/eliminate_allocation_test.cpp | 2 +- test/eliminate_concat_test.cpp | 4 ++-- test/memory_coloring_test.cpp | 2 +- test/normalize_ops_test.cpp | 2 +- test/replace_allocate.cpp | 4 ++-- 9 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/include/migraphx/argument.hpp b/src/include/migraphx/argument.hpp index 0326e460b0d..6f78d952d5c 100644 --- a/src/include/migraphx/argument.hpp +++ b/src/include/migraphx/argument.hpp @@ -46,7 +46,7 @@ struct MIGRAPHX_EXPORT argument : raw_data { argument() = default; - argument(const shape& s); + explicit argument(const shape& s); template ()())>{})> argument(shape s, F d) diff --git a/src/include/migraphx/op/allocate.hpp b/src/include/migraphx/op/allocate.hpp index 33ea6bb2260..e2670c64c11 100644 --- a/src/include/migraphx/op/allocate.hpp +++ b/src/include/migraphx/op/allocate.hpp @@ -88,13 +88,13 @@ struct allocate { if(args.empty()) { - return {output_shape}; + return argument{output_shape}; } else { std::vector output_dims(output_shape.ndim()); args.at(0).visit([&](auto a) { output_dims.assign(a.begin(), a.end()); }); - return {shape{buf_type, output_dims}}; + return argument{shape{buf_type, output_dims}}; } } }; diff --git a/src/include/migraphx/op/pooling.hpp b/src/include/migraphx/op/pooling.hpp index 276ad3295fe..7bfe456f3a0 100644 --- a/src/include/migraphx/op/pooling.hpp +++ b/src/include/migraphx/op/pooling.hpp @@ -411,7 +411,7 @@ struct pooling // for dynamic GlobalPooling, there's no padding kernel_dims.insert(kernel_dims.end(), input_lens.begin() + 2, input_lens.end()); output_shape = dyn_out.computed_shape; - result = dyn_out.computed_shape; + result = argument{dyn_out.computed_shape}; } else if((padding_mode != op::padding_mode_t::default_)) { @@ -439,7 +439,7 @@ struct pooling { kernel_dims = this->lengths; output_shape = dyn_out.computed_shape; - result = dyn_out.computed_shape; + result = argument{dyn_out.computed_shape}; } // Perform the computation and populate result diff --git a/src/targets/gpu/include/migraphx/gpu/convolution.hpp b/src/targets/gpu/include/migraphx/gpu/convolution.hpp index d6680f17ec8..f88cee86855 100644 --- a/src/targets/gpu/include/migraphx/gpu/convolution.hpp +++ b/src/targets/gpu/include/migraphx/gpu/convolution.hpp @@ -199,9 +199,9 @@ struct miopen_convolution // MIOpen has APIs to pass pre-allocated buffers starting from rocm-5.6 preallocate = true; #endif - auto x = preallocate ? to_gpu(generate_argument(x_shape)) : inputs[0]; - auto w = preallocate ? to_gpu(generate_argument(w_shape)) : inputs[1]; - auto y = preallocate ? allocate_gpu(output_shape) : inputs[2]; + auto x = preallocate ? to_gpu(generate_argument(x_shape)) : argument{inputs[0]}; + auto w = preallocate ? to_gpu(generate_argument(w_shape)) : argument{inputs[1]}; + auto y = preallocate ? allocate_gpu(output_shape) : argument{inputs[2]}; auto workspace = preallocate ? allocate_gpu(workspace_shape) : migraphx::argument(workspace_shape); diff --git a/test/eliminate_allocation_test.cpp b/test/eliminate_allocation_test.cpp index 2bfc7a54809..ba1179d1be7 100644 --- a/test/eliminate_allocation_test.cpp +++ b/test/eliminate_allocation_test.cpp @@ -55,7 +55,7 @@ struct allocate const migraphx::shape& output_shape, const std::vector&) const { - return {output_shape}; + return migraphx::argument{output_shape}; } }; diff --git a/test/eliminate_concat_test.cpp b/test/eliminate_concat_test.cpp index 13984e98645..dc2834bfa91 100644 --- a/test/eliminate_concat_test.cpp +++ b/test/eliminate_concat_test.cpp @@ -60,7 +60,7 @@ struct concat const migraphx::shape& output_shape, const std::vector&) const { - return {output_shape}; + return migraphx::argument{output_shape}; } }; @@ -104,7 +104,7 @@ struct allocate const migraphx::shape& output_shape, const std::vector&) const { - return {output_shape}; + return migraphx::argument{output_shape}; } }; diff --git a/test/memory_coloring_test.cpp b/test/memory_coloring_test.cpp index 7716c8b89a8..7cbb3efdec6 100644 --- a/test/memory_coloring_test.cpp +++ b/test/memory_coloring_test.cpp @@ -55,7 +55,7 @@ struct allocate const migraphx::shape& output_shape, const std::vector&) const { - return {output_shape}; + return migraphx::argument{output_shape}; } }; diff --git a/test/normalize_ops_test.cpp b/test/normalize_ops_test.cpp index f9ec2f033c2..a48223dd5fe 100644 --- a/test/normalize_ops_test.cpp +++ b/test/normalize_ops_test.cpp @@ -57,7 +57,7 @@ struct normalize_test_op const migraphx::shape& output_shape, const std::vector&) const { - return {output_shape}; + return migraphx::argument{output_shape}; } }; diff --git a/test/replace_allocate.cpp b/test/replace_allocate.cpp index 90b8a943973..68e3cfd5d37 100644 --- a/test/replace_allocate.cpp +++ b/test/replace_allocate.cpp @@ -54,7 +54,7 @@ struct allocate_no_out : migraphx::auto_register_op const migraphx::shape& output_shape, const std::vector&) const { - return {output_shape}; + return migraphx::argument{output_shape}; } }; @@ -78,7 +78,7 @@ struct allocate_with_out : migraphx::auto_register_op const migraphx::shape& output_shape, const std::vector&) const { - return {output_shape}; + return migraphx::argument{output_shape}; } }; From 49eb032263d68b8ca56f9c347d0494b633e6a450 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 19 Oct 2023 10:26:31 -0400 Subject: [PATCH 04/12] Onnxruntime Weekly Sync 2023-10-13 (#2335) * Update onnxruntime main 635d3faa3b3908d2806d009dc6872152cfcfcdda * Update script for build_and_test_onnxrt.sh --- test/onnx/.onnxrt-commit | 2 +- tools/build_and_test_onnxrt.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/onnx/.onnxrt-commit b/test/onnx/.onnxrt-commit index af082fddc91..c2996e0bb69 100644 --- a/test/onnx/.onnxrt-commit +++ b/test/onnx/.onnxrt-commit @@ -1 +1 @@ -6d7bc2a097a1a08541cd0d4628831c79ab8092d5 +635d3faa3b3908d2806d009dc6872152cfcfcdda diff --git a/tools/build_and_test_onnxrt.sh b/tools/build_and_test_onnxrt.sh index 75915b15f6e..52c2b93e0f7 100755 --- a/tools/build_and_test_onnxrt.sh +++ b/tools/build_and_test_onnxrt.sh @@ -40,4 +40,4 @@ echo 'InferenceSessionTests.CheckRunProfilerWithSessionOptions' >> ../../../tool echo 'InferenceSessionTests.CheckRunProfilerWithSessionOptions2' >> ../../../tools/ci_build/github/pai/migraphx-excluded-tests.txt echo 'InferenceSessionTests.Test3LayerNestedSubgraph' >> ../../../tools/ci_build/github/pai/migraphx-excluded-tests.txt echo 'InferenceSessionTests.Test2LayerNestedSubgraph' >> ../../../tools/ci_build/github/pai/migraphx-excluded-tests.txt -../../../tools/ci_build/github/pai/migraphx_test_launcher.sh || (gdb ./onnxruntime_test_all core -batch -ex bt && exit 1) +../../../tools/ci_build/github/pai/pai_test_launcher.sh || (gdb ./onnxruntime_test_all core -batch -ex bt && exit 1) From 1f3e7c7cd35bbf0c67303dbf161b2ad9c2500ac8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 19 Oct 2023 10:32:14 -0400 Subject: [PATCH 05/12] Bump urllib3 from 1.26.15 to 1.26.18 in /docs/.sphinx (#2342) Bumps [urllib3](https://github.com/urllib3/urllib3) from 1.26.15 to 1.26.18. --- docs/.sphinx/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/.sphinx/requirements.txt b/docs/.sphinx/requirements.txt index 1cb36ab86ed..d8afc285852 100644 --- a/docs/.sphinx/requirements.txt +++ b/docs/.sphinx/requirements.txt @@ -130,7 +130,7 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx typing-extensions==4.5.0 # via pydata-sphinx-theme -urllib3==1.26.15 +urllib3==1.26.18 # via requests wrapt==1.15.0 # via deprecated From 3d101611bd5408a58eacb2e0461e8e2ba437bdb3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 19 Oct 2023 10:33:10 -0400 Subject: [PATCH 06/12] Bump onnxruntime from 1.10.0 to 1.16.1 in /tools/accuracy (#2321) --- tools/accuracy/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/accuracy/requirements.txt b/tools/accuracy/requirements.txt index 0329701e0d7..07a39660d50 100644 --- a/tools/accuracy/requirements.txt +++ b/tools/accuracy/requirements.txt @@ -22,4 +22,4 @@ # THE SOFTWARE. ##################################################################################### numpy==1.21.6 -onnxruntime==1.10.0 +onnxruntime==1.16.1 From e64e17940d1d632f09884dba884a4b17d91809cc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 19 Oct 2023 10:34:45 -0400 Subject: [PATCH 07/12] Bump rocm-docs-core from 0.24.2 to 0.26.0 in /docs/.sphinx (#2325) --- docs/.sphinx/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/.sphinx/requirements.txt b/docs/.sphinx/requirements.txt index d8afc285852..71a3573b5bb 100644 --- a/docs/.sphinx/requirements.txt +++ b/docs/.sphinx/requirements.txt @@ -87,7 +87,7 @@ requests==2.28.2 # via # pygithub # sphinx -rocm-docs-core==0.24.2 +rocm-docs-core==0.26.0 # via -r requirements.in smmap==5.0.0 # via gitdb From e7486577c5d604393889a3bdbb6c1ae8c5e1233d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 19 Oct 2023 10:35:39 -0400 Subject: [PATCH 08/12] Bump gitpython from 3.1.32 to 3.1.37 in /docs/.sphinx (#2312) --- docs/.sphinx/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/.sphinx/requirements.txt b/docs/.sphinx/requirements.txt index 71a3573b5bb..c6fdfee8b9b 100644 --- a/docs/.sphinx/requirements.txt +++ b/docs/.sphinx/requirements.txt @@ -35,7 +35,7 @@ fastjsonschema==2.16.3 # via rocm-docs-core gitdb==4.0.10 # via gitpython -gitpython==3.1.32 +gitpython==3.1.37 # via rocm-docs-core idna==3.4 # via requests From c8f1cd93c1eda9a67cc37fb8db145575d9ad3fba Mon Sep 17 00:00:00 2001 From: arvindcheru <90783369+arvindcheru@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:56:21 -0400 Subject: [PATCH 09/12] Update lib target soversion (#2267) --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 42f263a9af3..a76fb11c090 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,7 +76,7 @@ include(ROCMSetupVersion) option(BUILD_DEV "Build for development purpose only" OFF) rocm_setup_version(VERSION 2.8.0) -set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}.${PROJECT_VERSION_PATCH}) +set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}) option( BUILD_SHARED_LIBS "Build as a shared library" ON ) From 6072b2c4dfa0d74259b9cddbbd2aeebc1c290031 Mon Sep 17 00:00:00 2001 From: music-dino <111048524+music-dino@users.noreply.github.com> Date: Fri, 20 Oct 2023 05:56:47 +0200 Subject: [PATCH 10/12] Add MeanVarianceNormalization ONNX parsing (#2255) --- .../parse_mean_variance_normalization.cpp | 86 ++++++++++++++ test/onnx/gen_onnx.py | 71 ++++++++++++ test/onnx/mvn_axes_rank_too_big_test.onnx | Bin 0 -> 173 bytes test/onnx/mvn_axes_rank_too_small_test.onnx | Bin 0 -> 181 bytes test/onnx/mvn_default_axes_fp16_test.onnx | 17 +++ .../mvn_default_axes_rank_too_small_test.onnx | 13 +++ test/onnx/mvn_default_axes_test.onnx | 15 +++ test/onnx/mvn_rank_2_fp16_test.onnx | 14 +++ test/onnx/mvn_rank_2_test.onnx | 12 ++ test/onnx/mvn_rank_3_fp16_test.onnx | Bin 0 -> 163 bytes test/onnx/mvn_rank_3_test.onnx | Bin 0 -> 152 bytes test/onnx/onnx_test.cpp | 60 ++++++++++ test/onnx/verify_onnx.cpp | 109 ++++++++++++++++++ test/py/onnx_backend_test.py | 2 - 14 files changed, 397 insertions(+), 2 deletions(-) create mode 100644 src/onnx/parse_mean_variance_normalization.cpp create mode 100644 test/onnx/mvn_axes_rank_too_big_test.onnx create mode 100644 test/onnx/mvn_axes_rank_too_small_test.onnx create mode 100644 test/onnx/mvn_default_axes_fp16_test.onnx create mode 100644 test/onnx/mvn_default_axes_rank_too_small_test.onnx create mode 100644 test/onnx/mvn_default_axes_test.onnx create mode 100644 test/onnx/mvn_rank_2_fp16_test.onnx create mode 100644 test/onnx/mvn_rank_2_test.onnx create mode 100644 test/onnx/mvn_rank_3_fp16_test.onnx create mode 100644 test/onnx/mvn_rank_3_test.onnx diff --git a/src/onnx/parse_mean_variance_normalization.cpp b/src/onnx/parse_mean_variance_normalization.cpp new file mode 100644 index 00000000000..75287d300d4 --- /dev/null +++ b/src/onnx/parse_mean_variance_normalization.cpp @@ -0,0 +1,86 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_mean_variance_normalization : op_parser +{ + std::vector operators() const { return {{"MeanVarianceNormalization"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + onnx_parser::node_info info, + std::vector args) const + { + auto&& data = args.front(); + auto data_rank = data->get_shape().ndim(); + std::vector axes{0, 2, 3}; + + if(contains(info.attributes, "axes")) + { + const auto& axes_attr = info.attributes["axes"].ints(); + axes.assign(axes_attr.begin(), axes_attr.end()); + } + else if(data_rank != 4) + { + MIGRAPHX_THROW( + "Input tensor needs to be rank 4 when axes is not specified. Instead it is rank " + + std::to_string(data_rank)); + } + + if(axes.size() != data_rank - 1) + { + MIGRAPHX_THROW("Length of axes array needs to be equal to input tensor rank - 1"); + } + + auto data_mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), data); + auto data_mean_squared = info.add_common_op("mul", data_mean, data_mean); + + auto data_squared = info.add_common_op("mul", data, data); + auto data_squared_mean = + info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), data_squared); + + auto mean_sub = info.add_common_op("sub", data_squared_mean, data_mean_squared); + auto std = info.add_common_op("sqrt", mean_sub); + + auto dividend = info.add_common_op("sub", data, data_mean); + auto epsilon = + info.add_literal({data->get_shape().type(), + {data->get_shape().type() == shape::half_type ? 1e-7 : 1e-9}}); + auto divisor = info.add_common_op("add", std, epsilon); + + return info.add_common_op("div", dividend, divisor); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 9a1caeda801..c4fbbfc0d42 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -4681,6 +4681,77 @@ def mean_integral_test(): return ([node], data, [mean]) +def mvn_default_axes_test_base(dims, type=TensorProto.FLOAT): + data = helper.make_tensor_value_info("data", type, dims) + out = helper.make_tensor_value_info("out", type, dims) + node = helper.make_node("MeanVarianceNormalization", + inputs=["data"], + outputs=["out"]) + + return ([node], [data], [out]) + + +@onnx_test() +def mvn_default_axes_test(): + return mvn_default_axes_test_base([2, 2, 2, 2]) + + +@onnx_test() +def mvn_default_axes_fp16_test(): + return mvn_default_axes_test_base([2, 2, 2, 2], TensorProto.FLOAT16) + + +@onnx_test() +def mvn_default_axes_rank_too_small_test(): + return mvn_default_axes_test_base([2, 2, 2]) + + +@onnx_test() +def mvn_default_axes_rank_too_big_test(): + return mvn_default_axes_test_base([2, 2, 2, 2, 2]) + + +def mvn_n_rank_test_base(axes, dims, type=TensorProto.FLOAT): + data = helper.make_tensor_value_info("data", type, dims) + out = helper.make_tensor_value_info("out", type, dims) + node = helper.make_node("MeanVarianceNormalization", + inputs=["data"], + outputs=["out"], + axes=axes) + + return ([node], [data], [out]) + + +@onnx_test() +def mvn_rank_2_test(): + return mvn_n_rank_test_base([1], [2, 2]) + + +@onnx_test() +def mvn_rank_2_fp16_test(): + return mvn_n_rank_test_base([1], [2, 2], TensorProto.FLOAT16) + + +@onnx_test() +def mvn_rank_3_test(): + return mvn_n_rank_test_base([0, 1], [2, 2, 2]) + + +@onnx_test() +def mvn_rank_3_fp16_test(): + return mvn_n_rank_test_base([0, 1], [2, 2, 2], TensorProto.FLOAT16) + + +@onnx_test() +def mvn_axes_rank_too_small_test(): + return mvn_n_rank_test_base([0, 1, 2], [2, 2, 2]) + + +@onnx_test() +def mvn_axes_rank_too_big_test(): + return mvn_n_rank_test_base([0], [2, 2, 2]) + + @onnx_test() def min_test(): a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) diff --git a/test/onnx/mvn_axes_rank_too_big_test.onnx b/test/onnx/mvn_axes_rank_too_big_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3d0da7e4f1eb052d01a26cdc72479410b3a9ed3d GIT binary patch literal 173 zcmdFG(#fvFc#tGUj4QNi0beV$LruQIhmc zP0R~REXqvGOHTF6FUn2K$*f8&$;{8w;^txj8REdOfRP=anNdw1ANvzsXTjP-BI-1UMLl Xc(|B2m|!?bk_%*+5EdyXCIMjpI;t;K literal 0 HcmV?d00001 diff --git a/test/onnx/mvn_default_axes_fp16_test.onnx b/test/onnx/mvn_default_axes_fp16_test.onnx new file mode 100644 index 00000000000..00519e61176 --- /dev/null +++ b/test/onnx/mvn_default_axes_fp16_test.onnx @@ -0,0 +1,17 @@ + mvn_default_axes_fp16_test:ƒ +& +dataout"MeanVarianceNormalizationmvn_default_axes_fp16_testZ +data + + + + + +b +out + + + + + +B \ No newline at end of file diff --git a/test/onnx/mvn_default_axes_rank_too_small_test.onnx b/test/onnx/mvn_default_axes_rank_too_small_test.onnx new file mode 100644 index 00000000000..75f169b793f --- /dev/null +++ b/test/onnx/mvn_default_axes_rank_too_small_test.onnx @@ -0,0 +1,13 @@ + $mvn_default_axes_rank_too_small_test:… +& +dataout"MeanVarianceNormalization$mvn_default_axes_rank_too_small_testZ +data + + + +b +out + + + +B \ No newline at end of file diff --git a/test/onnx/mvn_default_axes_test.onnx b/test/onnx/mvn_default_axes_test.onnx new file mode 100644 index 00000000000..94b843e2dea --- /dev/null +++ b/test/onnx/mvn_default_axes_test.onnx @@ -0,0 +1,15 @@ + mvn_default_axes_test:~ +& +dataout"MeanVarianceNormalizationmvn_default_axes_testZ +data + + + + +b +out + + + + +B \ No newline at end of file diff --git a/test/onnx/mvn_rank_2_fp16_test.onnx b/test/onnx/mvn_rank_2_fp16_test.onnx new file mode 100644 index 00000000000..f0aa84a2de0 --- /dev/null +++ b/test/onnx/mvn_rank_2_fp16_test.onnx @@ -0,0 +1,14 @@ + mvn_rank_2_fp16_test:z +3 +dataout"MeanVarianceNormalization* +axes@ mvn_rank_2_fp16_testZ +data +  + + +b +out +  + + +B \ No newline at end of file diff --git a/test/onnx/mvn_rank_2_test.onnx b/test/onnx/mvn_rank_2_test.onnx new file mode 100644 index 00000000000..84ab941cf12 --- /dev/null +++ b/test/onnx/mvn_rank_2_test.onnx @@ -0,0 +1,12 @@ + mvn_rank_2_test:u +3 +dataout"MeanVarianceNormalization* +axes@ mvn_rank_2_testZ +data +  + +b +out +  + +B \ No newline at end of file diff --git a/test/onnx/mvn_rank_3_fp16_test.onnx b/test/onnx/mvn_rank_3_fp16_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..33c3f7eb3628559a3c277f91561a2619023ad756 GIT binary patch literal 163 zcmdb1uEiSQYVdOI9Vo6CXNfctvFD+4$^i5673rj4@ zOw3D8^~*2HP0Y!xN-W9D&(q@NVo9t>Ep}jVU|hh+j@`T{DX3vWTml?iLOfhd9855r QB*_J`LI{hL6O({2038A*@c;k- literal 0 HcmV?d00001 diff --git a/test/onnx/mvn_rank_3_test.onnx b/test/onnx/mvn_rank_3_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4536ac8f1183d8e55447451ed7504ca43d270eba GIT binary patch literal 152 zcmd axes, + std::vector input_shape, + const std::string& test_file) +{ + using migraphx::make_op; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("data", {migraphx::shape::float_type, std::move(input_shape)}); + auto data_mean = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), data); + auto data_mean_squared = add_common_op(*mm, make_op("mul"), {data_mean, data_mean}); + + auto data_squared = add_common_op(*mm, make_op("mul"), {data, data}); + auto data_squared_mean = + mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), data_squared); + + auto mean_sub = add_common_op(*mm, make_op("sub"), {data_squared_mean, data_mean_squared}); + auto std = add_common_op(*mm, make_op("sqrt"), {mean_sub}); + + auto dividend = add_common_op(*mm, make_op("sub"), {data, data_mean}); + auto epsilon = mm->add_literal({migraphx::shape::float_type, {1e-9}}); + auto divisor = add_common_op(*mm, make_op("add"), {std, epsilon}); + add_common_op(*mm, make_op("div"), {dividend, divisor}); + + auto prog = optimize_onnx(test_file); + + EXPECT(p == prog); +} + +TEST_CASE(mvn_default_axes_test) +{ + mvn_n_rank_test({0, 2, 3}, {2, 2, 2, 2}, "mvn_default_axes_test.onnx"); +} + +TEST_CASE(mvn_default_axes_rank_too_small_test) +{ + EXPECT( + test::throws([&] { migraphx::parse_onnx("mvn_default_axes_rank_too_small_test.onnx"); })); +} + +TEST_CASE(mvn_default_axes_rank_too_big_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("mvn_default_axes_rank_too_big_test.onnx"); })); +} + +TEST_CASE(mvn_rank_2_test) { mvn_n_rank_test({1}, {2, 2}, "mvn_rank_2_test.onnx"); } + +TEST_CASE(mvn_rank_3_test) { mvn_n_rank_test({0, 1}, {2, 2, 2}, "mvn_rank_3_test.onnx"); } + +TEST_CASE(mvn_axes_rank_too_small_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("mvn_axes_rank_too_small_test.onnx"); })); +} + +TEST_CASE(mvn_axes_rank_too_big_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("mvn_axes_rank_too_big_test.onnx"); })); +} + TEST_CASE(min_test) { migraphx::program p; diff --git a/test/onnx/verify_onnx.cpp b/test/onnx/verify_onnx.cpp index 3079b440787..277e0c7d537 100644 --- a/test/onnx/verify_onnx.cpp +++ b/test/onnx/verify_onnx.cpp @@ -1211,6 +1211,115 @@ TEST_CASE(mean_integral_test) EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); } +template +std::vector mvn_test(std::vector data_lens, const std::string& test_file) +{ + migraphx::program p = migraphx::parse_onnx(test_file); + p.compile(migraphx::make_target("ref")); + + migraphx::shape data_shape(migraphx::shape::get_type{}, std::move(data_lens)); + std::vector data(data_shape.elements()); + std::iota(begin(data), end(data), 0); + + migraphx::parameter_map pm; + pm["data"] = migraphx::argument(data_shape, data.data()); + + auto result = p.eval(pm).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + return result_vector; +} + +TEST_CASE(mvn_default_axes_test) +{ + auto result = mvn_test({2, 2, 2, 2}, "mvn_default_axes_test.onnx"); + std::vector gold{-1.32424438, + -1.08347268, + -0.84270097, + -0.60192927, + -1.32424438, + -1.08347268, + -0.84270097, + -0.60192927, + 0.60192927, + 0.84270097, + 1.08347268, + 1.32424438, + 0.60192927, + 0.84270097, + 1.08347268, + 1.32424438}; + EXPECT(migraphx::verify::verify_rms_range(result, gold)); +} + +TEST_CASE(mvn_default_axes_fp16_test) +{ + using migraphx::half; + auto result = mvn_test({2, 2, 2, 2}, "mvn_default_axes_fp16_test.onnx"); + std::vector gold{half{-1.324}, + half{-1.084}, + half{-0.843}, + half{-0.602}, + half{-1.324}, + half{-1.084}, + half{-0.843}, + half{-0.602}, + half{0.602}, + half{0.843}, + half{1.084}, + half{1.324}, + half{0.602}, + half{0.843}, + half{1.084}, + half{1.324}}; + EXPECT(migraphx::verify::verify_rms_range(result, gold)); +} + +TEST_CASE(mvn_rank_2_test) +{ + auto result = mvn_test({2, 2}, "mvn_rank_2_test.onnx"); + std::vector gold{-1, 1, -1, 1}; + EXPECT(migraphx::verify::verify_rms_range(result, gold)); +} + +TEST_CASE(mvn_rank_2_fp16_test) +{ + using migraphx::half; + auto result = mvn_test({2, 2}, "mvn_rank_2_fp16_test.onnx"); + std::vector gold{half{-1}, half{1}, half{-1}, half{1}}; + EXPECT(migraphx::verify::verify_rms_range(result, gold)); +} + +TEST_CASE(mvn_rank_3_test) +{ + auto result = mvn_test({2, 2, 2}, "mvn_rank_3_test.onnx"); + std::vector gold{-1.34164079, + -1.34164079, + -0.4472136, + -0.4472136, + 0.4472136, + 0.4472136, + 1.34164079, + 1.34164079}; + EXPECT(migraphx::verify::verify_rms_range(result, gold)); +} + +TEST_CASE(mvn_rank_3_fp16_test) +{ + using migraphx::half; + auto result = mvn_test({2, 2, 2}, "mvn_rank_3_fp16_test.onnx"); + std::vector gold{half{-1.342}, + half{-1.342}, + half{-0.4473}, + half{-0.4473}, + half{0.4473}, + half{0.4473}, + half{1.342}, + half{1.342}}; + EXPECT(migraphx::verify::verify_rms_range(result, gold)); +} + TEST_CASE(mod_test) { migraphx::program p = migraphx::parse_onnx("mod_test.onnx"); diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index db2dec28e13..90a0dbaf902 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -154,7 +154,6 @@ def disabled_tests_onnx_1_7_0(backend_test): backend_test.exclude(r'test_maxunpool_export_without_output_shape_cpu') backend_test.exclude(r'test_mod_mixed_sign_int32_cpu') backend_test.exclude(r'test_mod_mixed_sign_int8_cpu') - backend_test.exclude(r'test_mvn_cpu') backend_test.exclude( r'test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index_cpu' ) @@ -803,7 +802,6 @@ def disabled_tests_onnx_1_13_0(backend_test): backend_test.exclude(r'test_group_normalization_example_cpu') backend_test.exclude(r'test_group_normalization_example_expanded_cpu') backend_test.exclude(r'test_mish_cpu') - backend_test.exclude(r'test_mvn_expanded_ver18_cpu') backend_test.exclude(r'test_optional_get_element_optional_sequence_cpu') backend_test.exclude(r'test_optional_get_element_optional_tensor_cpu') backend_test.exclude(r'test_optional_get_element_tensor_cpu') From b8b4630b4666bd3420227c4b580e7e1242771c50 Mon Sep 17 00:00:00 2001 From: nives-vukovic <110852104+nives-vukovic@users.noreply.github.com> Date: Fri, 20 Oct 2023 05:57:27 +0200 Subject: [PATCH 11/12] Fix trilu operator computation logic (#2212) --- src/onnx/parse_trilu.cpp | 8 +-- test/onnx/gen_onnx.py | 78 +++++++++++++++++++++++-- test/onnx/onnx_test.cpp | 5 -- test/onnx/tril_batch_diff_k_test.onnx | Bin 0 -> 150 bytes test/onnx/tril_neg_k_test.onnx | Bin 0 -> 137 bytes test/onnx/tril_out_k_test.onnx | Bin 0 -> 128 bytes test/onnx/tril_row_one_test.onnx | Bin 0 -> 132 bytes test/onnx/tril_test.onnx | Bin 0 -> 101 bytes test/onnx/trilu_batch_diff_k_test.onnx | 15 ----- test/onnx/trilu_lower_test.onnx | Bin 115 -> 0 bytes test/onnx/trilu_neg_k_test.onnx | 13 ----- test/onnx/trilu_out_k_test.onnx | 13 ----- test/onnx/trilu_row_one_test.onnx | 13 ----- test/onnx/trilu_test.onnx | 13 ----- test/onnx/triu_batch_diff_k_test.onnx | 15 +++++ test/onnx/triu_neg_k_test.onnx | 13 +++++ test/onnx/triu_out_k_test.onnx | 13 +++++ test/onnx/triu_row_one_test.onnx | 13 +++++ test/onnx/triu_test.onnx | 11 ++++ test/onnx/verify_onnx.cpp | 76 ++++++++++++++++++++---- test/py/onnx_backend_test.py | 3 - 21 files changed, 207 insertions(+), 95 deletions(-) create mode 100644 test/onnx/tril_batch_diff_k_test.onnx create mode 100644 test/onnx/tril_neg_k_test.onnx create mode 100644 test/onnx/tril_out_k_test.onnx create mode 100644 test/onnx/tril_row_one_test.onnx create mode 100644 test/onnx/tril_test.onnx delete mode 100644 test/onnx/trilu_batch_diff_k_test.onnx delete mode 100644 test/onnx/trilu_lower_test.onnx delete mode 100644 test/onnx/trilu_neg_k_test.onnx delete mode 100644 test/onnx/trilu_out_k_test.onnx delete mode 100644 test/onnx/trilu_row_one_test.onnx delete mode 100644 test/onnx/trilu_test.onnx create mode 100644 test/onnx/triu_batch_diff_k_test.onnx create mode 100644 test/onnx/triu_neg_k_test.onnx create mode 100644 test/onnx/triu_out_k_test.onnx create mode 100644 test/onnx/triu_row_one_test.onnx create mode 100644 test/onnx/triu_test.onnx diff --git a/src/onnx/parse_trilu.cpp b/src/onnx/parse_trilu.cpp index 28e0d4b7aaa..fec48a2e804 100644 --- a/src/onnx/parse_trilu.cpp +++ b/src/onnx/parse_trilu.cpp @@ -56,9 +56,6 @@ struct parse_trilu : op_parser k = arg_k.at(); } - if(k < 0) - MIGRAPHX_THROW("PARSE_TRILU: negative k values not supported"); - if(contains(info.attributes, "upper")) { upper = static_cast(info.attributes.at("upper").i()); @@ -69,9 +66,12 @@ struct parse_trilu : op_parser // when creating the mask, if upper == 1, // the inner triangle will have values set to 0 std::vector mask_mat(num_rows * num_cols, upper); + // if upper == 0, kth diagonal must also be masked + if(not upper) + k++; for(size_t i = 0; i < num_rows; i++) { - for(size_t j = 0; j < std::min(k, static_cast(num_cols)); j++) + for(int j = 0; j < std::min(k, static_cast(num_cols)); j++) { mask_mat[i * num_cols + j] = not upper; } diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index c4fbbfc0d42..64adf6e3910 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -8573,7 +8573,7 @@ def transpose_gather_test(): @onnx_test() -def trilu_test(): +def triu_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) @@ -8586,7 +8586,7 @@ def trilu_test(): @onnx_test() -def trilu_batch_diff_k_test(): +def triu_batch_diff_k_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 3]) k = np.array([2]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3]) @@ -8604,7 +8604,24 @@ def trilu_batch_diff_k_test(): @onnx_test() -def trilu_lower_test(): +def tril_batch_diff_k_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 3]) + k = np.array([2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3]) + k_tensor = helper.make_tensor(name='k', + data_type=TensorProto.INT64, + dims=k.shape, + vals=k.astype(np.int64)) + + node = onnx.helper.make_node('Trilu', + inputs=['x', 'k'], + outputs=['y'], + upper=0) + return ([node], [x], [y], [k_tensor]) + + +@onnx_test() +def tril_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) @@ -8613,7 +8630,7 @@ def trilu_lower_test(): @onnx_test() -def trilu_neg_k_test(): +def triu_neg_k_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) k = np.array([-1]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) @@ -8627,7 +8644,23 @@ def trilu_neg_k_test(): @onnx_test() -def trilu_out_k_test(): +def tril_neg_k_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) + k = np.array([-1]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) + k_tensor = helper.make_tensor(name='k', + data_type=TensorProto.INT64, + dims=k.shape, + vals=k.astype(np.int64)) + node = onnx.helper.make_node('Trilu', + inputs=['x', 'k'], + outputs=['y'], + upper=0) + return ([node], [x], [y], [k_tensor]) + + +@onnx_test() +def triu_out_k_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) k = np.array([5]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) @@ -8641,7 +8674,23 @@ def trilu_out_k_test(): @onnx_test() -def trilu_row_one_test(): +def tril_out_k_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) + k = np.array([5]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) + k_tensor = helper.make_tensor(name='k', + data_type=TensorProto.INT64, + dims=k.shape, + vals=k.astype(np.int64)) + node = onnx.helper.make_node('Trilu', + inputs=['x', 'k'], + outputs=['y'], + upper=0) + return ([node], [x], [y], [k_tensor]) + + +@onnx_test() +def triu_row_one_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4]) k = np.array([1]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 4]) @@ -8658,6 +8707,23 @@ def trilu_row_one_test(): return ([node], [x], [y], [k_tensor]) +@onnx_test() +def tril_row_one_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4]) + k = np.array([1]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 4]) + k_tensor = helper.make_tensor(name='k', + data_type=TensorProto.INT64, + dims=k.shape, + vals=k.astype(np.int64)) + + node = onnx.helper.make_node('Trilu', + inputs=['x', 'k'], + outputs=['y'], + upper=0) + return ([node], [x], [y], [k_tensor]) + + @onnx_test() def undefined_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index f42ce832f74..cea43ec487d 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -8091,11 +8091,6 @@ TEST_CASE(transpose_gather_test) EXPECT(p.sort() == prog.sort()); } -TEST_CASE(trilu_neg_k_test) -{ - EXPECT(test::throws([&] { migraphx::parse_onnx("trilu_neg_k_test.onnx"); })); -} - TEST_CASE(undefined_test) { migraphx::program p; diff --git a/test/onnx/tril_batch_diff_k_test.onnx b/test/onnx/tril_batch_diff_k_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ea81a0d1a2bd871f18139cb88f2f13196af3e18c GIT binary patch literal 150 zcmdB(ON-BrFG(#fu`1(|<6^AfV$2p|tW;tR0V*oh;^AT~ zEhtDWl3-ZC$b?I$78eJj0J{|pN#aAFeR F2LPYW9{vCT literal 0 HcmV?d00001 diff --git a/test/onnx/tril_neg_k_test.onnx b/test/onnx/tril_neg_k_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a446539f351d463a17f4146624518b47a626a3f7 GIT binary patch literal 137 zcmd%FKz+OHGf@jxR|qF0m@$lH+2m;9|@cVyskR4FSrQYVmNfmKGGG7D+HH zU}VD3q9x41D8O#T^&bWpofxyDgn_0A@p18RFbZ*SF>x>hF-sCquo7L+iAjJT0Hd-Z ADgXcg literal 0 HcmV?d00001 diff --git a/test/onnx/tril_out_k_test.onnx b/test/onnx/tril_out_k_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cc432a726d1b202d571d104680dca7045884628c GIT binary patch literal 128 zcmd%FKz+FD;4BjxR|qF0o4IlH+2m;9|@cVyskR4FSrQYVmNfmKGGG7D+HH xU}VD3qQ%9*D8O#T$m+zH9VHCZEX2pf!@(%T!NtVE48$x+K*36MK_?~wegFpH8CU=S literal 0 HcmV?d00001 diff --git a/test/onnx/tril_row_one_test.onnx b/test/onnx/tril_row_one_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..04fceb90b58cd25165cedad492eee9a34e5608cc GIT binary patch literal 132 zcmdKaIuya6r>hO zFf3qX!lFrwi-S>s-HMUXi7`7$7^q!{kBf(cQHX&q5mn literal 0 HcmV?d00001 diff --git a/test/onnx/tril_test.onnx b/test/onnx/tril_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..74075cf4e170fe03b13bb9bbb18a11918c8e9a21 GIT binary patch literal 101 zcmd b$Hl|JD8#|V#K8>2EJ;AYN_0UdCINl`o|zJQ literal 0 HcmV?d00001 diff --git a/test/onnx/trilu_batch_diff_k_test.onnx b/test/onnx/trilu_batch_diff_k_test.onnx deleted file mode 100644 index 7c258a20196..00000000000 --- a/test/onnx/trilu_batch_diff_k_test.onnx +++ /dev/null @@ -1,15 +0,0 @@ -trilu_batch_diff_k_test:i - -x -ky"Trilutrilu_batch_diff_k_test* -:BkZ -x - - - -b -y - - - -B \ No newline at end of file diff --git a/test/onnx/trilu_lower_test.onnx b/test/onnx/trilu_lower_test.onnx deleted file mode 100644 index 09ef4de5d93d51310009c098edddca5736ab161b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 115 zcmdob_ diff --git a/test/onnx/trilu_neg_k_test.onnx b/test/onnx/trilu_neg_k_test.onnx deleted file mode 100644 index 92418dbbc32..00000000000 --- a/test/onnx/trilu_neg_k_test.onnx +++ /dev/null @@ -1,13 +0,0 @@ -trilu_neg_k_test:c - -x -ky"Trilutrilu_neg_k_test*: -ÿÿÿÿÿÿÿÿÿBkZ -x -  - -b -y -  - -B \ No newline at end of file diff --git a/test/onnx/trilu_out_k_test.onnx b/test/onnx/trilu_out_k_test.onnx deleted file mode 100644 index 0b1cd3235c2..00000000000 --- a/test/onnx/trilu_out_k_test.onnx +++ /dev/null @@ -1,13 +0,0 @@ -trilu_out_k_test:Z - -x -ky"Trilutrilu_out_k_test* -:BkZ -x -  - -b -y -  - -B \ No newline at end of file diff --git a/test/onnx/trilu_row_one_test.onnx b/test/onnx/trilu_row_one_test.onnx deleted file mode 100644 index a88b5b750f2..00000000000 --- a/test/onnx/trilu_row_one_test.onnx +++ /dev/null @@ -1,13 +0,0 @@ -trilu_row_one_test:\ - -x -ky"Trilutrilu_row_one_test* -:BkZ -x -  - -b -y -  - -B \ No newline at end of file diff --git a/test/onnx/trilu_test.onnx b/test/onnx/trilu_test.onnx deleted file mode 100644 index 32965ca2761..00000000000 --- a/test/onnx/trilu_test.onnx +++ /dev/null @@ -1,13 +0,0 @@ - -trilu_test:E - -xy"Trilu -trilu_testZ -x -  - -b -y -  - -B \ No newline at end of file diff --git a/test/onnx/triu_batch_diff_k_test.onnx b/test/onnx/triu_batch_diff_k_test.onnx new file mode 100644 index 00000000000..9dc389bbc2d --- /dev/null +++ b/test/onnx/triu_batch_diff_k_test.onnx @@ -0,0 +1,15 @@ +triu_batch_diff_k_test:h + +x +ky"Trilutriu_batch_diff_k_test* +:BkZ +x + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/triu_neg_k_test.onnx b/test/onnx/triu_neg_k_test.onnx new file mode 100644 index 00000000000..9dcbcce8f74 --- /dev/null +++ b/test/onnx/triu_neg_k_test.onnx @@ -0,0 +1,13 @@ +triu_neg_k_test:b + +x +ky"Trilutriu_neg_k_test*: +ÿÿÿÿÿÿÿÿÿBkZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/triu_out_k_test.onnx b/test/onnx/triu_out_k_test.onnx new file mode 100644 index 00000000000..b9b5d33a4dc --- /dev/null +++ b/test/onnx/triu_out_k_test.onnx @@ -0,0 +1,13 @@ +triu_out_k_test:Y + +x +ky"Trilutriu_out_k_test* +:BkZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/triu_row_one_test.onnx b/test/onnx/triu_row_one_test.onnx new file mode 100644 index 00000000000..57aab2ce780 --- /dev/null +++ b/test/onnx/triu_row_one_test.onnx @@ -0,0 +1,13 @@ +triu_row_one_test:[ + +x +ky"Trilutriu_row_one_test* +:BkZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/triu_test.onnx b/test/onnx/triu_test.onnx new file mode 100644 index 00000000000..707cb609911 --- /dev/null +++ b/test/onnx/triu_test.onnx @@ -0,0 +1,11 @@ + triu_test:D + +xy"Trilu triu_testZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/verify_onnx.cpp b/test/onnx/verify_onnx.cpp index 277e0c7d537..fd642ac2e53 100644 --- a/test/onnx/verify_onnx.cpp +++ b/test/onnx/verify_onnx.cpp @@ -2233,9 +2233,10 @@ std::vector gen_trilu_test(const migraphx::shape& s, const migraphx::prog result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); return result_vector; } -TEST_CASE(trilu_test) + +TEST_CASE(triu_test) { - migraphx::program p = migraphx::parse_onnx("trilu_test.onnx"); + migraphx::program p = migraphx::parse_onnx("triu_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); @@ -2244,9 +2245,9 @@ TEST_CASE(trilu_test) EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); } -TEST_CASE(trilu_batch_diff_k_test) +TEST_CASE(triu_batch_diff_k_test) { - migraphx::program p = migraphx::parse_onnx("trilu_batch_diff_k_test.onnx"); + migraphx::program p = migraphx::parse_onnx("triu_batch_diff_k_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {2, 2, 3}}, p); @@ -2255,9 +2256,42 @@ TEST_CASE(trilu_batch_diff_k_test) EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); } -TEST_CASE(trilu_lower_test) +TEST_CASE(tril_test) { - migraphx::program p = migraphx::parse_onnx("trilu_lower_test.onnx"); + migraphx::program p = migraphx::parse_onnx("tril_test.onnx"); + + std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); + + std::vector gold = {1, 0, 0, 0, 5, 6, 0, 0, 9, 10, 11, 0}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(tril_batch_diff_k_test) +{ + migraphx::program p = migraphx::parse_onnx("tril_batch_diff_k_test.onnx"); + + std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {2, 2, 3}}, p); + + std::vector gold = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(triu_neg_k_test) +{ + migraphx::program p = migraphx::parse_onnx("triu_neg_k_test.onnx"); + + std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); + + std::vector gold = {1, 2, 3, 4, 5, 6, 7, 8, 0, 10, 11, 12}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(tril_neg_k_test) +{ + migraphx::program p = migraphx::parse_onnx("tril_neg_k_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); @@ -2266,9 +2300,9 @@ TEST_CASE(trilu_lower_test) EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); } -TEST_CASE(trilu_out_k_test) +TEST_CASE(triu_out_k_test) { - migraphx::program p = migraphx::parse_onnx("trilu_out_k_test.onnx"); + migraphx::program p = migraphx::parse_onnx("triu_out_k_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); @@ -2277,9 +2311,20 @@ TEST_CASE(trilu_out_k_test) EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); } -TEST_CASE(trilu_row_one_test) +TEST_CASE(tril_out_k_test) +{ + migraphx::program p = migraphx::parse_onnx("tril_out_k_test.onnx"); + + std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); + + std::vector gold = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(triu_row_one_test) { - migraphx::program p = migraphx::parse_onnx("trilu_row_one_test.onnx"); + migraphx::program p = migraphx::parse_onnx("triu_row_one_test.onnx"); std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {1, 4}}, p); @@ -2288,4 +2333,15 @@ TEST_CASE(trilu_row_one_test) EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); } +TEST_CASE(tril_row_one_test) +{ + migraphx::program p = migraphx::parse_onnx("tril_row_one_test.onnx"); + + std::vector result_vector = gen_trilu_test({migraphx::shape::float_type, {1, 4}}, p); + + std::vector gold = {1, 2, 0, 0}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index 90a0dbaf902..e3505269d42 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -590,9 +590,6 @@ def disabled_tests_onnx_1_9_0(backend_test): backend_test.exclude(r'test_gru_batchwise_cpu') backend_test.exclude(r'test_lstm_batchwise_cpu') backend_test.exclude(r'test_simple_rnn_batchwise_cpu') - backend_test.exclude(r'test_tril_cpu') - backend_test.exclude(r'test_tril_one_row_neg_cpu') - backend_test.exclude(r'test_tril_square_cpu') # from OnnxBackendPyTorchConvertedModelTest backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu') backend_test.exclude(r'test_MaxPool2d_stride_padding_dilation_cpu') From f47e0b5ba43686380c0881bb79c5fdc8c9c1d9a7 Mon Sep 17 00:00:00 2001 From: turneram <71655887+turneram@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:00:02 -0500 Subject: [PATCH 12/12] CK GEMM Int8 Bug Fixes (#2229) Adds workarounds to avoid passing capture ops and scalar literals from quantization as arguments to ck_gemm. --- src/rewrite_quantization.cpp | 21 ++++++-- src/targets/gpu/fuse_ck.cpp | 21 ++++++-- test/onnx/onnx_test.cpp | 80 ++++++++++++++++-------------- test/rewrite_quantization_test.cpp | 9 +++- 4 files changed, 86 insertions(+), 45 deletions(-) diff --git a/src/rewrite_quantization.cpp b/src/rewrite_quantization.cpp index d2ce20868d9..2e98d8a0054 100644 --- a/src/rewrite_quantization.cpp +++ b/src/rewrite_quantization.cpp @@ -33,6 +33,8 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK_WORKAROUNDS); + void apply_quantizelinear(module& m, instruction_ref ins) { assert(ins->name() == "quantizelinear"); @@ -62,9 +64,22 @@ void apply_quantizelinear(module& m, instruction_ref ins) max_quant = qt.max(); min_quant = qt.min(); }); - auto s = add_zero_point->get_shape(); - auto min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}}); - auto max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}}); + auto s = add_zero_point->get_shape(); + instruction_ref min_arg; + instruction_ref max_arg; + + if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{})) + { + std::vector min_data(s.elements(), min_quant); + std::vector max_data(s.elements(), max_quant); + min_arg = m.add_literal(literal(s, min_data)); + max_arg = m.add_literal(literal(s, max_data)); + } + else + { + min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}}); + max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}}); + } auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg}); m.replace_instruction( ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate); diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index 43c7087bce7..7043985573b 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -26,6 +26,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -92,6 +93,8 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) auto m = a.lens()[a.lens().size() - 2]; auto n = b.lens().back(); auto k = a.lens().back(); + auto batch_size = std::accumulate( + a.lens().rbegin() + 2, a.lens().rend(), std::size_t{1}, std::multiplies()); // Integer gemms must be divisible by 4 in ck if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type())) { @@ -102,9 +105,17 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) if(k % 4 != 0) return false; } - // Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy - // to avoid poor-performing GEMM kernels from CK - // To-do: Investigate a more precise strategy + auto device_name = trim(split_string(get_device_name(), ':').front()); + if(device_name == "gfx940") + { + if(ins->get_shape().type() == shape::half_type) + { + if(batch_size >= 64) + return m < 2048 or k <= 64 or n <= 384 or n >= 2048; + return true; + } + return true; + } return k <= 2048; } @@ -140,6 +151,10 @@ struct find_ck_gemm_pointwise return not input->inputs().empty() and input->inputs().front()->name() == "capture"; })) return; + if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) { + return not input->inputs().empty() and input->inputs().front()->name() == "capture"; + })) + return; assert(gemm_it != inputs.end()); if(gemm_idx != 0) { diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index cea43ec487d..027a6c9a639 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -42,11 +42,14 @@ #include #include #include +#include #include #include "test.hpp" +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK_WORKAROUNDS); + migraphx::program optimize_onnx(const std::string& name, bool run_passes = false) { migraphx::onnx_options options; @@ -5540,6 +5543,31 @@ TEST_CASE(qlinearmatmul_2D_test) EXPECT(p.sort() == prog.sort()); } +migraphx::instruction_ref insert_quantizelinear_clip(migraphx::module& m, + const migraphx::instruction_ref ins, + const migraphx::instruction_ref round, + const migraphx::shape s, + const int64_t min_quant, + const int64_t max_quant) +{ + migraphx::instruction_ref min_arg; + migraphx::instruction_ref max_arg; + if(migraphx::enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{})) + { + std::vector min_data(s.elements(), min_quant); + std::vector max_data(s.elements(), max_quant); + min_arg = m.add_literal(migraphx::literal(s, min_data)); + max_arg = m.add_literal(migraphx::literal(s, max_data)); + } + else + { + min_arg = m.add_literal(migraphx::literal{migraphx::shape{s.type()}, {min_quant}}); + max_arg = m.add_literal(migraphx::literal{migraphx::shape{s.type()}, {max_quant}}); + } + + return migraphx::insert_common_op(m, ins, migraphx::make_op("clip"), {round, min_arg, max_arg}); +} + TEST_CASE(quantizelinear_test) { migraphx::program p; @@ -5548,16 +5576,10 @@ TEST_CASE(quantizelinear_test) auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); auto l1_mbcast = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1); - auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); - auto round = mm->add_instruction(migraphx::make_op("round"), div); - auto s = round->get_shape(); - auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {0}}); - auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {255}}); - auto min_mbcast = - mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg); - auto max_mbcast = - mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg); - auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_mbcast, max_mbcast); + auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); + auto round = mm->add_instruction(migraphx::make_op("round"), div); + auto s = round->get_shape(); + auto clip = insert_quantizelinear_clip(*mm, div, round, s, 0, 255); mm->add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}), @@ -5579,16 +5601,10 @@ TEST_CASE(quantizelinear_int32_test) migraphx::make_op("convert", {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), l0); - auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); - auto round = mm->add_instruction(migraphx::make_op("round"), div); - auto s = round->get_shape(); - auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {0}}); - auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {255}}); - auto min_mbcast = - mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg); - auto max_mbcast = - mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg); - auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_mbcast, max_mbcast); + auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); + auto round = mm->add_instruction(migraphx::make_op("round"), div); + auto s = round->get_shape(); + auto clip = insert_quantizelinear_clip(*mm, div, round, s, 0, 255); mm->add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}), @@ -5615,15 +5631,9 @@ TEST_CASE(quantizelinear_zero_point_test) migraphx::make_op("convert", {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), l2_mbcast); - auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast); - auto s = round->get_shape(); - auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {-128}}); - auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {127}}); - auto min_mbcast = - mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg); - auto max_mbcast = - mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg); - auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_mbcast, max_mbcast); + auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast); + auto s = round->get_shape(); + auto clip = insert_quantizelinear_clip(*mm, div, add, s, -128, 127); mm->add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}), @@ -5654,15 +5664,9 @@ migraphx::program make_quantizelinear_axis_prog() migraphx::make_op("convert", {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), l2_bcast); - auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast); - auto s = round->get_shape(); - auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {-128}}); - auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {127}}); - auto min_mbcast = - mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg); - auto max_mbcast = - mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg); - auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_mbcast, max_mbcast); + auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast); + auto s = round->get_shape(); + auto clip = insert_quantizelinear_clip(*mm, div, add, s, -128, 127); mm->add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}), diff --git a/test/rewrite_quantization_test.cpp b/test/rewrite_quantization_test.cpp index 282a89bf354..6e32dfdd637 100644 --- a/test/rewrite_quantization_test.cpp +++ b/test/rewrite_quantization_test.cpp @@ -31,10 +31,13 @@ #include #include #include +#include #include #include +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK_WORKAROUNDS); + bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; } bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; } bool is_clip_scalar(migraphx::instruction& ins) @@ -82,7 +85,11 @@ TEST_CASE(quantizelinear) EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear)); EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear)); // ensure clip literals created in quantized program are scalar - EXPECT(any_of(*p2.get_main_module(), &is_clip_scalar)); + // unless CK workarounds are enabled + if(migraphx::enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{})) + EXPECT(none_of(*p2.get_main_module(), &is_clip_scalar)); + else + EXPECT(any_of(*p2.get_main_module(), &is_clip_scalar)); } TEST_CASE(dequantizelinear)