Skip to content

Commit

Permalink
Fix mod compile issue on GPU (#2268)
Browse files Browse the repository at this point in the history
  • Loading branch information
gyulaz-htec committed May 21, 2024
1 parent 1f07af9 commit d56693f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
3 changes: 1 addition & 2 deletions src/include/migraphx/op/mod.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -38,7 +38,6 @@ struct mod : binary<mod>
{
auto a = base_attributes();
a["commutative"] = false;
a["point_op"] = "${function:fmod}((${function:remainder}(${0}, ${1})) + ${1}, ${1})";
return a;
}
auto apply() const
Expand Down
18 changes: 17 additions & 1 deletion src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -256,6 +256,21 @@ constexpr auto min(const T& a, const U& b)
return min<common_type_t<T, U>>(a, b);
}

template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr T mod(const T& a, const T& b)
{
if constexpr(is_integral<T>{})
// onnx mod operator requires numpy style modulus
return ((a % b) + b) % b;
return static_cast<T>(fmod(remainder(a, b) + b, b));
}

template <class T, class U, MIGRAPHX_REQUIRES(not is_same<T, U>{} and not is_any_vec<T, U>())>
constexpr auto mod(const T& a, const U& b)
{
return mod<common_type_t<T, U>>(a, b);
}

MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh)
Expand All @@ -275,6 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(isnan)
MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(max)
MIGRAPHX_DEVICE_MATH_VEC(min)
MIGRAPHX_DEVICE_MATH_VEC(mod)
MIGRAPHX_DEVICE_MATH_VEC(nearbyint)
MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(remainder)
Expand Down
2 changes: 0 additions & 2 deletions test/py/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test.exclude(r'test_einsum_transpose_cpu')
backend_test.exclude(r'test_maxunpool_export_with_output_shape_cpu')
backend_test.exclude(r'test_maxunpool_export_without_output_shape_cpu')
backend_test.exclude(r'test_mod_mixed_sign_int32_cpu')
backend_test.exclude(r'test_mod_mixed_sign_int8_cpu')
backend_test.exclude(r'test_qlinearmatmul_2D_cpu')
backend_test.exclude(r'test_qlinearmatmul_3D_cpu')
backend_test.exclude(r'test_range_float_type_positive_delta_expanded_cpu')
Expand Down

0 comments on commit d56693f

Please sign in to comment.