Skip to content

Commit

Permalink
Merge branch 'develop' into dependabot/pip/test/py/numpy-1.22.0
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored Oct 19, 2023
2 parents f5b1db0 + 6e86734 commit 249b924
Show file tree
Hide file tree
Showing 58 changed files with 2,603 additions and 124 deletions.
7 changes: 7 additions & 0 deletions src/driver/argument_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ struct value_parser
}
};

// version for std::optional object
template <class T>
struct value_parser<std::optional<T>>
{
static T apply(const std::string& x) { return value_parser<T>::apply(x); }
};

struct argument_parser
{
struct argument
Expand Down
49 changes: 11 additions & 38 deletions src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,22 +540,17 @@ struct params : command<params>
struct verify : command<verify>
{
compiler c;
// Set to -1. as nonsense initial value
double rms_tol = -1.0;
double atol = -1.0;
double rtol = -1.0;
std::optional<double> rms_tol;
std::optional<double> atol;
std::optional<double> rtol;
bool per_instruction = false;
bool reduce = false;
void parse(argument_parser& ap)
{
c.parse(ap);
ap(rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error (Default: 0.001)"));
ap(atol,
{"--atol"},
ap.help("Tolerance for the elementwise absolute difference (Default: 0.001)"));
ap(rtol,
{"--rtol"},
ap.help("Tolerance for the elementwise relative difference (Default: 0.001)"));
ap(rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error"));
ap(atol, {"--atol"}, ap.help("Tolerance for the elementwise absolute difference"));
ap(rtol, {"--rtol"}, ap.help("Tolerance for the elementwise relative difference"));
ap(per_instruction,
{"-i", "--per-instruction"},
ap.help("Verify each instruction"),
Expand All @@ -572,33 +567,6 @@ struct verify : command<verify>
auto t = c.ct.get_target();
auto m = c.parameters.generate(p, t, true, c.l.batch);

// TODO remove this and make the driver able to figure out datatype most used in the model
// then set the tolerances appropriately. Need to check here because c.to_fp16 only set
// after argument_parser.parse() is run. This code is complicated because there's not a
// good way to change the default tolerances after reading `--fp16` but before reading
// `--rms-tol`, `--atol`, and `--rtol`.
migraphx::verify::tolerance tols{};
if(c.to_fp16)
{
tols = migraphx::verify::tolerance{8e-2, 4e-2, 4e-2};
}
if(not float_equal(this->rms_tol, -1.0))
{
tols.rms_tol = this->rms_tol;
}
if(not float_equal(this->atol, -1.0))
{
tols.atol = this->atol;
}
if(not float_equal(this->rtol, -1.0))
{
tols.rtol = this->rtol;
}

std::cout << "rms_tol: " << tols.rms_tol << std::endl;
std::cout << "atol: " << tols.atol << std::endl;
std::cout << "rtol: " << tols.rtol << std::endl;

auto quantize = precision::fp32;
if(c.to_fp16)
{
Expand All @@ -609,6 +577,11 @@ struct verify : command<verify>
quantize = precision::int8;
}

auto tols = get_tolerances(p, quantize, rms_tol, atol, rtol);
std::cout << "rms_tol: " << tols.rms_tol << std::endl;
std::cout << "atol: " << tols.atol << std::endl;
std::cout << "rtol: " << tols.rtol << std::endl;

if(per_instruction)
{
verify_instructions(p, t, c.co, quantize, tols);
Expand Down
36 changes: 36 additions & 0 deletions src/driver/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,42 @@ namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {

/**
* Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults.
* Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the
* model.
*/
verify::tolerance get_tolerances(const program& p,
precision quantize,
std::optional<double> rms_tol,
std::optional<double> atol,
std::optional<double> rtol)
{
bool has_fp16 = any_of(p.get_modules(), [](auto&& m) {
return any_of(*m, [](auto&& ins) { return (ins.get_shape().type() == shape::half_type); });
});
migraphx::verify::tolerance result{};
if(has_fp16 or quantize == precision::fp16)
{
result.rms_tol = 8e-2;
result.atol = 4e-2;
result.rtol = 4e-2;
}
if(rms_tol)
{
result.rms_tol = *rms_tol;
}
if(atol)
{
result.atol = *atol;
}
if(rtol)
{
result.rtol = *rtol;
}
return result;
}

std::vector<argument> run_ref(program p, const parameter_map& inputs)
{
p.compile(migraphx::make_target("ref"));
Expand Down
6 changes: 6 additions & 0 deletions src/driver/verify.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {

verify::tolerance get_tolerances(const program& p,
precision quantize,
std::optional<double> rms_tol,
std::optional<double> atol,
std::optional<double> rtol);

void verify_program(const std::string& name,
const program& p,
const target& t,
Expand Down
58 changes: 49 additions & 9 deletions src/include/migraphx/op/reshape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

/**
* 1 input version:
* reshape(input_data)
* this.dims = output_dims
* Makes a copy of input_data to the output shape.
*
* 2 input version:
* reshape(input_data, output_buffer)
* this.dims = unset
* Copies input_data to output_buffer; output_buffer already has the output shape.
* This version will not fail gracefully if the input shape and output_buffer shape are
* incompatible. There's a throw that will catch when the number of elements do not match at
* runtime. This version should only be used for dynamic reshapes (output dimensions only known at
* runtime). If output_buffer has a static shape during compile/parse, you can use the 1 input
* version.
*/
struct reshape
{
std::vector<int64_t> dims;
Expand Down Expand Up @@ -215,32 +231,56 @@ struct reshape

shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
check_shapes{inputs, *this, true}.has(1, 2);

auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("reshape: Dimensions for reshape can only have one -1 dim");

auto s0 = inputs.front();
if(s0.dynamic())
if(inputs.size() == 1)
{
return dyn_compute_shape(s0);
if(s0.dynamic())
{
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
}
else
{
return static_compute_shape(inputs, n_neg_dims);
return inputs.back();
}
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
assert(dyn_out.computed_shape.standard());
argument result{dyn_out.computed_shape};
if(args.size() == 1)
{
argument result{dyn_out.computed_shape};

visit_all(result, args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return result;
visit_all(result, args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return result;
}
else
{
// 2 arg
if(args[0].get_shape().elements() != args[1].get_shape().elements())
{
MIGRAPHX_THROW("Reshape: Number of elements must match at runtime. Input: " +
std::to_string(args[0].get_shape().elements()) +
" Output buffer: " + std::to_string(args[1].get_shape().elements()));
}
visit_all(args[1], args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return args[1];
}
}
};

Expand Down
130 changes: 130 additions & 0 deletions src/onnx/parse_groupnorm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

struct parse_groupnorm : op_parser<parse_groupnorm>
{
std::vector<op_desc> operators() const { return {{"GroupNormalization"}}; }

instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
size_t num_groups;
if(contains(info.attributes, "num_groups"))
{
num_groups = parser.parse_value(info.attributes.at("num_groups")).at<size_t>();
}
else
{
MIGRAPHX_THROW("PARSE_GROUPNORM: num_groups must be available");
}

if(args.size() != 3)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: invalid input count");
}

auto x = args.at(0);
auto scale = args.at(1);
auto bias = args.at(2);

auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();
auto x_dims = x_shape.lens();

if(x_shape.ndim() <= 2)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: invalid input shape");
}

auto c = x_shape.lens().at(1);
if(c % num_groups != 0)
{
MIGRAPHX_THROW(
"PARSE_GROUPNORM: num_groups should be a divisor of the number of channels");
}
auto group_size = c / num_groups;
if(scale->get_shape().ndim() != 1 or scale->get_shape().lens().at(0) != num_groups)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: scale tensor shape should be num_groups");
}
if(bias->get_shape().ndim() != 1 or bias->get_shape().lens().at(0) != num_groups)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: bias tensor shape should be num_groups");
}

// Original shape: N x C x D1 x ... x Dn
// New shape: N x num_groups x C // num_groups x D1 x ... x Dn

std::vector<size_t> dims = {x_dims.at(0), num_groups, group_size};
std::copy(x_dims.begin() + 2, x_dims.end(), std::back_inserter(dims));
auto x_reshaped = info.add_instruction(make_op("reshape", {{"dims", dims}}), x);

// Axes for D1 x ... x Dn
std::vector<size_t> axes(dims.size() - 2);
std::iota(axes.begin(), axes.end(), 2);

// y = (x - mean) * rsqrt(variance + epsilon) * scale + bias
// mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)

auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_reshaped);
auto x_sub_mean = info.add_common_op("sub", x_reshaped, mean);
auto x_sqdiff_mean = info.add_common_op("sqdiff", x_reshaped, mean);
auto variance =
info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_sqdiff_mean);
epsilon =
(x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon;
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}});
auto var_eps = info.add_common_op("add", variance, eps);
auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto result = info.add_common_op("mul", x_sub_mean, rsqrt);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto scaled = info.add_instruction(make_op("mul"), result, scale_bcast);
auto y = info.add_instruction(make_op("add"), scaled, bias_bcast);
auto y_reshaped = info.add_instruction(make_op("reshape", {{"dims", x_dims}}), y);
return y_reshaped;
}
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Loading

0 comments on commit 249b924

Please sign in to comment.