Skip to content

Commit

Permalink
[LinalgExt] Implement AggregateOpInterface for AttentionOp (iree-org#…
Browse files Browse the repository at this point in the history
…18890)

- Adds AggregateOpInterface for AttentionOp
- Move all aggregate interface tests to IR/test/decompose_aggregate_op
  • Loading branch information
Groverkss authored Oct 29, 2024
1 parent b31b033 commit 3cf5b65
Show file tree
Hide file tree
Showing 8 changed files with 374 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -299,50 +299,29 @@ static bool willBeContiguousSlice(OpFoldResult inputSize, OpFoldResult tileSize,
}

//===----------------------------------------------------------------------===//
// OnlineAttentionOp
// Attention Helpers
//===----------------------------------------------------------------------===//

FailureOr<SmallVector<Value>>
OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
Location loc = getLoc();
Value query = getQuery();
Value key = getKey();
Value value = getValue();
std::optional<Value> mask = getMask();
Value oldAcc = getOutput();
Value oldMax = getMax();
Value oldSum = getSum();
Type elementType = getElementTypeOrSelf(getOutput().getType());
DictionaryAttr config = getDecompositionConfigAttr();

DictionaryAttr qkAttrs, pvAttrs;
if (config) {
qkAttrs = config.getAs<DictionaryAttr>(getQKAttrStr());
pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
}

FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(getIndexingMapsArray());
assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
AttentionOpDetail opInfo = maybeOpInfo.value();

SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
getIterationDomain(b), [](Range x) { return x.size; });

Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query,
Value key, Value scale, std::optional<Value> mask,
AffineMap qMap, AffineMap kMap, AffineMap sMap,
std::optional<AffineMap> maskMap,
SmallVector<OpFoldResult> iterationDomain,
Type sElementType, Region &elementwiseRegion,
DictionaryAttr qkAttrs, bool lowPrecision) {
MLIRContext *ctx = b.getContext();
// Since we use exp2 for attention instead of the original exp, we have to
// multiply the scale by log2(e). We use exp2 instead of exp as most platforms
// have better support for exp2 (we verified that we gain some speedup on
// some GPUs).
Value scale = getScale();
Value log2e = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(scale.getType(), M_LOG2E));
scale = b.create<arith::MulFOp>(loc, scale, log2e);

auto qETy = getElementTypeOrSelf(query.getType());
auto vETy = getElementTypeOrSelf(value.getType());

AffineMap scaleMap = AffineMap::get(/*dimCount=*/getQueryMap().getNumInputs(),
/*symbolCount=*/0, getContext());
AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(),
/*symbolCount=*/0, ctx);

// In the original algorithm, the scaling is done after the softmax:
// softmax(Q @ K.T * scale) @ V
Expand All @@ -352,43 +331,40 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
// iteration of the loop. This is only valid for f16 or f32 types as f8
// is extremely limited on its dynamic range therefore this would
// significantly affect numerics.
if (qETy.getIntOrFloatBitWidth() > 8) {
AffineMap qMap = getQueryMap();
if (!lowPrecision) {
query = elementwiseValueInPlace<arith::MulFOp>(b, loc, qMap, scaleMap,
query, scale);
}

// ---- Matmul 1 ----
// ---- QK Matmul ----

// Get sizes for S.
AffineMap sMap = opInfo.getSMap();
SmallVector<OpFoldResult> sSizes;
for (AffineExpr dimExpr : sMap.getResults()) {
int dim = cast<AffineDimExpr>(dimExpr).getPosition();
sSizes.push_back(sizes[dim]);
sSizes.push_back(iterationDomain[dim]);
}

// S = Q @ K
// SMap = QMap @ KMap
Value emptyS = b.create<tensor::EmptyOp>(loc, sSizes, elementType);
Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value emptyS = b.create<tensor::EmptyOp>(loc, sSizes, sElementType);
Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(sElementType));
Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);

s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);
s = computeMatmul(b, loc, qMap, kMap, sMap, query, key, s);
if (qkAttrs) {
s.getDefiningOp()->setDiscardableAttrs(qkAttrs);
s.getDefiningOp()->setAttrs(qkAttrs);
}

