diff --git a/cmake/onnxruntime_java.cmake b/cmake/onnxruntime_java.cmake index 765ebab111ac7..b15b9632e9e24 100644 --- a/cmake/onnxruntime_java.cmake +++ b/cmake/onnxruntime_java.cmake @@ -7,7 +7,7 @@ include(FindJava) find_package(Java REQUIRED) include(UseJava) -if (NOT CMAKE_SYSTEM_NAME STREQUAL "Android") +if (NOT ANDROID) find_package(JNI REQUIRED) endif() @@ -21,23 +21,28 @@ endif() set(GRADLE_EXECUTABLE "${JAVA_ROOT}/gradlew") +set(COMMON_GRADLE_ARGS --console=plain) +if(WIN32) + list(APPEND COMMON_GRADLE_ARGS -Dorg.gradle.daemon=false) +elseif (ANDROID) + # For Android build, we may run gradle multiple times in same build, + # sometimes gradle JVM will run out of memory if we keep the daemon running + # it is better to not keep a daemon running + list(APPEND COMMON_GRADLE_ARGS --no-daemon) +endif() + # Specify the Java source files file(GLOB_RECURSE onnxruntime4j_gradle_files "${JAVA_ROOT}/*.gradle") file(GLOB_RECURSE onnxruntime4j_src "${JAVA_ROOT}/src/main/java/ai/onnxruntime/*.java") set(JAVA_OUTPUT_JAR ${JAVA_ROOT}/build/libs/onnxruntime.jar) # this jar is solely used to signaling mechanism for dependency management in CMake # if any of the Java sources change, the jar (and generated headers) will be regenerated and the onnxruntime4j_jni target will be rebuilt -set(GRADLE_ARGS --console=plain clean jar -x test) -if(WIN32) - set(GRADLE_ARGS ${GRADLE_ARGS} -Dorg.gradle.daemon=false) -elseif (CMAKE_SYSTEM_NAME STREQUAL "Android") - # For Android build, we may run gradle multiple times in same build, - # sometimes gradle JVM will run out of memory if we keep the daemon running - # it is better to not keep a daemon running - set(GRADLE_ARGS ${GRADLE_ARGS} --no-daemon) -endif() +set(GRADLE_ARGS clean jar -x test) -add_custom_command(OUTPUT ${JAVA_OUTPUT_JAR} COMMAND ${GRADLE_EXECUTABLE} ${GRADLE_ARGS} WORKING_DIRECTORY ${JAVA_ROOT} DEPENDS ${onnxruntime4j_gradle_files} ${onnxruntime4j_src}) +add_custom_command(OUTPUT ${JAVA_OUTPUT_JAR} + COMMAND ${GRADLE_EXECUTABLE} ${COMMON_GRADLE_ARGS} ${GRADLE_ARGS} + WORKING_DIRECTORY ${JAVA_ROOT} + DEPENDS ${onnxruntime4j_gradle_files} ${onnxruntime4j_src}) add_custom_target(onnxruntime4j DEPENDS ${JAVA_OUTPUT_JAR}) set_source_files_properties(${JAVA_OUTPUT_JAR} PROPERTIES GENERATED TRUE) set_property(TARGET onnxruntime4j APPEND PROPERTY ADDITIONAL_CLEAN_FILES "${JAVA_OUTPUT_DIR}") @@ -62,7 +67,7 @@ target_link_libraries(onnxruntime4j_jni PUBLIC onnxruntime) set(JAVA_PACKAGE_OUTPUT_DIR ${JAVA_OUTPUT_DIR}/build) file(MAKE_DIRECTORY ${JAVA_PACKAGE_OUTPUT_DIR}) -if (CMAKE_SYSTEM_NAME STREQUAL "Android") +if (ANDROID) set(ANDROID_PACKAGE_OUTPUT_DIR ${JAVA_PACKAGE_OUTPUT_DIR}/android) file(MAKE_DIRECTORY ${ANDROID_PACKAGE_OUTPUT_DIR}) endif() @@ -88,7 +93,7 @@ if(APPLE) elseif(JNI_ARCH STREQUAL "arm64") set(JNI_ARCH aarch64) endif() -elseif (CMAKE_SYSTEM_NAME STREQUAL "Android") +elseif (ANDROID) set(JNI_ARCH ${ANDROID_ABI}) elseif (ARM64) set(JNI_ARCH aarch64) @@ -180,15 +185,7 @@ else() endif() # run the build process (this copies the results back into CMAKE_CURRENT_BINARY_DIR) -set(GRADLE_ARGS --console=plain cmakeBuild -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR}) -if(WIN32) - set(GRADLE_ARGS ${GRADLE_ARGS} -Dorg.gradle.daemon=false) -elseif (CMAKE_SYSTEM_NAME STREQUAL "Android") - # For Android build, we may run gradle multiple times in same build, - # sometimes gradle JVM will run out of memory if we keep the daemon running - # it is better to not keep a daemon running - set(GRADLE_ARGS ${GRADLE_ARGS} --no-daemon) -endif() +set(GRADLE_ARGS cmakeBuild -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR}) # Append relevant native build flags to gradle command set(GRADLE_ARGS ${GRADLE_ARGS} ${ORT_PROVIDER_FLAGS}) @@ -197,9 +194,11 @@ if (onnxruntime_ENABLE_TRAINING_APIS) endif() message(STATUS "GRADLE_ARGS: ${GRADLE_ARGS}") -add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${GRADLE_EXECUTABLE} ${GRADLE_ARGS} WORKING_DIRECTORY ${JAVA_ROOT}) +add_custom_command(TARGET onnxruntime4j_jni POST_BUILD + COMMAND ${GRADLE_EXECUTABLE} ${COMMON_GRADLE_ARGS} ${GRADLE_ARGS} + WORKING_DIRECTORY ${JAVA_ROOT}) -if (CMAKE_SYSTEM_NAME STREQUAL "Android") +if (ANDROID) set(ANDROID_PACKAGE_JNILIBS_DIR ${JAVA_OUTPUT_DIR}/android) set(ANDROID_PACKAGE_ABI_DIR ${ANDROID_PACKAGE_JNILIBS_DIR}/${ANDROID_ABI}) file(MAKE_DIRECTORY ${ANDROID_PACKAGE_JNILIBS_DIR}) @@ -214,6 +213,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Android") POST_BUILD COMMAND ${CMAKE_COMMAND} -E echo "Generating Android AAR package..." COMMAND ${GRADLE_EXECUTABLE} + ${COMMON_GRADLE_ARGS} build -b build-android.gradle -c settings-android.gradle -DjniLibsDir=${ANDROID_PACKAGE_JNILIBS_DIR} -DbuildDir=${ANDROID_PACKAGE_OUTPUT_DIR} @@ -237,6 +237,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Android") POST_BUILD COMMAND ${CMAKE_COMMAND} -E echo "Building and running Android test for Android AAR package..." COMMAND ${GRADLE_EXECUTABLE} + ${COMMON_GRADLE_ARGS} clean assembleDebug assembleDebugAndroidTest -DminSdkVer=${ANDROID_MIN_SDK} --stacktrace diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json index 159b8654c1cb1..fcb2c7d5de89b 100644 --- a/cmake/vcpkg.json +++ b/cmake/vcpkg.json @@ -66,6 +66,12 @@ "platform": "windows" } ], + "overrides": [ + { + "name": "flatbuffers", + "version": "23.5.26" + } + ], "features": { "tests": { "description": "Build ONNXRuntime unit tests", diff --git a/include/onnxruntime/core/framework/kernel_registry.h b/include/onnxruntime/core/framework/kernel_registry.h index 7b3d04ee66d9e..aaf533135429c 100644 --- a/include/onnxruntime/core/framework/kernel_registry.h +++ b/include/onnxruntime/core/framework/kernel_registry.h @@ -8,6 +8,9 @@ #include "core/framework/op_kernel.h" namespace onnxruntime { +namespace logging { +class Logger; +} using KernelCreateMap = std::multimap; using KernelDefHashes = std::vector>; @@ -33,6 +36,7 @@ class KernelRegistry { // Kernel matching uses the types from the node and the kernel_type_str_resolver. Status TryFindKernel(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger, const KernelCreateInfo** out) const; // map of type constraint name to required type @@ -42,6 +46,7 @@ class KernelRegistry { // Kernel matching uses the explicit type constraint name to required type map in type_constraints. Status TryFindKernel(const Node& node, ProviderType exec_provider, const TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const; /** @@ -61,13 +66,15 @@ class KernelRegistry { std::string_view domain, int version, const KernelRegistry::TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const; static bool HasImplementationOf(const KernelRegistry& r, const Node& node, ProviderType exec_provider, - const IKernelTypeStrResolver& kernel_type_str_resolver) { + const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger) { const KernelCreateInfo* info; - Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info); + Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, logger, &info); return st.IsOK(); } @@ -83,6 +90,7 @@ class KernelRegistry { Status TryFindKernelImpl(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver* kernel_type_str_resolver, const TypeConstraintMap* type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const; // Check whether the types of inputs/outputs of the given node match the extra diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index 6cff153c336f0..31b0f22340510 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -53,6 +53,7 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& execution_provider /*required by constant folding*/, + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable = {}, concurrency::ThreadPool* intra_op_thread_pool = nullptr, std::unordered_map>* p_buffered_tensors = nullptr); @@ -84,6 +85,7 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable = {}, concurrency::ThreadPool* intra_op_thread_pool = nullptr, std::unordered_map>* p_buffered_tensors = nullptr); diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h index 3963b80de58a4..d035fd34bd072 100644 --- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h +++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h @@ -47,8 +47,20 @@ enum COREMLFlags { // and SessionOptionsAppendExecutionProvider (C API). For the old API, use COREMLFlags instead. static const char* const kCoremlProviderOption_MLComputeUnits = "MLComputeUnits"; static const char* const kCoremlProviderOption_ModelFormat = "ModelFormat"; +// same as COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES static const char* const kCoremlProviderOption_RequireStaticInputShapes = "RequireStaticInputShapes"; static const char* const kCoremlProviderOption_EnableOnSubgraphs = "EnableOnSubgraphs"; +// provided by https://developer.apple.com/documentation/coreml/mloptimizationhints-swift.struct/specializationstrategy-swift.property +// Core ML segments the model’s compute graph and specializes each segment for the target compute device. +// This process can affect the model loading time and the prediction latency. +// Use this option to tailor the specialization strategy for your model. +static const char* const kCoremlProviderOption_SpecializationStrategy = "SpecializationStrategy"; +// Profile the Core ML MLComputePlan. +// This logs the hardware each operator is dispatched to and the estimated execution time. +// Intended for developer usage but provide useful diagnostic information if performance is not as expected. +static const char* const kCoremlProviderOption_ProfileComputePlan = "ProfileComputePlan"; +// please refer to https://developer.apple.com/documentation/coreml/mlmodelconfiguration/allowlowprecisionaccumulationongpu +static const char* const kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU = "AllowLowPrecisionAccumulationOnGPU"; #ifdef __cplusplus extern "C" { diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 8e881c757f9ac..a35d975ac8f1b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4615,6 +4615,8 @@ struct OrtApi { * \param[in] num_keys * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.17. */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2, _In_ OrtSessionOptions* options, @@ -4632,6 +4634,8 @@ struct OrtApi { * \param[in] num_keys * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, @@ -4645,7 +4649,10 @@ struct OrtApi { * \param[in] mem_info OrtMemoryInfo instance * \param[in] count_or_bytes How many bytes is this scratch buffer * \param[out] out A pointer to the scrach buffer + * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); @@ -4656,6 +4663,8 @@ struct OrtApi { * \param[out] out A pointer to OrtAllocator * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); @@ -4677,6 +4686,8 @@ struct OrtApi { * \param[in] num_external_initializer_files Number of external files * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. */ ORT_API2_STATUS(AddExternalInitializersFromFilesInMemory, _In_ OrtSessionOptions* options, _In_reads_(num_external_initializer_files) const ORTCHAR_T* const* external_initializer_file_names, @@ -4699,6 +4710,8 @@ struct OrtApi { * OrtApi::ReleaseLoraAdapter. * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. */ ORT_API2_STATUS(CreateLoraAdapter, const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator, _Outptr_ OrtLoraAdapter** out); @@ -4717,6 +4730,8 @@ struct OrtApi { * OrtApi::ReleaseLoraAdapter. * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. */ ORT_API2_STATUS(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, _In_ OrtAllocator* allocator, _Outptr_ OrtLoraAdapter** out); @@ -4738,6 +4753,8 @@ struct OrtApi { * \param[in] adapter OrtLoraAdapter instance * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. */ ORT_API2_STATUS(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter); @@ -4756,6 +4773,8 @@ struct OrtApi { * \param[in] kv_len Number of elements in the keys and values arrays * * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. */ ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); diff --git a/js/react_native/android/gradle/wrapper/gradle-wrapper.jar b/js/react_native/android/gradle/wrapper/gradle-wrapper.jar index 249e5832f090a..e6441136f3d4b 100644 Binary files a/js/react_native/android/gradle/wrapper/gradle-wrapper.jar and b/js/react_native/android/gradle/wrapper/gradle-wrapper.jar differ diff --git a/js/react_native/android/gradle/wrapper/gradle-wrapper.properties b/js/react_native/android/gradle/wrapper/gradle-wrapper.properties index 012d6d90445b4..381baa9cef1ec 100644 --- a/js/react_native/android/gradle/wrapper/gradle-wrapper.properties +++ b/js/react_native/android/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,8 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=cb87f222c5585bd46838ad4db78463a5c5f3d336e5e2b98dc7c0c586527351c2 -distributionUrl=https\://services.gradle.org/distributions/gradle-7.5-bin.zip +distributionSha256Sum=544c35d6bd849ae8a5ed0bcea39ba677dc40f49df7d1835561582da2009b961d +distributionUrl=https\://services.gradle.org/distributions/gradle-8.7-bin.zip +networkTimeout=10000 +validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/js/react_native/android/gradlew b/js/react_native/android/gradlew index a69d9cb6c2065..1aa94a4269074 100755 --- a/js/react_native/android/gradlew +++ b/js/react_native/android/gradlew @@ -55,7 +55,7 @@ # Darwin, MinGW, and NonStop. # # (3) This script is generated from the Groovy template -# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt # within the Gradle project. # # You can find Gradle at https://github.com/gradle/gradle/. @@ -80,13 +80,11 @@ do esac done -APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit - -APP_NAME="Gradle" +# This is normally unused +# shellcheck disable=SC2034 APP_BASE_NAME=${0##*/} - -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD=maximum @@ -133,22 +131,29 @@ location of your Java installation." fi else JAVACMD=java - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the location of your Java installation." + fi fi # Increase the maximum file descriptors if we can. if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then case $MAX_FD in #( max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 MAX_FD=$( ulimit -H -n ) || warn "Could not query maximum file descriptor limit" esac case $MAX_FD in #( '' | soft) :;; #( *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 ulimit -n "$MAX_FD" || warn "Could not set maximum file descriptor limit to $MAX_FD" esac @@ -193,11 +198,15 @@ if "$cygwin" || "$msys" ; then done fi -# Collect all arguments for the java command; -# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of -# shell script including quotes and variable substitutions, so put them in -# double quotes to make sure that they get re-expanded; and -# * put everything else in single quotes, so that it's not re-expanded. + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. set -- \ "-Dorg.gradle.appname=$APP_BASE_NAME" \ diff --git a/js/react_native/android/gradlew.bat b/js/react_native/android/gradlew.bat index f127cfd49d402..25da30dbdeee9 100644 --- a/js/react_native/android/gradlew.bat +++ b/js/react_native/android/gradlew.bat @@ -26,6 +26,7 @@ if "%OS%"=="Windows_NT" setlocal set DIRNAME=%~dp0 if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @@ -42,11 +43,11 @@ set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 if %ERRORLEVEL% equ 0 goto execute -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail @@ -56,11 +57,11 @@ set JAVA_EXE=%JAVA_HOME%/bin/java.exe if exist "%JAVA_EXE%" goto execute -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 8f5cdc97f27e5..b67d003eaceeb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -258,7 +258,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches current_length, cpu_state.sequences, parameters->max_length, - decoder_subgraph_.has_decoder_masked_attention_)); + decoder_subgraph_.has_decoder_masked_attention_, + this->cuda_device_prop_ != nullptr)); if (decoder_subgraph_.past_present_share_buffer_) { decoder_fetches.reserve(static_cast(decoder_subgraph_.GetFirstPresentOutputIndex()) + @@ -302,17 +303,24 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches auto cur_len = std::to_string(current_length); dumper->Print("***CurrentLength", cur_len, true); - for (int i = 0; i <= decoder_subgraph_.GetFirstPastInputIndex(); i++) { + for (int i = 0; i < decoder_subgraph_.GetFirstPastInputIndex(); i++) { dumper->Print("decoder_feeds", i, true); dumper->Print("", decoder_feeds[i]); } - auto offset = decoder_subgraph_.GetFirstPastInputIndex() + 4 * decoder_subgraph_.num_layers; - dumper->Print("past_sequence_length", offset, true); - dumper->Print("", decoder_feeds[offset]); - dumper->Print("beam_width", offset + 1, true); - dumper->Print("", decoder_feeds[offset + 1]); - dumper->Print("cache_redir", offset + 2, true); - dumper->Print("", decoder_feeds[offset + 2]); + for (int i = 0; i < decoder_subgraph_.num_layers; i++) { + int self_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * i; + int self_value_idx = self_key_idx + 1; + dumper->Print("past_key_self", i, true); + dumper->Print("", decoder_feeds[self_key_idx]); + dumper->Print("past_value_self", i + 1, true); + dumper->Print("", decoder_feeds[self_value_idx]); + int cross_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * decoder_subgraph_.num_layers + 2 * i; + int cross_value_idx = cross_key_idx + 1; + dumper->Print("past_key_cross", i, true); + dumper->Print("", decoder_feeds[cross_key_idx]); + dumper->Print("past_value_cross", i, true); + dumper->Print("", decoder_feeds[cross_value_idx]); + } #endif #ifdef DEBUG_NODE_INPUTS_OUTPUTS diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 30bf3aa0a1212..8145fbd4a4123 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -100,6 +100,7 @@ struct ISequences { virtual gsl::span GetCurrentDeviceSequences() const = 0; // Get all current beam_index sequences in one continuous block (to pass to CUDA) virtual gsl::span GetNextDeviceSequences() = 0; // Get all next beam_index sequences in one continuous block (to pass to CUDA) virtual int GetSequenceLength() const = 0; + virtual int GetMaxLength() const = 0; }; struct ILogitsProcessorList { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc index 723c271897a78..ecad146da6777 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc @@ -36,6 +36,10 @@ int Sequences::GetSequenceLength() const { return current_length_; } +int Sequences::GetMaxLength() const { + return max_length_; +} + #ifdef DEBUG_GENERATION void Sequences::PrintSequences(const IConsoleDumper* dumper) const { for (int i = 0; i < batch_beam_size_; i++) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.h b/onnxruntime/contrib_ops/cpu/transformers/sequences.h index 440a07e14a6cc..7dd1f28d270c7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.h @@ -25,6 +25,9 @@ class Sequences : public ISequences { // Returns current sequence length. int GetSequenceLength() const override; + // Returns max sequence length. + int GetMaxLength() const override; + #ifdef DEBUG_GENERATION // Print the sequences to StdOut in debug mode void PrintSequences(const IConsoleDumper* dumper) const; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc index d675ba742e03b..7757435990a65 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc @@ -31,6 +31,7 @@ Subgraph::Subgraph( allocator_(nullptr), is_output_float16_(false) { num_implicit_inputs = static_cast(node.ImplicitInputDefs().size()); + used_implicit_inputs = std::vector(num_implicit_inputs, true); auto& subgraph_inputs = subgraph.GetInputs(); auto& subgraph_outputs = subgraph.GetOutputs(); @@ -73,8 +74,18 @@ Status Subgraph::Setup(const SessionState& session_state, // The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter. feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end()); - for (auto& entry : node.ImplicitInputDefs()) { - feed_names.push_back(entry->Name()); + const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap(); + + const auto& implicit_input_defs = node.ImplicitInputDefs(); + for (size_t i = 0, end = num_implicit_inputs; i < end; ++i) { + const auto* entry = implicit_input_defs[i]; + int idx; + if (subgraph_map.GetIdx(entry->Name(), idx).IsOK()) { + feed_names.push_back(entry->Name()); + } else { + --num_implicit_inputs; + used_implicit_inputs[i] = false; + } } InlinedVector feed_locations; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h index bde591626bb83..8ec9c9cbdc20f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h @@ -31,6 +31,7 @@ class Subgraph { const GraphViewer& subgraph; // The subgraph int num_implicit_inputs; + std::vector used_implicit_inputs; int num_subgraph_inputs; // Same as subgraph_input_names.size(), keep it for convenience. int num_subgraph_outputs; // Same as subgraph_output_names.size() diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 9037e58aaf31f..f4e7173c917c1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -156,7 +156,8 @@ Status T5DecoderSubgraph::CreateInitialFeeds( int cur_len, transformers::Sequences& sequences, int past_present_share_buffer_max_seq_len, - bool need_cache_indir) { + bool need_cache_indir, + bool use_cuda) { ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); // Allocate subgraph inputs from same device as inputs of encoder subgraph. @@ -171,8 +172,9 @@ Status T5DecoderSubgraph::CreateInitialFeeds( Tensor::InitOrtValue(DataTypeImpl::GetType(), input_ids_shape, allocator, input_ids); int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); AllocatorPtr buffer_allocator = std::make_shared(); - size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); + size_t total_size = static_cast(cur_len) * static_cast(batch_beam_size); + size_t total_size_bytes = total_size * sizeof(int); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size_bytes, false, stream); int* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { @@ -182,19 +184,35 @@ Status T5DecoderSubgraph::CreateInitialFeeds( stream, DeviceCopyDirection::hostToDevice)); } else { - for (int i = 0; i < batch_beam_size; i++) { - gsl::span sequence = sequences.GetSequence(i); - const int32_t* sequence_data = sequence.data(); - long long seq_index = (long long)i * cur_len; - memcpy(seq_copy_ptr + seq_index, sequence_data, total_size); + if (use_cuda) { + auto sequences_buffer = sequences.GetCurrentDeviceSequences(); + for (int i = 0; i < batch_beam_size; i++) { + size_t batch_beam_stride = static_cast(i) * static_cast(sequences.GetMaxLength()); + int seq_size = sequences.GetSequenceLength(); + gsl::span sequence = sequences_buffer.subspan(batch_beam_stride, seq_size); + gsl::span temp_input(input_ids_data + static_cast(i) * seq_size, seq_size); + ORT_RETURN_IF_ERROR(device_copy_int32_func( + temp_input, + sequence, + stream, + DeviceCopyDirection::deviceToDevice)); + } + } else { + const size_t cur_len_bytes = cur_len * sizeof(int); + for (int i = 0; i < batch_beam_size; i++) { + gsl::span sequence = sequences.GetSequence(i); + const int32_t* sequence_data = sequence.data(); + ptrdiff_t seq_index = static_cast(i) * cur_len; + memcpy(seq_copy_ptr + seq_index, sequence_data, cur_len_bytes); + } + gsl::span temp_input(input_ids_data, total_size); + gsl::span temp_sequence(seq_copy_ptr, total_size); + ORT_RETURN_IF_ERROR(device_copy_int32_func( + temp_input, + temp_sequence, + stream, + DeviceCopyDirection::hostToDevice)); } - gsl::span temp_input(input_ids_data, total_size); - gsl::span temp_sequence(seq_copy_ptr, total_size); - ORT_RETURN_IF_ERROR(device_copy_int32_func( - temp_input, - temp_sequence, - stream, - DeviceCopyDirection::hostToDevice)); } // The ordering is the same as used in Setup. @@ -230,7 +248,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( num_beam, allocator, expanded_hidden_states, - true, + false, 0 /*max_sequence_length*/)); } else { ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, @@ -238,7 +256,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( num_beam, allocator, expanded_hidden_states, - true, + false, 0 /*max_sequence_length*/)); } decoder_feeds.push_back(expanded_hidden_states); @@ -281,8 +299,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds( } // Pass through implicit inputs. - for (const auto* entry : implicit_inputs) { - decoder_feeds.push_back(*entry); + for (size_t i = 0; i < implicit_inputs.size(); ++i) { + const auto* entry = implicit_inputs[i]; + if (used_implicit_inputs[i]) { + decoder_feeds.push_back(*entry); + } } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index 83dae49c7dcbd..a72ce37a93aba 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -48,7 +48,8 @@ class T5DecoderSubgraph : public Subgraph { int cur_len, transformers::Sequences& sequences, int past_present_share_buffer_max_seq_len = -1, - bool need_cache_indir = false); + bool need_cache_indir = false, + bool use_cuda = false); Status Validate(const std::vector& subgraph_inputs, const std::vector& subgraph_outputs) override; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc index 51473c0c931b9..d59db4afac2c2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc @@ -145,8 +145,11 @@ Status T5EncoderSubgraph::CreateInitialFeeds( pinned_allocator, location)); - for (const auto* entry : implicit_inputs) { - feeds.push_back(*entry); + for (size_t i = 0; i < implicit_inputs.size(); ++i) { + const auto* entry = implicit_inputs[i]; + if (used_implicit_inputs[i]) { + feeds.push_back(*entry); + } } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e047bd948434d..4e65336665bf7 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -1264,16 +1264,14 @@ Status UpdateDecoderFeeds( CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(), cudaMemcpyHostToDevice, cuda_stream)); } else { - for (int i = 0; i < batch_beam_size; i++) { - gsl::span sequence = sequences.GetSequence(i); - const int32_t* sequence_data = sequence.data(); - CUDA_RETURN_IF_ERROR( - cudaMemcpyAsync(input_ids_data + static_cast(i) * current_length, - sequence_data, - current_length * sizeof(int32_t), - cudaMemcpyHostToDevice, - cuda_stream)); - } + // We expect sequences to point directly to device memory + int max_length = sequences.GetMaxLength(); + auto sequences_buffer = sequences.GetCurrentDeviceSequences(); + CUDA_RETURN_IF_ERROR( + cudaMemcpy2DAsync(input_ids_data, current_length * sizeof(int32_t), + sequences_buffer.data(), max_length * sizeof(int32_t), + current_length * sizeof(int32_t), batch_beam_size, + cudaMemcpyDeviceToDevice, cuda_stream)); } next_inputs[0] = input_ids; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 4e2429c7c2964..a567af4f9b076 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -341,6 +341,7 @@ const TILE_SIZE : u32 = 16u; const VALUES_PER_VEC4 : u32 = 4u; const QUANTIZATION_BLOCK_SIZE : u32 = 32; const A_REPEAT : u32 = 8u; + // We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU/SM, // so we use BLOCKS_PER_CYCLE as 2u, or process weights 2 blocks at a time. // This uses all 16 lanes on 12th gen intel chips. @@ -525,7 +526,6 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context // const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1; constexpr uint32_t output_number = 1; MatMulNBitsProgram program{output_number, gsl::narrow(components_b), has_zero_points, use_block32}; - if (use_block32) { components = 1; constexpr uint32_t workgroup_size = 128; diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 5dca4cf6c165b..ecd3960107926 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -138,7 +138,8 @@ class PlannerImpl { const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps, const InlinedHashMap& outer_scope_node_arg_to_location_map, const OrtValueNameIdxMap& ort_value_name_idx_map, - const ISequentialPlannerContext& context, SequentialExecutionPlan& plan) + const ISequentialPlannerContext& context, SequentialExecutionPlan& plan, + const logging::Logger& logger) : context_(&context), plan_(plan), parent_node_(parent_node), @@ -148,14 +149,15 @@ class PlannerImpl { kernel_create_info_map_(kernel_create_info_map), subgraphs_kernel_create_info_maps_(subgraphs_kernel_create_info_maps), outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map), - ort_value_name_idx_map_(ort_value_name_idx_map) {} + ort_value_name_idx_map_(ort_value_name_idx_map), + logger_(logger) { + } Status CreatePlan( #ifdef ORT_ENABLE_STREAM const IStreamCommandHandleRegistry& stream_handle_registry, #endif - const PathString& partition_config_file, - const logging::Logger& logger); + const PathString& partition_config_file); private: gsl::not_null context_; @@ -183,6 +185,12 @@ class PlannerImpl { InlinedHashMap> dependence_graph_; InlinedHashMap value_node_map_; + // logger_ is not currently used in a minimal build +#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) + [[maybe_unused]] +#endif + const logging::Logger& logger_; + // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: struct OrtValueInfo { const onnxruntime::NodeArg* p_def_site; // the (unique) NodeArg corresponding to the MLValue @@ -213,6 +221,7 @@ class PlannerImpl { FreeBufferInfo(OrtValueIndex ort_value, size_t dealloc_point) : ml_value(ort_value), deallocate_point(dealloc_point) {} }; + // freelist_ : a list of ml-values whose buffers are free to be reused, sorted by when // they became free (more recently freed earlier in the list). std::list freelist_; @@ -225,7 +234,8 @@ class PlannerImpl { } int& UseCount(OrtValueIndex n) { - ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size(), "invalid value index: ", n, " against size ", ort_value_info_.size()); + ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size(), + "invalid value index: ", n, " against size ", ort_value_info_.size()); return ort_value_info_[n].usecount; } int& UseCount(const OrtValueName& name) { return UseCount(Index(name)); } @@ -335,9 +345,9 @@ class PlannerImpl { // we cannot. const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " - << producer_node->Name() << " which has external outputs. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " + << producer_node->Name() << " which has external outputs. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif *reusable_input = Index(p_input_arg->Name()); @@ -361,9 +371,9 @@ class PlannerImpl { // we cannot. const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " - << producer_node->Name() << " which has external outputs. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " + << producer_node->Name() << " which has external outputs. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif *reusable_input = Index(p_input_arg->Name()); @@ -397,8 +407,8 @@ class PlannerImpl { } } else { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node " - << producer_node->Name() << " as it has external outputs"; + LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node " + << producer_node->Name() << " as it has external outputs"; #endif } } @@ -448,8 +458,8 @@ class PlannerImpl { return true; } else { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node " - << producer_node->Name() << " as it has external outputs."; + LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node " + << producer_node->Name() << " as it has external outputs."; #endif } } @@ -1198,9 +1208,9 @@ class PlannerImpl { // Otherwise, we cannot reuse the buffer. const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " - << producer_node->Name() << " which has external outputs is reused. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " + << producer_node->Name() << " which has external outputs is reused. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif @@ -1241,9 +1251,9 @@ class PlannerImpl { // Otherwise, we cannot reuse the buffer. const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " - << producer_node->Name() << " which has external outputs is reused. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " + << producer_node->Name() << " which has external outputs is reused. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif @@ -1290,8 +1300,8 @@ class PlannerImpl { } } else { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - LOGS_DEFAULT(VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node " - << producer_node->Name() << " as it has external outputs"; + LOGS(logger_, VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node " + << producer_node->Name() << " as it has external outputs"; #endif } } @@ -1869,8 +1879,7 @@ class PlannerImpl { } #ifndef ORT_ENABLE_STREAM - void PartitionIntoStreams(const logging::Logger& /*logger*/, - const ExecutionProviders& /*execution_providers*/, + void PartitionIntoStreams(const ExecutionProviders& /*execution_providers*/, const PathString& /*partition_config_file*/) { if (graph_viewer_.NumberOfNodes() > 0) { stream_nodes_.push_back({}); @@ -1915,11 +1924,11 @@ class PlannerImpl { #else - void - PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers, - const PathString& partition_config_file) { - auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger, partition_config_file); - auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, context_->GetExecutionOrder()); + void PartitionIntoStreams(const ExecutionProviders& execution_providers, + const PathString& partition_config_file) { + auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger_, partition_config_file); + auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, + context_->GetExecutionOrder()); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); plan_.node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); for (size_t i = 0; i < stream_nodes_.size(); ++i) { @@ -2282,10 +2291,9 @@ Status PlannerImpl::CreatePlan( #ifdef ORT_ENABLE_STREAM const IStreamCommandHandleRegistry& stream_handle_registry, #endif - const PathString& partition_config_file, - const logging::Logger& logger) { + const PathString& partition_config_file) { // 1. partition graph into streams - PartitionIntoStreams(logger, execution_providers_, this->parent_node_ ? PathString{} : partition_config_file); + PartitionIntoStreams(execution_providers_, parent_node_ ? PathString{} : partition_config_file); // 2. initialize the plan based on stream partition result int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; @@ -2354,14 +2362,13 @@ Status SequentialPlanner::CreatePlan( PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers, kernel_create_info_map, subgraphs_kernel_create_info_maps, outer_scope_node_arg_to_location_map, - ort_value_name_idx_map, context, *plan); + ort_value_name_idx_map, context, *plan, logger); return planner.CreatePlan( #ifdef ORT_ENABLE_STREAM stream_handle_registry, #endif - partition_config_file, - logger); + partition_config_file); } #ifdef ORT_ENABLE_STREAM diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index ef68b88187e08..1eb7420b44d2c 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -41,7 +41,8 @@ static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const Node std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) { + gsl::span tentative_nodes, + const logging::Logger& logger) { // automatic conversion from const std::vector& const auto& ordered_nodes = graph.GetNodesInTopologicalOrder(); InlinedVector node_id_to_order_map(graph.MaxNodeIndex()); @@ -83,7 +84,7 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name()); for (auto& consumer_node : consumer_nodes) { candidates.push(consumer_node->Index()); - LOGS_DEFAULT(INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name(); + LOGS(logger, INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name(); } } return Status::OK(); @@ -159,9 +160,9 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe if (place_in_cpu) { cpu_nodes.insert(cur); - LOGS_DEFAULT(INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name() - << " because the CPU execution path is deemed faster than overhead involved with execution on other EPs " - << " capable of executing this node"; + LOGS(logger, INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name() + << " because the CPU execution path is deemed faster than overhead involved with execution " + "on other EPs capable of executing this node"; for (auto* output : node->OutputDefs()) { cpu_output_args.insert(output); } diff --git a/onnxruntime/core/framework/fallback_cpu_capability.h b/onnxruntime/core/framework/fallback_cpu_capability.h index c5bcd22888b7c..bca75adbfd5a7 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.h +++ b/onnxruntime/core/framework/fallback_cpu_capability.h @@ -9,6 +9,9 @@ #include "core/graph/graph_viewer.h" namespace onnxruntime { +namespace logging { +class Logger; +} /** Returns a list of nodes that are preferred on CPU. @@ -19,6 +22,7 @@ namespace onnxruntime { */ std::unordered_set GetCpuPreferredNodes(const GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger); } // namespace onnxruntime diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 6174122cf3cb4..406fc1b15effc 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -149,13 +149,13 @@ auto get_capabilities = [](const IExecutionProvider& ep, }; } // namespace -static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { +static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const logging::Logger& logger) { auto& current_ep = params.current_ep.get(); const auto& ep_type = current_ep.Type(); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) if (current_ep.GetPreferredLayout() == DataLayout::NHWC && !params.transform_layout.get()) { - LOGS_DEFAULT(WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by " + LOGS(logger, WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by " "the layout transformer."; return Status::OK(); } @@ -165,7 +165,8 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); const KernelLookup kernel_lookup{ep_type, kernel_registries_for_ep, - kernel_registry_mgr.GetKernelTypeStrResolver()}; + kernel_registry_mgr.GetKernelTypeStrResolver(), + logger}; auto& graph = params.graph.get(); auto& capabilities = params.capabilities.get(); @@ -248,13 +249,15 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, const KernelRegistryManager& kernel_registry_mgr, const IExecutionProvider& current_ep, + const logging::Logger& logger, std::vector>& capabilities) { const auto& ep_type = current_ep.Type(); const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); const KernelLookup kernel_lookup{ep_type, kernel_registries_for_ep, - kernel_registry_mgr.GetKernelTypeStrResolver()}; + kernel_registry_mgr.GetKernelTypeStrResolver(), + logger}; // TODO: Provide EP with a capability to look inside the functions. capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); @@ -359,7 +362,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, GraphPartitioner::Mode mode, int& fused_node_unique_id, const layout_transformation::TransformLayoutFunction& transform_layout_fn, - const layout_transformation::DebugGraphFn& debug_graph_fn) { + const layout_transformation::DebugGraphFn& debug_graph_fn, + const logging::Logger& logger) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability if (graph.NumberOfNodes() == 0) { @@ -373,7 +377,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, - transform_layout_fn, debug_graph_fn)); + transform_layout_fn, debug_graph_fn, logger)); } } @@ -398,7 +402,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, std::cref(transform_layout_fn), std::cref(debug_graph_fn)}; - ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params)); + ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { return Status::OK(); } @@ -425,7 +429,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id); if (n != nullptr) { // searching in kernel registries, if no kernel registered for the fused_node, use compile approach - if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type)) { + if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type, logger)) { nodes_to_compile.push_back(n); capabilities_to_compile.push_back(std::move(capability)); } else { @@ -559,6 +563,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers, const KernelRegistryManager& kernel_registry_mgr, Graph& graph, + const logging::Logger& logger, InlinedHashSet& not_inlined, size_t& inlined_count) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. @@ -574,6 +579,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_mgr, *subgraph, + logger, not_inlined, inlined_count)); } @@ -597,7 +603,8 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide InlinedHashSet claimed_by_ep; for (const auto& ep : execution_providers) { std::vector> capabilities; - ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, capabilities)); + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, logger, + capabilities)); for (auto& capability : capabilities) { const auto& nodes = capability->sub_graph->nodes; if (nodes.size() == 1) { @@ -727,7 +734,8 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, - KernelRegistryManager& kernel_registry_manager) { + KernelRegistryManager& kernel_registry_manager, + const logging::Logger& logger) { bool modified_graph = false; auto& graph = partition_params.graph.get(); @@ -742,7 +750,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, func_mgr, kernel_registry_manager, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function, - partition_params.debug_graph_fn)); + partition_params.debug_graph_fn, + logger)); } // expand any nodes that have an ONNX function definition but no matching ORT kernel. @@ -762,7 +771,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_params, KernelRegistryManager& kernel_registry_mgr, - IExecutionProvider& current_ep) { + IExecutionProvider& current_ep, + const logging::Logger& logger) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability auto& graph = partition_params.graph.get(); @@ -776,7 +786,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param auto& subgraph = *entry.second; PartitionParams subgraph_partition_params = partition_params; subgraph_partition_params.graph = std::ref(subgraph); - ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep, + logger)); } } @@ -795,7 +806,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param }; // clang-format on - ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params)); + ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { return Status::OK(); } @@ -876,10 +887,11 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param // Simplified partitioning where custom EPs may produce compiled nodes. static Status PartitionOrtFormatModel(const PartitionParams& partition_params, const ExecutionProviders& execution_providers, - KernelRegistryManager& kernel_registry_manager) { + KernelRegistryManager& kernel_registry_manager, + const logging::Logger& logger) { // process full graph with each EP for (const auto& ep : execution_providers) { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep, logger)); } return Status::OK(); @@ -906,6 +918,7 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_manager, graph, + logger, not_inlined, inlined_count)); @@ -977,8 +990,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) - ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, - providers_, kernel_registry_mgr_)); + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger)); bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); @@ -991,8 +1003,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, - providers_, kernel_registry_mgr_)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, providers_, kernel_registry_mgr_, logger)); } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/kernel_lookup.h b/onnxruntime/core/framework/kernel_lookup.h index 0dd17d2f4a624..fac43bad0fefb 100644 --- a/onnxruntime/core/framework/kernel_lookup.h +++ b/onnxruntime/core/framework/kernel_lookup.h @@ -21,17 +21,19 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup { public: KernelLookup(ProviderType provider_type, gsl::span> kernel_registries, - const IKernelTypeStrResolver& kernel_type_str_resolver) + const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger) : provider_type_{provider_type}, kernel_registries_{kernel_registries}, - kernel_type_str_resolver_{kernel_type_str_resolver} { + kernel_type_str_resolver_{kernel_type_str_resolver}, + logger_{logger} { ORT_ENFORCE(!provider_type_.empty(), "provider_type must be specified."); } const KernelCreateInfo* LookUpKernel(const Node& node) const override { const KernelCreateInfo* kernel_create_info{}; for (const auto& registry : kernel_registries_) { - const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_, + const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_, logger_, &kernel_create_info); if (lookup_status.IsOK() && kernel_create_info != nullptr) { return kernel_create_info; @@ -45,6 +47,7 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup { ProviderType provider_type_; const gsl::span> kernel_registries_; const IKernelTypeStrResolver& kernel_type_str_resolver_; + const logging::Logger& logger_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_registry.cc b/onnxruntime/core/framework/kernel_registry.cc index cbbe0f86b8b7e..8602a3b4004ff 100644 --- a/onnxruntime/core/framework/kernel_registry.cc +++ b/onnxruntime/core/framework/kernel_registry.cc @@ -183,6 +183,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver* kernel_type_str_resolver, const TypeConstraintMap* type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const { const auto& node_provider = node.GetExecutionProviderType(); const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider); @@ -215,7 +216,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node, std::ostream_iterator(oss, "\n")); oss << ")"; - VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str(); + VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str(); return Status(common::ONNXRUNTIME, common::FAIL, oss.str()); } @@ -224,14 +225,16 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node, Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger, const KernelCreateInfo** out) const { - return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, out); + return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, logger, out); } Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider, const TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const { - return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, out); + return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, logger, out); } static bool KernelDefCompatible(int version, const KernelDef& kernel_def, @@ -261,6 +264,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider, std::string_view domain, int version, const KernelRegistry::TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const { auto range = kernel_creator_fn_map_.equal_range(GetMapKey(op_type, domain, exec_provider)); if (out) *out = nullptr; @@ -289,7 +293,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider, std::ostream_iterator(oss, "\n")); oss << ")"; - VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str(); + VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str(); return Status(common::ONNXRUNTIME, common::FAIL, oss.str()); } diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index f8ccdb8fb0238..721353854a474 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -57,7 +57,7 @@ void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptrTryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info); + status = registry->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info); if (status.IsOK()) { return status; } @@ -95,7 +95,7 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, } if (p != nullptr) { - status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info); + status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info); if (status.IsOK()) { return status; } @@ -104,10 +104,14 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for ")); } -bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) { +bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, + const Node& node, + const std::string& provider_type, + const logging::Logger& logger) { const auto kernel_registries = r.GetKernelRegistriesByProviderType(provider_type); return std::any_of(kernel_registries.begin(), kernel_registries.end(), [&](const KernelRegistry* kernel_registry) { - return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver()); + return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver(), + logger); }); } diff --git a/onnxruntime/core/framework/kernel_registry_manager.h b/onnxruntime/core/framework/kernel_registry_manager.h index 1da73208cb536..72f0ed3c6268a 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.h +++ b/onnxruntime/core/framework/kernel_registry_manager.h @@ -67,13 +67,14 @@ class KernelRegistryManager { // This function assumes the node is already assigned to an execution provider // Don't call this function before graph partition is done - Status SearchKernelRegistry(const Node& node, + Status SearchKernelRegistry(const Node& node, const logging::Logger& logger, /*out*/ const KernelCreateInfo** kernel_create_info) const; /** * Whether this node can be run on this provider */ - static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type); + static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type, + const logging::Logger& logger); Status CreateKernel(const Node& node, const IExecutionProvider& execution_provider, diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 0d0b22ff61e01..0ac2271ba09f1 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -178,7 +178,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne bool saving_ort_format) { for (auto& node : graph_.Nodes()) { const KernelCreateInfo* kci = nullptr; - auto status = kernel_registry_manager.SearchKernelRegistry(node, &kci); + auto status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci); if (!status.IsOK() && saving_ort_format) { // if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled. // in that case we assigned the node to that EP but do not compile it into a fused node. @@ -187,7 +187,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne // at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible. // if that's not possible for some reason we can fallback to the CPU EP implementation. node.SetExecutionProviderType(kCpuExecutionProvider); - status = kernel_registry_manager.SearchKernelRegistry(node, &kci); + status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci); } ORT_RETURN_IF_ERROR(status); diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 6f1f1c831d191..5a3cd86b04492 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -9,7 +9,7 @@ #include "core/graph/constants.h" #include "core/graph/contrib_ops/contrib_defs.h" #include "core/graph/contrib_ops/shape_inference_functions.h" -#include "onnx/onnx-ml.pb.h" // ? +#include "core/graph/onnx_protobuf.h" // Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from // ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build @@ -23,7 +23,7 @@ void convTransposeShapeInference(InferenceContext& ctx); void convPoolShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, bool use_dilation, bool require_kernel_shape, int input1Idx, int input2Idx); namespace defs::math::utils { - void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx); +void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx); } } // namespace ONNX_NAMESPACE @@ -822,10 +822,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( } } - if (all_lengths_known) { - output_shape->mutable_dim(axis)->set_dim_value(total_length); - } - })); + if (all_lengths_known) { + output_shape->mutable_dim(axis)->set_dim_value(total_length); + } + })); ONNX_MS_OPERATOR_SET_SCHEMA(QLinearWhere, 1, OpSchema() .SetDoc("Return elements, either from X or Y, depending on condition.") @@ -955,7 +955,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttributeProto::INT, static_cast(0)) .Attr("do_rotary", "Whether to use rotary position embedding. Default value is 0.", AttributeProto::INT, OPTIONAL_VALUE) - .Attr("past_present_share_buffer", "Corresponding past and present are same tensor, its shape is " + .Attr("past_present_share_buffer", + "Corresponding past and present are same tensor, its shape is " "(2, batch_size, num_heads, max_sequence_length, head_size)", AttributeProto::INT, OPTIONAL_VALUE) .Attr("mask_filter_value", diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index 1466de51d0b99..e755b4bfa6364 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -227,11 +227,12 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, #if !defined(DISABLE_SPARSE_TENSORS) // Create execution frame for executing constant nodes. OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_, - is_sparse_initializer_check); + is_sparse_initializer_check, logger); #else // Create execution frame for executing constant nodes. - OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_, - [](std::string const&) { return false; }); + OptimizerExecutionFrame::Info info( + {node}, constant_inputs, graph.ModelPath(), execution_provider_, [](const std::string&) { return false; }, + logger); #endif std::vector fetch_mlvalue_idxs; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 2f2524420dc44..ba2b87b5aa0ca 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -190,6 +190,7 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable, [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, std::unordered_map>* p_buffered_tensors) { @@ -404,7 +405,8 @@ InlinedVector> GenerateTransformers( } auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); - auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); + auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry), + logger); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } @@ -437,6 +439,7 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable, [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, std::unordered_map>* p_buffered_tensors) { @@ -490,7 +493,8 @@ InlinedVector> GenerateTransformersForMinimalB #ifndef DISABLE_CONTRIB_OPS AllocatorPtr cpu_allocator = std::make_shared(); auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); - auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); + auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry), + logger); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 67ebc22dab41d..b1665c7172549 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -84,7 +84,9 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) { // going to a node that will need a Cast. // // Return true if all the fp16 inputs and outputs are connected to nodes that will be cast to fp32. -static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) { +static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, + const KernelRegistry& cpu_kernel_registry, + const logging::Logger& logger) { // we can check if it's an isolated fp16 node // if node has input coming from other nodes (only consuming graph inputs or initializers if it doesn't), // does not have a subgraph (would have to alter subgraph inputs if we cast the input to this node), @@ -211,7 +213,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: const KernelCreateInfo* kernel_create_info{}; const auto lookup_status = cpu_kernel_registry.TryFindKernel( kCpuExecutionProvider, node.OpType(), node.Domain(), - node.SinceVersion(), type_constraint_map, &kernel_create_info); + node.SinceVersion(), type_constraint_map, logger, &kernel_create_info); if (lookup_status.IsOK() && kernel_create_info != nullptr) { return true; } @@ -220,9 +222,10 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: return false; } -static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) { +static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry, + const logging::Logger& logger) { for (auto& node : graph.Nodes()) { - if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) { + if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry, logger)) { // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32 node.SetExecutionProviderType(""); } @@ -319,7 +322,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { return dst_bit_length <= src_bit_length; } - if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { + if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || + (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { return true; } @@ -453,7 +457,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { if (force_cpu_fp32_) - ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_)); + ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_, logger)); GraphViewer graph_viewer(graph); auto& order = graph_viewer.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index ee79fa620374e..cd654991c92d5 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -44,7 +44,9 @@ NhwcConvLookup( return &(iter->second); } -NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry) noexcept +NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, + std::shared_ptr cpu_kernel_registry, + const logging::Logger& logger) noexcept : GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)) { if (!cpu_kernel_registry) { // This is a CPU op nodes optimizer, not useful if cpu EP is not available. @@ -64,7 +66,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, qconv_int8.op_type_, qconv_int8.domain_, - qconv_int8.version_, qconv_int8.type_constraints_, &kernel_create_info); + qconv_int8.version_, qconv_int8.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -83,7 +85,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, qconv_uint8.op_type_, qconv_uint8.domain_, - qconv_uint8.version_, qconv_uint8.type_constraints_, &kernel_create_info); + qconv_uint8.version_, qconv_uint8.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -103,7 +105,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, - nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, &kernel_create_info); + nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -123,7 +125,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_, - nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, &kernel_create_info); + nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -140,7 +142,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_, - nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, &kernel_create_info); + nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -157,7 +159,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_, - nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, &kernel_create_info); + nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( diff --git a/onnxruntime/core/optimizer/nhwc_transformer.h b/onnxruntime/core/optimizer/nhwc_transformer.h index 000732060b889..c65f851fdab9d 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.h +++ b/onnxruntime/core/optimizer/nhwc_transformer.h @@ -75,7 +75,8 @@ and inserts nodes to transpose tensors as needed. class NhwcTransformer : public GraphTransformer { private: public: - explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry) noexcept; + explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry, + const logging::Logger& logger) noexcept; /** * @brief Usually called right after constructor, it shows whether diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index ed7d5feb2beb3..b2e8e491c361c 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -32,9 +32,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, const InitializedTensorSet& initialized_tensor_set, const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func) + const std::function& is_sparse_initializer_func, + const logging::Logger& logger) : execution_provider_(execution_provider), - is_sparse_initializer_func_(is_sparse_initializer_func) { + is_sparse_initializer_func_(is_sparse_initializer_func), + logger_(logger) { allocator_ptr_ = std::make_shared(); ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer"); @@ -79,9 +81,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, const std::unordered_map& initialized_tensor_set, const std::filesystem::path& /* model_path */, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func) + const std::function& is_sparse_initializer_func, + const logging::Logger& logger) : execution_provider_(execution_provider), - is_sparse_initializer_func_(is_sparse_initializer_func) { + is_sparse_initializer_func_(is_sparse_initializer_func), + logger_(logger) { allocator_ptr_ = std::make_shared(); ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer"); @@ -117,7 +121,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, Status OptimizerExecutionFrame::Info::TryFindKernel(const Node* node, const KernelCreateInfo** out) const { std::shared_ptr kernel_registry = execution_provider_.GetKernelRegistry(); const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; - return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, out); + return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, logger_, out); } static Status TryCreateKernel(const Node& node, @@ -128,10 +132,11 @@ static Status TryCreateKernel(const Node& node, FuncManager& funcs_mgr, const DataTransferManager& data_transfer_mgr, const ConfigOptions& config_options, + const logging::Logger& logger, /*out*/ std::unique_ptr& op_kernel) { const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; const KernelCreateInfo* kernel_create_info = nullptr; - ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver, + ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver, logger, &kernel_create_info)); static const AllocatorMap dummy_allocators; @@ -154,7 +159,7 @@ OptimizerExecutionFrame::Info::CreateKernel(const Node* node, const ConfigOption std::shared_ptr kernel_registry = execution_provider_.GetKernelRegistry(); FuncManager func; auto status = TryCreateKernel(*node, *kernel_registry, execution_provider_, initializers_, - ort_value_name_idx_map_, func, data_transfer_mgr_, config_options, + ort_value_name_idx_map_, func, data_transfer_mgr_, config_options, logger_, op_kernel); // Kernel found in the CPU kernel registry diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.h b/onnxruntime/core/optimizer/optimizer_execution_frame.h index b0f7f461661b5..24a23312feba9 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.h +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.h @@ -27,13 +27,15 @@ class OptimizerExecutionFrame final : public IExecutionFrame { const InitializedTensorSet& initialized_tensor_set, const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func); + const std::function& is_sparse_initializer_func, + const logging::Logger& logger); Info(const std::vector& nodes, const std::unordered_map& initialized_tensor_set, const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func); + const std::function& is_sparse_initializer_func, + const logging::Logger& logger); ~Info() = default; @@ -76,6 +78,7 @@ class OptimizerExecutionFrame final : public IExecutionFrame { std::unique_ptr node_index_info_; const IExecutionProvider& execution_provider_; const std::function& is_sparse_initializer_func_; + const logging::Logger& logger_; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Info); }; diff --git a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc index 18e462c04dff3..5538aa54801cc 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc @@ -36,7 +36,7 @@ static inline bool MatchesOpSinceVersion( return std::find(versions.begin(), versions.end(), node.SinceVersion()) != versions.end(); } -static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { +static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const logging::Logger& logger) { constexpr size_t w_idx = 1; constexpr size_t w_zp_idx = 9; constexpr size_t r_idx = 2; @@ -60,7 +60,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_idx]) || !graph.GetInitializedTensor(input_defs[r_idx]->Name(), r_tensor_proto) || r_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) { - LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator," + LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator," << " cannot locate recurrence tensor of const int8 type," << " int8 overflow might impact precision !"; return false; @@ -86,7 +86,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_zp_idx]) || !graph.GetInitializedTensor(input_defs[r_zp_idx]->Name(), r_zp_tensor_proto) || r_zp_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) { - LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator," + LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator," << " unable to locate recurrence tensor or its zero point value," << " int8 overflow might impact precision !"; return false; @@ -171,7 +171,7 @@ Status Avx2WeightS8ToU8Transformer::ApplyImpl(Graph& graph, bool& modified, int if (graph_utils::IsSupportedOptypeVersionAndDomain( op_node, "DynamicQuantizeLSTM", {1}, kMSDomain)) { // This one has two set of quantized arguments - modified |= TryConvertDynamicQuantizeLSTM(op_node, graph); + modified |= TryConvertDynamicQuantizeLSTM(op_node, graph, logger); continue; // go on to next operator node } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index d2240b5d50194..81305f7effa16 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -291,7 +291,8 @@ SelectorManager::SelectorManager() { InitializeSelectorsMap(); } -std::vector SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer) const { +std::vector SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer, + const logging::Logger& logger) const { std::vector qdq_selections; for (auto index : graph_viewer.GetNodesInTopologicalOrder()) { const auto* node = graph_viewer.GetNode(index); @@ -313,7 +314,7 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap const auto& versions = op_versions_and_selector.op_versions_map.find(node->OpType())->second; if (!versions.empty()) { if (std::find(versions.cbegin(), versions.cend(), node->SinceVersion()) == versions.cend()) { - LOGS_DEFAULT(VERBOSE) << "Op version is not supported for" << node->OpType(); + LOGS(logger, VERBOSE) << "Op version is not supported for" << node->OpType(); continue; } } @@ -329,7 +330,7 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap } std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer) { +GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger) { std::vector> node_unit_holder; std::unordered_map node_unit_map; @@ -342,7 +343,7 @@ GetAllNodeUnits(const GraphViewer& graph_viewer) { // Get QDQ NodeUnits first QDQ::SelectorManager selector_mgr; - const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer); + const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer, logger); for (const auto& qdq_selection : qdq_selections) { auto qdq_unit = std::make_unique(graph_viewer, qdq_selection); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h index f388206551172..ccc1844e3e985 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h @@ -15,7 +15,9 @@ #endif namespace onnxruntime { - +namespace logging { +class Logger; +} class GraphViewer; class Node; @@ -65,7 +67,7 @@ class SelectorManager { // Methods that finds and returns a vector of QDQ::NodeGroup in a given graph // Can be used in QDQ support in different EPs - std::vector GetQDQSelections(const GraphViewer& graph_viewer) const; + std::vector GetQDQSelections(const GraphViewer& graph_viewer, const logging::Logger& logger) const; private: Selectors qdq_selectors_; @@ -88,7 +90,7 @@ class SelectorManager { // We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer // library whereas it should be able to be used by an EP with no dependency on optimizers. std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer); +GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger); } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index 6a5a85ce0ff31..8c0136c495403 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -17,13 +17,22 @@ class TransformerMemcpyImpl { TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider) : graph_(graph), provider_(provider) {} - bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter); + bool ModifyGraph(const KernelRegistryManager& schema_registries, + const logging::Logger& logger, + int& copy_node_counter); private: - void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed); - void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries); + void ProcessDefs(onnxruntime::Node& node, + const KernelRegistryManager& kernel_registries, + InitializedTensorSet& initializers_consumed, + const logging::Logger& logger); + void BuildDefsMapping(const onnxruntime::NodeArg* arg, + const KernelRegistryManager& kernel_registries, + const logging::Logger& logger); void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger); - bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed); + bool ProcessInitializers(const KernelRegistryManager& kernel_registries, + const InitializedTensorSet& initializers_consumed, + const logging::Logger& logger); private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TransformerMemcpyImpl); @@ -130,21 +139,21 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi // find defs that require copy for (auto& node : graph_.Nodes()) { // as we process the defs, collect all the initializers consumed at the current graph level - ProcessDefs(node, kernel_registries, initializers_consumed); + ProcessDefs(node, kernel_registries, initializers_consumed, logger); } // for initializers shared by different providers, create dups - if (ProcessInitializers(kernel_registries, initializers_consumed)) + if (ProcessInitializers(kernel_registries, initializers_consumed, logger)) modified = true; for (auto arg : graph_.GetInputs()) - BuildDefsMapping(arg, kernel_registries); + BuildDefsMapping(arg, kernel_registries, logger); for (auto arg : non_provider_input_defs_) - BuildDefsMapping(arg, kernel_registries); + BuildDefsMapping(arg, kernel_registries, logger); for (auto arg : non_provider_output_defs_) - BuildDefsMapping(arg, kernel_registries); + BuildDefsMapping(arg, kernel_registries, logger); for (auto arg : graph_.GetInputs()) // For inputs we need to create a copy node only when the input is connected to both provider @@ -202,8 +211,10 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi return modified; } -void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, - InitializedTensorSet& initializers_consumed) { +void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, + const KernelRegistryManager& kernel_registries, + InitializedTensorSet& initializers_consumed, + const logging::Logger& logger) { auto node_provider_type = node.GetExecutionProviderType(); if ((node_provider_type == provider_) || (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) || @@ -211,7 +222,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg provider_nodes_.insert(&node); // note KernelCreateInfo might be nullptr for custom kernel const KernelCreateInfo* kci = nullptr; - ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, &kci)); + ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, logger, &kci)); bool is_implicit_input = false; auto process_inputs = @@ -278,7 +289,9 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg } // for non_provider defs, collect the nodes that expect it is provider tensor as input/output. -void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries) { +void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, + const KernelRegistryManager& kernel_registries, + const logging::Logger& logger) { for (auto& it : graph_.Nodes()) { if (it.OpType() == "MemcpyFromHost" || it.OpType() == "MemcpyToHost") continue; auto input_it = @@ -296,7 +309,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) || (node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) { const KernelCreateInfo* kci = nullptr; - ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, &kci)); + ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, logger, &kci)); if (arg_input_index != -1) { if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it); } @@ -351,7 +364,9 @@ static const onnxruntime::NodeArg* FindNodeArg(const NodeArgSetType& def_set, co // We duplicate any initializer that is used by both provider nodes and non-provider nodes // to ensure that provider nodes and non-provider nodes don't share initializers, as they // need to stay in different memory locations. -bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed) { +bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries, + const InitializedTensorSet& initializers_consumed, + const logging::Logger& logger) { std::map replacements; for (const auto& pair : initializers_consumed) { const auto& name = pair.first; @@ -383,7 +398,7 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker auto dup_replacements = replacements; const KernelCreateInfo* kci = nullptr; - auto status = kernel_registries.SearchKernelRegistry(*p_node, &kci); + auto status = kernel_registries.SearchKernelRegistry(*p_node, logger, &kci); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); if (kci == nullptr) continue; if (kci->kernel_def == nullptr) continue; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index a799ed743ef52..f954baf3eabae 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1288,15 +1288,15 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe const KernelCreateInfo* cann_kernel_def = kernel_lookup.LookUpKernel(node); if (cann_kernel_def == nullptr) { - LOGS_DEFAULT(INFO) << "CANN kernel not found in registries for Op type: " << node.OpType() - << " node name: " << node.Name(); + LOGS(*GetLogger(), INFO) << "CANN kernel not found in registries for Op type: " << node.OpType() + << " node name: " << node.Name(); continue; } candidates.push_back(node.Index()); } - auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates, *GetLogger()); for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) continue; diff --git a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc index cc68fa6ec399a..442194cb31cbc 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc @@ -151,7 +151,7 @@ bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBu return false; } -#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) +#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) && TARGET_OS_IOS && TARGET_CPU_X86_64 // To Pass IOS pipeline https://dev.azure.com/onnxruntime/onnxruntime/_build?definitionId=134&_a=summary auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && input_params.coreml_version < 7) { diff --git a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc index f161b309a2425..d533b867bd454 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc @@ -133,9 +133,8 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInpu return false; } -#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) - // to pass https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1563483&view=logs&j=f7cc61a9-cc70-56e7-b06c-4668ca17e426 - // ReductionOpTest.ReduceSum_half_bert +#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) && TARGET_OS_IOS && TARGET_CPU_X86_64 + // skip ReductionOpTest.ReduceSum_half_bert because reduce_sum will output all zeros int32_t input_type; GetType(*input_defs[0], input_type, logger); if (node.OpType() == "ReduceSum" && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { diff --git a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc index c8df7c1a43f65..a1b3a18265c70 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc @@ -13,6 +13,10 @@ #include "core/optimizer/initializer.h" #include "core/providers/cpu/tensor/unsqueeze.h" +#ifdef __APPLE__ +#include +#endif + namespace onnxruntime { namespace coreml { @@ -54,32 +58,50 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const } } +#if defined(COREML_ENABLE_MLPROGRAM) +void HandleX86ArchUnsqueezeScalarInput(ModelBuilder& model_builder, + const Node& node, const logging::Logger& logger) { + const auto& input_defs(node.InputDefs()); + TensorShapeVector axes; + GetAxes(model_builder, node, axes); + + std::vector input_shape; + GetShape(*input_defs[0], input_shape, logger); + auto op = model_builder.CreateOperation(node, "reshape"); + AddOperationInput(*op, "x", input_defs[0]->Name()); + TensorShapeVector output_shape = UnsqueezeBase::ComputeOutputShape(TensorShape(input_shape), axes); + AddOperationInput(*op, "shape", model_builder.AddConstant(op->type(), "shape", AsSpan(output_shape))); + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); +} +#endif + Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, [[maybe_unused]] const logging::Logger& logger) const { std::unique_ptr layer = model_builder.CreateNNLayer(node); - const auto& input_defs(node.InputDefs()); auto* coreml_squeeze = layer->mutable_squeeze(); TensorShapeVector axes; GetAxes(model_builder, node, axes); - std::vector input_shape; - GetShape(*input_defs[0], input_shape, logger); #if defined(COREML_ENABLE_MLPROGRAM) + const auto& input_defs(node.InputDefs()); if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; - std::string_view coreml_op_type = node.OpType() == "Squeeze" ? "squeeze" : "reshape"; +#if defined(TARGET_CPU_X86_64) && TARGET_CPU_X86_64 + // expand_dims has limited requirements for static shape, however, X86_64 has a bug that it can't handle scalar input + if (node.OpType() == "Unsqueeze" && input_defs[0]->Shape()->dim_size() < 2) { + HandleX86ArchUnsqueezeScalarInput(model_builder, node, logger); + return Status::OK(); + } +#endif + std::string_view coreml_op_type = node.OpType() == "Squeeze" ? "squeeze" : "expand_dims"; std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); AddOperationInput(*op, "x", input_defs[0]->Name()); - if (coreml_op_type == "squeeze") { - if (!axes.empty()) { - // coreml squeeze op does support negative axes - AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", AsSpan(axes))); - } - } else { - TensorShapeVector output_shape = UnsqueezeBase::ComputeOutputShape(TensorShape(input_shape), axes); - AddOperationInput(*op, "shape", model_builder.AddConstant(op->type(), "shape", AsSpan(output_shape))); + if (!axes.empty()) { + // coreml supports negative axes + AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", AsSpan(axes))); } AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 2a02c1f4124f6..6486942199df7 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -408,7 +408,7 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge : graph_viewer_(graph_viewer), logger_(logger), coreml_version_(coreml_version), - coreml_compute_unit_(coreml_options.ComputeUnits()), + coreml_options_(coreml_options), create_ml_program_(coreml_options.CreateMLProgram()), model_output_path_(GetModelOutputPath(create_ml_program_)), onnx_input_names_(std::move(onnx_input_names)), @@ -989,7 +989,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) { get_sanitized_io_info(std::move(input_output_info_)), std::move(scalar_outputs_), std::move(int64_outputs_), - logger_, coreml_compute_unit_); + logger_, coreml_options_); } else #endif { @@ -999,7 +999,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) { std::move(input_output_info_), std::move(scalar_outputs_), std::move(int64_outputs_), - logger_, coreml_compute_unit_); + logger_, coreml_options_); } return model->LoadModel(); // load using CoreML API, including compilation diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index af47869f7e1c3..e19597cf0dc2e 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -7,6 +7,7 @@ #include "core/graph/graph_viewer.h" #include "core/providers/coreml/builders/coreml_spec.h" #include "core/providers/coreml/model/model.h" +#include "core/providers/coreml/coreml_options.h" #if defined(COREML_ENABLE_MLPROGRAM) // coremltools classes @@ -22,8 +23,6 @@ class StorageWriter; #endif namespace onnxruntime { -class CoreMLOptions; - namespace coreml { class IOpBuilder; @@ -218,7 +217,7 @@ class ModelBuilder { const GraphViewer& graph_viewer_; const logging::Logger& logger_; const int32_t coreml_version_; - const uint32_t coreml_compute_unit_; + CoreMLOptions coreml_options_; const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old) const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel diff --git a/onnxruntime/core/providers/coreml/coreml_options.cc b/onnxruntime/core/providers/coreml/coreml_options.cc index df78f74383871..4ec780208e528 100644 --- a/onnxruntime/core/providers/coreml/coreml_options.cc +++ b/onnxruntime/core/providers/coreml/coreml_options.cc @@ -63,11 +63,14 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option {"MLProgram", COREML_FLAG_CREATE_MLPROGRAM}, {"NeuralNetwork", COREML_FLAG_USE_NONE}, }; - std::unordered_set valid_options = { + const std::unordered_set valid_options = { kCoremlProviderOption_MLComputeUnits, kCoremlProviderOption_ModelFormat, kCoremlProviderOption_RequireStaticInputShapes, kCoremlProviderOption_EnableOnSubgraphs, + kCoremlProviderOption_SpecializationStrategy, + kCoremlProviderOption_ProfileComputePlan, + kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU, }; // Validate the options for (const auto& option : options) { @@ -90,6 +93,16 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option require_static_shape_ = option.second == "1"; } else if (kCoremlProviderOption_EnableOnSubgraphs == option.first) { enable_on_subgraph_ = option.second == "1"; + } else if (kCoremlProviderOption_SpecializationStrategy == option.first) { + if (option.second != "Default" && option.second != "FastPrediction") { + ORT_THROW("Invalid value for option ", option.first, ": ", option.second, + ". Valid values are Default and FastPrediction."); + } + strategy_ = option.second; + } else if (kCoremlProviderOption_ProfileComputePlan == option.first) { + profile_compute_plan_ = option.second == "1"; + } else if (kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU == option.first) { + allow_low_precision_accumulation_on_gpu_ = option.second == "1"; } } } diff --git a/onnxruntime/core/providers/coreml/coreml_options.h b/onnxruntime/core/providers/coreml/coreml_options.h index 8bb748fcd69c9..fd05c96927bd1 100644 --- a/onnxruntime/core/providers/coreml/coreml_options.h +++ b/onnxruntime/core/providers/coreml/coreml_options.h @@ -14,6 +14,9 @@ class CoreMLOptions { bool create_mlprogram_{false}; bool enable_on_subgraph_{false}; uint32_t compute_units_{0}; + std::string strategy_; + bool profile_compute_plan_{false}; + bool allow_low_precision_accumulation_on_gpu_{false}; public: explicit CoreMLOptions(uint32_t coreml_flags); @@ -25,6 +28,9 @@ class CoreMLOptions { bool CreateMLProgram() const { return create_mlprogram_; } bool EnableOnSubgraph() const { return enable_on_subgraph_; } uint32_t ComputeUnits(uint32_t specific_flag = 0xffffffff) const { return compute_units_ & specific_flag; } + bool AllowLowPrecisionAccumulationOnGPU() const { return allow_low_precision_accumulation_on_gpu_; } + bool UseStrategy(std::string_view strategy) const { return strategy_ == strategy; } + bool ProfileComputePlan() const { return profile_compute_plan_ && create_mlprogram_; } private: void ValidateAndParseProviderOption(const ProviderOptions& options); diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index 68ecbe5fb80c4..84b7d741b4714 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -18,6 +18,7 @@ #endif namespace onnxruntime { +class CoreMLOptions; namespace coreml { class Execution; @@ -53,7 +54,7 @@ class Model { std::unordered_map&& input_output_info, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, - const logging::Logger& logger, uint32_t coreml_compute_unit); + const logging::Logger& logger, const CoreMLOptions& coreml_options); ~Model(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model); diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index c8edb64ff55d7..755dbfbd6e68c 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -25,6 +25,7 @@ #include "core/providers/coreml/model/host_utils.h" #include "core/providers/coreml/model/objc_str_utils.h" #include "core/providers/coreml/shape_utils.h" +#include "core/providers/coreml/coreml_options.h" // force the linker to create a dependency on the CoreML framework so that in MAUI usage we don't need // to manually do this @@ -300,6 +301,53 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, return Status::OK(); } +// since __clang_major__ >= 15, MLComputePlan is introduced in +// We are actually ensure the MacOS/IOS version and Xcode version is greater than `macOS 14.4, iOS 17.4`. +// The macro API_AVAILABLE should also be fine. +// Otherwise, the compiler will complain `MLComputePlan` is not defined. +// we define __clang_analyzer__ here is for bypass static analysis +void ProfileComputePlan(NSURL* compileUrl, MLModelConfiguration* config) { +#if defined(__APPLE__) && defined(__clang__) && __clang_major__ >= 15 && !defined(__clang_analyzer__) + if (@available(macOS 14.4, iOS 17.4, *)) { + [MLComputePlan loadContentsOfURL:compileUrl + configuration:config + completionHandler:^(MLComputePlan* _Nullable computePlan, NSError* _Nullable error) { + if (!computePlan) { + NSLog(@"Error loading compute plan: %@", error); + // Handle error. + return; + } + MLModelStructureProgram* program = computePlan.modelStructure.program; + if (!program) { + NSLog(@"Error loading program from compute plan., this is not a mlprogram model"); + return; + } + + MLModelStructureProgramFunction* mainFunction = program.functions[@"main"]; + if (!mainFunction) { + NSLog(@"Error loading main function from program"); + return; + } + + NSArray* operations = mainFunction.block.operations; + NSLog(@"Number of operations, 'const' node is included. : %lu", operations.count); + for (MLModelStructureProgramOperation* operation in operations) { + // Get the compute device usage for the operation. + MLComputePlanDeviceUsage* computeDeviceUsage = [computePlan computeDeviceUsageForMLProgramOperation:operation]; + id preferredDevice = computeDeviceUsage.preferredComputeDevice; + // Get the estimated cost of executing the operation. + MLComputePlanCost* estimatedCost = [computePlan estimatedCostOfMLProgramOperation:operation]; + if (![operation.operatorName isEqualToString:@"const"]) { + NSLog(@"Operation: %@, Device Usage: %@, Estimated Cost: %f", operation.operatorName, preferredDevice, estimatedCost.weight); + } + } + }]; + } else { + NSLog(@"iOS 17.4+/macOS 14.4+ or later is required to use the compute plan API"); + } +#endif +} + // Internal Execution class // This class is part of the model class and handles the calls into CoreML. Specifically, it performs // 1. Compile the model by given path for execution @@ -307,7 +355,7 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, // 3. The compiled model will be removed in dealloc or removed using cleanup function class Execution { public: - Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags); + Execution(const std::string& path, const logging::Logger& logger, const CoreMLOptions& coreml_options); ~Execution(); Status LoadModel(); @@ -320,13 +368,13 @@ Status Predict(const std::unordered_map& inputs, NSString* coreml_model_path_{nil}; NSString* compiled_model_path_{nil}; const logging::Logger& logger_; - uint32_t coreml_compute_unit_{0}; + CoreMLOptions coreml_options_; MLModel* model_{nil}; }; -Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_compute_unit) +Execution::Execution(const std::string& path, const logging::Logger& logger, const CoreMLOptions& coreml_options) : logger_(logger), - coreml_compute_unit_(coreml_compute_unit) { + coreml_options_(coreml_options) { @autoreleasepool { coreml_model_path_ = util::Utf8StringToNSString(path.c_str()); } @@ -395,17 +443,41 @@ Status Predict(const std::unordered_map& inputs, compiled_model_path_ = [compileUrl path]; MLModelConfiguration* config = [[MLModelConfiguration alloc] init]; - - if (coreml_compute_unit_ & COREML_FLAG_USE_CPU_ONLY) { + uint32_t coreml_compute_unit = coreml_options_.ComputeUnits(); + if (coreml_compute_unit & COREML_FLAG_USE_CPU_ONLY) { config.computeUnits = MLComputeUnitsCPUOnly; - } else if (coreml_compute_unit_ & COREML_FLAG_USE_CPU_AND_GPU) { + } else if (coreml_compute_unit & COREML_FLAG_USE_CPU_AND_GPU) { config.computeUnits = MLComputeUnitsCPUAndGPU; - } else if (coreml_compute_unit_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) { + } else if (coreml_compute_unit & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) { config.computeUnits = MLComputeUnitsCPUAndNeuralEngine; // Apple Neural Engine } else { config.computeUnits = MLComputeUnitsAll; } + if (coreml_options_.AllowLowPrecisionAccumulationOnGPU()) { + config.allowLowPrecisionAccumulationOnGPU = YES; + } + +// Set the specialization strategy to FastPrediction for macOS 10.15+ +// since __clang_major__ >= 15, optimizationHints is introduced in +// Same as above comments for why we are checking __clang_major__. +// we define __clang_analyzer__ here is for bypass static analysis +#if defined(__APPLE__) && defined(__clang__) && __clang_major__ >= 15 && !defined(__clang_analyzer__) + if (HAS_COREML8_OR_LATER) { + MLOptimizationHints* optimizationHints = [[MLOptimizationHints alloc] init]; + if (coreml_options_.UseStrategy("FastPrediction")) { + optimizationHints.specializationStrategy = MLSpecializationStrategyFastPrediction; + config.optimizationHints = optimizationHints; + } else if (coreml_options_.UseStrategy("Default")) { + optimizationHints.specializationStrategy = MLSpecializationStrategyDefault; + config.optimizationHints = optimizationHints; + } + } +#endif + if (coreml_options_.ProfileComputePlan()) { + ProfileComputePlan(compileUrl, config); + } + model_ = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error]; if (error != nil || model_ == nil) { @@ -524,8 +596,8 @@ Status Predict(const std::unordered_map& inputs, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, const logging::Logger& logger, - uint32_t coreml_flags) - : execution_(std::make_unique(path, logger, coreml_flags)), + const CoreMLOptions& coreml_options) + : execution_(std::make_unique(path, logger, coreml_options)), model_input_names_(std::move(model_input_names)), model_output_names_(std::move(model_output_names)), input_output_info_(std::move(input_output_info)), diff --git a/onnxruntime/core/providers/coreml/model/model_stub.cc b/onnxruntime/core/providers/coreml/model/model_stub.cc index c6f2e7401ea1e..e9036e2fc7e1a 100644 --- a/onnxruntime/core/providers/coreml/model/model_stub.cc +++ b/onnxruntime/core/providers/coreml/model/model_stub.cc @@ -4,6 +4,7 @@ #include "core/providers/coreml/model/model.h" namespace onnxruntime { +class CoreMLOptions; namespace coreml { class Execution {}; @@ -15,7 +16,7 @@ Model::Model(const std::string& /*path*/, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, const logging::Logger& /*logger*/, - uint32_t /*coreml_flags*/) + const CoreMLOptions& /*coreml_flags*/) : execution_(std::make_unique()), model_input_names_(std::move(model_input_names)), model_output_names_(std::move(model_output_names)), diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 8396e2629d2bf..d4013a7dc3d57 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2693,7 +2693,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // For CUDA EP, exclude the subgraph that is preferred to be placed in CPU // These are usually shape related computation subgraphs // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 35a2c451a49a5..9f95818501dac 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -62,7 +62,8 @@ namespace Dml const auto kernel_type_str_resolver = onnxruntime::OpSchemaKernelTypeStrResolver{}; const auto kernel_lookup = onnxruntime::KernelLookup{provider_type, gsl::make_span(®istry, 1), - kernel_type_str_resolver}; + kernel_type_str_resolver, + logger}; std::vector> compiledPartitionInfos; std::vector additionalSplittingNodes; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp index 6318b0d5e2865..b9b90d6bc17bd 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -54,7 +54,8 @@ namespace Dml const auto kernelLookup = onnxruntime::KernelLookup( providerType, gsl::make_span(®istry, 1), - kernelTypeStrResolver); + kernelTypeStrResolver, + logger); onnxruntime::GraphViewer graphViewer(graph); const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 228dfeb123175..826f48b5f7a68 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -95,7 +95,7 @@ namespace Dml const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const { #ifdef ENABLE_GRAPH_COMPILATION - return m_impl->GetCapability(graph, kernel_lookup); + return m_impl->GetCapability(graph, kernel_lookup, *GetLogger()); #else return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup); #endif @@ -876,7 +876,8 @@ namespace Dml std::vector> ExecutionProviderImpl::GetCapability( const onnxruntime::GraphViewer& graph, - const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::logging::Logger& logger) const { uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask(); // Each bit corresponds to each DML_TENSOR_DATA_TYPE. @@ -900,7 +901,7 @@ namespace Dml } // Get the list of nodes that should stay on the CPU - auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes); + auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes, logger); for (size_t nodeIndex : toplogicalOrder) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 32a5b9add35a0..e7d859c5764de 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -88,7 +88,8 @@ namespace Dml std::vector> GetCapability( const onnxruntime::GraphViewer& graph, - const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::logging::Logger& logger ) const; uint32_t GetSupportedDeviceDataTypeMask() const; diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 4cb40ec8bf5fd..c1a8b373bed84 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -818,7 +818,7 @@ std::vector> JsExecutionProvider::GetCapabili candidates.push_back(node.Index()); tenative_candidates.push_back(node.Index()); } - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger()); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index 12416ea0c121b..e4bee6f959a01 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -32,8 +32,16 @@ namespace nnapi { ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle, gsl::span nnapi_target_devices, - TargetDeviceOption target_device_option) - : nnapi_(nnapi_handle), graph_viewer_(graph_viewer), nnapi_model_{std::make_unique(nnapi_handle)}, shaper_{graph_viewer}, nnapi_target_devices_(nnapi_target_devices), target_device_option_(target_device_option), nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)) { + TargetDeviceOption target_device_option, + const logging::Logger& logger) + : nnapi_(nnapi_handle), + graph_viewer_(graph_viewer), + nnapi_model_{std::make_unique(nnapi_handle)}, + shaper_{graph_viewer}, + nnapi_target_devices_(nnapi_target_devices), + target_device_option_(target_device_option), + nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)), + logger_(logger) { nnapi_model_->nnapi_effective_feature_level_ = nnapi_effective_feature_level_; } @@ -136,7 +144,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const { } void ModelBuilder::PreprocessNodeUnits() { - std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_); + std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_); } // Help to get all quantized operators' input and the NodeUnit(s) using the input diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h index b2118150dd304..4db335afa98b0 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h @@ -14,7 +14,9 @@ struct NnApi; namespace onnxruntime { - +namespace logging { +class Logger; +} class GraphViewer; enum class DataLayout; class NodeUnit; @@ -31,7 +33,8 @@ class ModelBuilder { using Shape = Shaper::Shape; ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle, - gsl::span nnapi_target_devices, TargetDeviceOption target_device_option); + gsl::span nnapi_target_devices, TargetDeviceOption target_device_option, + const logging::Logger& logger); common::Status Compile(std::unique_ptr& model); @@ -173,6 +176,9 @@ class ModelBuilder { // <1,1> <1,2> <1,3> InlinedVector> operations_recorder_; #endif + + const logging::Logger& logger_; + // Convert the ONNX model to ANeuralNetworksModel common::Status Prepare(); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index fca52396a190c..f92c9592742d5 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -81,6 +81,7 @@ NnapiExecutionProvider::~NnapiExecutionProvider() {} std::vector> NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const { + const auto& logger = *GetLogger(); std::vector> result; // TODO: Task 812756: NNAPI EP, add support for subgraph (If and Loop operators) @@ -101,7 +102,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view return ORT_NNAPI_MAX_SUPPORTED_API_LEVEL; #endif }(); - LOGS_DEFAULT(VERBOSE) << "Effective NNAPI feature level: " << android_feature_level; + LOGS(logger, VERBOSE) << "Effective NNAPI feature level: " << android_feature_level; const nnapi::OpSupportCheckParams params{ android_feature_level, @@ -109,7 +110,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view }; if (params.android_feature_level < ORT_NNAPI_MIN_API_LEVEL) { - LOGS_DEFAULT(WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level [" + LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level [" << params.android_feature_level << "] is lower than minimal supported NNAPI API feature level [" << ORT_NNAPI_MIN_API_LEVEL @@ -121,7 +122,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit to be checked for multiple times @@ -150,7 +151,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view node_unit_supported_result[node_unit] = supported; } - LOGS_DEFAULT(VERBOSE) << "Node supported: [" << supported + LOGS(logger, VERBOSE) << "Node supported: [" << supported << "] Operator type: [" << node.OpType() << "] index: [" << node.Index() << "] name: [" << node.Name() @@ -224,9 +225,9 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view // If the graph is partitioned in multiple subgraphs, and this may impact performance, // we want to give users a summary message at warning level. if (num_of_partitions > 1) { - LOGS_DEFAULT(WARNING) << summary_msg; + LOGS(logger, WARNING) << summary_msg; } else { - LOGS_DEFAULT(INFO) << summary_msg; + LOGS(logger, INFO) << summary_msg; } return result; @@ -273,11 +274,13 @@ static Status GetOutputBuffer(Ort::KernelContext& context, common::Status NnapiExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { using namespace android::nn::wrapper; + const auto& logger = *GetLogger(); + for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { Node& fused_node = fused_node_and_graph.fused_node; const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); - nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_); + nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_, logger); builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW); builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16); diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index f1df1abf4c49a..decfe91c598be 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -687,7 +687,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, // Get all the NodeUnits in the graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph, logger); std::unordered_set seen_node_units; const auto& node_indices = src_graph.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 88fa6429fc01e..75973c7031d62 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -104,7 +104,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // valid throughout the lifetime of the ModelBuilder std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // This name must be same with the EPContext node name const auto& graph_name = fused_node.Name(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 3bb069196e31c..060bbd4f79bf2 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -718,7 +718,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // remove is_qnn_ctx_model related code const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 75b8ac7e134f3..0a427b146dcaa 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -2493,7 +2493,7 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // For ROCM EP, exclude the subgraph that is preferred to be placed in CPU // These are usually shape related computation subgraphs // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index b84825236a453..45f81ed22b7f7 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -294,7 +294,8 @@ std::unique_ptr CreateGPUDataTransfer(); std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger); std::string GetEnvironmentVar(const std::string& var_name); @@ -371,8 +372,8 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { namespace QDQ { inline std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer* graph_viewer) { - return g_host->QDQ__GetAllNodeUnits(graph_viewer); +GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) { + return g_host->QDQ__GetAllNodeUnits(graph_viewer, logger); } } // namespace QDQ diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index d3b12f9728135..aa8c367d25d51 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -369,8 +369,9 @@ std::string GetEnvironmentVar(const std::string& var_name) { std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) { - return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger) { + return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); } namespace profiling { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index f9f2bb69a9d1a..7ab93d56cfe26 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -202,7 +202,8 @@ struct ProviderHost { virtual std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) = 0; + gsl::span tentative_nodes, + const logging::Logger& logger) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ bool* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) = 0; @@ -890,7 +891,7 @@ struct ProviderHost { virtual std::unique_ptr NodeUnit__OutputEdgesEnd(const NodeUnit* p) = 0; virtual std::pair>, std::unordered_map> - QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer) = 0; + QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) = 0; // Model virtual std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc index bbf8255ac2940..db8a87d9eaf24 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc @@ -34,7 +34,8 @@ namespace onnxruntime { namespace vsi { namespace npu { -GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(graph_viewer) { +GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger) + : graph_viewer_(graph_viewer), logger_(logger) { Prepare(); context_ = tim::vx::Context::Create(); graph_ = context_->CreateGraph(); @@ -42,7 +43,7 @@ GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(g } bool GraphEP::Prepare() { - std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_); + std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_); for (const auto& node_unit : node_unit_holder_) { auto quant_op_type = util::GetQuantizedOpType(*node_unit); diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h index 49344770d060e..5bb332fad0177 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h @@ -51,7 +51,7 @@ struct NodeIOInfo { class GraphEP { public: - explicit GraphEP(const GraphViewer& graph_viewer); + explicit GraphEP(const GraphViewer& graph_viewer, const logging::Logger& logger); ~GraphEP() {} bool Prepare(); @@ -104,6 +104,7 @@ class GraphEP { // In the form of {input_name, [NodeUnit(s) using the input]} std::unordered_map> all_quantized_op_inputs_; const GraphViewer& graph_viewer_; + const logging::Logger& logger_; // Holder for the NodeUnits in the graph, this will guarantee the NodeUnits is // valid throughout the lifetime of the ModelBuilder diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc index 669c702544de8..7da7cc6cb63ba 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc @@ -62,6 +62,7 @@ VSINPUExecutionProvider::~VSINPUExecutionProvider() {} std::vector> VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const { + const auto& logger = *GetLogger(); std::vector> result; if (graph_viewer.IsSubgraph()) { @@ -82,7 +83,7 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // Get all the NodeUnits in the graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit to be checked for multiple times @@ -174,7 +175,8 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie } Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, - OrtKernelContext* context) { + OrtKernelContext* context, + const logging::Logger& logger) { Ort::KernelContext ctx(context); size_t num_in = ctx.GetInputCount(); const size_t num_inputs = graph_ep->GetGraphInputs().size(); @@ -192,7 +194,7 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, } if (!graph_ep->GetGraph()->Run()) { - LOGS_DEFAULT(ERROR) << "Failed to run graph."; + LOGS(logger, ERROR) << "Failed to run graph."; } for (size_t i = 0; i < ctx.GetOutputCount(); i++) { auto timvx_tensor = graph_ep->GetGraphOutputs()[i]->tensor; @@ -207,12 +209,14 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, Status VSINPUExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { + const auto& logger = *GetLogger(); + for (const auto& fused_node_graph : fused_nodes_and_graphs) { const GraphViewer& graph_viewer = fused_node_graph.filtered_graph; - std::shared_ptr graph_ep = std::make_shared(graph_viewer); + std::shared_ptr graph_ep = std::make_shared(graph_viewer, logger); for (auto tensor : graph_viewer.GetInputsIncludingInitializers()) { - LOGS_DEFAULT(VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#" + LOGS(logger, VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#" << graph_viewer.IsInitializedTensor(tensor->Name()); auto input = std::make_shared(); input->name = tensor->Name(); @@ -220,7 +224,7 @@ Status VSINPUExecutionProvider::Compile(const std::vector& fu graph_ep->GetGraphInputs().push_back(input); } for (auto tensor : graph_viewer.GetOutputs()) { - LOGS_DEFAULT(VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor); + LOGS(logger, VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor); auto output = std::make_shared(); output->name = tensor->Name(); output->is_initializer = false; @@ -236,16 +240,16 @@ Status VSINPUExecutionProvider::Compile(const std::vector& fu if (node != &node_unit.GetNode()) { continue; } - LOGS_DEFAULT(VERBOSE) << "Adding node: [" << node->OpType() << "]"; + LOGS(logger, VERBOSE) << "Adding node: [" << node->OpType() << "]"; vsi::npu::SupportedBuiltinOps().at(node->OpType())->BuildOp(graph_ep.get(), graph_viewer, node_unit); } - LOGS_DEFAULT(INFO) << "Verifying graph"; + LOGS(logger, INFO) << "Verifying graph"; graph_ep->GetCompiled() = graph_ep->GetGraph()->Compile(); if (!graph_ep->GetCompiled()) { - LOGS_DEFAULT(ERROR) << "Failed to verify graph."; + LOGS(logger, ERROR) << "Failed to verify graph."; } else { - LOGS_DEFAULT(INFO) << "Graph has been verified successfully."; + LOGS(logger, INFO) << "Graph has been verified successfully."; } NodeComputeInfo compute_info; @@ -259,7 +263,7 @@ Status VSINPUExecutionProvider::Compile(const std::vector& fu [graph_ep, this](FunctionState /*state*/, const OrtApi* /* api */, OrtKernelContext* context) { std::lock_guard lock(this->GetMutex()); - Status res = ComputeStateFunc(graph_ep.get(), context); + Status res = ComputeStateFunc(graph_ep.get(), context, *GetLogger()); return res; }; diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 9fc886cb69bbf..809616660aa9e 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -11,14 +11,20 @@ namespace onnxruntime { namespace webgpu { Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); - - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") - << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" - << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n " - << output.SetByOffset("global_idx", input.GetByOffset("input_offset")); - + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"); + if (input.NumComponents() != output.NumComponents()) { + const auto& output_indices = shader.AddIndices("output_indices"); + shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n" + << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n " + << " let value = vec4(" << input.GetByOffset("input_offset") << ");\n" + << output.SetByOffset("global_idx", "value"); + } else { + shader.MainFunctionBody() << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n " + << output.SetByOffset("global_idx", input.GetByOffset("input_offset")); + } return Status::OK(); } @@ -28,18 +34,27 @@ Status Expand::ComputeInternal(ComputeContext& context) const { auto output_dims = input_shape_tensor->DataAsSpan(); TensorShape output_shape{}; - ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_tensor->Shape(), output_dims, output_shape)); + TensorShape input_shape = input_tensor->Shape(); + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_shape, output_dims, output_shape)); auto* output_tensor = context.Output(0, output_shape); - uint32_t data_size = gsl::narrow(output_shape.Size()); + const int components_i = input_shape.IsScalar() ? 1 : input_shape[input_shape.NumDimensions() - 1] % 4 == 0 ? 4 + : 1; + const int components_o = output_shape.IsScalar() ? 1 : output_shape[output_shape.NumDimensions() - 1] % 4 == 0 ? 4 + : 1; + uint32_t data_size = gsl::narrow(output_shape.Size() / components_o); + ExpandProgram program{}; program - .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank}}) + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_i}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_o}}) .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .AddUniformVariables({ {data_size}, }); + if (components_i != components_o) { + program.AddIndices(output_shape); + } return context.RunProgram(program); } @@ -55,8 +70,8 @@ Status Expand::ComputeInternal(ComputeContext& context) const { KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ KERNEL_CLASS); -WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedFloatTypes()) -WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedFloatTypes()) +WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedNumberTypes()) +WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedNumberTypes()) } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/flatten.cc b/onnxruntime/core/providers/webgpu/tensor/flatten.cc index 81d28bd3c0fa7..11ded865b6be2 100644 --- a/onnxruntime/core/providers/webgpu/tensor/flatten.cc +++ b/onnxruntime/core/providers/webgpu/tensor/flatten.cc @@ -13,7 +13,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 1, 8, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -21,7 +24,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 9, 10, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -29,7 +35,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 11, 12, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -37,7 +46,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 13, 20, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); ONNX_OPERATOR_KERNEL_EX( @@ -45,8 +57,11 @@ ONNX_OPERATOR_KERNEL_EX( kOnnxDomain, 21, kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1), + (*KernelDefBuilder::Create()) + .Alias(0, 0) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .InputMemoryType(OrtMemTypeCPU, 1), Flatten); } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 90b6862758bd7..66209adf6f1a9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -798,7 +798,8 @@ std::vector> WebGpuExecutionProvider::GetCapa candidates.push_back(node.Index()); tenative_candidates.push_back(node.Index()); } - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates); + + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger()); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) { diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index f36f8283e9bf6..45a87960126cd 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -178,14 +178,31 @@ bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, if (!GetWebNNOpType(onnx_op_type, webnn_op_type)) return false; + return IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits, + webnn_input_output_name, onnx_input_output_name, logger); +} + +bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type, + const std::string& webnn_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger) { + if (wnn_limits[webnn_op_type].isUndefined()) { + LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] is not supported for now"; + return false; + } + if (wnn_limits[webnn_op_type][webnn_input_output_name].isUndefined()) { + LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] doesn't have parameter [" + << webnn_input_output_name << "]"; + return false; + } if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) { - LOGS(logger, VERBOSE) << "[" << onnx_op_type - << "] " << onnx_input_output_name - << " type: [" << onnx_data_type - << "] is not supported for now"; + LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] " << onnx_input_output_name << "'s data type: [" + << onnx_data_type << "] is not supported by WebNN op [" << webnn_op_type << "] for now"; return false; } - return true; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 7fdfc5aefa798..a06f46f1bdf0a 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -340,6 +340,13 @@ bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, const std::string& webnn_input_output_name, const std::string& onnx_input_output_name, const logging::Logger& logger); +bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type, + const std::string& webnn_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger); bool GetBidirectionalBroadcastShape(std::vector& shape_a, std::vector& shape_b, diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index 70fa0f9516c5c..290d16a48dd83 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -29,7 +29,7 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const emscripten::val& wnn_limits, const logging::Logger& logger) const { - if (!HasSupportedInputs(node, wnn_limits, logger)) + if (!HasSupportedInputs(initializers, node, wnn_limits, logger)) return false; if (!HasSupportedOutputs(node, wnn_limits, logger)) @@ -41,7 +41,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons return IsOpSupportedImpl(initializers, node, device_type, logger); } -bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, +bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { @@ -50,10 +50,10 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& } } - return HasSupportedInputsImpl(node, wnn_limits, logger); + return HasSupportedInputsImpl(initializers, node, wnn_limits, logger); } -bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, +bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { // We only check the type of input 0 by default, specific op builder can override this. diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index 9412fa8026fb3..0a4367a71add4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -40,7 +40,7 @@ class BaseOpBuilder : public IOpBuilder { return true; } - virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + virtual bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; @@ -56,7 +56,7 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; - bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input. diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index af82a01b14de5..e14507e8f5aea 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -22,8 +22,8 @@ class BinaryOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -86,8 +86,8 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } -bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool BinaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index 70ebe18c85b86..4b2f04bed0eb1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -21,8 +21,8 @@ class CastOpBuilder : public BaseOpBuilder { // Operator support related. private: - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -86,8 +86,8 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. -bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool CastOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index 1a0d93ae7eada..bac528300e077 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -21,8 +21,8 @@ class ConcatOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -55,8 +55,8 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -bool ConcatOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool ConcatOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 52fcc39ae5418..81e688ea4f8ea 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -29,8 +29,8 @@ class ConvOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -397,8 +397,8 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool ConvOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; // input data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc index 931854d0f33c1..ef713f48b8135 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc @@ -27,8 +27,8 @@ class EinsumOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Helper functions, thanks for DML EP's OperatorHelper. @@ -735,8 +735,8 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -776,11 +776,11 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten: return false; } else if (recognized_operator_type == RecognizedOperatorType::Pairwise) { // Map to WebNN's gemm or matmul - return IsDataTypeSupportedByOp("MatMul", input0_type, wnn_limits, "a", "inputs", logger); + return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger); } else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) { - return IsDataTypeSupportedByOp("ReduceSum", input0_type, wnn_limits, "input", "inputs", logger); + return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger); } else { - return IsDataTypeSupportedByOp("Identity", input0_type, wnn_limits, "input", "inputs", logger); + return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc index 225cfcdfc852c..cb7b7de74e121 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc @@ -20,8 +20,8 @@ class GatherElementsOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -49,7 +49,8 @@ Status GatherElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde // Operator support related. -bool GatherElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, +bool GatherElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc index cb4f85a40ee12..002a1a6a63026 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc @@ -22,8 +22,8 @@ class GatherNDOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -55,8 +55,8 @@ bool GatherNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initial return true; } -bool GatherNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GatherNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& op_type = node.OpType(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index ae9fe3e3f3bd1..88d22f103cadc 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -22,8 +22,8 @@ class GatherOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -69,8 +69,8 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GatherOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& op_type = node.OpType(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 252d49a2f4d4d..5f4e6de8fda98 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -25,8 +25,8 @@ class GemmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -215,8 +215,8 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer return true; } -bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GemmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; // A data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index ffb9b7fbf2e7a..b240e30d38b22 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -26,8 +26,8 @@ class GruOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -187,8 +187,8 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c return true; } -bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GruOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input_X_type = 0; // input data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index d56fdbc08c677..91910f55f37c7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -21,8 +21,8 @@ class LogicalOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -71,8 +71,8 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } -bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool LogicalOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 6213b039fb2f9..33ba22ac3fb5b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -25,8 +25,8 @@ class LstmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -198,8 +198,8 @@ bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool LstmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool LstmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type = 0; // input data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index e111ca412c6e9..40f94186e9ed6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -22,8 +22,8 @@ class MaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -87,8 +87,8 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool MaxMinOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 79ed0393e3044..50e49884bdfa9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -25,8 +25,8 @@ class NormalizationOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -228,7 +228,8 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi return true; } -bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, +bool NormalizationOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index ca15e123d0999..b71507a871bf6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -22,8 +22,8 @@ class QDQOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -118,8 +118,8 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -bool QDQOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool QDQOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type = 0; // input data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc index c786aa468736c..8c70525835059 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc @@ -22,8 +22,8 @@ class ScatterElementsOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -65,7 +65,8 @@ bool ScatterElementsOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* return true; } -bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, +bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc index feb93cc14b7c4..8089b9706886f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc @@ -22,8 +22,8 @@ class ScatterNDOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -57,7 +57,8 @@ bool ScatterNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia return true; } -bool ScatterNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, +bool ScatterNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index d51297f19f1c2..41c66038c2694 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -27,6 +27,8 @@ class SliceOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; // TODO: Support Slice opset < 10, which uses attributes for starts and ends. int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; } }; @@ -161,6 +163,30 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } +bool SliceOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& input = *input_defs[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + // If there is step < 0, check data type support of reverse. + if (input_defs.size() > 4 && input_defs[4]->Exists()) { + std::vector steps; + if (!ReadIntArrayFrom1DTensor(*initializers.at(input_defs[4]->Name()), steps, logger)) + return false; + if (std::any_of(steps.begin(), steps.end(), [](int64_t step) { return step < 0; })) { + if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger)) { + return false; + } + } + } + + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger); +} + void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 4b6cf312074ba..c7b3129c0c85b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -18,8 +18,8 @@ class TernaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -46,8 +46,8 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons return Status::OK(); } -bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool TernaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; // condition data type diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index 12e567e7080b3..ee4e7be0f1f49 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -258,6 +258,7 @@ static void AddComputeCapabilityForEachNodeInNodeUnit( std::vector> XnnpackExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { + const auto& logger = *GetLogger(); std::vector> capabilities; std::shared_ptr registry = GetKernelRegistry(); @@ -268,7 +269,7 @@ std::vector> XnnpackExecutionProvider::GetCap // Get all the NodeUnits in the GraphViewer so we can check if something is in a QDQ node group std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph, logger); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit being checked for multiple times diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 2ff9fa525fa3b..a60ee500a9898 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1644,7 +1644,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations( level <= static_cast(session_options.graph_optimization_level); ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( - static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, + static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, logger, optimizers_to_disable, intra_op_thread_pool, p_buffered_tensors); for (const auto& transformer : transformers) { @@ -1840,7 +1840,8 @@ common::Status InferenceSession::Initialize() { ORT_RETURN_IF_ERROR_SESSIONID_(AddPredefinedTransformers(graph_transformer_mgr_, session_options_.graph_optimization_level, minimal_build_optimization_handling, - record_runtime_optimization_produced_op_schema)); + record_runtime_optimization_produced_op_schema, + *session_logger_)); #ifdef USE_DML const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider); @@ -2112,7 +2113,7 @@ common::Status InferenceSession::Initialize() { std::vector tuning_results; bool found_tuning_results = false; ORT_RETURN_IF_ERROR_SESSIONID_(inference_session_utils::ParseTuningResultsFromModelMetadata( - model_metadata_, tuning_results, found_tuning_results)); + model_metadata_, tuning_results, found_tuning_results, *session_logger_)); if (found_tuning_results) { ORT_RETURN_IF_ERROR_SESSIONID_(SetTuningResults(tuning_results, /*error_on_invalid*/ false, /*auto_enable*/ true)); } @@ -3233,7 +3234,8 @@ common::Status InferenceSession::AddPredefinedTransformers( GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const { + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, + const logging::Logger& logger) const { const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { TransformerLevel level = static_cast(i); @@ -3245,7 +3247,7 @@ common::Status InferenceSession::AddPredefinedTransformers( minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations; if (use_full_build_optimizations) { - return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, + return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, logger, optimizers_to_disable_, GetIntraOpThreadPoolToUse(), session_state_->GetMutableBufferedTensors()); @@ -3257,6 +3259,7 @@ common::Status InferenceSession::AddPredefinedTransformers( record_runtime_optimization_produced_op_schema_fn}} : SatApplyContextVariant{SatDirectApplicationContext{}}; return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, + logger, optimizers_to_disable_, GetIntraOpThreadPoolToUse(), session_state_->GetMutableBufferedTensors()); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 0675f64848fd0..e28ff75345785 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -690,8 +690,9 @@ class InferenceSession { * If we encounter an invalid request, we return an error * back to the user. */ - [[nodiscard]] common::Status ValidateAndParseShrinkArenaString(const std::string& ort_device_list, - /*out*/ InlinedVector& arenas_to_shrink) const; + [[nodiscard]] common::Status ValidateAndParseShrinkArenaString( + const std::string& ort_device_list, + /*out*/ InlinedVector& arenas_to_shrink) const; /* * Performs the shrinkage of arenas requested to be shrunk by the user @@ -708,7 +709,8 @@ class InferenceSession { GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const; + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, + const logging::Logger& logger) const; common::Status TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format); diff --git a/onnxruntime/core/session/inference_session_utils.cc b/onnxruntime/core/session/inference_session_utils.cc index 3436eebda3819..8b9de0c604441 100644 --- a/onnxruntime/core/session/inference_session_utils.cc +++ b/onnxruntime/core/session/inference_session_utils.cc @@ -236,7 +236,8 @@ Status JsonConfigParser::ParseRunOptionsFromModelProto(RunOptions& /*run_options Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, std::vector& results, - bool& key_found) { + bool& key_found, + const logging::Logger& logger) { results.clear(); key_found = false; auto it = metadata.custom_metadata_map.find(kTuningResultsKeys); @@ -245,7 +246,7 @@ Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& met } key_found = true; - LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while loading the model"; + LOGS(logger, INFO) << "Found tuning results in the model file to be used while loading the model"; Status status; ORT_TRY { diff --git a/onnxruntime/core/session/inference_session_utils.h b/onnxruntime/core/session/inference_session_utils.h index a0bcdb9013bf0..f297d928f8a0d 100644 --- a/onnxruntime/core/session/inference_session_utils.h +++ b/onnxruntime/core/session/inference_session_utils.h @@ -19,7 +19,9 @@ using json = nlohmann::json; #endif namespace onnxruntime { - +namespace logging { +class Logger; +} namespace inference_session_utils { // need this value to be accessible in all builds in order to report error for attempted usage in a minimal build @@ -60,7 +62,8 @@ class JsonConfigParser { Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, /*out*/ std::vector& results, - /*out*/ bool& key_found); + /*out*/ bool& key_found, + const logging::Logger& logger); #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 109445c877786..ca6950af0227a 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2803,12 +2803,15 @@ static constexpr OrtApi ort_api_1_to_21 = { &OrtApis::KernelInfoGetAllocator, &OrtApis::AddExternalInitializersFromFilesInMemory, // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) + // End of Version 19 - DO NOT MODIFY ABOVE (see above text for more information) + &OrtApis::CreateLoraAdapter, &OrtApis::CreateLoraAdapterFromArray, &OrtApis::ReleaseLoraAdapter, &OrtApis::RunOptionsAddActiveLoraAdapter, &OrtApis::SetEpDynamicOptions, + // End of Version 20 - DO NOT MODIFY ABOVE (see above text for more information) }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -2840,6 +2843,8 @@ static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265, "Size of version 16 API cannot change"); static_assert(offsetof(OrtApi, SessionOptionsAppendExecutionProvider_OpenVINO_V2) / sizeof(void*) == 275, "Size of version 17 API cannot change"); static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeof(void*) == 279, "Size of version 18 API cannot change"); +// no additions in version 19 +static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Size of version 20 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.21.0", diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index d55fd34d5a8f2..c3832498af584 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -279,8 +279,9 @@ struct ProviderHostImpl : ProviderHost { std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) override { - return onnxruntime::GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger) override { + return onnxruntime::GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ bool* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } @@ -1057,8 +1058,8 @@ struct ProviderHostImpl : ProviderHost { } std::pair>, std::unordered_map> - QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer) override { - return QDQ::GetAllNodeUnits(*graph_viewer); + QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) override { + return QDQ::GetAllNodeUnits(*graph_viewer, logger); } // Model (wrapped) diff --git a/onnxruntime/core/session/standalone_op_invoker.cc b/onnxruntime/core/session/standalone_op_invoker.cc index 9cbf01946e92b..2706448d831cc 100644 --- a/onnxruntime/core/session/standalone_op_invoker.cc +++ b/onnxruntime/core/session/standalone_op_invoker.cc @@ -314,7 +314,8 @@ class StandAloneKernelContext : public OpKernelContext { AllocatorPtr allocator_; }; // StandAloneKernelContext -onnxruntime::Status CreateOpAttr(const char* name, const void* data, int len, OrtOpAttrType type, OrtOpAttr** op_attr) { +onnxruntime::Status CreateOpAttr(const char* name, const void* data, int len, OrtOpAttrType type, + OrtOpAttr** op_attr) { auto attr = std::make_unique(); onnxruntime::Status status = onnxruntime::Status::OK(); attr->set_name(std::string{name}); @@ -410,7 +411,9 @@ onnxruntime::Status CreateOp(_In_ const OrtKernelInfo* info, node_ptr->SetSinceVersion(version); - auto status = kernel_registry->TryFindKernel(*node_ptr, ep->Type(), type_constraint_map, &kernel_create_info); + auto status = kernel_registry->TryFindKernel(*node_ptr, ep->Type(), type_constraint_map, + logging::LoggingManager::DefaultLogger(), // no other logger available + &kernel_create_info); ORT_RETURN_IF_ERROR(status); auto& kernel_def = kernel_create_info->kernel_def; diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index f6fc9ea7662cb..8c69e2d9810b8 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -7,6 +7,8 @@ #include #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" +#include "test/providers/model_tester.h" +#include "test/util/include/current_test_name.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_options.h" @@ -394,5 +396,47 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { } } +TEST(BeamSearchTest, DummyT5) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 16, 6, 14, 1, 15, 6, 14, 1, 15, 2, 3, 4, 15, 6, 14, 1, 15, 6, 14, 2, 16, 6, 14, 1, 15, 6, 14, 1, 14}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + +TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_outer_scope_initializers.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 16, 6, 14, 1, 15, 6, 14, 1, 15, 2, 3, 4, 15, 6, 14, 1, 15, 6, 14, 2, 16, 6, 14, 1, 15, 6, 14, 1, 14}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + +TEST(BeamSearchTest, DummyT5WithSequenceInputIds) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 19, 18, 3, 8, 8, 8, 8, 8, 8, 2, 19, 18, 3, 10, 19, 18, 3, 8, 8, 2, 19, 18, 15, 13, 13, 13, 13, 13, 13}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index a7f8a6424aa50..adab93908cdc4 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -252,6 +252,7 @@ class PlannerTest : public ::testing::Test { void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg, std::unordered_map>& kernel_create_info_map) { + const auto& logger = DefaultLoggingManager().DefaultLogger(); const IExecutionProvider* ep = execution_providers_.Get(*p_node); ASSERT_NE(ep, nullptr); auto info = std::make_unique( @@ -261,7 +262,7 @@ class PlannerTest : public ::testing::Test { op_kernel_infos_.push_back(std::move(info)); const auto kernel_type_str_resolver = OpSchemaKernelTypeStrResolver{}; if (!KernelRegistry::HasImplementationOf(*reg, *p_node, onnxruntime::kCpuExecutionProvider, - kernel_type_str_resolver)) { + kernel_type_str_resolver, logger)) { ASSERT_STATUS_OK(reg->Register( KernelCreateInfo(std::make_unique(kernel_def), [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { @@ -271,7 +272,7 @@ class PlannerTest : public ::testing::Test { } const KernelCreateInfo* kci; - ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", kernel_type_str_resolver, &kci)); + ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", kernel_type_str_resolver, logger, &kci)); kernel_create_info_map.insert({p_node->Index(), gsl::not_null(kci)}); } @@ -283,7 +284,8 @@ class PlannerTest : public ::testing::Test { } } - void CreatePlan(const std::vector& outer_scope_node_args = {}, bool invoke_createPlan_explicityly = true) { + void CreatePlan(const std::vector& outer_scope_node_args = {}, + bool invoke_createPlan_explicityly = true) { state_.reset(new SessionState(graph_, execution_providers_, tp_.get(), nullptr, dtm_, edlm_, DefaultLoggingManager().DefaultLogger(), profiler_, *sess_options_)); EXPECT_EQ(graph_.Resolve(), Status::OK()); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 67d60ea3a4ff6..2ff0b599beebf 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -831,7 +831,8 @@ static void VerifyConstantFoldingWithDequantizeLinear(const std::unordered_mapName() == "ConstantFolding") { ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(transformer), TransformerLevel::Level1)); @@ -4704,7 +4705,8 @@ TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) { // Compare results double per_sample_tolerance = 1e-3; double relative_per_sample_tolerance = 0.0; - auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], per_sample_tolerance, relative_per_sample_tolerance, false); + auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], + per_sample_tolerance, relative_per_sample_tolerance, false); EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; } @@ -4713,7 +4715,8 @@ static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_opt std::make_unique(CPUExecutionProviderInfo()); bool has_gelu_approximation = false; - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, *e.get(), {}); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, *e.get(), + DefaultLoggingManager().DefaultLogger(), {}); for (auto& transformer : transformers) { if (transformer->Name() == "GeluApproximation") { has_gelu_approximation = true; @@ -4728,7 +4731,8 @@ TEST_F(GraphTransformationTests, DoubleQDQRemover_SessionOptionConfig) { auto verify_session_config = [&](bool is_enabled, SessionOptions& session_option) { std::unique_ptr cpu_ep = std::make_unique(CPUExecutionProviderInfo()); bool has_double_qdq_remover = false; - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(), {}); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(), + DefaultLoggingManager().DefaultLogger(), {}); for (auto& transformer : transformers) { if (transformer->Name() == "DoubleQDQPairsRemover") { has_double_qdq_remover = true; diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index 66b74641e41d3..caa64560426af 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -36,9 +36,11 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { std::string l2_transformer = "ConvActivationFusion"; InlinedHashSet disabled = {l1_rule1, l1_transformer, l2_transformer}; CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + const auto& logger = DefaultLoggingManager().DefaultLogger(); - auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep); - auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled); + auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger); + auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger, + disabled); // check ConstantFolding transformer was removed ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); @@ -61,8 +63,9 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { #ifndef DISABLE_CONTRIB_OPS // check that ConvActivationFusion was removed - all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep); - filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled); + all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); + filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger, + disabled); ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif } diff --git a/onnxruntime/test/optimizer/optimizer_test.cc b/onnxruntime/test/optimizer/optimizer_test.cc index 81c1a4ace1e33..b306f026b2dfd 100644 --- a/onnxruntime/test/optimizer/optimizer_test.cc +++ b/onnxruntime/test/optimizer/optimizer_test.cc @@ -27,6 +27,7 @@ namespace test { TEST(OptimizerTest, Basic) { Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); + const logging::Logger& logger = DefaultLoggingManager().DefaultLogger(); auto& graph = model.MainGraph(); constexpr int tensor_dim = 10; @@ -66,22 +67,21 @@ TEST(OptimizerTest, Basic) { auto cpu_execution_provider = std::make_unique(CPUExecutionProviderInfo()); #if !defined(DISABLE_SPARSE_TENSORS) - OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set, - graph.ModelPath(), - *cpu_execution_provider.get(), - [&graph](const std::string& name) -> bool { - return graph.IsSparseInitializer(name); - }); + OptimizerExecutionFrame::Info info( + nodes, initialized_tensor_set, graph.ModelPath(), *cpu_execution_provider.get(), + [&graph](const std::string& name) -> bool { + return graph.IsSparseInitializer(name); + }, + logger); #else - OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set, - graph.ModelPath(), - *cpu_execution_provider.get(), - [](std::string const&) { return false; }); + OptimizerExecutionFrame::Info info( + nodes, initialized_tensor_set, graph.ModelPath(), *cpu_execution_provider.get(), + [](std::string const&) { return false; }, + logger); #endif //! defined(DISABLE_SPARSE_TENSORS) std::vector fetch_mlvalue_idxs{info.GetMLValueIndex("out")}; OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs); - const logging::Logger& logger = DefaultLoggingManager().DefaultLogger(); const ConfigOptions empty_config_options; diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index cfee4a83a4292..043b92d7ef121 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -3928,6 +3928,7 @@ TEST(QDQTransformerTests, QDQPropagation_DQForward_SliceMultipleConsumers) { TEST(QDQTransformerTests, QDQ_Selector_Test) { const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/transform/qdq_conv.onnx"); + const auto& logger = DefaultLoggingManager().DefaultLogger(); SessionOptions so; // We want to keep the graph un-optimized to prevent QDQ transformer to kick in @@ -3962,7 +3963,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { // Check if SelectorManager get a conv qdq group selection as expected { - const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer); + const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer, logger); ASSERT_FALSE(result.empty()); const auto& qdq_group = result.at(0); ASSERT_EQ(std::vector({0, 1, 2}), qdq_group.dq_nodes); @@ -3977,7 +3978,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer, logger); // We should get a single QDQ Node unit in the result ASSERT_EQ(1, node_unit_holder.size()); @@ -4045,7 +4046,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { // Check SelectorManager will get empty result { - const auto result = selector_mgr.GetQDQSelections(partial_graph_viewer); + const auto result = selector_mgr.GetQDQSelections(partial_graph_viewer, logger); ASSERT_TRUE(result.empty()); } } diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 3f2c2cb7f761c..23c3812ebd025 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -135,6 +135,9 @@ namespace perftest { "\t [CoreML only] [MLComputeUnits]:[CPUAndNeuralEngine CPUAndGPU ALL CPUOnly] Specify to limit the backend device used to run the model.\n" "\t [CoreML only] [AllowStaticInputShapes]:[0 1].\n" "\t [CoreML only] [EnableOnSubgraphs]:[0 1].\n" + "\t [CoreML only] [SpecializationStrategy]:[Default FastPrediction].\n" + "\t [CoreML only] [ProfileComputePlan]:[0 1].\n" + "\t [CoreML only] [AllowLowPrecisionAccumulationOnGPU]:[0 1].\n" "\t [Example] [For CoreML EP] -e coreml -i \"ModelFormat|MLProgram MLComputeUnits|CPUAndGPU\"\n" "\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 5db1894a5074b..a96028ed3903e 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -346,7 +346,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); static const std::unordered_set available_keys = {kCoremlProviderOption_MLComputeUnits, kCoremlProviderOption_ModelFormat, kCoremlProviderOption_RequireStaticInputShapes, - kCoremlProviderOption_EnableOnSubgraphs}; + kCoremlProviderOption_EnableOnSubgraphs, + kCoremlProviderOption_SpecializationStrategy, + kCoremlProviderOption_ProfileComputePlan, + kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU}; ParseSessionConfigs(ov_string, provider_options, available_keys); std::unordered_map available_options = { @@ -364,6 +367,12 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); (provider_option.second == "1" || provider_option.second == "0")) { } else if (provider_option.first == kCoremlProviderOption_EnableOnSubgraphs && (provider_option.second == "0" || provider_option.second == "1")) { + } else if (provider_option.first == kCoremlProviderOption_SpecializationStrategy && + (provider_option.second == "Default" || provider_option.second == "FastPrediction")) { + } else if (provider_option.first == kCoremlProviderOption_ProfileComputePlan && + (provider_option.second == "0" || provider_option.second == "1")) { + } else if (provider_option.first == kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU && + (provider_option.second == "0" || provider_option.second == "1")) { } else { ORT_THROW("Invalid value for option ", provider_option.first, ": ", provider_option.second); } diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 9d83c789c5124..b0958e05dc373 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -420,6 +420,7 @@ bool SetEpsForAllNodes(Graph& graph, continue; bool found = false; + const auto& logger = DefaultLoggingManager().DefaultLogger(); for (const auto& ep : execution_providers) { auto provider_type = ep->Type(); @@ -438,7 +439,8 @@ bool SetEpsForAllNodes(Graph& graph, } // Check the EP has an impl for the node from builtin registry. - if (KernelRegistry::HasImplementationOf(*ep->GetKernelRegistry(), node, ep->Type(), kernel_type_str_resolver)) { + if (KernelRegistry::HasImplementationOf(*ep->GetKernelRegistry(), node, ep->Type(), kernel_type_str_resolver, + logger)) { found = true; break; } @@ -451,6 +453,7 @@ bool SetEpsForAllNodes(Graph& graph, std::string_view(kMSInternalNHWCDomain), node.SinceVersion(), type_constraint_map, + logger, &kci); if (status.IsOK() && kci != nullptr) { found = true; @@ -463,7 +466,7 @@ bool SetEpsForAllNodes(Graph& graph, std::any_of(custom_registries->cbegin(), custom_registries->cend(), [&](auto reg) { return KernelRegistry::HasImplementationOf(*reg->GetKernelRegistry(), node, ep->Type(), - kernel_type_str_resolver); + kernel_type_str_resolver, logger); })) { found = true; break; diff --git a/onnxruntime/test/providers/cpu/tensor/expand_test.cc b/onnxruntime/test/providers/cpu/tensor/expand_test.cc index 4b0f4e84ca37d..38e3bc3af6d6b 100644 --- a/onnxruntime/test/providers/cpu/tensor/expand_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/expand_test.cc @@ -167,6 +167,20 @@ TEST(ExpandOpTest, Expand_2x2x1x2x1_float) { test.Run(); } +TEST(ExpandOpTest, Expand_3x1x8_float) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3, 2, 1}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + test.AddInput("data_1", {3}, {3, 1, 8}); + test.AddOutput("result", {3, 2, 8}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, + 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, + 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, + 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, + 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f}); + test.Run(); +} + #ifndef USE_TENSORRT TEST(ExpandOpTest, Expand_scalar_float) { OpTester test("Expand", 8); diff --git a/onnxruntime/test/providers/kernel_compute_test_utils.cc b/onnxruntime/test/providers/kernel_compute_test_utils.cc index 23ec48fa649dd..93e688570631e 100644 --- a/onnxruntime/test/providers/kernel_compute_test_utils.cc +++ b/onnxruntime/test/providers/kernel_compute_test_utils.cc @@ -42,8 +42,9 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { } #endif + const auto& logger = DefaultLoggingManager().DefaultLogger(); Model model("test", false, ModelMetaData(), ORT_TSTR(""), IOnnxRuntimeOpSchemaRegistryList(), - {{domain_, opset_version_}}, {}, DefaultLoggingManager().DefaultLogger()); + {{domain_, opset_version_}}, {}, logger); std::vector input_args; std::unordered_map initializer_map; @@ -89,8 +90,7 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { ASSERT_STATUS_OK(graph.Resolve()); node.SetExecutionProviderType(ep_type); - OptimizerExecutionFrame::Info info({&node}, initializer_map, graph.ModelPath(), *execution_providers.Get(ep_type), - [](std::string const&) { return false; }); + OptimizerExecutionFrame::Info info({&node}, initializer_map, graph.ModelPath(), *execution_providers.Get(ep_type), [](std::string const&) { return false; }, logger); const KernelCreateInfo* kernel_create_info = nullptr; ASSERT_STATUS_OK(info.TryFindKernel(&node, &kernel_create_info)); ASSERT_TRUE(kernel_create_info); @@ -139,7 +139,7 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { #pragma warning(disable : 6387) #endif OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs, outputs); - OpKernelContext op_kernel_context(&frame, kernel.get(), nullptr, nullptr, DefaultLoggingManager().DefaultLogger()); + OpKernelContext op_kernel_context(&frame, kernel.get(), nullptr, nullptr, logger); #ifdef _WIN32 #pragma warning(pop) #endif diff --git a/onnxruntime/test/providers/partitioning_utils_test.cc b/onnxruntime/test/providers/partitioning_utils_test.cc index 5db69489afaef..f1fbb1cea7ea2 100644 --- a/onnxruntime/test/providers/partitioning_utils_test.cc +++ b/onnxruntime/test/providers/partitioning_utils_test.cc @@ -51,7 +51,7 @@ TEST(PartitioningUtilsTest, TestQDQHandling) { std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed, gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map, @@ -82,7 +82,7 @@ static void CheckAllNodesProcessed(const std::function& std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); const auto is_node_supported = [&](const Node& /*node*/) -> bool { return true; diff --git a/onnxruntime/test/testdata/dummy_t5.onnx b/onnxruntime/test/testdata/dummy_t5.onnx new file mode 100644 index 0000000000000..3a3bbf4767523 Binary files /dev/null and b/onnxruntime/test/testdata/dummy_t5.onnx differ diff --git a/onnxruntime/test/testdata/dummy_t5_with_outer_scope_initializers.onnx b/onnxruntime/test/testdata/dummy_t5_with_outer_scope_initializers.onnx new file mode 100644 index 0000000000000..4b36cc9b6eca0 Binary files /dev/null and b/onnxruntime/test/testdata/dummy_t5_with_outer_scope_initializers.onnx differ diff --git a/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx b/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx new file mode 100644 index 0000000000000..5a5c302914890 Binary files /dev/null and b/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx differ diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 87a7cbc0375a4..f1545e96481fa 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -758,7 +758,8 @@ Status TrainingSession::AddPredefinedTransformers( GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn /*record_runtime_optimization_produced_op_schema_fn*/) const { + RecordRuntimeOptimizationProducedNodeOpSchemaFn /*record_runtime_optimization_produced_op_schema_fn*/, + const logging::Logger& /*logger*/) const { ORT_RETURN_IF_NOT( minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations, "Only applying full build optimizations is supported by TrainingSession."); diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 765f88e1c992e..58492dc62400f 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -489,7 +489,8 @@ class TrainingSession : public InferenceSession { GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const override; + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, + const logging::Logger& logger) const override; /** Perform auto-diff to add backward graph into the model. @param weights_to_train a set of weights to be training. diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc index 0944e46ff8eaf..58c173ed90277 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc @@ -139,7 +139,8 @@ void GradientOpTester::Run(int output_index_to_use_as_loss, auto reg = execution_provider->GetKernelRegistry(); const KernelCreateInfo* kci; - auto st = reg->TryFindKernel(node, execution_provider->Type(), kernel_type_str_resolver, &kci); + auto st = reg->TryFindKernel(node, execution_provider->Type(), kernel_type_str_resolver, + DefaultLoggingManager().DefaultLogger(), &kci); if (!st.IsOK()) { // The goal here is unclear. It seems best to leave it to the Session // creation to figure out whether the model can be executed using some diff --git a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc index 548f39bb0150c..1b8699d1de497 100644 --- a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc @@ -23,8 +23,10 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { InlinedHashSet disabled = {l1_rule1, l1_transformer, l2_transformer}; CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); - auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep); - auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled); + const auto& logger = DefaultLoggingManager().DefaultLogger(); + auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger); + auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger, + disabled); // check ConstantFolding transformer was removed ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); @@ -47,8 +49,8 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { #ifndef DISABLE_CONTRIB_OPS // check that ConvActivationFusion was removed - all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep); - filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled); + all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); + filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger, disabled); ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif } diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index fdc85e2cafa31..6ee37b8b0519e 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1559,6 +1559,9 @@ def generate_build_tree( # The "/profile" flag implies "/DEBUG:FULL /DEBUGTYPE:cv,fixup /OPT:REF /OPT:NOICF /INCREMENTAL:NO /FIXED:NO". We set it for satisfying a Microsoft internal compliance requirement. External users # do not need to have it. ldflags = ["/profile", "/DYNAMICBASE"] + # Address Sanitizer libs do not have a Qspectre version. So they two cannot be both enabled. + if not args.enable_address_sanitizer: + cflags += ["/Qspectre"] if config == "Release": cflags += ["/O2", "/Ob2", "/DNDEBUG"] elif config == "RelWithDebInfo":