Skip to content

Commit

Permalink
pnnx match onnx zeros ones (#5832)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Dec 19, 2024
1 parent ad6d84e commit 57fac4a
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level2/torch_full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,6 @@ pnnx.Output output 1 0 out
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full_onnx, 20)
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full_onnx, 21)

} // namespace pnnx
26 changes: 26 additions & 0 deletions tools/pnnx/src/pass_level2/torch_ones.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,30 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_ones, 20)

class torch_ones_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 size
ConstantOfShape op_0 1 1 size out value=1.0
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.ones";
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const
{
op->params["dtype"] = Parameter();
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_ones_onnx, 20)

} // namespace pnnx
26 changes: 26 additions & 0 deletions tools/pnnx/src/pass_level2/torch_zeros.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,30 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_zeros, 20)

class torch_zeros_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 size
ConstantOfShape op_0 1 1 size out value=0.0
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.zeros";
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const
{
op->params["dtype"] = Parameter();
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_zeros_onnx, 20)

} // namespace pnnx

0 comments on commit 57fac4a

Please sign in to comment.