Skip to content

Commit

Permalink
Merge branch 'openxla:main' into shark
Browse files Browse the repository at this point in the history
  • Loading branch information
saienduri authored Feb 19, 2024
2 parents 76479f0 + da98215 commit eb2b348
Show file tree
Hide file tree
Showing 31 changed files with 816 additions and 359 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ iree_compiler_cc_library(
":PassesIncGen",
# Dialects
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtTransformOps",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ iree_cc_library(
MLIRVectorTransforms
iree::compiler::Codegen::Common::TransformExtensions::CommonExtensions
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::Codegen::LLVMCPU::TransformExtensions::LLVMCPUExtensions
iree::compiler::Codegen::LLVMGPU::TransformExtensions::LLVMGPUExtensions
iree::compiler::Codegen::TransformStrategies::Common::TransformStrategies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/LLVMCPU/TransformExtensions/LLVMCPUExtensions.h"
#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
Expand Down Expand Up @@ -65,6 +66,7 @@ void registerTransformDialectTranslationDependentDialects(
mlir::iree_compiler::IREE::VectorExt::IREEVectorExtDialect,
mlir::iree_compiler::IREE::Flow::FlowDialect,
mlir::iree_compiler::IREE::Codegen::IREECodegenDialect,
mlir::iree_compiler::IREE::GPU::IREEGPUDialect,
arith::ArithDialect,
affine::AffineDialect,
bufferization::BufferizationDialect,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"

#define DEBUG_TYPE "iree-amdgpu-distribute-contract"

Expand All @@ -19,18 +20,74 @@ namespace {
using namespace mlir::iree_compiler::IREE::VectorExt;
using VectorValue = TypedValue<VectorType>;

enum class ContractKind { MK_KN_MN, UNKNOWN };

// Gets the kind of a contract op with the given indexing |maps|.
ContractKind inferContractKind(MLIRContext *ctx, SmallVector<AffineMap> maps) {
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); };
AffineExpr m, n, k;
bindDims(ctx, m, n, k);
if (maps == infer({{m, k}, {k, n}, {m, n}}))
return ContractKind::MK_KN_MN;
return ContractKind::UNKNOWN;
}
/// A class for querying information about a contract op.
class ContractOpDetail {
public:
enum class OpKind { MK_KN_MN, MK_NK_MN, UNKNOWN };

explicit ContractOpDetail(vector::ContractionOp op) {
opKind = inferOpKind(op.getContext(), op.getIndexingMapsArray());
}

OpKind getOpKind() const { return opKind; }

// Returns the (LHS M, RHS N) dimension index pair.
std::optional<std::pair<int, int>> getOperandMNIndex() const {
switch (opKind) {
case OpKind::MK_KN_MN:
return std::make_pair(0, 1);
case OpKind::MK_NK_MN:
return std::make_pair(0, 0);
case OpKind::UNKNOWN:
break;
}
return std::nullopt;
}

// Returns the (LHS K, RHS K) dimension index pair.
std::optional<std::pair<int, int>> getOperandKIndex() const {
switch (opKind) {
case OpKind::MK_KN_MN:
return std::make_pair(1, 0);
case OpKind::MK_NK_MN:
return std::make_pair(1, 1);
case OpKind::UNKNOWN:
break;
}
return std::nullopt;
}

// Returns the result (M, N) dimension index pair.
std::optional<std::pair<int, int>> getResultMNIndex() const {
switch (opKind) {
case OpKind::MK_KN_MN:
case OpKind::MK_NK_MN:
return std::make_pair(0, 1);
default:
break;
}
return std::nullopt;
}

private:
// Gets the kind of a contract op with the given indexing |maps|.
OpKind inferOpKind(MLIRContext *ctx, SmallVector<AffineMap> maps) {
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [&](MapList m) {
return AffineMap::inferFromExprList(m, ctx);
};
AffineExpr m, n, k;
bindDims(ctx, m, n, k);
if (maps == infer({{m, k}, {k, n}, {m, n}}))
return OpKind::MK_KN_MN;
if (maps == infer({{m, k}, {n, k}, {m, n}}))
return OpKind::MK_NK_MN;
return OpKind::UNKNOWN;
}

private:
OpKind opKind = OpKind::UNKNOWN;
};

/// Distributes `vector.contract` ops with nested layouts.
struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
Expand Down Expand Up @@ -83,9 +140,8 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
mfmaParams.blocks = mfmaAttr.getBlockSize();

// Infer the contract kind so that we know know to correlate M/N/K dims.
ContractKind contractKind =
inferContractKind(getContext(), contractOp.getIndexingMapsArray());
if (contractKind == ContractKind::UNKNOWN) {
ContractOpDetail opDetail(contractOp);
if (opDetail.getOpKind() == ContractOpDetail::OpKind::UNKNOWN) {
return rewriter.notifyMatchFailure(contractOp, "unknown contract kind");
}

Expand Down Expand Up @@ -146,7 +202,7 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {

// Get the k batch size for LHS and RHS vector.
std::optional<int64_t> kBatch =
getKBatchSize(contractKind, lhsLayout, rhsLayout);
getKBatchSize(opDetail, lhsLayout, rhsLayout);
LLVM_DEBUG(llvm::dbgs() << "k batch size = " << kBatch << "\n");
if (!kBatch) {
return rewriter.notifyMatchFailure(contractOp,
Expand All @@ -156,15 +212,12 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
// Perform contraction by doing separate outer product with amdgpu.mfma
// operation and accumulate to the same vector.
for (int k = 0; k < kBatch; ++k) {
// Get the batch offsets for LHS and RHS. For the K dimension it's the
// Fills the batch offsets for LHS and RHS. For the K dimension it's the
// induction variable; for the M/N dimension we need to extract from the
// result batch offsets.
if (!getOperandBatchOffsets(contractKind, k, originalResultBatchOffsets,
resultLayout, lhsBatchOffsets,
rhsBatchOffsets, lhsLayout, rhsLayout)) {
return rewriter.notifyMatchFailure(
contractOp, "cannot deduce lhs/rhs batch offsets");
}
fillOperandBatchOffsets(opDetail, k, originalResultBatchOffsets,
resultLayout, lhsBatchOffsets, rhsBatchOffsets,
lhsLayout, rhsLayout);
LLVM_DEBUG({
llvm::dbgs() << "current lhs batch offsets: [";
llvm::interleaveComma(lhsBatchOffsets, llvm::dbgs());
Expand All @@ -190,47 +243,41 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
}

// Gets the batch size for matmul K dimensions.
std::optional<int64_t> getKBatchSize(ContractKind kind,
std::optional<int64_t> getKBatchSize(const ContractOpDetail &opDetail,
NestedLayoutAttr lhsLayout,
NestedLayoutAttr rhsLayout) const {
int64_t lhsKBatch = 0, rhsKBatch = 0;
if (kind == ContractKind::MK_KN_MN) {
lhsKBatch = lhsLayout.getBatchesPerSubgroup()[1];
rhsKBatch = rhsLayout.getBatchesPerSubgroup()[0];
} else {
return std::nullopt;
}
auto [lhsK, rhsK] = *opDetail.getOperandKIndex();
int64_t lhsKBatch = lhsLayout.getBatchesPerSubgroup()[lhsK];
int64_t rhsKBatch = rhsLayout.getBatchesPerSubgroup()[rhsK];

if (lhsKBatch != rhsKBatch)
return std::nullopt;
return lhsKBatch;
}

// Given a contract op's batch |resultOffsets|, gets its batch offsets for
// Given a contract op's batch |resultOffsets|, fills its batch offsets for
// both LHS and RHS.
bool getOperandBatchOffsets(ContractKind kind, int64_t kOffset,
ArrayRef<int64_t> resultOffsets,
NestedLayoutAttr resultLayout,
SmallVector<int64_t, 2> &lhsOffsets,
SmallVector<int64_t, 2> &rhsOffsets,
NestedLayoutAttr lhsLayout,
NestedLayoutAttr rhsLayout) const {
void fillOperandBatchOffsets(const ContractOpDetail &opDetail,
int64_t kOffset, ArrayRef<int64_t> resultOffsets,
NestedLayoutAttr resultLayout,
SmallVector<int64_t, 2> &lhsOffsets,
SmallVector<int64_t, 2> &rhsOffsets,
NestedLayoutAttr lhsLayout,
NestedLayoutAttr rhsLayout) const {
auto [lhsM, rhsN] = *opDetail.getOperandMNIndex();
auto [lhsK, rhsK] = *opDetail.getOperandKIndex();
auto [resultM, resultN] = *opDetail.getResultMNIndex();
// resultOffsets contains batch indices into the C/D vector. It is a 2-D
// index for both M and N. We need to split out for M and N, and add index
// for K.
if (kind == ContractKind::MK_KN_MN) {
lhsOffsets[0] = resultOffsets[0];
lhsOffsets[1] = kOffset;
rhsOffsets[0] = kOffset;
rhsOffsets[1] = resultOffsets[1];
} else {
return false;
}
lhsOffsets[lhsM] = resultOffsets[resultM];
lhsOffsets[lhsK] = kOffset;
rhsOffsets[rhsN] = resultOffsets[resultN];
rhsOffsets[rhsK] = kOffset;

// Now apply permutation on LHS/RHS according to their batch order.
applyPermutationToVector(lhsOffsets, lhsLayout.getBatchOrder());
applyPermutationToVector(rhsOffsets, rhsLayout.getBatchOrder());
return true;
}

struct MFMAParameters {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ static void populateWarpAndThreadIndices(RewriterBase &rewriter, Value threadId,
int64_t rank = vectorLayout.getBatchOrder().size();
// The delinearized thread IDs are returned from outer most to inner most,
// i.e. before applying the layout described dimensions ordering.
ValueRange threadIds = vectorLayout.computeThreadIds(threadId, rewriter);
SmallVector<Value> threadIds =
vectorLayout.computeThreadIds(threadId, rewriter);

// Subgroup and thread (lane) indices normalized to the order in which
// they are used by each dimension.
Expand Down
Loading

0 comments on commit eb2b348

Please sign in to comment.