Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CK GEMM-Softmax-GEMM fusion #2250

Merged
merged 49 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f5ebc8f
Fix inverted logic
turneram Jul 27, 2023
26d1a96
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMD…
turneram Aug 10, 2023
b119ed8
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMD…
turneram Aug 25, 2023
8ab0b22
Add gemm_softmax_gemm
turneram Aug 30, 2023
0a463c1
Formatting
turneram Aug 30, 2023
3793980
Move fuse_gsg to fuse_ck and fix bugs
turneram Sep 21, 2023
c393f23
Formatting
turneram Sep 21, 2023
59993d9
Update CK commit hash
turneram Sep 22, 2023
d1d48bd
Cleanup
turneram Sep 26, 2023
e781c38
Formatting
turneram Sep 26, 2023
59a0b0c
Format and cppcheck
turneram Sep 27, 2023
4a066e4
Address fuse_ck review comments
turneram Sep 27, 2023
791addf
Formatting
turneram Sep 27, 2023
135eb63
Move common CK device functions to ck.hpp
turneram Sep 27, 2023
a045fb1
Merge branch 'develop' into ck-flash-attn
turneram Sep 27, 2023
7e8b69a
Update CK SHA
turneram Sep 28, 2023
7103d3e
Call eval on scale lit and move gsg settings out of ck.hpp
turneram Sep 28, 2023
6fe8b43
Update CK SHA
turneram Sep 29, 2023
6d57f78
Add name() to CK compute shape throws; enforce mul has 2 args
turneram Sep 29, 2023
a9b32b7
Formatting
turneram Sep 29, 2023
250d3c8
Merge branch 'develop' into ck-flash-attn
causten Oct 1, 2023
c96139f
Move common functions to ck.hpp + other cleanup
turneram Oct 3, 2023
0b6b490
Formatting
turneram Oct 3, 2023
370b2cc
Update CK SHA
turneram Oct 3, 2023
c2bafa5
Merge branch 'ck-flash-attn' of https://github.com/ROCmSoftwarePlatfo…
turneram Oct 3, 2023
a7d5704
Move gemm_softmax_gemm matching to prefuse_ops
turneram Oct 4, 2023
4acc55c
Formatting
turneram Oct 4, 2023
f19e41b
Fix cppcheck + other cleanup
turneram Oct 4, 2023
29aae08
Formatting
turneram Oct 4, 2023
c04204a
Fix header schema
turneram Oct 4, 2023
5a3dff2
Formatting
turneram Oct 4, 2023
321367b
Naming
turneram Oct 4, 2023
a5dab12
Merge branch 'develop' into ck-flash-attn
turneram Oct 6, 2023
1b89f91
Merge remote-tracking branch 'origin/develop' into ck-flash-attn
turneram Oct 6, 2023
51b6ddf
Use new embed.cmake
turneram Oct 10, 2023
3aeab10
Formatting
turneram Oct 10, 2023
f407c8c
CK SHA
turneram Oct 10, 2023
36ea759
Use string instead of string_view for embedded ck headers; use safe d…
turneram Oct 11, 2023
c451aa9
Formatting
turneram Oct 11, 2023
9861856
Update CK SHA
turneram Oct 11, 2023
5b2b748
Merge remote-tracking branch 'origin/develop' into ck-flash-attn
turneram Oct 11, 2023
44a7755
Cppcheck
turneram Oct 12, 2023
b71fdf8
Formatting
turneram Oct 12, 2023
6fcf99b
Merge remote-tracking branch 'origin/develop' into ck-flash-attn
turneram Oct 12, 2023
9548894
Update src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp
turneram Oct 12, 2023
1cd455e
Remove ck.hpp include where not needed
turneram Oct 12, 2023
85a473b
Merge remote-tracking branch 'origin/develop' into ck-flash-attn
turneram Oct 12, 2023
99a9f23
Add ck.hpp back to prefuse_op.cpp
turneram Oct 12, 2023
8837320
Merge branch 'develop' into ck-flash-attn
causten Oct 13, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ ROCmSoftwarePlatform/[email protected]
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/[email protected] -DMSGPACK_BUILD_TESTS=Off
[email protected] -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@a22e479b8e1557961039db2d5c5ff89cff35e86b -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@12748a3402c069f733ea7f2ba1f8d8a070b3622a -DBUILD_FAT_LIBROCKCOMPILER=On
ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@12748a3402c069f733ea7f2ba1f8d8a070b3622a -DBUILD_FAT_LIBROCKCOMPILER=On
47 changes: 36 additions & 11 deletions src/targets/gpu/fuse_ck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>

namespace migraphx {
Expand Down Expand Up @@ -55,7 +55,7 @@ struct ck_gemm
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
MIGRAPHX_THROW(name() + ": should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
for(const auto& input : inputs)
Expand All @@ -65,21 +65,27 @@ struct ck_gemm
return r;
return r.with_type(mods.front()->get_output_shapes().front().type());
}

static bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
};
MIGRAPHX_REGISTER_OP(ck_gemm);

namespace {

bool is_ck_supported_type(shape::type_t t)
struct ck_gemm_softmax_gemm : gemm_softmax_gemm
{
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
};
MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm);

namespace {

MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(not is_ck_supported_type(ins->get_shape().type()))
if(not ck_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
Expand Down Expand Up @@ -127,7 +133,11 @@ struct find_ck_gemm_pointwise
ins->get_shape().type() != gemm_ins->get_shape().type())
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not is_ck_supported_type(input->get_shape().type());
return not ck_gemm::is_ck_supported_type(input->get_shape().type());
}))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
}))
return;
assert(gemm_it != inputs.end());
Expand All @@ -152,7 +162,7 @@ struct find_ck_gemm_pointwise

struct find_ck_gemm
{
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
auto matcher() const { return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")); }

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
Expand All @@ -161,11 +171,26 @@ struct find_ck_gemm
}
};

struct find_ck_gemm_softmax_gemm
{
auto matcher() const { return match::name("gpu::pre_gemm_softmax_gemm"); }

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto v = ins->get_operator().to_value();
assert(v.contains("scale"));
auto scale = v.at("scale").to<float>();
mpm.get_module().replace_instruction(
ins, ck_gemm_softmax_gemm{migraphx::make_op("dot"), scale}, ins->inputs());
}
};

} // namespace

void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}

Expand Down
165 changes: 165 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/ck.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_CK_HPP
#define MIGRAPHX_GUARD_GPU_CK_HPP

#include <migraphx/compile_src.hpp>
#include <migraphx/env.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
#include <string_view>

#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
#endif

// NOLINTNEXTLINE
const char* const disable_warning_pragma = R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__";

template <class P>
std::string ck_disable_warnings(P p)
{
return interpolate_string(disable_warning_pragma,
{{"content", std::string{p.data(), p.size()}}});
}

static std::unordered_map<std::string, std::string> create_ck_header_strings()
{
std::unordered_map<std::string, std::string> result;
auto ck_headers = ck::host::GetHeaders();

std::transform(
ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto& p) {
return std::pair<std::string, std::string>(p.first, ck_disable_warnings(p.second));
});
return result;
}

static std::vector<src_file> create_ck_headers()
{
static const auto& header_strings = create_ck_header_strings();
std::vector<src_file> srcs;
std::transform(header_strings.begin(),
header_strings.end(),
std::back_inserter(srcs),
[&](auto& p) { return src_file{p}; });
return srcs;
}

static inline const std::vector<src_file>& ck_headers()
{
static const auto& headers = create_ck_headers();
return headers;
}

inline bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }

inline ck::host::DataType get_type(const shape& s)
{
if(s.type() == shape::half_type)
return ck::host::DataType::Half;
else if(s.type() == shape::float_type)
return ck::host::DataType::Float;
else if(s.type() == shape::int8_type)
return ck::host::DataType::Int8;
else if(s.type() == shape::int32_type)
return ck::host::DataType::Int32;
MIGRAPHX_THROW("Unsupported ck type");
}

inline std::size_t get_batch_count(const shape& s)
{
return std::accumulate(
s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
}

inline void fold_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto batch_count = get_batch_count(s);
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
if(transposed_matrix(s))
s = shape{s.type(), {m1, m2 * batch_count}};
else
s = shape{s.type(), {m1 * batch_count, m2}};
}

inline void remove_batch_dims(shape& s)
{
auto lens = s.lens();
if(lens.size() <= 2)
return;
auto m1 = lens.at(lens.size() - 2);
auto m2 = lens.at(lens.size() - 1);
s = shape{s.type(), {m1, m2}};
}

inline bool standard_batch(const shape& s)
{
if(s.lens().size() < 3)
return true;
std::vector<std::size_t> lens(s.lens().begin(), s.lens().end() - 2);
std::vector<std::size_t> strides(s.strides().begin(), s.strides().end() - 2);
auto base = *(s.lens().end() - 2) * *(s.lens().end() - 1);
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto stride) {
return stride / base;
});
return shape{s.type(), lens, strides}.standard();
}

inline bool can_fold_batch(const std::vector<shape>& inputs)
{
const auto& b_shape = inputs[1];
if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) {
return not standard_batch(input);
}))
return false;
const auto& b_strides = b_shape.strides();
return std::all_of(
b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; });
}

} // namespace gpu

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_CK_HPP
75 changes: 75 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
#define MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP

#include <migraphx/make_op.hpp>
#include <migraphx/check_shapes.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

struct gemm_softmax_gemm
{
operation op = make_op("dot");
float scale = 1.0;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"), f(self.scale, "scale"));
}

std::string name() const { return "gpu::gemm_softmax_gemm"; }

void check_gemm_shape(const shape& s) const
{
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for " + name());
}

shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>&) const
{
check_shapes{inputs, *this}.same_ndims();
if(inputs.size() < 3)
MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size()));
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[2];
for(const auto& input : inputs)
{
check_gemm_shape(input);
}
return op.compute_shape({op.compute_shape({a, b}), b1});
}

static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
};

} // namespace gpu

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP
Loading
Loading