-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Out-Tree EP feature #21450
base: main
Are you sure you want to change the base?
[WIP] Out-Tree EP feature #21450
Conversation
@@ -0,0 +1,35 @@ | |||
#pragma once |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,14 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
OrtExecutionProviderFactory* RegisterCustomEp() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return Status instead #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to do this? This function will new a factory object by invoking its constructor which has no return type
… EP as graph API is not exported by ORT. Need to put these graph API into ortapi structure
…roviderAdapter::Compile()
} OrtMetaDef; | ||
|
||
typedef struct OrtIndexedSubGraph { | ||
OrtMetaDef* meta_def; // TODO(leca): how to define a nested structure pointer? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this have to be a pointer to an OrtMetaDef? It may be simpler if this meta_def is contained by value instead. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks we will check the pointer is null or not to distinguish between single node mode and fused node mode (See base class IExecutionProvider::GetCapability() which does not set this pointer and TryAssignSingleNode() which will check this pointer)
|
||
OutTreeEp::OutTreeEp(const char* ep_type, const OutTreeEpInfo& ep_info) : info(ep_info) { | ||
type = ep_type; | ||
OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm understanding correctly, the type of the OrtIndexedSubGraph*** indexed_sub_graph
parameter is essentially asking the EP to fill out an array of pointers to OrtIndexedSubGraph
objects.
Would it be simpler to change this to OrtIndexedSubgraph** indexed_sub_graph
so that the EP fills out an array of OrtIndexedSubGraph
objects directly? Each OrtIndexedSubgraph struct is a simple POD that can be created on the stack and copied around. It seems like it would result in less pointer tracking. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, who is responsible for freeing this memory and when? If the EP allocates an array, then the EP should free it. The currently example leaks the allocations.
Edit: one possibility is to have onnxruntime call a new EP function (e.g., ReleaseOrtIndexedSubGraph()) so the the EP can free the memory. onnxruntime would call this once it is done using the indexed_sub_graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that we don't know how many OrtIndexedSubGraph would be before we call GetCapability() function. I will fix the leak issue in the coming commits
@@ -0,0 +1,21 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,89 @@ | |||
#include "kernel_ep.h" |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,36 @@ | |||
#pragma once |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,15 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,14 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,627 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,285 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,266 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,147 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,557 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,110 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,54 @@ | |||
#include "qnn_execution_provider.h" |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,33 @@ | |||
#pragma once |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
OrtExecutionProviderFactory* RegisterCustomEp() { | ||
std::unique_ptr<onnxruntime::TensorrtExecutionProviderFactory> ret = std::make_unique<onnxruntime::TensorrtExecutionProviderFactory>(); | ||
return ret.release(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At first glance this seems to be a memory leak of the returned object. However, after some digging, it looks like ORT is freeing the memory when the EP library is unloaded. This is still an issue. Preferably, memory should not be allocated on one side of the API boundary and then deleted on the other. ORT does not know what allocator the EP library used to allocate the object, so it can't be expected to know exactly how to delete it.
This may not need a heap allocation at all. The RegisterCustomEp
function could accept a pointer to a OrtExecutionProviderFactory
object that was allocated by ORT, and then it could just fill out the members. Then ORT can decide how/when to free the object. The EP library only worries about filling out the function callbacks.
OrtStatus* RegisterCustomEp(OrtExecutionProviderFactory* ep_factory) {
ep_factory->CreateExecutionProvider = [](/*params*/) { /* impl to create EP instance */ };
return nullptr;
}
Also, I really think that this function should return an OrtStatus*
so that the EP library can indicate an error with a descriptive error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After more investigation, it seems like this implementation directly returns a OrtExecutionProviderFactory*
because you want to be able to inherit from OrtExecutionProviderFactory
and use polymorphism for calls to GetCapability
and Compile
.
Here's an alternative approach that keeps allocations on the EP side of the API boundary, returns OrtStatus*
where appropriate, and still allows you to create a custom TensorRTEP class.
// tensorrt_execution_provider.cc (EP Plugin Library).
#include <onnxruntime_c_api_ep.h>
struct TensorRTEpFactory {
// …
std::string ep_name;
OrtExecutionProviderFactory ort_ep_factory = {};
};
class TensorRTEp {
public:
static std::unique_ptr<TensorRTEp> CreateInstance(const std::string& ep_name, size_t num_ep_options,
const char* const* ep_option_keys, const char* const* ep_option_vals,
/*output*/ std::string& err_msg) {
if (/* EP options invalid */) { err_msg = "Invalid EP options"; return nullptr; }
// Create internal EP object/state and fill out ORT callbacks (only GetCapability and Compile are shown here).
std::unique_ptr<TensorRTEp> ep = std::make_unique<TensorRTEp>(ep_name, /*other args*/);
ep->ort_ep.GetCapability = [](const OrtExecutionProvider* ort_ep, /*params*/) {
TensorRTEp* this_ = reinterpret_cast<TensorRTEp*>(ort_ep->state);
/* impl */
};
ep->ort_ep.Compile = [](const OrtExecutionProvider* ort_ep, /*params*/) {/* impl */};
ep->ort_ep.state = ep->get();
return ep;
}
OrtExecutionProvider* GetOrtExecutionProvider() { return &this->ort_ep; }
private:
TensorRTEp(const std::string& ep_name, /* other params */) { /* … */ }
std::string ep_name;
OrtExecutionProvider ort_ep = {};
// Other state/methods here …
};
static void DestroyOnDllUnload(std::unique_prt<TensorRTEpFactory>&& ep_factory) {
static std::vector<std::unique_prt<TensorRTEp>> ep_factories;
static std::mutex m_;
std::lock_guard<std::mutex> lock(m_);
ep_factories.push_back(std::move(ep_factory));
}
static void DestroyOnDllUnload(std::unique<TensorRTEp> ep_instance) {/*similar implementation as above*/}
#ifdef __cplusplus
extern "C" {
#endif
// DLL ENTRY POINT
OrtStatus* ORT_API_CALL RegisterCustomEp(const OrtExecutionProviderFactory** ort_ep_factory, const char* ep_name) {
auto ep_factory = std::make_unique<TensorRTEpFactory>(ep_name, OrtExecutionProviderFactory{});
ep_factory->ort_ep_factory.CreateExecutionProvider = CreateExecutionProviderCallback;
ep_factory->ort_ep_factory.state = ep_factor->get(); // Can optionally store custom state.
*ort_ep_factory = &ep_factory->ort_ep_factory; // Update output parameter to point to our EP factory callbacks.
DestroyOnDllUnload(std::move(ep_factory));
return nullptr;
}
#ifdef __cplusplus
}
#endif
OrtStatus* ORT_API_CALL CreateExecutionProviderCallback(const OrtExectutionProviderFactory* ort_ep_factory,
const char* const* option_keys,
const char* const* option_vals,
size_t num_options,
const OrtExecutionProvider** ort_ep) {
TensorRTEpFactory* ep_factory = reinterpret_cast<TensorRTEpFactory*>(ort_ep_factory->state);
std::string err_msg;
auto ep = TensorRTEp::CreateInstance(ep_factory->ep_name, option_keys, option_vals, num_options, err_msg);
if (!err_msg.empty()) { /* return OrtStatus with error message */ }
*ort_ep = ep->GetOrtExecutionProvider(); // Update output parameter to point to our EP callbacks.
DestroyOnDllUnload(std::move(ep));
return nullptr;
}
The above requires modifying OrtExecutionProviderFactory
and OrtExecutionProvider
to store a void*
pointer to custom state.
struct OrtExecutionProviderFactory {
// Same
void* state; // State set by the EP plugin library.
};
struct OrtExecutionProvider {
// Same
void* state; // State set by the EP plugin library.
};
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's another alternative. In the above example, ORT receives pointers to OrtExecutionProviderFactory
and OrtExecutionProvider
structs that were allocated within the EP plugin library. An alternative is to just copy the structs since they’re bags of function pointers (plain-old-data). There is really no need to give ORT a pointer to these structs when we can just copy them so that ORT has its own versions. This would be achievable with the following changes:
#include <onnxruntime_c_api_ep.h>
// Same as previous example
#ifdef __cplusplus
extern "C" {
#endif
// DLL ENTRY POINT
OrtStatus* ORT_API_CALL RegisterCustomEp(/*out*/ OrtExecutionProviderFactory* ort_ep_factory, const char* ep_name) {
auto ep_factory = std::make_unique<TensorRTEpFactory>(ep_name, OrtExecutionProviderFactory{});
ep_factory->ort_ep_factory.CreateExecutionProvider = CreateExecutionProviderCallback;
ep_factory->ort_ep_factory.state = ep_factor->get(); // Can optionally store custom state.
// This is a struct copy. This updates the output parameter to a copy of our EP factory callbacks.
*ort_ep_factory = ep_factory->ort_ep_factory;
DestroyOnDllUnload(std::move(ep_factory));
return nullptr;
}
#ifdef __cplusplus
}
#endif
OrtStatus* ORT_API_CALL CreateExecutionProviderCallback(const OrtExectutionProviderFactory* ort_ep_factory,
const char* const* option_keys,
const char* const* option_vals,
size_t num_options,
/*output*/ OrtExecutionProvider* ort_ep) {
TensorRTEpFactory* ep_factory = reinterpret_cast<TensorRTEpFactory*>(ort_ep_factory->state);
std::string err_msg;
auto ep = TensorRTEp::CreateInstance(ep_factory->ep_name, option_keys, option_vals, num_options, err_msg);
if (!err_msg.empty()) { /* return OrtStatus with error message */ }
*ort_ep = *(ep->GetOrtExecutionProvider()); // Struct copy. This updates the out param to a copy of our EP callbacks.
DestroyOnDllUnload(std::move(ep));
return nullptr;
}
*ep_context_graph = reinterpret_cast<OrtGraphViewer*>(graph_build_viewer.release()); | ||
} else { | ||
::onnxruntime::GraphViewer* content_graph_viewer = reinterpret_cast<::onnxruntime::GraphViewer*>(*ep_context_graph); | ||
graph_build = const_cast<::onnxruntime::Graph*>(&(content_graph_viewer->GetGraph())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems using const_cast
here might result in undefined behavior:
"Modifying a const object through a non-const access path and referring to a volatile object through a non-volatile glvalue results in undefined behavior." - https://en.cppreference.com/w/cpp/language/const_cast
, because later this function calls graph_build->GetOrCreateNodeArg
which might modify the "constant" graph instance. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment. Rolled back the change
|
||
namespace onnxruntime { | ||
|
||
static const std::string tensorrtEp = "tensorrtEp"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the same EP name set by the user application?
// Load the EP library and register EP-creation functions with ORT
status = api->RegisterPluginExecutionProviderLibrary(L"trt_ep_lib.dll", env, "trt_ep");
Seems like this should be initialized with the name set by the user, right? Otherwise, we can have a conflict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, they should be the same name
void* extra_param_for_create_state_func; | ||
void* extra_param_for_compute_func; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this may be able to be replaced with a single void* state;
. See https://github.com/microsoft/onnxruntime/pull/21450/files#r1868553280
* | ||
* \since Version 1.xx. | ||
*/ | ||
ORT_API2_STATUS(SessionOptionsAppendPluginExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this can be removed? We may be able to use the existing C API function called SessionOptionsAppendExecutionProvider. The existing C API does not take an OrtEnv parameter, but we can just get the default OrtEnv since there is only one per process.
General comment, please, re-format so the lines do not exceed 120 chars limit. |
@@ -325,6 +327,8 @@ class IExecutionProvider { | |||
return InlinedVector<const Node*>(); | |||
} | |||
|
|||
bool IsBuiltInEp() const { return builtin_ep_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#include <string> | ||
#include <set> | ||
|
||
struct OrtTypeConstraints { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#include "core/common/common.h" | ||
#include "core/common/status.h" | ||
#include "core/platform/threadpool.h" | ||
#include "core/common/logging/logging.h" | ||
#include "core/framework/allocator.h" | ||
#include "core/session/onnxruntime_c_api_ep.h" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we forward declare the types and avoid public header inclusion in the header?
@@ -5,11 +5,13 @@ | |||
|
|||
#include <atomic> | |||
#include <memory> | |||
#include <unordered_set> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* | ||
* \param[in] kernel_registry Opaque pointer of KernelRegistry object | ||
* \param[in] custom_op Custom Op where the kernel compute function is defined | ||
* \param[in] type_constraints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OrtExecutionProvider*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size); | ||
} OrtExecutionProviderFactory; | ||
|
||
struct OrtGraphApi { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really a read-only graph viewer.
Suggest to name it appropriatly
Perhaps, the name should reflect the fact that this API is specifically for EP interaction.
ONNXTensorElementDataType data_type; | ||
const char* data; | ||
size_t data_len; | ||
} OrtTensorRef; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be a repeat of the existing API such as TensorTypeAndShape. Can we re-use that part?
@@ -4665,7 +4671,128 @@ struct OrtApi { | |||
_In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array, | |||
_In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, | |||
size_t num_external_initializer_files); | |||
}; | |||
|
|||
/** \brief Create OrtDevice object. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion is to create a table that is separate from
the main API. Reasons:
Most of the clients do not need that code.
But it does affect language bindings.
For example, we now need to pad C# imported API structure, although it is unlikely we would ever need that in the C#, but if we do we can add that separate.
ORT_API2_STATUS(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* out); | ||
|
||
/** \brief Get the NodeIndex values of the graph nodes sorted in topological order | ||
* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the output ptr data is const, please, specify that it is not to be freed by the client.
size_t node_index_len; | ||
} OrtIndexedSubGraph; | ||
|
||
typedef struct OrtComputeContext { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest adding a constructor for _cplusplus
/** \brief Gets the path of the owning model if any | ||
* | ||
* \param[in] graph The graph to query | ||
* \param[out] model_path The path of the owning model if any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* \param[out] out True if the graph is a subgraph | ||
* | ||
*/ | ||
ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraphViewer* graph, _Out_ bool* out); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* \remarks The caller is responsible for freeing the byte array using OrtFreeMem. | ||
* | ||
*/ | ||
ORT_API2_STATUS(OrtGraph_SerializeToArray, const OrtGraphViewer* graph, _Out_ void** data, _Out_ size_t* data_size); // TODO(leca): review and discuss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* \param[in] onnx_model_path The file path to save to | ||
* | ||
*/ | ||
ORT_API2_STATUS(OrtGraph_DumpOnnxModel, const OrtGraphViewer* graph, const char* onnx_model_path); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto& eps = GetExecutionProviders(); | ||
for (auto& ep : eps) { | ||
ep->RegisterStreamHandlers(GetStreamHandleRegistryInstance(), *allocators_); | ||
std::string register_resource_after = ""; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::string register_resource_after = ""; | ||
IExecutionProvider* plugin_ep = nullptr; | ||
for (auto& ep : execution_providers_) { | ||
if (register_resource_after == "") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::unordered_map<std::string, std::set<ONNXTensorElementDataType>>::iterator iter = type_constraints_.find(type_symbol); | ||
if (iter == type_constraints_.end()) { | ||
std::set<ONNXTensorElementDataType> types{type}; | ||
type_constraints_[type_symbol] = types; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Modify CMakeLists.txt for TRT EP plugin - Add "-l" for specifying EP plugin lib path for onnxruntime_perf_test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some memory/lifetime comments for OrtIndexedSubGraph in GetCapability()
} OrtMetaDef; | ||
|
||
typedef struct OrtIndexedSubGraph { | ||
OrtMetaDef* meta_def; // TODO(leca): how to define a nested structure pointer? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I don't think it's necessary for this to be a separate memory allocation. I see that it was based on our internal IndexedSubGraph implementation, but in my view reducing the number of memory allocation makes things simpler. Perhaps meta_def
can be stored inline and can add a boolean to indicate if the meta_def is valid.
typedef struct OrtMetaDef {
bool is_valid;
// ...
} OrtMetaDef;
typedef struct OrtIndexedSubGraph {
OrtMetaDef meta_def;
// ...
} OrtIndexedSubGraph;
virtual std::vector<std::unique_ptr<ComputeCapability>> GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup) const override { | ||
size_t cnt = 0; | ||
OrtIndexedSubGraph** indexed_subgraph = nullptr; | ||
if (ep_impl_->GetCapability) ep_impl_->GetCapability(ep_impl_, reinterpret_cast<const OrtGraphViewer*>(&graph_viewer), &cnt, &indexed_subgraph); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is currently passing an OrtIndexedSubGraph***
to the EP's GetCapability() function and asking the EP to allocate an array of pointers to OrtIndexSubGraph objects. Because the EP is allocating the memory, we currently have a separate API to allow the EP to delete this memory.
I wonder if it would be simpler to allow ORT to pass in an OrtAllocator
to the EP. The EP would use this ORT-owned allocator to allocate memory for the array. This woud remove the need for a separate C API to clean up the memory. Also, it would allow the parameter to be a OrtIndexedSubGraph**
instead of OrtIndexedSubGraph***
.
Description
Out-Tree EP feature.
Motivation and Context