From 38199614505e0b4f7eb018acd84f81c3184a86e3 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 29 Nov 2023 04:46:08 +0000 Subject: [PATCH] Add ORT_MIGRAPHX_SET_FAST_MATH env option and api hooks Allow users to set the fast math option for MIGraphX compilation for quantized data types (fp16) This allows us to toggle whether we can use faster math with the tradeoff of accuracy. --- .../migraphx/migraphx_execution_provider.cc | 13 ++++++++++--- .../migraphx/migraphx_execution_provider.h | 3 +++ .../migraphx/migraphx_execution_provider_info.cc | 4 ++++ .../migraphx/migraphx_execution_provider_info.h | 1 + .../providers/migraphx/migraphx_provider_factory.cc | 2 ++ onnxruntime/python/onnxruntime_pybind_state.cc | 8 ++++++++ 6 files changed, 28 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 2ab54e285c98d..4cb9a04c5aebf 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -114,6 +114,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); } + // whether fp16 is enable + const std::string fast_math_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFastMathOptimization); + if (!fast_math_env.empty()) { + fast_math_enable_ = (std::stoi(fast_math_enable_env) == 0 ? false : true); + } + // whether int8 is enabled const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); if (!int8_enable_env.empty()) { @@ -168,6 +174,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " << "device_id: " << device_id_ << ", migraphx_fp16_enable: " << fp16_enable_ + << ", migraphx_fast_math: " << fast_math_enable_ << ", migraphx_int8_enable: " << int8_enable_ << ", dump_model_ops: " << dump_model_ops_ << ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ @@ -1145,7 +1152,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::quantize_int8(prog, t_, quant_opts); } migraphx::compile_options co; - co.set_fast_math(false); + co.set_fast_math(fast_math_enable_); prog.compile(t_, co); auto prog_output_shapes = prog.get_output_shapes(); for (std::size_t i = 0; i < output_names.size(); ++i) { @@ -1165,7 +1172,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, + map_no_input_shape_[context->node_name], fp16_enable_, fast_math_enable_, int8_enable_, int8_calibration_cache_available_, dynamic_range_map, dump_model_ops_}; *state = p.release(); return 0; @@ -1265,7 +1272,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } migraphx::compile_options co; - co.set_fast_math(false); + co.set_fast_math(fast_math_enable); prog.compile(t, co); mgx_state->prog = prog; param_shapes = prog.get_parameter_shapes(); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index c094be51012e4..69b1d180316cd 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -26,6 +26,7 @@ static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; +static const char kSetFastMathOptimization[] = "ORT_MIGRAPHX_SET_FAST_MATH"; }; // namespace migraphx_env_vars // Information to construct kernel function state. @@ -41,6 +42,7 @@ struct MIGraphXFuncState { OrtMutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool fast_math_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; @@ -78,6 +80,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { private: bool fp16_enable_ = false; + bool fast_math_enable_ = false; bool int8_enable_ = false; std::string int8_calibration_cache_name_; bool int8_calibration_cache_available_ = false; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index b7d7a77853df6..85843765a1e57 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -14,6 +14,7 @@ namespace migraphx { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kFastMathEnable = "migx_fast_math_enable"; constexpr const char* kInt8Enable = "migx_int8_enable"; constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; @@ -38,6 +39,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(migraphx::provider_option_names::kFastMathEnable, info.fast_math_enable) .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) .Parse(options)); @@ -48,6 +50,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, + {migraphx::provider_option_names::kFastMathEnable, MakeStringWithClassicLocale(info.fast_math_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, }; return options; @@ -57,6 +60,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, + {migraphx::provider_option_names::kFastMathEnable, MakeStringWithClassicLocale(info.migraphx_fast_math_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, }; return options; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 18ac30fdc1283..95e8538b12ca4 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -16,6 +16,7 @@ struct MIGraphXExecutionProviderInfo { std::string target_device; int device_id{0}; bool fp16_enable{false}; + bool fast_math_enable{false}; bool int8_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index f985682ddc735..c4dfdf95416c5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -47,6 +47,7 @@ struct MIGraphX_Provider : Provider { info.device_id = options.device_id; info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; + info.fast_math_enable = options.migraphx_fast_math_enable; info.int8_enable = options.migraphx_int8_enable; info.int8_calibration_table_name = ""; if (options.migraphx_int8_calibration_table_name != nullptr) { @@ -61,6 +62,7 @@ struct MIGraphX_Provider : Provider { auto& migx_options = *reinterpret_cast(provider_options); migx_options.device_id = internal_options.device_id; migx_options.migraphx_fp16_enable = internal_options.fp16_enable; + migx_options.migraphx_fast_math_enable = internal_options.fast_math_enable; migx_options.migraphx_int8_enable = internal_options.int8_enable; char* dest = nullptr; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 8423dcfbadc58..800fabbd88a6d 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -734,6 +734,7 @@ std::unique_ptr CreateExecutionProviderInstance( 0, 0, 0, + 0, nullptr}; for (auto option : it->second) { if (option.first == "device_id") { @@ -752,6 +753,13 @@ std::unique_ptr CreateExecutionProviderInstance( "[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } + } + else if (option.first == "migraphx_set_fast_math") { + if (option.second == "True" || option.second == "true") { + params.migraphx_fast_math_enable = true; + } else { + params.migraphx_fast_math_enable = false; + } } else if (option.first == "migraphx_int8_enable") { if (option.second == "True" || option.second == "true") { params.migraphx_int8_enable = true;