Skip to content

Commit

Permalink
Merge branch 'develop' into msvc_windows_type_name
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored May 8, 2024
2 parents d019a2a + 70c338b commit ed8220a
Show file tree
Hide file tree
Showing 28 changed files with 1,320 additions and 29 deletions.
2 changes: 1 addition & 1 deletion docs/sphinx/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ idna==3.7
# via requests
imagesize==1.4.1
# via sphinx
jinja2==3.1.3
jinja2==3.1.4
# via
# myst-parser
# sphinx
Expand Down
2 changes: 1 addition & 1 deletion src/fuse_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ inline namespace MIGRAPHX_INLINE_NS {

static literal get_scalar(instruction_ref ins)
{
if(ins->name() == "contiguous")
if(contains({"contiguous", "broadcast", "multibroadcast"}, ins->name()))
return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape();
if(s.elements() != 1 and not(s.scalar()))
Expand Down
228 changes: 216 additions & 12 deletions src/onnx/parse_convolution.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -30,6 +30,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -42,18 +43,199 @@ struct parse_convolution : op_parser<parse_convolution>
return {{"Conv", "convolution"}, {"ConvInteger", "quant_convolution"}};
}

// Convert to half prior to a shift to ensure we preserve accuracy here then
// convert back to int8
static instruction_ref add_int8_shift(const onnx_parser::node_info& info,
const instruction_ref& offset_op,
instruction_ref& unshifted_input)
{
auto unshifted_input_half = info.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}),
unshifted_input);

auto input_shifted_half = info.add_common_op("add", unshifted_input_half, offset_op);

return info.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}),
input_shifted_half);
}

static void shift_input_and_bias(const onnx_parser::node_info& info,
const instruction_ref& offset_op,
const bool has_bias,
instruction_ref& input,
instruction_ref& input_bias)
{
input = add_int8_shift(info, offset_op, input);
if(has_bias)
{
input_bias = add_int8_shift(info, offset_op, input_bias);
}
}

static float get_symmetric_value(const instruction_ref& input)
{
float symmetric_value = 0;
// adjust symmetric zero point value for uint8 types
if(input->get_shape().type() == migraphx::shape::uint8_type)
{
symmetric_value = 128;
}
return symmetric_value;
}

static instruction_ref gen_symmetric_literal(const instruction_ref& input,
const bool is_quant_conv,
onnx_parser::node_info& info)
{
instruction_ref ret = input;
if(is_quant_conv)
{
float symmetric_value = get_symmetric_value(input);
ret = info.add_literal(migraphx::literal{
migraphx::shape{input->get_shape().type(), {1}, {0}}, {symmetric_value}});
}

return ret;
}

static instruction_ref get_zero_point(const instruction_ref& input,
int index,
const bool is_quant_conv,
onnx_parser::node_info& info,
const std::vector<instruction_ref>& args)
{
instruction_ref ret = input;
if(args.size() > index)
{
// Check for type mismatch on parse
if(input->get_shape().type() != args[index]->get_shape().type())
MIGRAPHX_THROW("PARSE:Conv Data and Data Zero Point must have same type");

ret = args[index];
if(is_symmetric_zero_point(ret))
{
ret = gen_symmetric_literal(ret, is_quant_conv, info);
}
}
else
{
ret = gen_symmetric_literal(ret, is_quant_conv, info);
}

return ret;
}

static bool is_symmetric_zero_point(instruction_ref zp)
{
if(not zp->can_eval())
return false;

float symmetric_value = get_symmetric_value(zp);

bool all_zeros = false;
zp->eval().visit([&](auto z) {
all_zeros = std::all_of(
z.begin(), z.end(), [&](auto val) { return float_equal(val, symmetric_value); });
});
return all_zeros;
}

static auto
qparam_broadcast_op(instruction_ref qparam, std::vector<std::size_t> lens, std::size_t axis)
{
if(qparam->get_shape().scalar())
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
}

static instruction_ref handle_quant_bias(const operation& op,
const instruction_ref& input,
const instruction_ref& x,
const instruction_ref& weights,
const instruction_ref& x_zp,
const instruction_ref& w_zp,
onnx_parser::node_info& info)
{
instruction_ref ret = input;
if(not is_symmetric_zero_point(x_zp))
{
auto out_zp_1 = info.add_common_op(op.name(), x_zp, weights);
ret = info.add_common_op("sub", ret, out_zp_1);
}

if(not is_symmetric_zero_point(w_zp))
{
auto out_zp_2 = info.add_common_op(op.name(), x, w_zp);
ret = info.add_common_op("sub", ret, out_zp_2);
}

if(not(is_symmetric_zero_point(x_zp)) and not(is_symmetric_zero_point(w_zp)))
{
auto x_zp_bc =
info.add_instruction(qparam_broadcast_op(x_zp, x->get_shape().lens(), 0), x_zp);
auto w_zp_bc = info.add_instruction(
qparam_broadcast_op(w_zp, weights->get_shape().lens(), 0), w_zp);

auto out_zp_3 = info.add_instruction(op, x_zp_bc, w_zp_bc);

ret = info.add_common_op("add", ret, out_zp_3);
}
return ret;
}

static void handle_quant_inputs(const bool is_quant_conv,
instruction_ref& input,
instruction_ref& weights,
instruction_ref& input_zp,
instruction_ref& weight_zp,
onnx_parser::node_info& info)
{
if(not is_quant_conv)
return;

auto input_type = input->get_shape().type();
auto weight_type = weights->get_shape().type();

// Handle uint8 bias and input shifts
instruction_ref offset_op;
if(((input_type == migraphx::shape::uint8_type) or
(weight_type == migraphx::shape::uint8_type)))
{
offset_op = info.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {-128}});
}

if(input_type == migraphx::shape::uint8_type)
{
shift_input_and_bias(
info, offset_op, (not is_symmetric_zero_point(input_zp)), input, input_zp);
}

if(weight_type == migraphx::shape::uint8_type)
{
shift_input_and_bias(
info, offset_op, (not is_symmetric_zero_point(weight_zp)), weights, weight_zp);
}
}

instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto op = make_op(opd.op_name);
auto values = op.to_value();
auto l0 = args[0];
auto weights = args[1];
auto l0_shape = l0->get_shape();
auto w_shape = weights->get_shape();
auto in_lens = l0_shape.max_lens();
auto op = make_op(opd.op_name);
auto values = op.to_value();
auto x = args[0];
auto weights = args[1];
auto x_shape = x->get_shape();
auto w_shape = weights->get_shape();
auto in_lens = x_shape.max_lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;

Expand Down Expand Up @@ -92,9 +274,9 @@ struct parse_convolution : op_parser<parse_convolution>

// check if image shape is dynamic
bool image_shape_dynamic = false;
if(l0_shape.dynamic())
if(x_shape.dynamic())
{
auto dyn_dims = l0_shape.dyn_dims();
auto dyn_dims = x_shape.dyn_dims();
std::for_each(dyn_dims.begin() + 2, dyn_dims.end(), [&](auto dyn_dim) {
if(not dyn_dim.is_fixed())
{
Expand Down Expand Up @@ -149,9 +331,31 @@ struct parse_convolution : op_parser<parse_convolution>

recalc_conv_attributes(values, kdims);

instruction_ref ret;
// parse a_zero_point and b_zero_point values
auto is_quant_conv = opd.op_name == "quant_convolution";

auto x_zp = get_zero_point(x, 2, is_quant_conv, info, args);
auto w_zp = get_zero_point(weights, 3, is_quant_conv, info, args);

op.from_value(values);
auto l1 = info.add_instruction(op, l0, args[1]);
return info.add_bias(args, l1, 1);

handle_quant_inputs(is_quant_conv, x, weights, x_zp, w_zp, info);

ret = info.add_instruction(op, x, weights);

// Handle quant_conv residuals between input/weights to avoid overflow
if(is_quant_conv)
{
ret = handle_quant_bias(op, ret, x, weights, x_zp, w_zp, info);
}
else
{
// Handle Convolution case with bias to output
ret = info.add_bias(args, ret, 1);
}

return ret;
}
};

Expand Down
9 changes: 9 additions & 0 deletions src/targets/gpu/compile_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ static bool use_lazy_inner(instruction_ref ins)
{
if(ins->outputs().size() != 1)
return false;
// When the inputs are broadcasted, it means the lambda will capture SGPRs
// when doing block/wave reduction. This can cause register spilling in
// the compiler when the lambda is evaluated at a later time although it
// shouldn't. Instead, use `inner` to workaround this issue in the
// compiler.
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](instruction_ref input) {
return input->get_shape().broadcasted();
}))
return false;
auto output = ins->outputs().front();
return contains(output->name(), "reduce") or output->name() == "@return";
}
Expand Down
28 changes: 28 additions & 0 deletions test/fuse_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,34 @@ TEST_CASE(scalar_input)
EXPECT(p1 == p2);
}

TEST_CASE(scalar_like_input)
{
migraphx::shape s{migraphx::shape::float_type, {2, 1}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto one = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1.0f}});
auto y =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), one);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({add1});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto add1 = add_pointwise(p2, "main:pointwise0", {x}, [=](auto* pm, const auto& inputs) {
auto y = pm->add_literal(1.0f);
return pm->add_instruction(migraphx::make_op("add"), inputs[0], y);
});
mm->add_return({add1});
}
EXPECT(p1 == p2);
}

TEST_CASE(contiguous_input)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
Expand Down
29 changes: 29 additions & 0 deletions test/onnx/conv_bad_bias_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
 conv_bad_bias_test:�
8
0
1
23"Conv*
dilations@@�*
strides@@�conv_bad_bias_testZ
0




 Z
1




Z
2


b
3




B
6 changes: 3 additions & 3 deletions test/onnx/convinteger_bias_test.onnx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
convinteger_bias_test:�
 convinteger_bias_test:�
?
0
1
Expand All @@ -19,11 +19,11 @@
Z
2



b
3




B
B
Loading

0 comments on commit ed8220a

Please sign in to comment.