-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Out-Tree EP feature #21450
base: main
Are you sure you want to change the base?
[WIP] Out-Tree EP feature #21450
Changes from 23 commits
0e6a80c
c30a639
7bfe57e
8e7d28d
808bfc3
49e396c
e790105
92f529d
3d83ed1
e29499a
f3678c4
ac5ae0a
0cc78e8
740a687
dad6397
94e9cf7
8698517
3d5d2bf
1f10c28
5e46d0f
85c168d
7bdb36a
7d915b7
4aea94b
865a17f
2811541
c97b19f
36f97b5
2fc7aac
4ad6993
53c736f
5fcb972
c3bb437
d1c657c
3efac97
766fec9
ea2465c
76a9305
330cdb6
6fd50f0
681585f
7db20cb
ff782e0
1d7b2df
a407944
f871b25
e84f00c
5b2de22
b1f8e2a
7acaaab
d150a03
da5b6eb
d280e59
cbe98e7
1529059
fa549f8
a28ad38
aa49805
bc65613
a1a3eea
0fe5f01
6bae1b9
ab75d98
c5510f2
08e3f20
b0b3123
9dbb0b1
5a59803
999e7fd
084f735
2b1cfdf
e337d8f
afe92e1
63f8774
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "core/session/onnxruntime_c_api.h" | ||
#include <unordered_map> | ||
Check warning on line 6 in include/onnxruntime/core/framework/ort_type_constraints.h GitHub Actions / Optional Lint C++
|
||
#include <string> | ||
Check warning on line 7 in include/onnxruntime/core/framework/ort_type_constraints.h GitHub Actions / Optional Lint C++
|
||
#include <set> | ||
Check warning on line 8 in include/onnxruntime/core/framework/ort_type_constraints.h GitHub Actions / Optional Lint C++
|
||
|
||
struct OrtTypeConstraints { | ||
bool AddTypeConstraint(const char* type_symbol, ONNXTensorElementDataType type); | ||
inline const std::unordered_map<std::string, std::set<ONNXTensorElementDataType>>& GetTypeConstraints() const { return type_constraints_; }; | ||
Check warning on line 12 in include/onnxruntime/core/framework/ort_type_constraints.h GitHub Actions / Optional Lint C++
Check warning on line 12 in include/onnxruntime/core/framework/ort_type_constraints.h GitHub Actions / Optional Lint C++
|
||
private: | ||
Check warning on line 13 in include/onnxruntime/core/framework/ort_type_constraints.h GitHub Actions / Optional Lint C++
|
||
std::unordered_map<std::string, std::set<ONNXTensorElementDataType>> type_constraints_; | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,6 +88,10 @@ | |
*/ | ||
Status CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg = nullptr); | ||
|
||
void InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given SessionOptionsAppendOrtExecutionProvider allows the user to register the instance of the EP, when do we need this factory? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is another C API RegisterOrtExecutionProviderLibrary which will load the shared library, create plugin EP factory and save it in the Environment. Please see the implementation of RegisterOrtExecutionProviderLibrary and the usage in test.cpp as examples There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed to a new Name. Hope it is more clear now. |
||
|
||
OrtExecutionProviderFactory* GetOrtExecutionProviderFactory(const std::string& ep_name); | ||
|
||
private: | ||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment); | ||
Status Initialize(std::unique_ptr<logging::LoggingManager> logging_manager, | ||
|
@@ -99,5 +103,6 @@ | |
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_; | ||
bool create_global_thread_pools_{false}; | ||
std::vector<AllocatorPtr> shared_allocators_; | ||
std::unordered_map<std::string, std::unique_ptr<OrtExecutionProviderFactory>> custom_ep_factories_; | ||
Check warning on line 106 in include/onnxruntime/core/session/environment.h GitHub Actions / Optional Lint C++
Check warning on line 106 in include/onnxruntime/core/session/environment.h GitHub Actions / Optional Lint C++
|
||
}; | ||
} // namespace onnxruntime |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -304,6 +304,15 @@ | |
ORT_RUNTIME_CLASS(OpAttr); | ||
ORT_RUNTIME_CLASS(Logger); | ||
ORT_RUNTIME_CLASS(ShapeInferContext); | ||
ORT_RUNTIME_CLASS(ExecutionProvider); | ||
ORT_RUNTIME_CLASS(ExecutionProviderFactory); | ||
ORT_RUNTIME_CLASS(Node); | ||
ORT_RUNTIME_CLASS(Model); | ||
ORT_RUNTIME_CLASS(Graph); | ||
ORT_RUNTIME_CLASS(GraphViewer); | ||
ORT_RUNTIME_CLASS(KernelRegistry); | ||
ORT_RUNTIME_CLASS(TypeConstraints); | ||
ORT_RUNTIME_CLASS(Device); | ||
|
||
#ifdef _WIN32 | ||
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; | ||
|
@@ -395,6 +404,13 @@ | |
OrtMemoryInfoDeviceType_FPGA = 2 | ||
} OrtMemoryInfoDeviceType; | ||
|
||
typedef enum OrtMemoryType { | ||
OrtMemoryType_Default = 0, | ||
OrtMemoryType_CUDA_PINNED = 1, | ||
OrtMemoryType_HIP_PINNED = 2, | ||
OrtMemoryType_CANN_PINNED = 3, | ||
} OrtMemoryType; | ||
|
||
Comment on lines
+397
to
+403
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we avoid having EP specific enum values in the public API? If 'pinned' equates to 'visible by device and host' there's hopefully a better place to plugin device specific info like CUDA/HIP/CANN than the high level memory type. |
||
/** \brief Algorithm to use for cuDNN Convolution Op | ||
*/ | ||
typedef enum OrtCudnnConvAlgoSearch { | ||
|
@@ -689,6 +705,66 @@ | |
*/ | ||
ORT_EXPORT const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION; | ||
|
||
typedef struct OrtCreateStream { | ||
int device_type; | ||
void*(ORT_API_CALL* CreateStreamFunc)(const OrtDevice*); | ||
} OrtCreateStream; | ||
|
||
typedef struct OrtMetaDef { | ||
const char* name; | ||
const char* domain; | ||
int since_version; | ||
|
||
const char** inputs; | ||
size_t input_len; | ||
const char** outputs; | ||
size_t output_len; | ||
const char** constant_initializers; | ||
size_t initializer_len; | ||
|
||
const char* doc_string; | ||
} OrtMetaDef; | ||
|
||
typedef struct OrtIndexedSubGraph { | ||
OrtMetaDef* meta_def; // TODO(leca): how to define a nested structure pointer? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this have to be a pointer to an OrtMetaDef? It may be simpler if this meta_def is contained by value instead. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks we will check the pointer is null or not to distinguish between single node mode and fused node mode (See base class IExecutionProvider::GetCapability() which does not set this pointer and TryAssignSingleNode() which will check this pointer) |
||
size_t* node_index; | ||
Check warning on line 730 in include/onnxruntime/core/session/onnxruntime_c_api.h GitHub Actions / Optional Lint C++
|
||
size_t node_index_len; | ||
} OrtIndexedSubGraph; | ||
|
||
typedef struct OrtComputeContext { | ||
void*(ORT_API_CALL* AllocateFunc)(void*, size_t, size_t); | ||
void(ORT_API_CALL* DestroyFunc)(void*, void*); | ||
void* allocator_handle; | ||
const char* node_name; | ||
} OrtComputeContext; | ||
|
||
typedef struct OrtNodeComputeInfo { | ||
int(ORT_API_CALL* CreateFunctionStateFunc)(OrtComputeContext*, void*, void**); | ||
OrtStatusPtr(ORT_API_CALL* ComputeFunc)(void*, void*, const OrtApi*, OrtKernelContext*); | ||
Check warning on line 743 in include/onnxruntime/core/session/onnxruntime_c_api.h GitHub Actions / Optional Lint C++
|
||
void(ORT_API_CALL* DestroyFunctionStateFunc)(void*); | ||
} OrtNodeComputeInfo; | ||
|
||
typedef struct OrtExecutionProvider { | ||
#ifdef __cplusplus | ||
OrtExecutionProvider() : GetCapability{nullptr}, Compile{nullptr}, RegisterKernels{nullptr}, CanCopy{nullptr}, CopyTensor{nullptr}, type{nullptr}, create_stream{nullptr}, default_device{nullptr}, | ||
extra_param_for_create_state_func{nullptr}, extra_param_for_compute_func{nullptr} {} | ||
#endif | ||
void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph***); | ||
OrtStatusPtr(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraphViewer** graph, const OrtNode** node, size_t cnt, OrtNodeComputeInfo** node_compute_info); | ||
void(ORT_API_CALL* RegisterKernels)(OrtKernelRegistry* kernel_registry); | ||
bool(ORT_API_CALL* CanCopy)(const OrtDevice* source, const OrtDevice* target); | ||
OrtStatusPtr(ORT_API_CALL* CopyTensor)(const void* src, OrtMemoryInfoDeviceType source_device_type, OrtMemoryType source_mem_type, void* dst, OrtMemoryInfoDeviceType target_device_type, size_t count, void* stream); | ||
const char* type; | ||
OrtCreateStream* create_stream; | ||
const OrtDevice* default_device; | ||
void* extra_param_for_create_state_func; | ||
void* extra_param_for_compute_func; | ||
} OrtExecutionProvider; | ||
|
||
typedef struct OrtExecutionProviderFactory { | ||
OrtExecutionProvider*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_, const char* const* ep_option_keys, const char* const* ep_option_values, size_t option_size); | ||
} OrtExecutionProviderFactory; | ||
|
||
/** \brief Thread work loop function | ||
* | ||
* Onnxruntime will provide the working loop on custom thread creation | ||
|
@@ -4665,7 +4741,112 @@ | |
_In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array, | ||
_In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, | ||
size_t num_external_initializer_files); | ||
}; | ||
|
||
ORT_API2_STATUS(CreateDevice, _In_ enum OrtMemoryInfoDeviceType device_type, _In_ enum OrtMemoryType memory_type, _In_ int16_t device_id, _Outptr_ const OrtDevice** out); | ||
|
||
ORT_API2_STATUS(DeviceGetDeviceType, _In_ const OrtDevice* device, _Out_ OrtMemoryInfoDeviceType* out); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please document all functions including 'since version ...' info. It's easier to review if the intent of the API is documented. nit: Does 'device' need to be in the function name twice? #Resolved |
||
|
||
ORT_API2_STATUS(DeviceGetMemoryType, _In_ const OrtDevice* device, _Out_ OrtMemoryType* out); | ||
|
||
ORT_API2_STATUS(DeviceGetDeviceId, _In_ const OrtDevice* device, _Out_ int16_t* out); | ||
|
||
ORT_CLASS_RELEASE(Device); | ||
|
||
ORT_API2_STATUS(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name); | ||
|
||
ORT_API2_STATUS(SessionOptionsAppendOrtExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* ep_name, _In_ OrtEnv* env, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would |
||
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); | ||
|
||
ORT_API2_STATUS(OrtGraph_IsConstantInitializer, const OrtGraphViewer* graph, const char* name, bool check_outer_scope, _Out_ bool* ret); | ||
|
||
ORT_API2_STATUS(OrtGraph_GetNodesIndexInTopologicalOrder, const OrtGraphViewer* graph, int execution_order, _Out_ size_t* len, _Out_ const size_t** nodes_index_in_topological_order); | ||
|
||
ORT_API2_STATUS(OrtGraph_IsSubgraph, const OrtGraph* graph, _Out_ bool* ret); | ||
|
||
ORT_API2_STATUS(OrtGraph_GetParentGraph, const OrtGraph* graph, _Outptr_ const OrtGraph** parent_graph); | ||
|
||
ORT_API2_STATUS(OrtGraph_GetParenNode, const OrtGraphViewer* graph, _Outptr_ const OrtNode** parent_node); | ||
|
||
ORT_API2_STATUS(OrtGraph_GetModelPath, const OrtGraphViewer* graph, _Outptr_ const void** path); | ||
|
||
ORT_API2_STATUS(OrtGraph_GetOrtGraph, const OrtGraphViewer* graph_viewer, _Outptr_ const OrtGraph** graph); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to expose OrtGraph. We can keep them here and will see |
||
|
||
ORT_API2_STATUS(OrtGraph_GetInputsIncludingInitializers, const OrtGraphViewer* graph, _Out_ size_t* num_inputs, _Outptr_ const char*** input_names); | ||
|
||
ORT_API2_STATUS(OrtGraph_GetOrtNode, const OrtGraphViewer* graph, size_t node_index, _Outptr_ const OrtNode** node); | ||
|
||
ORT_API2_STATUS(OrtGraph_GetNodesConsumingInput, const OrtGraphViewer* graph, const char* input_name, _Out_ size_t* len, _Outptr_ const OrtNode*** consumers); // TODO(leca): ValueConsumers::comprehensive ? | ||
|
||
ORT_API2_STATUS(OrtGraph_GetNodeProducingOutput, const OrtGraphViewer* graph, const char* output_name, _Outptr_ const OrtNode** producer); | ||
|
||
int(ORT_API_CALL* OrtGraph_NumberOfNodes)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||
|
||
ORT_API2_STATUS(OrtGraph_MaxNodeIndex, const OrtGraphViewer* graph, _Out_ int* out); | ||
|
||
size_t(ORT_API_CALL* OrtGraph_GetOutputSize)(const OrtGraphViewer*)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||
|
||
const char*(ORT_API_CALL* OrtGraph_GetIthOutputName)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||
|
||
int32_t(ORT_API_CALL* OrtGraph_GetIthOutputElemType)(const OrtGraphViewer*, size_t i)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||
|
||
size_t(ORT_API_CALL* OrtGraph_SerializeToArray)(const OrtGraphViewer*, _Out_ void** data)NO_EXCEPTION; | ||
|
||
ORT_API2_STATUS(OrtNode_GetName, const OrtNode* node, _Out_ const char** name); | ||
|
||
ORT_API2_STATUS(OrtNode_GetDescription, const OrtNode* node, _Out_ const char** description); | ||
|
||
ORT_API2_STATUS(OrtNode_GetDomain, const OrtNode* node, _Out_ const char** domain); | ||
|
||
ORT_API2_STATUS(OrtNode_SinceVersion, const OrtNode* node, _Out_ int* since_version); | ||
|
||
ORT_API2_STATUS(OrtNode_GetExecutionProviderType, const OrtNode* node, _Out_ const char** ep_type); | ||
|
||
ORT_API2_STATUS(OrtNode_GetOpType, const OrtNode* node, _Out_ const char** op_type); | ||
|
||
ORT_API2_STATUS(OrtNode_GetImplicitInputSize, const OrtNode* node, _Out_ size_t* input_size); | ||
|
||
ORT_API2_STATUS(OrtNode_GetIthImplicitInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name); | ||
|
||
ORT_API2_STATUS(OrtNode_GetInputSize, const OrtNode* node, _Out_ size_t* input_size); | ||
|
||
ORT_API2_STATUS(OrtNode_GetIthInputName, const OrtNode* node, size_t i, _Out_ const char** ith_input_name); | ||
|
||
ORT_API2_STATUS(OrtNode_GetOutputSize, const OrtNode* node, _Out_ size_t* output_size); | ||
|
||
ORT_API2_STATUS(OrtNode_GetIthOutputName, const OrtNode* node, size_t i, _Out_ const char** ith_output_name); | ||
|
||
ORT_API2_STATUS(OrtNode_GetIndex, const OrtNode* node, _Out_ size_t* index); | ||
|
||
ORT_API2_STATUS(OrtNode_GetAttributeSize, const OrtNode* node, _Out_ size_t* attr_size); | ||
|
||
ORT_API2_STATUS(OrtNode_GetAttributeKeyCount, const OrtNode* node, const char* key, _Out_ size_t* count); | ||
|
||
ORT_API2_STATUS(OrtNode_GetAttributeIntSize, const OrtNode* node, const char* key, _Out_ int* int_size); | ||
|
||
ORT_API2_STATUS(OrtNode_GetAttributeFloatSize, const OrtNode* node, const char* key, _Out_ int* float_size); | ||
|
||
ORT_API2_STATUS(OrtNode_GetAttributeStringSize, const OrtNode* node, const char* key, _Out_ int* str_size); | ||
|
||
ORT_API2_STATUS(OrtNode_GetAttributeIthInt, const OrtNode* node, const char* key, int i, _Out_ int64_t* ints); | ||
|
||
ORT_API2_STATUS(OrtNode_GetAttributeIthFloat, const OrtNode* node, const char* key, int i, _Out_ float* floats); | ||
|
||
ORT_API2_STATUS(OrtNode_GetAttributeIthStr, const OrtNode* node, const char* key, int i, _Out_ const char** strs); | ||
|
||
const char*(ORT_API_CALL* OrtNode_GetAttributeStr)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||
|
||
int64_t(ORT_API_CALL* OrtNode_GetAttributeInt)(const OrtNode*, const char* key)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||
|
||
ORT_API2_STATUS(OrtNode_GetSubgraphs, const OrtNode* node, _Out_ size_t* len, _Outptr_ const OrtGraphViewer*** subgraphs); | ||
|
||
ORT_API2_STATUS(OrtKernelRegistry_RegisterKernel, OrtKernelRegistry* kernel_registry, OrtCustomOp* custom_op, OrtTypeConstraints* type_constraints); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: slightly more readable if this is after the type constraint functions given it takes OrtTypeConstaints as an input. #Resolved |
||
|
||
ORT_API2_STATUS(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type_constraints); | ||
|
||
ORT_API2_STATUS(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); | ||
|
||
ORT_CLASS_RELEASE(TypeConstraints); | ||
}; // struct OrtApi | ||
|
||
/* | ||
* Steps to use a custom op: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
Check warning Code scanning / lintrunner CLANGFORMAT/format Warning
See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch. |
||
// Licensed under the MIT License. | ||
|
||
#include "core/framework/ort_type_constraints.h" | ||
|
||
bool OrtTypeConstraints::AddTypeConstraint(const char* type_symbol, ONNXTensorElementDataType type) { | ||
std::unordered_map<std::string, std::set<ONNXTensorElementDataType>>::iterator iter = type_constraints_.find(type_symbol); | ||
if (iter == type_constraints_.end()) { | ||
std::set<ONNXTensorElementDataType> types{type}; | ||
type_constraints_[type_symbol] = types; | ||
return true; | ||
} | ||
return (iter->second).insert(type).second; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
Check warning Code scanning / lintrunner CLANGFORMAT/format Warning
See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch. |
||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "core/session/onnxruntime_c_api.h" | ||
#include "core/framework/compute_capability.h" | ||
|
||
namespace onnxruntime { | ||
|
||
class DataTransferAdapter : public IDataTransfer { | ||
public: | ||
DataTransferAdapter(OrtExecutionProvider* ep) : ep_impl_(ep) {} | ||
virtual bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override { | ||
return ep_impl_->CanCopy(&src_device, &dst_device); | ||
} | ||
|
||
virtual common::Status CopyTensor(const Tensor& src, Tensor& dst) const override { | ||
OrtMemoryInfoDeviceType source_device_type = static_cast<OrtMemoryInfoDeviceType>(src.Location().device.Type()); | ||
OrtMemoryInfoDeviceType target_device_type = static_cast<OrtMemoryInfoDeviceType>(dst.Location().device.Type()); | ||
OrtMemoryType source_mem_type = static_cast<OrtMemoryType>(src.Location().device.MemType()); | ||
return ToStatus(ep_impl_->CopyTensor(src.DataRaw(), source_device_type, source_mem_type, dst.MutableDataRaw(), target_device_type, src.SizeInBytes(), nullptr)); | ||
} | ||
|
||
virtual common::Status CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const override { | ||
OrtMemoryInfoDeviceType source_device_type = static_cast<OrtMemoryInfoDeviceType>(src.Location().device.Type()); | ||
OrtMemoryInfoDeviceType target_device_type = static_cast<OrtMemoryInfoDeviceType>(dst.Location().device.Type()); | ||
OrtMemoryType source_mem_type = static_cast<OrtMemoryType>(src.Location().device.MemType()); | ||
return ToStatus(ep_impl_->CopyTensor(src.DataRaw(), source_device_type, source_mem_type, dst.MutableDataRaw(), target_device_type, src.SizeInBytes(), stream.GetHandle())); | ||
} | ||
private: | ||
OrtExecutionProvider* ep_impl_; | ||
}; | ||
|
||
class ExecutionProviderAdapter : public IExecutionProvider { | ||
public: | ||
ExecutionProviderAdapter(OrtExecutionProvider* ep) : IExecutionProvider(ep->type, ep->default_device ? *(ep->default_device) : OrtDevice()), ep_impl_(ep) { | ||
if (ep_impl_->RegisterKernels) { | ||
kernel_registry_ = std::make_shared<KernelRegistry>(); | ||
ep_impl_->RegisterKernels(reinterpret_cast<OrtKernelRegistry*>(kernel_registry_.get())); | ||
} | ||
} | ||
virtual std::vector<std::unique_ptr<ComputeCapability>> GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup) const override { | ||
size_t cnt = 0; | ||
OrtIndexedSubGraph** indexed_subgraph = nullptr; | ||
if (ep_impl_->GetCapability) ep_impl_->GetCapability(ep_impl_, reinterpret_cast<const OrtGraphViewer*>(&graph_viewer), &cnt, &indexed_subgraph); | ||
|
||
if (cnt == 0) return IExecutionProvider::GetCapability(graph_viewer, kernel_lookup); | ||
|
||
std::vector<std::unique_ptr<ComputeCapability>> ret; | ||
for (size_t i = 0; i < cnt; i++) { | ||
std::unique_ptr<IndexedSubGraph> sb = std::make_unique<IndexedSubGraph>(); | ||
sb->nodes.reserve(indexed_subgraph[i]->node_index_len); | ||
for (size_t j = 0; j < indexed_subgraph[i]->node_index_len; j++) sb->nodes.push_back((indexed_subgraph[i]->node_index)[j]); | ||
if (indexed_subgraph[i]->meta_def != nullptr) { | ||
std::unique_ptr<IndexedSubGraph::MetaDef> meta_def = std::make_unique<IndexedSubGraph::MetaDef>(); | ||
meta_def->name = indexed_subgraph[i]->meta_def->name ? indexed_subgraph[i]->meta_def->name : ""; | ||
meta_def->doc_string = indexed_subgraph[i]->meta_def->doc_string ? indexed_subgraph[i]->meta_def->doc_string : ""; | ||
meta_def->domain = indexed_subgraph[i]->meta_def->domain ? indexed_subgraph[i]->meta_def->domain : ""; | ||
meta_def->since_version = indexed_subgraph[i]->meta_def->since_version; | ||
|
||
meta_def->inputs.reserve(indexed_subgraph[i]->meta_def->input_len); | ||
for (size_t j = 0; j < indexed_subgraph[i]->meta_def->input_len; j++) meta_def->inputs.push_back(indexed_subgraph[i]->meta_def->inputs[j]); | ||
|
||
meta_def->outputs.reserve(indexed_subgraph[i]->meta_def->output_len); | ||
for (size_t j = 0; j < indexed_subgraph[i]->meta_def->output_len; j++) meta_def->outputs.push_back(indexed_subgraph[i]->meta_def->outputs[j]); | ||
|
||
meta_def->constant_initializers.reserve(indexed_subgraph[i]->meta_def->initializer_len); | ||
for (size_t j = 0; j < indexed_subgraph[i]->meta_def->initializer_len; j++) meta_def->constant_initializers.push_back(indexed_subgraph[i]->meta_def->constant_initializers[j]); | ||
|
||
sb->SetMetaDef(std::move(meta_def)); | ||
} | ||
|
||
ret.push_back(std::make_unique<ComputeCapability>(std::move(sb))); | ||
} | ||
return ret; | ||
} | ||
|
||
virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs, std::vector<NodeComputeInfo>& node_compute_funcs) override { | ||
std::vector<const OrtGraphViewer*> ortGraphs; | ||
std::vector<const OrtNode*> ortNodes; | ||
for (auto& fused_node_graph : fused_nodes_and_graphs) { | ||
const GraphViewer& graph_viewer = fused_node_graph.filtered_graph; | ||
const Node& fused_node = fused_node_graph.fused_node; | ||
ortGraphs.push_back(reinterpret_cast<const OrtGraphViewer*>(&graph_viewer)); | ||
ortNodes.push_back(reinterpret_cast<const OrtNode*>(&fused_node)); | ||
} | ||
size_t count = fused_nodes_and_graphs.size(); | ||
std::vector<OrtNodeComputeInfo> cache; | ||
cache.resize(count); | ||
OrtNodeComputeInfo* cache_data = cache.data(); | ||
OrtStatus* ret = ep_impl_->Compile(ep_impl_, ortGraphs.data(), ortNodes.data(), count, &cache_data); | ||
if (ret != nullptr) return ToStatus(ret); | ||
node_compute_funcs.reserve(count); | ||
for (size_t i = 0; i < count; i++) { | ||
NodeComputeInfo compute_info; | ||
compute_info.create_state_func = [&, cache, i](ComputeContext* context, void** state) { | ||
if (cache[i].CreateFunctionStateFunc) return cache[i].CreateFunctionStateFunc(reinterpret_cast<OrtComputeContext*>(context), ep_impl_->extra_param_for_create_state_func, state); | ||
return 0; | ||
}; | ||
compute_info.compute_func = [&, cache, i](void* state, const OrtApi* api, OrtKernelContext* context) { | ||
return ToStatus(cache[i].ComputeFunc(state, ep_impl_->extra_param_for_compute_func, api, context)); | ||
}; | ||
compute_info.release_state_func = [&, cache, i](void* state) { | ||
if (cache[i].DestroyFunctionStateFunc) { | ||
cache[i].DestroyFunctionStateFunc(state); | ||
} | ||
}; | ||
node_compute_funcs.emplace_back(std::move(compute_info)); | ||
} | ||
|
||
return Status::OK(); | ||
} | ||
|
||
virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap&) const override { | ||
if (ep_impl_->create_stream) { | ||
CreateStreamFn csf = [&](const OrtDevice& device) -> std::unique_ptr<Stream> { | ||
void* stream = ep_impl_->create_stream->CreateStreamFunc(&device); | ||
return std::make_unique<Stream>(stream, device); | ||
}; | ||
stream_handle_registry.RegisterCreateStreamFn(static_cast<OrtDevice::DeviceType>(ep_impl_->create_stream->device_type), csf); | ||
} | ||
} | ||
|
||
virtual std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const override { | ||
return std::make_unique<DataTransferAdapter>(ep_impl_); | ||
} | ||
|
||
virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const override { return kernel_registry_; } | ||
private: | ||
OrtExecutionProvider* ep_impl_; | ||
std::shared_ptr<KernelRegistry> kernel_registry_; // TODO(leca): should be static local | ||
}; | ||
} |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning