Skip to content

Commit

Permalink
incorporate review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
apwojcik committed Oct 8, 2023
1 parent 37549f3 commit d635594
Showing 1 changed file with 93 additions and 107 deletions.
200 changes: 93 additions & 107 deletions src/targets/cpu/include/migraphx/cpu/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,118 +91,25 @@ struct post_op : reflect_equality<post_op>, reflect_stream<post_op>
}
};

template <class F>
struct execute_wrapper
{
F f;
argument operator()(context&, const std::vector<argument>& args) const { return f(args); }
};

template <class F>
execute_wrapper<F> make_execute_wrapper(F f)
{
return {std::move(f)};
}

template <class Derived, class Primitive>
struct dnnl_op : auto_register_op<Derived>
{
std::vector<post_op> post_ops;
std::function<argument(context& ctx, const std::vector<argument>& args)> execute;

class executable
{
std::unordered_map<int, dnnl::memory::desc> md;
Primitive prim;
std::vector<int> arg_lookup;
#ifdef _DEBUG
const dnnl_op& self;
const Derived& derived;
std::string name;
dnnl::primitive_attr prim_attr;
const std::vector<shape>& inputs;
const shape& output_shape;
#endif
public:
// clang-format off
executable(const dnnl_op& op, const shape& out_shape, const std::vector<shape>& in_shapes)
: md{op.to_memory_desc(out_shape, in_shapes)},
prim{op.get_primitive(md)},
arg_lookup{op.create_arg_map(in_shapes.size())}
#ifdef _DEBUG
, self{op},
derived{static_cast<const Derived&>(op)},
name{derived.name()},
prim_attr{op.get_primitive_attr(md)},
inputs{in_shapes},
output_shape{out_shape}
#endif
// clang-format on
{
}

argument operator()(context&, const std::vector<argument>& args)
{
#ifdef _DEBUG
// Check that the memory descriptors have not changed
auto debug_args = args;
debug_args.pop_back();
auto debug_md = self.to_memory_desc(output_shape, to_shapes(debug_args));
for(auto&& p : debug_md)
{
if(md.count(p.first) == 0)
MIGRAPHX_THROW(name +
": Missing memory descriptor for: " + std::to_string(p.first));
if(p.second == md.at(p.first))
continue;
MIGRAPHX_THROW(name +
": Memory descriptor has changed for: " + std::to_string(p.first));
}
// Check post_ops args are correct
auto pos = prim_attr.get_post_ops();
auto prim_input_size = inputs.size() - self.get_extra_post_op_args();
int j = 0;
for(int i = 0; i < pos.len(); i++)
{
auto arg = j + prim_input_size;
auto kind = pos.kind(i);
std::string mesg =
"Post op " + std::to_string(i) + "@" + std::to_string(arg) + ": ";
try
{
dnnl::algorithm algo;
dnnl::memory::desc mdesc;
float scale = 0;
float alpha = 0;
float beta = 0;
if(kind == dnnl::primitive::kind::binary)
{
pos.get_params_binary(i, algo, mdesc);
if(mdesc != md.at(arg_lookup.at(arg)))
MIGRAPHX_THROW(mesg +
"Memory descriptor doesn't match for binary post op");
j++;
}
else if(kind == dnnl::primitive::kind::eltwise)
{
pos.get_params_eltwise(i, scale, algo, alpha, beta);
}
else if(kind == dnnl::primitive::kind::sum)
{
pos.get_params_sum(i, scale);
algo = dnnl::algorithm::binary_add;
}
else
{
MIGRAPHX_THROW("Unknown kind");
}
if(to_dnnl_algo(self.post_ops[i].algo) != algo)
MIGRAPHX_THROW(mesg + "Algorithm doesn't match for post op " +
self.post_ops[i].algo + " != " + to_string(algo));
}
catch(const dnnl::error& e)
{
MIGRAPHX_THROW(mesg + "Failed to get post ops argument " + ": " + e.what());
}
}
#endif
std::unordered_map<int, dnnl::memory> m;
m[MIGRAPHX_DNNL_PREFIX(ARG_DST)] =
to_dnnl_memory(md.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), args.back());
for(int i = 0; i < args.size() - 1; i++)
m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]);
prim.execute(get_dnnl_context().stream, m);
return args.back();
}
};

template <class Self, class F>
static auto reflect_base(Self& self, F f)
{
Expand Down Expand Up @@ -406,7 +313,86 @@ struct dnnl_op : auto_register_op<Derived>
{
// Compensate for allocation
inputs.pop_back();
execute = executable{*this, output_shape, inputs};
const auto& self = static_cast<const Derived&>(*this);
auto name = self.name();
auto md = to_memory_desc(output_shape, inputs);
auto prim = get_primitive(md);
auto arg_lookup = create_arg_map(inputs.size());
#ifndef NDEBUG
auto prim_attr = get_primitive_attr(md);
#endif
execute = make_execute_wrapper([=](const std::vector<argument>& args) {
#ifndef NDEBUG
// Check that the memory descriptors have not changed
auto debug_args = args;
debug_args.pop_back();
auto debug_md = to_memory_desc(output_shape, to_shapes(debug_args));
for(auto&& p : debug_md)
{
if(md.count(p.first) == 0)
MIGRAPHX_THROW(name +
": Missing memory descriptor for: " + std::to_string(p.first));
if(p.second == md.at(p.first))
continue;
MIGRAPHX_THROW(name +
": Memory descriptor has changed for: " + std::to_string(p.first));
}
// Check post_ops args are correct
auto pos = prim_attr.get_post_ops();
auto prim_input_size = inputs.size() - this->get_extra_post_op_args();
int j = 0;
for(int i = 0; i < pos.len(); i++)
{
auto arg = j + prim_input_size;
auto kind = pos.kind(i);
std::string mesg =
"Post op " + std::to_string(i) + "@" + std::to_string(arg) + ": ";
try
{
dnnl::algorithm algo;
dnnl::memory::desc mdesc;
float scale = 0;
float alpha = 0;
float beta = 0;
if(kind == dnnl::primitive::kind::binary)
{
pos.get_params_binary(i, algo, mdesc);
if(mdesc != md.at(arg_lookup.at(arg)))
MIGRAPHX_THROW(mesg +
"Memory descriptor doesn't match for binary post op");
j++;
}
else if(kind == dnnl::primitive::kind::eltwise)
{
pos.get_params_eltwise(i, scale, algo, alpha, beta);
}
else if(kind == dnnl::primitive::kind::sum)
{
pos.get_params_sum(i, scale);
algo = dnnl::algorithm::binary_add;
}
else
{
MIGRAPHX_THROW("Unknown kind");
}
if(to_dnnl_algo(post_ops[i].algo) != algo)
MIGRAPHX_THROW(mesg + "Algorithm doesn't match for post op " +
post_ops[i].algo + " != " + to_string(algo));
}
catch(const dnnl::error& e)
{
MIGRAPHX_THROW(mesg + "Failed to get post ops argument " + ": " + e.what());
}
}
#endif
std::unordered_map<int, dnnl::memory> m;
m[MIGRAPHX_DNNL_PREFIX(ARG_DST)] =
to_dnnl_memory(md.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), args.back());
for(int i = 0; i < args.size() - 1; i++)
m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]);
prim.execute(get_dnnl_context().stream, m);
return args.back();
});
}
std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const
{
Expand Down

0 comments on commit d635594

Please sign in to comment.