s = applyPostQKMatmulElementwise(b, loc, getRegion(), s);
s = applyPostQKMatmulElementwise(b, loc, elementwiseRegion, s);

bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8;
if (lowPrecision) {
// For low bit-depth types we perform post Q @ K scaling. This is to avoid
// losing numerical precision due to the low dynamic range of fp8 types when
// pre applying the sclaing.
AffineMap sMap = b.getMultiDimIdentityMap(sSizes.size());
AffineMap scaleMap = AffineMap::get(/*dimCount=*/sMap.getNumInputs(),
/*symbolCount=*/0, getContext());
/*symbolCount=*/0, ctx);
s = elementwiseValueInPlace<arith::MulFOp>(b, loc, sMap, scaleMap, s,
scale);

Expand All @@ -401,16 +377,176 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
.convertToDouble();
Value offset = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / mx));
loc, b.getFloatAttr(sElementType, clAttentionSoftmaxMax / mx));
s = elementwiseValueInPlace<arith::AddFOp>(b, loc, sMap, scaleMap, s,
offset);
}

// S += mask
if (mask != nullptr) {
s = applyMask(b, loc, sMap, *getMaskMap(), s, mask.value());
s = applyMask(b, loc, sMap, *maskMap, s, mask.value());
}

return s;
}

//===----------------------------------------------------------------------===//
// AttentionOp
//===----------------------------------------------------------------------===//

FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
Location loc = getLoc();
Value query = getQuery();
Value key = getKey();
Value value = getValue();
std::optional<Value> mask = getMask();
DictionaryAttr config = getDecompositionConfigAttr();

DictionaryAttr qkAttrs, pvAttrs;
if (config) {
qkAttrs = config.getAs<DictionaryAttr>(getQKAttrStr());
pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
}
Value output = getOutput();

FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(getIndexingMapsArray());
assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
AttentionOpDetail opInfo = maybeOpInfo.value();

SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
getIterationDomain(b), [](Range x) { return x.size; });

AffineMap qMap = getQueryMap();
AffineMap kMap = getKeyMap();
AffineMap sMap = opInfo.getSMap();

auto qETy = getElementTypeOrSelf(query.getType());
bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8;

// We compute output of first matmul in f32.
Type f32Type = b.getF32Type();

// ---- QK Matmul + elementwise math ----
Value s = computeQKAndElementwise(loc, b, query, key, getScale(), mask, qMap,
kMap, sMap, getMaskMap(), sizes, f32Type,
getRegion(), qkAttrs, lowPrecision);

// ---- Softmax ----

AffineMap accMap = getOutputMap();

llvm::SmallBitVector projectedK2Dims(opInfo.getDomainRank(), false);
for (auto dim : opInfo.getK2Dims()) {
projectedK2Dims.set(dim);
}

AffineMap maxMap = projectDims(sMap, projectedK2Dims).dropZeroResults();
AffineMap sumMap = maxMap;

SmallVector<OpFoldResult> rowRedSize =
applyPermutationMap<OpFoldResult>(maxMap, sizes);

Value rowRedEmpty = b.create<tensor::EmptyOp>(loc, rowRedSize, f32Type);

Value accInit = arith::getIdentityValue(arith::AtomicRMWKind::addf,
getElementTypeOrSelf(output), b, loc,
/*useOnlyFiniteValue=*/true);
Value maxInit =
arith::getIdentityValue(arith::AtomicRMWKind::maximumf, f32Type, b, loc,
/*useOnlyFiniteValue=*/true);
Value sumInit =
arith::getIdentityValue(arith::AtomicRMWKind::addf, f32Type, b, loc);

Value accFill =
b.create<linalg::FillOp>(loc, ValueRange{accInit}, output).getResult(0);
Value maxFill =
b.create<linalg::FillOp>(loc, ValueRange{maxInit}, rowRedEmpty)
.getResult(0);
Value sumFill =
b.create<linalg::FillOp>(loc, ValueRange{sumInit}, rowRedEmpty)
.getResult(0);

