Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FourierNeuralOperators segfault #185

Open
avik-pal opened this issue Dec 13, 2024 · 4 comments
Open

FourierNeuralOperators segfault #185

avik-pal opened this issue Dec 13, 2024 · 4 comments

Comments

@avik-pal
Copy link
Collaborator

Unoptimized MLIR

module {
  func.func private @"+_broadcast_scalar"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar1"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar2"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @gelu_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
    %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
    %1 = stablehlo.multiply %0, %0 : tensor<f32>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %2 = stablehlo.multiply %1, %cst : tensor<f32>
    %3 = stablehlo.add %2, %cst_1 : tensor<f32>
    %4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
    %5 = stablehlo.multiply %4, %3 : tensor<f32>
    %6 = stablehlo.logistic %5 : tensor<f32>
    %7 = stablehlo.multiply %0, %6 : tensor<f32>
    %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
    %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    return %8, %9 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar3"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar4"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @gelu_broadcast_scalar1(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
    %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
    %1 = stablehlo.multiply %0, %0 : tensor<f32>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %2 = stablehlo.multiply %1, %cst : tensor<f32>
    %3 = stablehlo.add %2, %cst_1 : tensor<f32>
    %4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
    %5 = stablehlo.multiply %4, %3 : tensor<f32>
    %6 = stablehlo.logistic %5 : tensor<f32>
    %7 = stablehlo.multiply %0, %6 : tensor<f32>
    %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
    %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    return %8, %9 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar5"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar6"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @gelu_broadcast_scalar2(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
    %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
    %1 = stablehlo.multiply %0, %0 : tensor<f32>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %2 = stablehlo.multiply %1, %cst : tensor<f32>
    %3 = stablehlo.add %2, %cst_1 : tensor<f32>
    %4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
    %5 = stablehlo.multiply %4, %3 : tensor<f32>
    %6 = stablehlo.logistic %5 : tensor<f32>
    %7 = stablehlo.multiply %0, %6 : tensor<f32>
    %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
    %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    return %8, %9 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar7"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar8"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @gelu_broadcast_scalar3(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
    %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
    %1 = stablehlo.multiply %0, %0 : tensor<f32>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %2 = stablehlo.multiply %1, %cst : tensor<f32>
    %3 = stablehlo.add %2, %cst_1 : tensor<f32>
    %4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
    %5 = stablehlo.multiply %4, %3 : tensor<f32>
    %6 = stablehlo.logistic %5 : tensor<f32>
    %7 = stablehlo.multiply %0, %6 : tensor<f32>
    %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
    %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    return %8, %9 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar9"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @gelu_broadcast_scalar4(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %cst = stablehlo.constant dense<4.471500e-02> : tensor<f32>
    %cst_0 = stablehlo.constant dense<1.59576917> : tensor<f32>
    %1 = stablehlo.multiply %0, %0 : tensor<f32>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %2 = stablehlo.multiply %1, %cst : tensor<f32>
    %3 = stablehlo.add %2, %cst_1 : tensor<f32>
    %4 = stablehlo.multiply %cst_0, %0 : tensor<f32>
    %5 = stablehlo.multiply %4, %3 : tensor<f32>
    %6 = stablehlo.logistic %5 : tensor<f32>
    %7 = stablehlo.multiply %0, %6 : tensor<f32>
    %8 = stablehlo.transpose %7, dims = [] : (tensor<f32>) -> tensor<f32>
    %9 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    return %8, %9 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar10"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<f32>) -> tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4, %5 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @abs2_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.abs %0 : tensor<f32>
    %2 = stablehlo.multiply %1, %1 : tensor<f32>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<f32>) -> tensor<f32>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    return %3, %4 : tensor<f32>, tensor<f32>
  }
  func.func private @"Const{typeof(sumabs2first)}(Main.sumabs2first)_autodiff"(%arg0: tensor<2x64xf32>, %arg1: tensor<64xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64xf32>, %arg4: tensor<16x64x64xcomplex<f32>>, %arg5: tensor<64x64xf32>, %arg6: tensor<64xf32>, %arg7: tensor<16x64x64xcomplex<f32>>, %arg8: tensor<64x64xf32>, %arg9: tensor<64xf32>, %arg10: tensor<16x64x64xcomplex<f32>>, %arg11: tensor<64x64xf32>, %arg12: tensor<64xf32>, %arg13: tensor<16x64x64xcomplex<f32>>, %arg14: tensor<64x128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x1xf32>, %arg17: tensor<1xf32>, %arg18: tensor<5x32x2xf32>) -> (tensor<f32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %2 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %3 = stablehlo.transpose %arg3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %4 = stablehlo.transpose %arg4, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %5 = stablehlo.transpose %arg5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %6 = stablehlo.transpose %arg6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %7 = stablehlo.transpose %arg7, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %8 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %9 = stablehlo.transpose %arg9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %10 = stablehlo.transpose %arg10, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %11 = stablehlo.transpose %arg11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %12 = stablehlo.transpose %arg12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %13 = stablehlo.transpose %arg13, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %14 = stablehlo.transpose %arg14, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32>
    %15 = stablehlo.transpose %arg15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %16 = stablehlo.transpose %arg16, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32>
    %17 = stablehlo.transpose %arg17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %18 = stablehlo.transpose %arg18, dims = [2, 1, 0] : (tensor<5x32x2xf32>) -> tensor<2x32x5xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
    %19 = stablehlo.transpose %18, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32>
    %20 = stablehlo.reshape %19 : (tensor<5x32x2xf32>) -> tensor<160x2xf32>
    %21 = stablehlo.transpose %20, dims = [1, 0] : (tensor<160x2xf32>) -> tensor<2x160xf32>
    %22 = stablehlo.dot_general %0, %21, contracting_dims = [1] x [0] : (tensor<64x2xf32>, tensor<2x160xf32>) -> tensor<64x160xf32>
    %23 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
    %24:3 = enzyme.batch @"+_broadcast_scalar"(%22, %23) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
    %25 = stablehlo.transpose %24#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
    %26 = stablehlo.reshape %25 : (tensor<160x64xf32>) -> tensor<160x64xf32>
    %27 = stablehlo.transpose %26, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
    %28 = stablehlo.dot_general %2, %27, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32>
    %29 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
    %30:3 = enzyme.batch @"+_broadcast_scalar1"(%28, %29) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
    %31 = stablehlo.transpose %24#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
    %32 = stablehlo.reshape %31 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
    %33 = stablehlo.transpose %32, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
    %34 = stablehlo.transpose %33, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32>
    %35 = stablehlo.convert %34 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>>
    %36 = stablehlo.transpose %35, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %37 = stablehlo.fft %36, type =  FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %38 = stablehlo.transpose %37, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %c = stablehlo.constant dense<0> : tensor<i64>
    %c_1 = stablehlo.constant dense<0> : tensor<i64>
    %c_2 = stablehlo.constant dense<0> : tensor<i64>
    %39 = stablehlo.dynamic_slice %38, %c, %c_1, %c_2, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>>
    %40 = stablehlo.transpose %39, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %41 = stablehlo.reshape %40 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %42 = stablehlo.transpose %41, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %43 = stablehlo.transpose %42, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %cst_3 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>>
    %44 = stablehlo.transpose %4, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %45 = stablehlo.transpose %43, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %46 = stablehlo.dot_general %44, %45, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %47 = stablehlo.transpose %46, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %48 = stablehlo.transpose %47, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %cst_4 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>>
    %49 = stablehlo.transpose %48, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %50 = stablehlo.reshape %49 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %51 = stablehlo.transpose %50, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %52 = stablehlo.pad %51, %cst_4, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %53 = stablehlo.transpose %52, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %54 = stablehlo.fft %53, type =  IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %55 = stablehlo.transpose %54, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %56 = stablehlo.real %55 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32>
    %57 = stablehlo.transpose %56, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32>
    %58 = stablehlo.transpose %30#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
    %59 = stablehlo.reshape %58 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
    %60 = stablehlo.transpose %59, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
    %61:3 = enzyme.batch @"+_broadcast_scalar2"(%60, %57) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %62:2 = enzyme.batch @gelu_broadcast_scalar(%61#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
    %63 = stablehlo.transpose %62#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32>
    %64 = stablehlo.reshape %63 : (tensor<5x32x64xf32>) -> tensor<160x64xf32>
    %65 = stablehlo.transpose %64, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
    %66 = stablehlo.dot_general %5, %65, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32>
    %67 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
    %68:3 = enzyme.batch @"+_broadcast_scalar3"(%66, %67) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
    %69 = stablehlo.transpose %62#0, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32>
    %70 = stablehlo.convert %69 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>>
    %71 = stablehlo.transpose %70, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %72 = stablehlo.fft %71, type =  FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %73 = stablehlo.transpose %72, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %c_6 = stablehlo.constant dense<0> : tensor<i64>
    %c_7 = stablehlo.constant dense<0> : tensor<i64>
    %c_8 = stablehlo.constant dense<0> : tensor<i64>
    %74 = stablehlo.dynamic_slice %73, %c_6, %c_7, %c_8, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>>
    %75 = stablehlo.transpose %74, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %76 = stablehlo.reshape %75 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %77 = stablehlo.transpose %76, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %78 = stablehlo.transpose %77, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %cst_9 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>>
    %79 = stablehlo.transpose %7, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %80 = stablehlo.transpose %78, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %81 = stablehlo.dot_general %79, %80, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %82 = stablehlo.transpose %81, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %83 = stablehlo.transpose %82, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %cst_10 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>>
    %84 = stablehlo.transpose %83, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %85 = stablehlo.reshape %84 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %86 = stablehlo.transpose %85, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %87 = stablehlo.pad %86, %cst_10, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %88 = stablehlo.transpose %87, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %89 = stablehlo.fft %88, type =  IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %90 = stablehlo.transpose %89, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %91 = stablehlo.real %90 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32>
    %92 = stablehlo.transpose %91, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32>
    %93 = stablehlo.transpose %68#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
    %94 = stablehlo.reshape %93 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
    %95 = stablehlo.transpose %94, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
    %96:3 = enzyme.batch @"+_broadcast_scalar4"(%95, %92) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %97:2 = enzyme.batch @gelu_broadcast_scalar1(%96#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
    %98 = stablehlo.transpose %97#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32>
    %99 = stablehlo.reshape %98 : (tensor<5x32x64xf32>) -> tensor<160x64xf32>
    %100 = stablehlo.transpose %99, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
    %101 = stablehlo.dot_general %8, %100, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32>
    %102 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
    %103:3 = enzyme.batch @"+_broadcast_scalar5"(%101, %102) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
    %104 = stablehlo.transpose %97#0, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32>
    %105 = stablehlo.convert %104 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>>
    %106 = stablehlo.transpose %105, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %107 = stablehlo.fft %106, type =  FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %108 = stablehlo.transpose %107, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %c_12 = stablehlo.constant dense<0> : tensor<i64>
    %c_13 = stablehlo.constant dense<0> : tensor<i64>
    %c_14 = stablehlo.constant dense<0> : tensor<i64>
    %109 = stablehlo.dynamic_slice %108, %c_12, %c_13, %c_14, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>>
    %110 = stablehlo.transpose %109, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %111 = stablehlo.reshape %110 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %112 = stablehlo.transpose %111, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %113 = stablehlo.transpose %112, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %cst_15 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>>
    %114 = stablehlo.transpose %10, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %115 = stablehlo.transpose %113, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %116 = stablehlo.dot_general %114, %115, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %117 = stablehlo.transpose %116, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %118 = stablehlo.transpose %117, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %cst_16 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>>
    %119 = stablehlo.transpose %118, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %120 = stablehlo.reshape %119 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %121 = stablehlo.transpose %120, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %122 = stablehlo.pad %121, %cst_16, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %123 = stablehlo.transpose %122, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %124 = stablehlo.fft %123, type =  IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %125 = stablehlo.transpose %124, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %126 = stablehlo.real %125 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32>
    %127 = stablehlo.transpose %126, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32>
    %128 = stablehlo.transpose %103#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
    %129 = stablehlo.reshape %128 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
    %130 = stablehlo.transpose %129, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
    %131:3 = enzyme.batch @"+_broadcast_scalar6"(%130, %127) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %132:2 = enzyme.batch @gelu_broadcast_scalar2(%131#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %cst_17 = stablehlo.constant dense<0.000000e+00> : tensor<64x160xf32>
    %133 = stablehlo.transpose %132#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32>
    %134 = stablehlo.reshape %133 : (tensor<5x32x64xf32>) -> tensor<160x64xf32>
    %135 = stablehlo.transpose %134, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
    %136 = stablehlo.dot_general %11, %135, contracting_dims = [1] x [0] : (tensor<64x64xf32>, tensor<64x160xf32>) -> tensor<64x160xf32>
    %137 = stablehlo.broadcast_in_dim %12, dims = [0] : (tensor<64xf32>) -> tensor<64x160xf32>
    %138:3 = enzyme.batch @"+_broadcast_scalar7"(%136, %137) {batch_shape = array<i64: 64, 160>} : (tensor<64x160xf32>, tensor<64x160xf32>) -> (tensor<64x160xf32>, tensor<64x160xf32>, tensor<64x160xf32>)
    %139 = stablehlo.transpose %132#0, dims = [1, 0, 2] : (tensor<64x32x5xf32>) -> tensor<32x64x5xf32>
    %140 = stablehlo.convert %139 : (tensor<32x64x5xf32>) -> tensor<32x64x5xcomplex<f32>>
    %141 = stablehlo.transpose %140, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %142 = stablehlo.fft %141, type =  FFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %143 = stablehlo.transpose %142, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %c_18 = stablehlo.constant dense<0> : tensor<i64>
    %c_19 = stablehlo.constant dense<0> : tensor<i64>
    %c_20 = stablehlo.constant dense<0> : tensor<i64>
    %144 = stablehlo.dynamic_slice %143, %c_18, %c_19, %c_20, sizes = [16, 64, 5] : (tensor<32x64x5xcomplex<f32>>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<16x64x5xcomplex<f32>>
    %145 = stablehlo.transpose %144, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %146 = stablehlo.reshape %145 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %147 = stablehlo.transpose %146, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %148 = stablehlo.transpose %147, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %cst_21 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x5x16xcomplex<f32>>
    %149 = stablehlo.transpose %13, dims = [2, 0, 1] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %150 = stablehlo.transpose %148, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %151 = stablehlo.dot_general %149, %150, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x5xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %152 = stablehlo.transpose %151, dims = [1, 2, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<64x5x16xcomplex<f32>>
    %153 = stablehlo.transpose %152, dims = [2, 0, 1] : (tensor<64x5x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %cst_22 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<complex<f32>>
    %154 = stablehlo.transpose %153, dims = [2, 1, 0] : (tensor<16x64x5xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %155 = stablehlo.reshape %154 : (tensor<5x64x16xcomplex<f32>>) -> tensor<5x64x16xcomplex<f32>>
    %156 = stablehlo.transpose %155, dims = [2, 1, 0] : (tensor<5x64x16xcomplex<f32>>) -> tensor<16x64x5xcomplex<f32>>
    %157 = stablehlo.pad %156, %cst_22, low = [0, 0, 0], high = [16, 0, 0], interior = [0, 0, 0] : (tensor<16x64x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %158 = stablehlo.transpose %157, dims = [2, 1, 0] : (tensor<32x64x5xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %159 = stablehlo.fft %158, type =  IFFT, length = [32] : (tensor<5x64x32xcomplex<f32>>) -> tensor<5x64x32xcomplex<f32>>
    %160 = stablehlo.transpose %159, dims = [2, 1, 0] : (tensor<5x64x32xcomplex<f32>>) -> tensor<32x64x5xcomplex<f32>>
    %161 = stablehlo.real %160 : (tensor<32x64x5xcomplex<f32>>) -> tensor<32x64x5xf32>
    %162 = stablehlo.transpose %161, dims = [1, 0, 2] : (tensor<32x64x5xf32>) -> tensor<64x32x5xf32>
    %163 = stablehlo.transpose %138#0, dims = [1, 0] : (tensor<64x160xf32>) -> tensor<160x64xf32>
    %164 = stablehlo.reshape %163 : (tensor<160x64xf32>) -> tensor<5x32x64xf32>
    %165 = stablehlo.transpose %164, dims = [2, 1, 0] : (tensor<5x32x64xf32>) -> tensor<64x32x5xf32>
    %166:3 = enzyme.batch @"+_broadcast_scalar8"(%165, %162) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>, tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %167:2 = enzyme.batch @gelu_broadcast_scalar3(%166#0) {batch_shape = array<i64: 64, 32, 5>} : (tensor<64x32x5xf32>) -> (tensor<64x32x5xf32>, tensor<64x32x5xf32>)
    %cst_23 = stablehlo.constant dense<0.000000e+00> : tensor<128x160xf32>
    %168 = stablehlo.transpose %167#0, dims = [2, 1, 0] : (tensor<64x32x5xf32>) -> tensor<5x32x64xf32>
    %169 = stablehlo.reshape %168 : (tensor<5x32x64xf32>) -> tensor<160x64xf32>
    %170 = stablehlo.transpose %169, dims = [1, 0] : (tensor<160x64xf32>) -> tensor<64x160xf32>
    %171 = stablehlo.dot_general %14, %170, contracting_dims = [1] x [0] : (tensor<128x64xf32>, tensor<64x160xf32>) -> tensor<128x160xf32>
    %172 = stablehlo.transpose %15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %173 = stablehlo.reshape %172 : (tensor<128xf32>) -> tensor<1x128xf32>
    %174 = stablehlo.transpose %173, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %175 = stablehlo.broadcast_in_dim %174, dims = [0, 1] : (tensor<128x1xf32>) -> tensor<128x160xf32>
    %176:3 = enzyme.batch @"+_broadcast_scalar9"(%171, %175) {batch_shape = array<i64: 128, 160>} : (tensor<128x160xf32>, tensor<128x160xf32>) -> (tensor<128x160xf32>, tensor<128x160xf32>, tensor<128x160xf32>)
    %177:2 = enzyme.batch @gelu_broadcast_scalar4(%176#0) {batch_shape = array<i64: 128, 160>} : (tensor<128x160xf32>) -> (tensor<128x160xf32>, tensor<128x160xf32>)
    %cst_24 = stablehlo.constant dense<0.000000e+00> : tensor<1x160xf32>
    %178 = stablehlo.transpose %177#0, dims = [1, 0] : (tensor<128x160xf32>) -> tensor<160x128xf32>
    %179 = stablehlo.reshape %178 : (tensor<160x128xf32>) -> tensor<160x128xf32>
    %180 = stablehlo.transpose %179, dims = [1, 0] : (tensor<160x128xf32>) -> tensor<128x160xf32>
    %181 = stablehlo.dot_general %16, %180, contracting_dims = [1] x [0] : (tensor<1x128xf32>, tensor<128x160xf32>) -> tensor<1x160xf32>
    %182 = stablehlo.broadcast_in_dim %17, dims = [0] : (tensor<1xf32>) -> tensor<1x160xf32>
    %183:3 = enzyme.batch @"+_broadcast_scalar10"(%181, %182) {batch_shape = array<i64: 1, 160>} : (tensor<1x160xf32>, tensor<1x160xf32>) -> (tensor<1x160xf32>, tensor<1x160xf32>, tensor<1x160xf32>)
    %184 = stablehlo.transpose %183#0, dims = [1, 0] : (tensor<1x160xf32>) -> tensor<160x1xf32>
    %185 = stablehlo.reshape %184 : (tensor<160x1xf32>) -> tensor<5x32x1xf32>
    %186 = stablehlo.transpose %185, dims = [2, 1, 0] : (tensor<5x32x1xf32>) -> tensor<1x32x5xf32>
    %cst_25 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %187:2 = enzyme.batch @abs2_broadcast_scalar(%186) {batch_shape = array<i64: 1, 32, 5>} : (tensor<1x32x5xf32>) -> (tensor<1x32x5xf32>, tensor<1x32x5xf32>)
    %188 = stablehlo.reduce(%187#0 init: %cst_25) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<1x32x5xf32>, tensor<f32>) -> tensor<f32>
    %189 = stablehlo.transpose %188, dims = [] : (tensor<f32>) -> tensor<f32>
    %190 = stablehlo.transpose %0, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
    %191 = stablehlo.transpose %1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %192 = stablehlo.transpose %2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %193 = stablehlo.transpose %3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %194 = stablehlo.transpose %4, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %195 = stablehlo.transpose %5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %196 = stablehlo.transpose %6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %197 = stablehlo.transpose %7, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %198 = stablehlo.transpose %8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %199 = stablehlo.transpose %9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %200 = stablehlo.transpose %10, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %201 = stablehlo.transpose %11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %202 = stablehlo.transpose %12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %203 = stablehlo.transpose %13, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %204 = stablehlo.transpose %14, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
    %205 = stablehlo.transpose %15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %206 = stablehlo.transpose %16, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %207 = stablehlo.transpose %17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %208 = stablehlo.transpose %18, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32>
    return %189, %190, %191, %192, %193, %194, %195, %196, %197, %198, %199, %200, %201, %202, %203, %204, %205, %206, %207, %208 : tensor<f32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>
  }
  func.func @main(%arg0: tensor<2x64xf32>, %arg1: tensor<64xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64xf32>, %arg4: tensor<16x64x64xcomplex<f32>>, %arg5: tensor<64x64xf32>, %arg6: tensor<64xf32>, %arg7: tensor<16x64x64xcomplex<f32>>, %arg8: tensor<64x64xf32>, %arg9: tensor<64xf32>, %arg10: tensor<16x64x64xcomplex<f32>>, %arg11: tensor<64x64xf32>, %arg12: tensor<64xf32>, %arg13: tensor<16x64x64xcomplex<f32>>, %arg14: tensor<64x128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x1xf32>, %arg17: tensor<1xf32>, %arg18: tensor<5x32x2xf32>) -> (tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %2 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %3 = stablehlo.transpose %arg3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %4 = stablehlo.transpose %arg4, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %5 = stablehlo.transpose %arg5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %6 = stablehlo.transpose %arg6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %7 = stablehlo.transpose %arg7, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %8 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %9 = stablehlo.transpose %arg9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %10 = stablehlo.transpose %arg10, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %11 = stablehlo.transpose %arg11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %12 = stablehlo.transpose %arg12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %13 = stablehlo.transpose %arg13, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %14 = stablehlo.transpose %arg14, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32>
    %15 = stablehlo.transpose %arg15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %16 = stablehlo.transpose %arg16, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32>
    %17 = stablehlo.transpose %arg17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %18 = stablehlo.transpose %arg18, dims = [2, 1, 0] : (tensor<5x32x2xf32>) -> tensor<2x32x5xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<64x2xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
    %cst_3 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>>
    %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32>
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
    %cst_6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>>
    %cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32>
    %cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
    %cst_9 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>>
    %cst_10 = stablehlo.constant dense<0.000000e+00> : tensor<64x64xf32>
    %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
    %cst_12 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<64x64x16xcomplex<f32>>
    %cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<128x64xf32>
    %cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<128xf32>
    %cst_15 = stablehlo.constant dense<0.000000e+00> : tensor<1x128xf32>
    %cst_16 = stablehlo.constant dense<0.000000e+00> : tensor<1xf32>
    %cst_17 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %19 = stablehlo.transpose %0, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
    %20 = stablehlo.transpose %1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %21 = stablehlo.transpose %2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %22 = stablehlo.transpose %3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %23 = stablehlo.transpose %4, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %24 = stablehlo.transpose %5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %25 = stablehlo.transpose %6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %26 = stablehlo.transpose %7, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %27 = stablehlo.transpose %8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %28 = stablehlo.transpose %9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %29 = stablehlo.transpose %10, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %30 = stablehlo.transpose %11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %31 = stablehlo.transpose %12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %32 = stablehlo.transpose %13, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %33 = stablehlo.transpose %14, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
    %34 = stablehlo.transpose %15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %35 = stablehlo.transpose %16, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %36 = stablehlo.transpose %17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %37 = stablehlo.transpose %18, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32>
    %38 = stablehlo.transpose %cst_17, dims = [] : (tensor<f32>) -> tensor<f32>
    %39 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
    %40 = stablehlo.transpose %cst_0, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %41 = stablehlo.transpose %cst_1, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %42 = stablehlo.transpose %cst_2, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %43 = stablehlo.transpose %cst_3, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %44 = stablehlo.transpose %cst_4, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %45 = stablehlo.transpose %cst_5, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %46 = stablehlo.transpose %cst_6, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %47 = stablehlo.transpose %cst_7, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %48 = stablehlo.transpose %cst_8, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %49 = stablehlo.transpose %cst_9, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %50 = stablehlo.transpose %cst_10, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %51 = stablehlo.transpose %cst_11, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %52 = stablehlo.transpose %cst_12, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %53 = stablehlo.transpose %cst_13, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
    %54 = stablehlo.transpose %cst_14, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %55 = stablehlo.transpose %cst_15, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %56 = stablehlo.transpose %cst_16, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %57:37 = enzyme.autodiff @"Const{typeof(sumabs2first)}(Main.sumabs2first)_autodiff"(%19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>]} : (tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>, tensor<f32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>) -> (tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>)
    %58 = stablehlo.transpose %57#0, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32>
    %59 = stablehlo.transpose %57#1, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %60 = stablehlo.transpose %57#2, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %61 = stablehlo.transpose %57#3, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %62 = stablehlo.transpose %57#4, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %63 = stablehlo.transpose %57#5, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %64 = stablehlo.transpose %57#6, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %65 = stablehlo.transpose %57#7, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %66 = stablehlo.transpose %57#8, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %67 = stablehlo.transpose %57#9, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %68 = stablehlo.transpose %57#10, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %69 = stablehlo.transpose %57#11, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %70 = stablehlo.transpose %57#12, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %71 = stablehlo.transpose %57#13, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %72 = stablehlo.transpose %57#14, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32>
    %73 = stablehlo.transpose %57#15, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %74 = stablehlo.transpose %57#16, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32>
    %75 = stablehlo.transpose %57#17, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %76 = stablehlo.transpose %57#18, dims = [2, 1, 0] : (tensor<5x32x2xf32>) -> tensor<2x32x5xf32>
    %77 = stablehlo.transpose %57#19, dims = [1, 0] : (tensor<2x64xf32>) -> tensor<64x2xf32>
    %78 = stablehlo.transpose %57#20, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %79 = stablehlo.transpose %57#21, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %80 = stablehlo.transpose %57#22, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %81 = stablehlo.transpose %57#23, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %82 = stablehlo.transpose %57#24, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %83 = stablehlo.transpose %57#25, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %84 = stablehlo.transpose %57#26, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %85 = stablehlo.transpose %57#27, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %86 = stablehlo.transpose %57#28, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %87 = stablehlo.transpose %57#29, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %88 = stablehlo.transpose %57#30, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %89 = stablehlo.transpose %57#31, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %90 = stablehlo.transpose %57#32, dims = [2, 1, 0] : (tensor<16x64x64xcomplex<f32>>) -> tensor<64x64x16xcomplex<f32>>
    %91 = stablehlo.transpose %57#33, dims = [1, 0] : (tensor<64x128xf32>) -> tensor<128x64xf32>
    %92 = stablehlo.transpose %57#34, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %93 = stablehlo.transpose %57#35, dims = [1, 0] : (tensor<128x1xf32>) -> tensor<1x128xf32>
    %94 = stablehlo.transpose %57#36, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %95 = stablehlo.transpose %77, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
    %96 = stablehlo.transpose %78, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %97 = stablehlo.transpose %79, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %98 = stablehlo.transpose %80, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %99 = stablehlo.transpose %81, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %100 = stablehlo.transpose %82, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %101 = stablehlo.transpose %83, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %102 = stablehlo.transpose %84, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %103 = stablehlo.transpose %85, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %104 = stablehlo.transpose %86, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %105 = stablehlo.transpose %87, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %106 = stablehlo.transpose %88, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %107 = stablehlo.transpose %89, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %108 = stablehlo.transpose %90, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %109 = stablehlo.transpose %91, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
    %110 = stablehlo.transpose %92, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %111 = stablehlo.transpose %93, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %112 = stablehlo.transpose %94, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %113 = stablehlo.transpose %58, dims = [1, 0] : (tensor<64x2xf32>) -> tensor<2x64xf32>
    %114 = stablehlo.transpose %59, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %115 = stablehlo.transpose %60, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %116 = stablehlo.transpose %61, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %117 = stablehlo.transpose %62, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %118 = stablehlo.transpose %63, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %119 = stablehlo.transpose %64, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %120 = stablehlo.transpose %65, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %121 = stablehlo.transpose %66, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %122 = stablehlo.transpose %67, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %123 = stablehlo.transpose %68, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %124 = stablehlo.transpose %69, dims = [1, 0] : (tensor<64x64xf32>) -> tensor<64x64xf32>
    %125 = stablehlo.transpose %70, dims = [0] : (tensor<64xf32>) -> tensor<64xf32>
    %126 = stablehlo.transpose %71, dims = [2, 1, 0] : (tensor<64x64x16xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>
    %127 = stablehlo.transpose %72, dims = [1, 0] : (tensor<128x64xf32>) -> tensor<64x128xf32>
    %128 = stablehlo.transpose %73, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %129 = stablehlo.transpose %74, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %130 = stablehlo.transpose %75, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %131 = stablehlo.transpose %76, dims = [2, 1, 0] : (tensor<2x32x5xf32>) -> tensor<5x32x2xf32>
    return %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127, %128, %129, %130, %131 : tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>
  }
}

Error Message with debug build

enzymexlamlir-opt: external/llvm-project/mlir/lib/IR/Types.cpp:134: unsigned int mlir::Type::getIntOrFloatBitWidth() const: Assertion `isIntOrFloat() && "only integers and floats have a bitwidth"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: ./bazel-bin/enzymexlamlir-opt --enzyme-hlo-opt --enzyme-batch --enzyme envs/fno.mlir
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  enzymexlamlir-opt 0x000064943eed9fce
1  enzymexlamlir-opt 0x000064943eeda383
2  enzymexlamlir-opt 0x000064943eed7a84
3  enzymexlamlir-opt 0x000064943eed9af3
4  libc.so.6         0x00007709d944c1d0
5  libc.so.6         0x00007709d94a53f4
6  libc.so.6         0x00007709d944c120 gsignal + 32
7  libc.so.6         0x00007709d94334c3 abort + 223
8  libc.so.6         0x00007709d94333df
9  libc.so.6         0x00007709d9444177
10 enzymexlamlir-opt 0x000064943eba05c0
11 enzymexlamlir-opt 0x000064943b65a7c6
12 enzymexlamlir-opt 0x000064943b5a3f09
13 enzymexlamlir-opt 0x000064943b6c2f28
14 enzymexlamlir-opt 0x000064943e5ccfeb
15 enzymexlamlir-opt 0x000064943e5cda42
16 enzymexlamlir-opt 0x000064943af1b4d8
17 enzymexlamlir-opt 0x000064943e5d0d27
18 enzymexlamlir-opt 0x000064943e5cd769
19 enzymexlamlir-opt 0x000064943e5be495
20 enzymexlamlir-opt 0x000064943e5bf631
21 enzymexlamlir-opt 0x000064943e5c094b
22 enzymexlamlir-opt 0x000064943af1b4d8
23 enzymexlamlir-opt 0x000064943e5c06cf
24 enzymexlamlir-opt 0x000064943e5bf8fc
25 enzymexlamlir-opt 0x000064943e5bfab6
26 enzymexlamlir-opt 0x000064943b3e5548
27 enzymexlamlir-opt 0x000064943b5b2f6a
28 enzymexlamlir-opt 0x000064943ea6f1f9
29 enzymexlamlir-opt 0x000064943ea72eda
30 enzymexlamlir-opt 0x000064943af1b4d8
31 enzymexlamlir-opt 0x000064943ea7895f
32 enzymexlamlir-opt 0x000064943ea6f617
33 enzymexlamlir-opt 0x000064943ea6f959
34 enzymexlamlir-opt 0x000064943ea7180a
35 enzymexlamlir-opt 0x000064943ea715f9
36 enzymexlamlir-opt 0x000064943aefe9fb
37 enzymexlamlir-opt 0x000064943aeff1dd
38 enzymexlamlir-opt 0x000064943aeff8d0
39 enzymexlamlir-opt 0x000064943af00ae7
40 enzymexlamlir-opt 0x000064943edc5a42
41 enzymexlamlir-opt 0x000064943edc5198
42 enzymexlamlir-opt 0x000064943aeffa5e
43 enzymexlamlir-opt 0x000064943aeffd6c
44 enzymexlamlir-opt 0x000064943aeffffd
45 enzymexlamlir-opt 0x000064943aeb2642
46 libc.so.6         0x00007709d9434e08
47 libc.so.6         0x00007709d9434ecc __libc_start_main + 140
48 enzymexlamlir-opt 0x000064943aeb20b5
[1]    100225 IOT instruction (core dumped)  ./bazel-bin/enzymexlamlir-opt --enzyme-hlo-opt --enzyme-batch --enzyme 
@avik-pal
Copy link
Collaborator Author

xref: SciML/NeuralOperators.jl#52

@avik-pal
Copy link
Collaborator Author

Not yet fixed:

[602122] signal 11 (1): Segmentation fault
in expression starting at REPL[21]:1
unknown function (ip: 0x7f6f7e945b98)
_ZL8readBitsPKcmm at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZNK4mlir17DenseElementsAttr18IntElementIteratordeEv at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZNK4mlir17DenseElementsAttr13getSplatValueIN4llvm5APIntEEENSt9enable_ifIXoontsrSt10is_base_ofINS_9AttributeET_E5valuesrSt7is_sameIS6_S7_E5valueES7_E4typeEv at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZNK12_GLOBAL__N_111AddSimplify15matchAndRewriteEN4mlir9stablehlo5AddOpERNS1_15PatternRewriterE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZZN4mlir17PatternApplicator15matchAndRewriteEPNS_9OperationERNS_15PatternRewriterEN4llvm12function_refIFbRKNS_7PatternEEEENS6_IFvS9_EEENS6_IFNS5_13LogicalResultES9_EEEENKUlvE_clEv at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir17PatternApplicator15matchAndRewriteEPNS_9OperationERNS_15PatternRewriterEN4llvm12function_refIFbRKNS_7PatternEEEENS6_IFvS9_EEENS6_IFNS5_13LogicalResultES9_EEE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_126GreedyPatternRewriteDriver15processWorklistEv at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir28applyPatternsAndFoldGreedilyERNS_6RegionERKNS_23FrozenRewritePatternSetENS_19GreedyRewriteConfigEPb at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform15ApplyPatternsOp10applyToOneERNS0_17TransformRewriterEPNS_9OperationERNS0_21ApplyToEachResultListERNS0_14TransformStateE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform6detail20applyTransformToEachINS0_15ApplyPatternsOpERN4llvm14iterator_rangeINS4_20filter_iterator_implIPKPNS_9OperationEZNKS0_14TransformState13getPayloadOpsENS_5ValueEEUlS8_E_St26bidirectional_iterator_tagEEEEEENS_27DiagnosedSilenceableFailureET_RNS0_17TransformRewriterEOT0_RNS4_15SmallVectorImplINS0_21ApplyToEachResultListEEERSB_ at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform20TransformEachOpTraitINS0_15ApplyPatternsOpEE5applyERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform6detail35TransformOpInterfaceInterfaceTraits5ModelINS0_15ApplyPatternsOpEE5applyEPKNS2_7ConceptEPNS_9OperationERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform20TransformOpInterface5applyERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform14TransformState14applyTransformENS0_20TransformOpInterfaceE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZL18applySequenceBlockRN4mlir5BlockENS_9transform22FailurePropagationModeERNS2_14TransformStateERNS2_16TransformResultsE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform15NamedSequenceOp5applyERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform6detail35TransformOpInterfaceInterfaceTraits5ModelINS0_15NamedSequenceOpEE5applyEPKNS2_7ConceptEPNS_9OperationERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform20TransformOpInterface5applyERNS0_17TransformRewriterERNS0_16TransformResultsERNS0_14TransformStateE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform14TransformState14applyTransformENS0_20TransformOpInterfaceE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform15applyTransformsEPNS_9OperationENS0_20TransformOpInterfaceERKNS_11RaggedArrayIN4llvm12PointerUnionIJS2_NS_9AttributeENS_5ValueEEEEEERKNS0_16TransformOptionsEbNS5_12function_refIFvRNS0_14TransformStateEEEENSG_IFNS5_13LogicalResultESI_EEE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir9transform27applyTransformNamedSequenceENS_11RaggedArrayIN4llvm12PointerUnionIJPNS_9OperationENS_9AttributeENS_5ValueEEEEEENS0_20TransformOpInterfaceENS_8ModuleOpERKNS0_16TransformOptionsE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_115InterpreterPass14runOnOperationEv at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir6detail17OpToOpPassAdaptor3runEPNS_4PassEPNS_9OperationENS_15AnalysisManagerEbj at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir6detail17OpToOpPassAdaptor11runPipelineERNS_13OpPassManagerEPNS_9OperationENS_15AnalysisManagerEbjPNS_16PassInstrumentorEPKNS_19PassInstrumentation18PipelineParentInfoE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir11PassManager9runPassesEPNS_9OperationENS_15AnalysisManagerE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
_ZN4mlir11PassManager3runEPNS_9OperationE at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
mlirPassManagerRunOnOp at /mnt/.julia/artifacts/a7d008f5ba52e8657b34453e594bdc0e79c1f11d/lib/libReactantExtra.so (unknown line)
mlirPassManagerRunOnOp at /mnt/software/lux/Reactant.jl/src/mlir/libMLIR_h.jl:5853 [inlined]
run! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Pass.jl:74 [inlined]
#run_pass_pipeline!#1 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:267
run_pass_pipeline! at /mnt/software/lux/Reactant.jl/src/Compiler.jl:262 [inlined]
#compile_mlir!#8 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:313
compile_mlir! at /mnt/software/lux/Reactant.jl/src/Compiler.jl:293 [inlined]
#6 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:288 [inlined]
context! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
unknown function (ip: 0x7f6ebc0a9f36)
#compile_mlir#5 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:286
compile_mlir at /mnt/software/lux/Reactant.jl/src/Compiler.jl:283
unknown function (ip: 0x7f6ebc0a822d)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
do_call at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:126
eval_value at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:223
eval_stmt_value at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:174 [inlined]
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:663
jl_interpret_toplevel_thunk at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:625
eval_body at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:539
jl_interpret_toplevel_thunk at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
jl_toplevel_eval_flex at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
ijl_toplevel_eval_in at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
eval at ./Base.jl:130 [inlined]
repleval at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:229
#112 at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:192 [inlined]
with_logstate at ./logging/logging.jl:522
with_logger at ./logging/logging.jl:632 [inlined]
#111 at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:193
unknown function (ip: 0x7f6ebc910abf)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
jl_f__call_latest at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1055 [inlined]
invokelatest at ./essentials.jl:1052
unknown function (ip: 0x7f6f122ff822)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
do_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/builtins.c:831
#64 at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:34
unknown function (ip: 0x7f6f12367acf)
jl_apply at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
start_task at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/task.c:1202
Allocations: 178283979 (Pool: 178276956; Big: 7023); GC: 105
[1]    602122 segmentation fault (core dumped)  julia --project=docs --threads=auto --check-bounds=yes

@avik-pal
Copy link
Collaborator Author

avik-pal commented Jan 2, 2025

(base) ➜  Enzyme-JAX git:(ap/common_simplifications) ./bazel-bin/enzymexlamlir-opt --enzyme-hlo-opt --canonicalize --remove-unnecessary-enzyme-ops --arith-raise --enzyme --arith-raise  envs/fno_ad.mlir 
envs/fno_ad.mlir:2:177: error: 'complex.add' op operand #0 must be complex type with floating-point elements, but got 'tensor<16x64x64xcomplex<f32>>'
  func.func private @"Const{typeof(sumabs2first)}(Main.sumabs2first)_autodiff"(%arg0: tensor<2x64xf32>, %arg1: tensor<64xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64xf32>, %arg4: tensor<16x64x64xcomplex<f32>>, %arg5: tensor<64x64xf32>, %arg6: tensor<64xf32>, %arg7: tensor<16x64x64xcomplex<f32>>, %arg8: tensor<64x64xf32>, %arg9: tensor<64xf32>, %arg10: tensor<16x64x64xcomplex<f32>>, %arg11: tensor<64x64xf32>, %arg12: tensor<64xf32>, %arg13: tensor<16x64x64xcomplex<f32>>, %arg14: tensor<64x128xf32>, %arg15: tensor<128xf32>, %arg16: tensor<128x1xf32>, %arg17: tensor<1xf32>, %arg18: tensor<5x32x2xf32>) -> (tensor<f32>, tensor<2x64xf32>, tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x64xf32>, tensor<64xf32>, tensor<16x64x64xcomplex<f32>>, tensor<64x128xf32>, tensor<128xf32>, tensor<128x1xf32>, tensor<1xf32>, tensor<5x32x2xf32>) {
                                                                                                                                                                                ^
envs/fno_ad.mlir:2:177: note: see current operation: %504 = "complex.add"(%503, %arg24) <{fastmath = #arith.fastmath<none>}> : (tensor<16x64x64xcomplex<f32>>, tensor<16x64x64xcomplex<f32>>) -> tensor<16x64x64xcomplex<f32>>

where is complex.add showing up from?

@wsmoses
Copy link
Member

wsmoses commented Jan 2, 2025

the Enzyme pass will create, the arith-raise{stablehlo=true} pass if you run right after, will fix it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants