Skip to content

Commit

Permalink
fixup! Add ORT_MIGRAPHX_SET_FAST_MATH env option and api hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Ted Themistokleous committed Nov 29, 2023
1 parent 3819961 commit d369a15
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ typedef struct OrtTensorRTProviderOptions {
typedef struct OrtMIGraphXProviderOptions {
int device_id; // hip device id.
int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true
int migraphx_fast_math_enable; // MIGraphX Fast Math Optimize. Default 0 = false, nonzero = true
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
}

// whether fp16 is enable
const std::string fast_math_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFastMathOptimization);
if (!fast_math_env.empty()) {
const std::string fast_math_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSetFastMathOptimization);
if (!fast_math_enable_env.empty()) {
fast_math_enable_ = (std::stoi(fast_math_enable_env) == 0 ? false : true);
}

Expand Down Expand Up @@ -1195,6 +1195,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
migraphx::onnx_options& cmp_options = mgx_state->options;
bool& no_input_shape = mgx_state->no_input_shape;
bool fp16_enable = mgx_state->fp16_enable;
bool fast_math_enable = mgx_state->fast_math_enable;
bool int8_enable = mgx_state->int8_enable;
bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available;

Expand Down

0 comments on commit d369a15

Please sign in to comment.