// max = rowMax(S)
Value max = reduce<arith::MaximumFOp>(b, loc, sMap, maxMap, s, maxFill);

// P = exp2(S - max)
AffineMap pMap = sMap;
Value p = computeSubAndExp2(b, loc, maxMap, sMap, max, s);

// sum = rowSum(P)
Value sum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, sumFill);

// P = P / sum
p = elementwiseValueInPlace<arith::DivFOp>(b, loc, pMap, sumMap, p, sum);

// ---- Scale and truncate LHS to match RHS ----
SmallVector<OpFoldResult> sSizes;
for (AffineExpr dimExpr : sMap.getResults()) {
int dim = cast<AffineDimExpr>(dimExpr).getPosition();
sSizes.push_back(sizes[dim]);
}

auto pETy = getElementTypeOrSelf(p.getType());
auto vETy = getElementTypeOrSelf(value.getType());
if (pETy != vETy && isa<FloatType>(vETy)) {
Value convertP = b.create<tensor::EmptyOp>(loc, sSizes, vETy);
p = truncateFloat(b, loc, pMap, pMap, p, convertP, lowPrecision);
}

// result = P @ V + acc
Value result =
computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, accFill);
if (pvAttrs) {
result.getDefiningOp()->setAttrs(pvAttrs);
}

return SmallVector<Value>{result};
}

//===----------------------------------------------------------------------===//
// OnlineAttentionOp
//===----------------------------------------------------------------------===//

FailureOr<SmallVector<Value>>
OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
Location loc = getLoc();
Value query = getQuery();
Value key = getKey();
Value value = getValue();
std::optional<Value> mask = getMask();
Value oldAcc = getOutput();
Value oldMax = getMax();
Value oldSum = getSum();
Type elementType = getElementTypeOrSelf(getOutput().getType());
DictionaryAttr config = getDecompositionConfigAttr();

DictionaryAttr qkAttrs, pvAttrs;
if (config) {
qkAttrs = config.getAs<DictionaryAttr>(getQKAttrStr());
pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
}

FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(getIndexingMapsArray());
assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps");
AttentionOpDetail opInfo = maybeOpInfo.value();

SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
getIterationDomain(b), [](Range x) { return x.size; });

AffineMap qMap = getQueryMap();
AffineMap kMap = getKeyMap();
AffineMap sMap = opInfo.getSMap();

auto qETy = getElementTypeOrSelf(query.getType());
bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8;

// ---- QK Matmul + elementwise math ----
Value s = computeQKAndElementwise(
loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(),
sizes, elementType, getRegion(), qkAttrs, lowPrecision);

// TODO: This decomposition should be in a seperate op called
// "online softmax".
// ---- Online Softmax ----
Expand Down Expand Up @@ -441,7 +577,14 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
AffineMap accMap = getOutputMap();

// ---- Scale and truncate LHS to match RHS ----
SmallVector<OpFoldResult> sSizes;
for (AffineExpr dimExpr : sMap.getResults()) {
int dim = cast<AffineDimExpr>(dimExpr).getPosition();
sSizes.push_back(sizes[dim]);
}

auto pETy = getElementTypeOrSelf(p.getType());
auto vETy = getElementTypeOrSelf(value.getType());
if (pETy != vETy && isa<FloatType>(vETy)) {
Value convertP = b.create<tensor::EmptyOp>(loc, sSizes, vETy);
p = truncateFloat(b, loc, pMap, pMap, p, convertP, lowPrecision);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",
["getIndexingMapsForResults", "getIndexingMapsForOperands",
"getStaticLoopRanges"]>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ iree_lit_test_suite(
srcs = enforce_glob(
[
"canonicalize.mlir",
"decompose_aggregate_op.mlir",
"invalid.mlir",
"roundtrip.mlir",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_lit_test_suite(
lit
SRCS
"canonicalize.mlir"
"decompose_aggregate_op.mlir"
"invalid.mlir"
"roundtrip.mlir"
TOOLS
Expand Down
Loading

0 comments on commit 3cf5b65

Please sign in to comment.