Skip to content

Commit

Permalink
Add support for multi outputs in pointwise ops (#2957)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Apr 19, 2024
1 parent f6e22cb commit 4e1caca
Show file tree
Hide file tree
Showing 30 changed files with 581 additions and 150 deletions.
20 changes: 19 additions & 1 deletion src/argument.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -102,6 +102,24 @@ void argument::assign_buffer(std::function<char*()> d)
})(s);
}

std::vector<argument> flatten(const std::vector<argument>& args)
{
std::vector<argument> result;
for(const auto& arg : args)
{
if(arg.get_shape().type() == shape::tuple_type)
{
auto subs = flatten(arg.get_sub_objects());
result.insert(result.end(), subs.begin(), subs.end());
}
else
{
result.push_back(arg);
}
}
return result;
}

std::vector<shape> to_shapes(const std::vector<argument>& args)
{
std::vector<shape> shapes;
Expand Down
18 changes: 15 additions & 3 deletions src/cpp_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ inline namespace MIGRAPHX_INLINE_NS {
cpp_generator::function&
cpp_generator::function::set_body(const module& m, const cpp_generator::generate_module_callback& g)
{
const std::string prefix = "zz";
std::unordered_map<migraphx::instruction_ref, std::string> names;
std::stringstream ss;

Expand All @@ -53,12 +54,13 @@ cpp_generator::function::set_body(const module& m, const cpp_generator::generate
}
else if(ins->name() == "@return")
{
assert(ins->inputs().size() == 1);
return_ins = ins->inputs().front();
names[ins] = prefix + "return";
ss << "auto " << names[ins] << " = " << g(ins, names) << ";\n";
return_ins = ins;
}
else
{
std::string n = "z" + std::to_string(names.size());
std::string n = prefix + std::to_string(names.size());
names[ins] = n;
ss << "auto " << n << " = " << g(ins, names) << ";\n";
}
Expand Down Expand Up @@ -125,6 +127,7 @@ struct cpp_generator_impl
std::function<std::string(std::string)> fmap = nullptr;
std::function<std::string(shape)> fresult = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {};
bool always_return_tuple = false;
};
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}

Expand All @@ -142,6 +145,8 @@ void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { imp

void cpp_generator::fresult(const std::function<std::string(shape)>& f) { impl->fresult = f; }

void cpp_generator::always_return_tuple(bool b) { impl->always_return_tuple = b; }

void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{
impl->point_op_map[op_name] = code;
Expand Down Expand Up @@ -222,6 +227,13 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
});
return shape::cpp_type(ins->get_shape().type()) + "(" + string_literal + ")";
}
if(ins->name() == "@return")
{
// TODO: Customize the make_tuple call
if(impl->always_return_tuple or ins->inputs().size() != 1)
return "make_tuple(" + join_strings(to_args(ins->inputs(), names), ", ") + ")";
return names.at(ins->inputs().front());
}
auto s = g(ins, names);
if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')';
Expand Down
36 changes: 18 additions & 18 deletions src/fuse_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,17 @@ static void create_pointwise_modules(module_pass_manager& mpm)
}
}

