Skip to content

Commit

Permalink
Change driver verify to check for fp16 and --fp16 (#2334)
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 authored Oct 18, 2023
1 parent 94bda24 commit 5139b93
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 38 deletions.
7 changes: 7 additions & 0 deletions src/driver/argument_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ struct value_parser
}
};

// version for std::optional object
template <class T>
struct value_parser<std::optional<T>>
{
static T apply(const std::string& x) { return value_parser<T>::apply(x); }
};

struct argument_parser
{
struct argument
Expand Down
49 changes: 11 additions & 38 deletions src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,22 +540,17 @@ struct params : command<params>
struct verify : command<verify>
{
compiler c;
// Set to -1. as nonsense initial value
double rms_tol = -1.0;
double atol = -1.0;
double rtol = -1.0;
std::optional<double> rms_tol;
std::optional<double> atol;
std::optional<double> rtol;
bool per_instruction = false;
bool reduce = false;
void parse(argument_parser& ap)
{
c.parse(ap);
ap(rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error (Default: 0.001)"));
ap(atol,
{"--atol"},
ap.help("Tolerance for the elementwise absolute difference (Default: 0.001)"));
ap(rtol,
{"--rtol"},
ap.help("Tolerance for the elementwise relative difference (Default: 0.001)"));
ap(rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error"));
ap(atol, {"--atol"}, ap.help("Tolerance for the elementwise absolute difference"));
ap(rtol, {"--rtol"}, ap.help("Tolerance for the elementwise relative difference"));
ap(per_instruction,
{"-i", "--per-instruction"},
ap.help("Verify each instruction"),
Expand All @@ -572,33 +567,6 @@ struct verify : command<verify>
auto t = c.ct.get_target();
auto m = c.parameters.generate(p, t, true, c.l.batch);

// TODO remove this and make the driver able to figure out datatype most used in the model
// then set the tolerances appropriately. Need to check here because c.to_fp16 only set
// after argument_parser.parse() is run. This code is complicated because there's not a
// good way to change the default tolerances after reading `--fp16` but before reading
// `--rms-tol`, `--atol`, and `--rtol`.
migraphx::verify::tolerance tols{};
if(c.to_fp16)
{
tols = migraphx::verify::tolerance{8e-2, 4e-2, 4e-2};
}
if(not float_equal(this->rms_tol, -1.0))
{
tols.rms_tol = this->rms_tol;
}
if(not float_equal(this->atol, -1.0))
{
tols.atol = this->atol;
}
if(not float_equal(this->rtol, -1.0))
{
tols.rtol = this->rtol;
}

std::cout << "rms_tol: " << tols.rms_tol << std::endl;
std::cout << "atol: " << tols.atol << std::endl;
std::cout << "rtol: " << tols.rtol << std::endl;

auto quantize = precision::fp32;
if(c.to_fp16)
{
Expand All @@ -609,6 +577,11 @@ struct verify : command<verify>
quantize = precision::int8;
}

auto tols = get_tolerances(p, quantize, rms_tol, atol, rtol);
std::cout << "rms_tol: " << tols.rms_tol << std::endl;
std::cout << "atol: " << tols.atol << std::endl;
std::cout << "rtol: " << tols.rtol << std::endl;

if(per_instruction)
{
verify_instructions(p, t, c.co, quantize, tols);
Expand Down
36 changes: 36 additions & 0 deletions src/driver/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,42 @@ namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {

/**
* Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults.
* Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the
* model.
*/
verify::tolerance get_tolerances(const program& p,
precision quantize,
std::optional<double> rms_tol,
std::optional<double> atol,
std::optional<double> rtol)
{
bool has_fp16 = any_of(p.get_modules(), [](auto&& m) {
return any_of(*m, [](auto&& ins) { return (ins.get_shape().type() == shape::half_type); });
});
migraphx::verify::tolerance result{};
if(has_fp16 or quantize == precision::fp16)
{
result.rms_tol = 8e-2;
result.atol = 4e-2;
result.rtol = 4e-2;
}
if(rms_tol)
{
result.rms_tol = *rms_tol;
}
if(atol)
{
result.atol = *atol;
}
if(rtol)
{
result.rtol = *rtol;
}
return result;
}

std::vector<argument> run_ref(program p, const parameter_map& inputs)
{
p.compile(migraphx::make_target("ref"));
Expand Down
6 changes: 6 additions & 0 deletions src/driver/verify.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {

verify::tolerance get_tolerances(const program& p,
precision quantize,
std::optional<double> rms_tol,
std::optional<double> atol,
std::optional<double> rtol);

void verify_program(const std::string& name,
const program& p,
const target& t,
Expand Down

0 comments on commit 5139b93

Please sign in to comment.