Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify accelerator and non-accelator paths #1544

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#ifndef MLIR_LIB_DIALECT_ROCK_TRANSFORMS_MLIR_ACCEL_EMITTER_H
#define MLIR_LIB_DIALECT_ROCK_TRANSFORMS_MLIR_ACCEL_EMITTER_H

#include "mlir/Dialect/Rock/IR/FmaInsnGroup.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Rock/IR/MfmaInsnGroup.h"
#include "mlir/Dialect/Rock/IR/TransformMapBuilder.h"
Expand Down Expand Up @@ -86,7 +87,7 @@ struct AccelEmitter {
/// Select the right accelerator based on the set of features and architecture
static std::unique_ptr<AccelEmitter>
select(GemmFeatures features, Type dataTypeA, Type dataTypeB, StringRef arch,
RockAccelTuningParamAttrInterface tuningParams);
RockTuningParamAttrInterface tuningParams);

/// Emit the actual intrinsic in the threadwise operation
virtual void emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA,
Expand Down Expand Up @@ -150,15 +151,15 @@ struct AccelEmitter {

virtual ~AccelEmitter() {}

enum AccelEmitterKind { AEK_MFMAEmitter, AEK_WMMAEmitter };
enum AccelEmitterKind { AEK_MFMAEmitter, AEK_WMMAEmitter, AEK_FMAEmitter};

AccelEmitterKind getKind() const { return kind; }

protected:
AccelEmitter(StringRef arch, RockAccelTuningParamAttrInterface tuningParams,
AccelEmitter(StringRef arch, RockTuningParamAttrInterface tuningParams,
AccelEmitterParams accelEmitterParams, AccelEmitterKind kind);

RockAccelTuningParamAttrInterface tuningParams;
RockTuningParamAttrInterface tuningParams;
AccelEmitterParams accelEmitterParams;
int64_t waveSize;

Expand All @@ -170,7 +171,7 @@ struct AccelEmitter {
struct MfmaEmitter : public AccelEmitter {

MfmaEmitter(MfmaInsnGroup mfmaGroup, StringRef arch,
RockAccelTuningParamAttrInterface tuningParams);
RockTuningParamAttrInterface tuningParams);

void emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, Value argB,
Value bufferC, ValueRange regCOffset) override;
Expand Down Expand Up @@ -208,7 +209,7 @@ struct MfmaEmitter : public AccelEmitter {
/// Initialize the emitter parameters for mfma
AccelEmitterParams
initAccelEmitterParams(MfmaInsnGroup mfmaGroup,
RockAccelTuningParamAttrInterface tuningParams);
RockTuningParamAttrInterface tuningParams);

MfmaInsnGroup mfmaGroup;
};
Expand All @@ -217,7 +218,7 @@ struct MfmaEmitter : public AccelEmitter {
struct WmmaEmitter : public AccelEmitter {

WmmaEmitter(WmmaInsn wmmaInsn, StringRef arch,
RockAccelTuningParamAttrInterface tuningParams);
RockTuningParamAttrInterface tuningParams);

void emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, Value argB,
Value bufferC, ValueRange regCOffset) override;
Expand Down Expand Up @@ -249,11 +250,50 @@ struct WmmaEmitter : public AccelEmitter {
/// Initialize the emitter parameters for wmma
AccelEmitterParams
initAccelEmitterParams(WmmaInsn wmmaInsn,
RockAccelTuningParamAttrInterface tuningParams);
RockTuningParamAttrInterface tuningParams);

// Specifc wmma parameters
WmmaInsn wmmaInsn;
};

// Accel emitter implementation for fma

struct FmaEmitter : public AccelEmitter {

FmaEmitter(FmaInsn fmaInsn, StringRef arch,
RockTuningParamAttrInterface tuningParams);

void emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, Value argB,
Value bufferC, ValueRange regCOffset) override;

virtual Value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to repeat virtual here

wrapLDSBufferForLoad(OpBuilder &b, Location loc, Value buffer,
int64_t blockSize, int64_t dInCopyPerThread,
StringRef dName, bool rotateDWithK,
bool doSplitKAcrossThreadsFirst = false) const override;

