-
Notifications
You must be signed in to change notification settings - Fork 19
Added 64-bit support for CUDA Calls. #147
base: main
Are you sure you want to change the base?
Conversation
IREE Compiler was converting 64-bit datatype to 32-bit and to support that several CUDA function calls were converted to 32-bit FP (`float`). This patch will reverse this and allow 64-bit datatype (`double`).
cudnn.graph @conv2d(%x: !cudnn.tensor<8x32x4x4xf32, NHWC>, | ||
%w: !cudnn.tensor<32x32x1x1xf32, KHWC>) | ||
-> !cudnn.tensor<8x32x4x4xf32, NHWC> { | ||
cudnn.graph @conv2d(%x: !cudnn.tensor<8x32x4x4xf64, NHWC>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensors should stay in f32, only alpha/beta should be update to f64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Pretty much reverted this file. see: af73d6b
@@ -576,13 +576,13 @@ struct ConvertCudnnBinaryOp : public CudnnOpConversionPattern<T> { | |||
MLIRContext *ctx = rewriter.getContext(); | |||
ImplicitLocOpBuilder b(op->getLoc(), rewriter); | |||
|
|||
auto f32 = rewriter.getF32Type(); | |||
auto newType = rewriter.getF64Type(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
newType=>f64 (and in few other places as well)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. (see af73d6b
This still fails at IREE head with |
|
IREE Compiler was converting 64-bit datatype to 32-bit and to support that several CUDA function calls were converted to 32-bit FP (
float
). This patch will reverse this and allow 64-bit datatype (double
).