Skip to content

Commit

Permalink
Some initial implementations for DSL frontend (#524)
Browse files Browse the repository at this point in the history
* add version 15 support of operator Shape and unit test

* updated op register for Shape version 15

* add support of Slice operator version 13 and unit test

* generate json graph from antares expression and run custom op

* debug for packing kernel by adding extern C keyword

* added comparation with CustomOp in the AlexNet case
  • Loading branch information
donglinb authored Jul 7, 2023
1 parent bd4f6fe commit 5730e93
Show file tree
Hide file tree
Showing 10 changed files with 1,056 additions and 9 deletions.
31 changes: 31 additions & 0 deletions src/nnfusion/frontend/onnx_import/op/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,37 @@ namespace nnfusion

} // namespace set_1

namespace set_15
{
NamedNodeVector TranslateShapeOp(const onnx::NodeProto& node_proto,
const NodeMap& all_ng_nodes,
std::shared_ptr<nnfusion::graph::Graph> m_graph)
{
auto data = GetInputIndex(all_ng_nodes, node_proto, 0);
auto data_shape = data.get_shape();

Node node(node_proto);
int64_t rank = data_shape.size();
int64_t start = node.get_attribute_value<int64_t>("start", 0);
if(start < 0)
start += rank;
start = (start < 0) ? 0 : (start > rank) ? rank : start;
int64_t end = node.get_attribute_value<int64_t>("end", rank);
if (end < 0)
end += rank;
end = (end < 0) ? 0 : (end > rank) ? rank : end;
int64_t dim_value = (end - start) < 0 ? 0 : (end - start);

auto op = std::make_shared<op::Constant>(
nnfusion::element::i64, Shape{dim_value}, Shape(data_shape.begin(), data_shape.begin() + dim_value));
op->set_name(node_proto.output(0));
auto gnode = m_graph->add_node_and_edge(op, nnfusion::graph::GNodeVector{});
NamedNodeVector ret{{node_proto.output(0), gnode}};
return ret;
}

} // namespace set_15

} //namespace op

} // namespace onnx_import
Expand Down
7 changes: 7 additions & 0 deletions src/nnfusion/frontend/onnx_import/op/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ namespace nnfusion

} // namespace set_1

namespace set_15
{
NamedNodeVector TranslateShapeOp(const onnx::NodeProto& node_proto,
const NodeMap& all_ng_nodes,
std::shared_ptr<nnfusion::graph::Graph> m_graph);
} // namespace set_15

} //namespace op

} // namespace onnx_import
Expand Down
56 changes: 47 additions & 9 deletions src/nnfusion/frontend/onnx_import/op/slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
//----------------------------------------------------------------------------------------------

#include <vector>
#include <unordered_set>

#include "../util/util.hpp"
#include "nnfusion/frontend/util/evaluator.hpp"
Expand All @@ -30,6 +31,30 @@ static inline int64_t get_valid_array_idx(int64_t idx, int64_t last_idx)
return (idx >= 0) ? std::min(idx, last_idx) : std::max<int64_t>(0, last_idx + idx);
}

static inline void processSliceInputs(const int64_t input_rank, int64_t& start, int64_t& end, int64_t& step)
{
auto clamp = [](int64_t val, int64_t min, int64_t max) -> int64_t
{
return (val < min) ? min : (val > max) ? max : val;
};
// process step
NNFUSION_CHECK(step != 0);
// process start
if (start < 0)
start += input_rank;
if (step < 0)
start = clamp(start, 0, input_rank - 1);
else
start = clamp(start, 0, input_rank);
// process end
if (end < 0)
end += input_rank;
if (step < 0)
end = clamp(end, -1, input_rank - 1);
else
end = clamp(end, 0, input_rank);
}

namespace nnfusion
{
namespace frontend
Expand Down Expand Up @@ -94,6 +119,7 @@ namespace nnfusion
NNFUSION_CHECK(GetValueFromNGraphOp(inputs[1].gnode, &starts));
std::vector<int64_t> ends;
NNFUSION_CHECK(GetValueFromNGraphOp(inputs[2].gnode, &ends));
NNFUSION_CHECK(starts.size() == ends.size());
std::vector<int64_t> axes;
if (inputs.size() > 3)
{
Expand All @@ -104,6 +130,7 @@ namespace nnfusion
axes.resize(starts.size());
std::iota(axes.begin(), axes.end(), 0);
}
NNFUSION_CHECK(axes.size() == starts.size());

std::vector<int64_t> steps;
if (inputs.size() > 4)
Expand All @@ -114,20 +141,31 @@ namespace nnfusion
{
steps.resize(starts.size(), 1);
}
NNFUSION_CHECK(steps.size() == axes.size());

Shape data_shape = data.get_shape();
Shape lower_bounds(data_shape.size());
size_t data_rank = data_shape.size();
Shape lower_bounds(data_rank, 0);
Shape upper_bounds = data_shape;
Strides strides(data_shape.size(), 1);
Strides strides(data_rank, 1);

for (auto idx = 0; idx < axes.size(); ++idx)
std::unordered_set<int64_t> unique_axes;
for (size_t idx = 0; idx < axes.size(); ++idx)
{
size_t axis = axes.at(idx);
lower_bounds.at(axis) =
get_valid_array_idx(starts.at(idx), data_shape.at(axis));
upper_bounds.at(axis) =
get_valid_array_idx(ends.at(idx), data_shape.at(axis));
strides.at(axis) = steps.at(idx);
int64_t axis = axes.at(idx) < 0 ? axes.at(idx) + static_cast<int64_t>(data_rank) : axes.at(idx);
NNFUSION_CHECK(axis >= 0 && axis < static_cast<int64_t>(data_rank));
NNFUSION_CHECK(unique_axes.find(axis) == unique_axes.end());
unique_axes.insert(axis);

int64_t start = starts.at(idx);
int64_t end = ends.at(idx);
int64_t step = steps.at(idx);
int64_t data_dim = static_cast<int64_t>(data_shape.at(static_cast<size_t>(axis)));
processSliceInputs(data_dim, start, end, step);

lower_bounds.at(static_cast<size_t>(axis)) = start;
upper_bounds.at(static_cast<size_t>(axis)) = end;
strides.at(static_cast<size_t>(axis)) = step;
}

auto op = std::make_shared<op::Slice>(lower_bounds, upper_bounds, strides);
Expand Down
1 change: 1 addition & 0 deletions src/nnfusion/frontend/onnx_import/ops_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ namespace nnfusion
REGISTER_OPERATOR("ReshapeGrad", 1, TranslateReshapeGradOp);
//REGISTER_OPERATOR("Selu", 1, selu);
REGISTER_OPERATOR("Shape", 1, TranslateShapeOp);
REGISTER_OPERATOR("Shape", 15, TranslateShapeOp);
REGISTER_OPERATOR("Sigmoid", 1, TranslateUnaryOp<op::Sigmoid>);
REGISTER_OPERATOR("Sin", 1, TranslateUnaryOp<op::Sin>);
REGISTER_OPERATOR("Slice", 1, TranslateSliceOp);
Expand Down
Binary file added test/models/onnx/slice.onnx
Binary file not shown.
41 changes: 41 additions & 0 deletions test/nnfusion/frontend/onnx_import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,27 @@ TEST(nnfusion_onnx_import, relu_op)
}
}

TEST(nnfusion_onnx_import, shape_op)
{
// shape op is used
auto model = frontend::load_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/shape.onnx"));

Inputs inputs;
inputs.emplace_back(test::NDArray<float, 3>(
{{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}})
.get_vector());
vector<vector<int64_t>> expected_outputs{{3, 4, 5}};

vector<vector<int64_t>> outputs{execute<float, int64_t>(model, inputs, "NNFusion")};
EXPECT_EQ(outputs.size(), expected_outputs.size());
for (size_t i = 0; i < expected_outputs.size(); ++i)
{
EXPECT_EQ(expected_outputs[i], outputs[i]);
}
}

TEST(nnfusion_onnx_import, sigmoid_op)
{
auto model =
Expand Down Expand Up @@ -1248,6 +1269,26 @@ TEST(nnfusion_onnx_import, sin_op)
}
}

