diff --git a/src/utils.jl b/src/utils.jl index 9b751d14f..52d1a1e61 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -148,7 +148,8 @@ function should_rewrite_ft(@nospecialize(ft)) ft <: typeof(Base.typemin) || ft <: typeof(Base.getproperty) || ft <: typeof(Base.vect) || - ft <: typeof(Base.eltype) + ft <: typeof(Base.eltype) || + ft <: typeof(Base.argtail) return false end @@ -168,6 +169,20 @@ function is_reactant_method(mi::Core.MethodInstance) return mt === REACTANT_METHOD_TABLE end +@generated function applyiterate_with_reactant( + iteratefn, applyfn, args::Vararg{Any,N} +) where {N} + @assert iteratefn == typeof(Base.iterate) + newargs = Vector{Expr}(undef, N) + for i in 1:N + @inbounds newargs[i] = :(args[$i]...) + end + quote + Base.@_inline_meta + call_with_reactant(applyfn, $(newargs...)) + end +end + function rewrite_inst(inst, ir, interp) if Meta.isexpr(inst, :call) # Even if type unstable we do not want (or need) to replace intrinsic @@ -176,7 +191,13 @@ function rewrite_inst(inst, ir, interp) if ft == typeof(Core.kwcall) ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir)) end - if should_rewrite_ft(ft) + if ft == typeof(Core._apply_iterate) + ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir)) + if should_rewrite_ft(ft) + rep = Expr(:call, applyiterate_with_reactant, inst.args[2:end]...) + return true, rep + end + elseif should_rewrite_ft(ft) rep = Expr(:call, call_with_reactant, inst.args...) return true, rep end @@ -353,7 +374,7 @@ function call_with_reactant_generator( # look up the method match builtin_error = :(throw( - AssertionError("Unsupported call_with_reactant of builtin $redub_arguments") + AssertionError("Unsupported call_with_reactant of builtin $(args[1])") )) if args[1] <: Core.Builtin