Skip to content

Commit

Permalink
Merge main and fix conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Dec 23, 2024
2 parents 17c3bde + 81cd6ea commit 7e7fe68
Show file tree
Hide file tree
Showing 18 changed files with 559 additions and 54 deletions.
65 changes: 45 additions & 20 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ namespace transformers {
Inputs:
input_ids: int32 (B, 1)
encoder_input_ids: int32 (B, encode_sequence_length) (optional)
encoder_attention_mask: int32 (B, encode_sequence_length)
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) (optional)
past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
Expand Down Expand Up @@ -49,11 +50,9 @@ namespace transformers {

Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) {
bool has_hidden_state = subgraph_inputs[2]->Name() == "encoder_hidden_states" ? true : false;
SetPastInputIndex(has_hidden_state);

ORT_RETURN_IF(first_past_input_index_ != 2 && first_past_input_index_ != 3,
"kFirstPastInputIndex currently only supports 2 or 3");
bool has_encoder_input_ids = subgraph_inputs[1]->Name() == "encoder_input_ids";
bool has_hidden_state = subgraph_inputs[2 + has_encoder_input_ids]->Name() == "encoder_hidden_states";
SetPastInputIndex(has_hidden_state, has_encoder_input_ids);

if (!past_present_share_buffer_) {
ORT_RETURN_IF(has_decoder_masked_attention_, "decoder_masked_attention shall use with past_present_share_buffer");
Expand All @@ -75,13 +74,17 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i

ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids",
"decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name());
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask",
"decoder subgraph input 1 shall be named as encoder_attention_mask, got: ",
subgraph_inputs[1]->Name());
if (first_past_input_index_ == 3) {
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "encoder_hidden_states",
"decoder subgraph input 2 shall be named as encoder_hidden_states, got: ",
subgraph_inputs[2]->Name());
const int enc_attn_mask_index = 1 + has_encoder_input_ids_;
const int enc_hidden_state_index = enc_attn_mask_index + 1;
ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->Name() != "encoder_attention_mask",
"decoder subgraph input ", std::to_string(enc_attn_mask_index),
" shall be named as encoder_attention_mask, got: ",
subgraph_inputs[enc_attn_mask_index]->Name());
if (has_hidden_state_) {
ORT_RETURN_IF(subgraph_inputs[enc_hidden_state_index]->Name() != "encoder_hidden_states",
"decoder subgraph input ", std::to_string(enc_hidden_state_index),
" shall be named as encoder_hidden_states, got: ",
subgraph_inputs[enc_hidden_state_index]->Name());
}

// check subgraph outputs
Expand All @@ -108,12 +111,19 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i

ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 0 (input_ids) shall have int32 type");
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 1 (encoder_attention_mask) shall have int32 type");

auto float_type = subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type();
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
"decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type");
if (has_encoder_input_ids_) {
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 1 (encoder_input_ids) shall have int32 type");
}
ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input ", std::to_string(enc_attn_mask_index),
" (encoder_attention_mask) shall have int32 type");

auto float_type = subgraph_inputs[enc_hidden_state_index]->TypeAsProto()->tensor_type().elem_type();
if (has_hidden_state_) {
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
"decoder subgraph input ", std::to_string(enc_hidden_state_index), " (encoder_hidden_states) shall have float or float16 type");
}

for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) {
ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type,
Expand Down Expand Up @@ -219,6 +229,19 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
decoder_feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
decoder_feeds.push_back(input_ids);

if (has_encoder_input_ids_) {
// The encoder_input_ids is copied from the first input of encoder.
OrtValue expanded_encoder_input_ids;
ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream,
encoder_feeds[0],
num_beam,
allocator,
expanded_encoder_input_ids,
false,
0 /*max_sequence_length*/));
decoder_feeds.push_back(expanded_encoder_input_ids);
}