TEST(nnfusion_onnx_import, slice_op)
{
auto model = frontend::load_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/slice.onnx"));

RawInputs inputs;
inputs.emplace_back(convert_to_raw(test::NDArray<float, 2>({{1, 2, 3, 4}, {5, 6, 7, 8}}).get_vector())); // data
// inputs.emplace_back(convert_to_raw(vector<int64_t>{1, 0})); // starts
// inputs.emplace_back(convert_to_raw(vector<int64_t>{2, 3})); // ends
// inputs.emplace_back(convert_to_raw(vector<int64_t>{0, 1})); // axes
// inputs.emplace_back(convert_to_raw(vector<int64_t>{1, 2})); // steps
vector<vector<float>> expected_outputs{{5, 7}};

RawOutputs outputs{mixed_type_execute(model, inputs, "NNFusion")};
EXPECT_EQ(outputs.size(), expected_outputs.size());
for(size_t i = 0; i < expected_outputs.size(); ++i)
{
EXPECT_TRUE(test::all_close_f(expected_outputs[i], convert_from_raw<float>(outputs[i])));
}
}

TEST(nnfusion_onnx_import, sparse_softmax_cross_entropy_op)
{
// copy from onnxruntime
Expand Down
124 changes: 124 additions & 0 deletions test/python/dsl_frontend/export_json_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
from ir_parser import ir_graph_parser


def get_input_dict(input_orders):
input_list, input_dict = [], {}
for k in input_orders:
if isinstance(input_orders[k], tuple):
input_list += [(k, input_orders[k][2], input_orders[k][1])]
else:
input_list += [(k, input_orders[k].shape, input_orders[k].dtype)]
for k, shape, dtype in input_list:
input_dict[k] = {
'dtype': str(dtype).split('.')[1],
'shape': list(shape)
}
for k in input_dict:
if len(input_dict[k]['shape']) == 0:
input_dict[k]['shape'] = [1]
return input_dict

def construct_json_graph(ir, input_dict):
exprss = ir.replace('\n', ' ').strip()
ast_seq, input_dict, output_dict, _ = ir_graph_parser(exprss, input_dict)
# print('input_dict:', input_dict)
# print('output_dict:', output_dict)
# topological sort and construct graph
nodes = []
known_tensors = {k : v for v, k in enumerate(sorted(list(input_dict)))}
node_index_offset = len(known_tensors)
while not all([k in known_tensors for k in output_dict]):
node_len = len(nodes)
for index, ast in enumerate(ast_seq):
node_output_name = ast['props']['output_name']
if node_output_name in known_tensors:
continue # already added nodes
node_input_list = list(ast['props']['input_dict'])
node_input_list.sort(key=lambda x : ast['props']['raw_exprss'].find(x))
if all([k in known_tensors for k in node_input_list]):
# generate antares expression
expression_ast = ast['props']['raw_exprss'].replace('"', '`').replace('\n', ' ').strip()
input_dict_ast = json.dumps(ast['props']['input_dict'])
expression_ast = f'- einstein_v2(" {expression_ast}", input_dict={input_dict_ast})'
# replace input name in exprss
# print('expression_ast old:', expression_ast)
for v, k in enumerate(node_input_list):
expression_ast = expression_ast.replace(k, 'input%d' % (v))
# print('expression_ast new:', expression_ast)
# count edges
edges = [[known_tensors[k], 0] for k in node_input_list]
# construct new node
node_id = index + node_index_offset
nodes.append([node_id, expression_ast, node_output_name, edges])
known_tensors[node_output_name] = node_id
if node_len == len(nodes) and not all([k in known_tensors for k in output_dict]):
raise Exception('Invalid model graph.')
# add output node
node_index_offset += len(ast_seq)
for v, k in enumerate(output_dict):
nodes.append([v + node_index_offset, '', 'Result', [[known_tensors[k], 0]]])
return json.dumps(nodes, indent=2)


if __name__ == '__main__':
import torch
# from antares_core.frameworks.pytorch.custom_op import CustomOp
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dtype = torch.float32
kwargs = {'dtype': dtype,
'device': device,
'requires_grad': False}
def create_param(name, shape):
return (torch.rand(shape, **kwargs) - 0.5) * 0.001
input_tensor = torch.ones([64, 3, 227, 227], **kwargs)
const_0_ = create_param('const_0_', [11, 11, 3, 64])
const_1_ = create_param('const_1_', [5, 5, 64, 192])
const_2_ = create_param('const_2_', [3, 3, 192, 384])
const_3_ = create_param('const_3_', [3, 3, 384, 256])
const_4_ = create_param('const_4_', [3, 3, 256, 256])
const_5_ = create_param('const_5_', [9216, 4096])
const_6_ = create_param('const_6_', [4096, 4096])
const_7_ = create_param('const_7_', [4096, 1000])
ir = f'''
conv_0[N, F, HO, WO] +=! input_tensor[N, C, HO * 4 + KH, WO * 4 + KW] * const_0_[KH, KW, C, F] where HO in 55, WO in 55;
mpool_0[N, C, HO, WO] >=! conv_0[N, C, HO * 2 + KH, WO * 2 + KW].call(`max`, [0.0]) where HO in 27, WO in 27, KH in 3, KW in 3;
conv_1[N, F, HO, WO] +=! mpool_0[N, C, -2 + HO + KH, -2 + WO + KW].when([-2 + HO + KH >= 0, -2 + HO + KH < 27, -2 + WO + KW >= 0, -2 + WO + KW < 27], 0.0) * const_1_[KH, KW, C, F] where HO in 27, WO in 27;
mpool_1[N, C, HO, WO] >=! conv_1[N, C, HO * 2 + KH, WO * 2 + KW].call(`max`, [0.0]) where HO in 13, WO in 13, KH in 3, KW in 3;
conv_2[N, F, HO, WO] +=! mpool_1[N, C, -1 + HO + KH, -1 + WO + KW].when([-1 + HO + KH >= 0, -1 + HO + KH < 13, -1 + WO + KW >= 0, -1 + WO + KW < 13], 0.0) * const_2_[KH, KW, C, F] where HO in 13, WO in 13;
conv_2_relu[N, F, HO, WO] = conv_2[N, F, HO, WO].call(`max`, [0.0]);
conv_3[N, F, HO, WO] +=! conv_2_relu[N, C, -1 + HO + KH, -1 + WO + KW].when([-1 + HO + KH >= 0, -1 + HO + KH < 13, -1 + WO + KW >= 0, -1 + WO + KW < 13], 0.0) * const_3_[KH, KW, C, F] where HO in 13, WO in 13;
conv_3_relu[N, F, HO, WO] = conv_3[N, F, HO, WO].call(`max`, [0.0]);
conv_4[N, F, HO, WO] +=! conv_3_relu[N, C, -1 + HO + KH, -1 + WO + KW].when([-1 + HO + KH >= 0, -1 + HO + KH < 13, -1 + WO + KW >= 0, -1 + WO + KW < 13], 0.0) * const_4_[KH, KW, C, F] where HO in 13, WO in 13;
mpool_2[N, C, HO, WO] >=! conv_4[N, C, HO * 2 + KH, WO * 2 + KW].call(`max`, [0.0]) where HO in 6, WO in 6, KH in 3, KW in 3;
reshape_0[N0, N1] = mpool_2[N0, N1 // 36 % 256, N1 // 6 % 6, N1 % 6] where N1 in 9216;
dense_0[N, M] +=! reshape_0[N, K] * const_5_[K, M];
dense_0_relu[N, M] = dense_0[N, M].call(`max`, [0.0]);
dense_1[N, M] +=! dense_0_relu[N, K] * const_6_[K, M];
dense_1_relu[N, M] = dense_1[N, M].call(`max`, [0.0]);
dense_2[N, M] +=! dense_1_relu[N, K] * const_7_[K, M];
'''
input_orders={
'input_tensor': input_tensor,
'const_0_': const_0_,
'const_1_': const_1_,
'const_2_': const_2_,
'const_3_': const_3_,
'const_4_': const_4_,
'const_5_': const_5_,
'const_6_': const_6_,
'const_7_': const_7_,
}
input_dict = get_input_dict(input_orders)
graph = construct_json_graph(ir, input_dict)
with open('alexnet_ir_graph.json', 'w') as f:
f.write(graph)

# output_logits = CustomOp(ir, input_orders=input_orders, device=device).emit()
# result = output_logits(input_tensor, const_0_, const_1_, const_2_, const_3_, const_4_, const_5_, const_6_, const_7_)
# print('The result of tensor `%s` is:\n%s' % (output_logits.output_names[0], result))
Loading

0 comments on commit 5730e93

Please sign in to comment.