diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp index cadcc05f577..368ee9bc5fa 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -172,6 +172,7 @@ constexpr auto where(bool cond, const T& a, const U& b) MIGRAPHX_DEVICE_MATH_FOR(float, abs, ::abs) MIGRAPHX_DEVICE_MATH_FOR(double, abs, ::abs) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs) +MIGRAPHX_DEVICE_MATH_FOR(migraphx::bf16, abs, ::fabsf) MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::fmaxf) MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::fminf) MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp index 1b0d1343ea2..24b7d4a5b22 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -262,6 +262,8 @@ constexpr T numeric_max() return __FLT_MAX__; else if constexpr(is_same{}) return __FLT16_MAX__; + else if constexpr(is_same{}) + return 338953138925153547590470800371487866880.000000; else return 0; } diff --git a/test/gpu/jit.cpp b/test/gpu/jit.cpp index 65abfc03020..d9412cb24f5 100644 --- a/test/gpu/jit.cpp +++ b/test/gpu/jit.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -383,13 +383,10 @@ TEST_CASE(compile_math) auto vec_sizes = {2, 4, 6}; for(auto&& t : migraphx::shape::types()) { - if(contains({migraphx::shape::bool_type, - migraphx::shape::tuple_type, - migraphx::shape::bf16_type}, - t)) + if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) continue; auto name = migraphx::shape::cpp_type(t); - if(t == migraphx::shape::half_type) + if(contains({migraphx::shape::half_type, migraphx::shape::bf16_type}, t)) name.insert(0, "migraphx::"); data_types.push_back(name); // fp8 doesn't have vectorization support yet, therefore skip it for now. @@ -444,15 +441,16 @@ TEST_CASE(assert_type_min_max) migraphx::gpu::context ctx; for(auto&& t : migraphx::shape::types()) { - if(contains({migraphx::shape::bool_type, - migraphx::shape::tuple_type, - migraphx::shape::bf16_type}, - t)) + if(contains( + { + migraphx::shape::bool_type, + migraphx::shape::tuple_type, + }, + t)) continue; auto name = migraphx::shape::cpp_type(t); - if(t == migraphx::shape::half_type) + if(contains({migraphx::shape::half_type, migraphx::shape::bf16_type}, t)) name.insert(0, "migraphx::"); - migraphx::shape::visit(t, [&](auto as) { std::string min = ""; std::string max = "";