// The encoder_attention_mask is copied from the second input of encoder.
OrtValue expanded_decoder_attention_masks;
ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream,
Expand All @@ -238,7 +261,9 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
// When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output
// of encoder.
// When first_past_input_index_ == 2, the past states are copied from the second output of encoder.
for (size_t j = static_cast<size_t>(4) - first_past_input_index_; j < encoder_fetches.size(); j++) {
// TODO - probably more robust to introduce a encoder_out/decoder_in mapping instead of relying on positions.
// What happens if encoder_hidden_states is present in the encoder_fetches but not in the decoder_feeds?
for (size_t j = static_cast<size_t>(2) - has_hidden_state_; j < encoder_fetches.size(); j++) {
if (j == 1) {
ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false");
OrtValue expanded_hidden_states;
Expand Down
10 changes: 4 additions & 6 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,10 @@ class T5DecoderSubgraph : public Subgraph {
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) override;

void SetPastInputIndex(bool has_hidden_state) {
void SetPastInputIndex(bool has_hidden_state, bool has_encoder_input_ids) {
has_hidden_state_ = has_hidden_state;
if (!has_hidden_state_) {
first_past_input_index_ = 2;
} else {
first_past_input_index_ = 3;
}
has_encoder_input_ids_ = has_encoder_input_ids;
first_past_input_index_ = 2 + has_hidden_state_ + has_encoder_input_ids_;
}

int GetFirstPastInputIndex() const {
Expand All @@ -79,6 +76,7 @@ class T5DecoderSubgraph : public Subgraph {
int first_past_input_index_;
int first_present_output_index_;
bool has_hidden_state_;
bool has_encoder_input_ids_;
bool use_sequence_as_input_ids_;
};

Expand Down
19 changes: 18 additions & 1 deletion onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,24 @@ Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id)
return Status::OK();
}

Status QnnBackendManager::TerminateQnnLog() {
std::lock_guard<std::mutex> lock(logger_mutex_);
if (logger_ == nullptr) {
return Status::OK();
}

if (nullptr != qnn_interface_.logFree && nullptr != log_handle_) {
auto ret_val = qnn_interface_.logFree(log_handle_);

// Reset QNN log handle to nullptr so other threads that are waiting on logger_mutex_ know it was freed.
log_handle_ = nullptr;
ORT_RETURN_IF(QNN_SUCCESS != ret_val,
"Unable to terminate logging in the backend.");
}

return Status::OK();
}

void QnnBackendManager::ReleaseResources() {
if (!backend_setup_completed_) {
return;
Expand All @@ -1064,7 +1082,6 @@ void QnnBackendManager::ReleaseResources() {
ORT_THROW("Failed to ShutdownBackend.");
}

std::lock_guard<std::mutex> lock(logger_mutex_);
result = TerminateQnnLog();
if (Status::OK() != result) {
ORT_THROW("Failed to TerminateQnnLog.");
Expand Down
30 changes: 12 additions & 18 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class QnnBackendManager {
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
int64_t max_spill_fill_size);

// Initializes handles to QNN resources (device, logger, etc.).
// NOTE: This function locks the internal `logger_mutex_`.
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib);

Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id);
Expand Down Expand Up @@ -122,25 +124,9 @@ class QnnBackendManager {

// Resets the QNN log level to the given ORT log level or to the default log level if the argument is
// std::nullopt.
// IMPORTANT: This function locks the internal `logging_mutex_`.
// NOTE: This function locks the internal `logger_mutex_`.
Status ResetQnnLogLevel(std::optional<logging::Severity> ort_log_level = std::nullopt);

// Terminate logging in the backend
Status TerminateQnnLog() {
if (logger_ == nullptr) {
return Status::OK();
}

if (nullptr != qnn_interface_.logFree && nullptr != log_handle_) {
ORT_RETURN_IF(QNN_SUCCESS != qnn_interface_.logFree(log_handle_),
"Unable to terminate logging in the backend.");
}

return Status::OK();
}

void ReleaseResources();

Status ExtractBackendProfilingInfo();
Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile,
bool backendSupportsExtendedEventData, bool tracelogging_provider_ep_enabled);
Expand All @@ -163,9 +149,17 @@ class QnnBackendManager {

private:
// Sets the ORT logger and creates a corresponding QNN logger with the same log level.
// IMPORTANT: caller must lock the `logger_mutex_` before calling this function.
// NOTE: caller must lock the `logger_mutex_` before calling this function.
Status InitializeQnnLog(const logging::Logger& logger);

// Terminate logging in the backend
// NOTE: This function locks the internal `logger_mutex_`.
Status TerminateQnnLog();

// Releases all QNN resources. Called in the destructor.
// NOTE: This function indirectly locks the internal `logger_mutex_` via nested function calls.
void ReleaseResources();

void* LoadLib(const char* file_name, int flags, std::string& error_msg);

Status LoadQnnSystemLib();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio

if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) {
// (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID);
(void)qnn_backend_manager_->ResetQnnLogLevel();
(void)qnn_backend_manager_->ResetQnnLogLevel(std::nullopt);
}
});
etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ struct ProviderHost {
virtual const ConfigOptions& RunOptions__GetConfigOptions(const RunOptions* p) = 0;
// OrtSessionOptions
virtual const std::unordered_map<std::string, std::string>& SessionOptions__GetConfigOptionsMap(const OrtSessionOptions* p) = 0;
virtual bool SessionOptions__GetEnableProfiling(const OrtSessionOptions* p) = 0;
// ComputeCapability
virtual std::unique_ptr<ComputeCapability> ComputeCapability__construct(std::unique_ptr<IndexedSubGraph> t_sub_graph) = 0;
virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1561,5 +1561,8 @@ struct OrtSessionOptions final {
const std::unordered_map<std::string, std::string>& GetConfigOptions() const {
return onnxruntime::g_host->SessionOptions__GetConfigOptionsMap(this);
}
bool GetEnableProfiling() const {
return onnxruntime::g_host->SessionOptions__GetEnableProfiling(this);
}
PROVIDER_DISALLOW_ALL(OrtSessionOptions)
};
52 changes: 46 additions & 6 deletions onnxruntime/core/providers/vitisai/imp/global_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ struct OrtVitisAIEpAPI {
void (*initialize_onnxruntime_vitisai_ep)(vaip_core::OrtApiForVaip* api, std::vector<OrtCustomOpDomain*>& ret_domain);
std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>* (*compile_onnx_model_with_options)(
const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options);
std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>* (*compile_onnx_model_vitisai_ep_with_error_handling)(
const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options, void* status, vaip_core::error_report_func func);
uint32_t (*vaip_get_version)();
void (*create_ep_context_nodes)(
const std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>& eps,
Expand Down Expand Up @@ -77,10 +79,11 @@ struct OrtVitisAIEpAPI {
ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_));
#endif
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "initialize_onnxruntime_vitisai_ep", (void**)&initialize_onnxruntime_vitisai_ep));
auto status = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options);
if (!status.IsOK()) {
::onnxruntime::LogRuntimeError(0, status, __FILE__, static_cast<const char*>(__FUNCTION__), __LINE__);
ORT_THROW(status);
auto status1 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_error_handling", (void**)&compile_onnx_model_vitisai_ep_with_error_handling);
auto status2 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options);
if ((!status1.IsOK()) && (!status2.IsOK())) {
::onnxruntime::LogRuntimeError(0, status2, __FILE__, static_cast<const char*>(__FUNCTION__), __LINE__);
ORT_THROW(status2);
}
std::ignore = env.GetSymbolFromLibrary(handle_, "vaip_get_version",
(void**)&vaip_get_version);
Expand All @@ -89,6 +92,14 @@ struct OrtVitisAIEpAPI {
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_on_run_start", (void**)&vitisai_ep_on_run_start));
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_set_ep_dynamic_options", (void**)&vitisai_ep_set_ep_dynamic_options));
}
void Clear() {
if (handle_) {
auto& env = Provider_GetHost()->Env__Default();
auto status = env.UnloadDynamicLibrary(handle_);
vai_assert(status.IsOK(), status.ErrorMessage());
handle_ = nullptr;
}
}

private:
void* handle_{};
Expand All @@ -109,10 +120,25 @@ void profiler_collect(
}
}

