From 8de08307be610a6b3921e0613fcca68119d0195a Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 12 Oct 2023 19:41:39 -0500 Subject: [PATCH 1/3] Change to check for fp16 and --fp16 --- src/driver/main.cpp | 88 +++++++++++++++++++++++++-------------------- 1 file changed, 50 insertions(+), 38 deletions(-) diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 4978f0c7f6b..31f9138bda7 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -1,5 +1,5 @@ /* - * The MIT License (MIT) +ii The MIT License (MIT) * * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * @@ -41,6 +41,8 @@ #include #include +#include +#include #include #include #include @@ -537,25 +539,57 @@ struct params : command } }; +/** + * Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults. + * Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the model. + */ +verify::tolerance get_tolerances(const program& p, precision quantize, std::optional rms_tol, std::optional atol, std::optional rtol) +{ + bool has_fp16 = any_of(p.get_modules(), [](auto&& m) { + return any_of(*m, [](auto&& ins) { + return (ins.get_shape().type() == shape::half_type); + }); + }); + migraphx::verify::tolerance result{}; + if(has_fp16 or quantize == precision::fp16) + { + result.rms_tol = 8e-2; + result.atol = 4e-2; + result.rtol = 4e-2; + } + if(rms_tol) + { + result.rms_tol = *rms_tol; + } + if(atol) + { + result.atol = *atol; + } + if(rtol) + { + result.rtol = *rtol; + } + return result; +} + struct verify : command { compiler c; - // Set to -1. as nonsense initial value - double rms_tol = -1.0; - double atol = -1.0; - double rtol = -1.0; + std::optional rms_tol; + std::optional atol; + std::optional rtol; bool per_instruction = false; bool reduce = false; void parse(argument_parser& ap) { c.parse(ap); - ap(rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error (Default: 0.001)")); - ap(atol, + ap(*rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error")); + ap(*atol, {"--atol"}, - ap.help("Tolerance for the elementwise absolute difference (Default: 0.001)")); - ap(rtol, + ap.help("Tolerance for the elementwise absolute difference")); + ap(*rtol, {"--rtol"}, - ap.help("Tolerance for the elementwise relative difference (Default: 0.001)")); + ap.help("Tolerance for the elementwise relative difference")); ap(per_instruction, {"-i", "--per-instruction"}, ap.help("Verify each instruction"), @@ -571,34 +605,7 @@ struct verify : command auto t = c.ct.get_target(); auto m = c.parameters.generate(p, t, true, c.l.batch); - - // TODO remove this and make the driver able to figure out datatype most used in the model - // then set the tolerances appropriately. Need to check here because c.to_fp16 only set - // after argument_parser.parse() is run. This code is complicated because there's not a - // good way to change the default tolerances after reading `--fp16` but before reading - // `--rms-tol`, `--atol`, and `--rtol`. - migraphx::verify::tolerance tols{}; - if(c.to_fp16) - { - tols = migraphx::verify::tolerance{8e-2, 4e-2, 4e-2}; - } - if(not float_equal(this->rms_tol, -1.0)) - { - tols.rms_tol = this->rms_tol; - } - if(not float_equal(this->atol, -1.0)) - { - tols.atol = this->atol; - } - if(not float_equal(this->rtol, -1.0)) - { - tols.rtol = this->rtol; - } - - std::cout << "rms_tol: " << tols.rms_tol << std::endl; - std::cout << "atol: " << tols.atol << std::endl; - std::cout << "rtol: " << tols.rtol << std::endl; - + auto quantize = precision::fp32; if(c.to_fp16) { @@ -609,6 +616,11 @@ struct verify : command quantize = precision::int8; } + auto tols = get_tolerances(p, quantize, rms_tol, atol, rtol); + std::cout << "rms_tol: " << tols.rms_tol << std::endl; + std::cout << "atol: " << tols.atol << std::endl; + std::cout << "rtol: " << tols.rtol << std::endl; + if(per_instruction) { verify_instructions(p, t, c.co, quantize, tols); From 9dbca56bd813f6d29bc596adb0036dba0c270511 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 16 Oct 2023 14:04:27 -0500 Subject: [PATCH 2/3] Fix arg_parser for std::optional, review updates --- src/driver/argument_parser.hpp | 7 +++++ src/driver/main.cpp | 47 +++------------------------------- src/driver/verify.cpp | 36 ++++++++++++++++++++++++++ src/driver/verify.hpp | 6 +++++ 4 files changed, 53 insertions(+), 43 deletions(-) diff --git a/src/driver/argument_parser.hpp b/src/driver/argument_parser.hpp index 683df8478d3..a0ee434a2b6 100644 --- a/src/driver/argument_parser.hpp +++ b/src/driver/argument_parser.hpp @@ -187,6 +187,13 @@ struct value_parser } }; +// version for std::optional object +template +struct value_parser> +{ + static T apply(const std::string& x) { return value_parser::apply(x); } +}; + struct argument_parser { struct argument diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 31f9138bda7..5770dffeb51 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -1,5 +1,5 @@ /* -ii The MIT License (MIT) + * The MIT License (MIT) * * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * @@ -41,8 +41,6 @@ ii The MIT License (MIT) #include #include -#include -#include #include #include #include @@ -539,39 +537,6 @@ struct params : command } }; -/** - * Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults. - * Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the model. - */ -verify::tolerance get_tolerances(const program& p, precision quantize, std::optional rms_tol, std::optional atol, std::optional rtol) -{ - bool has_fp16 = any_of(p.get_modules(), [](auto&& m) { - return any_of(*m, [](auto&& ins) { - return (ins.get_shape().type() == shape::half_type); - }); - }); - migraphx::verify::tolerance result{}; - if(has_fp16 or quantize == precision::fp16) - { - result.rms_tol = 8e-2; - result.atol = 4e-2; - result.rtol = 4e-2; - } - if(rms_tol) - { - result.rms_tol = *rms_tol; - } - if(atol) - { - result.atol = *atol; - } - if(rtol) - { - result.rtol = *rtol; - } - return result; -} - struct verify : command { compiler c; @@ -583,13 +548,9 @@ struct verify : command void parse(argument_parser& ap) { c.parse(ap); - ap(*rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error")); - ap(*atol, - {"--atol"}, - ap.help("Tolerance for the elementwise absolute difference")); - ap(*rtol, - {"--rtol"}, - ap.help("Tolerance for the elementwise relative difference")); + ap(rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error")); + ap(atol, {"--atol"}, ap.help("Tolerance for the elementwise absolute difference")); + ap(rtol, {"--rtol"}, ap.help("Tolerance for the elementwise relative difference")); ap(per_instruction, {"-i", "--per-instruction"}, ap.help("Verify each instruction"), diff --git a/src/driver/verify.cpp b/src/driver/verify.cpp index df028a693ed..8a6c96b200b 100644 --- a/src/driver/verify.cpp +++ b/src/driver/verify.cpp @@ -36,6 +36,42 @@ namespace migraphx { namespace driver { inline namespace MIGRAPHX_INLINE_NS { +/** + * Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults. + * Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the + * model. + */ +verify::tolerance get_tolerances(const program& p, + precision quantize, + std::optional rms_tol, + std::optional atol, + std::optional rtol) +{ + bool has_fp16 = any_of(p.get_modules(), [](auto&& m) { + return any_of(*m, [](auto&& ins) { return (ins.get_shape().type() == shape::half_type); }); + }); + migraphx::verify::tolerance result{}; + if(has_fp16 or quantize == precision::fp16) + { + result.rms_tol = 8e-2; + result.atol = 4e-2; + result.rtol = 4e-2; + } + if(rms_tol) + { + result.rms_tol = *rms_tol; + } + if(atol) + { + result.atol = *atol; + } + if(rtol) + { + result.rtol = *rtol; + } + return result; +} + std::vector run_ref(program p, const parameter_map& inputs) { p.compile(migraphx::make_target("ref")); diff --git a/src/driver/verify.hpp b/src/driver/verify.hpp index 63ac161f252..582501c017a 100644 --- a/src/driver/verify.hpp +++ b/src/driver/verify.hpp @@ -32,6 +32,12 @@ namespace migraphx { namespace driver { inline namespace MIGRAPHX_INLINE_NS { +verify::tolerance get_tolerances(const program& p, + precision quantize, + std::optional rms_tol, + std::optional atol, + std::optional rtol); + void verify_program(const std::string& name, const program& p, const target& t, From e970cb40040ba57c7bd1f1b98680560a3ddaf6cd Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 16 Oct 2023 14:44:02 -0500 Subject: [PATCH 3/3] formatting --- src/driver/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 5770dffeb51..15da9b2d716 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -566,7 +566,7 @@ struct verify : command auto t = c.ct.get_target(); auto m = c.parameters.generate(p, t, true, c.l.batch); - + auto quantize = precision::fp32; if(c.to_fp16) {