Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 GPU implementation #2455

Merged
merged 105 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
df7f8a3
changes for the FP8 ref implementation
umangyadav Nov 9, 2023
9bc1828
cppcheck fixes
umangyadav Nov 9, 2023
155a2b1
move FNUZ as template parameter
umangyadav Nov 10, 2023
d9f11e3
Fix numeric limits
umangyadav Nov 10, 2023
4e9d51f
Working FNUZ and FN
umangyadav Nov 10, 2023
7639c28
use float equal
umangyadav Nov 10, 2023
a6372c5
add test for fp8e5m2
umangyadav Nov 10, 2023
439ea40
add test for fp8e5m2fnuz
umangyadav Nov 10, 2023
183db78
refactor add some comments
umangyadav Nov 10, 2023
ab653af
Review updates
umangyadav Nov 13, 2023
8319e01
Fix tidy
umangyadav Nov 14, 2023
9ee0418
Fix test failure
umangyadav Nov 14, 2023
355e4f6
fix isfinite
umangyadav Nov 14, 2023
ba471f4
Merge remote-tracking branch 'origin/develop' into ref_fp8
umangyadav Nov 14, 2023
6aec703
fix test for neg inf
umangyadav Nov 14, 2023
12aac37
fix warning
umangyadav Nov 14, 2023
6009232
add tests
umangyadav Nov 14, 2023
03f7139
Fix tests
umangyadav Nov 14, 2023
1e220c0
add stringstream tests
umangyadav Nov 14, 2023
a83e9dc
Remove clang diagnostics
umangyadav Nov 15, 2023
dfb35a6
Merge remote-tracking branch 'origin/develop' into ref_fp8
umangyadav Nov 15, 2023
26956f1
Remove NOLINTS
umangyadav Nov 15, 2023
269ce6d
Bugfixes and additional tests
umangyadav Nov 16, 2023
6414ee3
Fix undoing
umangyadav Nov 16, 2023
cd26ada
Handle underflow case separately to avoid sanitization errors
umangyadav Nov 16, 2023
1cf87ef
use std::min to avoid sanitization errors
umangyadav Nov 16, 2023
e7e5ba2
Merge branch 'develop' into ref_fp8
umangyadav Nov 16, 2023
98a838f
formatting
umangyadav Nov 16, 2023
61e4e1d
use 31 for min value
umangyadav Nov 16, 2023
a5c38eb
add note
umangyadav Nov 16, 2023
61775ea
Merge branch 'ref_fp8' of github.com:ROCmSoftwarePlatform/AMDMIGraphX…
umangyadav Nov 16, 2023
3806427
Merge branch 'develop' into ref_fp8
umangyadav Nov 16, 2023
017d67e
add some more comments
umangyadav Nov 17, 2023
9e6d866
Merge branch 'ref_fp8' of github.com:ROCmSoftwarePlatform/AMDMIGraphX…
umangyadav Nov 17, 2023
a9dd42f
port gpu changes
umangyadav Nov 17, 2023
d7339e8
use bit cast
umangyadav Nov 17, 2023
6094234
Make FNUZ template param and add numeric limits
umangyadav Nov 17, 2023
78ec77e
only compile for device
umangyadav Nov 17, 2023
3411649
remove non-JIT related code
umangyadav Nov 17, 2023
d2c25a0
Remove FP8_Lowest/Max
umangyadav Nov 17, 2023
5da68df
remove using for dtypes
umangyadav Nov 17, 2023
b36f72d
Update float8_impl
umangyadav Nov 17, 2023
85ba819
constructor from float works with constexpr
umangyadav Nov 17, 2023
aed1922
Remove unnecessary pragmas
umangyadav Nov 17, 2023
f975c63
Remove clang diagnostics
umangyadav Nov 17, 2023
32033d8
Add back floatequal
umangyadav Nov 17, 2023
e88d46a
disable DPP For FP8
umangyadav Nov 17, 2023
3ae93ca
Merge remote-tracking branch 'origin/develop' into gpu_fp8
umangyadav Nov 17, 2023
60dd1f4
formatting
umangyadav Nov 17, 2023
ef425d0
revert unwanted changes
umangyadav Nov 17, 2023
76f0318
Merge branch 'gpu_fp8' of https://github.com/ROCmSoftwarePlatform/AMD…
umangyadav Nov 17, 2023
bd0ae5f
add some more tests
umangyadav Nov 17, 2023
91cc9c7
Add math and reduce tests
umangyadav Nov 18, 2023
e2b0c40
Fix tidy and other errors
umangyadav Nov 18, 2023
9f50051
fixes
umangyadav Nov 18, 2023
249464c
add nolint
umangyadav Nov 18, 2023
1be9587
tidy fix
umangyadav Nov 18, 2023
13403ab
roialign, softmax, pow, acosh, atanh,pad tests are enabled now
umangyadav Nov 20, 2023
f550f81
add layernorm, remove constexpr for 1/r
umangyadav Nov 20, 2023
7e3444c
tidy fixes
umangyadav Nov 20, 2023
6155c78
use __builtin_is_constant_evaluated
umangyadav Nov 20, 2023
13ef414
add test for rsqrt and remove old-styple-cast
umangyadav Nov 20, 2023
8660572
add comment about c++20 extensions
umangyadav Nov 20, 2023
6fbd997
Remove old cast
umangyadav Nov 20, 2023
2acd265
Remove DPP
umangyadav Nov 20, 2023
836e201
Remove MIN max overloads
umangyadav Nov 20, 2023
f9542d5
Put numeric_max and numeeric lowest into float8
umangyadav Nov 20, 2023
480288f
use void for highest to match template candidates
umangyadav Nov 21, 2023
a6c5772
add float8 for tensorview
umangyadav Nov 21, 2023
1a56e6d
remvoe static_casts
umangyadav Nov 28, 2023
8c25100
use float for roialign and add back static_cast for softmax
umangyadav Nov 28, 2023
18a129a
skip convert for find_concat_op
umangyadav Nov 28, 2023
4aa561b
formatting
umangyadav Nov 28, 2023
b6a3ba7
revert dnnl change
umangyadav Nov 29, 2023
838aebf
add test for concat convert fusion
umangyadav Nov 29, 2023
41e5cc0
Merge branch 'gpu_fp8' of https://github.com/ROCmSoftwarePlatform/AMD…
umangyadav Nov 29, 2023
7487568
disable lowering of contiguous and concat for dnnl
umangyadav Nov 29, 2023
3cc2046
formattimg
umangyadav Nov 29, 2023
6ea01a9
add static cast for implicit conversion
umangyadav Nov 29, 2023
90de973
use implicit conversions
umangyadav Nov 29, 2023
dad2a78
Merge branch 'gpu_fp8' of https://github.com/ROCmSoftwarePlatform/AMD…
umangyadav Nov 29, 2023
c0b5724
use {} initializer
umangyadav Nov 29, 2023
5308c81
disable lowering for fp8 ops for DNNL
umangyadav Nov 29, 2023
d6e2177
use vec for fp8 inside tests, it would have no effect
umangyadav Nov 29, 2023
24873eb
remove unwanted formatting change
umangyadav Nov 29, 2023
3733ecc
Revert "use vec for fp8 inside tests, it would have no effect"
umangyadav Nov 29, 2023
8d4fa29
misssed enabling fp8 JIT tests
umangyadav Nov 29, 2023
a6d8e43
add back .f
umangyadav Nov 30, 2023
0d220fd
add explicit cast for convert and add fmod/mod tests
umangyadav Nov 30, 2023
af499e4
add equal separately otherwise types are mismatching
umangyadav Nov 30, 2023
015e4bb
give better name
umangyadav Nov 30, 2023
f1c5544
address Ted's comments
umangyadav Nov 30, 2023
db55bc2
formatting
umangyadav Nov 30, 2023
c923e41
address comments
umangyadav Nov 30, 2023
9ac18df
remove numeric lowest
umangyadav Nov 30, 2023
ac73b33
renaminng stuff, using angled bracket for headers
umangyadav Nov 30, 2023
d26a86f
remove unnecessary line
umangyadav Nov 30, 2023
ba45008
add back lowest
umangyadav Nov 30, 2023
b11b2fe
Update src/targets/cpu/dnnl.cpp
umangyadav Nov 30, 2023
52cb87c
add another overload for numeric_max/lowest for the float8
umangyadav Dec 1, 2023
42a1686
fix bug
umangyadav Dec 1, 2023
b936b0e
Merge branch 'gpu_fp8' of https://github.com/ROCmSoftwarePlatform/AMD…
umangyadav Dec 1, 2023
86c4484
change comments
umangyadav Dec 1, 2023
8561d6d
dont' use abbreviation
umangyadav Dec 1, 2023
dbda1a1
Merge branch 'develop' into gpu_fp8
causten Dec 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/include/migraphx/bit_cast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,24 @@
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP
#include <type_traits>
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif

#include <migraphx/requires.hpp>
#include <migraphx/config.hpp>

// NOLINTNEXTLINE(cppcoreguidelines-macro-usage)
#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x))

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <typename To, typename From>
template <typename To,
typename From,
MIGRAPHX_REQUIRES(std::is_trivially_copyable<To>{} and
std::is_trivially_copyable<From>{})>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
Expand Down
1 change: 1 addition & 0 deletions src/targets/cpu/dnnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t)
case st::int32_type: return dt::s32;
case st::int8_type: return dt::s8;
case st::uint8_type: return dt::u8;
case st::fp8e4m3fnuz_type: MIGRAPHX_THROW("fp8e4m3fnuz unsupported in DNNL");
default: MIGRAPHX_THROW("Unsupported data type");
}
}
Expand Down
13 changes: 12 additions & 1 deletion src/targets/cpu/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ struct cpu_apply
{"reduce_min", "reduction_min"},
{"reduce_sum", "reduction_sum"},
});

extend_op("concat", "dnnl::concat");
extend_op("contiguous", "dnnl::reorder");
extend_op("convolution", "dnnl::convolution");
Expand Down Expand Up @@ -376,13 +375,25 @@ struct cpu_apply
// Apply these operators first so the inputs can be const folded
for(auto it : iterator_for(*modl))
{
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if(std::any_of(it->inputs().begin(), it->inputs().end(), [](const auto& i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
continue;
if(it->name() == "pow")
{
apply_pow(it);
}
}
for(auto it : iterator_for(*modl))
{
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if(std::any_of(it->inputs().begin(), it->inputs().end(), [](const auto& i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
continue;
if(it->name() == "pooling")
{
apply_pooling(it);
Expand Down
10 changes: 10 additions & 0 deletions src/targets/gpu/compile_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ vectorize vectorize::elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes)
{
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
return {1, axis};
if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
return {1, axis};
Expand Down Expand Up @@ -86,6 +91,11 @@ vectorize vectorize::elements(std::size_t axis,

vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs)
{
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
return {1, axis};
if(inputs.empty())
return {1, axis};
std::size_t n = std::max_element(inputs.begin(),
Expand Down
37 changes: 37 additions & 0 deletions src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP

#include <migraphx/kernels/type_traits.hpp>

namespace migraphx {
template <typename To,
typename From,
MIGRAPHX_REQUIRES(is_trivially_copyable<To>{} and is_trivially_copyable<From>{})>
inline constexpr To bit_cast(From fr) noexcept
{
static_assert(sizeof(To) == sizeof(From));
return __builtin_bit_cast(To, fr);
TedThemistokleous marked this conversation as resolved.
Show resolved Hide resolved
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
Loading
Loading