Skip to content

Commit

Permalink
Move rewrite_low_precision after rewrite_reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
gyulaz-htec committed Mar 4, 2024
1 parent 1455b0b commit 2e0a5c3
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/include/migraphx/matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,8 +598,9 @@ MIGRAPHX_PRED_MATCHER(same_inputs, instruction_ref ins)
{
if(ins->inputs().empty())
return false;
auto s = ins->inputs().front();
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x == s; });
auto input = ins->inputs().front();
return std::all_of(
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x == input; });
}

MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
Expand Down
2 changes: 0 additions & 2 deletions src/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/rewrite_low_precision.hpp>
#include <set>

namespace migraphx {
Expand All @@ -58,7 +57,6 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{
run_passes(prog,
{normalize_ops{},
rewrite_low_precision{},
optimize_module{{"quantizelinear", "dequantizelinear"}},
quantize_fp16_pass{ins_names},
optimize_module{{"quantizelinear", "dequantizelinear"}}});
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_low_precision.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_reduce.hpp>
#include <migraphx/rewrite_quantization.hpp>
Expand Down Expand Up @@ -154,6 +155,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type}, shape::float_type, unsupported_fp8_ops},
dead_code_elimination{},
rewrite_reduce{},
rewrite_low_precision{},
dead_code_elimination{},
optimize_module{},
fuse_pointwise{},
Expand Down

0 comments on commit 2e0a5c3

Please sign in to comment.