Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial splitk reduce #2908

Merged
merged 60 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
ded44c3
Add initial split reduce pass
pfultz2 Mar 17, 2024
cc84a58
Format
pfultz2 Mar 17, 2024
35ffc5a
Format
pfultz2 Mar 17, 2024
cad365a
Fixes
pfultz2 Mar 17, 2024
363afbe
Format
pfultz2 Mar 17, 2024
8f007f7
Implement split reduce_kernel
pfultz2 Mar 18, 2024
aed9edb
Format
pfultz2 Mar 18, 2024
63be4df
Use unsafeAtomicAdd
pfultz2 Mar 18, 2024
3f052e9
Use array
pfultz2 Mar 18, 2024
2fc9d69
Format
pfultz2 Mar 18, 2024
b5e0eb4
Merge
pfultz2 Mar 18, 2024
6d4a878
Format
pfultz2 Mar 18, 2024
17e9128
Fix test
pfultz2 Mar 18, 2024
b907737
Format
pfultz2 Mar 18, 2024
1026502
Fix tests
pfultz2 Mar 18, 2024
a45febe
Format
pfultz2 Mar 18, 2024
369c952
Some refactoring
pfultz2 Mar 19, 2024
ac0a922
Format
pfultz2 Mar 19, 2024
02953fe
Some more refactoring
pfultz2 Mar 19, 2024
ce2a336
Format
pfultz2 Mar 19, 2024
be217d2
Move split to module class
pfultz2 Mar 19, 2024
4dbd08a
Format
pfultz2 Mar 19, 2024
88087f2
Add missing files
pfultz2 Mar 19, 2024
b1cc070
Format
pfultz2 Mar 19, 2024
d01103d
Add test case
pfultz2 Mar 19, 2024
792062c
Format
pfultz2 Mar 19, 2024
917695f
Move liveness
pfultz2 Mar 19, 2024
1f172c8
Seperate previous pointwise module if its used again
pfultz2 Mar 19, 2024
d9c2c9a
Format
pfultz2 Mar 19, 2024
9a86a41
Check threshold
pfultz2 Mar 20, 2024
d844e99
Add fill for split reduce
pfultz2 Mar 20, 2024
75aecd0
Format
pfultz2 Mar 20, 2024
3da1f32
Add license
pfultz2 Mar 20, 2024
650ac71
Format
pfultz2 Mar 20, 2024
be097a4
Fix tidy warnings
pfultz2 Mar 20, 2024
79ed517
Format
pfultz2 Mar 20, 2024
31021f3
Remvoe assert
pfultz2 Mar 20, 2024
fb1c9ba
Only use for reduce_sum
pfultz2 Mar 20, 2024
b0220b4
Fix typo
pfultz2 Mar 20, 2024
188602e
Add doc
pfultz2 Mar 29, 2024
cba899e
Add asserts
pfultz2 Mar 29, 2024
bc9827d
Format
pfultz2 Mar 29, 2024
155f2ff
Add unit tests for split function
pfultz2 Apr 4, 2024
f626268
Format
pfultz2 Apr 4, 2024
7c3069a
Add test for small reduce
pfultz2 Apr 4, 2024
bc41558
Format
pfultz2 Apr 4, 2024
f1e6ab4
Add docstring
pfultz2 Apr 4, 2024
9d71f9d
Remove TODO
pfultz2 Apr 4, 2024
be1d82d
Add assert
pfultz2 Apr 4, 2024
d19c734
Add more tests and TODOs
pfultz2 Apr 5, 2024
3a9267b
Format
pfultz2 Apr 5, 2024
7e3f587
Format
pfultz2 Apr 5, 2024
3a216c6
Fix windows
pfultz2 Apr 5, 2024
4db4a59
Format
pfultz2 Apr 5, 2024
0c51839
Add docstring
pfultz2 Apr 8, 2024
614830d
Format
pfultz2 Apr 8, 2024
b934717
Update src/targets/gpu/jit/reduce.cpp
pfultz2 Apr 9, 2024
936d5f6
Only use unsafe when available
pfultz2 Apr 9, 2024
7bcb560
Format
pfultz2 Apr 9, 2024
66e8053
Suppress tidy warnings
pfultz2 Apr 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/dev/env_vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ Disables the ``schedule`` pass.
Set to "1", "enable", "enabled", "yes", or "true" to use.
Disables the ``fuse_reduce`` pass.

.. envvar:: MIGRAPHX_ENABLE_SPLIT_REDUCE
Set to "1", "enable", "enabled", "yes", or "true" to use.
Enable split_reduce.

.. envvar:: MIGRAPHX_ENABLE_NHWC

Set to "1", "enable", "enabled", "yes", or "true" to use.
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ add_library(migraphx
operation.cpp
optimize_module.cpp
pad_calc.cpp
param_utils.cpp
pass.cpp
pass_manager.cpp
permutation.cpp
Expand All @@ -94,6 +95,7 @@ add_library(migraphx
replace_allocate.cpp
rewrite_reduce.cpp
simplify_qdq.cpp
split_reduce.cpp
sqlite.cpp
rewrite_gelu.cpp
rewrite_low_precision.cpp
Expand Down
2 changes: 1 addition & 1 deletion src/fuse_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
input_map[input] = map_ins[param];
}
}
pm->replace_return(pm->insert_instructions(last, xm, map_ins));
pm->replace_return(pm->insert_instructions(last, xm, &map_ins));
return inputs;
}

Expand Down
23 changes: 3 additions & 20 deletions src/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,6 @@ struct fused_reduce
};
MIGRAPHX_REGISTER_OP(fused_reduce);

static std::unordered_map<instruction_ref, instruction_ref>
get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref sm)
{
std::unordered_map<instruction_ref, instruction_ref> result;
auto names = sm->get_parameter_names();
std::sort(names.begin(), names.end());
assert(names.size() == inputs.size());
std::transform(names.begin(),
names.end(),
inputs.begin(),
std::inserter(result, result.end()),
[&](const auto& name, auto input) {
return std::make_pair(input, sm->get_parameter(name));
});
return result;
}

static void insert_params(module_ref sm,
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
Expand All @@ -119,7 +102,7 @@ static auto insert_ins_in_submodule(module_ref sm,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
insert_params(sm, ins->inputs(), map_ins);
return sm->add_instructions({ins}, map_ins);
return sm->add_instructions({ins}, &map_ins);
}

static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins)
Expand All @@ -136,12 +119,12 @@ insert_module_in_submodule(module_ref sm,
module::inserter insert = nullptr)
{
insert_params(sm, inputs, map_ins);
auto param_map = get_ins_param_map(inputs, m);
auto param_map = m->get_ins_param_map(inputs);
for(auto&& [input, param] : param_map)
{
map_ins[param] = map_ins.at(input);
}
return sm->add_instructions(m, map_ins, std::move(insert));
return sm->add_instructions(m, &map_ins, std::move(insert));
}

static auto
Expand Down
77 changes: 77 additions & 0 deletions src/include/migraphx/liveness.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LIVENESS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LIVENESS_HPP

#include <migraphx/config.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_set>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

// This will do liveness analysis on the module, and it will call the
// function `f` with the instruction and the set of the other instructions
// that are live
template <class F>
void liveness(const module& m, F f)
{
auto implicit_deps = m.calc_implicit_deps();
std::unordered_set<instruction_ref> live_set;
auto rp = reverse(m);
for(auto rins : iterator_for(rp)) // NOLINT
{
// The base iterator is one ahead, so we need to use the previous iterator
auto ins = std::prev(rins.base());
umangyadav marked this conversation as resolved.
Show resolved Hide resolved
// Add live variables
auto add_live_variables = [&](const auto& inputs) {
for(auto input : inputs)
{
auto i = instruction::get_output_alias(input);
// Skip if variable comes from parent
if(not m.has_instruction(i))
continue;
live_set.insert(i);
}
};
add_live_variables(ins->inputs());
add_live_variables(implicit_deps[ins]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this call may not necessary. Because implicit deps would contain instructions for parent which are skipped anyways.

Any other input that are in current module are already being handled by add_live_variables(ins->inputs()).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this call may not necessary.

Not for split_reduce since the submodules dont reference the parent, but it is necessary for memory_coloring because there are submodules which will reference the parent(such as when using loop or if).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

, but it is necessary for memory_coloring because there are submodules which will reference the parent

Yeah but they are skipped inside the loop

                // Skip if variable comes from parent
                if(not m.has_instruction(i))
                    continue;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but they are skipped inside the loop

Yes, its not the same thing. The implict deps do not calculate the dependencies of the parent module. It calculates the dependencies of the submodules of an instruction that reference back to the original module.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

. It calculates the dependencies of the submodules of an instruction that reference back to the original module.

Since submodule instruction is using instruction from original/parent module,
if(not m.has_instruction(i)) this should be true always.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if(not m.has_instruction(i)) this should be true always.

No, an instruction can have input that references a parent module(not in the case of split_reduce since those operators dont rely on lexical scoping, but it is the case when using other operators that do allow lexical scoping such as if or loop).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So say we have graph of modules like this:

module A -> module B -> module C

So module B can have inputs from module A, and module C can have inputs from module A or B. When we do liveness on module B we calculate implicit_deps, which finds instructions in module B that use module C where module C reference instructions from module B.

Even so, module B can still reference instructions from module A, but we want to skip those for liveness.

// Remove last usage
auto it = live_set.find(ins);
if(it != live_set.end())
{
live_set.erase(it);
f(ins, live_set);
}
}
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_LIVENESS_HPP
65 changes: 53 additions & 12 deletions src/include/migraphx/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ struct module_impl;
using parameter_map = std::unordered_map<std::string, argument>;
using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;

struct module_with_inputs;

/**
* @brief Stores the instruction stream
*/
Expand Down Expand Up @@ -127,38 +129,38 @@ struct MIGRAPHX_EXPORT module

std::vector<instruction_ref>
add_instructions(const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {},
inserter insert = nullptr);
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
inserter insert = nullptr);

std::vector<instruction_ref>
add_instructions(const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {},
inserter insert = nullptr);
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
inserter insert = nullptr);

std::vector<instruction_ref>
add_instructions(instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {},
inserter insert = nullptr);
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
inserter insert = nullptr);

std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {},
inserter insert = nullptr);
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
umangyadav marked this conversation as resolved.
Show resolved Hide resolved
inserter insert = nullptr);

std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {},
inserter insert = nullptr);
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
inserter insert = nullptr);

std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {},
inserter insert = nullptr);
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
inserter insert = nullptr);

template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
Expand Down Expand Up @@ -186,6 +188,8 @@ struct MIGRAPHX_EXPORT module

instruction_ref get_parameter(std::string name) const;

std::vector<instruction_ref> get_parameters() const;

void rename_parameter(instruction_ref ins, const std::string& name);

std::unordered_map<std::string, shape> get_parameter_shapes() const;
Expand All @@ -205,6 +209,28 @@ struct MIGRAPHX_EXPORT module

void finalize(std::vector<context>& contexts);

/// Create a mapping from the input instruction to the corresponding
/// parameter instruction. Use the `reverse` flag to reverse the lookup
/// to be from parameter instruction to input instread.
std::unordered_map<instruction_ref, instruction_ref>
get_ins_param_map(const std::vector<instruction_ref>& inputs, bool reverse = false) const;

using with_inputs = module_with_inputs;

/// This will split the module into two parts at the instruction splits.
/// Each split instruction becomes an input parameter in the second
/// module. As such the inputs instructions to the second module will use
/// the split instructions as input placeholders that can be replaced
/// later.
umangyadav marked this conversation as resolved.
Show resolved Hide resolved
std::array<with_inputs, 2> split(const std::vector<instruction_ref>& args,
const std::vector<instruction_ref>& splits) const;

/// This will split the module in 3 parts using different split
/// instruction for each additional module.
std::array<with_inputs, 3> split(const std::vector<instruction_ref>& args,
const std::vector<instruction_ref>& splits1,
const std::vector<instruction_ref>& splits2) const;

void debug_print() const;
void debug_print(instruction_ref ins) const;
void debug_print(instruction_ref ins,
Expand Down Expand Up @@ -266,6 +292,21 @@ struct MIGRAPHX_EXPORT module
std::unique_ptr<module_impl> impl;
};

struct module_with_inputs
{
module mod;
std::vector<instruction_ref> inputs;
/// Replace the instruction in the inputs with rep
void replace(instruction_ref ins, instruction_ref rep);
/// Replace the input instructions using the map_ins to lookup the replacement
void replace(const std::unordered_map<instruction_ref, instruction_ref>& map_ins);

/// Replace the input instructions of the keys with the instructions
/// passed as values. Both vectors should be in the same order.
void replace(const std::vector<instruction_ref>& keys,
const std::vector<instruction_ref>& values);
};

inline module& get_module(module& m) { return m; }

} // namespace MIGRAPHX_INLINE_NS
Expand Down
42 changes: 42 additions & 0 deletions src/include/migraphx/param_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP

#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <vector>
#include <string>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

std::string param_name(std::size_t i, const std::string& prefix = "x");

void sort_params(std::vector<instruction_ref>& params);

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP
2 changes: 1 addition & 1 deletion src/include/migraphx/pass_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct module_pass_manager
module_pass_manager(const module_pass_manager&) = delete;
virtual module& get_module() = 0;
virtual module* create_module(const std::string& name) = 0;
virtual module* create_module(const std::string& name, const module& m) = 0;
virtual module* create_module(const std::string& name, module m) = 0;
virtual module* get_common_parent() = 0;
virtual module* get_root_module() = 0;
virtual void run_pass(const pass& p) = 0;
Expand Down
45 changes: 45 additions & 0 deletions src/include/migraphx/split_reduce.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SPLIT_REDUCE_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SPLIT_REDUCE_HPP

#include <migraphx/config.hpp>
#include <string>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct module_pass_manager;

struct MIGRAPHX_EXPORT split_reduce
{
std::size_t split_size = 8192;
TedThemistokleous marked this conversation as resolved.
Show resolved Hide resolved
std::string name() const { return "split_reduce"; }
void apply(module_pass_manager& mpm) const;
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SPLIT_REDUCE_HPP
Loading
Loading