Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infinite recursion on unique(::Vector{Symbol}) within Reactant #493

Closed
mofeing opened this issue Jan 7, 2025 · 4 comments · Fixed by #505
Closed

Infinite recursion on unique(::Vector{Symbol}) within Reactant #493

mofeing opened this issue Jan 7, 2025 · 4 comments · Fixed by #505

Comments

@mofeing
Copy link
Collaborator

mofeing commented Jan 7, 2025

Reproducer

Reactant.compile(()) do
    x = [:a,:b,:c]
    unique(x)
end

Error

ERROR: StackOverflowError:
Stacktrace:
      [1] _unique_dims
        @ ./multidimensional.jl:1726 [inlined]
      [2] unique
        @ ./multidimensional.jl:1724 [inlined]
      [3] call_with_reactant(::typeof(unique), ::Vector{Symbol})
        @ Reactant ~/.julia/packages/Reactant/oG9qp/src/utils.jl:0
      [4] _unique_dims
        @ ./multidimensional.jl:1726 [inlined]
      [5] unique
        @ ./multidimensional.jl:1724 [inlined]
      [6] unique(none::Vector{Symbol})
        @ Reactant ./<missing>:0
--- the above 6 lines are repeated 19956 more times ---
 [119743] _unique_dims
        @ ./multidimensional.jl:1726 [inlined]
 [119744] unique
        @ ./multidimensional.jl:1724 [inlined]
 [119745] call_with_reactant(::typeof(unique), ::Vector{Symbol})
        @ Reactant ~/.julia/packages/Reactant/oG9qp/src/utils.jl:0
 [119746] _unique_dims
        @ ./multidimensional.jl:1726 [inlined]
 [119747] unique
        @ ./multidimensional.jl:1724 [inlined]
 [119748] #21
        @ ./REPL[20]:3 [inlined]
 [119749] (::var"#21#22")()
        @ Reactant ./<missing>:0
 [119750] GenericMemory
        @ ./boot.jl:516 [inlined]
 [119751] Array
        @ ./boot.jl:578 [inlined]
 [119752] vect
        @ ./array.jl:161 [inlined]
 [119753] #21
        @ ./REPL[20]:2 [inlined]
 [119754] call_with_reactant(redub_arguments#232::var"#21#22")
        @ Reactant ~/.julia/packages/Reactant/oG9qp/src/utils.jl:0
 [119755] make_mlir_fn(f::Function, args::Tuple{}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
        @ Reactant.TracedUtils ~/.julia/packages/Reactant/oG9qp/src/TracedUtils.jl:184
 [119756] make_mlir_fn
        @ ~/.julia/packages/Reactant/oG9qp/src/TracedUtils.jl:86 [inlined]
 [119757] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{}; optimize::Bool, no_nan::Bool)
        @ Reactant.Compiler ~/.julia/packages/Reactant/oG9qp/src/Compiler.jl:348
 [119758] compile_mlir!
        @ ~/.julia/packages/Reactant/oG9qp/src/Compiler.jl:339 [inlined]
 [119759] compile_xla(f::Function, args::Tuple{}; client::Nothing, optimize::Bool, no_nan::Bool)
        @ Reactant.Compiler ~/.julia/packages/Reactant/oG9qp/src/Compiler.jl:844
 [119760] compile_xla
        @ ~/.julia/packages/Reactant/oG9qp/src/Compiler.jl:835 [inlined]
 [119761] compile(f::Function, args::Tuple{}; client::Nothing, optimize::Bool, sync::Bool, no_nan::Bool)
        @ Reactant.Compiler ~/.julia/packages/Reactant/oG9qp/src/Compiler.jl:870

Log from debug interpreter

"args (var\"#23#24\",)"
"ir 2 1 ─ %1  = \$(Expr(:foreigncall, :(:jl_alloc_genericmemory), Ref{Memory{Symbol}}, svec(Any, Int64), 0, :(:ccall), Memory{Symbol}, 3, 3))::Memory{Symbol}\n  │   %2  = Core.memoryrefnew(%1)::MemoryRef{Symbol}\n  └── %3  = %new(Vector{Symbol}, %2, (3,))::Vector{Symbol}\n  2 ┄ %4  = φ (#1 => 1, #6 => %15)::Int64\n  │   %5  = φ (#1 => 1, #6 => %16)::Int64\n  │   %6  = Base.getfield((:a, :b, :c), %4, false)::Symbol\n  │   %7  = Base.getfield(%3, :ref)::MemoryRef{Symbol}\n  │   %8  = Base.memoryrefnew(%7, %4, false)::MemoryRef{Symbol}\n  │         Base.memoryrefset!(%8, %6, :not_atomic, false)::Symbol\n  │   %10 = (%5 === 3)::Bool\n  └──       goto #4 if not %10\n  3 ─       goto #5\n  4 ─ %13 = Base.add_int(%5, 1)::Int64\n  └──       goto #5\n  5 ┄ %15 = φ (#4 => %13)::Int64\n  │   %16 = φ (#4 => %13)::Int64\n  │   %17 = φ (#3 => true, #4 => false)::Bool\n  │   %18 = Base.not_int(%17)::Bool\n  └──       goto #7 if not %18\n  6 ─       goto #2\n  7 ─       goto #8\n3 8 ─ %22 = invoke Base.unique(%3::Vector{Symbol})::Vector{Symbol}\n  └──       return %22\n  "
"src CodeInfo(\n    @ REPL[22]:2 within `#23`\n   ┌ @ array.jl:161 within `vect`\n   │┌ @ boot.jl:578 within `Array`\n   ││┌ @ boot.jl:516 within `GenericMemory`\n1 ─│││ %1  = \$(Expr(:foreigncall, :(:jl_alloc_genericmemory), Ref{Memory{Symbol}}, svec(Any, Int64), 0, :(:ccall), Memory{Symbol}, 3, 3))::Memory{Symbol}\n│  ││└\n│  ││ @ boot.jl:579 within `Array`\n│  ││┌ @ boot.jl:522 within `memoryref`\n│  │││ %2  = Core.memoryrefnew(%1)::MemoryRef{Symbol}\n│  ││└\n└──││ %3  = %new(Vector{Symbol}, %2, (3,))::Vector{Symbol}\n   │└\n   │ @ array.jl:162 within `vect`\n2 ┄│ %4  = φ (#1 => 1, #6 => %15)::Int64\n│  │ %5  = φ (#1 => 1, #6 => %16)::Int64\n│  │ @ array.jl:163 within `vect`\n│  │┌ @ tuple.jl:33 within `__safe_getindex`\n│  ││ %6  = Base.getfield((:a, :b, :c), %4, false)::Symbol\n│  │└\n│  │┌ @ array.jl:1001 within `__safe_setindex!`\n│  ││┌ @ Base.jl:49 within `getproperty`\n│  │││ %7  = Base.getfield(%3, :ref)::MemoryRef{Symbol}\n│  ││└\n│  ││ %8  = Base.memoryrefnew(%7, %4, false)::MemoryRef{Symbol}\n│  ││       Base.memoryrefset!(%8, %6, :not_atomic, false)::Symbol\n│  │└\n│  │ @ array.jl:164 within `vect`\n│  │┌ @ range.jl:908 within `iterate`\n│  ││┌ @ promotion.jl:639 within `==`\n│  │││ %10 = (%5 === 3)::Bool\n│  ││└\n└──││       goto #4 if not %10\n3 ─││       goto #5\n   ││ @ range.jl:909 within `iterate`\n   ││┌ @ int.jl:87 within `+`\n4 ─│││ %13 = Base.add_int(%5, 1)::Int64\n└──│││       goto #5\n   │└└\n5 ┄│ %15 = φ (#4 => %13)::Int64\n│  │ %16 = φ (#4 => %13)::Int64\n│  │ %17 = φ (#3 => true, #4 => false)::Bool\n│  │ %18 = Base.not_int(%17)::Bool\n└──│       goto #7 if not %18\n6 ─│       goto #2\n   │ @ array.jl:165 within `vect`\n7 ─│       goto #8\n   └\n    @ REPL[22]:3 within `#23`\n   ┌ @ multidimensional.jl:1724 within `unique`\n   │┌ @ multidimensional.jl:1724 within `#unique#715`\n   ││┌ @ multidimensional.jl:1726 within `_unique_dims`\n8 ─│││ %22 = invoke Reactant.call_with_reactant(Base.unique::typeof(unique), %3::Vector{Symbol})::Any\n└──│││       return %22\n   └└└\n)"
"code_info CodeInfo(\n    @ REPL[22]:2 within `#23`\n   ┌ @ array.jl:161 within `vect`\n   │┌ @ boot.jl:578 within `Array`\n   ││┌ @ boot.jl:516 within `GenericMemory`\n1 ─│││ %1 = Core.getfield(##redub_arguments#232, 1)\n│  │││      (Reactant.safe_print)(\"fn arg[1]\", %1)\n│  │││ %3 = (Reactant.make_oc_ref)(Base.RefValue{Core.OpaqueClosure}(#undef), Tuple{}, Vector{Symbol}, CodeInfo(\n    @ REPL[22]:2 within `#23`\n   ┌ @ array.jl:161 within `vect`\n   │┌ @ boot.jl:578 within `Array`\n   ││┌ @ boot.jl:516 within `GenericMemory`\n1 ─│││ %1  = \$(Expr(:foreigncall, :(:jl_alloc_genericmemory), Ref{Memory{Symbol}}, svec(Any, Int64), 0, :(:ccall), Memory{Symbol}, 3, 3))::Memory{Symbol}\n│  ││└\n│  ││ @ boot.jl:579 within `Array`\n│  ││┌ @ boot.jl:522 within `memoryref`\n│  │││ %2  = Core.memoryrefnew(%1)::MemoryRef{Symbol}\n│  ││└\n└──││ %3  = %new(Vector{Symbol}, %2, (3,))::Vector{Symbol}\n   │└\n   │ @ array.jl:162 within `vect`\n2 ┄│ %4  = φ (#1 => 1, #6 => %15)::Int64\n│  │ %5  = φ (#1 => 1, #6 => %16)::Int64\n│  │ @ array.jl:163 within `vect`\n│  │┌ @ tuple.jl:33 within `__safe_getindex`\n│  ││ %6  = Base.getfield((:a, :b, :c), %4, false)::Symbol\n│  │└\n│  │┌ @ array.jl:1001 within `__safe_setindex!`\n│  ││┌ @ Base.jl:49 within `getproperty`\n│  │││ %7  = Base.getfield(%3, :ref)::MemoryRef{Symbol}\n│  ││└\n│  ││ %8  = Base.memoryrefnew(%7, %4, false)::MemoryRef{Symbol}\n│  ││       Base.memoryrefset!(%8, %6, :not_atomic, false)::Symbol\n│  │└\n│  │ @ array.jl:164 within `vect`\n│  │┌ @ range.jl:908 within `iterate`\n│  ││┌ @ promotion.jl:639 within `==`\n│  │││ %10 = (%5 === 3)::Bool\n│  ││└\n└──││       goto #4 if not %10\n3 ─││       goto #5\n   ││ @ range.jl:909 within `iterate`\n   ││┌ @ int.jl:87 within `+`\n4 ─│││ %13 = Base.add_int(%5, 1)::Int64\n└──│││       goto #5\n   │└└\n5 ┄│ %15 = φ (#4 => %13)::Int64\n│  │ %16 = φ (#4 => %13)::Int64\n│  │ %17 = φ (#3 => true, #4 => false)::Bool\n│  │ %18 = Base.not_int(%17)::Bool\n└──│       goto #7 if not %18\n6 ─│       goto #2\n   │ @ array.jl:165 within `vect`\n7 ─│       goto #8\n   └\n    @ REPL[22]:3 within `#23`\n   ┌ @ multidimensional.jl:1724 within `unique`\n   │┌ @ multidimensional.jl:1724 within `#unique#715`\n   ││┌ @ multidimensional.jl:1726 within `_unique_dims`\n8 ─│││ %22 = invoke Reactant.call_with_reactant(Base.unique::typeof(unique), %3::Vector{Symbol})::Any\n└──│││       return %22\n   └└└\n), 0, false, %1)\n│  │││ %4 = (%3)()\n│  │││      (Reactant.safe_print)(\"ocres\", %4)\n└──│││      return %4\n   └└└\n)"
"fn arg[1] #23"
@mofeing

This comment has been minimized.

@mofeing
Copy link
Collaborator Author

mofeing commented Jan 8, 2025

okay, if i replace the code to this:

x = [:a, :b, :c]
Reactant.compile(()) do
    unique(x)
end

then the generate IR is smaller and the error is still reproducible:

line 1

args (var\"#15#16\",)

line 2

src CodeInfo(
    @ REPL[26]:2 within `#15`
1%1 = Main.x::Any%2 = (Reactant.call_with_reactant)(Main.unique, %1)::Any
└──      return %2
)

line 3

code_info CodeInfo(
    @ REPL[26]:2 within `#15`
1%1 = Core.getfield(##redub_arguments#232, 1)
│        (Reactant.safe_print)("fn arg[1]", %1)
│   %3 = (Reactant.make_oc_ref)(Base.RefValue{Core.OpaqueClosure}(#undef), Tuple{}, Any, CodeInfo(
    @ REPL[26]:2 within `#15`
1%1 = Main.x::Any%2 = (Reactant.call_with_reactant)(Main.unique, %1)::Any
└──      return %2
), 0, false, %1)
│   %4 = (%3)()
│        (Reactant.safe_print)("ocres", %4)
└──      return %4
)

line 4

fn arg[1] #15

@mofeing
Copy link
Collaborator Author

mofeing commented Jan 8, 2025

Reactant.jl/src/utils.jl

Lines 444 to 450 in 47f363b

# Generator function which ensures that all calls to the function are executed within the ReactantInterpreter
# In particular this entails two pieces:
# 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance
# 2) Post type inference (using of course the reactant interpreter), all type unstable call functions are
# replaced with calls to `call_with_reactant`. This allows us to circumvent long standing issues in Julia
# using a custom interpreter in type unstable code.
# `redub_arguments` is `(typeof(original_function), map(typeof, original_args_tuple)...)`

here says that calls to type unstable code are replaced with call_with_reactant. but unique(::Vector{Symbol}) looks type stable to me...

@wsmoses could it be that it's rewriting a call that it shouldn't?

my bet is that...

  1. unique(::Vector{Symbol}) calls unique(::Colon, ::typeof(unique), ::Vector{Symbol})
  2. ...which calls _unique_dims(::Vector{Symbol}, ::Core.Const(Colon()))
  3. ...which ultimately calls invoke(::Core.Const(unique), Tuple{Any}, ::Vector{Symbol})

and that rewrite_inst within call_with_reactant_generator is replacing this last call to invoke with a call to unique... making it a infinite recursion 🫠

@mofeing
Copy link
Collaborator Author

mofeing commented Jan 15, 2025

i finally fixed it by using doing sth similar to @avik-pal in other Base methods: create a method with the same signature in Reactant's method table with @reactant_overlay, mark it as @noinline and inside call the original method with Base.inferencebarrier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
1 participant