diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 204ae3533c7b..7fc985bf67ab 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -299,50 +299,29 @@ static bool willBeContiguousSlice(OpFoldResult inputSize, OpFoldResult tileSize, } //===----------------------------------------------------------------------===// -// OnlineAttentionOp +// Attention Helpers //===----------------------------------------------------------------------===// -FailureOr> -OnlineAttentionOp::decomposeOperation(OpBuilder &b) { - Location loc = getLoc(); - Value query = getQuery(); - Value key = getKey(); - Value value = getValue(); - std::optional mask = getMask(); - Value oldAcc = getOutput(); - Value oldMax = getMax(); - Value oldSum = getSum(); - Type elementType = getElementTypeOrSelf(getOutput().getType()); - DictionaryAttr config = getDecompositionConfigAttr(); - - DictionaryAttr qkAttrs, pvAttrs; - if (config) { - qkAttrs = config.getAs(getQKAttrStr()); - pvAttrs = config.getAs(getPVAttrStr()); - } - - FailureOr maybeOpInfo = - AttentionOpDetail::get(getIndexingMapsArray()); - assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps"); - AttentionOpDetail opInfo = maybeOpInfo.value(); - - SmallVector sizes = llvm::map_to_vector( - getIterationDomain(b), [](Range x) { return x.size; }); - +Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, + Value key, Value scale, std::optional mask, + AffineMap qMap, AffineMap kMap, AffineMap sMap, + std::optional maskMap, + SmallVector iterationDomain, + Type sElementType, Region &elementwiseRegion, + DictionaryAttr qkAttrs, bool lowPrecision) { + MLIRContext *ctx = b.getContext(); // Since we use exp2 for attention instead of the original exp, we have to // multiply the scale by log2(e). We use exp2 instead of exp as most platforms // have better support for exp2 (we verified that we gain some speedup on // some GPUs). - Value scale = getScale(); Value log2e = b.create( loc, b.getFloatAttr(scale.getType(), M_LOG2E)); scale = b.create(loc, scale, log2e); auto qETy = getElementTypeOrSelf(query.getType()); - auto vETy = getElementTypeOrSelf(value.getType()); - AffineMap scaleMap = AffineMap::get(/*dimCount=*/getQueryMap().getNumInputs(), - /*symbolCount=*/0, getContext()); + AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(), + /*symbolCount=*/0, ctx); // In the original algorithm, the scaling is done after the softmax: // softmax(Q @ K.T * scale) @ V @@ -352,43 +331,40 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // iteration of the loop. This is only valid for f16 or f32 types as f8 // is extremely limited on its dynamic range therefore this would // significantly affect numerics. - if (qETy.getIntOrFloatBitWidth() > 8) { - AffineMap qMap = getQueryMap(); + if (!lowPrecision) { query = elementwiseValueInPlace(b, loc, qMap, scaleMap, query, scale); } - // ---- Matmul 1 ---- + // ---- QK Matmul ---- // Get sizes for S. - AffineMap sMap = opInfo.getSMap(); SmallVector sSizes; for (AffineExpr dimExpr : sMap.getResults()) { int dim = cast(dimExpr).getPosition(); - sSizes.push_back(sizes[dim]); + sSizes.push_back(iterationDomain[dim]); } // S = Q @ K // SMap = QMap @ KMap - Value emptyS = b.create(loc, sSizes, elementType); - Value sZero = b.create(loc, b.getZeroAttr(elementType)); + Value emptyS = b.create(loc, sSizes, sElementType); + Value sZero = b.create(loc, b.getZeroAttr(sElementType)); Value s = b.create(loc, sZero, emptyS).getResult(0); - s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s); + s = computeMatmul(b, loc, qMap, kMap, sMap, query, key, s); if (qkAttrs) { - s.getDefiningOp()->setDiscardableAttrs(qkAttrs); + s.getDefiningOp()->setAttrs(qkAttrs); } - s = applyPostQKMatmulElementwise(b, loc, getRegion(), s); + s = applyPostQKMatmulElementwise(b, loc, elementwiseRegion, s); - bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8; if (lowPrecision) { // For low bit-depth types we perform post Q @ K scaling. This is to avoid // losing numerical precision due to the low dynamic range of fp8 types when // pre applying the sclaing. AffineMap sMap = b.getMultiDimIdentityMap(sSizes.size()); AffineMap scaleMap = AffineMap::get(/*dimCount=*/sMap.getNumInputs(), - /*symbolCount=*/0, getContext()); + /*symbolCount=*/0, ctx); s = elementwiseValueInPlace(b, loc, sMap, scaleMap, s, scale); @@ -401,16 +377,176 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false) .convertToDouble(); Value offset = b.create( - loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / mx)); + loc, b.getFloatAttr(sElementType, clAttentionSoftmaxMax / mx)); s = elementwiseValueInPlace(b, loc, sMap, scaleMap, s, offset); } // S += mask if (mask != nullptr) { - s = applyMask(b, loc, sMap, *getMaskMap(), s, mask.value()); + s = applyMask(b, loc, sMap, *maskMap, s, mask.value()); + } + + return s; +} + +//===----------------------------------------------------------------------===// +// AttentionOp +//===----------------------------------------------------------------------===// + +FailureOr> AttentionOp::decomposeOperation(OpBuilder &b) { + Location loc = getLoc(); + Value query = getQuery(); + Value key = getKey(); + Value value = getValue(); + std::optional mask = getMask(); + DictionaryAttr config = getDecompositionConfigAttr(); + + DictionaryAttr qkAttrs, pvAttrs; + if (config) { + qkAttrs = config.getAs(getQKAttrStr()); + pvAttrs = config.getAs(getPVAttrStr()); + } + Value output = getOutput(); + + FailureOr maybeOpInfo = + AttentionOpDetail::get(getIndexingMapsArray()); + assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps"); + AttentionOpDetail opInfo = maybeOpInfo.value(); + + SmallVector sizes = llvm::map_to_vector( + getIterationDomain(b), [](Range x) { return x.size; }); + + AffineMap qMap = getQueryMap(); + AffineMap kMap = getKeyMap(); + AffineMap sMap = opInfo.getSMap(); + + auto qETy = getElementTypeOrSelf(query.getType()); + bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8; + + // We compute output of first matmul in f32. + Type f32Type = b.getF32Type(); + + // ---- QK Matmul + elementwise math ---- + Value s = computeQKAndElementwise(loc, b, query, key, getScale(), mask, qMap, + kMap, sMap, getMaskMap(), sizes, f32Type, + getRegion(), qkAttrs, lowPrecision); + + // ---- Softmax ---- + + AffineMap accMap = getOutputMap(); + + llvm::SmallBitVector projectedK2Dims(opInfo.getDomainRank(), false); + for (auto dim : opInfo.getK2Dims()) { + projectedK2Dims.set(dim); } + AffineMap maxMap = projectDims(sMap, projectedK2Dims).dropZeroResults(); + AffineMap sumMap = maxMap; + + SmallVector rowRedSize = + applyPermutationMap(maxMap, sizes); + + Value rowRedEmpty = b.create(loc, rowRedSize, f32Type); + + Value accInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, + getElementTypeOrSelf(output), b, loc, + /*useOnlyFiniteValue=*/true); + Value maxInit = + arith::getIdentityValue(arith::AtomicRMWKind::maximumf, f32Type, b, loc, + /*useOnlyFiniteValue=*/true); + Value sumInit = + arith::getIdentityValue(arith::AtomicRMWKind::addf, f32Type, b, loc); + + Value accFill = + b.create(loc, ValueRange{accInit}, output).getResult(0); + Value maxFill = + b.create(loc, ValueRange{maxInit}, rowRedEmpty) + .getResult(0); + Value sumFill = + b.create(loc, ValueRange{sumInit}, rowRedEmpty) + .getResult(0); + + // max = rowMax(S) + Value max = reduce(b, loc, sMap, maxMap, s, maxFill); + + // P = exp2(S - max) + AffineMap pMap = sMap; + Value p = computeSubAndExp2(b, loc, maxMap, sMap, max, s); + + // sum = rowSum(P) + Value sum = reduce(b, loc, pMap, sumMap, p, sumFill); + + // P = P / sum + p = elementwiseValueInPlace(b, loc, pMap, sumMap, p, sum); + + // ---- Scale and truncate LHS to match RHS ---- + SmallVector sSizes; + for (AffineExpr dimExpr : sMap.getResults()) { + int dim = cast(dimExpr).getPosition(); + sSizes.push_back(sizes[dim]); + } + + auto pETy = getElementTypeOrSelf(p.getType()); + auto vETy = getElementTypeOrSelf(value.getType()); + if (pETy != vETy && isa(vETy)) { + Value convertP = b.create(loc, sSizes, vETy); + p = truncateFloat(b, loc, pMap, pMap, p, convertP, lowPrecision); + } + + // result = P @ V + acc + Value result = + computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, accFill); + if (pvAttrs) { + result.getDefiningOp()->setAttrs(pvAttrs); + } + + return SmallVector{result}; +} + +//===----------------------------------------------------------------------===// +// OnlineAttentionOp +//===----------------------------------------------------------------------===// + +FailureOr> +OnlineAttentionOp::decomposeOperation(OpBuilder &b) { + Location loc = getLoc(); + Value query = getQuery(); + Value key = getKey(); + Value value = getValue(); + std::optional mask = getMask(); + Value oldAcc = getOutput(); + Value oldMax = getMax(); + Value oldSum = getSum(); + Type elementType = getElementTypeOrSelf(getOutput().getType()); + DictionaryAttr config = getDecompositionConfigAttr(); + + DictionaryAttr qkAttrs, pvAttrs; + if (config) { + qkAttrs = config.getAs(getQKAttrStr()); + pvAttrs = config.getAs(getPVAttrStr()); + } + + FailureOr maybeOpInfo = + AttentionOpDetail::get(getIndexingMapsArray()); + assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps"); + AttentionOpDetail opInfo = maybeOpInfo.value(); + + SmallVector sizes = llvm::map_to_vector( + getIterationDomain(b), [](Range x) { return x.size; }); + + AffineMap qMap = getQueryMap(); + AffineMap kMap = getKeyMap(); + AffineMap sMap = opInfo.getSMap(); + + auto qETy = getElementTypeOrSelf(query.getType()); + bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8; + + // ---- QK Matmul + elementwise math ---- + Value s = computeQKAndElementwise( + loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(), + sizes, elementType, getRegion(), qkAttrs, lowPrecision); + // TODO: This decomposition should be in a seperate op called // "online softmax". // ---- Online Softmax ---- @@ -441,7 +577,14 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap accMap = getOutputMap(); // ---- Scale and truncate LHS to match RHS ---- + SmallVector sSizes; + for (AffineExpr dimExpr : sMap.getResults()) { + int dim = cast(dimExpr).getPosition(); + sSizes.push_back(sizes[dim]); + } + auto pETy = getElementTypeOrSelf(p.getType()); + auto vETy = getElementTypeOrSelf(value.getType()); if (pETy != vETy && isa(vETy)) { Value convertP = b.create(loc, sSizes, vETy); p = truncateFloat(b, loc, pMap, pMap, p, convertP, lowPrecision); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index 329c79ca5297..3b46114abe5e 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -475,6 +475,7 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", ["getIndexingMapsForResults", "getIndexingMapsForOperands", "getStaticLoopRanges"]>, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + +func.func @custom_op_decomposition(%lhs1 : tensor<1000000x?xf32>, + %rhs1 : tensor, %rhs2 : tensor, %scalar : f32, + %outs1 : tensor<1000000x?xf32>, %outs2 : tensor<1000000x?xf32>) + -> (tensor<1000000x?xf32>, tensor<1000000x?xf32>) { + %0:2 = iree_linalg_ext.custom_op { + indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>, + affine_map<(d0, d1)[s0, s1] -> (s0, s1)>, + affine_map<(d0, d1)[s0, s1] -> (s1, d1)>, + affine_map<(d0, d1)[s0, s1] -> ()>, + affine_map<(d0, d1)[s0, s1] -> (d0, s1)>, + affine_map<(d0, d1)[s0, s1] -> (d0, d1)>], + iterator_types = [#iree_linalg_ext.iterator_type, + #iree_linalg_ext.iterator_type]} + ins(%lhs1, %rhs1, %rhs2, %scalar + : tensor<1000000x?xf32>, tensor, tensor, f32) + outs(%outs1, %outs2 : tensor<1000000x?xf32>, tensor<1000000x?xf32>) { + ^bb0(%t0 : tensor, %t1 : tensor, %t2 : tensor, + %s : f32, %t3 : tensor, %t4 : tensor) : + %0 = linalg.matmul ins(%t0, %t1 : tensor, tensor) + outs(%t3 : tensor) -> tensor + %1 = linalg.matmul ins(%0, %t2 : tensor, tensor) + outs(%t4 : tensor) -> tensor + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> ()>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%1, %s : tensor, f32) outs(%1 : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 :f32): + %3 = arith.addf %b0, %b2 : f32 + linalg.yield %3 : f32 + } -> tensor + iree_linalg_ext.yield %0, %2 : tensor, tensor + } -> tensor<1000000x?xf32>, tensor<1000000x?xf32> + return %0#0, %0#1 : tensor<1000000x?xf32>, tensor<1000000x?xf32> +} + +// CHECK-LABEL: func @custom_op_decomposition( +// CHECK-SAME: %[[LHS1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> +// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[RHS2:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[SCALAR:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> +// CHECK: %[[MATMUL1:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS1]], %[[RHS1]] : +// CHECK-SAME: outs(%[[INIT1]] : +// CHECK: %[[MATMUL2:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[MATMUL1]], %[[RHS2]] : +// CHECK-SAME: outs(%[[INIT2]] : +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[MATMUL2]], %[[SCALAR]] : +// CHECK-SAME: outs(%[[MATMUL2]] : +// CHECK: return %[[MATMUL1]], %[[GENERIC]] + +// ----- + +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> @@ -8,6 +83,89 @@ #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> func.func @attention_f16(%query: tensor<192x1024x64xf16>, + %key: tensor<192x1024x64xf16>, + %value: tensor<192x1024x64xf16>, + %output: tensor<192x1024x64xf32>) + -> (tensor<192x1024x64xf32>) { + %scale = arith.constant 1.0 : f16 + + %out = iree_linalg_ext.attention + { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO] } + ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) + outs(%output : tensor<192x1024x64xf32>) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score: f32 + } + -> tensor<192x1024x64xf32> + + return %out : tensor<192x1024x64xf32> +} + +// We just want to check if we are using the correct algorithm +// CHECK-LABEL: @attention_f16 +// Q = Q * scale +// CHECK: linalg.generic +// CHECK: arith.mulf +// S = Q @ K +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK: arith.extf +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// max = rowMax(S) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.maximumf +// CHECK: linalg.yield +// P = exp2(S - max) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.subf +// CHECK: math.exp2 +// CHECK: linalg.yield +// sum = rowSum(P) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.addf +// CHECK: linalg.yield +// P = P /= sum +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.divf +// CHECK: linalg.yield +// truncf P : f32 to f16 +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.truncf +// CHECK: linalg.yield +// newAcc = P @ V +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK: arith.extf +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield + +// ----- + +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + +#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> +#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> +#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> +#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> +#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> + +func.func @online_attention_f16(%query: tensor<192x1024x64xf16>, %key: tensor<192x1024x64xf16>, %value: tensor<192x1024x64xf16>, %output: tensor<192x1024x64xf32>, @@ -30,7 +188,7 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>, // We just want to check if we are using the correct algorithm and the // correct number of extf/truncfs are emitted. -// CHECK-LABEL: @attention_f16 +// CHECK-LABEL: @online_attention_f16 // Q = Q * scale // CHECK: linalg.generic // CHECK: arith.mulf @@ -83,6 +241,15 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>, // ----- +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> @@ -90,7 +257,7 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>, #mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> -func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, +func.func @online_attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, %key: tensor<192x1024x64xf8E4M3FNUZ>, %value: tensor<192x1024x64xf8E4M3FNUZ>, %output: tensor<192x1024x64xf32>, @@ -111,7 +278,7 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32> } -// CHECK-LABEL: @attention_f8 +// CHECK-LABEL: @online_attention_f8 // S = Q @ K // CHECK: linalg.generic // CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32 @@ -176,6 +343,15 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, // ----- +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> @@ -184,7 +360,7 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, #mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> -func.func @attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, +func.func @online_attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, %key: tensor<192x1024x64xf8E4M3FNUZ>, %value: tensor<192x1024x64xf8E4M3FNUZ>, %mask: tensor<192x1024x1024xf8E4M3FNUZ>, @@ -205,7 +381,7 @@ func.func @attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32> } -// CHECK-LABEL: @attention_f8_masked +// CHECK-LABEL: @online_attention_f8_masked // S = Q @ K // CHECK: linalg.generic // CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32 diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel index efe463a65949..6ba9d5cd801d 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel @@ -20,9 +20,7 @@ iree_lit_test_suite( "conv2d_to_winograd.mlir", "convert_to_loops.mlir", "convert_to_online_attention.mlir", - "decompose_aggregate_op.mlir", "decompose_im2col.mlir", - "decompose_online_attention.mlir", "decompose_winograd.mlir", "distribution.mlir", "pad_contraction_to_block_size.mlir", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt index 3288c1443dfd..a912973cb2f7 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt @@ -18,9 +18,7 @@ iree_lit_test_suite( "conv2d_to_winograd.mlir" "convert_to_loops.mlir" "convert_to_online_attention.mlir" - "decompose_aggregate_op.mlir" "decompose_im2col.mlir" - "decompose_online_attention.mlir" "decompose_winograd.mlir" "distribution.mlir" "pad_contraction_to_block_size.mlir" diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir deleted file mode 100644 index 80b0b7a693e3..000000000000 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir +++ /dev/null @@ -1,62 +0,0 @@ -// RUN: iree-opt --iree-transform-dialect-interpreter --canonicalize --mlir-print-local-scope --split-input-file %s | FileCheck %s - -func.func @custom_op_decomposition(%lhs1 : tensor<1000000x?xf32>, - %rhs1 : tensor, %rhs2 : tensor, %scalar : f32, - %outs1 : tensor<1000000x?xf32>, %outs2 : tensor<1000000x?xf32>) - -> (tensor<1000000x?xf32>, tensor<1000000x?xf32>) { - %0:2 = iree_linalg_ext.custom_op { - indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>, - affine_map<(d0, d1)[s0, s1] -> (s0, s1)>, - affine_map<(d0, d1)[s0, s1] -> (s1, d1)>, - affine_map<(d0, d1)[s0, s1] -> ()>, - affine_map<(d0, d1)[s0, s1] -> (d0, s1)>, - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>], - iterator_types = [#iree_linalg_ext.iterator_type, - #iree_linalg_ext.iterator_type]} - ins(%lhs1, %rhs1, %rhs2, %scalar - : tensor<1000000x?xf32>, tensor, tensor, f32) - outs(%outs1, %outs2 : tensor<1000000x?xf32>, tensor<1000000x?xf32>) { - ^bb0(%t0 : tensor, %t1 : tensor, %t2 : tensor, - %s : f32, %t3 : tensor, %t4 : tensor) : - %0 = linalg.matmul ins(%t0, %t1 : tensor, tensor) - outs(%t3 : tensor) -> tensor - %1 = linalg.matmul ins(%0, %t2 : tensor, tensor) - outs(%t4 : tensor) -> tensor - %2 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> ()>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%1, %s : tensor, f32) outs(%1 : tensor) { - ^bb0(%b0 : f32, %b1 : f32, %b2 :f32): - %3 = arith.addf %b0, %b2 : f32 - linalg.yield %3 : f32 - } -> tensor - iree_linalg_ext.yield %0, %2 : tensor, tensor - } -> tensor<1000000x?xf32>, tensor<1000000x?xf32> - return %0#0, %0#1 : tensor<1000000x?xf32>, tensor<1000000x?xf32> -} -module attributes { transform.with_named_sequence } { - transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["iree_linalg_ext.custom_op"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () - transform.yield - } -} -// CHECK-LABEL: func @custom_op_decomposition( -// CHECK-SAME: %[[LHS1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> -// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[RHS2:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[SCALAR:[a-zA-Z0-9]+]]: f32 -// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> -// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> -// CHECK: %[[MATMUL1:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[LHS1]], %[[RHS1]] : -// CHECK-SAME: outs(%[[INIT1]] : -// CHECK: %[[MATMUL2:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[MATMUL1]], %[[RHS2]] : -// CHECK-SAME: outs(%[[INIT2]] : -// CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: ins(%[[MATMUL2]], %[[SCALAR]] : -// CHECK-SAME: outs(%[[MATMUL2]] : -// CHECK: return %[[MATMUL1]], %[[GENERIC]]