Skip to content

Commit

Permalink
Fixes to parse DynamicQuantizeLinear (#2896)
Browse files Browse the repository at this point in the history
  • Loading branch information
TedThemistokleous authored and Chris Austen committed Mar 19, 2024
1 parent aee01f8 commit 233c382
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 36 deletions.
44 changes: 22 additions & 22 deletions src/onnx/parse_dynamicquantizelinear.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -98,41 +98,41 @@ struct parse_dynamicquantizelinear : op_parser<parse_dynamicquantizelinear>
if(x_shape.dynamic())
MIGRAPHX_THROW("DYNAMICQUANTIZELINEAR: dynamic shapes are not supported");

auto x_reshaped =
(x_shape.lens().size() == 1)
? x
: info.add_instruction(
migraphx::make_op("reshape", {{"dims", {x_shape.elements()}}}), x);

auto lit_0 = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0}});
x_reshaped =
info.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x_reshaped, lit_0);

// 1. Computing y_scale
// Note: currently, DynamicQuantizeLinear only has uint8 quantization:
const auto x_max = std::numeric_limits<uint8_t>::max();
const auto x_min = std::numeric_limits<uint8_t>::min();

auto q_range =
info.add_literal(migraphx::literal{migraphx::shape{x_type}, {x_max - x_min}});
const auto type_max = std::numeric_limits<uint8_t>::max();
const auto type_min = std::numeric_limits<uint8_t>::min();
std::vector<size_t> axes(x_shape.lens().size());
std::iota(axes.begin(), axes.end(), 0);

// maximum(0, max(x))
auto max_x =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), x_reshaped);
auto reduce_max_x =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", axes}}), x);
auto max_x = info.add_common_op("max", lit_0, reduce_max_x);

// minimum(0, min(x))
auto min_x =
info.add_instruction(migraphx::make_op("reduce_min", {{"axes", {0}}}), x_reshaped);
auto reduce_min_x =
info.add_instruction(migraphx::make_op("reduce_min", {{"axes", axes}}), x);
auto min_x = info.add_common_op("min", lit_0, reduce_min_x);

auto q_range = info.add_literal(migraphx::literal{
migraphx::shape{x_type, max_x->get_shape().lens()}, {type_max - type_min}});
auto q_min = info.add_literal(
migraphx::literal{migraphx::shape{x_type, max_x->get_shape().lens()}, {type_min}});
auto q_max = info.add_literal(
migraphx::literal{migraphx::shape{x_type, max_x->get_shape().lens()}, {type_max}});

// y_scale = (maximum(0, max(x)) - minimum(0, min(x))) / (qmax - qmin)
auto sub0 = info.add_common_op("sub", max_x, min_x);
auto y_scale = info.add_common_op("div", sub0, q_range);

// 2. Computing y_zero_point
// intermediate_zero_point = qmin - min(x) / y_scale
auto q_min = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {x_min}});
auto q_max = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {x_max}});
auto sub1 = info.add_common_op("sub", q_min, min_x);
auto interm_zp = info.add_common_op("div", sub1, y_scale);
auto div1 = info.add_common_op("div", min_x, y_scale);
auto interm_zp = info.add_common_op("sub", q_min, div1);

// y_zero_point = cast(round(saturate(itermediate_zero_point)))
auto saturate = info.add_instruction(migraphx::make_op("clip"), interm_zp, q_min, q_max);
auto round = info.add_instruction(migraphx::make_op("nearbyint"), saturate);
Expand Down
34 changes: 20 additions & 14 deletions test/onnx/parse/dynamicquantizelinear_2d_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -32,25 +32,31 @@ TEST_CASE(dynamicquantizelinear_2d_test)
auto x_type = migraphx::shape::float_type;
auto x = mm->add_parameter("x", {x_type, x_dims});

auto l0 = mm->add_literal({0.f});
auto x_reshaped = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), x);
x_reshaped = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x_reshaped, l0);
auto l0 = mm->add_literal({0.f});

auto q_range = mm->add_literal(
migraphx::literal{migraphx::shape{x_type}, {std::numeric_limits<uint8_t>::max()}});
std::vector<size_t> axes(x->get_shape().lens().size());
std::iota(axes.begin(), axes.end(), 0);

auto max_x = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), x_reshaped);
auto min_x = mm->add_instruction(migraphx::make_op("reduce_min", {{"axes", {0}}}), x_reshaped);
auto reduce_max_x = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", axes}}), x);
auto max_x = add_common_op(*mm, migraphx::make_op("max"), {l0, reduce_max_x});

auto reduce_min_x = mm->add_instruction(migraphx::make_op("reduce_min", {{"axes", axes}}), x);
auto min_x = add_common_op(*mm, migraphx::make_op("min"), {l0, reduce_min_x});

auto q_range = mm->add_literal(migraphx::literal{
migraphx::shape{x_type, max_x->get_shape().lens()},
{std::numeric_limits<uint8_t>::max() - std::numeric_limits<uint8_t>::min()}});
auto q_min = mm->add_literal(migraphx::literal{
migraphx::shape{x_type, max_x->get_shape().lens()}, {std::numeric_limits<uint8_t>::min()}});
auto q_max = mm->add_literal(migraphx::literal{
migraphx::shape{x_type, min_x->get_shape().lens()}, {std::numeric_limits<uint8_t>::max()}});

auto sub0 = mm->add_instruction(migraphx::make_op("sub"), max_x, min_x);
auto y_scale = mm->add_instruction(migraphx::make_op("div"), sub0, q_range);

auto q_min = mm->add_literal(
migraphx::literal{migraphx::shape{x_type}, {std::numeric_limits<uint8_t>::min()}});
auto q_max = mm->add_literal(
migraphx::literal{migraphx::shape{x_type}, {std::numeric_limits<uint8_t>::max()}});
auto sub1 = mm->add_instruction(migraphx::make_op("sub"), q_min, min_x);
auto interm_zp = mm->add_instruction(migraphx::make_op("div"), sub1, y_scale);
auto div1 = add_common_op(*mm, migraphx::make_op("div"), {min_x, y_scale});
auto interm_zp = add_common_op(*mm, migraphx::make_op("sub"), {q_min, div1});

auto saturate = mm->add_instruction(migraphx::make_op("clip"), interm_zp, q_min, q_max);
auto round = mm->add_instruction(migraphx::make_op("nearbyint"), saturate);
auto y_zero_point = mm->add_instruction(
Expand Down

0 comments on commit 233c382

Please sign in to comment.