static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
instruction_ref output)
static module::with_inputs append_pointwise_module(instruction_ref ins, instruction_ref output)
{
assert(contains(output->inputs(), ins));
module_ref pm = ins->module_inputs().at(0);
module pm = *ins->module_inputs().at(0);
module_ref xm = output->module_inputs().at(0);

auto last = std::prev(pm->end());
auto last = std::prev(pm.end());
assert(last->name() == "@return");
assert(last->inputs().size() == 1);

assert(pm->get_parameter_names().size() == ins->inputs().size());
assert(pm.get_parameter_names().size() == ins->inputs().size());
assert(xm->get_parameter_names().size() == output->inputs().size());

std::vector<instruction_ref> inputs = ins->inputs();
Expand All @@ -134,8 +133,8 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
for(auto i : range(inputs.size()))
{
auto input = inputs[i];
auto param = pm->get_parameter("x" + std::to_string(i));
assert(param != pm->end());
auto param = pm.get_parameter("x" + std::to_string(i));
assert(param != pm.end());
input_map[input] = param;
}
// Add the new parameter and additional inputs
Expand All @@ -157,20 +156,20 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
else
{
map_ins[param] =
pm->add_parameter("x" + std::to_string(inputs.size()), {input->get_shape().type()});
pm.add_parameter("x" + std::to_string(inputs.size()), {input->get_shape().type()});
inputs.push_back(input);
input_map[input] = map_ins[param];
}
}
pm->replace_return(pm->insert_instructions(last, xm, &map_ins));
return inputs;
pm.replace_return(pm.insert_instructions(last, xm, &map_ins));
return {std::move(pm), inputs};
}

static bool find_pointwise_modules(module& m)
static bool find_pointwise_modules(module_pass_manager& mpm)
{
bool changed = false;
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
auto last = std::prev(mpm.get_module().end());
for(auto ins : iterator_for(mpm.get_module()))
{
if(ins->name() != "pointwise")
continue;
Expand All @@ -183,10 +182,11 @@ static bool find_pointwise_modules(module& m)
continue;
auto input = *it;

auto new_inputs = append_pointwise_module(input, ins);
m.replace_instruction(input, input->get_operator(), new_inputs, input->module_inputs());
m.replace_instruction(ins, input);
m.move_instruction(input, ins);
auto fused = append_pointwise_module(input, ins);
auto name = fused.mod.name();
mpm.rename_module(name, name + ":" + ins->module_inputs().front()->name() + "-deleted");
auto* new_pm = mpm.create_module(name, std::move(fused.mod));
mpm.get_module().replace_instruction(ins, input->get_operator(), fused.inputs, {new_pm});

changed = true;
}
Expand All @@ -213,7 +213,7 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
for(int i = 0; i < 8; i++)
{
mpm.run_pass(rewrite_reshapes<pointwise_reshape>{});
if(not find_pointwise_modules(mpm.get_module()))
if(not find_pointwise_modules(mpm))
break;
mpm.run_pass(dead_code_elimination{});
}
Expand Down
4 changes: 3 additions & 1 deletion src/include/migraphx/argument.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -117,6 +117,8 @@ struct MIGRAPHX_EXPORT argument : raw_data<argument>
data_t m_data{};
};

std::vector<argument> flatten(const std::vector<argument>& args);

MIGRAPHX_EXPORT std::vector<shape> to_shapes(const std::vector<argument>& args);
MIGRAPHX_EXPORT void migraphx_to_value(value& v, const argument& a);
MIGRAPHX_EXPORT void migraphx_from_value(const value& v, argument& a);
Expand Down
4 changes: 3 additions & 1 deletion src/include/migraphx/cpp_generator.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -95,6 +95,8 @@ struct MIGRAPHX_EXPORT cpp_generator

void fresult(const std::function<std::string(shape)>& f);

void always_return_tuple(bool b = true);

void add_point_op(const std::string& op_name, const std::string& code);

std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
Expand Down
14 changes: 14 additions & 0 deletions src/include/migraphx/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,20 @@ struct MIGRAPHX_EXPORT module
instruction_ref begin() const;
instruction_ref end() const;

struct compute_shapes_options
{
std::string name = "compute_shapes";
bool strict_type = false;
bool strict_lens = false;
std::vector<std::size_t> scalar_const_out_lens = {};
};

/// Compute a new ouput shape by replacing each parameter with input
/// shapes passed in.
std::vector<shape> compute_shapes(const std::vector<shape>& inputs,
compute_shapes_options options) const;
std::vector<shape> compute_shapes(const std::vector<shape>& inputs) const;

std::vector<shape> get_output_shapes() const;

