Skip to content
This repository has been archived by the owner on Jun 27, 2024. It is now read-only.

Added 64-bit support for CUDA Calls. #147

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def CUBLAS_Dialect : Dialect {
//===----------------------------------------------------------------------===//

// TODO(ezhulenev): Add all supported data types.
def CUBLAS_AnyDataType : AnyTypeOf<[F32], "any supported cuBLAS data type">;
def CUBLAS_AnyDataType : AnyTypeOf<[F32,F64], "any supported cuBLAS data type">;

def CUBLAS_Tensor2D : 2DTensorOf<[CUBLAS_AnyDataType]>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def CUBLAS_GemmOp : CUBLAS_Op<"gemm", [
"CublasOperation::N">:$transa,
DefaultValuedOptionalAttr<CUBLAS_OperationAttr,
"CublasOperation::N">:$transb,
DefaultValuedOptionalAttr<F32Attr, "1.0f">:$alpha,
DefaultValuedOptionalAttr<F32Attr, "0.0f">:$beta,
DefaultValuedOptionalAttr<F64Attr, "1.0f">:$alpha,
DefaultValuedOptionalAttr<F64Attr, "0.0f">:$beta,
OptionalAttr<CUBLAS_GemmAlgoAttr>:$algo,
OptionalAttr<CUBLAS_DataTypeAttr>:$computeType,
Variadic<Index>:$argument_dims,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ func::FuncOp CudnnAPI::getPointwiseUnaryFunction(PatternRewriter &rewriter,
std::string_view op) {
MLIRContext *ctx = module->getContext();
auto tensor = CudnnTensorType::get(ctx);
auto f32 = Float32Type::get(ctx);
auto newType = Float64Type::get(ctx);
auto i32 = IntegerType::get(ctx, 32);

SmallVector<Type> args = {/*x=*/tensor, /*alpha=*/f32, /*is_virtual=*/i32};
SmallVector<Type> args = {/*x=*/tensor, /*alpha=*/newType, /*is_virtual=*/i32};
SmallVector<Type> rets = {/*y=*/tensor};
auto functionType = FunctionType::get(ctx, args, rets);

Expand All @@ -146,11 +146,11 @@ func::FuncOp CudnnAPI::getPointwiseBinaryFunction(PatternRewriter &rewriter,
std::string_view op) {
MLIRContext *ctx = module->getContext();
auto tensor = CudnnTensorType::get(ctx);
auto f32 = Float32Type::get(ctx);
auto newType = Float64Type::get(ctx);
auto i32 = IntegerType::get(ctx, 32);

SmallVector<Type> args = {/*x=*/tensor, /*alpha=*/f32, /*b=*/tensor,
/*alpha2=*/f32, /*is_virtual=*/i32};
SmallVector<Type> args = {/*x=*/tensor, /*alpha=*/newType, /*b=*/tensor,
/*alpha2=*/newType, /*is_virtual=*/i32};
SmallVector<Type> rets = {/*y=*/tensor};
auto functionType = FunctionType::get(ctx, args, rets);

Expand Down Expand Up @@ -535,11 +535,11 @@ struct ConvertCudnnUnaryOp : public CudnnOpConversionPattern<T> {
MLIRContext *ctx = rewriter.getContext();
ImplicitLocOpBuilder b(op->getLoc(), rewriter);

auto f32 = rewriter.getF32Type();
auto newType = rewriter.getF64Type();

SmallVector<Value> args = {
adaptor.getX(),
b.create<arith::ConstantFloatOp>(adaptor.getAlpha(), f32),
b.create<arith::ConstantFloatOp>(adaptor.getAlpha(), newType),
b.create<arith::ConstantIntOp>(IsVirtual(op.getY()), 32),
};

Expand Down Expand Up @@ -576,13 +576,13 @@ struct ConvertCudnnBinaryOp : public CudnnOpConversionPattern<T> {
MLIRContext *ctx = rewriter.getContext();
ImplicitLocOpBuilder b(op->getLoc(), rewriter);

auto f32 = rewriter.getF32Type();
auto newType = rewriter.getF64Type();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newType=>f64 (and in few other places as well)

Copy link
Contributor Author

@bviyer bviyer Jun 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. (see af73d6b


SmallVector<Value> args = {
adaptor.getX(),
b.create<arith::ConstantFloatOp>(adaptor.getAlpha(), f32),
b.create<arith::ConstantFloatOp>(adaptor.getAlpha(), newType),
adaptor.getB(),
b.create<arith::ConstantFloatOp>(adaptor.getAlpha2(), f32),
b.create<arith::ConstantFloatOp>(adaptor.getAlpha2(), newType),
b.create<arith::ConstantIntOp>(IsVirtual(op.getY()), 32),
};

Expand Down
10 changes: 5 additions & 5 deletions compiler/src/openxla/compiler/nvgpu/Dialect/CUDNN/IR/CUDNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class CUDNN_UnaryOp<string mnemonic, list<Trait> traits = []> :
CUDNN_Op<mnemonic, !listconcat([Pure], traits)>,
Arguments<(ins
CUDNN_TensorType:$x,
DefaultValuedOptionalAttr<F32Attr, "1.0f">:$alpha)>,
DefaultValuedOptionalAttr<F64Attr, "1.0f">:$alpha)>,
Results<(outs CUDNN_TensorType:$y)> {
let assemblyFormat = [{
`(` $x `)`
Expand All @@ -78,8 +78,8 @@ class CUDNN_BinaryOp<string mnemonic, list<Trait> traits = []> :
Arguments<(ins
CUDNN_TensorType:$x,
CUDNN_TensorType:$b,
DefaultValuedOptionalAttr<F32Attr, "1.0f">:$alpha,
DefaultValuedOptionalAttr<F32Attr, "1.0f">:$alpha2)>,
DefaultValuedOptionalAttr<F64Attr, "1.0f">:$alpha,
DefaultValuedOptionalAttr<F64Attr, "1.0f">:$alpha2)>,
Results<(outs CUDNN_TensorType:$y)> {
let assemblyFormat = [{
`(` $x `,` $b `)`
Expand Down Expand Up @@ -160,8 +160,8 @@ def CUDNN_ConvolutionOp : CUDNN_Op<"convolution", [Pure]> {
let arguments = (ins
CUDNN_TensorType:$x,
CUDNN_TensorType:$w,
F32Attr:$alpha,
F32Attr:$beta,
F64Attr:$alpha,
F64Attr:$beta,
I32Attr:$spatial_dim_count,
DenseI64ArrayAttr:$spatial_stride,
DenseI64ArrayAttr:$pre_padding,
Expand Down
4 changes: 2 additions & 2 deletions runtime/src/openxla/runtime/nvgpu/cudnn/cudnn_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ StatusOr<vm::ref<CudnnTensor>> CreateTensor(

StatusOr<iree::vm::ref<CudnnTensor>> CreatePointwiseUnary(
openxla_cudnn_dynamic_symbols_t* syms, cudnnPointwiseMode_t mode,
CudnnTensor& x, float alpha, int64_t uid, int64_t alignment,
CudnnTensor& x, double alpha, int64_t uid, int64_t alignment,
bool is_virtual) {
ScopedCudnnStubs stubs(syms);

Expand Down Expand Up @@ -434,7 +434,7 @@ StatusOr<iree::vm::ref<CudnnTensor>> CreatePointwiseUnary(

StatusOr<iree::vm::ref<CudnnTensor>> CreatePointwiseBinary(
openxla_cudnn_dynamic_symbols_t* syms, cudnnPointwiseMode_t mode,
CudnnTensor& x, float alpha, CudnnTensor& b, float alpha2, int64_t uid,
CudnnTensor& x, double alpha, CudnnTensor& b, double alpha2, int64_t uid,
int64_t alignment, bool is_virtual) {
ScopedCudnnStubs stubs(syms);

Expand Down
4 changes: 2 additions & 2 deletions runtime/src/openxla/runtime/nvgpu/cudnn/cudnn_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,13 @@ iree::StatusOr<iree::vm::ref<CudnnTensor>> CreateTensor(
// Creates a pointwise unary operation.
iree::StatusOr<iree::vm::ref<CudnnTensor>> CreatePointwiseUnary(
openxla_cudnn_dynamic_symbols_t* syms, cudnnPointwiseMode_t mode,
CudnnTensor& x, float alpha, int64_t uid, int64_t alignment,
CudnnTensor& x, double alpha, int64_t uid, int64_t alignment,
bool is_virtual);

// Creates a pointwise binary operation.
iree::StatusOr<iree::vm::ref<CudnnTensor>> CreatePointwiseBinary(
openxla_cudnn_dynamic_symbols_t* syms, cudnnPointwiseMode_t mode,
CudnnTensor& x, float alpha, CudnnTensor& b, float alpha2, int64_t uid,
CudnnTensor& x, double alpha, CudnnTensor& b, double alpha2, int64_t uid,
int64_t alignment, bool is_virtual);

// Creates a relu operation.
Expand Down
16 changes: 8 additions & 8 deletions runtime/src/openxla/runtime/nvgpu/cudnn/cudnn_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,22 @@ class CudnnModuleState {

// Creates a pointwise relu operation and returns result tensor.
StatusOr<vm::ref<CudnnTensor>> Relu(const vm::ref<CudnnTensor> input,
float lower_clip, float upper_clip,
double lower_clip, double upper_clip,
int64_t uid, int64_t alignment,
int32_t is_virtual);

// Creates a pointwise unary operation and returns a result tensor.
template <cudnnPointwiseMode_t mode>
StatusOr<vm::ref<CudnnTensor>> PointwiseUnary(const vm::ref<CudnnTensor> x,
float alpha,
double alpha,
int32_t is_virtual);

// Creates a pointwise binary operation and returns a result tensor.
template <cudnnPointwiseMode_t mode>
StatusOr<vm::ref<CudnnTensor>> PointwiseBinary(const vm::ref<CudnnTensor> x,
float alpha,
double alpha,
const vm::ref<CudnnTensor> b,
float alpha2,
double alpha2,
int32_t is_virtual);

// Creates a bias operation and returns a result tensor.
Expand Down Expand Up @@ -193,23 +193,23 @@ Status CudnnModuleState::PrintGraphDebug(
}

StatusOr<vm::ref<CudnnTensor>> CudnnModuleState::Relu(
const vm::ref<CudnnTensor> input, float lower_clip, float upper_clip,
const vm::ref<CudnnTensor> input, double lower_clip, double upper_clip,
int64_t uid, int64_t alignment, int32_t is_virtual) {
return CreateRelu(syms_, *input, lower_clip, upper_clip, uid, alignment,
is_virtual);
}

template <cudnnPointwiseMode_t mode>
StatusOr<vm::ref<CudnnTensor>> CudnnModuleState::PointwiseUnary(
const vm::ref<CudnnTensor> x, float alpha, int32_t is_virtual) {
const vm::ref<CudnnTensor> x, double alpha, int32_t is_virtual) {
return CreatePointwiseUnary(syms_, mode, *x, alpha, uid_++, kAlignment,
is_virtual);
}

template <cudnnPointwiseMode_t mode>
StatusOr<vm::ref<CudnnTensor>> CudnnModuleState::PointwiseBinary(
const vm::ref<CudnnTensor> x, float alpha, const vm::ref<CudnnTensor> b,
float alpha2, int32_t is_virtual) {
const vm::ref<CudnnTensor> x, double alpha, const vm::ref<CudnnTensor> b,
double alpha2, int32_t is_virtual) {
return CreatePointwiseBinary(syms_, mode, *x, alpha, *b, alpha2, uid_++,
kAlignment, is_virtual);
}
Expand Down
28 changes: 14 additions & 14 deletions runtime/src/openxla/runtime/nvgpu/cudnn/test/conv2d.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,34 @@ util.initializer {
util.initializer.return
}

cudnn.graph @conv2d(%x: !cudnn.tensor<8x32x4x4xf32, NHWC>,
%w: !cudnn.tensor<32x32x1x1xf32, KHWC>)
-> !cudnn.tensor<8x32x4x4xf32, NHWC> {
cudnn.graph @conv2d(%x: !cudnn.tensor<8x32x4x4xf64, NHWC>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensors should stay in f32, only alpha/beta should be update to f64

Copy link
Contributor Author

@bviyer bviyer Jun 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Pretty much reverted this file. see: af73d6b

%w: !cudnn.tensor<32x32x1x1xf64, KHWC>)
-> !cudnn.tensor<8x32x4x4xf64, NHWC> {
%0 = cudnn.convolution(%x, %w) alpha=1.0 beta=0.0
spatial_dim_count=2
spatial_stride=[1,1]
pre_padding=[0,0]
post_padding=[0,0]
dilation=[1,1]
: (!cudnn.tensor<8x32x4x4xf32, NHWC>, !cudnn.tensor<32x32x1x1xf32, KHWC>)
-> !cudnn.tensor<8x32x4x4xf32, NHWC>
cudnn.return %0: !cudnn.tensor<8x32x4x4xf32, NHWC>
: (!cudnn.tensor<8x32x4x4xf64, NHWC>, !cudnn.tensor<32x32x1x1xf64, KHWC>)
-> !cudnn.tensor<8x32x4x4xf64, NHWC>
cudnn.return %0: !cudnn.tensor<8x32x4x4xf64, NHWC>
}

util.global @x : tensor<8x4x4x32xf32> = dense<1.0> : tensor<8x4x4x32xf32>
util.global @w : tensor<32x1x1x32xf32> = dense<1.0> : tensor<32x1x1x32xf32>
util.global @x : tensor<8x4x4x32xf64> = dense<1.0> : tensor<8x4x4x32xf64>
util.global @w : tensor<32x1x1x32xf64> = dense<1.0> : tensor<32x1x1x32xf64>

// CHECK: EXEC @main
// CHECK: result[0]: hal.buffer_view
// CHECK: 8x4x4x32xf32
// CHECK: 8x4x4x32xf64
// CHECK: [32 32 32 32 32
func.func @main() -> tensor<8x4x4x32xf32> {
%x = util.global.load @x : tensor<8x4x4x32xf32>
%w = util.global.load @w : tensor<32x1x1x32xf32>
func.func @main() -> tensor<8x4x4x32xf64> {
%x = util.global.load @x : tensor<8x4x4x32xf64>
%w = util.global.load @w : tensor<32x1x1x32xf64>
%handle = util.global.load @handle : !cudnn.handle

%0 = cudnn.call handle(%handle) @conv2d(%x, %w)
: (tensor<8x4x4x32xf32>, tensor<32x1x1x32xf32>) -> tensor<8x4x4x32xf32>
: (tensor<8x4x4x32xf64>, tensor<32x1x1x32xf64>) -> tensor<8x4x4x32xf64>

return %0 : tensor<8x4x4x32xf32>
return %0 : tensor<8x4x4x32xf64>
}