-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GroupNorm and LayerNorm onnx parsing #2242
Conversation
Codecov Report
@@ Coverage Diff @@
## develop #2242 +/- ##
===========================================
+ Coverage 91.30% 91.33% +0.03%
===========================================
Files 436 438 +2
Lines 16345 16425 +80
===========================================
+ Hits 14923 15001 +78
- Misses 1422 1424 +2
|
fd748bb
to
39571f1
Compare
Would be good to add that. But not sure about priority. We already have matcher that pattern matches primitives to LayerNorm. Perhaps it can be tweaked to match for group norm as well. |
39571f1
to
59bb7b1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall. Doesn't support dynamic shapes; will have to make updates for that later on. Add some fp16 verify_onnx tests.
if(skipped_axes > 0) | ||
{ | ||
auto x_dims = x_shape.lens(); | ||
scale_bcast = info.add_instruction( | ||
make_op("broadcast", {{"axis", skipped_axes}, {"out_lens", x_dims}}), scale); | ||
if(not skip_bias) | ||
{ | ||
bias_bcast = info.add_instruction( | ||
make_op("broadcast", {{"axis", skipped_axes}, {"out_lens", x_dims}}), bias); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this will not work if x_shape
is dynamic. Probably outside of what we want to get done with this PR however. So should be fine.
59bb7b1
to
7b39645
Compare
This PR extends the normalization support with GroupNorm and LayerNorm.
Since they are very similar in implementation and usage, it is grouped in one PR.
I saw that there was already support for layernorm. This extends the parsing of the onnx node, but still uses built in primitives for calculation.
Similar support for groupnorm is not added yet. Is it needed? dnnl supports it. If so, i would recommend a follow-up PR instead of extending this.
Also MIOpen just introduced layernorm as a primitive, but it is in BETA.
Resolves migraphx-benchmark#94 and migraphx-benchmark#100