void change_status_with_error(void* status_ptr, int error_code, const char* error_msg) {
auto status = reinterpret_cast<Status*>(status_ptr);
*status = Status(onnxruntime::common::ONNXRUNTIME, error_code, error_msg);
}

vaip_core::DllSafe<std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>> compile_onnx_model(
const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) {
const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options) {
auto model_path = graph_viewer.ModelPath().string();
return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options));
if (s_library_vitisaiep.compile_onnx_model_vitisai_ep_with_error_handling) {
Status status = Status::OK();
auto status_ptr = reinterpret_cast<void*>(&status);
auto ret = vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_vitisai_ep_with_error_handling(model_path, graph_viewer.GetGraph(), options, status_ptr, change_status_with_error));
if (!status.IsOK()) {
ORT_THROW(status);
}
return ret;
} else {
return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options));
}
}

std::optional<std::vector<Node*>> create_ep_context_nodes(
Expand Down Expand Up @@ -396,10 +422,12 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
the_global_api.tensor_proto_get_shape_unsafe = vaip::tensor_proto_get_shape;
the_global_api.tensor_proto_data_type = [](const ONNX_NAMESPACE::TensorProto& t) -> int { return t.data_type(); };
the_global_api.tensor_proto_delete = [](ONNX_NAMESPACE::TensorProto* tp) { delete tp; };
the_global_api.tensor_proto_new_i4 = vaip::tensor_proto_new_i4;
the_global_api.tensor_proto_new_i8 = vaip::tensor_proto_new_i8;
the_global_api.tensor_proto_new_i16 = vaip::tensor_proto_new_i16;
the_global_api.tensor_proto_new_i32 = vaip::tensor_proto_new_i32;
the_global_api.tensor_proto_new_i64 = vaip::tensor_proto_new_i64;
the_global_api.tensor_proto_new_u4 = vaip::tensor_proto_new_u4;
the_global_api.tensor_proto_new_u8 = vaip::tensor_proto_new_u8;
the_global_api.tensor_proto_new_u16 = vaip::tensor_proto_new_u16;
the_global_api.tensor_proto_new_u32 = vaip::tensor_proto_new_u32;
Expand Down Expand Up @@ -468,9 +496,21 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
return vaip_core::DllSafe<std::string>(std::move(local_str));
};

the_global_api.is_profiling_enabled = [](void* session_options) {
auto options = reinterpret_cast<OrtSessionOptions*>(session_options);
return options->GetEnableProfiling();
};
the_global_api.graph_remove_initialized_tensor = [](Graph& graph, const std::string& tensor_name) {
graph.RemoveInitializedTensor(tensor_name);
};
if (!s_library_vitisaiep.vaip_get_version) {
return reinterpret_cast<vaip_core::OrtApiForVaip*>(&(the_global_api.host_));
} else {
return &the_global_api;
}
}

void deinitialize_vitisai_ep() {
s_library_vitisaiep.Clear();
s_kernel_registry_vitisaiep.reset();
}
13 changes: 13 additions & 0 deletions onnxruntime/core/providers/vitisai/imp/tensor_proto.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ static ONNX_NAMESPACE::TensorProto* tensor_proto_new(const std::string& name, co
return tensor_proto.release();
}

ONNX_NAMESPACE::TensorProto* tensor_proto_new_i4(const std::string& name, const std::vector<int64_t>& shape,
const std::vector<int8_t>& data) {
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT4,
reinterpret_cast<const char*>(&data[0]), data.size() * sizeof(data[0]));
}

ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector<int64_t>& shape,
const std::vector<int8_t>& data) {
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT8,
Expand All @@ -108,6 +114,13 @@ ONNX_NAMESPACE::TensorProto* tensor_proto_new_i64(const std::string& name, const
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT64,
reinterpret_cast<const char*>(&data[0]), data.size() * sizeof(data[0]));
}

ONNX_NAMESPACE::TensorProto* tensor_proto_new_u4(const std::string& name, const std::vector<int64_t>& shape,
const std::vector<uint8_t>& data) {
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_UINT4,
reinterpret_cast<const char*>(&data[0]), data.size() * sizeof(data[0]));
}

ONNX_NAMESPACE::TensorProto* tensor_proto_new_u8(const std::string& name, const std::vector<int64_t>& shape,
const std::vector<uint8_t>& data) {
return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_UINT8,
Expand Down
Loading

0 comments on commit 7e7fe68

Please sign in to comment.