Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Out-Tree EP feature #21450

Draft
wants to merge 45 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
0e6a80c
opaque pointer for graph
jslhcl Jul 17, 2024
c30a639
ORT C API RegisterOrtExecutionProviderLibrary work
jslhcl Jul 23, 2024
7bfe57e
ORT C-API SessionOptionsAppendOrtExecutionProvider work
jslhcl Jul 23, 2024
8e7d28d
Test Relu with compile based EP, build work, runtime error of loading…
jslhcl Jul 26, 2024
808bfc3
prototype works with hardcode node_compute_info's index in ExecutionP…
jslhcl Jul 29, 2024
49e396c
prototype works without hardcode
jslhcl Jul 29, 2024
e790105
fix comments for Compile function
jslhcl Jul 31, 2024
92f529d
add provider_factory_adapter.h
jslhcl Aug 1, 2024
3d83ed1
fix crash after introducing kernel based EP
jslhcl Aug 5, 2024
e29499a
kernel based EP work with type constraint check commented out
jslhcl Aug 6, 2024
f3678c4
add kernel type constraints from out tree EP
jslhcl Aug 7, 2024
ac5ae0a
add API ReleaseOrtTypeConstraints
jslhcl Aug 7, 2024
0cc78e8
introduce qnn ep
jslhcl Aug 12, 2024
740a687
more graph/node C API
jslhcl Aug 13, 2024
dad6397
stream support
jslhcl Aug 15, 2024
94e9cf7
support data transfer and OrtDevice in out tree EP API
jslhcl Aug 16, 2024
8698517
change compile return type from void to OrtStatusPtr
jslhcl Aug 20, 2024
3d5d2bf
add TensorRT dependency in tensorRT EP's CMakeLists.txt
jslhcl Aug 20, 2024
1f10c28
Add extra parameters in OrtExecutionProvider to avoid capture variabl…
jslhcl Aug 22, 2024
5e46d0f
add OrtGraph_SerializeToArray
jslhcl Aug 23, 2024
85c168d
finish Compile function
jslhcl Aug 24, 2024
7bdb36a
add override function implementation and cudart dependency for tensorrt
jslhcl Aug 26, 2024
7d915b7
add outOfTree tensorrt ep.1 (#21830)
guyang3532 Aug 27, 2024
4aea94b
GetSupportedList
jslhcl Aug 28, 2024
865a17f
GetSubGraph and TensorrtExecutionProviderInfo
jslhcl Aug 29, 2024
2811541
Add simple CUDA allocators for TRT EP (#21901)
chilo-ms Aug 29, 2024
c97b19f
add constructor for tensorrt ep and refine GetCapability (#21914)
guyang3532 Aug 29, 2024
36f97b5
relu can work on out tree TRT now
jslhcl Aug 29, 2024
2fc7aac
rebuild graph proto from scratch with the information needed from gra…
jslhcl Aug 31, 2024
4ad6993
complete the GetCapability (#21956)
guyang3532 Sep 2, 2024
53c736f
Chi's fix and reorder ep for registering shared resource
jslhcl Sep 4, 2024
5fcb972
complete the GetSubGraph (#21998)
guyang3532 Sep 5, 2024
c3bb437
run resnet18v1_7, crash on GetSubGraph()
jslhcl Sep 6, 2024
d1c657c
Merge branch 'leca/outOfTreeEP' of https://github.com/microsoft/onnxr…
jslhcl Sep 6, 2024
3efac97
resnet18-v1-7 works for TRT EP, with next_nodes_list assignment comme…
jslhcl Sep 6, 2024
766fec9
test cases for decoder and fast_rcnn, delete dynamic_cast in ShouldPo…
jslhcl Sep 9, 2024
ea2465c
add tensorrt home in CMakeLists, add trt and CUDA ep for test, change…
jslhcl Sep 11, 2024
76a9305
[WIP, DONT REVIEW] add initializer to graph proto (#22085)
jslhcl Sep 18, 2024
330cdb6
use parameter ExecutionOrder::PRIORITY_BASED for GraphViewerToProto()…
jslhcl Sep 19, 2024
6fd50f0
can create session with out tree trt ep now. Error:Name:'tensorrtEp_T…
jslhcl Sep 23, 2024
681585f
make trt_node_name_with_precision_ from string to map, to capture the…
jslhcl Sep 23, 2024
7db20cb
fix redundant inputs and outputs in GetSubgraph (#22201)
guyang3532 Sep 24, 2024
ff782e0
RunTinyYolov3()
jslhcl Sep 25, 2024
1d7b2df
fix bugs for run tinyYolo (#22233)
guyang3532 Sep 26, 2024
a407944
sample code to separate graph C API to different files
jslhcl Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@
*/
Status CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg = nullptr);

void InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory);

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment);
Status Initialize(std::unique_ptr<logging::LoggingManager> logging_manager,
Expand All @@ -99,5 +101,6 @@
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_;
bool create_global_thread_pools_{false};
std::vector<AllocatorPtr> shared_allocators_;
std::map<std::string, OrtExecutionProviderFactory*> custom_ep_factories_;

Check warning on line 104 in include/onnxruntime/core/session/environment.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <map> for map<> [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/session/environment.h:104: Add #include <map> for map<> [build/include_what_you_use] [4]

Check warning on line 104 in include/onnxruntime/core/session/environment.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/session/environment.h:104: Add #include <string> for string [build/include_what_you_use] [4]
};
} // namespace onnxruntime
16 changes: 16 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,9 @@
ORT_RUNTIME_CLASS(OpAttr);
ORT_RUNTIME_CLASS(Logger);
ORT_RUNTIME_CLASS(ShapeInferContext);
ORT_RUNTIME_CLASS(ExecutionProvider);
ORT_RUNTIME_CLASS(ExecutionProviderFactory);
ORT_RUNTIME_CLASS(Graph);

#ifdef _WIN32
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr;
Expand Down Expand Up @@ -681,6 +684,15 @@
const char*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION;
};

typedef struct OrtExecutionProvider {
//void(ORT_API_CALL* GetCapability)(const OrtExecutionProvider* this_, const OrtGraph* graph, _Out_ int* cnt, _Outptr_ OrtComputeCapability** compute_capability);

Check warning on line 688 in include/onnxruntime/core/session/onnxruntime_c_api.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: include/onnxruntime/core/session/onnxruntime_c_api.h:688: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 688 in include/onnxruntime/core/session/onnxruntime_c_api.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: include/onnxruntime/core/session/onnxruntime_c_api.h:688: Should have a space between // and comment [whitespace/comments] [4]
//void(ORT_API_CALL* Compile)(OrtExecutionProvider* this_, const OrtGraph* graph, const OrtNode* node, int size, _Out_ int* cnt, _Outptr_ OrtNodeComputeInfo** node_compute_info);

Check warning on line 689 in include/onnxruntime/core/session/onnxruntime_c_api.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: include/onnxruntime/core/session/onnxruntime_c_api.h:689: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 689 in include/onnxruntime/core/session/onnxruntime_c_api.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: include/onnxruntime/core/session/onnxruntime_c_api.h:689: Should have a space between // and comment [whitespace/comments] [4]
} OrtExecutionProvider;

typedef struct OrtExecutionProviderFactory {
void*(ORT_API_CALL* CreateExecutionProvider)(OrtExecutionProviderFactory* this_);
} OrtExecutionProviderFactory;

typedef struct OrtApiBase OrtApiBase;

/** \brief The Onnxruntime library's entry point to access the C API
Expand Down Expand Up @@ -4665,6 +4677,8 @@
_In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array,
_In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths,
size_t num_external_initializer_files);

ORT_API2_STATUS(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name);
};

/*
Expand Down Expand Up @@ -4825,6 +4839,8 @@
*/
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id);

ORT_API(bool, OrtGraph_IsConstantInitializer, const OrtGraph* graph, const char* name, bool check_outer_scope);

#ifdef __cplusplus
}
#endif
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/session/environment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,4 +348,8 @@ Status Environment::CreateAndRegisterAllocatorV2(const std::string& provider_typ
return Status{ONNXRUNTIME, common::INVALID_ARGUMENT, provider_type + " is not implemented in CreateAndRegisterAllocatorV2()"};
}

void Environment::InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory) {
custom_ep_factories_.insert({std::string(ep_name), ep_factory});
}

} // namespace onnxruntime
22 changes: 22 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "core/common/safeint.h"
#include "core/graph/constants.h"
#include "core/graph/graph.h"
#include "core/graph/graph_viewer.h"
#include "core/framework/allocator.h"
#include "core/framework/tensor.h"
#include "core/framework/ort_value.h"
Expand Down Expand Up @@ -2353,6 +2354,20 @@
#endif
}

ORT_API_STATUS_IMPL(OrtApis::RegisterOrtExecutionProviderLibrary, _In_ const char* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name) {

Check warning on line 2357 in onnxruntime/core/session/onnxruntime_c_api.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/session/onnxruntime_c_api.cc:2357: Lines should be <= 120 characters long [whitespace/line_length] [2]
API_IMPL_BEGIN
void* handle = nullptr;
ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(ToPathString(lib_path), false, &handle));
if (handle) {
OrtExecutionProviderFactory* (*symbol)();
ORT_THROW_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle, "RegisterCustomEp", (void**)&symbol));

Check warning on line 2363 in onnxruntime/core/session/onnxruntime_c_api.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<void**>(...) instead [readability/casting] [4] Raw Output: onnxruntime/core/session/onnxruntime_c_api.cc:2363: Using C-style cast. Use reinterpret_cast<void**>(...) instead [readability/casting] [4]
env->InsertCustomEp(ep_name, symbol());
return nullptr;
}
return CreateStatus(ORT_RUNTIME_EXCEPTION, "cannot load the shared library for out-tree EP");
API_IMPL_END
}

static constexpr OrtApiBase ort_api_base = {
&OrtApis::GetApi,
&OrtApis::GetVersionString};
Expand Down Expand Up @@ -2730,6 +2745,8 @@
&OrtApis::KernelInfoGetAllocator,
&OrtApis::AddExternalInitializersFromFilesInMemory,
// End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information)

