Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite reduce mean/variance #2883

Merged
merged 49 commits into from
Apr 27, 2024
Merged

Rewrite reduce mean/variance #2883

merged 49 commits into from
Apr 27, 2024

Conversation

pfultz2
Copy link
Collaborator

@pfultz2 pfultz2 commented Mar 12, 2024

Rewrites mean/variance to use reduce_mean(x) and reduce_mean(x*x) so it can be fused in the same reduction.

@pfultz2 pfultz2 requested a review from causten as a code owner March 12, 2024 21:14
@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Mar 12, 2024

Test Batch Rate new
52fe8e
Rate old
ee68f7
Diff Compare
torchvision-resnet50 64 2,824.97 2,821.87 0.11%
torchvision-resnet50_fp16 64 6,405.88 6,407.36 -0.02%
torchvision-densenet121 32 2,092.78 2,096.07 -0.16%
torchvision-densenet121_fp16 32 3,703.81 3,687.09 0.45%
torchvision-inceptionv3 32 1,604.11 1,605.13 -0.06%
torchvision-inceptionv3_fp16 32 2,555.85 2,551.45 0.17%
cadene-inceptionv4 16 718.11 717.52 0.08%
cadene-resnext64x4 16 680.69 680.72 -0.00%
slim-mobilenet 64 5,947.46 5,944.92 0.04%
slim-nasnetalarge 64 154.18 154.07 0.07%
slim-resnet50v2 64 2,589.24 2,583.14 0.24%
bert-mrpc-onnx 8 920.32 921.72 -0.15%
bert-mrpc-tf 1 396.61 397.23 -0.16%
pytorch-examples-wlang-gru 1 403.27 395.90 1.86%
pytorch-examples-wlang-lstm 1 428.73 374.49 14.48% 🔆
torchvision-resnet50_1 1 608.81 609.69 -0.14%
cadene-dpn92_1 1 390.69 390.37 0.08%
cadene-resnext101_1 1 332.01 333.18 -0.35%
onnx-taau-downsample 1 306.56 307.34 -0.25%
dlrm-criteoterabyte 1 28.87 28.87 0.02%
dlrm-criteoterabyte_fp16 1 48.27 48.32 -0.10%
agentmodel 1 7,222.46 7,330.30 -1.47%
unet_fp16 2 57.55 57.54 0.02%
resnet50v1_fp16 1 906.34 911.10 -0.52%
resnet50v1_int8 1 816.74 813.46 0.40%
bert_base_cased_fp16 64 1,034.77 1,033.99 0.08%
bert_large_uncased_fp16 32 300.48 300.56 -0.03%
bert_large_fp16 1 159.95 160.04 -0.05%
distilgpt2_fp16 16 1,854.32 1,854.56 -0.01%
yolov5s 1 475.96 474.48 0.31%
tinyllama 1 32.99 32.97 0.04%
vicuna-fastchat 1 157.86 158.01 -0.10%
whisper-tiny-encoder 1 348.54 346.15 0.69%
whisper-tiny-decoder 1 396.66 396.01 0.16%

Check results before merge 🔆

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Mar 12, 2024


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

@pfultz2 pfultz2 requested a review from a team as a code owner March 29, 2024 00:33
});
auto preduce = m.insert_instruction(last, parallel_reduce{op}, inputs);
int i = 0;
for(auto reduce : reduces)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be a slight nit-pick in naming but make "reduces" into something like "reductions" Just makes this a but more readable and then be clearer when you grab a reduce from the list of reductions.

Same goes for "preduce" make this into parallel_reductions since this might get interpreted as pointer-to-reduce with the name preduce.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use the name reduce since the operator name is reduce. It seems confusing to call it reduction when the operator is not named reduction. I guess i can make these single letters and then you can interpret it however makes sense for you.

}
EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(reduce_reduce_mismatch_axis)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case is good and I'm sure these can be massaged to create the other tests to handle cases where matching fails and some of the other coverage error warnings

Copy link
Collaborator

@TedThemistokleous TedThemistokleous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Add additional test coverage rewrite_reduce.cpp
  • Readability for prepare_reduce.cpp

Copy link
Collaborator

@TedThemistokleous TedThemistokleous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

@pfultz2 pfultz2 merged commit 56d341d into develop Apr 27, 2024
39 of 42 checks passed
@pfultz2 pfultz2 deleted the reduce-mean-variance branch April 27, 2024 02:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants