From ebdbbb7531f6be3a4df7901ff5482a8174b51bd7 Mon Sep 17 00:00:00 2001 From: Yueqing Zhang Date: Fri, 20 Dec 2024 22:03:27 -0800 Subject: [PATCH 1/3] [VitisAI] Int4 support (#22850) ### Description 1. Add support for throwing error when hardware is not supported for VitisAI. 2. Add support for unloading VitisAI EP. 3. Add API for Win25. ### Motivation and Context This is requirement for Win25 --- .../shared_library/provider_interfaces.h | 1 + .../shared_library/provider_wrappedtypes.h | 3 ++ .../core/providers/vitisai/imp/global_api.cc | 52 ++++++++++++++++--- .../providers/vitisai/imp/tensor_proto.cc | 13 +++++ .../core/providers/vitisai/imp/tensor_proto.h | 4 ++ .../vitisai/include/vaip/global_api.h | 1 + .../providers/vitisai/include/vaip/my_ort.h | 1 + .../vitisai/include/vaip/vaip_ort_api.h | 10 +++- .../vitisai/vitisai_provider_factory.cc | 2 +- .../core/session/provider_bridge_ort.cc | 1 + 10 files changed, 80 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 8bd4067e59492..5a179ec622f8c 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -589,6 +589,7 @@ struct ProviderHost { virtual const ConfigOptions& RunOptions__GetConfigOptions(const RunOptions* p) = 0; // OrtSessionOptions virtual const std::unordered_map& SessionOptions__GetConfigOptionsMap(const OrtSessionOptions* p) = 0; + virtual bool SessionOptions__GetEnableProfiling(const OrtSessionOptions* p) = 0; // ComputeCapability virtual std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) = 0; virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index d8516d5858a2f..76b6d8063fd66 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1476,5 +1476,8 @@ struct OrtSessionOptions final { const std::unordered_map& GetConfigOptions() const { return onnxruntime::g_host->SessionOptions__GetConfigOptionsMap(this); } + bool GetEnableProfiling() const { + return onnxruntime::g_host->SessionOptions__GetEnableProfiling(this); + } PROVIDER_DISALLOW_ALL(OrtSessionOptions) }; diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index cccaa65de45f2..8111ee3c1fe61 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -47,6 +47,8 @@ struct OrtVitisAIEpAPI { void (*initialize_onnxruntime_vitisai_ep)(vaip_core::OrtApiForVaip* api, std::vector& ret_domain); std::vector>* (*compile_onnx_model_with_options)( const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); + std::vector>* (*compile_onnx_model_vitisai_ep_with_error_handling)( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options, void* status, vaip_core::error_report_func func); uint32_t (*vaip_get_version)(); void (*create_ep_context_nodes)( const std::vector>& eps, @@ -77,10 +79,11 @@ struct OrtVitisAIEpAPI { ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_)); #endif ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "initialize_onnxruntime_vitisai_ep", (void**)&initialize_onnxruntime_vitisai_ep)); - auto status = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options); - if (!status.IsOK()) { - ::onnxruntime::LogRuntimeError(0, status, __FILE__, static_cast(__FUNCTION__), __LINE__); - ORT_THROW(status); + auto status1 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_error_handling", (void**)&compile_onnx_model_vitisai_ep_with_error_handling); + auto status2 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options); + if ((!status1.IsOK()) && (!status2.IsOK())) { + ::onnxruntime::LogRuntimeError(0, status2, __FILE__, static_cast(__FUNCTION__), __LINE__); + ORT_THROW(status2); } std::ignore = env.GetSymbolFromLibrary(handle_, "vaip_get_version", (void**)&vaip_get_version); @@ -89,6 +92,14 @@ struct OrtVitisAIEpAPI { ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_on_run_start", (void**)&vitisai_ep_on_run_start)); ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_set_ep_dynamic_options", (void**)&vitisai_ep_set_ep_dynamic_options)); } + void Clear() { + if (handle_) { + auto& env = Provider_GetHost()->Env__Default(); + auto status = env.UnloadDynamicLibrary(handle_); + vai_assert(status.IsOK(), status.ErrorMessage()); + handle_ = nullptr; + } + } private: void* handle_{}; @@ -109,10 +120,25 @@ void profiler_collect( } } +void change_status_with_error(void* status_ptr, int error_code, const char* error_msg) { + auto status = reinterpret_cast(status_ptr); + *status = Status(onnxruntime::common::ONNXRUNTIME, error_code, error_msg); +} + vaip_core::DllSafe>> compile_onnx_model( - const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { + const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options) { auto model_path = graph_viewer.ModelPath().string(); - return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options)); + if (s_library_vitisaiep.compile_onnx_model_vitisai_ep_with_error_handling) { + Status status = Status::OK(); + auto status_ptr = reinterpret_cast(&status); + auto ret = vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_vitisai_ep_with_error_handling(model_path, graph_viewer.GetGraph(), options, status_ptr, change_status_with_error)); + if (!status.IsOK()) { + ORT_THROW(status); + } + return ret; + } else { + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options)); + } } std::optional> create_ep_context_nodes( @@ -396,10 +422,12 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.tensor_proto_get_shape_unsafe = vaip::tensor_proto_get_shape; the_global_api.tensor_proto_data_type = [](const ONNX_NAMESPACE::TensorProto& t) -> int { return t.data_type(); }; the_global_api.tensor_proto_delete = [](ONNX_NAMESPACE::TensorProto* tp) { delete tp; }; + the_global_api.tensor_proto_new_i4 = vaip::tensor_proto_new_i4; the_global_api.tensor_proto_new_i8 = vaip::tensor_proto_new_i8; the_global_api.tensor_proto_new_i16 = vaip::tensor_proto_new_i16; the_global_api.tensor_proto_new_i32 = vaip::tensor_proto_new_i32; the_global_api.tensor_proto_new_i64 = vaip::tensor_proto_new_i64; + the_global_api.tensor_proto_new_u4 = vaip::tensor_proto_new_u4; the_global_api.tensor_proto_new_u8 = vaip::tensor_proto_new_u8; the_global_api.tensor_proto_new_u16 = vaip::tensor_proto_new_u16; the_global_api.tensor_proto_new_u32 = vaip::tensor_proto_new_u32; @@ -468,9 +496,21 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return vaip_core::DllSafe(std::move(local_str)); }; + the_global_api.is_profiling_enabled = [](void* session_options) { + auto options = reinterpret_cast(session_options); + return options->GetEnableProfiling(); + }; + the_global_api.graph_remove_initialized_tensor = [](Graph& graph, const std::string& tensor_name) { + graph.RemoveInitializedTensor(tensor_name); + }; if (!s_library_vitisaiep.vaip_get_version) { return reinterpret_cast(&(the_global_api.host_)); } else { return &the_global_api; } } + +void deinitialize_vitisai_ep() { + s_library_vitisaiep.Clear(); + s_kernel_registry_vitisaiep.reset(); +} diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index 872d022e85264..bb942c69003a1 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -87,6 +87,12 @@ static ONNX_NAMESPACE::TensorProto* tensor_proto_new(const std::string& name, co return tensor_proto.release(); } +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i4(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT4, + reinterpret_cast(&data[0]), data.size() * sizeof(data[0])); +} + ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector& shape, const std::vector& data) { return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT8, @@ -108,6 +114,13 @@ ONNX_NAMESPACE::TensorProto* tensor_proto_new_i64(const std::string& name, const return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT64, reinterpret_cast(&data[0]), data.size() * sizeof(data[0])); } + +ONNX_NAMESPACE::TensorProto* tensor_proto_new_u4(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_UINT4, + reinterpret_cast(&data[0]), data.size() * sizeof(data[0])); +} + ONNX_NAMESPACE::TensorProto* tensor_proto_new_u8(const std::string& name, const std::vector& shape, const std::vector& data) { return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_UINT8, diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h index 618d9c4728e2f..73015d3411a54 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h @@ -9,6 +9,10 @@ namespace vaip { gsl::span tensor_proto_as_raw(const onnxruntime::Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor); vaip_core::DllSafe> tensor_proto_get_shape(const ONNX_NAMESPACE::TensorProto& tensor); const std::string& tensor_proto_get_name(const ONNX_NAMESPACE::TensorProto& tensor); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i4(const std::string& name, const std::vector& shape, + const std::vector& data); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_u4(const std::string& name, const std::vector& shape, + const std::vector& data); ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector& shape, const std::vector& data); ONNX_NAMESPACE::TensorProto* tensor_proto_new_u8(const std::string& name, const std::vector& shape, diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index 704b156dff57f..7791ea430054a 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -11,6 +11,7 @@ #include "vaip/custom_op.h" #include void initialize_vitisai_ep(); +void deinitialize_vitisai_ep(); vaip_core::DllSafe>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options); std::shared_ptr get_kernel_registry_vitisaiep(); const std::vector& get_domains_vitisaiep(); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h index 7628e45d2b933..85a1262d8489b 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h @@ -122,4 +122,5 @@ using InitializedTensorSet = std::unordered_map; using ModelMetaData = std::unordered_map; +using error_report_func = void (*)(void*, int, const char*); } // namespace vaip_core diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 9425c08dceebc..6a51ef862280b 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,7 +13,7 @@ struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (12u) +#define VAIP_ORT_API_MAJOR (13u) #define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { @@ -235,6 +235,14 @@ struct OrtApiForVaip { DllSafe (*model_proto_serialize_as_string)(ModelProto& model_proto); // [96] void (*model_proto_delete)(ModelProto* p); // [97] DllSafe (*attr_proto_release_string)(AttributeProto* attr); // [98] + bool (*is_profiling_enabled)(void* session_options); // [99] // [98] + TensorProto* (*tensor_proto_new_i4)(const std::string& name, + const std::vector& shape, + const std::vector& data); // [100] + TensorProto* (*tensor_proto_new_u4)(const std::string& name, + const std::vector& shape, + const std::vector& data); // [101] + void (*graph_remove_initialized_tensor)(Graph& graph, const std::string& tensor_name); // [102] }; #ifndef USE_VITISAI diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index 453db30e1320f..99d9845302d9a 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -50,7 +50,7 @@ struct VitisAI_Provider : Provider { // Called right after loading the shared library, if this throws any errors Shutdown() will be called and the library unloaded void Initialize() override { initialize_vitisai_ep(); } // Called right before unloading the shared library - void Shutdown() override {} + void Shutdown() override { deinitialize_vitisai_ep(); } } g_provider; } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index a40fabd6a607c..af39edae2074d 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -720,6 +720,7 @@ struct ProviderHostImpl : ProviderHost { // OrtSessionOptions (wrapped) const std::unordered_map& SessionOptions__GetConfigOptionsMap(const OrtSessionOptions* p) override { return p->value.config_options.configurations; } + bool SessionOptions__GetEnableProfiling(const OrtSessionOptions* p) override { return p->value.enable_profiling; }; // ComputeCapability (wrapped) std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) override { return std::make_unique(std::move(t_sub_graph)); } void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; } From c6ba7edd830087bc52311a3b10b1f0692ef64b3b Mon Sep 17 00:00:00 2001 From: amancini-N <63410090+amancini-N@users.noreply.github.com> Date: Mon, 23 Dec 2024 06:30:49 +0100 Subject: [PATCH 2/3] Enable pointer-generator T5 models in BeamSearch (#23134) ### Description Introduces a new optional input (encoder_ibnput_ids) in the decoder graph of the T5 implementation for BeamSearch. This allows usage of pointer generator networks in decoder graph. ### Motivation and Context - Fixes #23123 --- .../cpu/transformers/subgraph_t5_decoder.cc | 65 ++- .../cpu/transformers/subgraph_t5_decoder.h | 10 +- .../test/contrib_ops/beam_search_test.cc | 22 + .../test/testdata/dummy_t5_model_generator.py | 377 ++++++++++++++++++ .../testdata/dummy_t5_pointer_generator.onnx | Bin 0 -> 7100 bytes 5 files changed, 448 insertions(+), 26 deletions(-) create mode 100644 onnxruntime/test/testdata/dummy_t5_model_generator.py create mode 100644 onnxruntime/test/testdata/dummy_t5_pointer_generator.onnx diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index f4e7173c917c1..997beb198f450 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -20,8 +20,9 @@ namespace transformers { Inputs: input_ids: int32 (B, 1) + encoder_input_ids: int32 (B, encode_sequence_length) (optional) encoder_attention_mask: int32 (B, encode_sequence_length) - encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) + encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) (optional) past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size) past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size) @@ -49,11 +50,9 @@ namespace transformers { Status T5DecoderSubgraph::Validate(const std::vector& subgraph_inputs, const std::vector& subgraph_outputs) { - bool has_hidden_state = subgraph_inputs[2]->Name() == "encoder_hidden_states" ? true : false; - SetPastInputIndex(has_hidden_state); - - ORT_RETURN_IF(first_past_input_index_ != 2 && first_past_input_index_ != 3, - "kFirstPastInputIndex currently only supports 2 or 3"); + bool has_encoder_input_ids = subgraph_inputs[1]->Name() == "encoder_input_ids"; + bool has_hidden_state = subgraph_inputs[2 + has_encoder_input_ids]->Name() == "encoder_hidden_states"; + SetPastInputIndex(has_hidden_state, has_encoder_input_ids); if (!past_present_share_buffer_) { ORT_RETURN_IF(has_decoder_masked_attention_, "decoder_masked_attention shall use with past_present_share_buffer"); @@ -75,13 +74,17 @@ Status T5DecoderSubgraph::Validate(const std::vector& subgraph_i ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name()); - ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask", - "decoder subgraph input 1 shall be named as encoder_attention_mask, got: ", - subgraph_inputs[1]->Name()); - if (first_past_input_index_ == 3) { - ORT_RETURN_IF(subgraph_inputs[2]->Name() != "encoder_hidden_states", - "decoder subgraph input 2 shall be named as encoder_hidden_states, got: ", - subgraph_inputs[2]->Name()); + const int enc_attn_mask_index = 1 + has_encoder_input_ids_; + const int enc_hidden_state_index = enc_attn_mask_index + 1; + ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->Name() != "encoder_attention_mask", + "decoder subgraph input ", std::to_string(enc_attn_mask_index), + " shall be named as encoder_attention_mask, got: ", + subgraph_inputs[enc_attn_mask_index]->Name()); + if (has_hidden_state_) { + ORT_RETURN_IF(subgraph_inputs[enc_hidden_state_index]->Name() != "encoder_hidden_states", + "decoder subgraph input ", std::to_string(enc_hidden_state_index), + " shall be named as encoder_hidden_states, got: ", + subgraph_inputs[enc_hidden_state_index]->Name()); } // check subgraph outputs @@ -108,12 +111,19 @@ Status T5DecoderSubgraph::Validate(const std::vector& subgraph_i ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type, "decoder subgraph input 0 (input_ids) shall have int32 type"); - ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type, - "decoder subgraph input 1 (encoder_attention_mask) shall have int32 type"); - - auto float_type = subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type(); - ORT_RETURN_IF(float_type != float32_type && float_type != float16_type, - "decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type"); + if (has_encoder_input_ids_) { + ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "decoder subgraph input 1 (encoder_input_ids) shall have int32 type"); + } + ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "decoder subgraph input ", std::to_string(enc_attn_mask_index), + " (encoder_attention_mask) shall have int32 type"); + + auto float_type = subgraph_inputs[enc_hidden_state_index]->TypeAsProto()->tensor_type().elem_type(); + if (has_hidden_state_) { + ORT_RETURN_IF(float_type != float32_type && float_type != float16_type, + "decoder subgraph input ", std::to_string(enc_hidden_state_index), " (encoder_hidden_states) shall have float or float16 type"); + } for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) { ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type, @@ -219,6 +229,19 @@ Status T5DecoderSubgraph::CreateInitialFeeds( decoder_feeds.reserve(static_cast(num_subgraph_inputs) + static_cast(num_implicit_inputs)); decoder_feeds.push_back(input_ids); + if (has_encoder_input_ids_) { + // The encoder_input_ids is copied from the first input of encoder. + OrtValue expanded_encoder_input_ids; + ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream, + encoder_feeds[0], + num_beam, + allocator, + expanded_encoder_input_ids, + false, + 0 /*max_sequence_length*/)); + decoder_feeds.push_back(expanded_encoder_input_ids); + } + // The encoder_attention_mask is copied from the second input of encoder. OrtValue expanded_decoder_attention_masks; ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream, @@ -238,7 +261,9 @@ Status T5DecoderSubgraph::CreateInitialFeeds( // When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output // of encoder. // When first_past_input_index_ == 2, the past states are copied from the second output of encoder. - for (size_t j = static_cast(4) - first_past_input_index_; j < encoder_fetches.size(); j++) { + // TODO - probably more robust to introduce a encoder_out/decoder_in mapping instead of relying on positions. + // What happens if encoder_hidden_states is present in the encoder_fetches but not in the decoder_feeds? + for (size_t j = static_cast(2) - has_hidden_state_; j < encoder_fetches.size(); j++) { if (j == 1) { ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false"); OrtValue expanded_hidden_states; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index a72ce37a93aba..b5d727b67924c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -54,13 +54,10 @@ class T5DecoderSubgraph : public Subgraph { Status Validate(const std::vector& subgraph_inputs, const std::vector& subgraph_outputs) override; - void SetPastInputIndex(bool has_hidden_state) { + void SetPastInputIndex(bool has_hidden_state, bool has_encoder_input_ids) { has_hidden_state_ = has_hidden_state; - if (!has_hidden_state_) { - first_past_input_index_ = 2; - } else { - first_past_input_index_ = 3; - } + has_encoder_input_ids_ = has_encoder_input_ids; + first_past_input_index_ = 2 + has_hidden_state_ + has_encoder_input_ids_; } int GetFirstPastInputIndex() const { @@ -79,6 +76,7 @@ class T5DecoderSubgraph : public Subgraph { int first_past_input_index_; int first_present_output_index_; bool has_hidden_state_; + bool has_encoder_input_ids_; bool use_sequence_as_input_ids_; }; diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 9f4ee071925b4..1ae15afdf7482 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -394,6 +394,8 @@ TEST(BeamSearchTest, DummyT5) { #if defined(USE_CUDA) && defined(USE_DML) SKIP_CUDA_TEST_WITH_DML; #endif + // dummy_t5.onnx model generated using following command: + // python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5.onnx ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5.onnx")); tester.ConfigEp(DefaultCpuExecutionProvider()); tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); @@ -408,6 +410,8 @@ TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) { #if defined(USE_CUDA) && defined(USE_DML) SKIP_CUDA_TEST_WITH_DML; #endif + // dummy_t5_with_outer_scope_initializers.onnx model generated using following command: + // python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_with_outer_scope_initializers.onnx --move-initializers ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_outer_scope_initializers.onnx")); tester.ConfigEp(DefaultCpuExecutionProvider()); tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); @@ -422,6 +426,8 @@ TEST(BeamSearchTest, DummyT5WithSequenceInputIds) { #if defined(USE_CUDA) && defined(USE_DML) SKIP_CUDA_TEST_WITH_DML; #endif + // dummy_t5_with_sequence_input_ids.onnx model generated using following command: + // python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_with_sequence_input_ids.onnx --sequence-as-input ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx")); tester.ConfigEp(DefaultCpuExecutionProvider()); tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8}); @@ -432,5 +438,21 @@ TEST(BeamSearchTest, DummyT5WithSequenceInputIds) { tester.RunWithConfig(); } +TEST(BeamSearchTest, DummyT5PointerGenerator) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + // dummy_t5_pointer_generator.onnx model generated using following command: + // python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_pointer_generator.onnx --decoder-needs-input-ids + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_pointer_generator.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 3, 6, 7, 3, 6, 7, 18, 3, 6, 2, 3, 6, 7, 18, 3, 6, 7, 18, 3, 2, 3, 6, 7, 3, 6, 7, 3, 6, 7}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/dummy_t5_model_generator.py b/onnxruntime/test/testdata/dummy_t5_model_generator.py new file mode 100644 index 0000000000000..1ecd8b9ee9c92 --- /dev/null +++ b/onnxruntime/test/testdata/dummy_t5_model_generator.py @@ -0,0 +1,377 @@ +""" Script to generate a dummy ONNX model emulating T5 model with BeamSearch op. """ + +import argparse + +import numpy as np +import onnx + +import onnxruntime as ort +from onnxruntime.transformers.convert_generation import move_initializers + + +def create_model( + vocab_size: int, + embed_dim: int, + num_heads: int, + head_size: int, + beam_size: int, + min_length: int, + max_length: int, + length_penalty: float, + sequence_as_input: bool, + decoder_needs_input_ids: bool, +) -> onnx.ModelProto: + encoder_graph = create_encoder(vocab_size, embed_dim, num_heads, head_size) + decoder_graph = create_decoder( + vocab_size, embed_dim, num_heads, head_size, sequence_as_input, decoder_needs_input_ids + ) + + # Inputs: encoder_input_ids + encoder_input_ids = onnx.helper.make_tensor_value_info( + "encoder_input_ids", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"] + ) + + # Outputs: sequences, scores + sequences = onnx.helper.make_tensor_value_info( + "sequences", onnx.TensorProto.INT32, ["batch_size", beam_size, "decode_sequence_length"] + ) + scores = onnx.helper.make_tensor_value_info("scores", onnx.TensorProto.FLOAT, ["batch_size", beam_size]) + + # Tensors + max_length_t = onnx.numpy_helper.from_array(np.array(max_length, dtype=np.int32), name="max_length") + min_length_t = onnx.numpy_helper.from_array(np.array(min_length, dtype=np.int32), name="min_length") + num_beams_t = onnx.numpy_helper.from_array(np.array(beam_size, dtype=np.int32), name="num_beams") + length_penalty_t = onnx.numpy_helper.from_array( + np.array(length_penalty, dtype=np.float32), name="length_penalty_as_tensor" + ) + + # Nodes + beam_search = onnx.helper.make_node( + "BeamSearch", + ["encoder_input_ids", "max_length", "min_length", "num_beams", "num_beams", "length_penalty_as_tensor"], + ["sequences", "scores"], + decoder_start_token_id=2, + eos_token_id=2, + early_stopping=0, + model_type=1, + pad_token_id=1, + decoder=decoder_graph, + encoder=encoder_graph, + domain="com.microsoft", + ) + + # Graph + graph = onnx.helper.make_graph( + [beam_search], + "model", + [encoder_input_ids], + [sequences, scores], + [max_length_t, min_length_t, num_beams_t, length_penalty_t], + ) + + # Model + model = onnx.helper.make_model( + graph, opset_imports=[onnx.helper.make_opsetid("", 17), onnx.helper.make_opsetid("com.microsoft", 1)] + ) + + return model + + +def create_encoder(vocab_size, embed_dim, num_heads, head_size) -> onnx.GraphProto: + # Inputs: encoder_input_ids, encoder_attention_mask, decoder_input_ids + encoder_input_ids = onnx.helper.make_tensor_value_info( + "encoder_input_ids", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"] + ) + encoder_attention_mask = onnx.helper.make_tensor_value_info( + "encoder_attention_mask", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"] + ) + decoder_input_ids = onnx.helper.make_tensor_value_info( + "decoder_input_ids", onnx.TensorProto.INT32, ["batch_size", 1] + ) + + # Outputs: logits, present_key_self_0, present_value_self_0, present_key_cross_0, present_value_cross_0, encoder_hidden_states + logits = onnx.helper.make_tensor_value_info( + "logits", onnx.TensorProto.FLOAT, ["batch_size", "decode_sequence_length", vocab_size] + ) + present_key_self_0 = onnx.helper.make_tensor_value_info( + "present_key_self_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, 1, head_size] + ) + present_value_self_0 = onnx.helper.make_tensor_value_info( + "present_value_self_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, 1, head_size] + ) + present_key_cross_0 = onnx.helper.make_tensor_value_info( + "present_key_cross_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ) + present_value_cross_0 = onnx.helper.make_tensor_value_info( + "present_value_cross_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ) + encoder_hidden_states = onnx.helper.make_tensor_value_info( + "encoder_hidden_states", onnx.TensorProto.FLOAT, ["batch_size", "encode_sequence_length", embed_dim] + ) + + # Tensors + encoder_embeddings_tensor = onnx.numpy_helper.from_array( + np.random.randn(vocab_size, embed_dim).astype(np.float32), name="encoder_embeddings" + ) + num_heads_and_size_tensor = onnx.numpy_helper.from_array( + np.array([num_heads, head_size], dtype=np.int64), name="num_heads_and_size" + ) + final_proj_tensor = onnx.numpy_helper.from_array( + np.random.randn(embed_dim, vocab_size).astype(np.float32), name="init_final_proj" + ) + self_state_before_tranpose_shape_tensor = onnx.numpy_helper.from_array( + np.array([-1, 1, num_heads, head_size], dtype=np.int64), name="self_state_before_tranpose_shape" + ) + + # Nodes + nodes = [ + onnx.helper.make_node("Gather", ["encoder_embeddings", "encoder_input_ids"], ["encoder_hidden_states"]), + onnx.helper.make_node("Shape", ["encoder_hidden_states"], ["encoder_batch_seq_len"], end=2), + onnx.helper.make_node( + "Concat", ["encoder_batch_seq_len", "num_heads_and_size"], ["encoder_final_shape"], axis=0 + ), + onnx.helper.make_node( + "Reshape", ["encoder_hidden_states", "encoder_final_shape"], ["encoder_hidden_states_reshaped"] + ), + onnx.helper.make_node( + "Transpose", ["encoder_hidden_states_reshaped"], ["present_key_cross_0"], perm=[0, 2, 1, 3] + ), + onnx.helper.make_node( + "Transpose", ["encoder_hidden_states_reshaped"], ["present_value_cross_0"], perm=[0, 2, 1, 3] + ), + onnx.helper.make_node("Gather", ["encoder_embeddings", "decoder_input_ids"], ["decoder_hidden_states"]), + onnx.helper.make_node("ReduceMean", ["encoder_hidden_states"], ["encoder_hidden_states_mean"], axes=[1]), + onnx.helper.make_node("Add", ["decoder_hidden_states", "encoder_hidden_states_mean"], ["encoder_decoder_sum"]), + onnx.helper.make_node("MatMul", ["encoder_decoder_sum", "init_final_proj"], ["logits"]), + onnx.helper.make_node( + "Reshape", ["encoder_decoder_sum", "self_state_before_tranpose_shape"], ["self_state_before_tranpose"] + ), + onnx.helper.make_node("Transpose", ["self_state_before_tranpose"], ["present_key_self_0"], perm=[0, 2, 1, 3]), + onnx.helper.make_node("Transpose", ["self_state_before_tranpose"], ["present_value_self_0"], perm=[0, 2, 1, 3]), + ] + + # Graph + graph = onnx.helper.make_graph( + nodes, + "encoder", + [encoder_input_ids, encoder_attention_mask, decoder_input_ids], + [ + logits, + encoder_hidden_states, + present_key_self_0, + present_value_self_0, + present_key_cross_0, + present_value_cross_0, + ], + [ + encoder_embeddings_tensor, + num_heads_and_size_tensor, + final_proj_tensor, + self_state_before_tranpose_shape_tensor, + ], + ) + return graph + + +def create_decoder( + vocab_size, embed_dim, num_heads, head_size, sequence_as_input, decoder_needs_input_ids +) -> onnx.GraphProto: + # Inputs: input_ids, encoder_input_ids (optional), encoder_attention_mask, past_self_key_0, past_self_value_0, past_cross_key_0, past_cross_value_0 + inputs = [] + inputs.append( + onnx.helper.make_tensor_value_info( + "input_ids", onnx.TensorProto.INT32, ["batch_size", "decode_sequence_length" if sequence_as_input else 1] + ) + ) + if decoder_needs_input_ids: + inputs.append( + onnx.helper.make_tensor_value_info( + "encoder_input_ids", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"] + ) + ) + inputs.append( + onnx.helper.make_tensor_value_info( + "encoder_attention_mask", onnx.TensorProto.INT32, ["batch_size", "encode_sequence_length"] + ) + ) + inputs.append( + onnx.helper.make_tensor_value_info( + "past_self_key_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "decode_sequence_length", head_size] + ) + ) + inputs.append( + onnx.helper.make_tensor_value_info( + "past_self_value_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "decode_sequence_length", head_size] + ) + ) + inputs.append( + onnx.helper.make_tensor_value_info( + "past_cross_key_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ) + ) + inputs.append( + onnx.helper.make_tensor_value_info( + "past_cross_value_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ) + ) + + # Outputs: logits, present_key_self_0, present_value_self_0 + outputs = [ + onnx.helper.make_tensor_value_info("logits", onnx.TensorProto.FLOAT, ["batch_size", 1, vocab_size]), + onnx.helper.make_tensor_value_info( + "present_key_self_0", + onnx.TensorProto.FLOAT, + ["batch_size", num_heads, "present_decode_sequence_length", head_size], + ), + onnx.helper.make_tensor_value_info( + "present_value_self_0", + onnx.TensorProto.FLOAT, + ["batch_size", num_heads, "present_decode_sequence_length", head_size], + ), + ] + + # Tensors: decoder_embeddings, final_proj, self_state_before_tranpose_shape_no_batch, hidden_states_mean + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(vocab_size, embed_dim).astype(np.float32), name="decoder_embeddings" + ), + onnx.numpy_helper.from_array(np.random.randn(embed_dim, vocab_size).astype(np.float32), name="final_proj"), + onnx.numpy_helper.from_array( + np.array([-1, num_heads, head_size], dtype=np.int64), name="self_state_before_tranpose_shape_no_batch" + ), + onnx.numpy_helper.from_array(np.array([-1, 1, embed_dim], dtype=np.int64), name="hidden_states_mean_shape"), + ] + + # Nodes + nodes = [] + nodes.append(onnx.helper.make_node("Gather", ["decoder_embeddings", "input_ids"], ["decoder_hidden_states"])) + if decoder_needs_input_ids: + nodes.append( + onnx.helper.make_node("Gather", ["decoder_embeddings", "encoder_input_ids"], ["encoder_input_embeddings"]) + ) + nodes.append( + onnx.helper.make_node( + "ReduceMean", ["encoder_input_embeddings"], ["encoder_input_embeddings_mean"], axes=[1] + ) + ) + nodes.append( + onnx.helper.make_node( + "Mul", ["decoder_hidden_states", "encoder_input_embeddings_mean"], ["combined_hidden_states"] + ) + ) + else: + nodes.append(onnx.helper.make_node("Identity", ["decoder_hidden_states"], ["combined_hidden_states"])) + nodes.append(onnx.helper.make_node("ReduceMean", ["past_cross_key_0"], ["encoder_hidden_states_mean"], axes=[2])) + nodes.append( + onnx.helper.make_node( + "Reshape", + ["encoder_hidden_states_mean", "hidden_states_mean_shape"], + ["encoder_hidden_states_mean_reshaped"], + ) + ) + if sequence_as_input: + nodes.append( + onnx.helper.make_node("ReduceMean", ["combined_hidden_states"], ["decoder_hidden_states_mean"], axes=[1]) + ) + nodes.append( + onnx.helper.make_node( + "Add", ["decoder_hidden_states_mean", "encoder_hidden_states_mean_reshaped"], ["encoder_decoder_sum"] + ) + ) + else: + nodes.append( + onnx.helper.make_node( + "Add", ["combined_hidden_states", "encoder_hidden_states_mean_reshaped"], ["encoder_decoder_sum"] + ) + ) + nodes.append(onnx.helper.make_node("Shape", ["combined_hidden_states"], ["decoder_batch"], end=1)) + nodes.append( + onnx.helper.make_node( + "Concat", + ["decoder_batch", "self_state_before_tranpose_shape_no_batch"], + ["self_state_before_tranpose_shape_dec"], + axis=0, + ) + ) + nodes.append(onnx.helper.make_node("MatMul", ["encoder_decoder_sum", "final_proj"], ["logits"])) + nodes.append( + onnx.helper.make_node( + "Reshape", ["encoder_decoder_sum", "self_state_before_tranpose_shape_dec"], ["self_state_before_tranpose"] + ) + ) + nodes.append( + onnx.helper.make_node("Transpose", ["self_state_before_tranpose"], ["single_self_key_0"], perm=[0, 2, 1, 3]) + ) + nodes.append( + onnx.helper.make_node("Transpose", ["self_state_before_tranpose"], ["single_self_value_0"], perm=[0, 2, 1, 3]) + ) + nodes.append( + onnx.helper.make_node("Concat", ["past_self_key_0", "single_self_key_0"], ["present_key_self_0"], axis=2) + ) + nodes.append( + onnx.helper.make_node("Concat", ["past_self_value_0", "single_self_value_0"], ["present_value_self_0"], axis=2) + ) + + # Graph + graph = onnx.helper.make_graph(nodes, "decoder", inputs, outputs, initializers) + return graph + + +def run_model(model_path): + ort_session = ort.InferenceSession(model_path) + encoder_input_ids = np.array([[14, 6, 13, 9, 7]]).astype(np.int32) + print("encoder_input_ids: ", encoder_input_ids) + sequence, scores = ort_session.run(None, {"encoder_input_ids": encoder_input_ids}) + print("sequence: ", sequence) + print("scores: ", scores) + + +def move_initializers_on_outer_scope(model) -> None: + main_graph = model.graph + beam_search_node = model.graph.node[0] + decoder_graph = next(attr for attr in beam_search_node.attribute if attr.name == "decoder").g + encoder_graph = next(attr for attr in beam_search_node.attribute if attr.name == "encoder").g + main_graph.initializer.extend(move_initializers(decoder_graph, min_elements=10)) + main_graph.initializer.extend(move_initializers(encoder_graph, min_elements=10)) + + +def arg_parser(): + parser = argparse.ArgumentParser(description="Generate a dummy ONNX model emulating T5 model with BeamSearch op.") + parser.add_argument("--output-path", type=str, default="model.onnx", help="Model output path") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--vocab-size", type=int, default=20, help="Vocab size") + parser.add_argument("--embed-dim", type=int, default=8, help="Embedding dimension") + parser.add_argument("--num-heads", type=int, default=2, help="Number of heads") + parser.add_argument("--head-size", type=int, default=4, help="Head size") + parser.add_argument("--beam-size", type=int, default=3, help="Beam size") + parser.add_argument("--min-length", type=int, default=1, help="Min length") + parser.add_argument("--max-length", type=int, default=10, help="Max length") + parser.add_argument("--length-penalty", type=float, default=1.1, help="Length penalty") + parser.add_argument("--move-initializers", action="store_true", help="Move initializers to outer scope") + parser.add_argument("--sequence-as-input", action="store_true", help="Use sequence as input") + parser.add_argument("--decoder-needs-input-ids", action="store_true", help="Decoder needs model/encoder input ids") + + return parser.parse_args() + + +if __name__ == "__main__": + args = arg_parser() + np.random.seed(args.seed) + + model = create_model( + args.vocab_size, + args.embed_dim, + args.num_heads, + args.head_size, + args.beam_size, + args.min_length, + args.max_length, + args.length_penalty, + args.sequence_as_input, + args.decoder_needs_input_ids, + ) + if args.move_initializers: + move_initializers_on_outer_scope(model) + onnx.save(model, args.output_path) + + run_model(args.output_path) diff --git a/onnxruntime/test/testdata/dummy_t5_pointer_generator.onnx b/onnxruntime/test/testdata/dummy_t5_pointer_generator.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f7fee773cffe17cafa21af369ad06887d7cb38b4 GIT binary patch literal 7100 zcmd5>c|cRg)`zf#D=J(RktZ%FsIjh9R3vi(A}F#6xD=3>kgGu4>AWWQtDPKR&jaOU7w0ualy8}cLPK~0{Z%We|-5TH@WBh=FH4FGr!-A zknzYJ$}AYd*m0_u1O=BUja8|Ww9;6GhGF7m$qXKPhbyj0G~;x$Zz zN5@)AshU&C;^~OOyc1=4Qr{1Nl4^0E{r$t9>&QMW0a!HPK=d; zt7&}DNhgLdY~3#@7cb`&idfY&zM;2avF#17m9Yv1r;=*4GA$2j=_k`FxkP4+QU9jh zV;vj!xW2KsP@zn-Hmq}V-qLtZrgC9IIYm+o7sx-0S~2FbWKJV7-C}CRgfYhBXIchi z9byvV<*_PGA#L1kv%sV{#*eX8%QRYPOk#pYBb~veNC&V^h5#GcHy@oqADx0}!Hscj z3MJJjWonM?-h5+He!lfK3Ky%;+9zfr(@bxe55ZyW4UHKRL6a2kV&0%M44@tx@W}Hgl$WM%bNn#q?FugTg+|=6a;(qRx^%E4xv|z^CY50jlR4nCUclF+ zT|Hb~1dOq;*mefUTB{a7S;Lwp+GHVYCluP6dK+2S{L?JX=9P-wqa5L1v&P|`e_2TS z9!tY7rf(M7Yqjv$4uHj=L% z2BF!j33zg5N4!^QReWQuB`%oYjpt09$oAKJNwnJ&cpapH=lkq2{SAlbuKZ4>{ka24 z{C&9nat)jq`h>pmyNutAEr6>l@9Rp`xiG}xByLxC#EIp*M;um3#J{WVP)AJ{_%Z(r ztUO^ymrQ;k9{H^qa38sYZ`D$~G`R%tXMF{>uDi&Yjeu_1rBHt7q1gSmHPFp3OFZ_b znMAPf8YDfLN3VUC3{gW}A+-NK7;Z5e*{%ISo0DKclp)VxHF*-lzdxAUTjcf zS+*Glugb(blhbrVGPJOG(p(&9Z3-``6&|Yg$D%=(b!R`%z!!zB(Kb8>o#*}z8ERXx zd$*7AOXo-;vALtWJW5WjJe~5eUY`3yKm&5MDBN-ik4*|8gSU8Ltj$3@ zwZIic>MZoGx&W>RWUzO$py(%+4tzDa;x3;?Vd3+mAb4R7b4oI>$31Vlpw|>!{qOx) z-7y|Yrd7d6^8;Ai*%sz6bHg-cHSLz)0n_?T#!b{$>~=UD$E9SE)Y3rR$-P?YI6;YH z^6fEH<^`QRO`xq;Di*0ui0?&6fZg{C=67rb2eQ)IhVBu9a?Bwb#ZtW9=vu~Bq!Kc$;0H1FU9*+ml^~(^p zT0!W7?;x&!=i+eF05mDIhi(so=`WU)l;%3&pX0v~UyAKQZ=|dS&B)J*_oY1~>-t%` zFu$7mB>$6cd;TNwTYP~^C$A#W<>nGAffSZJ{F=&+@D+{953;;q9krmlc#!389D%Q^by&u$-o+hzE1W+%zgP;(+c7xs7jcil%|Qu{Hl`xXtWfohh%)O(uNc?j#iBU~FYPe;Zm74P)}fZa@cU;iEw1C$5u=&9H;Gc2mgmme z*aVd{UZ$DxUn7lavOlu!%wVA@JK&vQ1VRDB2UJUr-VKQFnR&!mBYWdt2J<(7*F13SQ8%s z)0ORPuy1~lctHeI7bnoz(R8r}f5QRwwly8#5-*=RS7 zD4j6oJ*KR^AwBd&x>movkNu8+FKBxMXl+@17;H0>Y!FT#EWiWmXQz`?bg!S#2Je`*v8q_Dv<1x-W!$f8jRL4SfiSKSKTK~0wJ4B`tT^w{ zMh>1mqg7&b>EBVdz7XyX zoJKtYvczuD2XS{)E0p)SDLxvu66Zx0lceuE;*CkGX`;0gYCBZI?Cz6`*qBTVS`i8x zd^Y05YvW1s`PXP7*oALak0Gu$V(2rcKM`Dfs2eF-Plrf!c+hVcR#ZHoWx<)Cd$f-P z?!Jx_9_5kr-pA8r7NrT8B4F?R1xS2)xMNgx(L7L>rP1W=BuqPFpdFYP*0eb{!4V z!gKMk`A&SX$pWUdSx0gtX=JDD8{*{95AT=9lXb1vkfLAR@xafIV8r@oWZT?zc*njM zh(tmG&l-(v2L859XyJN;@_D!X`9#)bq~Mf!W8g}_DiAw{;?J9Bf#sxfvM4bWoR2>P zOTRhvQk4XU9g4xH6(N{*catt<_g*q%*Oz4U(J;7aG2T>f^>dM?C>r+}JUkoH~zBuc*Md;djD0Y6`7lNxjFf@NM$=c~59{rOW%G|9b8KYg`$p+Kn z>y?C@pWYkU;SsR?hCe=eeP4Ic*A!f~O@@%Kau7%EroxYYqiZBRQTh5heE-Q$Bu}vw zO{%AZ@@Ff_*`672DD^02R=%XJqHR##dk6-4-Xt@Ed|}+!c)W5(qpMU07rS(skDEf$ zY1E<3c=_2G__fU;oF90dUb$vL^5%R8AfLGZ~}DdT+E#>l$c$u#xLXlMxSqRfIe5363^H!XwrE+&7J-S z-FVFb7k>5$6dgZE7Rm==-+-wwdBP@A*dYWW`*#EL;o)ehY>V5106e@p!P()vizbeF zB3{TJH4LBG0av!MfwirI>Djx{^tO3;s z$AXGR3(fgNe(wh3-_@Xc4BF-0H2?o5PWm!-hOM`rQ#32B Date: Mon, 23 Dec 2024 10:02:04 -0800 Subject: [PATCH 3/3] [QNN EP] Fix multithread sync bug in ETW callback (#23156) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Fixes crash in QNN dlls when an ETW callback tries to change the QNN log level. This is caused by a function that does not lock a mutex before modifying the QNN log level. ### Motivation and Context An ETW callback into QNN EP leads to a crash within QNN SDK dlls. It happens approximately 1 out of 3 full QNN unit tests runs. The cause is a multithreading synchronization bug in QNN EP. We're not always locking a mutex when ETW calls QNN EP to notify of ETW config change. There are two branches in the QNN EP callback function that try to update the QNN log handle. One branch correctly locks a mutex, but other does not lock it at all. This causes crashes within QNN dlls. - Does not lock mutex: [onnxruntime/onnxruntime/core/providers/qnn/qnn_execution_provider.cc at main · microsoft/onnxruntime](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/qnn/qnn_execution_provider.cc#L426) - Locks mutex: [onnxruntime/onnxruntime/core/providers/qnn/qnn_execution_provider.cc at main · microsoft/onnxruntime](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/qnn/qnn_execution_provider.cc#L442) The fix is to lock the mutex in both paths. --- .../qnn/builder/qnn_backend_manager.cc | 46 ++++++++++++------- .../qnn/builder/qnn_backend_manager.h | 46 ++++++++----------- .../providers/qnn/qnn_execution_provider.cc | 4 +- 3 files changed, 49 insertions(+), 47 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 3af646c3ce13a..fa38fad8eeb59 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -255,7 +255,9 @@ void QnnLogging(const char* format, } } -Status QnnBackendManager::InitializeQnnLog() { +Status QnnBackendManager::InitializeQnnLog(const logging::Logger& logger) { + logger_ = &logger; + // Set Qnn log level align with Ort log level auto ort_log_level = logger_->GetSeverity(); QnnLog_Level_t qnn_log_level = MapOrtSeverityToQNNLogLevel(ort_log_level); @@ -303,23 +305,15 @@ QnnLog_Level_t QnnBackendManager::MapOrtSeverityToQNNLogLevel(logging::Severity } } -Status QnnBackendManager::ResetQnnLogLevel() { +Status QnnBackendManager::ResetQnnLogLevel(std::optional ort_log_level) { std::lock_guard lock(logger_mutex_); - - if (backend_setup_completed_ && logger_ != nullptr) { - auto ort_log_level = logger_->GetSeverity(); - LOGS(*logger_, INFO) << "Reset Qnn log level to ORT Logger level: " << (unsigned int)ort_log_level; - return UpdateQnnLogLevel(ort_log_level); + if (!backend_setup_completed_ || logger_ == nullptr) { + return Status::OK(); } - return Status::OK(); -} - -Status QnnBackendManager::UpdateQnnLogLevel(logging::Severity ort_log_level) { ORT_RETURN_IF(nullptr == log_handle_, "Unable to update QNN Log Level. Invalid QNN log handle."); - ORT_RETURN_IF(false == backend_setup_completed_, "Unable to update QNN Log Level. Backend setup not completed."); - ORT_RETURN_IF(nullptr == logger_, "Unable to update QNN Log Level. Invalid logger."); - QnnLog_Level_t qnn_log_level = MapOrtSeverityToQNNLogLevel(ort_log_level); + logging::Severity actual_log_level = ort_log_level.has_value() ? *ort_log_level : logger_->GetSeverity(); + QnnLog_Level_t qnn_log_level = MapOrtSeverityToQNNLogLevel(actual_log_level); LOGS(*logger_, INFO) << "Updating Qnn log level to: " << qnn_log_level; @@ -332,7 +326,8 @@ Status QnnBackendManager::UpdateQnnLogLevel(logging::Severity ort_log_level) { LOGS(*logger_, ERROR) << "Invalid log handle provided to QnnLog_setLogLevel."; } } - ORT_RETURN_IF(QNN_BACKEND_NO_ERROR != result, "Failed to set log level in Qnn backend. Error: ", QnnErrorHandleToString(result)); + ORT_RETURN_IF(QNN_BACKEND_NO_ERROR != result, + "Failed to set log level in Qnn backend. Error: ", QnnErrorHandleToString(result)); return Status::OK(); } @@ -823,7 +818,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, LOGS(logger, VERBOSE) << "Backend build version: " << sdk_build_version_; - SetLogger(&logger); + ORT_RETURN_IF_ERROR(InitializeQnnLog(logger)); LOGS(logger, VERBOSE) << "SetLogger succeed."; ORT_RETURN_IF_ERROR(InitializeBackend()); @@ -1049,6 +1044,24 @@ Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) return Status::OK(); } +Status QnnBackendManager::TerminateQnnLog() { + std::lock_guard lock(logger_mutex_); + if (logger_ == nullptr) { + return Status::OK(); + } + + if (nullptr != qnn_interface_.logFree && nullptr != log_handle_) { + auto ret_val = qnn_interface_.logFree(log_handle_); + + // Reset QNN log handle to nullptr so other threads that are waiting on logger_mutex_ know it was freed. + log_handle_ = nullptr; + ORT_RETURN_IF(QNN_SUCCESS != ret_val, + "Unable to terminate logging in the backend."); + } + + return Status::OK(); +} + void QnnBackendManager::ReleaseResources() { if (!backend_setup_completed_) { return; @@ -1074,7 +1087,6 @@ void QnnBackendManager::ReleaseResources() { ORT_THROW("Failed to ShutdownBackend."); } - std::lock_guard lock(logger_mutex_); result = TerminateQnnLog(); if (Status::OK() != result) { ORT_THROW("Failed to TerminateQnnLog."); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index b145f2a2cd724..beabc9bd71b94 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -96,6 +96,8 @@ class QnnBackendManager { std::unordered_map>& qnn_models, int64_t max_spill_fill_size); + // Initializes handles to QNN resources (device, logger, etc.). + // NOTE: This function locks the internal `logger_mutex_`. Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib); Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id); @@ -121,34 +123,10 @@ class QnnBackendManager { const Qnn_ProfileHandle_t& GetQnnProfileHandle() { return profile_backend_handle_; } - void SetLogger(const logging::Logger* logger) { - if (logger_ == nullptr) { - logger_ = logger; - (void)InitializeQnnLog(); - } - } - - Status InitializeQnnLog(); - - Status UpdateQnnLogLevel(logging::Severity ort_log_level); - - Status ResetQnnLogLevel(); - - // Terminate logging in the backend - Status TerminateQnnLog() { - if (logger_ == nullptr) { - return Status::OK(); - } - - if (nullptr != qnn_interface_.logFree && nullptr != log_handle_) { - ORT_RETURN_IF(QNN_SUCCESS != qnn_interface_.logFree(log_handle_), - "Unable to terminate logging in the backend."); - } - - return Status::OK(); - } - - void ReleaseResources(); + // Resets the QNN log level to the given ORT log level or to the default log level if the argument is + // std::nullopt. + // NOTE: This function locks the internal `logger_mutex_`. + Status ResetQnnLogLevel(std::optional ort_log_level = std::nullopt); Status ExtractBackendProfilingInfo(); Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile, @@ -171,6 +149,18 @@ class QnnBackendManager { uint64_t& max_spill_fill_buffer_size); private: + // Sets the ORT logger and creates a corresponding QNN logger with the same log level. + // NOTE: caller must lock the `logger_mutex_` before calling this function. + Status InitializeQnnLog(const logging::Logger& logger); + + // Terminate logging in the backend + // NOTE: This function locks the internal `logger_mutex_`. + Status TerminateQnnLog(); + + // Releases all QNN resources. Called in the destructor. + // NOTE: This function indirectly locks the internal `logger_mutex_` via nested function calls. + void ReleaseResources(); + void* LoadLib(const char* file_name, int flags, std::string& error_msg); Status LoadQnnSystemLib(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 27e195dea73d2..1d9242f8a5939 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -423,7 +423,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); - (void)qnn_backend_manager_->UpdateQnnLogLevel(ortETWSeverity); + (void)qnn_backend_manager_->ResetQnnLogLevel(ortETWSeverity); } if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { if (Level != 0) { @@ -439,7 +439,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); - (void)qnn_backend_manager_->ResetQnnLogLevel(); + (void)qnn_backend_manager_->ResetQnnLogLevel(std::nullopt); } }); etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_);