virtual RegsAsMatrixSubTiles createAccelGemmOperandTransforms(
OpBuilder &b, Location loc, int64_t kIters,
ArrayRef<int64_t> bidGridLengths, int64_t blockSize,
int64_t dInCopyPerThread, StringRef dName, bool isKContigousDim,
bool rotateDWithK,
bool doSplitKAcrossThreadsFirst = false) const override;

RegsAsMatrixSubTiles computeOutputTransforms(
OpBuilder &b, Location loc, int64_t mLen, int64_t nLen, int64_t blockSize,
ArrayRef<int64_t> bidGridLengths, int64_t inMPerThread,
int64_t inNPerThread, bool doSwapThreadIterSubDimsForM = false,
bool doSwapThreadIterSubDimsForN = false) override;

private:
// Initialize the emitter parameters for fma
AccelEmitterParams
initAccelEmitterParams(FmaInsn fmaInsn,
RockTuningParamAttrInterface tuningParams);

// Specific fma parameters
FmaInsn fmaInsn;
};
} // namespace accel
} // namespace rock
} // namespace mlir
Expand Down
39 changes: 39 additions & 0 deletions mlir/include/mlir/Dialect/Rock/IR/FmaInsnGroup.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//===- FmaInsnGroup.h - MLIR to C++ for Rock conversion
//---------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// This file implements code selection logic for Fma instructions.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_FMA_INSN_GROUP_H
#define MLIR_FMA_INSN_GROUP_H

#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringMap.h"

namespace mlir {
namespace rock {

struct FmaInsn {
Type argTypeA;
Type argTypeB;
Type retType;

public:
static FailureOr<FmaInsn> select(Type elementTypeA, Type elementTypeB, StringRef arch);
};



} // namespace rock
} // namespace mlir

#endif // MLIR_FMA_INSN_GROUP_H
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: newline

28 changes: 23 additions & 5 deletions mlir/include/mlir/Dialect/Rock/IR/RockOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,10 @@ def Rock_GridwiseGemmOp :
MemRefRankOf<GemmInputTypes, [3]>:$b,
MemRefRankOf<GemmAccumulatorTypes, [3]>:$c,
Rock_GemmFeaturesAttr:$features,
StrAttr:$arch,
I32Attr:$numCU,
I32Attr:$gridSize,
Rock_GeneralGemmParamsAttr:$params)> {
RockTuningParamAttrInterface:$params)> {
let summary = "Gridwise GEMM";
let description = [{
The `rock.gridwise_gemm` op computes gridwise GEMM.
Expand All @@ -415,7 +416,7 @@ def Rock_GridwiseGemmAccelOp :
StoreMethodAttr:$storeMethod,
I32Attr:$blockSize,
I32Attr:$gridSize,
RockAccelTuningParamAttrInterface:$params)> {
RockTuningParamAttrInterface:$params)> {
let summary = "Gridwise GEMM accelerated version";
let description = [{
The `rock.gridwise_gemm` op computes gridwise GEMM with acceleration.
Expand Down Expand Up @@ -1116,7 +1117,9 @@ def Rock_BlockwiseGemmOp:
I32Attr:$inNPerThread,
UnitAttr:$rotateMWithK,
UnitAttr:$rotateNWithK,
Rock_GeneralGemmParamsAttr:$params
StrAttr:$arch,
Rock_GemmFeaturesAttr:$features,
RockTuningParamAttrInterface:$params
)> {
let summary = "Blockwise GEMM non accelerated version";
let description = [{
Expand Down Expand Up @@ -1168,7 +1171,7 @@ def Rock_BlockwiseGemmAccelOp:
StrAttr:$arch,
Rock_GemmFeaturesAttr:$features,
I32Attr:$blockSize,
RockAccelTuningParamAttrInterface:$params)>{
RockTuningParamAttrInterface:$params)>{
let summary = "Blockwise GEMM accelerated version";
let description = [{
The `rock.block_gemm_v2` op does GEMM at workgroup (block) level.
Expand Down Expand Up @@ -1215,7 +1218,7 @@ def Rock_ThreadwiseAccelGemmOp:
Arg<MemRefOf<NativeMemoryOpTypes>, "dest register view C", [MemRead, MemWrite]>:$matrixC, Variadic<Index>:$computeIndices,
StrAttr:$arch,
Rock_GemmFeaturesAttr:$features,
RockAccelTuningParamAttrInterface:$params)> {
RockTuningParamAttrInterface:$params)> {
let summary = "Accelerated GEMM";
let description = [{
The `rock.accel_gemm` op is an abstraction of doing GEMM based on an accelerator.
Expand All @@ -1229,6 +1232,21 @@ def Rock_ThreadwiseAccelGemmOp:
}];
let hasVerifier = 1;
}
// threadwise_gemmv2
def Rock_ThreadwiseGemmOpv2:
Rock_Op<"threadwise_gemmv2">,
Arguments<(ins Arg<MemRefOf<NativeMemoryOpTypes>, "source register view A", [MemRead]>:$matrixA,
Arg<MemRefOf<NativeMemoryOpTypes>, "source register view B", [MemRead]>:$matrixB,
Arg<MemRefOf<NativeMemoryOpTypes>, "dest register view C", [MemRead, MemWrite]>:$matrixC, Variadic<Index>:$computeIndices,
StrAttr:$arch,
Rock_GemmFeaturesAttr:$features,
RockTuningParamAttrInterface:$params)> {
let assemblyFormat = [{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a temporary thing, right? Why's this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that this OP will replace threadwise_gemm and threadwise_accel_gemm. I am creating a new operation so I don't have to delete the existing code.

$matrixC `+` `` `=` $matrixA `*` $matrixB `at` `[` $computeIndices `]` `features` `=` $features attr-dict
`:` type($matrixC) `+` `` `=` type($matrixA) `*` type($matrixB)
}];
let hasVerifier = 1;
}

// blockwise_broadcasting_reduction
def Rock_BlockwiseBroadcastReduceOp:
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Rock/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_rocmlir_dialect_library(MLIRRockOps
RockWriterOpInterface.cpp
MfmaInsnGroup.cpp
WmmaInsnGroup.cpp
FmaInsnGroup.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Rock
Expand Down
36 changes: 36 additions & 0 deletions mlir/lib/Dialect/Rock/IR/FmaInsnGroup.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "mlir/Dialect/Rock/IR/FmaInsnGroup.h"

#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Rock/utility/AmdArchDb.h"
#include "mlir/Dialect/Rock/utility/math.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"

#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include <cstdint>

#define DEBUG_TYPE "rock-fma-insn-group"

using namespace mlir;
using namespace mlir::rock;

static Type getRetType(Type inputType) {
Builder b(inputType.getContext());
if (inputType.isInteger(8))
return b.getI32Type();

return b.getF32Type();;
}

FailureOr<FmaInsn> FmaInsn::select(mlir::Type elementTypeA, mlir::Type elementTypeB, StringRef arch ){
LLVM_DEBUG(llvm::dbgs() << "Invoke FMA group selection:\n"
<< "elementTypeA: " << elementTypeA << "\n"
<< "elementTypeB: " << elementTypeB << "\n"
<< "arch: " << arch << "\n");

Type retType = getRetType(elementTypeA);

return FmaInsn{elementTypeA, elementTypeB, retType};
}
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Rock/IR/RockDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,14 @@ LogicalResult ThreadwiseGemmOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// ThreadwiseGemmOpv2
//===----------------------------------------------------------------------===//
LogicalResult ThreadwiseGemmOpv2::verify() {
//TO-DO
return success();
}

//===----------------------------------------------------------------------===//
// ThreadwiseAccelGemmOp
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 6 additions & 5 deletions mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ struct BlockwiseGemmRewritePattern
int64_t mC = bufferCType.getShape()[0];
int64_t nC = bufferCType.getShape()[1];

GeneralGemmParamsAttr params = op.getParams();
GeneralGemmParamsAttr params = op.getParams().cast<GeneralGemmParamsAttr>();
uint32_t blockSize = params.getBlockSize();
int64_t kPerThread = params.getKPerThread();
int64_t mPerThread = params.getMPerThread();
Expand Down Expand Up @@ -382,8 +382,8 @@ struct BlockwiseGemmRewritePattern
Value reshapedBRegisters = reshapeBuffer(
b, loc, threadBAllocOp, {"k", "n", "kpack"}, {kPerThread, nC, kPack});
// Actually do the gemm - this goes inside the look over kOffset
b.create<ThreadwiseGemmOp>(loc, reshapedARegisters, reshapedBRegisters,
op.getMatrixC());
b.create<ThreadwiseGemmOpv2>(loc, reshapedARegisters, reshapedBRegisters, op.getMatrixC(),
ValueRange{zeroConstantOp,zeroConstantOp,zeroConstantOp,zeroConstantOp}, op.getArchAttr(), op.getFeaturesAttr(), op.getParamsAttr());

return success();
}
Expand All @@ -402,7 +402,7 @@ struct BlockwiseGemmAccelRewritePattern
Location loc = op.getLoc();

StringAttr arch = op.getArchAttr();
RockAccelTuningParamAttrInterface tuningParams = op.getParams();
RockAccelTuningParamAttrInterface tuningParams = op.getParams().cast<RockAccelTuningParamAttrInterface>();
int64_t kpackPerBlock = tuningParams.getKpackPerBlock();
int64_t mPerWave = tuningParams.getMPerWave();
int64_t nPerWave = tuningParams.getNPerWave();
Expand Down Expand Up @@ -503,7 +503,8 @@ struct BlockwiseGemmAccelRewritePattern
Value viewC = accelEmitterPtr->generateThreadwiseViewBufferC(
b, loc, adaptor.getMatrixC());
Value k = kLoop.getInductionVar();
b.create<ThreadwiseAccelGemmOp>(loc, viewA, viewB, viewC,

b.create<ThreadwiseGemmOpv2>(loc, viewA, viewB, viewC,
ValueRange{i, j, k}, arch,
op.getFeaturesAttr(), tuningParams);
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,10 @@ GemmRewritePattern::matchAndRewrite(GemmOp op, GemmOpAdaptor adaptor,
rw.create<GridwiseGemmAccelOp>(
loc, a, b, accumulator, op.getArchAttr(), numCUAttr,
op.getFeaturesAttr(), op.getStoreMethodAttr(), blockSize, gridSize,
params.cast<RockAccelTuningParamAttrInterface>());
params.cast<RockTuningParamAttrInterface>());
} else {
rw.create<GridwiseGemmOp>(loc, a, b, accumulator, op.getFeaturesAttr(),
numCUAttr, gridSize,
op.getArchAttr(), numCUAttr, gridSize,
params.cast<GeneralGemmParamsAttr>());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<GridwiseGemmOp> {

// Obtain critical tuning parameters.
uint32_t gridSize = op.getGridSize();
GeneralGemmParamsAttr tuningParams = op.getParams();
GeneralGemmParamsAttr tuningParams = op.getParams().cast<GeneralGemmParamsAttr>();
int64_t kpack = tuningParams.getKpack();
// TODO: kPerBlock, as defined in parameter selection etc,
// is in units of kPack, not individual k. This should be changed
Expand Down Expand Up @@ -639,7 +639,8 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<GridwiseGemmOp> {
b.getI32IntegerAttr(copyMPerThread),
b.getI32IntegerAttr(copyNPerThread),
rotateMWithK ? b.getUnitAttr() : nullptr,
rotateNWithK ? b.getUnitAttr() : nullptr, op.getParamsAttr());
rotateNWithK ? b.getUnitAttr() : nullptr,
op.getArchAttr(), op.getFeaturesAttr(), op.getParamsAttr());

// LDS barrier.
// This barrier prevents halo part of outputs having weird values.
Expand Down Expand Up @@ -2427,7 +2428,7 @@ struct GridwiseGemmAccelRewritePattern
StringRef arch = op.getArch();
uint32_t blockSize = op.getBlockSize();
uint32_t gridSize = op.getGridSize();
RockAccelTuningParamAttrInterface tuningParams = op.getParams();
RockAccelTuningParamAttrInterface tuningParams = op.getParams().cast<RockAccelTuningParamAttrInterface>();
int64_t kpack = tuningParams.getKpack();
// TODO: kPerBlock, as defined in parameter selection etc,
// is in units of kPack, not individual k. This should be changed
Expand Down
Loading