diff --git a/src/targets/cpu/include/migraphx/cpu/dnnl.hpp b/src/targets/cpu/include/migraphx/cpu/dnnl.hpp index 821411eae81..d063bd34f25 100644 --- a/src/targets/cpu/include/migraphx/cpu/dnnl.hpp +++ b/src/targets/cpu/include/migraphx/cpu/dnnl.hpp @@ -91,118 +91,25 @@ struct post_op : reflect_equality, reflect_stream } }; +template +struct execute_wrapper +{ + F f; + argument operator()(context&, const std::vector& args) const { return f(args); } +}; + +template +execute_wrapper make_execute_wrapper(F f) +{ + return {std::move(f)}; +} + template struct dnnl_op : auto_register_op { std::vector post_ops; std::function& args)> execute; - class executable - { - std::unordered_map md; - Primitive prim; - std::vector arg_lookup; -#ifdef _DEBUG - const dnnl_op& self; - const Derived& derived; - std::string name; - dnnl::primitive_attr prim_attr; - const std::vector& inputs; - const shape& output_shape; -#endif - public: - // clang-format off - executable(const dnnl_op& op, const shape& out_shape, const std::vector& in_shapes) - : md{op.to_memory_desc(out_shape, in_shapes)}, - prim{op.get_primitive(md)}, - arg_lookup{op.create_arg_map(in_shapes.size())} -#ifdef _DEBUG - , self{op}, - derived{static_cast(op)}, - name{derived.name()}, - prim_attr{op.get_primitive_attr(md)}, - inputs{in_shapes}, - output_shape{out_shape} -#endif - // clang-format on - { - } - - argument operator()(context&, const std::vector& args) - { -#ifdef _DEBUG - // Check that the memory descriptors have not changed - auto debug_args = args; - debug_args.pop_back(); - auto debug_md = self.to_memory_desc(output_shape, to_shapes(debug_args)); - for(auto&& p : debug_md) - { - if(md.count(p.first) == 0) - MIGRAPHX_THROW(name + - ": Missing memory descriptor for: " + std::to_string(p.first)); - if(p.second == md.at(p.first)) - continue; - MIGRAPHX_THROW(name + - ": Memory descriptor has changed for: " + std::to_string(p.first)); - } - // Check post_ops args are correct - auto pos = prim_attr.get_post_ops(); - auto prim_input_size = inputs.size() - self.get_extra_post_op_args(); - int j = 0; - for(int i = 0; i < pos.len(); i++) - { - auto arg = j + prim_input_size; - auto kind = pos.kind(i); - std::string mesg = - "Post op " + std::to_string(i) + "@" + std::to_string(arg) + ": "; - try - { - dnnl::algorithm algo; - dnnl::memory::desc mdesc; - float scale = 0; - float alpha = 0; - float beta = 0; - if(kind == dnnl::primitive::kind::binary) - { - pos.get_params_binary(i, algo, mdesc); - if(mdesc != md.at(arg_lookup.at(arg))) - MIGRAPHX_THROW(mesg + - "Memory descriptor doesn't match for binary post op"); - j++; - } - else if(kind == dnnl::primitive::kind::eltwise) - { - pos.get_params_eltwise(i, scale, algo, alpha, beta); - } - else if(kind == dnnl::primitive::kind::sum) - { - pos.get_params_sum(i, scale); - algo = dnnl::algorithm::binary_add; - } - else - { - MIGRAPHX_THROW("Unknown kind"); - } - if(to_dnnl_algo(self.post_ops[i].algo) != algo) - MIGRAPHX_THROW(mesg + "Algorithm doesn't match for post op " + - self.post_ops[i].algo + " != " + to_string(algo)); - } - catch(const dnnl::error& e) - { - MIGRAPHX_THROW(mesg + "Failed to get post ops argument " + ": " + e.what()); - } - } -#endif - std::unordered_map m; - m[MIGRAPHX_DNNL_PREFIX(ARG_DST)] = - to_dnnl_memory(md.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), args.back()); - for(int i = 0; i < args.size() - 1; i++) - m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]); - prim.execute(get_dnnl_context().stream, m); - return args.back(); - } - }; - template static auto reflect_base(Self& self, F f) { @@ -406,7 +313,86 @@ struct dnnl_op : auto_register_op { // Compensate for allocation inputs.pop_back(); - execute = executable{*this, output_shape, inputs}; + const auto& self = static_cast(*this); + auto name = self.name(); + auto md = to_memory_desc(output_shape, inputs); + auto prim = get_primitive(md); + auto arg_lookup = create_arg_map(inputs.size()); +#ifndef NDEBUG + auto prim_attr = get_primitive_attr(md); +#endif + execute = make_execute_wrapper([=](const std::vector& args) { +#ifndef NDEBUG + // Check that the memory descriptors have not changed + auto debug_args = args; + debug_args.pop_back(); + auto debug_md = to_memory_desc(output_shape, to_shapes(debug_args)); + for(auto&& p : debug_md) + { + if(md.count(p.first) == 0) + MIGRAPHX_THROW(name + + ": Missing memory descriptor for: " + std::to_string(p.first)); + if(p.second == md.at(p.first)) + continue; + MIGRAPHX_THROW(name + + ": Memory descriptor has changed for: " + std::to_string(p.first)); + } + // Check post_ops args are correct + auto pos = prim_attr.get_post_ops(); + auto prim_input_size = inputs.size() - this->get_extra_post_op_args(); + int j = 0; + for(int i = 0; i < pos.len(); i++) + { + auto arg = j + prim_input_size; + auto kind = pos.kind(i); + std::string mesg = + "Post op " + std::to_string(i) + "@" + std::to_string(arg) + ": "; + try + { + dnnl::algorithm algo; + dnnl::memory::desc mdesc; + float scale = 0; + float alpha = 0; + float beta = 0; + if(kind == dnnl::primitive::kind::binary) + { + pos.get_params_binary(i, algo, mdesc); + if(mdesc != md.at(arg_lookup.at(arg))) + MIGRAPHX_THROW(mesg + + "Memory descriptor doesn't match for binary post op"); + j++; + } + else if(kind == dnnl::primitive::kind::eltwise) + { + pos.get_params_eltwise(i, scale, algo, alpha, beta); + } + else if(kind == dnnl::primitive::kind::sum) + { + pos.get_params_sum(i, scale); + algo = dnnl::algorithm::binary_add; + } + else + { + MIGRAPHX_THROW("Unknown kind"); + } + if(to_dnnl_algo(post_ops[i].algo) != algo) + MIGRAPHX_THROW(mesg + "Algorithm doesn't match for post op " + + post_ops[i].algo + " != " + to_string(algo)); + } + catch(const dnnl::error& e) + { + MIGRAPHX_THROW(mesg + "Failed to get post ops argument " + ": " + e.what()); + } + } +#endif + std::unordered_map m; + m[MIGRAPHX_DNNL_PREFIX(ARG_DST)] = + to_dnnl_memory(md.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), args.back()); + for(int i = 0; i < args.size() - 1; i++) + m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]); + prim.execute(get_dnnl_context().stream, m); + return args.back(); + }); } std::vector trim_post_op_inputs(const std::vector& inputs) const {