Skip to content

Commit

Permalink
do recursive bit inside foreachfield codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
piever committed Jul 17, 2020
1 parent 68eaf79 commit 963be22
Showing 1 changed file with 25 additions and 10 deletions.
35 changes: 25 additions & 10 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,40 @@ else
const _getproperty = getproperty
end

function _foreachfield(names, L)
array_names_types(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = array_names_types(C)
array_names_types(::Type{NamedTuple{names, types}}) where {names, types} = zip(names, types.parameters)
array_names_types(::Type{T}) where {T<:Tuple} = enumerate(T.parameters)

function apply_f_to_vars_fields(names_types, vars)
exprs = Expr[]
for (name, type) in names_types
sym = QuoteNode(name)
args = [Expr(:call, :_getproperty, var, sym) for var in vars]
expr = if type <: StructArray
apply_f_to_vars_fields(array_names_types(type), args)
else
Expr(:call, :f, args...)
end
push!(exprs, expr)
end
return Expr(:block, exprs...)
end

function _foreachfield(names_types, L)
vars = ntuple(i -> gensym(), L)
exprs = Expr[]
for (i, v) in enumerate(vars)
push!(exprs, Expr(:(=), v, Expr(:call, :getfield, :xs, i)))
end
for field in names
sym = QuoteNode(field)
args = [Expr(:call, :_getproperty, var, sym) for var in vars]
push!(exprs, Expr(:call, :f, args...))
end
push!(exprs, apply_f_to_vars_fields(names_types, vars))
push!(exprs, :(return nothing))
return Expr(:block, exprs...)
end

@generated foreachfield_gen(::NamedTuple{names}, f, xs::Vararg{Any, L}) where {names, L} =
_foreachfield(names, L)
@generated foreachfield_gen(::NTuple{N, Any}, f, xs::Vararg{Any, L}) where {N, L} =
_foreachfield(Base.OneTo(N), L)
@generated foreachfield_gen(::NT, f, xs::Vararg{Any, L}) where {NT<:NamedTuple, L} =
_foreachfield(array_names_types(NT), L)
@generated foreachfield_gen(::T, f, xs::Vararg{Any, L}) where {T<:Tuple, L} =
_foreachfield(array_names_types(T), L)

foreachfield(f, x::StructArray, xs...) = foreachfield_gen(fieldarrays(x), f, x, xs...)

Expand Down

0 comments on commit 963be22

Please sign in to comment.