diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index 4800c48744236..26f8987c76623 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -787,7 +787,7 @@ def main(): logger.error("fp16 is for GPU only") return - if args.precision == Precision.INT8 and args.use_gpu and args.provider != "migraphx": + if args.precision == Precision.INT8 and args.use_gpu and args.provider not in ["migraphx", "rocm"]: logger.error("int8 is for CPU only") return