Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

one-shot-bufferize pass generates memref.alloc()s in GPU kernels code and breaks the pipeline #360

Open
zhczhong opened this issue Sep 25, 2024 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@zhczhong
Copy link
Member

zhczhong commented Sep 25, 2024

The input code is as follows and the insertGPUAllocs cannot deal with the following case properly. "gpu.dealloc"(%51) : (memref<16x16xf16>) -> () should be inserted inside the kernel code but be inserted ouside now

Reproducer

gc-opt --gc-gpu-pipeline file.mlir

File.mlir:

module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>, #dlti.dl_entry<"num_threads", 4 : i32>, #dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>, #dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>, #dlti.dl_entry<"L3_cache_size_in_bytes", 1966080 : i32>, #dlti.dl_entry<"max_vector_width", 512 : i32>>>} {
  func.func @entry(%arg0: memref<128x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<128x1024xf16>, %arg3: memref<128x1024xf16>) attributes {compiletime_const_args_index = [1 : i32, 2 : i32]} {
    %0 = bufferization.to_tensor %arg0 restrict : memref<128x1024xf16>
    %1 = bufferization.to_tensor %arg1 restrict : memref<1024x1024xf16>
    %2 = bufferization.to_tensor %arg2 restrict : memref<128x1024xf16>
    %3 = tensor.empty() : tensor<1024x1024xf16>
    %transposed = linalg.transpose ins(%1 : tensor<1024x1024xf16>) outs(%3 : tensor<1024x1024xf16>) permutation = [1, 0]
    %4 = tensor.empty() : tensor<128x1024xf16>
    %cst = arith.constant 0.000000e+00 : f16
    %5 = linalg.fill ins(%cst : f16) outs(%4 : tensor<128x1024xf16>) -> tensor<128x1024xf16>
    %6 = linalg.matmul ins(%0, %transposed : tensor<128x1024xf16>, tensor<1024x1024xf16>) outs(%5 : tensor<128x1024xf16>) -> tensor<128x1024xf16>
    %7 = tensor.empty() : tensor<128x1024xf16>
    %8 = linalg.add ins(%6, %2 : tensor<128x1024xf16>, tensor<128x1024xf16>) outs(%7 : tensor<128x1024xf16>) -> tensor<128x1024xf16>
    %9 = tensor.empty() : tensor<128x1024xf16>
    %cst_0 = arith.constant 0.000000e+00 : f16
    %10 = linalg.fill ins(%cst_0 : f16) outs(%9 : tensor<128x1024xf16>) -> tensor<128x1024xf16>
    %res = tensor.empty() : tensor<128x1024xf16>
    %11 = linalg.max ins(%8, %10 : tensor<128x1024xf16>, tensor<128x1024xf16>) outs(%res : tensor<128x1024xf16>) -> tensor<128x1024xf16>
    bufferization.materialize_in_destination %11 in restrict writable %arg3 : (tensor<128x1024xf16>, memref<128x1024xf16>) -> ()
    return
  }
}
long log

test.txt

/home/jovyan/code/graph-compiler/build/gpu_input.mlir:23:13: error: operand #0 does not dominate this use
      %20 = linalg.fill ins(%cst_0 : f16) outs(%extracted_slice_9 : tensor<16x16xf16>) -> tensor<16x16xf16>
            ^
/home/jovyan/code/graph-compiler/build/gpu_input.mlir:23:13: note: see current operation: "gpu.dealloc"(%159) : (memref<16x16xf16>) -> ()
/home/jovyan/code/graph-compiler/build/gpu_input.mlir:23:13: note: operand defined here (op in a child region)
// -----// IR Dump After InsertGPUAllocs Failed (insert-gpu-allocs) //----- //
"func.func"() <{function_type = (memref<1024x1024xf16>, memref<256x1024xf16>, memref<256x1024xf16>, memref<256x1024xf16>, memref<256x1024xf16>, memref<256x1024xf16>, memref<i8>) -> (), sym_name = "entry"}> ({
^bb0(%arg0: memref<1024x1024xf16>, %arg1: memref<256x1024xf16>, %arg2: memref<256x1024xf16>, %arg3: memref<256x1024xf16>, %arg4: memref<256x1024xf16>, %arg5: memref<256x1024xf16>, %arg6: memref<i8>):
  %0 = "arith.constant"() <{value = dense<0.000000e+00> : vector<256xf16>}> : () -> vector<256xf16>
  %1 = "arith.constant"() <{value = 8 : index}> : () -> index
  %2 = "arith.constant"() <{value = 16 : index}> : () -> index
  %3 = "arith.constant"() <{value = 1024 : index}> : () -> index
  %4 = "arith.constant"() <{value = 256 : index}> : () -> index
  %5 = "arith.constant"() <{value = 0 : index}> : () -> index
  %6 = "gpu.alloc"() <{operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> memref<256x1024xf16>
  %7 = "gpu.alloc"() <{operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> memref<256x1024xf16>
  %8 = "gpu.alloc"() <{operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> memref<256x1024xf16>
  %9 = "arith.constant"() <{value = 1 : index}> : () -> index
  %10 = "affine.apply"(%4, %5, %2) <{map = affine_map<(d0)[s0, s1] -> ((d0 - s0) ceildiv s1)>}> : (index, index, index) -> index
  %11 = "affine.apply"(%3, %5, %2) <{map = affine_map<(d0)[s0, s1] -> ((d0 - s0) ceildiv s1)>}> : (index, index, index) -> index
  "gpu.launch"(%10, %11, %9, %9, %9, %9) <{operandSegmentSizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0>}> ({
  ^bb0(%arg41: index, %arg42: index, %arg43: index, %arg44: index, %arg45: index, %arg46: index, %arg47: index, %arg48: index, %arg49: index, %arg50: index, %arg51: index, %arg52: index):
    %131 = "affine.apply"(%arg41, %2, %5) <{map = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>}> : (index, index, index) -> index
    %132 = "affine.apply"(%arg42, %2, %5) <{map = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>}> : (index, index, index) -> index
    %133 = "xegpu.create_nd_tdesc"(%6, %131, %132) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %134 = "xegpu.update_nd_offset"(%133, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %135 = "vector.shape_cast"(%0) : (vector<256xf16>) -> vector<16x16xf16>
    "xegpu.store_nd"(%135, %134) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    %136 = "xegpu.create_nd_tdesc"(%6, %131, %132) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %137 = "xegpu.update_nd_offset"(%136, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %138 = "xegpu.update_nd_offset"(%136, %1, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %139 = "xegpu.load_nd"(%137) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<8x16xf16>
    %140 = "xegpu.load_nd"(%138) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<8x16xf16>
    %141 = "arith.extf"(%139) : (vector<8x16xf16>) -> vector<8x16xf32>
    %142 = "arith.extf"(%140) : (vector<8x16xf16>) -> vector<8x16xf32>
    %143 = "xegpu.create_nd_tdesc"(%arg1, %131) <{const_offsets = array<i64: -9223372036854775808, 0>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 1, 0, 0>}> : (memref<256x1024xf16>, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %144 = "xegpu.update_nd_offset"(%143, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %145 = "xegpu.create_nd_tdesc"(%arg0, %132) <{const_offsets = array<i64: 0, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 1, 0, 0>}> : (memref<1024x1024xf16>, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %146 = "xegpu.update_nd_offset"(%145, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %147:4 = "scf.for"(%5, %3, %2, %141, %142, %144, %146) ({
    ^bb0(%arg53: index, %arg54: vector<8x16xf32>, %arg55: vector<8x16xf32>, %arg56: !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, %arg57: !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>):
      %172 = "arith.remui"(%arg53, %3) : (index, index) -> index
      %173 = "arith.cmpi"(%172, %5) <{predicate = 0 : i64}> : (index, index) -> i1
      "scf.if"(%173) ({
        "gpu.barrier"() : () -> ()
        "scf.yield"() : () -> ()
      }, {
      }) : (i1) -> ()
      %174 = "xegpu.load_nd"(%arg56) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
      %175 = "xegpu.load_nd"(%arg57) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>, packed}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<8x16x2xf16>
      %176 = "xegpu.update_nd_offset"(%arg56, %5, %2) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
      %177 = "xegpu.update_nd_offset"(%arg57, %2, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
      "xegpu.prefetch_nd"(%176) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
      "xegpu.prefetch_nd"(%177) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
      %178 = "vector.shape_cast"(%174) : (vector<16x16xf16>) -> vector<256xf16>
      %179 = "vector.extract_strided_slice"(%178) <{offsets = [0], sizes = [128], strides = [1]}> : (vector<256xf16>) -> vector<128xf16>
      %180 = "vector.shape_cast"(%179) : (vector<128xf16>) -> vector<8x8x2xf16>
      %181 = "vector.extract_strided_slice"(%178) <{offsets = [128], sizes = [128], strides = [1]}> : (vector<256xf16>) -> vector<128xf16>
      %182 = "vector.shape_cast"(%181) : (vector<128xf16>) -> vector<8x8x2xf16>
      %183 = "xegpu.dpas"(%180, %175, %arg54) : (vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32>) -> vector<8x16xf32>
      %184 = "xegpu.dpas"(%182, %175, %arg55) : (vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32>) -> vector<8x16xf32>
      "scf.yield"(%183, %184, %176, %177) : (vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    }) : (index, index, index, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> (vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>)
    %148 = "arith.truncf"(%147#0) : (vector<8x16xf32>) -> vector<8x16xf16>
    %149 = "arith.truncf"(%147#1) : (vector<8x16xf32>) -> vector<8x16xf16>
    "xegpu.store_nd"(%148, %137) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    "xegpu.store_nd"(%149, %138) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    %150 = "xegpu.create_nd_tdesc"(%6, %131, %132) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %151 = "xegpu.update_nd_offset"(%150, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %152 = "xegpu.load_nd"(%151) : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
    %153 = "xegpu.create_nd_tdesc"(%arg2, %131, %132) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %154 = "xegpu.update_nd_offset"(%153, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %155 = "xegpu.load_nd"(%154) : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
    %156 = "arith.addf"(%152, %155) <{fastmath = #arith.fastmath<none>}> : (vector<16x16xf16>, vector<16x16xf16>) -> vector<16x16xf16>
    %157 = "xegpu.create_nd_tdesc"(%7, %131, %132) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %158 = "xegpu.update_nd_offset"(%157, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    "xegpu.store_nd"(%156, %158) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    %159 = "gpu.alloc"() <{operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> memref<16x16xf16>
    %160 = "xegpu.create_nd_tdesc"(%159, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 16, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<16x16xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %161 = "xegpu.update_nd_offset"(%160, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %162 = "vector.shape_cast"(%0) : (vector<256xf16>) -> vector<16x16xf16>
    "xegpu.store_nd"(%162, %161) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    %163 = "xegpu.create_nd_tdesc"(%7, %131, %132) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %164 = "xegpu.update_nd_offset"(%163, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %165 = "xegpu.load_nd"(%164) : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
    %166 = "xegpu.create_nd_tdesc"(%159, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 16, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<16x16xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %167 = "xegpu.update_nd_offset"(%166, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %168 = "xegpu.load_nd"(%167) : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
    %169 = "arith.maximumf"(%165, %168) <{fastmath = #arith.fastmath<none>}> : (vector<16x16xf16>, vector<16x16xf16>) -> vector<16x16xf16>
    %170 = "xegpu.create_nd_tdesc"(%8, %131, %132) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %171 = "xegpu.update_nd_offset"(%170, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    "xegpu.store_nd"(%169, %171) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    "gpu.terminator"() : () -> ()
  }) {SCFToGPU_visited, workgroup_attributions = 0 : i64} : (index, index, index, index, index, index) -> ()
  %12 = "gpu.alloc"() <{operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> memref<256x1024xf16>
  %13 = "gpu.alloc"() <{operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> memref<256x1024xf16>
  %14 = "gpu.alloc"() <{operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> memref<256x1024xf16>
  %15 = "arith.constant"() <{value = 1 : index}> : () -> index
  %16 = "affine.apply"(%4, %5, %2) <{map = affine_map<(d0)[s0, s1] -> ((d0 - s0) ceildiv s1)>}> : (index, index, index) -> index
  %17 = "affine.apply"(%3, %5, %2) <{map = affine_map<(d0)[s0, s1] -> ((d0 - s0) ceildiv s1)>}> : (index, index, index) -> index
  "gpu.launch"(%16, %17, %15, %15, %15, %15) <{operandSegmentSizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0>}> ({
  ^bb0(%arg24: index, %arg25: index, %arg26: index, %arg27: index, %arg28: index, %arg29: index, %arg30: index, %arg31: index, %arg32: index, %arg33: index, %arg34: index, %arg35: index):
    %77 = "affine.apply"(%arg24, %2, %5) <{map = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>}> : (index, index, index) -> index
    %78 = "affine.apply"(%arg25, %2, %5) <{map = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>}> : (index, index, index) -> index
    %79 = "xegpu.create_nd_tdesc"(%12, %77, %78) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %80 = "xegpu.update_nd_offset"(%79, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %81 = "vector.shape_cast"(%0) : (vector<256xf16>) -> vector<16x16xf16>
    "xegpu.store_nd"(%81, %80) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    %82 = "xegpu.create_nd_tdesc"(%12, %77, %78) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %83 = "xegpu.update_nd_offset"(%82, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %84 = "xegpu.update_nd_offset"(%82, %1, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %85 = "xegpu.load_nd"(%83) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<8x16xf16>
    %86 = "xegpu.load_nd"(%84) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<8x16xf16>
    %87 = "arith.extf"(%85) : (vector<8x16xf16>) -> vector<8x16xf32>
    %88 = "arith.extf"(%86) : (vector<8x16xf16>) -> vector<8x16xf32>
    %89 = "xegpu.create_nd_tdesc"(%8, %77) <{const_offsets = array<i64: -9223372036854775808, 0>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 1, 0, 0>}> : (memref<256x1024xf16>, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %90 = "xegpu.update_nd_offset"(%89, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %91 = "xegpu.create_nd_tdesc"(%arg0, %78) <{const_offsets = array<i64: 0, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 1, 0, 0>}> : (memref<1024x1024xf16>, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %92 = "xegpu.update_nd_offset"(%91, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %93:4 = "scf.for"(%5, %3, %2, %87, %88, %90, %92) ({
    ^bb0(%arg36: index, %arg37: vector<8x16xf32>, %arg38: vector<8x16xf32>, %arg39: !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, %arg40: !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>):
      %118 = "arith.remui"(%arg36, %3) : (index, index) -> index
      %119 = "arith.cmpi"(%118, %5) <{predicate = 0 : i64}> : (index, index) -> i1
      "scf.if"(%119) ({
        "gpu.barrier"() : () -> ()
        "scf.yield"() : () -> ()
      }, {
      }) : (i1) -> ()
      %120 = "xegpu.load_nd"(%arg39) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
      %121 = "xegpu.load_nd"(%arg40) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>, packed}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<8x16x2xf16>
      %122 = "xegpu.update_nd_offset"(%arg39, %5, %2) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
      %123 = "xegpu.update_nd_offset"(%arg40, %2, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
      "xegpu.prefetch_nd"(%122) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
      "xegpu.prefetch_nd"(%123) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
      %124 = "vector.shape_cast"(%120) : (vector<16x16xf16>) -> vector<256xf16>
      %125 = "vector.extract_strided_slice"(%124) <{offsets = [0], sizes = [128], strides = [1]}> : (vector<256xf16>) -> vector<128xf16>
      %126 = "vector.shape_cast"(%125) : (vector<128xf16>) -> vector<8x8x2xf16>
      %127 = "vector.extract_strided_slice"(%124) <{offsets = [128], sizes = [128], strides = [1]}> : (vector<256xf16>) -> vector<128xf16>
      %128 = "vector.shape_cast"(%127) : (vector<128xf16>) -> vector<8x8x2xf16>
      %129 = "xegpu.dpas"(%126, %121, %arg37) : (vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32>) -> vector<8x16xf32>
      %130 = "xegpu.dpas"(%128, %121, %arg38) : (vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32>) -> vector<8x16xf32>
      "scf.yield"(%129, %130, %122, %123) : (vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    }) : (index, index, index, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> (vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>)
    %94 = "arith.truncf"(%93#0) : (vector<8x16xf32>) -> vector<8x16xf16>
    %95 = "arith.truncf"(%93#1) : (vector<8x16xf32>) -> vector<8x16xf16>
    "xegpu.store_nd"(%94, %83) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    "xegpu.store_nd"(%95, %84) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    %96 = "xegpu.create_nd_tdesc"(%12, %77, %78) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %97 = "xegpu.update_nd_offset"(%96, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %98 = "xegpu.load_nd"(%97) : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
    %99 = "xegpu.create_nd_tdesc"(%arg3, %77, %78) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %100 = "xegpu.update_nd_offset"(%99, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %101 = "xegpu.load_nd"(%100) : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
    %102 = "arith.addf"(%98, %101) <{fastmath = #arith.fastmath<none>}> : (vector<16x16xf16>, vector<16x16xf16>) -> vector<16x16xf16>
    %103 = "xegpu.create_nd_tdesc"(%13, %77, %78) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %104 = "xegpu.update_nd_offset"(%103, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    "xegpu.store_nd"(%102, %104) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    %105 = "gpu.alloc"() <{operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> memref<16x16xf16>
    %106 = "xegpu.create_nd_tdesc"(%105, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 16, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<16x16xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %107 = "xegpu.update_nd_offset"(%106, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %108 = "vector.shape_cast"(%0) : (vector<256xf16>) -> vector<16x16xf16>
    "xegpu.store_nd"(%108, %107) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    %109 = "xegpu.create_nd_tdesc"(%13, %77, %78) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %110 = "xegpu.update_nd_offset"(%109, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %111 = "xegpu.load_nd"(%110) : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
    %112 = "xegpu.create_nd_tdesc"(%105, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 16, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<16x16xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %113 = "xegpu.update_nd_offset"(%112, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %114 = "xegpu.load_nd"(%113) : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
    %115 = "arith.maximumf"(%111, %114) <{fastmath = #arith.fastmath<none>}> : (vector<16x16xf16>, vector<16x16xf16>) -> vector<16x16xf16>
    %116 = "xegpu.create_nd_tdesc"(%14, %77, %78) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %117 = "xegpu.update_nd_offset"(%116, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    "xegpu.store_nd"(%115, %117) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    "gpu.terminator"() : () -> ()
  }) {SCFToGPU_visited, workgroup_attributions = 0 : i64} : (index, index, index, index, index, index) -> ()
  %18 = "gpu.alloc"() <{operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> memref<256x1024xf16>
  %19 = "gpu.alloc"() <{operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> memref<256x1024xf16>
  %20 = "arith.constant"() <{value = 1 : index}> : () -> index
  %21 = "affine.apply"(%4, %5, %2) <{map = affine_map<(d0)[s0, s1] -> ((d0 - s0) ceildiv s1)>}> : (index, index, index) -> index
  %22 = "affine.apply"(%3, %5, %2) <{map = affine_map<(d0)[s0, s1] -> ((d0 - s0) ceildiv s1)>}> : (index, index, index) -> index
  "gpu.launch"(%21, %22, %20, %20, %20, %20) <{operandSegmentSizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0>}> ({
  ^bb0(%arg7: index, %arg8: index, %arg9: index, %arg10: index, %arg11: index, %arg12: index, %arg13: index, %arg14: index, %arg15: index, %arg16: index, %arg17: index, %arg18: index):
    %23 = "affine.apply"(%arg7, %2, %5) <{map = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>}> : (index, index, index) -> index
    %24 = "affine.apply"(%arg8, %2, %5) <{map = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>}> : (index, index, index) -> index
    %25 = "xegpu.create_nd_tdesc"(%18, %23, %24) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %26 = "xegpu.update_nd_offset"(%25, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %27 = "vector.shape_cast"(%0) : (vector<256xf16>) -> vector<16x16xf16>
    "xegpu.store_nd"(%27, %26) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    %28 = "xegpu.create_nd_tdesc"(%18, %23, %24) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %29 = "xegpu.update_nd_offset"(%28, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %30 = "xegpu.update_nd_offset"(%28, %1, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %31 = "xegpu.load_nd"(%29) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<8x16xf16>
    %32 = "xegpu.load_nd"(%30) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<8x16xf16>
    %33 = "arith.extf"(%31) : (vector<8x16xf16>) -> vector<8x16xf32>
    %34 = "arith.extf"(%32) : (vector<8x16xf16>) -> vector<8x16xf32>
    %35 = "xegpu.create_nd_tdesc"(%14, %23) <{const_offsets = array<i64: -9223372036854775808, 0>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 1, 0, 0>}> : (memref<256x1024xf16>, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %36 = "xegpu.update_nd_offset"(%35, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %37 = "xegpu.create_nd_tdesc"(%arg0, %24) <{const_offsets = array<i64: 0, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 1, 0, 0>}> : (memref<1024x1024xf16>, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %38 = "xegpu.update_nd_offset"(%37, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %39:4 = "scf.for"(%5, %3, %2, %33, %34, %36, %38) ({
    ^bb0(%arg19: index, %arg20: vector<8x16xf32>, %arg21: vector<8x16xf32>, %arg22: !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, %arg23: !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>):
      %64 = "arith.remui"(%arg19, %3) : (index, index) -> index
      %65 = "arith.cmpi"(%64, %5) <{predicate = 0 : i64}> : (index, index) -> i1
      "scf.if"(%65) ({
        "gpu.barrier"() : () -> ()
        "scf.yield"() : () -> ()
      }, {
      }) : (i1) -> ()
      %66 = "xegpu.load_nd"(%arg22) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
      %67 = "xegpu.load_nd"(%arg23) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>, packed}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<8x16x2xf16>
      %68 = "xegpu.update_nd_offset"(%arg22, %5, %2) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
      %69 = "xegpu.update_nd_offset"(%arg23, %2, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
      "xegpu.prefetch_nd"(%68) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
      "xegpu.prefetch_nd"(%69) <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
      %70 = "vector.shape_cast"(%66) : (vector<16x16xf16>) -> vector<256xf16>
      %71 = "vector.extract_strided_slice"(%70) <{offsets = [0], sizes = [128], strides = [1]}> : (vector<256xf16>) -> vector<128xf16>
      %72 = "vector.shape_cast"(%71) : (vector<128xf16>) -> vector<8x8x2xf16>
      %73 = "vector.extract_strided_slice"(%70) <{offsets = [128], sizes = [128], strides = [1]}> : (vector<256xf16>) -> vector<128xf16>
      %74 = "vector.shape_cast"(%73) : (vector<128xf16>) -> vector<8x8x2xf16>
      %75 = "xegpu.dpas"(%72, %67, %arg20) : (vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32>) -> vector<8x16xf32>
      %76 = "xegpu.dpas"(%74, %67, %arg21) : (vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32>) -> vector<8x16xf32>
      "scf.yield"(%75, %76, %68, %69) : (vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    }) : (index, index, index, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> (vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>)
    %40 = "arith.truncf"(%39#0) : (vector<8x16xf32>) -> vector<8x16xf16>
    %41 = "arith.truncf"(%39#1) : (vector<8x16xf32>) -> vector<8x16xf16>
    "xegpu.store_nd"(%40, %29) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    "xegpu.store_nd"(%41, %30) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    %42 = "xegpu.create_nd_tdesc"(%18, %23, %24) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %43 = "xegpu.update_nd_offset"(%42, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %44 = "xegpu.load_nd"(%43) : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
    %45 = "xegpu.create_nd_tdesc"(%arg4, %23, %24) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %46 = "xegpu.update_nd_offset"(%45, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %47 = "xegpu.load_nd"(%46) : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
    %48 = "arith.addf"(%44, %47) <{fastmath = #arith.fastmath<none>}> : (vector<16x16xf16>, vector<16x16xf16>) -> vector<16x16xf16>
    %49 = "xegpu.create_nd_tdesc"(%19, %23, %24) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %50 = "xegpu.update_nd_offset"(%49, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    "xegpu.store_nd"(%48, %50) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    %51 = "gpu.alloc"() <{operandSegmentSizes = array<i32: 0, 0, 0>}> : () -> memref<16x16xf16>
    %52 = "xegpu.create_nd_tdesc"(%51, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 16, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<16x16xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %53 = "xegpu.update_nd_offset"(%52, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %54 = "vector.shape_cast"(%0) : (vector<256xf16>) -> vector<16x16xf16>
    "xegpu.store_nd"(%54, %53) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    %55 = "xegpu.create_nd_tdesc"(%19, %23, %24) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %56 = "xegpu.update_nd_offset"(%55, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %57 = "xegpu.load_nd"(%56) : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
    %58 = "xegpu.create_nd_tdesc"(%51, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 16, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<16x16xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %59 = "xegpu.update_nd_offset"(%58, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %60 = "xegpu.load_nd"(%59) : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> vector<16x16xf16>
    %61 = "arith.maximumf"(%57, %60) <{fastmath = #arith.fastmath<none>}> : (vector<16x16xf16>, vector<16x16xf16>) -> vector<16x16xf16>
    %62 = "xegpu.create_nd_tdesc"(%arg5, %23, %24) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>, const_strides = array<i64: 1024, 1>, operandSegmentSizes = array<i32: 1, 2, 0, 0>}> : (memref<256x1024xf16>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    %63 = "xegpu.update_nd_offset"(%62, %5, %5) <{const_offsets = array<i64: -9223372036854775808, -9223372036854775808>}> : (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>, index, index) -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>
    "xegpu.store_nd"(%61, %63) <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_scope =  global, array_length = 1 : i64, boundary_check = true>>) -> ()
    "gpu.terminator"() : () -> ()
  }) {SCFToGPU_visited, workgroup_attributions = 0 : i64} : (index, index, index, index, index, index) -> ()
  "gpu.dealloc"(%6) : (memref<256x1024xf16>) -> ()
  "gpu.dealloc"(%7) : (memref<256x1024xf16>) -> ()
  "gpu.dealloc"(%159) : (memref<16x16xf16>) -> ()
  "gpu.dealloc"(%8) : (memref<256x1024xf16>) -> ()
  "gpu.dealloc"(%12) : (memref<256x1024xf16>) -> ()
  "gpu.dealloc"(%13) : (memref<256x1024xf16>) -> ()
  "gpu.dealloc"(%105) : (memref<16x16xf16>) -> ()
  "gpu.dealloc"(%14) : (memref<256x1024xf16>) -> ()
  "gpu.dealloc"(%18) : (memref<256x1024xf16>) -> ()
  "gpu.dealloc"(%19) : (memref<256x1024xf16>) -> ()
  "gpu.dealloc"(%51) : (memref<16x16xf16>) -> ()
  "func.return"() : () -> ()
}) : () -> ()
@zhczhong zhczhong added the bug Something isn't working label Sep 25, 2024
@zhczhong zhczhong added this to the 0.2 GPU Downstream IMEX milestone Sep 25, 2024
@dchigarev
Copy link
Contributor

dchigarev commented Sep 25, 2024

I think the core problem is that there's a memref.alloc() in the kernel code. Even if insert-gpu-allocs could insert gpu.dealloc into the kernel code it would still fail later in the pipeline on gpu-to-spv stage (since allocs are illegal operations for a gpu kernel).

This alloc is produced by one-shot-bufferization pass and isn't hoisted by either buffer-hoisting or buffer-loop-hoisting passes. We should deal with this issue at the beginning of our pipeline. There are several options:

  1. Generalize named linalg.ops and see if it helps one-shot-bufferization pass to avoid allocs in loops.
    Potential problems: we don't know how good linalg.generics are handled by linalg-to-xegpu pass.
    Tried this. Doesn't work.
  2. Assign SLM address to memrefs allocated in kernels (addr_space=3) and hope that gpu-to-spv pass lowers them correctly.
  3. Fuse linalg.fill into linalg.max using this pass from TPP. This fixes this particular case, but the problem of allocs in gpu kernels still stays. Tried this on the example and it works

any other options that I forgot to mention? @kurapov-peter

@kurapov-peter
Copy link
Contributor

Right, the fill should be handled specially to avoid unnecessary allocations. In most cases, those should actually be placed onto registers and fall back to SLM when we don't have enough. Hoisting allocations for all the groups can provide functional correctness, yet there's little value in it as it'd produce some dead-slow kernels (even though there might be cases where that's necessary but I'd rather see them first).
There are two things we should be doing:

  1. Create a separate internal allocation handling pass that would ensure we are not producing any incorrect IR for the lowering to consume. Here we'd manage the SLM allocations, for example.
  2. Prepare the input IR for gpu lowering in a way that avoids unnecessary allocations in the first place. This can be done by improving the bufferization capabilities, improved named ops semantics, and downstream combination of upstream transformations such as generalization+fusion.

For simple cases such as MLP, the latter should suffice, as you mention. For generic cases, we'll need additional handling similar to, for example, what iree does with multi-buffering and other optimizations.

SLM allocations/deallocations should adhere to semantic restrictions and land into the gpu.launch memory attributions (workgroup). In that sense, we just need to agree on how those are handled between passes. I would guess that it should be easier to analyze SLM allocations when they are tied to a kernel (for example, are placed inside). If we had to split/create two or more kernels within a single module we'd need additional information attached to the ops.

@dchigarev dchigarev changed the title InsertGPUAllocs does not properly insert dealloc for temporary buffers in kernel code one-shot-bufferize pass generates memref.alloc()s in GPU kernels code and breaks the pipeline Oct 10, 2024
@dchigarev
Copy link
Contributor

dchigarev commented Oct 10, 2024

  1. Fuse linalg.fill into linalg.max

Using upstream passes for the fusion (--linalg-generalize-named-ops + --linalg-fuse-elementwise-ops) indeed generates code that doesn't cause any memref.allocs() to appear in the GPU kernel:

reproducer after applying the passes above
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>, #dlti.dl_entry<"num_threads", 4 : i32>, #dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>, #dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>, #dlti.dl_entry<"L3_cache_size_in_bytes", 1966080 : i32>, #dlti.dl_entry<"max_vector_width", 512 : i32>>>} {
  func.func @entry(%arg0: memref<128x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<128x1024xf16>, %arg3: memref<128x1024xf16>) attributes {compiletime_const_args_index = [1 : i32, 2 : i32]} {
    %cst = arith.constant 0.000000e+00 : f16
    %0 = bufferization.to_tensor %arg0 restrict : memref<128x1024xf16>
    %1 = bufferization.to_tensor %arg1 restrict : memref<1024x1024xf16>
    %2 = bufferization.to_tensor %arg2 restrict : memref<128x1024xf16>
    %3 = tensor.empty() : tensor<128x1024xf16>
    %4 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%3 : tensor<128x1024xf16>) {
    ^bb0(%out: f16):
      linalg.yield %cst : f16
    } -> tensor<128x1024xf16>
    %5 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%0, %1 : tensor<128x1024xf16>, tensor<1024x1024xf16>) outs(%4 : tensor<128x1024xf16>) {
    ^bb0(%in: f16, %in_0: f16, %out: f16):
      %8 = arith.mulf %in, %in_0 : f16
      %9 = arith.addf %out, %8 : f16
      linalg.yield %9 : f16
    } -> tensor<128x1024xf16>
    %6 = tensor.empty() : tensor<128x1024xf16>
    %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%5, %2 : tensor<128x1024xf16>, tensor<128x1024xf16>) outs(%6 : tensor<128x1024xf16>) {
    ^bb0(%in: f16, %in_0: f16, %out: f16):
      %8 = arith.addf %in, %in_0 : f16
      %9 = arith.maximumf %8, %cst : f16
      linalg.yield %9 : f16
    } -> tensor<128x1024xf16>
    bufferization.materialize_in_destination %7 in restrict writable %arg3 : (tensor<128x1024xf16>, memref<128x1024xf16>) -> ()
    return
  }
}

linalg-to-xegpu pass doesn't handle linalg.generics very good so we have to convert them back to named ops where possible. Applying --linalg-specialize-generic-ops pass helps to convert some generics but not all of them:

mlir after specializing generic ops
#map = affine_map<(d0, d1) -> (d0, d1)>
module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>, #dlti.dl_entry<"num_threads", 4 : i32>, #dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>, #dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>, #dlti.dl_entry<"L3_cache_size_in_bytes", 1966080 : i32>, #dlti.dl_entry<"max_vector_width", 512 : i32>>>} {
  func.func @entry(%arg0: memref<128x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<128x1024xf16>, %arg3: memref<128x1024xf16>) attributes {compiletime_const_args_index = [1 : i32, 2 : i32]} {
    %cst = arith.constant 0.000000e+00 : f16
    %0 = bufferization.to_tensor %arg0 restrict : memref<128x1024xf16>
    %1 = bufferization.to_tensor %arg1 restrict : memref<1024x1024xf16>
    %2 = bufferization.to_tensor %arg2 restrict : memref<128x1024xf16>
    %3 = tensor.empty() : tensor<128x1024xf16>
    %4 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%3 : tensor<128x1024xf16>) {
    ^bb0(%out: f16):
      linalg.yield %cst : f16
    } -> tensor<128x1024xf16>
    %5 = linalg.matmul_transpose_b ins(%0, %1 : tensor<128x1024xf16>, tensor<1024x1024xf16>) outs(%4 : tensor<128x1024xf16>) -> tensor<128x1024xf16>
    %6 = tensor.empty() : tensor<128x1024xf16>
    %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%5, %2 : tensor<128x1024xf16>, tensor<128x1024xf16>) outs(%6 : tensor<128x1024xf16>) {
    ^bb0(%in: f16, %in_0: f16, %out: f16):
      %8 = arith.addf %in, %in_0 : f16
      %9 = arith.maximumf %8, %cst : f16
      linalg.yield %9 : f16
    } -> tensor<128x1024xf16>
    bufferization.materialize_in_destination %7 in restrict writable %arg3 : (tensor<128x1024xf16>, memref<128x1024xf16>) -> ()
    return
  }
}

The first of the remaining generics is a linalg.fill() that was not de-generalized due to strict restrictions on what the pass assumes to be a fill op (this can be fixed with this patch)

The only way to generalize the second generic is to replace it with linalg.add + linalg.max (which kind of ruins the fusion). I was able to write a patch that lowers this specific generic correctly at the linalg-to-xegpu pass (the patch simply searches for generics that contain arith.addf + arith.maximumf and converts them into a vector style).

@kurapov-peter I'm wondering what our next steps should be. Are we okay with having some linalg-ops as generics during our pipeline (I thought you've said we should better have named ops)? How should we handle linalg.generics in linalg-to-xegpu pass? Should it be a general mechanism that is able to convert arbitrary linalg generics or it's okay to only match known patterns and add more cases as we encounter them?

UPD: discussed offline and decided to gradually apply linalg-fuse-elementwise-ops only to the problematic pattern (linalg.fill + linalg.add + linalg.max) by generalizing it first via transform.structured.generalize

@dchigarev
Copy link
Contributor

dchigarev commented Oct 16, 2024

UPD: discussed offline and decided to gradually apply linalg-fuse-elementwise-ops only to the problematic pattern (linalg.fill + linalg.add + linalg.max) by generalizing it first via transform.structured.generalize

Unfortunately this won't solve the problem completely. Besides the case with linalg.broadcast described in #382, @AndreyPavlenko also stumbled into another memref.alloc-in-a-loop case while developing the nested-tiling pass.

It appears that in case of a nested-looping the one-shot-bufferize pass is more likely to produce allocations inside a loop (so inside a gpu kernel). Even a simple matmul module that used to work fine is now affected by this problem:

Simple matmul module with nested tiling
func.func @entry(%arg0: memref<64x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<64x128xf16>, %arg3: memref<i8>) {
  %cst = arith.constant 0.000000e+00 : f16
  %alloc = memref.alloc() {alignment = 64 : i64} : memref<64x128xf16>
  scf.forall (%arg4, %arg5) = (0, 0) to (64, 128) step (64, 128) {
    %subview = memref.subview %arg0[%arg4, 0] [64, 128] [1, 1] : memref<64x128xf16> to memref<64x128xf16, strided<[128, 1], offset: ?>>
    %subview_0 = memref.subview %arg1[%arg5, 0] [128, 128] [1, 1] : memref<128x128xf16> to memref<128x128xf16, strided<[128, 1], offset: ?>>
    %subview_1 = memref.subview %alloc[%arg4, %arg5] [64, 128] [1, 1] : memref<64x128xf16> to memref<64x128xf16, strided<[128, 1], offset: ?>>
    %subview_2 = memref.subview %arg0[%arg4, %arg5] [64, 128] [1, 1] : memref<64x128xf16> to memref<64x128xf16, strided<[128, 1], offset: ?>>
    %subview_3 = memref.subview %arg2[%arg4, %arg5] [64, 128] [1, 1] : memref<64x128xf16> to memref<64x128xf16, strided<[128, 1], offset: ?>>
    scf.forall (%arg6, %arg7) = (0, 0) to (64, 128) step (8, 16) {
      %subview_5 = memref.subview %subview[%arg6, 0] [8, 128] [1, 1] : memref<64x128xf16, strided<[128, 1], offset: ?>> to memref<8x128xf16, strided<[128, 1], offset: ?>>
      %subview_6 = memref.subview %subview_0[%arg7, 0] [16, 128] [1, 1] : memref<128x128xf16, strided<[128, 1], offset: ?>> to memref<16x128xf16, strided<[128, 1], offset: ?>>
      %subview_7 = memref.subview %subview_1[%arg6, %arg7] [8, 16] [1, 1] : memref<64x128xf16, strided<[128, 1], offset: ?>> to memref<8x16xf16, strided<[128, 1], offset: ?>>
      // allocating buffer for 'linalg.fill' result (also matmul result)
      %alloc_8 = memref.alloc() {alignment = 64 : i64} : memref<8x16xf16>
      linalg.fill ins(%cst : f16) outs(%alloc_8 : memref<8x16xf16>)
      linalg.matmul_transpose_b ins(%subview_5, %subview_6 : memref<8x128xf16, strided<[128, 1], offset: ?>>, memref<16x128xf16, strided<[128, 1], offset: ?>>) outs(%alloc_8 : memref<8x16xf16>)
      %subview_9 = memref.subview %subview_2[%arg6, %arg7] [8, 16] [1, 1] : memref<64x128xf16, strided<[128, 1], offset: ?>> to memref<8x16xf16, strided<[128, 1], offset: ?>>
      %subview_10 = memref.subview %subview_3[%arg6, %arg7] [8, 16] [1, 1] : memref<64x128xf16, strided<[128, 1], offset: ?>> to memref<8x16xf16, strided<[128, 1], offset: ?>>
      linalg.add ins(%alloc_8, %subview_9 : memref<8x16xf16>, memref<8x16xf16, strided<[128, 1], offset: ?>>) outs(%subview_10 : memref<8x16xf16, strided<[128, 1], offset: ?>>)
      %subview_11 = memref.subview %subview_3[%arg6, %arg7] [8, 16] [1, 1] : memref<64x128xf16, strided<[128, 1], offset: ?>> to memref<8x16xf16, strided<[128, 1], offset: ?>>
      memref.copy %subview_10, %subview_11 : memref<8x16xf16, strided<[128, 1], offset: ?>> to memref<8x16xf16, strided<[128, 1], offset: ?>>
    }
    %subview_4 = memref.subview %arg2[%arg4, %arg5] [64, 128] [1, 1] : memref<64x128xf16> to memref<64x128xf16, strided<[128, 1], offset: ?>>
    memref.copy %subview_3, %subview_4 : memref<64x128xf16, strided<[128, 1], offset: ?>> to memref<64x128xf16, strided<[128, 1], offset: ?>>
  }
  memref.copy %arg2, %arg2 : memref<64x128xf16> to memref<64x128xf16>
  return
}

@kurapov-peter what do you think we should do about this problem right now? Will lowering memref.alloc()<8x16> + linalg.fill(%cst) into %res = arith.constant %cst : vector<8x16> in linalg-to-xegpu pass be enough for now? (it should work at least for static shapes but would require getting rid of xegpu.store/load for intermediate results in order to be able to forward the 0-tensor from matmul result to the next operation (add) directly). Or should better we start with SLM-multi-buffering-thing right away?

@kurapov-peter
Copy link
Contributor

I think it's time to address the problem. Let's start with a pass that would put the allocas to SLM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants