Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify select_module after simplify_dyn_ops pass #2714

Merged
merged 13 commits into from
Feb 16, 2024
137 changes: 137 additions & 0 deletions src/simplify_dyn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/tensor_view.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -318,6 +319,141 @@
}
};

/**
* Go through `select_module` instructions and update the `output_dyn_shapes` attribute.
* Checks the submodule output shapes and determines an appropriate `output_dyn_shapes` attribute.
* This version ignores dynamic_dimension opt values.
* Intended to be run after the other simplify_dyn_ops passes.
*/
struct simplify_select_module_output_shape
{
auto matcher() const { return match::name("select_module"); }

void apply(module& m, const match::matcher_result& mr) const
{
auto sm_ins = mr.result;
auto sm_module_inputs = sm_ins->module_inputs();
std::vector<std::vector<shape>> all_output_shapes(sm_module_inputs.size());
std::transform(sm_module_inputs.begin(),
sm_module_inputs.end(),
all_output_shapes.begin(),
[](auto submod) { return submod->get_output_shapes(); });
// check that all of the submodules have the same number of outputs and all respective
// outputs have the same rank and type
auto shapes_ndim = get_shapes_ndim(all_output_shapes.front());
bpickrel marked this conversation as resolved.
Show resolved Hide resolved
auto shapes_types = get_shapes_types(all_output_shapes.front());
bool check = std::all_of(
all_output_shapes.begin() + 1, all_output_shapes.end(), [&](auto out_shapes) {
bool same_types = get_shapes_types(out_shapes) == shapes_types;
bool same_ndim = get_shapes_ndim(out_shapes) == shapes_ndim;
return same_types and same_ndim;
});
if(not check)
CharlieL7 marked this conversation as resolved.
Show resolved Hide resolved
{
return;
}
auto num_out_shapes = shapes_ndim.size();
std::vector<shape> dyn_shapes(num_out_shapes);
auto num_submod = sm_module_inputs.size();
// compare respective output shapes from each submodule to get a range for the output shape
for(int i : range(num_out_shapes))
{
std::vector<shape> shapes_at_index(num_submod);
std::transform(all_output_shapes.begin(),
all_output_shapes.end(),
shapes_at_index.begin(),
[&](auto output_shapes) { return output_shapes.at(i); });
dyn_shapes.at(i) = dyn_shape_from_shapes(shapes_at_index);
}
auto tuple_shape = shape{dyn_shapes};
m.replace_instruction(
sm_ins,
make_op("select_module", {{"output_dyn_shapes", to_value(tuple_shape)}}),
sm_ins->inputs(),
sm_module_inputs);
}

std::vector<std::size_t> get_shapes_ndim(std::vector<shape> shapes) const

Check warning on line 376 in src/simplify_dyn_ops.cpp

View workflow job for this annotation

GitHub Actions / tidy

the parameter 'shapes' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param,-warnings-as-errors]
{
std::vector<std::size_t> ret(shapes.size());
std::transform(
shapes.cbegin(), shapes.cend(), ret.begin(), [](auto s) { return s.ndim(); });
return ret;
}

std::vector<shape::type_t> get_shapes_types(std::vector<shape> shapes) const

Check warning on line 384 in src/simplify_dyn_ops.cpp

View workflow job for this annotation

GitHub Actions / tidy

the parameter 'shapes' is copied for each invocation but only used as a const reference; consider making it a const reference [performance-unnecessary-value-param,-warnings-as-errors]
{
std::vector<shape::type_t> ret(shapes.size());
std::transform(
shapes.cbegin(), shapes.cend(), ret.begin(), [](auto s) { return s.type(); });
return ret;
}

/**
* Calculating an appropriate shape that encompasses all of the given vector of shapes.
* Equivalent to creating a 2D matrix of shape lengths and do a reduce_min over each axis.
* The shapes can be dynamic or static.
* Assuming all shapes have the same ndim.
*/
shape dyn_shape_from_shapes(std::vector<shape> shape_vec) const
{
// making 2D matrices of min_lens and max_lens
// specifically using uint64_t because we're going to put the values into a tensor_view
// later
std::vector<uint64_t> all_min_lens;
std::vector<uint64_t> all_max_lens;
for(int i : range(shape_vec.size()))
CharlieL7 marked this conversation as resolved.
Show resolved Hide resolved
{
auto s = shape_vec.at(i);
auto min_lens = s.min_lens();
auto max_lens = s.max_lens();
std::copy(min_lens.begin(), min_lens.end(), std::back_inserter(all_min_lens));
std::copy(max_lens.begin(), max_lens.end(), std::back_inserter(all_max_lens));
}
assert(all_min_lens.size() == shape_vec.size() * shape_vec.front().ndim());
assert(all_max_lens.size() == shape_vec.size() * shape_vec.front().ndim());
auto num_rows = shape_vec.size();
auto num_cols = shape_vec.front().ndim();
shape tensor_shape{shape::uint64_type, {num_rows, num_cols}};
auto min_lens_matrix = make_view(tensor_shape, all_min_lens.data());
auto max_lens_matrix = make_view(tensor_shape, all_max_lens.data());

std::vector<uint64_t> mins(num_cols);
std::vector<uint64_t> maxes(num_cols);
// rearranging data into column vectors to reduce over
// i = row, j = column
for(int j : range(num_cols))
{
std::vector<uint64_t> reduce_min_vals(num_rows);
std::vector<uint64_t> reduce_max_vals(num_rows);
for(int i : range(num_rows))
{
reduce_min_vals.at(i) = min_lens_matrix(i, j);
reduce_max_vals.at(i) = max_lens_matrix(i, j);
}
uint64_t max_int = std::numeric_limits<uint64_t>::max();
uint64_t min_val =
std::accumulate(reduce_min_vals.begin(),
reduce_min_vals.end(),
max_int,
[](uint64_t x, uint64_t y) { return x < y ? x : y; });
uint64_t max_val = std::accumulate(
reduce_max_vals.begin(), reduce_max_vals.end(), 0, [](uint64_t x, uint64_t y) {
return x > y ? x : y;
});
mins.at(j) = min_val;
maxes.at(j) = max_val;
}
// fixed output shape case
if(mins == maxes)
{
return shape{shape_vec.front().type(), mins};
}
// dynamic output shape case
return shape{shape_vec.front().type(), mins, maxes, {}};
}
};

void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(m,
Expand All @@ -328,6 +464,7 @@
find_const_3in_slice{},
find_const_4in_slice{},
find_const_alloc_fill{});
match::find_matches(m, simplify_select_module_output_shape{});
CharlieL7 marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
Loading
Loading