instruction_ref validate() const;
Expand Down
34 changes: 18 additions & 16 deletions src/include/migraphx/op/pointwise.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -45,23 +45,18 @@ struct pointwise
{
MIGRAPHX_THROW("should have one submodule.");
}
auto* pm = mods.front();
if(pm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("pointwise should have only one output.");
if(inputs.empty())
MIGRAPHX_THROW("pointwise should have at least one input");
auto* pm = mods.front();
auto pnames = pm->get_parameter_names();
std::sort(pnames.begin(), pnames.end());
check_shapes{inputs, *this}.has(pnames.size()).same_dims();

auto type = pm->get_output_shapes().front().type();

// Scalar output if all inputs are scalar
if(inputs.front().elements() == 1 and
all_of(inputs, [](const auto& s) { return s.scalar(); }))
return shape{type};

return shape::from_permutation(type, inputs.front().lens(), find_permutation(inputs));
auto result = pm->compute_shapes(
inputs,
{.name = name(), .strict_type = true, .scalar_const_out_lens = inputs.front().lens()});
if(result.size() == 1)
return result.front();
return shape{result};
}

argument compute(const shape& output_shape,
Expand All @@ -75,7 +70,7 @@ struct pointwise
auto pnames = pm->get_parameter_names();
std::sort(pnames.begin(), pnames.end());

par_for(output_shape.elements(), [&](auto i) {
par_for(args[0].get_shape().elements(), [&](auto i) {
std::unordered_map<std::string, argument> params;

std::transform(
Expand All @@ -86,8 +81,15 @@ struct pointwise
[&](auto&& name, auto&& arg) { return std::make_pair(name, arg.element(i)); });

auto results = run(pm, params);
assert(results.size() == 1);
visit_all(output, results.front())([&](auto out, auto x) { out[i] = x.front(); });
assert(results.size() == output.get_sub_objects().size() or
(results.size() == 1 and output.get_sub_objects().empty()));
std::vector<argument> outputs;
if(results.size() == 1)
outputs = {output.share()};
else
outputs = output.share().get_sub_objects();
for(auto j : range(results.size()))
visit_all(outputs[j], results[j])([&](auto out, auto x) { out[i] = x.front(); });
});
return output;
}
Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/pass_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct module_pass_manager
virtual module& get_module() = 0;
virtual module* create_module(const std::string& name) = 0;
virtual module* create_module(const std::string& name, module m) = 0;
virtual void rename_module(const std::string& old_name, const std::string& new_name) = 0;
virtual module* get_common_parent() = 0;
virtual module* get_root_module() = 0;
virtual void run_pass(const pass& p) = 0;
Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/program.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ struct MIGRAPHX_EXPORT program
std::unordered_multimap<module_ref, module_ref> get_module_tree();

void remove_module(const std::string& name);
void rename_module(const std::string& old_name, const std::string& new_name);
void remove_unused_modules();

private:
Expand Down
8 changes: 8 additions & 0 deletions src/include/migraphx/raw_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ struct raw_data : raw_data_base
ss << static_cast<const Derived&>(*this);
return ss.str();
}

template <class T>
std::vector<T> to_vector() const
{
std::vector<T> result(static_cast<const Derived&>(*this).get_shape().elements());
this->visit([&](auto x) { result.assign(x.begin(), x.end()); });
return result;
}
};

namespace detail {
Expand Down
3 changes: 3 additions & 0 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,9 @@ struct MIGRAPHX_EXPORT shape
std::shared_ptr<const shape_impl> impl;
};

/// Flatten subshapes to a single vector of non-tuple type of shapes
std::vector<shape> flatten(const std::vector<shape>& shapes);

MIGRAPHX_EXPORT void migraphx_to_value(value& v, const shape& s);
MIGRAPHX_EXPORT void migraphx_from_value(const value& v, shape& s);

Expand Down
Loading

0 comments on commit 4e1caca

Please sign in to comment.