Skip to content

Commit

Permalink
Rewrite apply iterate (#479)
Browse files Browse the repository at this point in the history
* Rewrite apply iterate

* fix

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
wsmoses and github-actions[bot] authored Jan 5, 2025
1 parent 0b1ee70 commit ab61f24
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ function should_rewrite_ft(@nospecialize(ft))
ft <: typeof(Base.typemin) ||
ft <: typeof(Base.getproperty) ||
ft <: typeof(Base.vect) ||
ft <: typeof(Base.eltype)
ft <: typeof(Base.eltype) ||
ft <: typeof(Base.argtail)
return false
end

Expand All @@ -168,6 +169,20 @@ function is_reactant_method(mi::Core.MethodInstance)
return mt === REACTANT_METHOD_TABLE
end

@generated function applyiterate_with_reactant(
iteratefn, applyfn, args::Vararg{Any,N}
) where {N}
@assert iteratefn == typeof(Base.iterate)
newargs = Vector{Expr}(undef, N)
for i in 1:N
@inbounds newargs[i] = :(args[$i]...)
end
quote
Base.@_inline_meta
call_with_reactant(applyfn, $(newargs...))
end
end

function rewrite_inst(inst, ir, interp)
if Meta.isexpr(inst, :call)
# Even if type unstable we do not want (or need) to replace intrinsic
Expand All @@ -176,7 +191,13 @@ function rewrite_inst(inst, ir, interp)
if ft == typeof(Core.kwcall)
ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir))
end
if should_rewrite_ft(ft)
if ft == typeof(Core._apply_iterate)
ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir))
if should_rewrite_ft(ft)
rep = Expr(:call, applyiterate_with_reactant, inst.args[2:end]...)
return true, rep
end
elseif should_rewrite_ft(ft)
rep = Expr(:call, call_with_reactant, inst.args...)
return true, rep
end
Expand Down Expand Up @@ -353,7 +374,7 @@ function call_with_reactant_generator(

# look up the method match
builtin_error = :(throw(
AssertionError("Unsupported call_with_reactant of builtin $redub_arguments")
AssertionError("Unsupported call_with_reactant of builtin $(args[1])")
))

if args[1] <: Core.Builtin
Expand Down

0 comments on commit ab61f24

Please sign in to comment.