From 1b7af5455e5007481429e3b2dfae76de171405ba Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Fri, 16 Dec 2022 09:57:57 -0500 Subject: [PATCH] Fix conversion issue in layernorm fusion (#1483) (#1493) --- src/targets/gpu/include/migraphx/gpu/hip.hpp | 6 ++-- .../include/migraphx/kernels/layernorm.hpp | 1 + .../include/migraphx/kernels/pointwise.hpp | 32 ------------------- .../kernels/include/migraphx/kernels/vec.hpp | 32 +++++++++++++++++++ src/targets/gpu/prefuse_ops.cpp | 7 ++-- 5 files changed, 41 insertions(+), 37 deletions(-) diff --git a/src/targets/gpu/include/migraphx/gpu/hip.hpp b/src/targets/gpu/include/migraphx/gpu/hip.hpp index 3e3ff3cb745..a6acbb3cb65 100644 --- a/src/targets/gpu/include/migraphx/gpu/hip.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hip.hpp @@ -105,7 +105,7 @@ struct hip_copy_to_gpu std::string name() const { return "hip::copy_to_gpu"; } shape compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(1, 2); + check_shapes{inputs, *this}.has(1, 2).same_type(); return inputs.at(0); } argument compute(context& ctx, const shape&, const std::vector& args) const @@ -131,7 +131,7 @@ struct hip_copy_from_gpu std::string name() const { return "hip::copy_from_gpu"; } shape compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(1, 2); + check_shapes{inputs, *this}.has(1, 2).same_type(); return inputs.at(0); } argument @@ -159,7 +159,7 @@ struct hip_copy std::string name() const { return "hip::copy"; } shape compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(2); + check_shapes{inputs, *this}.has(2).same_type(); return inputs.at(1); } argument compute(context& ctx, const shape&, std::vector args) const diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp index e358045db91..3d54c3302ef 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp @@ -25,6 +25,7 @@ #define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP #include #include +#include #include namespace migraphx { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp index 8fe5c954539..4b5f9fc865c 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp @@ -33,38 +33,6 @@ namespace migraphx { -template -struct implicit_conversion_op -{ - T x; - - template - constexpr operator vec() const - { - if constexpr(vec_size() == 0) - { - return x; - } - else - { - static_assert(vec_size() == N, "Vector mismatch size"); - return __builtin_convertvector(x, vec); - } - } - - template - constexpr operator U() const - { - return x; - } -}; - -template -constexpr implicit_conversion_op implicit_conversion(T x) -{ - return {x}; -} - template __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp index 5dea03d7d8b..9f012f29d8b 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp @@ -185,5 +185,37 @@ constexpr auto vec_reduce(T x, Op op) } } +template +struct implicit_conversion_op +{ + T x; + + template + constexpr operator vec() const + { + if constexpr(vec_size() == 0) + { + return x; + } + else + { + static_assert(vec_size() == N, "Vector mismatch size"); + return __builtin_convertvector(x, vec); + } + } + + template + constexpr operator U() const + { + return x; + } +}; + +template +constexpr implicit_conversion_op implicit_conversion(T x) +{ + return {x}; +} + } // namespace migraphx #endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index cbc66dbb1eb..e9a7823c002 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -51,17 +51,20 @@ struct layernorm_base } check_shapes{inputs, static_cast(*this)}.has(nargs + N); auto s = inputs.at(0); + auto t = s.type(); + if(not mods.empty()) + t = mods.front()->get_output_shapes().front().type(); if(s.scalar()) { return s; } else if(s.broadcasted()) { - return {s.type(), s.lens()}; + return {t, s.lens()}; } else { - return s.with_lens(s.lens()); + return s.with_lens(t, s.lens()); } } };