Skip to content

Commit

Permalink
responding to comments
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
AlexandreEichenberger committed Oct 18, 2024
1 parent ea57ac9 commit a2230f1
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ Value ZTensorHelper::getPreTransformedDescPtr(zdnn_data_types zDNNDataType,
Type llvmZTensorDescStructTy = getZTensorDescStructTy(context);
Value one = create.llvm.constant(llvmI64Ty, static_cast<int64_t>(1));

// TODO: evaluate if a heap alloc would not be better.
// Alloca is fine for LLVM structs; if we were to use alloc, we would also to
// manually insert free calls. So alloca makes total sense here.
Value preTransformedDescPtr = create.llvm._alloca(
krnl::getPointerType(context, llvmZTensorDescStructTy),
llvmZTensorDescStructTy, one,
Expand Down Expand Up @@ -155,7 +156,6 @@ Value ZTensorHelper::getTransformedDescPtr(
Type llvmZTensorDescStructTy = getZTensorDescStructTy(context);
Value one = create.llvm.constant(llvmI64Ty, static_cast<int64_t>(1));

// TODO: evaluate if a heap alloc would not be better.
Value transformedDescPtr = create.llvm._alloca(
krnl::getPointerType(context, llvmZTensorDescStructTy),
llvmZTensorDescStructTy, one,
Expand Down Expand Up @@ -217,7 +217,6 @@ ZTensor ZTensorHelper::getZTensor(Value bufferPtr, zdnn_data_types dataType,
Value transformedDescPtr =
getTransformedDescPtr(preTransformedDescPtr, isConcat, concatInfo);
// Create the input zTensor.
// TODO: evaluate if a heap alloc would not be better.
Value alloc =
create.llvm._alloca(krnl::getPointerType(context, llvmZTensorStructTy),
llvmZTensorStructTy, one,
Expand Down Expand Up @@ -253,7 +252,6 @@ ZTensor ZTensorHelper::getZTensor(Value preTransformedDescPtr,
Type llvmZTensorStructTy = getZTensorStructTy(context);
Value one =
create.llvm.constant(rewriter.getI64Type(), static_cast<int64_t>(1));
// TODO: evaluate if a heap alloc would not be better.
Value alloc =
create.llvm._alloca(krnl::getPointerType(context, llvmZTensorStructTy),
llvmZTensorStructTy, one,
Expand Down
25 changes: 7 additions & 18 deletions src/Conversion/ONNXToKrnl/Math/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,6 @@ struct ONNXGemmOpLowering : public OpConversionPattern<GemmOp> {
MemRefType bTileType =
MemRefType::get({kCacheTile, jCacheTile}, elementType);
SmallVector<IndexExpr, 1> empty;
// Allocate here on heap, only when no parallelism.
Value aBuff, bBuff, rBuff;
if (!enableParallel) {
aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN);
bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN);
if (mustTileR)
rBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN);
}

// 3) introduce the loops and permute them
// I, J, K loop.
Expand Down Expand Up @@ -255,11 +247,10 @@ struct ONNXGemmOpLowering : public OpConversionPattern<GemmOp> {
{I, J, K},
[&](const KrnlBuilder &createKrnl, ValueRange i1_j1_indices) {
Value i1(i1_j1_indices[0]), j1(i1_j1_indices[1]);
if (enableParallel) {
aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN);
bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN);
rBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN);
}
// If parallel, will stay inside, otherwise will migrate out.
Value aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN);
Value bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN);
Value rBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN);
createKrnl.copyToBuffer(rBuff, R, {i1, j1}, zeroVal, false);
createKrnl.iterateIE({}, {kk1}, {}, {},
[&](const KrnlBuilder &createKrnl, ValueRange k1_index) {
Expand Down Expand Up @@ -321,11 +312,9 @@ struct ONNXGemmOpLowering : public OpConversionPattern<GemmOp> {
{J, K, I},
[&](const KrnlBuilder &createKrnl, ValueRange j1_k1_indices) {
Value j1(j1_k1_indices[0]), k1(j1_k1_indices[1]);
// If parallel, allocate on stack inside the parallel region.
if (enableParallel) {
aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN);
bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN);
}
// If parallel, it will stay inside, otherwise it will migrate out.
Value aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN);
Value bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN);
if (bTrans)
createKrnl.copyToBuffer(bBuff, B, {j1, k1}, zeroVal, true);
else
Expand Down
22 changes: 4 additions & 18 deletions src/Conversion/ONNXToKrnl/Math/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1150,18 +1150,11 @@ struct ONNXReductionOpLowering : public OpConversionPattern<ONNXReductionOp> {
"not enough work for reduction h-simd");
}
}
Value tmpAlloc;
if (!enableParallel) {
// No parallel, alloc once outside.
tmpAlloc = create.mem.alignedAlloc(tmpType);
}
create.krnl.iterateIE(outLoopDef, outLoopDef, lbs, flatOutDims,
[&](const KrnlBuilder &ck, ValueRange outLoopInd) {
MDBuilder create(ck);
if (enableParallel) {
// Allocate temp inside loop because of parallel.
tmpAlloc = create.mem.alignedAlloc(tmpType);
}
// When parallel, will stay inside; otherwise will migrate out.
Value tmpAlloc = create.mem.alignedAlloc(tmpType);
Value identity = getIdentityValue<ONNXReductionOp>(
rewriter, create.getLoc(), elementType);
Value initVec = create.vec.splat(vecType, identity);
Expand Down Expand Up @@ -1311,18 +1304,11 @@ struct ONNXReductionOpLowering : public OpConversionPattern<ONNXReductionOp> {
"not enough work for reduction shuffle h-simd");
}
}
Value tmpBlockedAlloc;
if (!enableParallel) {
// Sequential, can allocate before loop.
tmpBlockedAlloc = create.mem.alignedAlloc(tmpBlockedType);
}
create.krnl.iterateIE(outLoopDef, optimizedOutLoopDef, lbs, flatOutDims,
[&](const KrnlBuilder &ck, ValueRange blockedOutLoopInd) {
MDBuilder create(ck);
if (enableParallel) {
// Create temp inside loop because of parallel.
tmpBlockedAlloc = create.mem.alignedAlloc(tmpBlockedType);
}
// When parallel, will stay inside; otherwise will migrate out.
Value tmpBlockedAlloc = create.mem.alignedAlloc(tmpBlockedType);
Value identity = getIdentityValue<ONNXReductionOp>(
rewriter, create.getLoc(), elementType);
Value initVec = create.vec.splat(vecType, identity);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ func.func private @gpt2_original(%arg0 : tensor<?x?x768xf32>) -> tensor<?x?x1xf3
// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[RES_]]([[RES_]]_3) : (memref<?x?x1xf32>, memref<2xindex>) -> memref<?x?xf32>
// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2
// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
// CHECK: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){
// CHECK: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: [[VAR_8_:%.+]] = affine.apply [[MAP_1_]]([[VAR_7_]]#1){{.}}[[VAR_dim_0_]]{{.}}
// CHECK: [[VAR_9_:%.+]] = arith.cmpi slt, [[VAR_8_]], [[CST_0_]] : index
// CHECK: scf.if [[VAR_9_]] {
Expand Down Expand Up @@ -431,9 +431,9 @@ func.func private @gpt2_no_keepdims(%arg0 : tensor<?x?x768xf32>) -> tensor<*xf32
// CHECK-DAG: [[VAR_5_:%.+]] = arith.sitofp [[VAR_4_]] : i64 to f32
// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2
// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
// CHECK: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){
// CHECK: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: [[VAR_8_:%.+]] = affine.apply [[MAP_1_]]([[VAR_7_]]#1){{.}}[[VAR_dim_0_]]{{.}}
// CHECK: [[VAR_9_:%.+]] = arith.cmpi slt, [[VAR_8_]], [[CST_0_]] : index
// CHECK: scf.if [[VAR_9_]] {
Expand Down Expand Up @@ -553,9 +553,9 @@ func.func private @gpt2_reduce2(%arg0 : tensor<?x?x96x8xf32>) -> tensor<*xf32> {
// CHECK-DAG: [[VAR_reshape_7_:%.+]] = memref.reshape [[RES_]]([[RES_]]_6) : (memref<?x?x1x1xf32>, memref<2xindex>) -> memref<?x?xf32>
// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2
// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
// CHECK: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){
// CHECK: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: [[VAR_8_:%.+]] = affine.apply [[MAP_1_]]([[VAR_7_]]#1){{.}}[[VAR_dim_0_]]{{.}}
// CHECK: [[VAR_9_:%.+]] = arith.cmpi slt, [[VAR_8_]], [[CST_0_]] : index
// CHECK: scf.if [[VAR_9_]] {
Expand Down Expand Up @@ -677,9 +677,9 @@ func.func private @gpt2_one_not_multiple(%arg0 : tensor<?x?x97x8xf32>) -> tensor
// CHECK-DAG: [[VAR_reshape_7_:%.+]] = memref.reshape [[RES_]]([[RES_]]_6) : (memref<?x?x1x1xf32>, memref<2xindex>) -> memref<?x?xf32>
// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2
// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
// CHECK: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){
// CHECK: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: [[VAR_8_:%.+]] = affine.apply [[MAP_1_]]([[VAR_7_]]#1){{.}}[[VAR_dim_0_]]{{.}}
// CHECK: [[VAR_9_:%.+]] = arith.cmpi slt, [[VAR_8_]], [[CST_0_]] : index
// CHECK: scf.if [[VAR_9_]] {
Expand Down Expand Up @@ -802,9 +802,9 @@ func.func private @gpt2_no_simd_as_not_mult_of_VL(%arg0 : tensor<?x?x97x9xf32>)
// CHECK-DAG: [[VAR_reshape_7_:%.+]] = memref.reshape [[RES_]]([[RES_]]_6) : (memref<?x?x1x1xf32>, memref<2xindex>) -> memref<?x?xf32>
// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2
// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
// CHECK: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){
// CHECK: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: [[VAR_8_:%.+]] = affine.apply [[MAP_1_]]([[VAR_7_]]#1){{.}}[[VAR_dim_0_]]{{.}}
// CHECK: [[VAR_9_:%.+]] = arith.cmpi slt, [[VAR_8_]], [[CST_0_]] : index
// CHECK: scf.if [[VAR_9_]] {
Expand Down Expand Up @@ -922,9 +922,10 @@ func.func private @test_reducemax_v13_bis(%arg0 : tensor<1028x256xf32>) -> tenso
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1028xf32>
// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1
// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
// CHECK: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 1028){
// CHECK: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index
// CHECK-DAG: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_0_]]([[VAR_1_]])
// CHECK-DAG: [[VAR_3_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]])
// CHECK-DAG: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]])
Expand Down Expand Up @@ -995,9 +996,9 @@ func.func private @test_reducemax_v13_small(%arg0 : tensor<7x8xf32>) -> tensor<*
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<7xf32>
// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1
// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
// CHECK: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 7){
// CHECK: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index
// CHECK-DAG: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_0_]]([[VAR_1_]])
// CHECK: [[VAR_3_:%.+]] = arith.cmpi slt, [[VAR_2_]], [[CST_0_]] : index
// CHECK: scf.if [[VAR_3_]] {
Expand Down Expand Up @@ -1079,9 +1080,9 @@ func.func private @test_reducemax_int_v13(%arg0 : tensor<128x256x768xi32>) -> te
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<128x256xi32>
// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<1x32xi32>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 128, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 256){
// CHECK: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<1x32xi32>
// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x32xi32>, vector<32xi32>
// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1
// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
Expand Down Expand Up @@ -1138,9 +1139,10 @@ func.func private @bertsquad10_same_pattern(%arg0 : tensor<?x256x768xf32>) -> te
// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[RES_]]([[RES_]]_1) : (memref<?x256x1xf32>, memref<2xindex>) -> memref<?x256xf32>
// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2
// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
// CHECK: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 256){
// CHECK: [[VAR_6_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[VAR_6_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_7_:%.+]] = affine.apply [[MAP_2_]]([[VAR_6_]]#1)
// CHECK-DAG: [[VAR_8_:%.+]] = affine.apply [[MAP_3_]]([[VAR_6_]]#1)
// CHECK-DAG: [[VAR_9_:%.+]] = affine.apply [[MAP_4_]]([[VAR_6_]]#1)
Expand Down Expand Up @@ -1220,9 +1222,10 @@ func.func private @bertsquad10_const_pattern(%arg0 : tensor<1x256x768xf32>) -> t
// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[RES_]]([[RES_]]_1) : (memref<1x256x1xf32>, memref<2xindex>) -> memref<1x256xf32>
// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2
// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
// CHECK: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 256){
// CHECK: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_0_]]([[VAR_1_]]#1)
// CHECK-DAG: [[VAR_3_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#1)
// CHECK-DAG: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]#1)
Expand Down
Loading

0 comments on commit a2230f1

Please sign in to comment.