Skip to content
This repository has been archived by the owner on Jun 27, 2024. It is now read-only.

Added 64-bit support for CUDA Calls. #147

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

bviyer
Copy link
Contributor

@bviyer bviyer commented Jun 15, 2023

IREE Compiler was converting 64-bit datatype to 32-bit and to support that several CUDA function calls were converted to 32-bit FP (float). This patch will reverse this and allow 64-bit datatype (double).

IREE Compiler was converting 64-bit datatype to 32-bit and to support that
several CUDA function calls were converted to 32-bit FP (`float`). This
patch will reverse this and allow 64-bit datatype (`double`).
cudnn.graph @conv2d(%x: !cudnn.tensor<8x32x4x4xf32, NHWC>,
%w: !cudnn.tensor<32x32x1x1xf32, KHWC>)
-> !cudnn.tensor<8x32x4x4xf32, NHWC> {
cudnn.graph @conv2d(%x: !cudnn.tensor<8x32x4x4xf64, NHWC>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensors should stay in f32, only alpha/beta should be update to f64

Copy link
Contributor Author

@bviyer bviyer Jun 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Pretty much reverted this file. see: af73d6b

@@ -576,13 +576,13 @@ struct ConvertCudnnBinaryOp : public CudnnOpConversionPattern<T> {
MLIRContext *ctx = rewriter.getContext();
ImplicitLocOpBuilder b(op->getLoc(), rewriter);

auto f32 = rewriter.getF32Type();
auto newType = rewriter.getF64Type();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newType=>f64 (and in few other places as well)

Copy link
Contributor Author

@bviyer bviyer Jun 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. (see af73d6b

@ezhulenev
Copy link
Contributor

This still fails at IREE head with INTERNAL; import function signature mismatch between module and source cudnn; expected 0rfrfi_r but got 0rFrFi_r; resolving module 'module' imports; creating VM context; creating run context, what the PR that fixes the problem on IREE side?

@bviyer
Copy link
Contributor Author

bviyer commented Jun 16, 2023

This still fails at IREE head with INTERNAL; import function signature mismatch between module and source cudnn; expected 0rfrfi_r but got 0rFrFi_r; resolving module 'module' imports; creating VM context; creating run context, what the PR that fixes the problem on IREE side?

iree-org/iree#14114

@bviyer bviyer requested a review from ezhulenev June 16, 2023 16:16
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants