Skip to content

Commit

Permalink
Dont use mixed layouts with convolution (#3587) (#3614)
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored Nov 15, 2024
1 parent f7eb605 commit 4b20cbc
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 53 deletions.
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ add_library(migraphx
insert_pad.cpp
instruction.cpp
json.cpp
layout_nhwc.cpp
layout_convolution.cpp
lexing.cpp
load_save.cpp
make_op.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP

#include <string>
#include <migraphx/instruction_ref.hpp>
Expand All @@ -34,14 +34,15 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;

/**
* Transform convolutions to nhwc
* Transform convolutions layout
*/
struct MIGRAPHX_EXPORT layout_nhwc
struct MIGRAPHX_EXPORT layout_convolution
{
std::string name() const { return "layout_nhwc"; }
bool channels_last = false;
std::string name() const { return "layout_convolution"; }
void apply(module_pass_manager& mpm) const;
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP
72 changes: 43 additions & 29 deletions src/layout_nhwc.cpp → src/layout_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/layout_convolution.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
Expand All @@ -32,58 +32,71 @@
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/stringutils.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

template <class Predicate>
std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
namespace {
std::vector<int64_t> get_permutation(instruction_ref ins, const layout_convolution& lc)
{
std::vector<instruction_ref> result;
fix([&](auto self, auto ins) {
if(pred(ins))
{
result.push_back(ins);
return;
}
for(auto input : ins->inputs())
self(input);
})(std::prev(m.end()));
return result;
if(lc.channels_last)
{
std::vector<int64_t> perm(ins->get_shape().ndim());
std::iota(perm.begin() + 1, perm.end() - 1, 2);
perm.back() = 1;
return perm;
}
return find_permutation(ins->inputs().front()->get_shape());
}

bool skip_layout(const shape& s)
{
return s.ndim() == 1 or s.dynamic() or s.type() == shape::tuple_type;
}

void preserve_output_layout(module& m)
{
auto last = std::prev(m.end());
std::vector<instruction_ref> outputs;
if(last->name() == "@return")
outputs = last->inputs();
else
outputs = {last};

for(auto output : outputs)
{
auto permutation = find_permutation(output->get_shape());
auto layout = m.insert_instruction(
std::next(output), make_op("layout", {{"permutation", permutation}}), output);
m.replace_instruction(output, layout);
std::vector<instruction_ref> outputs;
std::transform(last->inputs().begin(),
last->inputs().end(),
std::back_inserter(outputs),
[&](instruction_ref ins) {
if(skip_layout(ins->get_shape()))
return ins;
auto permutation = find_permutation(ins->get_shape());
return m.insert_instruction(
last, make_op("layout", {{"permutation", permutation}}), ins);
});
m.replace_return(outputs);
}
else if(not skip_layout(last->get_shape()))
{
auto permutation = find_permutation(last->get_shape());
m.add_instruction(make_op("layout", {{"permutation", permutation}}), last);
}
}

void transform_convolutions(module& m)
void transform_convolutions(module& m, const layout_convolution& lc)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "convolution")
if(not contains({"convolution", "quant_convolution"}, ins->name()))
continue;
if(ins->get_shape().dynamic())
continue;
if(ins->get_shape().lens().size() != 4)
continue;
auto v = ins->get_operator().to_value();
if(v.at("group").to<int>() > 1)
continue;
auto args = ins->inputs();
auto perm = get_permutation(ins, lc);
std::transform(args.begin(), args.end(), args.begin(), [&](const auto& i) {
return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
return m.insert_instruction(ins, make_op("layout", {{"permutation", perm}}), i);
});
auto conv = m.insert_instruction(ins, ins->get_operator(), args);
auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
Expand All @@ -102,11 +115,12 @@ void remove_layout(module& m)
m.replace_instruction(ins, ins->inputs().front());
}
}
} // namespace

void layout_nhwc::apply(module_pass_manager& mpm) const
void layout_convolution::apply(module_pass_manager& mpm) const
{
preserve_output_layout(mpm.get_module());
transform_convolutions(mpm.get_module());
transform_convolutions(mpm.get_module(), *this);
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(eliminate_contiguous{"contiguous"});
mpm.run_pass(dead_code_elimination{});
Expand Down
2 changes: 0 additions & 2 deletions src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,6 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
{
impl->changed.notify();
assert(has_instruction(ins));
assert(has_instruction(rep));
assert(ins != rep);

if(ins == std::prev(this->end()))
Expand Down Expand Up @@ -541,7 +540,6 @@ instruction_ref module::insert_parameter(instruction_ref ins, std::string name,
instruction_ref module::replace_return(std::vector<instruction_ref> args)
{
impl->changed.notify();
assert(std::all_of(args.begin(), args.end(), [&](auto ins) { return has_instruction(ins); }));
auto last = std::prev(this->end());
// If there is no return then add a return
if(last->name() != "@return")
Expand Down
1 change: 0 additions & 1 deletion src/targets/cpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/eliminate_convert.hpp>
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp>
Expand Down
4 changes: 2 additions & 2 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
#include <migraphx/fuse_pointwise_reduce.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/layout_convolution.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/optimize_module.hpp>
Expand Down Expand Up @@ -182,7 +182,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
rewrite_gelu{options.fast_math},
optimize_module{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
layout_convolution{.channels_last = enabled(MIGRAPHX_ENABLE_NHWC{})},
dead_code_elimination{},
prefuse_ops{},
dead_code_elimination{},
Expand Down
Loading

0 comments on commit 4b20cbc

Please sign in to comment.