Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcast support for elementwise ops #148

Merged
merged 6 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,25 @@ void ConversionContext::set_convertor(NodePtr node, const Convertor& convertor)
node->get_rt_info()[rt_info_convertor()] = as_any;
}

Value ConversionContext::get_dimension_value(const Dimension& d) {
auto symbol = d.get_symbol();
assert(symbol);
symbol = ov::symbol::ancestor_of(symbol);
// Suppose all dimensions are known and the map is populated
// FIXME: Add dimensions on demand to avoid unnecessary operations in the produced MLIR
assert(dimension_map.count(symbol));
return dimension_map.at(symbol);
}

SmallVector<Value> ConversionContext::get_dynamic_dimension_values (const PartialShape& shape) {
SmallVector<Value> dims;
for (const auto& dim: shape) {
if (dim.is_dynamic()) {
dims.push_back(get_dimension_value(dim));
}
}
return dims;
}


const std::string& subgraph_mark() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/IR/Builders.h"

#include "typedefs.hpp"
#include "convert_common.hpp"

namespace ov {
namespace mlir {
Expand All @@ -20,6 +21,7 @@ using ::mlir::MLIRContext;
using ::mlir::OpBuilder;
using ::mlir::Operation;
using ::mlir::SmallVector;
using ::mlir::ValueRange;

class ConversionContext {
static std::string rt_info_convertor ();
Expand All @@ -32,6 +34,7 @@ class ConversionContext {
mlir::MLIRContext* context;
mlir::OpBuilder* block_builder;
NodeOutputMap nodeOutputMap;
std::map<SymbolPtr, Value> dimension_map;

ConversionContext(mlir::MLIRContext* context, mlir::OpBuilder* block_builder);

Expand All @@ -45,6 +48,10 @@ class ConversionContext {
static void set_convertor(NodePtr node, const Convertor& convertor);

void convert(NodePtr node);

Value get_dimension_value(const Dimension& d);

SmallVector<Value> get_dynamic_dimension_values (const PartialShape& shape);
};


Expand Down
56 changes: 23 additions & 33 deletions src/common/transformations/src/transformations/mlir/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
#include "mlir_op.hpp"
#include "op/matmul.hpp"
#include "op/relu.hpp"
#include "op/binary_eltwise.hpp"
#include "openvino/core/dimension.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/core/symbol.hpp"
Expand Down Expand Up @@ -107,30 +108,6 @@ SmallVector<mlir::Type> get_types_for_values(mlir::MLIRContext* context, const o
return types;
}

template <typename TargetOp>
struct ConvertBinary {
void operator()(ConversionContext& context, NodePtr node) {
auto loc = createLocation(context.context, node);
auto& builder = context.builder();
// TODO: Support broadcasts
const auto inputs = context.getInputs(node);
auto outType = cast<mlir::ShapedType>(inputs[0].getType());
// Named binary ops directly overwrite data in `outs` buffer so, there is no need to provide non-empty
// destination at the tensor-level.
// Use `tensor.empty` to avoid temporary buffer allocation and memcpy after bufferization.
llvm::SmallVector<Value> dynamicSizes;
for (auto [idx, dim] : llvm::enumerate(outType.getShape())) {
if (!mlir::ShapedType::isDynamic(dim))
continue;
auto dimSize = builder.create<tensor::DimOp>(loc, inputs[0], idx);
dynamicSizes.push_back(dimSize);
}
auto empty = builder.create<tensor::EmptyOp>(loc, outType, dynamicSizes);
auto op = builder.create<TargetOp>(loc, mlir::ValueRange{inputs[0], inputs[1]}, mlir::ValueRange{empty});
context.addOutputs(node, op);
}
};


mlir::OwningOpRef<mlir::ModuleOp> ngraph_to_mlir(MLIRContext* context,
const ov::OutputVector& inputs,
Expand Down Expand Up @@ -159,6 +136,24 @@ mlir::OwningOpRef<mlir::ModuleOp> ngraph_to_mlir(MLIRContext* context,
auto loc = createLocation(context, inputs[i].get_node_shared_ptr());
auto tensor = block_builder.create<bufferization::ToTensorOp>(loc, funcInputVal, /*restrict = */ true);
conversion_context.nodeOutputMap.emplace(inputs[i], tensor);

// FIXME: Avoid pre-population of dimension_map, take dimension values only if needed
auto input_shape = inputs[i].get_partial_shape();
auto input_rank = input_shape.rank();
if(input_rank.is_static()) {
for(size_t j = 0; j < input_rank.get_length(); ++j) {
auto dim = input_shape[j];
if(dim.is_dynamic()) {
auto symbol = dim.get_symbol();
assert(symbol);
symbol = ov::symbol::ancestor_of(symbol);
if(dim.is_dynamic() && !conversion_context.dimension_map.count(symbol)) {
auto dimSize = block_builder.create<tensor::DimOp>(loc, tensor, j);
conversion_context.dimension_map[symbol] = dimSize;
}
}
}
}
}

for (size_t i = 0; i < nodes.size(); ++i) {
Expand Down Expand Up @@ -276,21 +271,16 @@ class Partitioner : public ov::pass::ModelPass {
}
};

template <typename Op>
NodePtr elementwise_f32_binary_no_broadcast() {
using namespace ov::pass::pattern;
return wrap_type<Op>({any_input(), any_input()}, elementwise_no_broadcast_predicate<ov::element::f32>);
}

void injectMLIR(std::shared_ptr<ov::Model> model, MLIRContext* context) {
ov::pass::Manager manager;
using namespace ov::op;
manager.set_per_pass_validation(false);
manager.register_pass<ov::pass::SymbolicPropagation>();
manager.register_pass<MarkPattern>(elementwise_f32_binary_no_broadcast<v1::Add>(), ConvertBinary<linalg::AddOp>());
manager.register_pass<MarkPattern>(elementwise_f32_binary_no_broadcast<v1::Subtract>(), ConvertBinary<linalg::SubOp>());
manager.register_pass<MarkPattern>(elementwise_f32_binary_no_broadcast<v1::Multiply>(), ConvertBinary<linalg::MulOp>());
manager.register_pass<MarkPattern>(elementwise_f32_binary_no_broadcast<v1::Divide>(), ConvertBinary<linalg::DivOp>());
manager.register_pass<BinaryEltwisePattern<v1::Add, linalg::AddOp>>(ov::element::f32);
manager.register_pass<BinaryEltwisePattern<v1::Subtract, linalg::SubOp>>(ov::element::f32);
manager.register_pass<BinaryEltwisePattern<v1::Multiply, linalg::MulOp>>(ov::element::f32);
manager.register_pass<BinaryEltwisePattern<v1::Divide, linalg::DivOp>>(ov::element::f32);
manager.register_pass<ReluPattern>();
manager.register_pass<MatMulPattern>();
manager.register_pass<Partitioner>(context);
Expand Down
120 changes: 105 additions & 15 deletions src/common/transformations/src/transformations/mlir/convert_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,38 +132,128 @@ bool elementwise_no_broadcast_predicate_impl(const ov::Output<ov::Node>& output,
if (output.get_element_type() != type) {
return false;
}
if (has_dynamic_rank(output.get_node_shared_ptr())) {
return false;
}
// Check if implicit broadcast is possible, reject in this case
// Relies on symbolic information -- register SymbolicPropagation before applying this pattern
auto inputs = output.get_node_shared_ptr()->inputs();
auto output_shape = output.get_partial_shape();
if (output_shape.rank().is_dynamic()) {
return false;
}

if (std::any_of(inputs.begin(), inputs.end(), [&](const ov::Input<ov::Node>& input) {
auto input_shape = input.get_partial_shape();
return input_shape.rank().is_dynamic() ||
output_shape.rank().get_length() != input_shape.rank().get_length();
if(output_shape.rank().get_length() != input_shape.rank().get_length()) {
return true;
}
for (size_t i = 0; i < output_shape.size(); ++i) {
if(!are_equal_dimensions(input_shape[i], output_shape[i]))
return true;
}
return false;
})) {
return false;
}

return true;
}

bool has_dynamic_rank(NodePtr node) {
auto inputs = node->inputs();
auto outputs = node->outputs();
if (std::any_of(inputs.begin(), inputs.end(), [&](const ov::Input<ov::Node>& input) {
for (size_t i = 0; i < output_shape.size(); ++i) {
auto input_shape = input.get_partial_shape();
if (output_shape[i] != input_shape[i])
return true;
if (output_shape[i].is_static() && input_shape[i].is_static())
continue;
if (!ov::symbol::are_equal(output_shape[i].get_symbol(), input_shape[i].get_symbol()))
return true;
}
return false;
return input.get_partial_shape().rank().is_dynamic();
})) {
return true;
}
if (std::any_of(outputs.begin(), outputs.end(), [&](const ov::Output<ov::Node>& output) {
return output.get_partial_shape().rank().is_dynamic();
})) {
return true;
}
return false;
}

bool are_equal_dimensions(Dimension d1, Dimension d2) {
return
d1.is_static() && d2.is_static() && d1 == d2
||
ov::symbol::are_equal(d1.get_symbol(), d2.get_symbol());
}

bool has_broadcast(Dimension from, Dimension to) {
return from.is_static() && from.get_length() == 1 && !are_equal_dimensions(from, to);
}

bool statically_broadcastable(const PartialShape& from, const PartialShape& to) {
if(from.rank().is_dynamic() || to.rank().is_dynamic()) { // FIXME: `from` can has dynamic rank
return false;
}

auto from_rank = from.rank().get_length();
auto to_rank = to.rank().get_length();

if(from_rank > to_rank) { // such cases shouldn't be allowed to this function, but kept to make the function generic
return false;
}

auto offset = to_rank - from_rank;
for(size_t i = 0; i < from_rank; ++i) {
auto d_from = from[i];
auto d_to = to[offset + i];
if(!are_equal_dimensions(d_from, d_to) && !has_broadcast(d_from, d_to)) {
// cannot deduce neither dimensions broadcast nor dimensions equality
return false;
}
}

return true;
}

BroadcastDimensions broadcast_dimensions(const PartialShape& src, const PartialShape& dst) {
assert(statically_broadcastable(src, dst));

auto src_rank = src.rank().get_length();
auto dst_rank = dst.rank().get_length();
auto offset = dst_rank - src_rank;

BroadcastDimensions result;
auto& [collapse_groups, dimensions] = result;
ReassociationIndices group;
bool group_bonded = false; // true if `group` has a non-brodcasted dimension

size_t dst_i = 0; // dimension index in the `dst` shape
for(; dst_i < offset; ++dst_i) {
dimensions.push_back(dst_i);
}
for(; dst_i < dst_rank; ++dst_i) {
auto src_i = dst_i - offset;
auto src_d = src[src_i];
auto dst_d = dst[dst_i];
if(has_broadcast(src_d, dst_d)) {
dimensions.push_back(dst_i);
} else {
if(group_bonded) {
collapse_groups.emplace_back(group);
group = ReassociationIndices();
} else {
group_bonded = true;
}
}
group.push_back(src_i);
}

if(group_bonded && !group.empty()) {
collapse_groups.emplace_back(group);
}

assert(dst_rank - dimensions.size() == collapse_groups.size());

return result;
}

bool symbol_ancestor_less (SymbolPtr x, SymbolPtr y) {
return ov::symbol::ancestor_of(x) < ov::symbol::ancestor_of(y);
}

} // namespace mlir
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Location.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

#include "typedefs.hpp"

Expand Down Expand Up @@ -52,5 +53,18 @@ mlir::arith::ConstantOp getConstant(OpBuilder &builder, const ov::element::Type&
return builder.create<arith::ConstantOp>(unkLoc, type, attr);
}

bool has_dynamic_rank(NodePtr node);

bool are_equal_dimensions(Dimension d1, Dimension d2);

bool has_broadcast(Dimension from, Dimension to);

bool statically_broadcastable(const PartialShape& from, const PartialShape& to);

using BroadcastDimensions = std::tuple<SmallVector<ReassociationIndices>, SmallVector<int64_t>>;
BroadcastDimensions broadcast_dimensions(const PartialShape& from, const PartialShape& to);

bool symbol_ancestor_less (SymbolPtr x, SymbolPtr y);

} // namespace mlir
} // namespace ov
Loading
Loading