From d56693f8408801c3512e8eaad388b703231f725d Mon Sep 17 00:00:00 2001 From: Gyula Zakor Date: Wed, 15 May 2024 10:23:12 +0000 Subject: [PATCH] Fix mod compile issue on GPU (#2268) --- src/include/migraphx/op/mod.hpp | 3 +-- .../kernels/include/migraphx/kernels/math.hpp | 18 +++++++++++++++++- test/py/onnx_backend_test.py | 2 -- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/include/migraphx/op/mod.hpp b/src/include/migraphx/op/mod.hpp index f1a48e3c58f..38f947a3587 100644 --- a/src/include/migraphx/op/mod.hpp +++ b/src/include/migraphx/op/mod.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,7 +38,6 @@ struct mod : binary { auto a = base_attributes(); a["commutative"] = false; - a["point_op"] = "${function:fmod}((${function:remainder}(${0}, ${1})) + ${1}, ${1})"; return a; } auto apply() const diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp index 5a6cca7bc24..da00ff9c781 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -256,6 +256,21 @@ constexpr auto min(const T& a, const U& b) return min>(a, b); } +template ())> +constexpr T mod(const T& a, const T& b) +{ + if constexpr(is_integral{}) + // onnx mod operator requires numpy style modulus + return ((a % b) + b) % b; + return static_cast(fmod(remainder(a, b) + b, b)); +} + +template {} and not is_any_vec())> +constexpr auto mod(const T& a, const U& b) +{ + return mod>(a, b); +} + MIGRAPHX_DEVICE_MATH_VEC(abs) MIGRAPHX_DEVICE_MATH_VEC(acos) MIGRAPHX_DEVICE_MATH_VEC(acosh) @@ -275,6 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(isnan) MIGRAPHX_DEVICE_MATH_VEC(log) MIGRAPHX_DEVICE_MATH_VEC(max) MIGRAPHX_DEVICE_MATH_VEC(min) +MIGRAPHX_DEVICE_MATH_VEC(mod) MIGRAPHX_DEVICE_MATH_VEC(nearbyint) MIGRAPHX_DEVICE_MATH_VEC(pow) MIGRAPHX_DEVICE_MATH_VEC(remainder) diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index 353bcea3944..2d847c97300 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -119,8 +119,6 @@ def disabled_tests_onnx_1_7_0(backend_test): backend_test.exclude(r'test_einsum_transpose_cpu') backend_test.exclude(r'test_maxunpool_export_with_output_shape_cpu') backend_test.exclude(r'test_maxunpool_export_without_output_shape_cpu') - backend_test.exclude(r'test_mod_mixed_sign_int32_cpu') - backend_test.exclude(r'test_mod_mixed_sign_int8_cpu') backend_test.exclude(r'test_qlinearmatmul_2D_cpu') backend_test.exclude(r'test_qlinearmatmul_3D_cpu') backend_test.exclude(r'test_range_float_type_positive_delta_expanded_cpu')