Skip to content

Commit

Permalink
fix bugs for run tinyYolo
Browse files Browse the repository at this point in the history
  • Loading branch information
guyang3532 committed Sep 26, 2024
1 parent ff782e0 commit 22f9ff7
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 17 deletions.
2 changes: 1 addition & 1 deletion include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4777,7 +4777,7 @@ struct OrtApi {

ORT_API2_STATUS(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order);

ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraphViewer* graph, _Out_ bool* ret);
ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret);

ORT_API2_STATUS(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph);

Expand Down
20 changes: 10 additions & 10 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2434,9 +2434,9 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetNodesIndexInTopologicalOrder, const Ort
return nullptr;
}

ORT_API_STATUS_IMPL(OrtApis::OrtGraph_IsSubgraph, const OrtGraphViewer* graph, _Out_ bool* ret) {
const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast<const ::onnxruntime::GraphViewer*>(graph);
*ret = graph_viewer->IsSubgraph();
ORT_API_STATUS_IMPL(OrtApis::OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret) {
const ::onnxruntime::Graph* graph_ptr = reinterpret_cast<const ::onnxruntime::Graph*>(graph);
*ret = graph_ptr->IsSubgraph();
return nullptr;
}

Expand Down Expand Up @@ -2610,7 +2610,7 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetSubGraph, const OrtGraphViewer* graph,
std::vector<ONNX_NAMESPACE::FunctionProto>(), graph_viewer->GetGraph().GetLogger());

auto& graph_build = model_build->MainGraph();
// bool has_control_flow_op = false;
bool has_control_flow_op = false;

std::vector<std::string> subgraph_output_names;
const std::vector<NodeIndex>& node_index = graph_viewer->GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED);
Expand Down Expand Up @@ -2646,10 +2646,10 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetSubGraph, const OrtGraphViewer* graph,
}
}

// TODO: handle control flow ops
// if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) {
// has_control_flow_op = true;
// }
std::unordered_set<std::string> control_flow_op_set = {"If", "Loop", "Scan"};
if (control_flow_op_set.find(node->OpType()) != control_flow_op_set.end()) {
has_control_flow_op = true;
}

// If the node has subgraph, it's possible that the ORT graph of that subgraph and the GraphProto in the node attributes are not in sync because of graph optimization.
// Therefore, we need to force GraphProto attributes to be updated in order to get the valid GraphProto.
Expand Down Expand Up @@ -2678,12 +2678,12 @@ ORT_API_STATUS_IMPL(OrtApis::OrtGraph_GetSubGraph, const OrtGraphViewer* graph,
// TODO:yang
// Only if the newly built graph has control flow op as well as it has parent node,
// it needs to handle outer scope values before calling graph.Resolve().
// if (has_control_flow_op && graph.ParentNode()) {
if (has_control_flow_op && graph_viewer->ParentNode()) {
// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name();
// BuildSubGraphContext(graph_build);
// SetGraphOuterScopeValuesAndInputs(graph_build, graph.GetGraph());
// SetAllGraphInputs(graph_build);
// }
}

common::Status status = graph_build.Resolve();
if (status != Status::OK()) return ToOrtStatus(status);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ ORT_API_STATUS_IMPL(OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_

ORT_API_STATUS_IMPL(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** path);

ORT_API_STATUS_IMPL(OrtGraph_IsSubgraph, const OrtGraphViewer* graph, _Out_ bool* ret);
ORT_API_STATUS_IMPL(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret);

ORT_API_STATUS_IMPL(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph);

Expand Down
2 changes: 1 addition & 1 deletion samples/c_test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ void RunTinyYolov3(OrtEnv* p_env, OrtSessionOptions* so) {
THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input2, input2_len, input2_shape, sizeof(input2_shape)/sizeof(input2_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensors[1]));

const char* input_names[] = {"input_1", "image_shape"};
const char* output_names[] = {"6379", "6381", "6383"};
const char* output_names[] = {"yolonms_layer_1", "yolonms_layer_1:1", "yolonms_layer_1:2"};

size_t output_count = sizeof(output_names)/sizeof(output_names[0]);
std::vector<OrtValue*> output_tensors(output_count, nullptr);
Expand Down
6 changes: 4 additions & 2 deletions samples/tensorRTEp/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraphViewer*
const OrtGraph* cur_graph = nullptr;
api->OrtGraph_GetOrtGraph(graph, &cur_graph);
bool is_subgraph = false;
api->OrtGraph_IsSubgraph(graph, &is_subgraph);
api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph);
if (is_subgraph) {
const OrtNode* node = nullptr;
api->OrtGraph_GetParenNode(graph, &node);
Expand Down Expand Up @@ -1235,8 +1235,10 @@ std::unique_ptr<OrtIndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGr

// Generate unique kernel name for TRT subgraph
std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(subgraph_index);
const OrtGraph* cur_graph = nullptr;
api_->OrtGraph_GetOrtGraph(graph, &cur_graph);
bool is_subgraph = false;
api_->OrtGraph_IsSubgraph(graph, &is_subgraph);
api_->OrtGraph_IsSubgraph(cur_graph, &is_subgraph);
const std::string graph_type = is_subgraph ? "subgraph" : "graph";
const char* graph_name = api_->OrtGraph_GetName(graph);
std::string meta_def_name = "TRTKernel_" + graph_type + "_" + std::string(graph_name) + subgraph_id;
Expand Down
4 changes: 2 additions & 2 deletions samples/tensorRTEp/tensorrt_execution_provider_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,12 @@ HashValue TRTGenerateId(const OrtGraphViewer* graph_viewer) {
const OrtGraph* cur_graph = nullptr;
api->OrtGraph_GetOrtGraph(graph_viewer, &cur_graph);
bool is_subgraph = false;
api->OrtGraph_IsSubgraph(graph_viewer, &is_subgraph);
api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph);
while (is_subgraph) {
const OrtGraph* parent_graph = nullptr;
api->OrtGraph_GetParentGraph(cur_graph, &parent_graph);
cur_graph = parent_graph;
api->OrtGraph_IsSubgraph(graph_viewer, &is_subgraph);
api->OrtGraph_IsSubgraph(cur_graph, &is_subgraph);
}

const OrtGraph* main_graph = cur_graph;
Expand Down

0 comments on commit 22f9ff7

Please sign in to comment.