Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 2D forward convolution using rocMLIR #2507

Merged
merged 128 commits into from
Dec 7, 2023
Merged

FP8 2D forward convolution using rocMLIR #2507

merged 128 commits into from
Dec 7, 2023

Conversation

umangyadav
Copy link
Member

@umangyadav umangyadav commented Dec 3, 2023

FP8 convolutions that uses rocMLIR.

Uses rocMLIR from SHA : 5085343bca363109ae9ebabb7ca2b65c52bc861c

ROCm/rocMLIR#1336

regular convolution takes in both inputs as FP8 and generates FP8 output. Internally on hardware it will do accumulation in Fp32 but final result is converted back to fp8 using downcasting.

quant_convolution takes in both inputs as fp8 and generates FP32 output. This version can utilize QDQ quantization scheme to use scales to downcast FP32 output to fp8.

rocMLIR fp8 convolution are limited to 2d forward convolutions only. Therefore only have added tests for those.
3D convolutions, 1D convolution, backwards convolution (transposed convolutions) "tests" therefore are not enabled.

convert fusion with mlir-convolution is not compiling on non-MI300 hardware therefore disabled it.

For now i've kept both versions of convolutions (quant and regular). In future if there is no use for regular fp8 convolution then it can be removed.

Testing : make check passes on MI300.

depends on #2473

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Dec 3, 2023


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ torchvision-inceptionv3_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ slim-vgg16_1: PASSED: MIGraphX meets tolerance

     ✅ slim-mobilenet_1: PASSED: MIGraphX meets tolerance

     ✅ slim-inceptionv4_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

🔴bert_base_cased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large_uncased_fp16: PASSED: MIGraphX meets tolerance

     ✅ bert_large: PASSED: MIGraphX meets tolerance

🔴distilgpt2_fp16: FAILED: MIGraphX is not within tolerance - check verbose output

requirements.txt Outdated Show resolved Hide resolved
@umangyadav umangyadav changed the base branch from develop to rocblas_fp8 December 6, 2023 00:07
Base automatically changed from rocblas_fp8 to develop December 6, 2023 01:20
@TedThemistokleous
Copy link
Collaborator

LGTM

@causten causten merged commit 6a72e8f into develop Dec 7, 2023
44 checks passed
@causten causten deleted the rocblas_mlir_fp8 branch December 7, 2023 02:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FP8 issues related to FP8 implemenation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants