diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index ab94f203c231..0db7877b5e42 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -173,7 +173,9 @@ def __init__(self, module): self._get_input = module["get_input"] self._get_num_outputs = module["get_num_outputs"] self._get_input_index = module["get_input_index"] + self._get_output_index = module["get_output_index"] self._get_input_info = module["get_input_info"] + self._get_output_info = module["get_output_info"] self._get_num_inputs = module["get_num_inputs"] self._load_params = module["load_params"] self._share_params = module["share_params"] @@ -315,6 +317,21 @@ def get_input_index(self, name): """ return self._get_input_index(name) + def get_output_index(self, name): + """Get outputs index via output name. + + Parameters + ---------- + name : str + The output key name + + Returns + ------- + index: int + The output index. -1 will be returned if the given output name is not found. + """ + return self._get_output_index(name) + def get_input_info(self): """Return the 'shape' and 'dtype' dictionaries of the graph. @@ -341,6 +358,24 @@ def get_input_info(self): return shape_dict, dtype_dict + def get_output_info(self): + """Return the 'shape' and 'dtype' dictionaries of the graph. + + Returns + ------- + shape_dict : Map + Shape dictionary - {output_name: tuple}. + dtype_dict : Map + dtype dictionary - {output_name: dtype}. + """ + output_info = self._get_output_info() + assert "shape" in output_info + shape_dict = output_info["shape"] + assert "dtype" in output_info + dtype_dict = output_info["dtype"] + + return shape_dict, dtype_dict + def get_output(self, index, out=None): """Get index-th output to out diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index 107613e5a28c..6324da9c27ef 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -745,6 +745,18 @@ PackedFunc GraphExecutor::GetFunction(const String& name, const ObjectPtrGetInputIndex(args[0].operator String()); }); + } else if (name == "get_output_index") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(String::CanConvertFrom(args[0])) << "Output key is not a string"; + int out_idx = -1; + for (size_t i = 0; i < outputs_.size(); i++) { + std::string& name = nodes_[outputs_[i].node_id].name; + if (args[0].operator String() == name) { + out_idx = i; + } + } + *rv = out_idx; + }); } else if (name == "get_input_info") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { auto [shape_info, dtype_info] = this->GetInputInfo(); diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index fc6ec59a6d51..d7b6e13c18b6 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -467,6 +467,12 @@ def test_graph_executor_api(): assert isinstance(dtype_dict[name], tvm.runtime.container.String) assert dtype_dict[name] == ty.dtype + shape_dict, dtype_dict = mod.get_output_info() + assert isinstance(shape_dict, tvm.container.Map) + assert isinstance(dtype_dict, tvm.container.Map) + for i, key in enumerate(shape_dict): + assert mod.get_output_index(key) == i + @tvm.testing.requires_llvm def test_benchmark():