&OrtApis::RegisterOrtExecutionProviderLibrary,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down Expand Up @@ -2802,3 +2819,8 @@
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession)
DEFINE_RELEASE_ORT_OBJECT_FUNCTION(ModelMetadata, ::onnxruntime::ModelMetadata)

ORT_API(bool, OrtGraph_IsConstantInitializer, const OrtGraph* graph, const char* name, bool check_outer_scope) {
::onnxruntime::GraphViewer graph_viewer(*(reinterpret_cast<const ::onnxruntime::Graph*>(graph)));
return graph_viewer.IsConstantInitializer(std::string(name), check_outer_scope);
}
2 changes: 2 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -523,4 +523,6 @@
ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);

ORT_API_STATUS_IMPL(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out);

ORT_API_STATUS_IMPL(RegisterOrtExecutionProviderLibrary, _In_ const ORTCHAR_T* lib_path, _In_ OrtEnv* env, _In_ const char* ep_name);

Check warning on line 527 in onnxruntime/core/session/ort_apis.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/session/ort_apis.h:527: Lines should be <= 120 characters long [whitespace/line_length] [2]
} // namespace OrtApis
4 changes: 4 additions & 0 deletions onnxruntime/core/session/ort_env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,7 @@ onnxruntime::common::Status OrtEnv::UnregisterAllocator(const OrtMemoryInfo& mem
onnxruntime::common::Status OrtEnv::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg) {
return value_->CreateAndRegisterAllocatorV2(provider_type, mem_info, options, arena_cfg);
}

void OrtEnv::InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory) {
value_->InsertCustomEp(ep_name, ep_factory);
}
2 changes: 2 additions & 0 deletions onnxruntime/core/session/ort_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ struct OrtEnv {
~OrtEnv();
onnxruntime::common::Status CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg = nullptr);

void InsertCustomEp(const char* ep_name, OrtExecutionProviderFactory* ep_factory);

private:
static std::unique_ptr<OrtEnv> p_instance_;
static onnxruntime::OrtMutex m_;
Expand Down
10 changes: 10 additions & 0 deletions samples/c_test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# usage:
# cd build/
# cmake -S ../ -B ./
# cmake --build ./
cmake_minimum_required(VERSION 3.26)
project(TestOutTreeEp)
add_executable(TestOutTreeEp test.cpp)

target_include_directories(TestOutTreeEp PUBLIC "../../include/onnxruntime")
target_link_libraries(TestOutTreeEp PUBLIC "/home/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so")
14 changes: 14 additions & 0 deletions samples/c_test/test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "core/session/onnxruntime_c_api.h"
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning test

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.

inline void THROW_ON_ERROR(OrtStatus* status) {
if (status != nullptr) abort();
}

int main() {
const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION);
OrtEnv* p_env = nullptr;
OrtLoggingLevel log_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO;
THROW_ON_ERROR(g_ort->CreateEnv(log_level, "", &p_env));
THROW_ON_ERROR(g_ort->RegisterOrtExecutionProviderLibrary("/home/leca/code/onnxruntime/samples/outTreeEp/build/liboutTreeEp.so", p_env, "outTreeEp"));

Check warning on line 12 in samples/c_test/test.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: samples/c_test/test.cpp:12: Lines should be <= 120 characters long [whitespace/line_length] [2]
return 0;
}
9 changes: 9 additions & 0 deletions samples/outTreeEp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# usage:
# cd build/
# cmake -S ../ -B ./
# cmake --build ./
cmake_minimum_required(VERSION 3.26)
project(outTreeEp VERSION 1.0)
set(CMAKE_CXX_STANDARD 17)
add_library(outTreeEp SHARED out_tree_ep.cc)
target_include_directories(outTreeEp PUBLIC "../../include/onnxruntime")
22 changes: 22 additions & 0 deletions samples/outTreeEp/out_tree_ep.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "out_tree_ep.h"
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
#include <memory>
namespace onnxruntime {

OutTreeEpFactory::OutTreeEpFactory() {
OrtExecutionProviderFactory::CreateExecutionProvider = [](OrtExecutionProviderFactory* this_) -> void* {
std::unique_ptr<OutTreeEp> ret = std::make_unique<OutTreeEp>();
return ret.release(); };
}

}

#ifdef __cplusplus
extern "C" {
#endif
OrtExecutionProviderFactory* RegisterCustomEp() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return Status instead

std::unique_ptr<onnxruntime::OutTreeEpFactory> ret = std::make_unique<onnxruntime::OutTreeEpFactory>();
return ret.release();
}
#ifdef __cplusplus
}
#endif
35 changes: 35 additions & 0 deletions samples/outTreeEp/out_tree_ep.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
#include "core/session/onnxruntime_c_api.h"
#include <string>

#ifdef _WIN32
#define EXPORT_API __declspec(dllexport)
#else
#define EXPORT_API
#endif

namespace onnxruntime {

struct OutTreeEpInfo {
int int_property;
std::string str_property;
};

struct OutTreeEp : public OrtExecutionProvider {
OutTreeEp() {}
};

struct OutTreeEpFactory : public OrtExecutionProviderFactory {
OutTreeEpFactory();
};
}

#ifdef __cplusplus
extern "C" {
#endif

EXPORT_API OrtExecutionProviderFactory* RegisterCustomEp();

#ifdef __cplusplus
}
#endif
Loading