diff --git a/requirements.txt b/requirements.txt index 3a06a254528..b4d4a4f84aa 100755 --- a/requirements.txt +++ b/requirements.txt @@ -28,5 +28,5 @@ ROCmSoftwarePlatform/half@rocm-5.6.0 pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCmSoftwarePlatform/composable_kernel@a22e479b8e1557961039db2d5c5ff89cff35e86b -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCmSoftwarePlatform/rocMLIR@12748a3402c069f733ea7f2ba1f8d8a070b3622a -DBUILD_FAT_LIBROCKCOMPILER=On +ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On +ROCmSoftwarePlatform/rocMLIR@12748a3402c069f733ea7f2ba1f8d8a070b3622a -DBUILD_FAT_LIBROCKCOMPILER=On \ No newline at end of file diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index fc3b3e773c8..43c7087bce7 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -22,9 +22,9 @@ * THE SOFTWARE. */ #include +#include #include #include -#include #include namespace migraphx { @@ -55,7 +55,7 @@ struct ck_gemm { check_shapes{inputs, *this}.same_ndims(); if(inputs.size() < 2) - MIGRAPHX_THROW("should have at least two inputs."); + MIGRAPHX_THROW(name() + ": should have at least two inputs."); auto a = inputs[0]; auto b = inputs[1]; for(const auto& input : inputs) @@ -65,21 +65,27 @@ struct ck_gemm return r; return r.with_type(mods.front()->get_output_shapes().front().type()); } + + static bool is_ck_supported_type(shape::type_t t) + { + return contains({shape::half_type, shape::int8_type, shape::int32_type}, t); + } }; MIGRAPHX_REGISTER_OP(ck_gemm); -namespace { - -bool is_ck_supported_type(shape::type_t t) +struct ck_gemm_softmax_gemm : gemm_softmax_gemm { - return contains({shape::half_type, shape::int8_type, shape::int32_type}, t); -} + std::string name() const { return "gpu::ck_gemm_softmax_gemm"; } +}; +MIGRAPHX_REGISTER_OP(ck_gemm_softmax_gemm); + +namespace { MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) { if(ins->name() != "dot" and ins->name() != "quant_dot") return false; - if(not is_ck_supported_type(ins->get_shape().type())) + if(not ck_gemm::is_ck_supported_type(ins->get_shape().type())) return false; auto a = ins->inputs().front()->get_shape(); auto b = ins->inputs().back()->get_shape(); @@ -127,7 +133,11 @@ struct find_ck_gemm_pointwise ins->get_shape().type() != gemm_ins->get_shape().type()) return; if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) { - return not is_ck_supported_type(input->get_shape().type()); + return not ck_gemm::is_ck_supported_type(input->get_shape().type()); + })) + return; + if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) { + return not input->inputs().empty() and input->inputs().front()->name() == "capture"; })) return; assert(gemm_it != inputs.end()); @@ -152,7 +162,7 @@ struct find_ck_gemm_pointwise struct find_ck_gemm { - auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); } + auto matcher() const { return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const { @@ -161,11 +171,26 @@ struct find_ck_gemm } }; +struct find_ck_gemm_softmax_gemm +{ + auto matcher() const { return match::name("gpu::pre_gemm_softmax_gemm"); } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto ins = r.result; + auto v = ins->get_operator().to_value(); + assert(v.contains("scale")); + auto scale = v.at("scale").to(); + mpm.get_module().replace_instruction( + ins, ck_gemm_softmax_gemm{migraphx::make_op("dot"), scale}, ins->inputs()); + } +}; + } // namespace void fuse_ck::apply(module_pass_manager& mpm) const { - match::find_matches(mpm, find_ck_gemm_pointwise{}); + match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{}); match::find_matches(mpm, find_ck_gemm{}); } diff --git a/src/targets/gpu/include/migraphx/gpu/ck.hpp b/src/targets/gpu/include/migraphx/gpu/ck.hpp new file mode 100644 index 00000000000..1b7f5ad3e81 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/ck.hpp @@ -0,0 +1,165 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_GPU_CK_HPP +#define MIGRAPHX_GUARD_GPU_CK_HPP + +#include +#include +#include +#include +#include + +#include "ck/host/device_gemm_multiple_d.hpp" +#include "ck/host/device_batched_gemm_softmax_gemm.hpp" + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +#ifndef _WIN32 +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK); +#endif + +// NOLINTNEXTLINE +const char* const disable_warning_pragma = R"__migraphx__( +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Weverything" +${content} +#pragma clang diagnostic pop +)__migraphx__"; + +template +std::string ck_disable_warnings(P p) +{ + return interpolate_string(disable_warning_pragma, + {{"content", std::string{p.data(), p.size()}}}); +} + +static std::unordered_map create_ck_header_strings() +{ + std::unordered_map result; + auto ck_headers = ck::host::GetHeaders(); + + std::transform( + ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto& p) { + return std::pair(p.first, ck_disable_warnings(p.second)); + }); + return result; +} + +static std::vector create_ck_headers() +{ + static const auto& header_strings = create_ck_header_strings(); + std::vector srcs; + std::transform(header_strings.begin(), + header_strings.end(), + std::back_inserter(srcs), + [&](auto& p) { return src_file{p}; }); + return srcs; +} + +static inline const std::vector& ck_headers() +{ + static const auto& headers = create_ck_headers(); + return headers; +} + +inline bool transposed_matrix(const shape& s) { return s.strides().back() != 1; } + +inline ck::host::DataType get_type(const shape& s) +{ + if(s.type() == shape::half_type) + return ck::host::DataType::Half; + else if(s.type() == shape::float_type) + return ck::host::DataType::Float; + else if(s.type() == shape::int8_type) + return ck::host::DataType::Int8; + else if(s.type() == shape::int32_type) + return ck::host::DataType::Int32; + MIGRAPHX_THROW("Unsupported ck type"); +} + +inline std::size_t get_batch_count(const shape& s) +{ + return std::accumulate( + s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies()); +} + +inline void fold_batch_dims(shape& s) +{ + auto lens = s.lens(); + if(lens.size() <= 2) + return; + auto batch_count = get_batch_count(s); + auto m1 = lens.at(lens.size() - 2); + auto m2 = lens.at(lens.size() - 1); + if(transposed_matrix(s)) + s = shape{s.type(), {m1, m2 * batch_count}}; + else + s = shape{s.type(), {m1 * batch_count, m2}}; +} + +inline void remove_batch_dims(shape& s) +{ + auto lens = s.lens(); + if(lens.size() <= 2) + return; + auto m1 = lens.at(lens.size() - 2); + auto m2 = lens.at(lens.size() - 1); + s = shape{s.type(), {m1, m2}}; +} + +inline bool standard_batch(const shape& s) +{ + if(s.lens().size() < 3) + return true; + std::vector lens(s.lens().begin(), s.lens().end() - 2); + std::vector strides(s.strides().begin(), s.strides().end() - 2); + auto base = *(s.lens().end() - 2) * *(s.lens().end() - 1); + std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto stride) { + return stride / base; + }); + return shape{s.type(), lens, strides}.standard(); +} + +inline bool can_fold_batch(const std::vector& inputs) +{ + const auto& b_shape = inputs[1]; + if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) { + return not standard_batch(input); + })) + return false; + const auto& b_strides = b_shape.strides(); + return std::all_of( + b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; }); +} + +} // namespace gpu + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_CK_HPP diff --git a/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp b/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp new file mode 100644 index 00000000000..f27b30659ea --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp @@ -0,0 +1,75 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP +#define MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct gemm_softmax_gemm +{ + operation op = make_op("dot"); + float scale = 1.0; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.op, "op"), f(self.scale, "scale")); + } + + std::string name() const { return "gpu::gemm_softmax_gemm"; } + + void check_gemm_shape(const shape& s) const + { + if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1)) + MIGRAPHX_THROW("Invalid shape for " + name()); + } + + shape compute_shape(std::vector inputs, const std::vector&) const + { + check_shapes{inputs, *this}.same_ndims(); + if(inputs.size() < 3) + MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size())); + auto a = inputs[0]; + auto b = inputs[1]; + auto b1 = inputs[2]; + for(const auto& input : inputs) + { + check_gemm_shape(input); + } + return op.compute_shape({op.compute_shape({a, b}), b1}); + } + + static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); } +}; + +} // namespace gpu + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_GEMM_SOFTMAX_GEMM_HPP diff --git a/src/targets/gpu/jit/ck_gemm.cpp b/src/targets/gpu/jit/ck_gemm.cpp index 2937f653c09..7d0c9676e99 100644 --- a/src/targets/gpu/jit/ck_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm.cpp @@ -27,6 +27,7 @@ #include #include +#include #include #include #include @@ -37,8 +38,6 @@ #include #include -#include "ck/host/device_gemm_multiple_d.hpp" - namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -46,12 +45,6 @@ namespace gpu { using namespace migraphx::gpu::gen; // NOLINT -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM); -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING); -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING_VALUE); -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG); -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK); - // NOLINTNEXTLINE static const char* const ck_gemm_kernel = R"__migraphx__( #include @@ -79,219 +72,10 @@ MIGRAPHX_GLOBAL void ${kernel}(${params}) )__migraphx__"; -// NOLINTNEXTLINE -static const char* const disable_warning_pragma = R"__migraphx__( -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Weverything" -${content} -#pragma clang diagnostic pop -)__migraphx__"; - -template -static std::string ck_disable_warnings(P p) -{ - return interpolate_string(disable_warning_pragma, - {{"content", std::string{p.first, p.second}}}); -} - -static std::unordered_map create_ck_header_strings() -{ - std::unordered_map result; - auto ck_headers = ck::host::GetHeaders(); - - std::transform( - ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto&& p) { - return std::make_pair(p.first, ck_disable_warnings(p.second)); - }); - return result; -} - -static std::vector create_ck_headers() -{ - static const auto& header_strings = create_ck_header_strings(); - std::vector srcs; - std::transform( - header_strings.begin(), header_strings.end(), std::back_inserter(srcs), [&](auto&& p) { - return src_file{p.first, p.second}; - }); - return srcs; -} - -static const std::vector& ck_headers() -{ - static const auto& headers = create_ck_headers(); - return headers; -} - -static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; } - -using tuning_entry = std::pair, size_t>; -static std::vector read_tuning(const std::string& s) -{ - if(not fs::exists(s)) - return {}; - return from_value>(from_json_string(read_string(s))); -} - -static float matrix_distance(const shape& x, const shape& y) -{ - if(x.type() != y.type()) - return std::numeric_limits::max(); - if(transposed_matrix(x) != transposed_matrix(y)) - return std::numeric_limits::max(); - auto sum_squared = std::inner_product(x.lens().rbegin(), - x.lens().rbegin() + 2, - y.lens().rbegin(), - 0, - std::plus<>{}, - [](auto a, auto b) { return (a - b) * (a - b); }); - return std::sqrt(sum_squared); -} - -static std::size_t get_tuning_for(const std::vector& inputs) -{ - static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, "")); - if(tuning.empty()) - { - std::cout << "*********** Warning: No CK tuning! for config:" << std::endl; - std::cout << " " << inputs[0] << std::endl; - std::cout << " " << inputs[1] << std::endl; - std::cout << " " << inputs[2] << std::endl; - } - auto it = std::find_if( - tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; }); - if(it == tuning.end()) - { - std::cout << "*********** Warning: CK tuning missing for config!" << std::endl; - std::cout << " " << inputs[0] << std::endl; - std::cout << " " << inputs[1] << std::endl; - std::cout << " " << inputs[2] << std::endl; - std::vector> w; - std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) { - if(inputs.size() < 3 or p.first.size() < 3) - MIGRAPHX_THROW("Invalid CK config"); - auto avg_distance = std::inner_product( - p.first.begin(), - p.first.begin() + 3, - inputs.begin(), - 0.0f, - std::plus<>{}, - [](const auto& x, const auto& y) { return matrix_distance(x, y) / 3.0f; }); - return std::make_pair(avg_distance, p.second); - }); - std::sort(w.begin(), w.end()); - std::size_t default_value = 4; - if(not w.empty()) - default_value = w.front().second; - auto tuning_val = value_of(MIGRAPHX_CK_TUNING_VALUE{}, default_value); - std::cout << "*********** Warning: CK try tuning: " << tuning_val << std::endl; - return tuning_val; - } - return it->second; -} - struct ck_gemm_compiler : compiler { - static std::string get_layout(const shape& s) - { - return transposed_matrix(s) ? "ck::tensor_layout::gemm::ColumnMajor" - : "ck::tensor_layout::gemm::RowMajor"; - } - - static ck::host::DataType get_type(const shape& s) - { - if(s.type() == shape::half_type) - return ck::host::DataType::Half; - else if(s.type() == shape::float_type) - return ck::host::DataType::Float; - else if(s.type() == shape::int8_type) - return ck::host::DataType::Int8; - else if(s.type() == shape::int32_type) - return ck::host::DataType::Int32; - MIGRAPHX_THROW("Unsupported ck type"); - } - - template - static std::string ck_tuple(Iterator start, Iterator last, F f) - { - std::vector s; - std::transform(start, last, std::back_inserter(s), f); - return "ck::Tuple<" + join_strings(s, ",") + ">"; - } - - static std::vector adjust_inputs(std::vector inputs, bool& swap_inputs) - { - swap_inputs = false; - auto c_shape = inputs.back(); - if(not transposed_matrix(c_shape)) - return inputs; - std::vector perm(c_shape.lens().size()); - std::iota(perm.begin(), perm.end(), 0); - std::swap(perm[perm.size() - 1], perm[perm.size() - 2]); - std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](shape s) { - return reorder_shape(s, perm); - }); - swap_inputs = true; - return inputs; - } - - static std::size_t get_batch_count(const shape& s) - { - return std::accumulate( - s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies()); - } - - static void fold_batch_dims(shape& s) - { - auto lens = s.lens(); - if(lens.size() <= 2) - return; - auto batch_count = get_batch_count(s); - auto m1 = lens.at(lens.size() - 2); - auto m2 = lens.at(lens.size() - 1); - if(transposed_matrix(s)) - s = shape{s.type(), {m1, m2 * batch_count}}; - else - s = shape{s.type(), {m1 * batch_count, m2}}; - } - - static void remove_batch_dims(shape& s) - { - auto lens = s.lens(); - if(lens.size() <= 2) - return; - auto m1 = lens.at(lens.size() - 2); - auto m2 = lens.at(lens.size() - 1); - s = shape{s.type(), {m1, m2}}; - } - std::vector names() const { return {"ck_gemm", "gpu::ck_gemm"}; } - static bool standard_batch(const shape& s) - { - if(s.lens().size() < 3) - return true; - std::vector lens(s.lens().begin(), s.lens().end() - 2); - std::vector strides(s.strides().begin(), s.strides().end() - 2); - auto base = *(s.lens().end() - 2) * *(s.lens().end() - 1); - std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto stride) { - return stride / base; - }); - return shape{s.type(), lens, strides}.standard(); - } - - bool can_fold_batch(const std::vector& inputs) const - { - const auto& b_shape = inputs[1]; - if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) { - return not standard_batch(input); - })) - return false; - const auto& b_strides = b_shape.strides(); - return std::all_of( - b_strides.begin(), b_strides.end() - 2, [](auto stride) { return stride == 0; }); - } - ck::host::device_gemm_multiple_d::Problem create_problem(const std::vector& inputs, const value& v) const { @@ -300,8 +84,7 @@ struct ck_gemm_compiler : compiler const auto& c_shape = inputs.back(); // cppcheck-suppress unreadVariable - auto rank = a_shape.ndim(); - + auto rank = a_shape.ndim(); auto batch_count = get_batch_count(c_shape); auto m = c_shape.lens()[rank - 2]; m = can_fold_batch(inputs) ? m * batch_count : m; @@ -351,12 +134,8 @@ struct ck_gemm_compiler : compiler operation compile_op(context& ctx, const std::vector& inputs, const value& v) const { - const auto& a_shape = inputs[0]; - const auto& b_shape = inputs[1]; const auto& c_shape = inputs.back(); - auto tuning_value = v.get("tuning_value", 4); - if(not v.contains("tuning_value")) - tuning_value = get_tuning_for({a_shape, b_shape, c_shape}); + auto tuning_value = v.get("tuning_value", 34); auto batch_count = get_batch_count(c_shape); auto problem = create_problem(inputs, v); diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp new file mode 100644 index 00000000000..4176ed04e1d --- /dev/null +++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp @@ -0,0 +1,236 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +namespace gpu { + +using namespace migraphx::gpu::gen; // NOLINT + +// NOLINTNEXTLINE +static const char* const ck_gemm_softmax_gemm_kernel = R"__migraphx__( +#include +#include +#include +#include +#include +#include +#include <${include}> + +namespace migraphx { + +${preamble} + +extern "C" { + +MIGRAPHX_GLOBAL void ${kernel}(${params}) +{ + transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) { + auto settings = make_ck_gemm_softmax_gemm_settings(MIGRAPHX_MAKE_CONSTANT(float{SCALE})); + ck_gemm_softmax_gemm<${solution}, ${blocks_per_batch}>(settings, xs...); + }); +} + +} + +} // namespace migraphx + +)__migraphx__"; + +struct ck_gemm_softmax_gemm_compiler : compiler +{ + std::vector names() const + { + return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"}; + } + + ck::host::device_batched_gemm_softmax_gemm::Problem + create_problem(const std::vector& inputs, const value&) const + { + const auto& a_shape = inputs[0]; + const auto& b_shape = inputs[1]; + const auto& b1_shape = inputs[2]; + const auto& c_shape = inputs.back(); + + // cppcheck-suppress unreadVariable + auto rank = a_shape.ndim(); + auto batch_count = get_batch_count(c_shape); + auto m = c_shape.lens()[rank - 2]; + m = can_fold_batch(inputs) ? m * batch_count : m; + auto n = c_shape.lens().back(); + auto k = a_shape.lens().back(); + auto o = c_shape.lens().back(); + + const bool trans_a = transposed_matrix(a_shape); + const bool trans_b = transposed_matrix(b_shape); + const bool trans_b1 = transposed_matrix(b1_shape); + const bool trans_c = transposed_matrix(c_shape); + const auto a_type = get_type(a_shape); + const auto b_type = get_type(b_shape); + const auto b1_type = get_type(b1_shape); + const auto c_type = get_type(c_shape); + + std::string ck_passthrough = "ck_passthrough"; + return ck::host::device_batched_gemm_softmax_gemm::Problem{m, + n, + k, + o, + trans_a, + trans_b, + trans_b1, + trans_c, + a_type, + b_type, + b1_type, + c_type, + ck_passthrough, + ck_passthrough, + ck_passthrough, + ck_passthrough}; + } + + operation compile_op(context& ctx, const std::vector& inputs, const value& v) const + { + const auto& c_shape = inputs.back(); + auto tuning_value = v.get("tuning_value", 5); + auto batch_count = get_batch_count(c_shape); + auto problem = create_problem(inputs, v); + + const auto include_header = problem.GetIncludeHeader(); + const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); + const auto& solution = solutions.at(tuning_value); + const auto template_str = solution.template_str; + const auto blocks_per_batch = solution.grid_size; + const auto block_size = solution.block_size; + + hip_compile_options options; + options.additional_src_files = ck_headers(); + auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch; + options.set_launch_params(v, grid_size * block_size, block_size); + options.inputs = inputs; + options.output = c_shape; + options.kernel_name = v.get("kernel", "ck_gemm_softmax_gemm_kernel"); + options.virtual_inputs = inputs; + if(can_fold_batch(inputs)) + { + auto vinputs = inputs; + fold_batch_dims(vinputs[0]); + remove_batch_dims(vinputs[1]); + std::for_each(vinputs.begin() + 2, vinputs.end(), fold_batch_dims); + options.virtual_inputs = vinputs; + } + + if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{})) + options.params += " -DMIGRAPHX_CK_CHECK=1"; + + // scale + assert(v.contains("scale")); + auto scale = v.at("scale").to(); + options.params += " -DSCALE=" + std::to_string(scale); + + auto src = interpolate_string(ck_gemm_softmax_gemm_kernel, + {{"solution", template_str}, + {"include", include_header}, + {"params", enum_params(inputs.size(), "void * private_p")}, + {"args", enum_params(inputs.size(), "private_p")}, + {"blocks_per_batch", to_string(blocks_per_batch)}, + {"preamble", v.get("preamble", std::string{})}, + {"kernel", options.kernel_name}}); + + return compile_hip_code_object(src, options); + } + + value create_settings(instruction_ref ins, const operation& op) const + { + auto v = op.to_value(); + v["kernel"] = "ck_gemm_softmax_gemm_kernel"; + if(not ins->module_inputs().empty()) + { + auto* pm = ins->module_inputs().front(); + v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") + + "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, " + "post_ck_gemm_softmax_gemm_function);"; + v["post"] = "ck_function_adaptor"; + v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel"; + } + return v; + } + + compiler_replace + compile(context& ctx, instruction_ref ins, const operation& op, const value& solution) const + { + auto shapes = to_shapes(ins->inputs()); + auto v = create_settings(ins, op); + if(not solution.is_null()) + v["tuning_value"] = solution; + return {compile_op(ctx, shapes, v), + [=](module& m, instruction_ref ins2, const operation& code_object) { + if(enabled(MIGRAPHX_LOG_CK_GEMM{})) + { + std::vector gemm_shapes{ + shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())}; + std::cout << "gpu::ck_gemm_softmax_gemm: " + << to_json_string(to_value(gemm_shapes)) << std::endl; + } + m.replace_instruction(ins2, code_object, ins2->inputs()); + }}; + } + + optional + get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive) const + { + if(not exhaustive and not enabled(MIGRAPHX_TUNE_CK{})) + return nullopt; + tuning_config tc; + auto shapes = to_shapes(ins->inputs()); + auto problem = create_problem(shapes, create_settings(ins, op)); + auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); + tc.solutions.resize(solutions.size()); + std::iota(tc.solutions.begin(), tc.solutions.end(), 0); + std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()}; + tc.problem = to_value(gemm_shapes); + return tc; + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp index f8ba21d9570..370191155da 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp @@ -154,6 +154,17 @@ struct ck_add } }; +// In CK, the B matrix is ordered as N,K instead of K,N +template +constexpr auto ck_transposeb_dims(Dims dims) +{ + return unpack(dims, [](auto k, auto n) { return make_const_array(n, k); }); +} + +template +using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c{}.lens), + ck_transposeb_dims(get_shape_c{}.strides))); + #ifdef MIGRAPHX_CK_CHECK #define MIGRAPHX_CK_STATIC_ASSERT static_assert #else diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp index fb032ca7e96..bc942029a29 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp @@ -33,17 +33,6 @@ namespace migraphx { -// In CK, the B matrix is ordered as N,K instead of K,N -template -constexpr auto ck_transposeb_dims(Dims dims) -{ - return unpack(dims, [](auto k, auto n) { return make_const_array(n, k); }); -} - -template -using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c{}.lens), - ck_transposeb_dims(get_shape_c{}.strides))); - template __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds) { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp new file mode 100644 index 00000000000..80d4f69f549 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp @@ -0,0 +1,74 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_SOFTMAX_GEMM_HPP +#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_SOFTMAX_GEMM_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { + +template +struct ck_gemm_softmax_gemm_settings +{ + T scale{}; +}; + +template +constexpr ck_gemm_softmax_gemm_settings make_ck_gemm_softmax_gemm_settings(Ts... xs) +{ + return {xs...}; +} + +template +__device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1, Settings s) +{ + constexpr auto desc = G::make_descriptor(to_ck_tensor(), + to_ck_tensor>(), + to_ck_tensor>(), + to_ck_tensor()); + + static_assert(desc.IsValid(), "Invalid ck gemm."); + + G::Run(desc, + s.scale, + to_ck_const_pointer(a.data()), + to_ck_const_pointer(b.data()), + to_ck_const_pointer(b1.data()), + to_ck_pointer(c.data())); +} + +template +__device__ void ck_gemm_softmax_gemm(Settings s, Ts... xs) +{ + gemm_batch_args(make_index(), _c, xs...)( + [&](auto... ys) { ck_gemm_softmax_gemm_matrix(ys..., s); }); +} + +} // namespace migraphx +#endif diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index 0c93c6c67db..e689eb7e741 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -23,16 +23,17 @@ */ #include #include +#include #include -#include -#include #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { + namespace { template @@ -120,6 +121,60 @@ struct find_add_layernorm m.replace_instruction(ins, add_layernorm{op.epsilon}, add_ins->inputs()); } }; + +struct pre_gemm_softmax_gemm : gemm_softmax_gemm +{ + std::string name() const { return "gpu::pre_gemm_softmax_gemm"; } +}; +MIGRAPHX_REGISTER_OP(pre_gemm_softmax_gemm); + +MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) +{ + if(ins->name() != "dot") + return false; + if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type())) + return false; + return true; +} + +struct find_gemm_softmax_gemm +{ + auto matcher() const + { + auto gemm1 = + match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); + auto mul = match::name("mul")( + match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1)); + auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax"); + + return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(softmax)); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto ins = r.result; + auto gemm2_ins = r.instructions["gemm2"]; + auto gemm1_ins = r.instructions["gemm1"]; + auto scale_lit = r.instructions["scale"]; + + float scale = 1.0; + scale_lit->eval().visit([&](const auto s) { + // CK only supports single-valued scale + if(std::all_of( + s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); })) + scale = s.front(); + else + return; + }); + + auto inputs = gemm1_ins->inputs(); // A, B + inputs.push_back(gemm2_ins->inputs().back()); // B1 + + mpm.get_module().replace_instruction( + ins, pre_gemm_softmax_gemm{gemm2_ins->get_operator(), scale}, inputs); + } +}; + } // namespace void prefuse_ops::apply(module_pass_manager& mpm) const @@ -127,6 +182,8 @@ void prefuse_ops::apply(module_pass_manager& mpm) const match::find_matches(mpm.get_module(), find_layernorm{}); mpm.run_pass(dead_code_elimination{}); match::find_matches(mpm.get_module(), find_add_layernorm{}); + if(enabled(MIGRAPHX_ENABLE_CK{})) + match::find_matches(mpm, find_gemm_softmax_gemm{}); } } // namespace gpu diff --git a/test/verify/ck_gemm_softmax_gemm.cpp b/test/verify/ck_gemm_softmax_gemm.cpp new file mode 100644 index 00000000000..84c309cb734 --- /dev/null +++ b/test/verify/ck_gemm_softmax_gemm.cpp @@ -0,0 +1,56 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct ck_gemm_softmax_gemm : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}}; + migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}}; + auto m2_elements = m2_shape.elements(); + auto a = mm->add_parameter("1", m1_shape); + auto b = mm->add_parameter("2", m1_shape); + auto b1 = mm->add_parameter("3", m1_shape); + std::vector eights(m2_elements, 0.125); + auto eight = mm->add_literal(migraphx::literal{m2_shape, eights}); + std::vector zeros(m2_elements, 0); + auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros}); + + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight); + auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero); + auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias); + mm->add_instruction(migraphx::make_op("dot"), softmax, b1); + + return p; + } +};