Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Find dot slice #3268

Merged
merged 7 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,89 @@ struct find_mul_dot
}
};

/*
Moves the slice on the output of the Dot operation to slices on the inputs of the Dot operation to
avoid computing redundant values.
e.g. slice(gemm(a, b)) --> gemm(slice(a), slice(b))
*/
struct find_dot_slice
CharlieL7 marked this conversation as resolved.
Show resolved Hide resolved
{
auto matcher() const
{
return match::name("slice")(
match::args(match::name("dot", "quant_dot")(match::used_once()).bind("dot_ins")));
}

void apply(module& m, const match::matcher_result& r) const
{
auto slice_ins = r.result;
auto dot_ins = r.instructions["dot_ins"];
auto slice_op = slice_ins->normalized_operator().to_value();
auto axes = slice_op["axes"].to_vector<int64_t>();
auto starts = slice_op["starts"].to_vector<int64_t>();
auto ends = slice_op["ends"].to_vector<int64_t>();
assert(starts.size() == ends.size() and starts.size() == axes.size());
auto has_neg_vals = [](auto vec) {
return std::any_of(vec.begin(), vec.end(), [](auto i) { return i < 0; });
};
if(has_neg_vals(starts) or has_neg_vals(ends) or has_neg_vals(axes))
{
MIGRAPHX_THROW("FIND_DOT_SLICE: slice is not normalized.");
}
auto dot_inputs = dot_ins->inputs();
auto num_batch_dims = dot_ins->get_shape().lens().size() - 2;
std::vector<int64_t> slice_axes_1, starts_1, ends_1; // NOLINT
std::vector<int64_t> slice_axes_2, starts_2, ends_2; // NOLINT
for(auto i : range(axes.size()))
{
if(axes[i] < num_batch_dims)
{
slice_axes_1.push_back(axes[i]);
starts_1.push_back(starts[i]);
ends_1.push_back(ends[i]);
slice_axes_2.push_back(axes[i]);
starts_2.push_back(starts[i]);
ends_2.push_back(ends[i]);
}
else if(axes[i] == num_batch_dims)
{
slice_axes_1.push_back(axes[i]);
starts_1.push_back(starts[i]);
ends_1.push_back(ends[i]);
}
else if(axes[i] == num_batch_dims + 1)
{
slice_axes_2.push_back(axes[i]);
starts_2.push_back(starts[i]);
ends_2.push_back(ends[i]);
}
else
{
MIGRAPHX_THROW("FIND_DOT_SLICE: invalid case");
}
}
auto slice_1 = dot_inputs.at(0);
if(not slice_axes_1.empty())
{
slice_1 = m.insert_instruction(
slice_ins,
migraphx::make_op("slice",
{{"axes", slice_axes_1}, {"starts", starts_1}, {"ends", ends_1}}),
dot_inputs.at(0));
}
auto slice_2 = dot_inputs.at(1);
if(not slice_axes_2.empty())
{
slice_2 = m.insert_instruction(
slice_ins,
migraphx::make_op("slice",
{{"axes", slice_axes_2}, {"starts", starts_2}, {"ends", ends_2}}),
dot_inputs.at(1));
}
m.replace_instruction(slice_ins, dot_ins->get_operator(), {slice_1, slice_2});
}
};

struct find_dot_mul
{
auto matcher() const
Expand Down Expand Up @@ -1896,6 +1979,7 @@ void simplify_algebra::apply(module& m) const
find_mul_conv{},
find_mul_slice_conv{},
find_mul_dot{},
find_dot_slice{},
find_dot_mul{},
find_mul_add{},
find_unit_ops{},
Expand Down
167 changes: 167 additions & 0 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4081,6 +4081,173 @@ TEST_CASE(dot_mul_b_non_const)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_slice_b)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
auto b = m1.add_parameter("b", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto slice = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {128}}}), dot);
m1.add_return({slice});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", as);
auto b = m2.add_parameter("b", bs);
auto b_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {128}}}), b);
auto dot = m2.add_instruction(migraphx::make_op("dot"), a, b_slice);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_slice_a)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
auto b = m1.add_parameter("b", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto slice = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {128}}}), dot);
m1.add_return({slice});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", as);
auto b = m2.add_parameter("b", bs);
auto a_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {128}}}), a);
auto dot = m2.add_instruction(migraphx::make_op("dot"), a_slice, b);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_slice_ab)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
auto b = m1.add_parameter("b", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto slice = m1.add_instruction(
migraphx::make_op("slice",
{{"axes", {1, 2}}, {"starts", {64, 64}}, {"ends", {128, 128}}}),
dot);
m1.add_return({slice});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", as);
auto b = m2.add_parameter("b", bs);
auto a_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {128}}}), a);
auto b_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {128}}}), b);
auto dot = m2.add_instruction(migraphx::make_op("dot"), a_slice, b_slice);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_slice_batch_dims)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
auto b = m1.add_parameter("b", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto slice = m1.add_instruction(
migraphx::make_op(
"slice", {{"axes", {0, 1, 2}}, {"starts", {0, 64, 64}}, {"ends", {1, 128, 128}}}),
dot);
m1.add_return({slice});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", as);
auto b = m2.add_parameter("b", bs);
auto a_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 64}}, {"ends", {1, 128}}}),
a);
auto b_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 64}}, {"ends", {1, 128}}}),
b);
auto dot = m2.add_instruction(migraphx::make_op("dot"), a_slice, b_slice);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_slice_not_applicable_1)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
auto b = m1.add_parameter("b", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto slice1 = m1.add_instruction(
migraphx::make_op("slice",
{{"axes", {1, 2}}, {"starts", {64, 64}}, {"ends", {128, 128}}}),
dot);
auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), dot);

m1.add_return({slice1, slice2});
};
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_slice_not_applicable_2)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}};
migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
auto b = m1.add_parameter("b", bs);
auto dot = m1.add_instruction(migraphx::make_op("dot"), a, b);
auto slice = m1.add_instruction(
shivadbhavsar marked this conversation as resolved.
Show resolved Hide resolved
migraphx::make_op("slice",
{{"axes", {-2, -1}}, {"starts", {64, 64}}, {"ends", {128, 128}}}),
dot);
m1.add_return({slice});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", as);
auto b = m2.add_parameter("b", bs);
auto a_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {64}}, {"ends", {128}}}), a);
auto b_slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {64}}, {"ends", {128}}}), b);
auto dot = m2.add_instruction(migraphx::make_op("dot"), a_slice, b_slice);
m2.add_return({dot});
};
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(conv_concat)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 8, 4, 4}};
Expand Down
Loading