Skip to content

Commit

Permalink
Fix conversion issue in layernorm fusion (#1483) (#1493)
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored Dec 16, 2022
1 parent fe19455 commit 1b7af54
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 37 deletions.
6 changes: 3 additions & 3 deletions src/targets/gpu/include/migraphx/gpu/hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1, 2);
check_shapes{inputs, *this}.has(1, 2).same_type();
return inputs.at(0);
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
Expand All @@ -131,7 +131,7 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1, 2);
check_shapes{inputs, *this}.has(1, 2).same_type();
return inputs.at(0);
}
argument
Expand Down Expand Up @@ -159,7 +159,7 @@ struct hip_copy
std::string name() const { return "hip::copy"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
check_shapes{inputs, *this}.has(2).same_type();
return inputs.at(1);
}
argument compute(context& ctx, const shape&, std::vector<argument> args) const
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/print.hpp>

namespace migraphx {
Expand Down
32 changes: 0 additions & 32 deletions src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,38 +33,6 @@

namespace migraphx {

template <class T>
struct implicit_conversion_op
{
T x;

template <index_int N, class U>
constexpr operator vec<U, N>() const
{
if constexpr(vec_size<T>() == 0)
{
return x;
}
else
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
}

template <class U>
constexpr operator U() const
{
return x;
}
};

template <class T>
constexpr implicit_conversion_op<T> implicit_conversion(T x)
{
return {x};
}

template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{
Expand Down
32 changes: 32 additions & 0 deletions src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,5 +185,37 @@ constexpr auto vec_reduce(T x, Op op)
}
}

template <class T>
struct implicit_conversion_op
{
T x;

template <index_int N, class U>
constexpr operator vec<U, N>() const
{
if constexpr(vec_size<T>() == 0)
{
return x;
}
else
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
}

template <class U>
constexpr operator U() const
{
return x;
}
};

template <class T>
constexpr implicit_conversion_op<T> implicit_conversion(T x)
{
return {x};
}

} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
7 changes: 5 additions & 2 deletions src/targets/gpu/prefuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,20 @@ struct layernorm_base
}
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N);
auto s = inputs.at(0);
auto t = s.type();
if(not mods.empty())
t = mods.front()->get_output_shapes().front().type();
if(s.scalar())
{
return s;
}
else if(s.broadcasted())
{
return {s.type(), s.lens()};
return {t, s.lens()};
}
else
{
return s.with_lens(s.lens());
return s.with_lens(t, s.lens());
}
}
};
Expand Down

0 comments on commit 1b7af54

Please sign in to comment.