-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unify accelerator and non-accelator paths #1544
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
//===- FmaInsnGroup.h - MLIR to C++ for Rock conversion | ||
//---------------===// | ||
// | ||
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
// This file implements code selection logic for Fma instructions. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_FMA_INSN_GROUP_H | ||
#define MLIR_FMA_INSN_GROUP_H | ||
|
||
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" | ||
#include "llvm/ADT/DenseMap.h" | ||
#include "llvm/ADT/SmallString.h" | ||
#include "llvm/ADT/StringMap.h" | ||
|
||
namespace mlir { | ||
namespace rock { | ||
|
||
struct FmaInsn { | ||
Type argTypeA; | ||
Type argTypeB; | ||
Type retType; | ||
|
||
public: | ||
static FailureOr<FmaInsn> select(Type elementTypeA, Type elementTypeB, StringRef arch); | ||
}; | ||
|
||
|
||
|
||
} // namespace rock | ||
} // namespace mlir | ||
|
||
#endif // MLIR_FMA_INSN_GROUP_H | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: newline |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -390,9 +390,10 @@ def Rock_GridwiseGemmOp : | |
MemRefRankOf<GemmInputTypes, [3]>:$b, | ||
MemRefRankOf<GemmAccumulatorTypes, [3]>:$c, | ||
Rock_GemmFeaturesAttr:$features, | ||
StrAttr:$arch, | ||
I32Attr:$numCU, | ||
I32Attr:$gridSize, | ||
Rock_GeneralGemmParamsAttr:$params)> { | ||
RockTuningParamAttrInterface:$params)> { | ||
let summary = "Gridwise GEMM"; | ||
let description = [{ | ||
The `rock.gridwise_gemm` op computes gridwise GEMM. | ||
|
@@ -415,7 +416,7 @@ def Rock_GridwiseGemmAccelOp : | |
StoreMethodAttr:$storeMethod, | ||
I32Attr:$blockSize, | ||
I32Attr:$gridSize, | ||
RockAccelTuningParamAttrInterface:$params)> { | ||
RockTuningParamAttrInterface:$params)> { | ||
let summary = "Gridwise GEMM accelerated version"; | ||
let description = [{ | ||
The `rock.gridwise_gemm` op computes gridwise GEMM with acceleration. | ||
|
@@ -1116,7 +1117,9 @@ def Rock_BlockwiseGemmOp: | |
I32Attr:$inNPerThread, | ||
UnitAttr:$rotateMWithK, | ||
UnitAttr:$rotateNWithK, | ||
Rock_GeneralGemmParamsAttr:$params | ||
StrAttr:$arch, | ||
Rock_GemmFeaturesAttr:$features, | ||
RockTuningParamAttrInterface:$params | ||
)> { | ||
let summary = "Blockwise GEMM non accelerated version"; | ||
let description = [{ | ||
|
@@ -1168,7 +1171,7 @@ def Rock_BlockwiseGemmAccelOp: | |
StrAttr:$arch, | ||
Rock_GemmFeaturesAttr:$features, | ||
I32Attr:$blockSize, | ||
RockAccelTuningParamAttrInterface:$params)>{ | ||
RockTuningParamAttrInterface:$params)>{ | ||
let summary = "Blockwise GEMM accelerated version"; | ||
let description = [{ | ||
The `rock.block_gemm_v2` op does GEMM at workgroup (block) level. | ||
|
@@ -1215,7 +1218,7 @@ def Rock_ThreadwiseAccelGemmOp: | |
Arg<MemRefOf<NativeMemoryOpTypes>, "dest register view C", [MemRead, MemWrite]>:$matrixC, Variadic<Index>:$computeIndices, | ||
StrAttr:$arch, | ||
Rock_GemmFeaturesAttr:$features, | ||
RockAccelTuningParamAttrInterface:$params)> { | ||
RockTuningParamAttrInterface:$params)> { | ||
let summary = "Accelerated GEMM"; | ||
let description = [{ | ||
The `rock.accel_gemm` op is an abstraction of doing GEMM based on an accelerator. | ||
|
@@ -1229,6 +1232,21 @@ def Rock_ThreadwiseAccelGemmOp: | |
}]; | ||
let hasVerifier = 1; | ||
} | ||
// threadwise_gemmv2 | ||
def Rock_ThreadwiseGemmOpv2: | ||
Rock_Op<"threadwise_gemmv2">, | ||
Arguments<(ins Arg<MemRefOf<NativeMemoryOpTypes>, "source register view A", [MemRead]>:$matrixA, | ||
Arg<MemRefOf<NativeMemoryOpTypes>, "source register view B", [MemRead]>:$matrixB, | ||
Arg<MemRefOf<NativeMemoryOpTypes>, "dest register view C", [MemRead, MemWrite]>:$matrixC, Variadic<Index>:$computeIndices, | ||
StrAttr:$arch, | ||
Rock_GemmFeaturesAttr:$features, | ||
RockTuningParamAttrInterface:$params)> { | ||
let assemblyFormat = [{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a temporary thing, right? Why's this needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea is that this OP will replace threadwise_gemm and threadwise_accel_gemm. I am creating a new operation so I don't have to delete the existing code. |
||
$matrixC `+` `` `=` $matrixA `*` $matrixB `at` `[` $computeIndices `]` `features` `=` $features attr-dict | ||
`:` type($matrixC) `+` `` `=` type($matrixA) `*` type($matrixB) | ||
}]; | ||
let hasVerifier = 1; | ||
} | ||
|
||
// blockwise_broadcasting_reduction | ||
def Rock_BlockwiseBroadcastReduceOp: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#include "mlir/Dialect/Rock/IR/FmaInsnGroup.h" | ||
|
||
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" | ||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" | ||
#include "mlir/Dialect/Rock/utility/AmdArchDb.h" | ||
#include "mlir/Dialect/Rock/utility/math.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
|
||
#include "llvm/Support/Debug.h" | ||
#include "llvm/Support/ErrorHandling.h" | ||
#include <cstdint> | ||
|
||
#define DEBUG_TYPE "rock-fma-insn-group" | ||
|
||
using namespace mlir; | ||
using namespace mlir::rock; | ||
|
||
static Type getRetType(Type inputType) { | ||
Builder b(inputType.getContext()); | ||
if (inputType.isInteger(8)) | ||
return b.getI32Type(); | ||
|
||
return b.getF32Type();; | ||
} | ||
|
||
FailureOr<FmaInsn> FmaInsn::select(mlir::Type elementTypeA, mlir::Type elementTypeB, StringRef arch ){ | ||
LLVM_DEBUG(llvm::dbgs() << "Invoke FMA group selection:\n" | ||
<< "elementTypeA: " << elementTypeA << "\n" | ||
<< "elementTypeB: " << elementTypeB << "\n" | ||
<< "arch: " << arch << "\n"); | ||
|
||
Type retType = getRetType(elementTypeA); | ||
|
||
return FmaInsn{elementTypeA, elementTypeB, retType}; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to repeat
virtual
here