Skip to content

End to end walk through

Rob Suderman edited this page Dec 15, 2023 · 46 revisions

Welcome to the iree-amd-aie wiki!

Below is a demonstration of the end-to-end flow for compiling an ONNX model to AIE execution including custom dispatches.

Onnx Front-end Conversion

Model ingestion involve legalization from a number of different source languages. Each of these languages can be cross-converted from one to another. Supporting a wide array of front ends is achieved by finding a fully representative path from the source language to the target. In the case of onnx we target linalg via torch.

image


ONNX models are ingested using the onnx_importer, converting the ONNX binary file to a series of torch.operator functions represented in the MLIR dialect. Below is a sample of this ingested form for a onnx.transpose operation.

Run: onnx_importer.py file.onnx > onnx.mlir

Results
func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {
                       torch.onnx_meta.ir_version = 7 : si64,
                       torch.onnx_meta.opset_version = 13 : si64 } {
  %0 = torch.operator "onnx.Transpose"(%arg0) : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32>
  return %0 : !torch.vtensor<[4,3,2],f32>
}

Once ingested these onnx operations can be translated into native torch.aten operators. This allows onnx to be lowered to alternative representations via the existing torch conversion pipelines and avoids creating a custom onnx pipeline.

torch-mlir-opt --convert-torch-onnx-to-torch onnx.mlir

Results
func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> {
  %int0 = torch.constant.int 0
  %int2 = torch.constant.int 2
  %0 = torch.aten.transpose.int %arg0, %int0, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,3,2],f32>
  return %0 : !torch.vtensor<[4,3,2],f32>
}

When targeting aie we convert from the torch dialect to our general computational dialect, linalg. Unlike torch, linalg provides a semantic representation of what each operation does by focusing on composition instead of standalone operations.

torch-mlir-opt --convert-torch-to-linalg --torch-func-backend-type-conversion --cse --canonicalize --torch-finalizing-backend-type-conversion torch.mlir > linalg.mlir

Results
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
func.func @test_transpose_default(%arg0: tensor<2x3x4xf32>) -> tensor<4x3x2xf32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
  %0 = tensor.empty() : tensor<4x3x2xf32>
  %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4xf32>) outs(%0 : tensor<4x3x2xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<4x3x2xf32>
  return %1 : tensor<4x3x2xf32>
}

IREE-AMD-AIE compilation process

image

E2E compilation command

SAMPLES_DIR=<iree-amd-aie source dir>/tests/samples
iree-compile  --iree-hal-target-backends=amd-aie  \
 ${SAMPLES_DIR}/matmul_fill_static_i32.mlir \
  --iree-codegen-transform-dialect-library=${SAMPLES_DIR}/matmul_fill_spec_pad.mlir \
  --iree-amd-aie-peano-install-dir=<peano installation directory> \
  --iree-amd-aie-mlir-aie-install-dir=<mlir-aie installation directory> \
  --iree-amd-aie-vitis-install-dir=<vitis installation directory> \
  --iree-hal-dump-executable-files-to=$PWD  --iree-amd-aie-show-invoked-commands

Prior to AIE ingestion a model is shown in a loop base computation shown below. In this example we show a a matrix multiplication between a tensor<8x16xi32> and a tensor<16x8xi32>resulting in atensor<8x8xi32>`.

Results
module {
  func.func @matmul_static_dispatch_0_matmul_8x8x16_i32() {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %c16 = arith.constant 16 : index
    %c4 = arith.constant 4 : index
    %c0_i32 = arith.constant 0 : i32
    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<8x16xi32>
    memref.assume_alignment %0, 64 : memref<8x16xi32>
    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<16x8xi32>
    memref.assume_alignment %1, 64 : memref<16x8xi32>
    %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<8x8xi32>
    memref.assume_alignment %2, 64 : memref<8x8xi32>
    scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%c1, %c1) step (%c1, %c1) {
      %3 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg0)
      %4 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg1)
      %subview = memref.subview %0[%3, 0] [8, 16] [1, 1] : memref<8x16xi32> to memref<8x16xi32, strided<[16, 1], offset: ?>>
      %subview_0 = memref.subview %1[0, %4] [16, 8] [1, 1] : memref<16x8xi32> to memref<16x8xi32, strided<[8, 1], offset: ?>>
      %subview_1 = memref.subview %2[%3, %4] [8, 8] [1, 1] : memref<8x8xi32> to memref<8x8xi32, strided<[8, 1], offset: ?>>
      %alloc = memref.alloc() : memref<8x16xi32, 1>
      memref.copy %subview, %alloc : memref<8x16xi32, strided<[16, 1], offset: ?>> to memref<8x16xi32, 1>
      %alloc_2 = memref.alloc() : memref<16x8xi32, 1>
      memref.copy %subview_0, %alloc_2 : memref<16x8xi32, strided<[8, 1], offset: ?>> to memref<16x8xi32, 1>
      %alloc_3 = memref.alloc() : memref<8x8xi32, 1>
      scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
        %5 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg2)
        %6 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg3)
        %subview_4 = memref.subview %alloc[%5, 0] [4, 16] [1, 1] : memref<8x16xi32, 1> to memref<4x16xi32, strided<[16, 1], offset: ?>, 1>
        %subview_5 = memref.subview %alloc_2[0, %6] [16, 4] [1, 1] : memref<16x8xi32, 1> to memref<16x4xi32, strided<[8, 1], offset: ?>, 1>
        %subview_6 = memref.subview %alloc_3[%5, %6] [4, 4] [1, 1] : memref<8x8xi32, 1> to memref<4x4xi32, strided<[8, 1], offset: ?>, 1>
        %alloc_7 = memref.alloc() : memref<4x4xi32, 2>
        linalg.fill ins(%c0_i32 : i32) outs(%alloc_7 : memref<4x4xi32, 2>)
        scf.for %arg4 = %c0 to %c16 step %c4 {
          %subview_8 = memref.subview %subview_4[0, %arg4] [4, 4] [1, 1] : memref<4x16xi32, strided<[16, 1], offset: ?>, 1> to memref<4x4xi32, strided<[16, 1], offset: ?>, 1>
          %subview_9 = memref.subview %subview_5[%arg4, 0] [4, 4] [1, 1] : memref<16x4xi32, strided<[8, 1], offset: ?>, 1> to memref<4x4xi32, strided<[8, 1], offset: ?>, 1>
          %alloc_10 = memref.alloc() : memref<4x4xi32, 2>
          memref.copy %subview_8, %alloc_10 : memref<4x4xi32, strided<[16, 1], offset: ?>, 1> to memref<4x4xi32, 2>
          %alloc_11 = memref.alloc() : memref<4x4xi32, 2>
          memref.copy %subview_9, %alloc_11 : memref<4x4xi32, strided<[8, 1], offset: ?>, 1> to memref<4x4xi32, 2>
          linalg.matmul ins(%alloc_10, %alloc_11 : memref<4x4xi32, 2>, memref<4x4xi32, 2>) outs(%alloc_7 : memref<4x4xi32, 2>)
          memref.dealloc %alloc_10 : memref<4x4xi32, 2>
          memref.dealloc %alloc_11 : memref<4x4xi32, 2>
        }
        memref.copy %alloc_7, %subview_6 : memref<4x4xi32, 2> to memref<4x4xi32, strided<[8, 1], offset: ?>, 1>
        memref.dealloc %alloc_7 : memref<4x4xi32, 2>
        scf.yield
      }
      memref.copy %alloc_3, %subview_1 : memref<8x8xi32, 1> to memref<8x8xi32, strided<[8, 1], offset: ?>>
      memref.dealloc %alloc : memref<8x16xi32, 1>
      memref.dealloc %alloc_2 : memref<16x8xi32, 1>
      memref.dealloc %alloc_3 : memref<8x8xi32, 1>
      scf.yield
    }
    return
  }
}

Post ingestion the outer loop structures remain, computing a series of 4x4xi32 matrix multiplications. The async execution component is handle byair.

Results
func.func @matmul_static_dispatch_0_matmul_8x8x16_i32() {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %async_token, %results = air.execute -> (memref<8x16xi32>) {
    %1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<8x16xi32>
    air.execute_terminator %1 : memref<8x16xi32>
  }
  %async_token_0 = air.execute [%async_token] {
    memref.assume_alignment %results, 64 : memref<8x16xi32>
  }
  %async_token_1, %results_2 = air.execute -> (memref<16x8xi32>) {
    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<16x8xi32>
    air.execute_terminator %1 : memref<16x8xi32>
  }
  %async_token_3 = air.execute [%async_token_1] {
    memref.assume_alignment %results_2, 64 : memref<16x8xi32>
  }
  %async_token_4, %results_5 = air.execute -> (memref<8x8xi32>) {
    %1 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<8x8xi32>
    air.execute_terminator %1 : memref<8x8xi32>
  }
  %async_token_6 = air.execute [%async_token_4] {
    memref.assume_alignment %results_5, 64 : memref<8x8xi32>
  }
  %0 = air.launch async [%async_token_0, %async_token_3, %async_token_6] (%arg0, %arg1) in (%arg2=%c1, %arg3=%c1) args(%arg4=%results, %arg5=%results_2, %arg6=%results_5) : memref<8x16xi32>, memref<16x8xi32>, memref<8x8xi32> attributes {id = 1 : i32} {
    %c1_7 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %c8 = arith.constant 8 : index
    %c0_8 = arith.constant 0 : index
    %async_token_9, %results_10 = air.execute -> (index) {
      %5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%arg0]
      air.execute_terminator %5 : index
    }
    %async_token_11, %results_12 = air.execute -> (index) {
      %5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%arg1]
      air.execute_terminator %5 : index
    }
    %async_token_13, %results_14 = air.execute -> (index) {
      %5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%arg0]
      air.execute_terminator %5 : index
    }
    %async_token_15, %results_16 = air.execute -> (index) {
      %5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%arg0]
      air.execute_terminator %5 : index
    }
    %1 = air.channel.put async [%async_token_15]  @channel_5[] (%arg4[%results_16, %c0_8] [%c8, %c16] [%c16, %c1_7]) {id = 1 : i32} : (memref<8x16xi32>)
    %2 = air.channel.put async [%async_token_11, %1]  @channel_5[] (%arg5[%c0_8, %results_12] [%c16, %c8] [%c8, %c1_7]) {id = 2 : i32} : (memref<16x8xi32>)
    %async_token_17, %results_18 = air.execute -> (index) {
      %5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%arg0]
      air.execute_terminator %5 : index
    }
    %async_token_19, %results_20 = air.execute -> (index) {
      %5 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%arg1]
      air.execute_terminator %5 : index
    }
    %3 = air.channel.get async [%async_token_17, %async_token_19]  @channel_7[] (%arg6[%results_18, %results_20] [%c8, %c8] [%c8, %c1_7]) {id = 3 : i32} : (memref<8x8xi32>)
    %4 = air.segment @segment_0 async  attributes {id = 2 : i32, x_loc = 0 : i64, x_size = 1 : i64, y_loc = 2 : i64, y_size = 4 : i64} {
      %c4 = arith.constant 4 : index
      %c2 = arith.constant 2 : index
      %c12 = arith.constant 12 : index
      %c8_21 = arith.constant 8 : index
      %c1_22 = arith.constant 1 : index
      %c16_23 = arith.constant 16 : index
      %c0_24 = arith.constant 0 : index
      %async_token_25, %results_26 = air.execute -> (memref<8x16xi32, 1>) {
        %alloc = memref.alloc() : memref<8x16xi32, 1>
        air.execute_terminator %alloc : memref<8x16xi32, 1>
      }
      %5 = air.wait_all async 
      %async_token_27, %results_28 = air.execute -> (memref<16x8xi32, 1>) {
        %alloc = memref.alloc() : memref<16x8xi32, 1>
        air.execute_terminator %alloc : memref<16x8xi32, 1>
      }
      %6 = air.channel.get async [%async_token_25]  @channel_5[] (%results_26[] [] []) {id = 4 : i32} : (memref<8x16xi32, 1>)
      %7 = air.channel.get async [%async_token_27, %6]  @channel_5[] (%results_28[] [] []) {id = 5 : i32} : (memref<16x8xi32, 1>)
      %async_token_29, %results_30 = air.execute -> (memref<8x8xi32, 1>) {
        %alloc = memref.alloc() : memref<8x8xi32, 1>
        air.execute_terminator %alloc : memref<8x8xi32, 1>
      }
      %8 = air.channel.put async [%5]  @channel_0[] (%results_26[%c0_24, %c0_24] [%c4, %c4] [%c16_23, %c1_22]) {id = 6 : i32, unrolled_iteration = 0 : i32} : (memref<8x16xi32, 1>)
      %9 = air.channel.put async [%8]  @channel_0[] (%results_26[%c0_24, %c4] [%c4, %c4] [%c16_23, %c1_22]) {id = 6 : i32, unrolled_iteration = 1 : i32} : (memref<8x16xi32, 1>)
      %10 = air.channel.put async [%9]  @channel_0[] (%results_26[%c0_24, %c8_21] [%c4, %c4] [%c16_23, %c1_22]) {id = 6 : i32, unrolled_iteration = 2 : i32} : (memref<8x16xi32, 1>)
      %11 = air.channel.put async [%10]  @channel_0[] (%results_26[%c0_24, %c12] [%c4, %c4] [%c16_23, %c1_22]) {id = 6 : i32, unrolled_iteration = 3 : i32} : (memref<8x16xi32, 1>)
      %12 = air.channel.put async [%5]  @channel_1[] (%results_26[%c4, %c0_24] [%c4, %c4] [%c16_23, %c1_22]) {id = 7 : i32, unrolled_iteration = 0 : i32} : (memref<8x16xi32, 1>)
      %13 = air.channel.put async [%12]  @channel_1[] (%results_26[%c4, %c4] [%c4, %c4] [%c16_23, %c1_22]) {id = 7 : i32, unrolled_iteration = 1 : i32} : (memref<8x16xi32, 1>)
      %14 = air.channel.put async [%13]  @channel_1[] (%results_26[%c4, %c8_21] [%c4, %c4] [%c16_23, %c1_22]) {id = 7 : i32, unrolled_iteration = 2 : i32} : (memref<8x16xi32, 1>)
      %15 = air.channel.put async [%14]  @channel_1[] (%results_26[%c4, %c12] [%c4, %c4] [%c16_23, %c1_22]) {id = 7 : i32, unrolled_iteration = 3 : i32} : (memref<8x16xi32, 1>)
      %16 = air.channel.put async [%7]  @channel_2[] (%results_28[%c0_24, %c0_24] [%c4, %c4] [%c8_21, %c1_22]) {id = 8 : i32, unrolled_iteration = 0 : i32} : (memref<16x8xi32, 1>)
      %17 = air.channel.put async [%16]  @channel_2[] (%results_28[%c4, %c0_24] [%c4, %c4] [%c8_21, %c1_22]) {id = 8 : i32, unrolled_iteration = 1 : i32} : (memref<16x8xi32, 1>)
      %18 = air.channel.put async [%17]  @channel_2[] (%results_28[%c8_21, %c0_24] [%c4, %c4] [%c8_21, %c1_22]) {id = 8 : i32, unrolled_iteration = 2 : i32} : (memref<16x8xi32, 1>)
      %19 = air.channel.put async [%18]  @channel_2[] (%results_28[%c12, %c0_24] [%c4, %c4] [%c8_21, %c1_22]) {id = 8 : i32, unrolled_iteration = 3 : i32} : (memref<16x8xi32, 1>)
      %20 = air.channel.put async [%7]  @channel_3[] (%results_28[%c0_24, %c4] [%c4, %c4] [%c8_21, %c1_22]) {id = 9 : i32, unrolled_iteration = 0 : i32} : (memref<16x8xi32, 1>)
      %21 = air.channel.put async [%20]  @channel_3[] (%results_28[%c4, %c4] [%c4, %c4] [%c8_21, %c1_22]) {id = 9 : i32, unrolled_iteration = 1 : i32} : (memref<16x8xi32, 1>)
      %22 = air.channel.put async [%21]  @channel_3[] (%results_28[%c8_21, %c4] [%c4, %c4] [%c8_21, %c1_22]) {id = 9 : i32, unrolled_iteration = 2 : i32} : (memref<16x8xi32, 1>)
      %23 = air.channel.put async [%22]  @channel_3[] (%results_28[%c12, %c4] [%c4, %c4] [%c8_21, %c1_22]) {id = 9 : i32, unrolled_iteration = 3 : i32} : (memref<16x8xi32, 1>)
      %24 = scf.parallel (%arg7, %arg8) = (%c0_24, %c0_24) to (%c2, %c2) step (%c1_22, %c1_22) init (%async_token_29) -> !air.async.token {
        %async_token_34, %results_35 = air.execute -> (index) {
          %28 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%arg7]
          air.execute_terminator %28 : index
        }
        %async_token_36, %results_37 = air.execute -> (index) {
          %28 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%arg8]
          air.execute_terminator %28 : index
        }
        %27 = air.channel.get async [%async_token_29, %async_token_36, %async_token_34]  @channel_6[%arg7, %arg8] (%results_30[%results_35, %results_37] [%c4, %c4] [%c8_21, %c1_22]) {id = 10 : i32} : (memref<8x8xi32, 1>)
        scf.reduce(%27)  : !air.async.token {
        ^bb0(%arg9: !air.async.token, %arg10: !air.async.token):
          %28 = air.wait_all async [%arg9, %arg10] 
          scf.reduce.return %28 : !air.async.token
        }
        scf.yield
      }
      %25 = air.herd @herd_0 async [%7, %async_token_29]  tile (%arg7, %arg8) in (%arg9=%c1_22, %arg10=%c4) attributes {id = 3 : i32, x_loc = 0 : i64, y_loc = 2 : i64} {
        %c4_34 = arith.constant 4 : index
        %c1_35 = arith.constant 1 : index
        %c16_36 = arith.constant 16 : index
        %c0_37 = arith.constant 0 : index
        %c0_i32 = arith.constant 0 : i32
        %c8_38 = arith.constant 8 : index
        %c2_39 = arith.constant 2 : index
        %27 = arith.remsi %arg8, %c2_39 : index
        %28 = arith.divsi %arg8, %c2_39 : index
        %async_token_40, %results_41 = air.execute -> (memref<4x4xi32, 2>) {
          %alloc = memref.alloc() : memref<4x4xi32, 2>
          air.execute_terminator %alloc : memref<4x4xi32, 2>
        }
        %async_token_42 = air.execute [%async_token_40] {
          scf.for %arg11 = %c0_37 to %c4_34 step %c1_35 {
            scf.for %arg12 = %c0_37 to %c4_34 step %c1_35 {
              memref.store %c0_i32, %results_41[%arg11, %arg12] : memref<4x4xi32, 2>
            }
          }
        }
        %async_token_43, %results_44 = air.execute [%async_token_42] -> (memref<4x4xi32, 2>) {
          %alloc = memref.alloc() : memref<4x4xi32, 2>
          air.execute_terminator %alloc : memref<4x4xi32, 2>
        }
        %async_token_45, %results_46 = air.execute [%async_token_43] -> (memref<4x4xi32, 2>) {
          %alloc = memref.alloc() : memref<4x4xi32, 2>
          air.execute_terminator %alloc : memref<4x4xi32, 2>
        }
        %async_token_47, %results_48 = air.execute [%async_token_45] -> (memref<4x4xi32, 2>) {
          %alloc = memref.alloc() : memref<4x4xi32, 2>
          air.execute_terminator %alloc : memref<4x4xi32, 2>
        }
        %async_token_49, %results_50 = air.execute [%async_token_45] -> (memref<4x4xi32, 2>) {
          %alloc = memref.alloc() : memref<4x4xi32, 2>
          air.execute_terminator %alloc : memref<4x4xi32, 2>
        }
        %29:4 = scf.for %arg11 = %c0_37 to %c16_36 step %c8_38 iter_args(%arg12 = %async_token_47, %arg13 = %async_token_49, %arg14 = %async_token_49, %arg15 = %async_token_49) -> (!air.async.token, !air.async.token, !air.async.token, !air.async.token) {
          %31 = affine.if affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 1 >= 0)>()[%28, %27] -> !air.async.token {
            %36 = air.channel.get async [%arg15, %arg12, %async_token_47]  @channel_0[%28, %27] (%results_48[] [] []) {id = 11 : i32} : (memref<4x4xi32, 2>)
            affine.yield %36 : !air.async.token
          } else {
            %36 = air.channel.get async [%arg15, %arg12, %async_token_47]  @channel_1[%28, %27] (%results_48[] [] []) {id = 12 : i32} : (memref<4x4xi32, 2>)
            affine.yield %36 : !air.async.token
          }
          %32 = affine.if affine_set<()[s0, s1] : (s0 >= 0, -s0 + 1 >= 0, s1 == 0)>()[%28, %27] -> !air.async.token {
            %36 = air.channel.get async [%arg15, %arg12, %async_token_49]  @channel_2[%28, %27] (%results_50[] [] []) {id = 13 : i32} : (memref<4x4xi32, 2>)
            affine.yield %36 : !air.async.token
          } else {
            %36 = air.channel.get async [%arg15, %arg12, %async_token_49]  @channel_3[%28, %27] (%results_50[] [] []) {id = 14 : i32} : (memref<4x4xi32, 2>)
            affine.yield %36 : !air.async.token
          }
          %async_token_52 = air.execute [%arg14, %32, %31] {
            scf.for %arg16 = %c0_37 to %c4_34 step %c1_35 {
              scf.for %arg17 = %c0_37 to %c4_34 step %c1_35 {
                scf.for %arg18 = %c0_37 to %c4_34 step %c1_35 {
                  %36 = memref.load %results_48[%arg16, %arg18] : memref<4x4xi32, 2>
                  %37 = memref.load %results_50[%arg18, %arg17] : memref<4x4xi32, 2>
                  %38 = memref.load %results_41[%arg16, %arg17] : memref<4x4xi32, 2>
                  %39 = arith.muli %36, %37 : i32
                  %40 = arith.addi %38, %39 : i32
                  memref.store %40, %results_41[%arg16, %arg17] : memref<4x4xi32, 2>
                }
              }
            }
          }
          %async_token_53 = air.execute {
            memref.dealloc %results_48 : memref<4x4xi32, 2>
          }
          %async_token_54 = air.execute {
            memref.dealloc %results_50 : memref<4x4xi32, 2>
          }
          %33 = affine.if affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 1 >= 0)>()[%28, %27] -> !air.async.token {
            %36 = air.channel.get async [%32, %31, %arg13]  @channel_0[%28, %27] (%results_46[] [] []) {id = 11 : i32} : (memref<4x4xi32, 2>)
            affine.yield %36 : !air.async.token
          } else {
            %36 = air.channel.get async [%32, %31, %arg13]  @channel_1[%28, %27] (%results_46[] [] []) {id = 12 : i32} : (memref<4x4xi32, 2>)
            affine.yield %36 : !air.async.token
          }
          %34 = affine.if affine_set<()[s0, s1] : (s0 >= 0, -s0 + 1 >= 0, s1 == 0)>()[%28, %27] -> !air.async.token {
            %36 = air.channel.get async [%32, %31, %arg13]  @channel_2[%28, %27] (%results_44[] [] []) {id = 13 : i32} : (memref<4x4xi32, 2>)
            affine.yield %36 : !air.async.token
          } else {
            %36 = air.channel.get async [%32, %31, %arg13]  @channel_3[%28, %27] (%results_44[] [] []) {id = 14 : i32} : (memref<4x4xi32, 2>)
            affine.yield %36 : !air.async.token
          }
          %async_token_55 = air.execute [%async_token_52, %34, %33] {
            scf.for %arg16 = %c0_37 to %c4_34 step %c1_35 {
              scf.for %arg17 = %c0_37 to %c4_34 step %c1_35 {
                scf.for %arg18 = %c0_37 to %c4_34 step %c1_35 {
                  %36 = memref.load %results_46[%arg16, %arg18] : memref<4x4xi32, 2>
                  %37 = memref.load %results_44[%arg18, %arg17] : memref<4x4xi32, 2>
                  %38 = memref.load %results_41[%arg16, %arg17] : memref<4x4xi32, 2>
                  %39 = arith.muli %36, %37 : i32
                  %40 = arith.addi %38, %39 : i32
                  memref.store %40, %results_41[%arg16, %arg17] : memref<4x4xi32, 2>
                }
              }
            }
          }
          %async_token_56 = air.execute {
            memref.dealloc %results_46 : memref<4x4xi32, 2>
          }
          %async_token_57 = air.execute {
            memref.dealloc %results_44 : memref<4x4xi32, 2>
          }
          %35 = air.wait_all async [%33, %34] 
          scf.yield %async_token_52, %async_token_55, %async_token_55, %35 : !air.async.token, !air.async.token, !air.async.token, !air.async.token
        }
        %30 = air.channel.put async [%29#1]  @channel_6[%28, %27] (%results_41[] [] []) {id = 15 : i32} : (memref<4x4xi32, 2>)
        %async_token_51 = air.execute [%30] {
          memref.dealloc %results_41 : memref<4x4xi32, 2>
        }
        air.herd_terminator
      }
      %26 = air.channel.put async [%25]  @channel_7[] (%results_30[] [] []) {id = 16 : i32} : (memref<8x8xi32, 1>)
      %async_token_31 = air.execute [%5] {
        memref.dealloc %results_26 : memref<8x16xi32, 1>
      }
      %async_token_32 = air.execute [%7] {
        memref.dealloc %results_28 : memref<16x8xi32, 1>
      }
      %async_token_33 = air.execute [%26] {
        memref.dealloc %results_30 : memref<8x8xi32, 1>
      }
      air.segment_terminator
    }
    air.launch_terminator
  }
  return
}

The compilation continues, having air operations rewritten to use the aie equivalent behavior:

Results
#executable_target_elf = #hal.executable.target<"amd-aie", "elf", {target_arch = "chip-tbd"}>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>
#translation = #iree_codegen.translation_info<TransformDialectCodegen codegen_spec = @__transform_main>
#device_target_amd_aie = #hal.device.target<"amd-aie", {executable_targets = [#executable_target_elf], legacy_sync}>
module attributes {hal.device.targets = [#device_target_amd_aie]} {
  hal.executable private @matmul_static_dispatch_0 {
    hal.executable.variant public @elf target(#executable_target_elf) {
      hal.executable.export public @matmul_static_dispatch_0_matmul_64x64x256_i32 ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} {
      ^bb0(%arg0: !hal.device):
        %c2 = arith.constant 2 : index
        %c1 = arith.constant 1 : index
        hal.return %c2, %c2, %c1 : index, index, index
      }
      builtin.module {
        AIE.device(ipu) {
          %tile_0_0 = AIE.tile(0, 0)
          %tile_0_1 = AIE.tile(0, 1)
          %tile_0_2 = AIE.tile(0, 2)
          %tile_0_3 = AIE.tile(0, 3)
          %tile_0_4 = AIE.tile(0, 4)
          %tile_0_5 = AIE.tile(0, 5)
          %lock_0_1 = AIE.lock(%tile_0_1, 5) {init = 4 : i32}
          %lock_0_1_0 = AIE.lock(%tile_0_1, 4) {init = 0 : i32}
          %lock_0_1_1 = AIE.lock(%tile_0_1, 3) {init = 2 : i32}
          %lock_0_1_2 = AIE.lock(%tile_0_1, 2) {init = 0 : i32}
          %lock_0_1_3 = AIE.lock(%tile_0_1, 1) {init = 2 : i32}
          %lock_0_1_4 = AIE.lock(%tile_0_1, 0) {init = 0 : i32}
          %lock_0_2 = AIE.lock(%tile_0_2, 5) {init = 2 : i32}
          %lock_0_2_5 = AIE.lock(%tile_0_2, 4) {init = 0 : i32}
          %lock_0_2_6 = AIE.lock(%tile_0_2, 3) {init = 2 : i32}
          %lock_0_2_7 = AIE.lock(%tile_0_2, 2) {init = 0 : i32}
          %lock_0_2_8 = AIE.lock(%tile_0_2, 1) {init = 1 : i32}
          %lock_0_2_9 = AIE.lock(%tile_0_2, 0) {init = 0 : i32}
          %lock_0_3 = AIE.lock(%tile_0_3, 5) {init = 2 : i32}
          %lock_0_3_10 = AIE.lock(%tile_0_3, 4) {init = 0 : i32}
          %lock_0_3_11 = AIE.lock(%tile_0_3, 3) {init = 2 : i32}
          %lock_0_3_12 = AIE.lock(%tile_0_3, 2) {init = 0 : i32}
          %lock_0_3_13 = AIE.lock(%tile_0_3, 1) {init = 1 : i32}
          %lock_0_3_14 = AIE.lock(%tile_0_3, 0) {init = 0 : i32}
          %lock_0_4 = AIE.lock(%tile_0_4, 5) {init = 2 : i32}
          %lock_0_4_15 = AIE.lock(%tile_0_4, 4) {init = 0 : i32}
          %lock_0_4_16 = AIE.lock(%tile_0_4, 3) {init = 2 : i32}
          %lock_0_4_17 = AIE.lock(%tile_0_4, 2) {init = 0 : i32}
          %lock_0_4_18 = AIE.lock(%tile_0_4, 1) {init = 1 : i32}
          %lock_0_4_19 = AIE.lock(%tile_0_4, 0) {init = 0 : i32}
          %lock_0_5 = AIE.lock(%tile_0_5, 5) {init = 2 : i32}
          %lock_0_5_20 = AIE.lock(%tile_0_5, 4) {init = 0 : i32}
          %lock_0_5_21 = AIE.lock(%tile_0_5, 3) {init = 2 : i32}
          %lock_0_5_22 = AIE.lock(%tile_0_5, 2) {init = 0 : i32}
          %lock_0_5_23 = AIE.lock(%tile_0_5, 1) {init = 1 : i32}
          %lock_0_5_24 = AIE.lock(%tile_0_5, 0) {init = 0 : i32}
          %buffer_0_1 = AIE.buffer(%tile_0_1) {sym_name = "buf22"} : memref<32x256xi32, 1>
          %buffer_0_1_25 = AIE.buffer(%tile_0_1) {sym_name = "buf21"} : memref<256x32xi32, 1>
          %buffer_0_1_26 = AIE.buffer(%tile_0_1) {sym_name = "buf20"} : memref<32x32xi32, 1>
          %buffer_0_5 = AIE.buffer(%tile_0_5) {sym_name = "buf19"} : memref<16x16xi32, 2>
          %buffer_0_5_27 = AIE.buffer(%tile_0_5) {sym_name = "buf18"} : memref<16x16xi32, 2>
          %buffer_0_5_28 = AIE.buffer(%tile_0_5) {sym_name = "buf17"} : memref<16x16xi32, 2>
          %buffer_0_5_29 = AIE.buffer(%tile_0_5) {sym_name = "buf16"} : memref<16x16xi32, 2>
          %buffer_0_5_30 = AIE.buffer(%tile_0_5) {sym_name = "buf15"} : memref<16x16xi32, 2>
          %buffer_0_4 = AIE.buffer(%tile_0_4) {sym_name = "buf14"} : memref<16x16xi32, 2>
          %buffer_0_4_31 = AIE.buffer(%tile_0_4) {sym_name = "buf13"} : memref<16x16xi32, 2>
          %buffer_0_4_32 = AIE.buffer(%tile_0_4) {sym_name = "buf12"} : memref<16x16xi32, 2>
          %buffer_0_4_33 = AIE.buffer(%tile_0_4) {sym_name = "buf11"} : memref<16x16xi32, 2>
          %buffer_0_4_34 = AIE.buffer(%tile_0_4) {sym_name = "buf10"} : memref<16x16xi32, 2>
          %buffer_0_3 = AIE.buffer(%tile_0_3) {sym_name = "buf9"} : memref<16x16xi32, 2>
          %buffer_0_3_35 = AIE.buffer(%tile_0_3) {sym_name = "buf8"} : memref<16x16xi32, 2>
          %buffer_0_3_36 = AIE.buffer(%tile_0_3) {sym_name = "buf7"} : memref<16x16xi32, 2>
          %buffer_0_3_37 = AIE.buffer(%tile_0_3) {sym_name = "buf6"} : memref<16x16xi32, 2>
          %buffer_0_3_38 = AIE.buffer(%tile_0_3) {sym_name = "buf5"} : memref<16x16xi32, 2>
          %buffer_0_2 = AIE.buffer(%tile_0_2) {sym_name = "buf4"} : memref<16x16xi32, 2>
          %buffer_0_2_39 = AIE.buffer(%tile_0_2) {sym_name = "buf3"} : memref<16x16xi32, 2>
          %buffer_0_2_40 = AIE.buffer(%tile_0_2) {sym_name = "buf2"} : memref<16x16xi32, 2>
          %buffer_0_2_41 = AIE.buffer(%tile_0_2) {sym_name = "buf1"} : memref<16x16xi32, 2>
          %buffer_0_2_42 = AIE.buffer(%tile_0_2) {sym_name = "buf0"} : memref<16x16xi32, 2>
          %mem_0_5 = AIE.mem(%tile_0_5) {
            %0 = AIE.dmaStart(S2MM, 0, ^bb1, ^bb7)
          ^bb1:  // 2 preds: ^bb0, ^bb2
            AIE.useLock(%lock_0_5_21, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_5_29 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_5_22, Release, 1)
            AIE.nextBd ^bb2
          ^bb2:  // pred: ^bb1
            AIE.useLock(%lock_0_5_21, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_5_28 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_5_22, Release, 1)
            AIE.nextBd ^bb1
          ^bb3:  // pred: ^bb4
            AIE.end
          ^bb4:  // pred: ^bb7
            %1 = AIE.dmaStart(S2MM, 1, ^bb5, ^bb3)
          ^bb5:  // 2 preds: ^bb4, ^bb6
            AIE.useLock(%lock_0_5, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_5_30 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_5_20, Release, 1)
            AIE.nextBd ^bb6
          ^bb6:  // pred: ^bb5
            AIE.useLock(%lock_0_5, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_5_27 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_5_20, Release, 1)
            AIE.nextBd ^bb5
          ^bb7:  // pred: ^bb0
            %2 = AIE.dmaStart(MM2S, 0, ^bb8, ^bb4)
          ^bb8:  // 2 preds: ^bb7, ^bb8
            AIE.useLock(%lock_0_5_24, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_5 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_5_23, Release, 1)
            AIE.nextBd ^bb8
          }
          %core_0_5 = AIE.core(%tile_0_5) {
            %c32 = arith.constant 32 : index
            %c0_i32 = arith.constant 0 : i32
            %c256 = arith.constant 256 : index
            %c16 = arith.constant 16 : index
            %c1 = arith.constant 1 : index
            %c0 = arith.constant 0 : index
            cf.br ^bb1
          ^bb1:  // 2 preds: ^bb0, ^bb1
            scf.for %arg0 = %c0 to %c16 step %c1 {
              scf.for %arg1 = %c0 to %c16 step %c1 {
                memref.store %c0_i32, %buffer_0_5[%arg0, %arg1] : memref<16x16xi32, 2>
              }
            }
            scf.for %arg0 = %c0 to %c256 step %c32 {
              AIE.useLock(%lock_0_5_22, AcquireGreaterEqual, 1)
              AIE.useLock(%lock_0_5_20, AcquireGreaterEqual, 1)
              scf.for %arg1 = %c0 to %c16 step %c1 {
                scf.for %arg2 = %c0 to %c16 step %c1 {
                  scf.for %arg3 = %c0 to %c16 step %c1 {
                    %0 = memref.load %buffer_0_5_29[%arg1, %arg3] : memref<16x16xi32, 2>
                    %1 = memref.load %buffer_0_5_30[%arg3, %arg2] : memref<16x16xi32, 2>
                    %2 = memref.load %buffer_0_5[%arg1, %arg2] : memref<16x16xi32, 2>
                    %3 = arith.muli %0, %1 : i32
                    %4 = arith.addi %2, %3 : i32
                    memref.store %4, %buffer_0_5[%arg1, %arg2] : memref<16x16xi32, 2>
                  }
                }
              }
              AIE.useLock(%lock_0_5_21, Release, 1)
              AIE.useLock(%lock_0_5, Release, 1)
              AIE.useLock(%lock_0_5_22, AcquireGreaterEqual, 1)
              AIE.useLock(%lock_0_5_20, AcquireGreaterEqual, 1)
              scf.for %arg1 = %c0 to %c16 step %c1 {
                scf.for %arg2 = %c0 to %c16 step %c1 {
                  scf.for %arg3 = %c0 to %c16 step %c1 {
                    %0 = memref.load %buffer_0_5_28[%arg1, %arg3] : memref<16x16xi32, 2>
                    %1 = memref.load %buffer_0_5_27[%arg3, %arg2] : memref<16x16xi32, 2>
                    %2 = memref.load %buffer_0_5[%arg1, %arg2] : memref<16x16xi32, 2>
                    %3 = arith.muli %0, %1 : i32
                    %4 = arith.addi %2, %3 : i32
                    memref.store %4, %buffer_0_5[%arg1, %arg2] : memref<16x16xi32, 2>
                  }
                }
              }
              AIE.useLock(%lock_0_5_21, Release, 1)
              AIE.useLock(%lock_0_5, Release, 1)
            }
            AIE.useLock(%lock_0_5_23, AcquireGreaterEqual, 1)
            AIE.useLock(%lock_0_5_24, Release, 1)
            cf.br ^bb1
          } {elf_file = "segment_0_core_0_5.elf"}
          %mem_0_4 = AIE.mem(%tile_0_4) {
            %0 = AIE.dmaStart(S2MM, 0, ^bb1, ^bb7)
          ^bb1:  // 2 preds: ^bb0, ^bb2
            AIE.useLock(%lock_0_4_16, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_4_33 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_4_17, Release, 1)
            AIE.nextBd ^bb2
          ^bb2:  // pred: ^bb1
            AIE.useLock(%lock_0_4_16, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_4_32 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_4_17, Release, 1)
            AIE.nextBd ^bb1
          ^bb3:  // pred: ^bb4
            AIE.end
          ^bb4:  // pred: ^bb7
            %1 = AIE.dmaStart(S2MM, 1, ^bb5, ^bb3)
          ^bb5:  // 2 preds: ^bb4, ^bb6
            AIE.useLock(%lock_0_4, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_4_34 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_4_15, Release, 1)
            AIE.nextBd ^bb6
          ^bb6:  // pred: ^bb5
            AIE.useLock(%lock_0_4, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_4_31 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_4_15, Release, 1)
            AIE.nextBd ^bb5
          ^bb7:  // pred: ^bb0
            %2 = AIE.dmaStart(MM2S, 0, ^bb8, ^bb4)
          ^bb8:  // 2 preds: ^bb7, ^bb8
            AIE.useLock(%lock_0_4_19, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_4 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_4_18, Release, 1)
            AIE.nextBd ^bb8
          }
          %core_0_4 = AIE.core(%tile_0_4) {
            %c32 = arith.constant 32 : index
            %c0_i32 = arith.constant 0 : i32
            %c256 = arith.constant 256 : index
            %c16 = arith.constant 16 : index
            %c1 = arith.constant 1 : index
            %c0 = arith.constant 0 : index
            cf.br ^bb1
          ^bb1:  // 2 preds: ^bb0, ^bb1
            scf.for %arg0 = %c0 to %c16 step %c1 {
              scf.for %arg1 = %c0 to %c16 step %c1 {
                memref.store %c0_i32, %buffer_0_4[%arg0, %arg1] : memref<16x16xi32, 2>
              }
            }
            scf.for %arg0 = %c0 to %c256 step %c32 {
              AIE.useLock(%lock_0_4_17, AcquireGreaterEqual, 1)
              AIE.useLock(%lock_0_4_15, AcquireGreaterEqual, 1)
              scf.for %arg1 = %c0 to %c16 step %c1 {
                scf.for %arg2 = %c0 to %c16 step %c1 {
                  scf.for %arg3 = %c0 to %c16 step %c1 {
                    %0 = memref.load %buffer_0_4_33[%arg1, %arg3] : memref<16x16xi32, 2>
                    %1 = memref.load %buffer_0_4_34[%arg3, %arg2] : memref<16x16xi32, 2>
                    %2 = memref.load %buffer_0_4[%arg1, %arg2] : memref<16x16xi32, 2>
                    %3 = arith.muli %0, %1 : i32
                    %4 = arith.addi %2, %3 : i32
                    memref.store %4, %buffer_0_4[%arg1, %arg2] : memref<16x16xi32, 2>
                  }
                }
              }
              AIE.useLock(%lock_0_4_16, Release, 1)
              AIE.useLock(%lock_0_4, Release, 1)
              AIE.useLock(%lock_0_4_17, AcquireGreaterEqual, 1)
              AIE.useLock(%lock_0_4_15, AcquireGreaterEqual, 1)
              scf.for %arg1 = %c0 to %c16 step %c1 {
                scf.for %arg2 = %c0 to %c16 step %c1 {
                  scf.for %arg3 = %c0 to %c16 step %c1 {
                    %0 = memref.load %buffer_0_4_32[%arg1, %arg3] : memref<16x16xi32, 2>
                    %1 = memref.load %buffer_0_4_31[%arg3, %arg2] : memref<16x16xi32, 2>
                    %2 = memref.load %buffer_0_4[%arg1, %arg2] : memref<16x16xi32, 2>
                    %3 = arith.muli %0, %1 : i32
                    %4 = arith.addi %2, %3 : i32
                    memref.store %4, %buffer_0_4[%arg1, %arg2] : memref<16x16xi32, 2>
                  }
                }
              }
              AIE.useLock(%lock_0_4_16, Release, 1)
              AIE.useLock(%lock_0_4, Release, 1)
            }
            AIE.useLock(%lock_0_4_18, AcquireGreaterEqual, 1)
            AIE.useLock(%lock_0_4_19, Release, 1)
            cf.br ^bb1
          } {elf_file = "segment_0_core_0_4.elf"}
          %mem_0_3 = AIE.mem(%tile_0_3) {
            %0 = AIE.dmaStart(S2MM, 0, ^bb1, ^bb7)
          ^bb1:  // 2 preds: ^bb0, ^bb2
            AIE.useLock(%lock_0_3_11, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_3_37 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_3_12, Release, 1)
            AIE.nextBd ^bb2
          ^bb2:  // pred: ^bb1
            AIE.useLock(%lock_0_3_11, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_3_36 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_3_12, Release, 1)
            AIE.nextBd ^bb1
          ^bb3:  // pred: ^bb4
            AIE.end
          ^bb4:  // pred: ^bb7
            %1 = AIE.dmaStart(S2MM, 1, ^bb5, ^bb3)
          ^bb5:  // 2 preds: ^bb4, ^bb6
            AIE.useLock(%lock_0_3, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_3_38 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_3_10, Release, 1)
            AIE.nextBd ^bb6
          ^bb6:  // pred: ^bb5
            AIE.useLock(%lock_0_3, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_3_35 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_3_10, Release, 1)
            AIE.nextBd ^bb5
          ^bb7:  // pred: ^bb0
            %2 = AIE.dmaStart(MM2S, 0, ^bb8, ^bb4)
          ^bb8:  // 2 preds: ^bb7, ^bb8
            AIE.useLock(%lock_0_3_14, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_3 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_3_13, Release, 1)
            AIE.nextBd ^bb8
          }
          %core_0_3 = AIE.core(%tile_0_3) {
            %c32 = arith.constant 32 : index
            %c0_i32 = arith.constant 0 : i32
            %c256 = arith.constant 256 : index
            %c16 = arith.constant 16 : index
            %c1 = arith.constant 1 : index
            %c0 = arith.constant 0 : index
            cf.br ^bb1
          ^bb1:  // 2 preds: ^bb0, ^bb1
            scf.for %arg0 = %c0 to %c16 step %c1 {
              scf.for %arg1 = %c0 to %c16 step %c1 {
                memref.store %c0_i32, %buffer_0_3[%arg0, %arg1] : memref<16x16xi32, 2>
              }
            }
            scf.for %arg0 = %c0 to %c256 step %c32 {
              AIE.useLock(%lock_0_3_12, AcquireGreaterEqual, 1)
              AIE.useLock(%lock_0_3_10, AcquireGreaterEqual, 1)
              scf.for %arg1 = %c0 to %c16 step %c1 {
                scf.for %arg2 = %c0 to %c16 step %c1 {
                  scf.for %arg3 = %c0 to %c16 step %c1 {
                    %0 = memref.load %buffer_0_3_37[%arg1, %arg3] : memref<16x16xi32, 2>
                    %1 = memref.load %buffer_0_3_38[%arg3, %arg2] : memref<16x16xi32, 2>
                    %2 = memref.load %buffer_0_3[%arg1, %arg2] : memref<16x16xi32, 2>
                    %3 = arith.muli %0, %1 : i32
                    %4 = arith.addi %2, %3 : i32
                    memref.store %4, %buffer_0_3[%arg1, %arg2] : memref<16x16xi32, 2>
                  }
                }
              }
              AIE.useLock(%lock_0_3_11, Release, 1)
              AIE.useLock(%lock_0_3, Release, 1)
              AIE.useLock(%lock_0_3_12, AcquireGreaterEqual, 1)
              AIE.useLock(%lock_0_3_10, AcquireGreaterEqual, 1)
              scf.for %arg1 = %c0 to %c16 step %c1 {
                scf.for %arg2 = %c0 to %c16 step %c1 {
                  scf.for %arg3 = %c0 to %c16 step %c1 {
                    %0 = memref.load %buffer_0_3_36[%arg1, %arg3] : memref<16x16xi32, 2>
                    %1 = memref.load %buffer_0_3_35[%arg3, %arg2] : memref<16x16xi32, 2>
                    %2 = memref.load %buffer_0_3[%arg1, %arg2] : memref<16x16xi32, 2>
                    %3 = arith.muli %0, %1 : i32
                    %4 = arith.addi %2, %3 : i32
                    memref.store %4, %buffer_0_3[%arg1, %arg2] : memref<16x16xi32, 2>
                  }
                }
              }
              AIE.useLock(%lock_0_3_11, Release, 1)
              AIE.useLock(%lock_0_3, Release, 1)
            }
            AIE.useLock(%lock_0_3_13, AcquireGreaterEqual, 1)
            AIE.useLock(%lock_0_3_14, Release, 1)
            cf.br ^bb1
          } {elf_file = "segment_0_core_0_3.elf"}
          %mem_0_2 = AIE.mem(%tile_0_2) {
            %0 = AIE.dmaStart(S2MM, 0, ^bb1, ^bb7)
          ^bb1:  // 2 preds: ^bb0, ^bb2
            AIE.useLock(%lock_0_2_6, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_2_41 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_2_7, Release, 1)
            AIE.nextBd ^bb2
          ^bb2:  // pred: ^bb1
            AIE.useLock(%lock_0_2_6, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_2_40 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_2_7, Release, 1)
            AIE.nextBd ^bb1
          ^bb3:  // pred: ^bb4
            AIE.end
          ^bb4:  // pred: ^bb7
            %1 = AIE.dmaStart(S2MM, 1, ^bb5, ^bb3)
          ^bb5:  // 2 preds: ^bb4, ^bb6
            AIE.useLock(%lock_0_2, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_2_42 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_2_5, Release, 1)
            AIE.nextBd ^bb6
          ^bb6:  // pred: ^bb5
            AIE.useLock(%lock_0_2, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_2_39 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_2_5, Release, 1)
            AIE.nextBd ^bb5
          ^bb7:  // pred: ^bb0
            %2 = AIE.dmaStart(MM2S, 0, ^bb8, ^bb4)
          ^bb8:  // 2 preds: ^bb7, ^bb8
            AIE.useLock(%lock_0_2_9, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_2 : memref<16x16xi32, 2>, 0, 256>, 0)
            AIE.useLock(%lock_0_2_8, Release, 1)
            AIE.nextBd ^bb8
          }
          %core_0_2 = AIE.core(%tile_0_2) {
            %c32 = arith.constant 32 : index
            %c0_i32 = arith.constant 0 : i32
            %c256 = arith.constant 256 : index
            %c16 = arith.constant 16 : index
            %c1 = arith.constant 1 : index
            %c0 = arith.constant 0 : index
            cf.br ^bb1
          ^bb1:  // 2 preds: ^bb0, ^bb1
            scf.for %arg0 = %c0 to %c16 step %c1 {
              scf.for %arg1 = %c0 to %c16 step %c1 {
                memref.store %c0_i32, %buffer_0_2[%arg0, %arg1] : memref<16x16xi32, 2>
              }
            }
            scf.for %arg0 = %c0 to %c256 step %c32 {
              AIE.useLock(%lock_0_2_7, AcquireGreaterEqual, 1)
              AIE.useLock(%lock_0_2_5, AcquireGreaterEqual, 1)
              scf.for %arg1 = %c0 to %c16 step %c1 {
                scf.for %arg2 = %c0 to %c16 step %c1 {
                  scf.for %arg3 = %c0 to %c16 step %c1 {
                    %0 = memref.load %buffer_0_2_41[%arg1, %arg3] : memref<16x16xi32, 2>
                    %1 = memref.load %buffer_0_2_42[%arg3, %arg2] : memref<16x16xi32, 2>
                    %2 = memref.load %buffer_0_2[%arg1, %arg2] : memref<16x16xi32, 2>
                    %3 = arith.muli %0, %1 : i32
                    %4 = arith.addi %2, %3 : i32
                    memref.store %4, %buffer_0_2[%arg1, %arg2] : memref<16x16xi32, 2>
                  }
                }
              }
              AIE.useLock(%lock_0_2_6, Release, 1)
              AIE.useLock(%lock_0_2, Release, 1)
              AIE.useLock(%lock_0_2_7, AcquireGreaterEqual, 1)
              AIE.useLock(%lock_0_2_5, AcquireGreaterEqual, 1)
              scf.for %arg1 = %c0 to %c16 step %c1 {
                scf.for %arg2 = %c0 to %c16 step %c1 {
                  scf.for %arg3 = %c0 to %c16 step %c1 {
                    %0 = memref.load %buffer_0_2_40[%arg1, %arg3] : memref<16x16xi32, 2>
                    %1 = memref.load %buffer_0_2_39[%arg3, %arg2] : memref<16x16xi32, 2>
                    %2 = memref.load %buffer_0_2[%arg1, %arg2] : memref<16x16xi32, 2>
                    %3 = arith.muli %0, %1 : i32
                    %4 = arith.addi %2, %3 : i32
                    memref.store %4, %buffer_0_2[%arg1, %arg2] : memref<16x16xi32, 2>
                  }
                }
              }
              AIE.useLock(%lock_0_2_6, Release, 1)
              AIE.useLock(%lock_0_2, Release, 1)
            }
            AIE.useLock(%lock_0_2_8, AcquireGreaterEqual, 1)
            AIE.useLock(%lock_0_2_9, Release, 1)
            cf.br ^bb1
          } {elf_file = "segment_0_core_0_2.elf"}
          AIE.flow(%tile_0_0, DMA : 0, %tile_0_1, DMA : 0)
          AIE.flow(%tile_0_0, DMA : 1, %tile_0_1, DMA : 1)
          AIE.flow(%tile_0_1, DMA : 0, %tile_0_0, DMA : 0)
          AIE.flow(%tile_0_1, DMA : 1, %tile_0_2, DMA : 0)
          AIE.flow(%tile_0_1, DMA : 1, %tile_0_3, DMA : 0)
          AIE.flow(%tile_0_1, DMA : 2, %tile_0_4, DMA : 0)
          AIE.flow(%tile_0_1, DMA : 2, %tile_0_5, DMA : 0)
          AIE.flow(%tile_0_1, DMA : 3, %tile_0_2, DMA : 1)
          AIE.flow(%tile_0_1, DMA : 3, %tile_0_4, DMA : 1)
          AIE.flow(%tile_0_1, DMA : 4, %tile_0_3, DMA : 1)
          AIE.flow(%tile_0_1, DMA : 4, %tile_0_5, DMA : 1)
          AIE.flow(%tile_0_2, DMA : 0, %tile_0_1, DMA : 2)
          AIE.flow(%tile_0_4, DMA : 0, %tile_0_1, DMA : 3)
          AIE.flow(%tile_0_3, DMA : 0, %tile_0_1, DMA : 4)
          AIE.flow(%tile_0_5, DMA : 0, %tile_0_1, DMA : 5)
          %memTileDMA_0_1 = AIE.memTileDMA(%tile_0_1) {
            %0 = AIE.dmaStart(S2MM, 0, ^bb1, ^bb21)
          ^bb1:  // 2 preds: ^bb0, ^bb1
            AIE.useLock(%lock_0_1_3, AcquireGreaterEqual, 2)
            AIE.dmaBd(<%buffer_0_1 : memref<32x256xi32, 1>, 0, 8192>, 0)
            AIE.useLock(%lock_0_1_4, Release, 2)
            AIE.nextBd ^bb1
          ^bb2:  // pred: ^bb3
            AIE.end
          ^bb3:  // pred: ^bb5
            %1 = AIE.dmaStart(S2MM, 1, ^bb4, ^bb2)
          ^bb4:  // 2 preds: ^bb3, ^bb4
            AIE.useLock(%lock_0_1_1, AcquireGreaterEqual, 2)
            AIE.dmaBd(<%buffer_0_1_25 : memref<256x32xi32, 1>, 0, 8192>, 0)
            AIE.useLock(%lock_0_1_2, Release, 2)
            AIE.nextBd ^bb4
          ^bb5:  // pred: ^bb7
            %2 = AIE.dmaStart(S2MM, 2, ^bb6, ^bb3)
          ^bb6:  // 2 preds: ^bb5, ^bb6
            AIE.useLock(%lock_0_1, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_1_26 : memref<32x32xi32, 1>, 0, 256>, 0, [<16, 32>, <16, 1>])
            AIE.useLock(%lock_0_1_0, Release, 1)
            AIE.nextBd ^bb6
          ^bb7:  // pred: ^bb9
            %3 = AIE.dmaStart(S2MM, 3, ^bb8, ^bb5)
          ^bb8:  // 2 preds: ^bb7, ^bb8
            AIE.useLock(%lock_0_1, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_1_26 : memref<32x32xi32, 1>, 2048, 256>, 0, [<16, 32>, <16, 1>])
            AIE.useLock(%lock_0_1_0, Release, 1)
            AIE.nextBd ^bb8
          ^bb9:  // pred: ^bb11
            %4 = AIE.dmaStart(S2MM, 4, ^bb10, ^bb7)
          ^bb10:  // 2 preds: ^bb9, ^bb10
            AIE.useLock(%lock_0_1, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_1_26 : memref<32x32xi32, 1>, 64, 256>, 0, [<16, 32>, <16, 1>])
            AIE.useLock(%lock_0_1_0, Release, 1)
            AIE.nextBd ^bb10
          ^bb11:  // pred: ^bb13
            %5 = AIE.dmaStart(S2MM, 5, ^bb12, ^bb9)
          ^bb12:  // 2 preds: ^bb11, ^bb12
            AIE.useLock(%lock_0_1, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_1_26 : memref<32x32xi32, 1>, 2112, 256>, 0, [<16, 32>, <16, 1>])
            AIE.useLock(%lock_0_1_0, Release, 1)
            AIE.nextBd ^bb12
          ^bb13:  // pred: ^bb15
            %6 = AIE.dmaStart(MM2S, 0, ^bb14, ^bb11)
          ^bb14:  // 2 preds: ^bb13, ^bb14
            AIE.useLock(%lock_0_1_0, AcquireGreaterEqual, 4)
            AIE.dmaBd(<%buffer_0_1_26 : memref<32x32xi32, 1>, 0, 1024>, 0)
            AIE.useLock(%lock_0_1, Release, 4)
            AIE.nextBd ^bb14
          ^bb15:  // pred: ^bb17
            %7 = AIE.dmaStart(MM2S, 1, ^bb16, ^bb13)
          ^bb16:  // 2 preds: ^bb15, ^bb16
            AIE.useLock(%lock_0_1_4, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_1 : memref<32x256xi32, 1>, 0, 4096>, 0, [<16, 16>, <16, 256>, <16, 1>])
            AIE.useLock(%lock_0_1_3, Release, 1)
            AIE.nextBd ^bb16
          ^bb17:  // pred: ^bb19
            %8 = AIE.dmaStart(MM2S, 2, ^bb18, ^bb15)
          ^bb18:  // 2 preds: ^bb17, ^bb18
            AIE.useLock(%lock_0_1_4, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_1 : memref<32x256xi32, 1>, 16384, 4096>, 0, [<16, 16>, <16, 256>, <16, 1>])
            AIE.useLock(%lock_0_1_3, Release, 1)
            AIE.nextBd ^bb18
          ^bb19:  // pred: ^bb21
            %9 = AIE.dmaStart(MM2S, 3, ^bb20, ^bb17)
          ^bb20:  // 2 preds: ^bb19, ^bb20
            AIE.useLock(%lock_0_1_2, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_1_25 : memref<256x32xi32, 1>, 0, 4096>, 0, [<16, 512>, <16, 32>, <16, 1>])
            AIE.useLock(%lock_0_1_1, Release, 1)
            AIE.nextBd ^bb20
          ^bb21:  // pred: ^bb0
            %10 = AIE.dmaStart(MM2S, 4, ^bb22, ^bb19)
          ^bb22:  // 2 preds: ^bb21, ^bb22
            AIE.useLock(%lock_0_1_2, AcquireGreaterEqual, 1)
            AIE.dmaBd(<%buffer_0_1_25 : memref<256x32xi32, 1>, 64, 4096>, 0, [<16, 512>, <16, 32>, <16, 1>])
            AIE.useLock(%lock_0_1_1, Release, 1)
            AIE.nextBd ^bb22
          }
          AIE.shimDMAAllocation @airMemcpyId16(S2MM, 0, 0)
          memref.global "public" @airMemcpyId16 : memref<32x32xi32, 1>
          AIE.shimDMAAllocation @airMemcpyId4(MM2S, 0, 0)
          memref.global "public" @airMemcpyId4 : memref<32x256xi32, 1>
          AIE.shimDMAAllocation @airMemcpyId5(MM2S, 1, 0)
          memref.global "public" @airMemcpyId5 : memref<256x32xi32, 1>
          func.func @matmul_static_dispatch_0_matmul_64x64x256_i32(%arg0: memref<64x256xi32>, %arg1: memref<256x64xi32>, %arg2: memref<64x64xi32>) {
            %c0_i32 = arith.constant 0 : i32
            %c1_i32 = arith.constant 1 : i32
            %c32_i32 = arith.constant 32 : i32
            %c256_i32 = arith.constant 256 : i32
            %c2_i32 = arith.constant 2 : i32
            %c64_i32 = arith.constant 64 : i32
            %c2048_i32 = arith.constant 2048 : i32
            memref.assume_alignment %arg0, 64 : memref<64x256xi32>
            memref.assume_alignment %arg1, 64 : memref<256x64xi32>
            memref.assume_alignment %arg2, 64 : memref<64x64xi32>
            AIEX.ipu.dma_memcpy_nd(%c0_i32, %c0_i32, %arg0[%c0_i32, %c0_i32, %c0_i32, %c0_i32] [%c1_i32, %c1_i32, %c32_i32, %c256_i32] [%c0_i32, %c0_i32, %c256_i32]) {id = 1 : i32, metadata = @airMemcpyId4} : (i32, i32, memref<64x256xi32>, [i32, i32, i32, i32], [i32, i32, i32, i32], [i32, i32, i32])
            AIEX.ipu.dma_memcpy_nd(%c0_i32, %c0_i32, %arg0[%c0_i32, %c0_i32, %c0_i32, %c0_i32] [%c1_i32, %c1_i32, %c32_i32, %c256_i32] [%c0_i32, %c0_i32, %c256_i32]) {id = 2 : i32, metadata = @airMemcpyId4} : (i32, i32, memref<64x256xi32>, [i32, i32, i32, i32], [i32, i32, i32, i32], [i32, i32, i32])
            AIEX.ipu.dma_memcpy_nd(%c0_i32, %c0_i32, %arg0[%c0_i32, %c0_i32, %c32_i32, %c0_i32] [%c1_i32, %c1_i32, %c32_i32, %c256_i32] [%c0_i32, %c0_i32, %c256_i32]) {id = 3 : i32, metadata = @airMemcpyId4} : (i32, i32, memref<64x256xi32>, [i32, i32, i32, i32], [i32, i32, i32, i32], [i32, i32, i32])
            AIEX.ipu.dma_memcpy_nd(%c0_i32, %c0_i32, %arg0[%c0_i32, %c0_i32, %c32_i32, %c0_i32] [%c1_i32, %c1_i32, %c32_i32, %c256_i32] [%c0_i32, %c0_i32, %c256_i32]) {id = 4 : i32, metadata = @airMemcpyId4} : (i32, i32, memref<64x256xi32>, [i32, i32, i32, i32], [i32, i32, i32, i32], [i32, i32, i32])
            AIEX.ipu.dma_memcpy_nd(%c0_i32, %c0_i32, %arg1[%c0_i32, %c0_i32, %c0_i32, %c0_i32] [%c2_i32, %c2_i32, %c256_i32, %c32_i32] [%c0_i32, %c32_i32, %c64_i32]) {id = 5 : i32, metadata = @airMemcpyId5} : (i32, i32, memref<256x64xi32>, [i32, i32, i32, i32], [i32, i32, i32, i32], [i32, i32, i32])
            AIEX.ipu.dma_memcpy_nd(%c0_i32, %c0_i32, %arg2[%c0_i32, %c0_i32, %c0_i32, %c0_i32] [%c2_i32, %c2_i32, %c32_i32, %c32_i32] [%c2048_i32, %c32_i32, %c64_i32]) {id = 6 : i32, metadata = @airMemcpyId16} : (i32, i32, memref<64x64xi32>, [i32, i32, i32, i32], [i32, i32, i32, i32], [i32, i32, i32])
            AIEX.ipu.sync {channel = 0 : i32, column = 0 : i32, column_num = 1 : i32, direction = 0 : i32, row = 0 : i32, row_num = 1 : i32}
            return
          }
        } {sym_name = "segment_0"}
      }
    }
  }
  func.func @matmul_static(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    %c65536 = arith.constant 65536 : index
    %c16384 = arith.constant 16384 : index
    %c0 = arith.constant 0 : index
    %c268435488_i32 = arith.constant 268435488 : i32
    %c1_i32 = arith.constant 1 : i32
    %c64 = arith.constant 64 : index
    %c256 = arith.constant 256 : index
    hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("input 0") shape([%c64, %c256]) type(%c268435488_i32) encoding(%c1_i32)
    %0 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<64x256xi32> in !stream.resource<external>{%c65536}
    hal.buffer_view.assert<%arg1 : !hal.buffer_view> message("input 1") shape([%c256, %c64]) type(%c268435488_i32) encoding(%c1_i32)
    %1 = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<256x64xi32> in !stream.resource<external>{%c65536}
    %result, %result_timepoint = stream.resource.alloca uninitialized : !stream.resource<external>{%c16384} => !stream.timepoint
    %2 = stream.cmd.execute await(%result_timepoint) => with(%0 as %arg2: !stream.resource<external>{%c65536}, %1 as %arg3: !stream.resource<external>{%c65536}, %result as %arg4: !stream.resource<external>{%c16384}) {
      stream.cmd.dispatch @matmul_static_dispatch_0::@elf::@matmul_static_dispatch_0_matmul_64x64x256_i32 {
        ro %arg2[%c0 for %c65536] : !stream.resource<external>{%c65536},
        ro %arg3[%c0 for %c65536] : !stream.resource<external>{%c65536},
        wo %arg4[%c0 for %c16384] : !stream.resource<external>{%c16384}
      } attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]}
    } => !stream.timepoint
    %3 = stream.timepoint.await %2 => %result : !stream.resource<external>{%c16384}
    %4 = stream.tensor.export %3 : tensor<64x64xi32> in !stream.resource<external>{%c16384} -> !hal.buffer_view
    return %4 : !hal.buffer_view
  }
}
At this point further lowering is handled by `mlir-aie` and handed off to peano.

Custom Dispatches

The below gist is a simple example showing how externally written kernels can be plugged into IREE. Further examples and details can be found here:

https://github.com/openxla/iree/tree/main/samples/custom_dispatch vulkan, cpu, cuda examples

Example

#spirv_target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit, CooperativeMatrixKHR], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, api=Vulkan, AMD:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, min_subgroup_size = 32, max_subgroup_size = 64, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = i8, b_type = i8, c_type = i32, result_type = i32, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>}>

module {

  func.func @forward(%arg0: tensor<1x32000xf16>) -> tensor<1xi64> {
    %c1 = arith.constant 1 : index
    %dim = tensor.dim %arg0, %c1 : tensor<1x32000xf16>
    %dim_i32 = arith.index_cast %dim : index to i32
    %4 = hal.dispatch.extern "main"[%dim](%dim_i32, %arg0) : (i32, tensor<1x32000xf16>) -> tensor<1xi64>
      count(%device: !hal.device, %workload: index) -> (index, index, index) {
        %c1_0 = arith.constant 1 : index
        hal.return %c1_0, %c1_0, %c1_0 : index, index, index
      }   
      layout(#hal.pipeline.layout<push_constants = 1, sets = [
        <0, bindings = [
            <0, storage_buffer, ReadOnly>,
            <1, storage_buffer>
        ]>
      ]>)
      bindings([
        #hal.interface.binding<0, 0>, 
        #hal.interface.binding<0, 1>
      ])  
      objects(#hal.executable.objects<{
        // Per-target specification of the kernel to use. The compiler will automatically select which one
        // to use based on the target being compiled for, or the runtime can pick when multi-targeting.
        #spirv_target = [ 
          #hal.executable.object<{
            // Path to the .spv/.hsaco, or can just embed the binary blob directly.
            path = "/home/quinn/one_workgroup_argmax_subgroup_f16.spv"
          }>
        ]
      }>)
    return %4 : tensor<1xi64>
  }
}

Full generation log

Running command : /tmp/mravisha/llvm-aie/install/RelWithDebInfo/bin/opt -O2 --inline-threshold=10 /tmp/matmul_static_dispatch_0-53cc41.bc -o /tmp/matmul_static_dispatch_0-48121d.opt.bc
Running command : /tmp/mravisha/llvm-aie/install/RelWithDebInfo/bin/llc -O2 --march=aie2 --function-sections --filetype=obj /tmp/matmul_static_dispatch_0-48121d.opt.bc -o /tmp/matmul_static_dispatch_0-b18ca6.o
Running command : /tmp/mravisha/llvm-aie/install/RelWithDebInfo/bin/clang -O2 --target=aie2-none-elf /tmp/matmul_static_dispatch_0-b18ca6.o /proj/xcohdstaff6/abhvarma/mlir-aie/install/aie_runtime_lib/AIE2/me_basic.o /tmp/mravisha/llvm-aie/install/RelWithDebInfo/lib/aie2-none-unknown-elf/libc.a -Wl,--gc-sections -Wl,-T,/tmp/segment_0_core_0_2.elf-2af211.ld.script -o /tmp/elf-a040f9/segment_0_core_0_2.elf
Running command : /tmp/mravisha/llvm-aie/install/RelWithDebInfo/bin/clang -O2 --target=aie2-none-elf /tmp/matmul_static_dispatch_0-b18ca6.o /proj/xcohdstaff6/abhvarma/mlir-aie/install/aie_runtime_lib/AIE2/me_basic.o /tmp/mravisha/llvm-aie/install/RelWithDebInfo/lib/aie2-none-unknown-elf/libc.a -Wl,--gc-sections -Wl,-T,/tmp/segment_0_core_0_3.elf-374038.ld.script -o /tmp/elf-a040f9/segment_0_core_0_3.elf
Running command : /tmp/mravisha/llvm-aie/install/RelWithDebInfo/bin/clang -O2 --target=aie2-none-elf /tmp/matmul_static_dispatch_0-b18ca6.o /proj/xcohdstaff6/abhvarma/mlir-aie/install/aie_runtime_lib/AIE2/me_basic.o /tmp/mravisha/llvm-aie/install/RelWithDebInfo/lib/aie2-none-unknown-elf/libc.a -Wl,--gc-sections -Wl,-T,/tmp/segment_0_core_0_4.elf-1702a2.ld.script -o /tmp/elf-a040f9/segment_0_core_0_4.elf
Running command : /tmp/mravisha/llvm-aie/install/RelWithDebInfo/bin/clang -O2 --target=aie2-none-elf /tmp/matmul_static_dispatch_0-b18ca6.o /proj/xcohdstaff6/abhvarma/mlir-aie/install/aie_runtime_lib/AIE2/me_basic.o /tmp/mravisha/llvm-aie/install/RelWithDebInfo/lib/aie2-none-unknown-elf/libc.a -Wl,--gc-sections -Wl,-T,/tmp/segment_0_core_0_5.elf-f5381d.ld.script -o /tmp/elf-a040f9/segment_0_core_0_5.elf
Running command : /tmp/mravisha/llvm-aie/install/RelWithDebInfo/bin/clang++ -fPIC -c -std=c++17 -D__AIEARCH__=20 -D__AIESIM__ -D__CDO__ -D__PS_INIT_AIE__ -D__LOCK_FENCE_MODE__=2 -DAIE_OPTION_SCALAR_FLOAT_ON_VECTOR -DAIE2_FP32_EMULATION_ACCURACY_FAST -Wno-deprecated-declarations -I/tmp/elf-a040f9 -I/proj/xcohdstaff6/abhvarma/mlir-aie/install/runtime_lib/x86_64/xaiengine/cdo/include -I/proj/xbuilds/SWIP/2023.2_1013_2256/installs/lin64/Vitis/2023.2/aietools/include -o /tmp/elf-a040f9/gen_cdo.o /proj/xcohdstaff6/abhvarma/mlir-aie/install/data/generated-source/gen_cdo.cpp
Running command : /tmp/mravisha/llvm-aie/install/RelWithDebInfo/bin/clang++ -fPIC -c -std=c++17 -I/tmp/elf-a040f9 -I/proj/xcohdstaff6/abhvarma/mlir-aie/install/runtime_lib/x86_64/xaiengine/cdo/include -I/proj/xbuilds/SWIP/2023.2_1013_2256/installs/lin64/Vitis/2023.2/aietools/include -o /tmp/elf-a040f9/cdo_main.o /proj/xcohdstaff6/abhvarma/mlir-aie/install/data/generated-source/cdo_main.cpp
Running command : /tmp/mravisha/llvm-aie/install/RelWithDebInfo/bin/clang++ -L/proj/xcohdstaff6/abhvarma/mlir-aie/install/runtime_lib/x86_64/xaiengine/cdo -L/proj/xbuilds/SWIP/2023.2_1013_2256/installs/lin64/Vitis/2023.2/aietools/lib/lnx64.o -lxaienginecdo -lcdo_driver -o /tmp/elf-a040f9/cdo_main /tmp/elf-a040f9/gen_cdo.o /tmp/elf-a040f9/cdo_main.o
Running command : LD_LIBRARY_PATH=/proj/xcohdstaff6/abhvarma/mlir-aie/install/runtime_lib/x86_64/xaiengine/cdo:/proj/xbuilds/SWIP/2023.2_1013_2256/installs/lin64/Vitis/2023.2/aietools/lib/lnx64.o /tmp/elf-a040f9/cdo_main --work-dir-path /tmp/elf-a040f9/
Generating: /tmp/elf-a040f9/aie_cdo_error_handling.bin
Generating: /tmp/elf-a040f9/aie_cdo_elfs.bin
Generating: /tmp/elf-a040f9/aie_cdo_init.bin
Generating: /tmp/elf-a040f9/aie_cdo_enable.bin
Running command : /proj/xbuilds/SWIP/2023.2_1013_2256/installs/lin64/Vitis/2023.2/bin/bootgen -arch versal -image /tmp/elf-a040f9/design.bif -o /tmp/elf-a040f9/design.pdi -w


****** Bootgen v2023.2
  **** Build date : Oct 11 2023-12:50:27
    ** Copyright 1986-2022 Xilinx, Inc. All Rights Reserved.
    ** Copyright 2022-2023 Advanced Micro Devices, Inc. All Rights Reserved.


[INFO]   : Bootimage generated successfully

Running command : /proj/xbuilds/SWIP/2023.2_1013_2256/installs/lin64/Vitis/2023.2/bin/xclbinutil --input /proj/xcohdstaff6/abhvarma/mlir-aie/install/data/1x4.xclbin --add-kernel /tmp/elf-a040f9/kernels.json --add-replace-section AIE_PARTITION:JSON:/tmp/elf-a040f9/aie_partition.json --force --output /tmp/elf-a040f9/final.xclbin
XRT Build Version: 2.16.0 (Vitis)
       Build Date: 2023-07-13 16:00:55
          Hash ID: 157faa07876c55bb8aa8ec51b28608a6a0f6638e
Reading xclbin file into memory.  File: /proj/xcohdstaff6/abhvarma/mlir-aie/install/data/1x4.xclbin

Section 'AIE_PARTITION'(32) was successfully removed

Section: 'AIE_PARTITION'(32) was successfully added.
Size   : 16960 bytes
Format : JSON
File   : '/tmp/elf-a040f9/aie_partition.json'

Section 'GROUP_TOPOLOGY'(26) was successfully removed

Section 'GROUP_CONNECTIVITY'(27) was successfully removed
Successfully wrote (82484 bytes) to the output file: /tmp/elf-a040f9/final.xclbin
Leaving xclbinutil.
Clone this wiki locally