Skip to content

Commit

Permalink
fix: handle traced array returns inside objects (#417)
Browse files Browse the repository at this point in the history
* fix: handle traced array returns inside objects

* test: add #416 as a test

* fix: propagate track_numbers correctly

* fix: aliasing and add a test

* test: use updated API for the tests

* feat: cache new arrays

* fix: traced_getfield
  • Loading branch information
avik-pal authored Dec 24, 2024
1 parent 0b6dafc commit 057e6b8
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 91 deletions.
120 changes: 84 additions & 36 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,27 @@ import ..Reactant:
ConcreteRNumber,
TracedRArray,
TracedRNumber,
RArray,
RNumber,
OrderedIdDict,
make_tracer,
TracedToConcrete,
append_path,
TracedType

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
@inline traced_getfield(
@nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field
) = Base.getindex(obj, field)
@inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T}
(isbitstype(T) || obj isa RArray) && return Base.getfield(obj, field)
return Base.getindex(obj, field)
end

@inline traced_setfield!(@nospecialize(obj), field, val) = Base.setfield!(obj, field, val)
@inline function traced_setfield!(
@nospecialize(obj::AbstractArray{T}), field, val
) where {T}
(isbitstype(T) || obj isa RArray) && return Base.setfield!(obj, field, val)
return Base.setindex!(obj, val, field)
end

function create_result(tocopy::T, path, result_stores) where {T}
if !isstructtype(typeof(tocopy))
Expand Down Expand Up @@ -573,32 +584,32 @@ function codegen_flatten!(linear_args, result_stores)
push!(flatten_code, :($usbuf = $flatcode.data))
push!(flatten_code, :($sbuf = XLA.synced_buffer($usbuf)))

# TODO
respaths = ((p for p in arg.paths if p[1] != :args)...,)
# TODO: unused for the time being
# respaths = ((p for p in arg.paths if p[1] == :result || p[1] == :resargs)...,)

# resarg = false
for respath in respaths
if respath[1] == :result
flatcode = :result
respath = respath[2:end]
result_stores[respath] = usbuf
resarg = true
else
@assert respath[1] == :resargs
if respath[2] != path[2]
continue
end
# flatcode = :(args[$(respath[2])])
path = path[3:end]
end
# for p in path
# flatcode = :(traced_getfield($flatcode, $(Meta.quot(p))))
# end
# resarg = true
# flatcode = :($flatcode.data = $usbuf)
# @show flatcode
# push!(flatten_code, res)
end
# for respath in respaths
# if respath[1] == :result
# flatcode = :result
# respath = respath[2:end]
# result_stores[respath] = usbuf
# resarg = true
# else
# @assert respath[1] == :resargs
# if respath[2] != path[2]
# continue
# end
# # flatcode = :(args[$(respath[2])])
# path = path[3:end]
# end
# # for p in path
# # flatcode = :(traced_getfield($flatcode, $(Meta.quot(p))))
# # end
# # resarg = true
# # flatcode = :($flatcode.data = $usbuf)
# # @show flatcode
# # push!(flatten_code, res)
# end
# if resarg
# push!(resarg_code, :($usbuf = $flatcode.data))
# end
Expand All @@ -620,11 +631,16 @@ function codegen_unflatten!(
concrete_result,
result_stores,
)
unflatten_code = Expr[]
cache_dict = gensym("cache_dict")
unflatten_code = Expr[:(
$cache_dict = $(IdDict{
Union{TracedRArray,TracedRNumber},Union{ConcreteRArray,ConcreteRNumber}
}())
),]

# mutate the result stores to point to the correct concrete results
for (concrete_res_name, result) in zip(concretized_res_names, linear_results)
paths = ((p for p in result.paths if p[1] != :args)...,)
paths = ((p for p in result.paths if p[1] == :result || p[1] == :resargs)...,)
for path in paths
if path[1] == :result
unflatcode = :result
Expand All @@ -635,15 +651,47 @@ function codegen_unflatten!(
@assert path[1] == :resargs
unflatcode = :(args[$(path[2])])
path = path[3:end]
end

# unroll path tree
for p in path
unflatcode = :(traced_getfield($unflatcode, $(Meta.quot(p))))
end
unflatcode = :($unflatcode.data = $concrete_res_name)
for p in path[1:(end - 1)]
unflatcode = :(traced_getfield($unflatcode, $(Meta.quot(p))))
end

push!(unflatten_code, unflatcode)
if length(path) > 0
final_val = gensym("final_val")
clocal = gensym("clocal")
unflatcode = quote
$final_val = traced_getfield($unflatcode, $(Meta.quot(path[end])))
if $final_val isa TracedRArray
$clocal = if haskey($cache_dict, $final_val)
$cache_dict[$final_val]
else
$cache_dict[$final_val] = ConcreteRArray{
eltype($final_val),ndims($final_val)
}(
$concrete_res_name, size($final_val)
)
$cache_dict[$final_val]
end
traced_setfield!($unflatcode, $(Meta.quot(path[end])), $clocal)
elseif $final_val isa TracedRNumber
$clocal = if haskey($cache_dict, $final_val)
$cache_dict[$final_val]
else
$cache_dict[$final_val] = ConcreteRNumber{eltype($final_val)}(
$concrete_res_name
)
$cache_dict[$final_val]
end
traced_setfield!($unflatcode, $(Meta.quot(path[end])), $clocal)
else
traced_setfield!($final_val, :data, $concrete_res_name)
end
end
else
unflatcode = :($unflatcode.data = $concrete_res_name)
end
push!(unflatten_code, unflatcode)
end
end
end

Expand Down
15 changes: 12 additions & 3 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ function overload_autodiff(
primf = f.val
primargs = ((v.val for v in args)...,)

fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = Reactant.TracedUtils.make_mlir_fn(
fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = TracedUtils.make_mlir_fn(
primf, primargs, (), string(f) * "_autodiff", false
)

Expand Down Expand Up @@ -302,7 +302,7 @@ function overload_autodiff(
cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
push!(ad_inputs, cst)
end
else
elseif TracedUtils.has_argidx(a)
idx, path = TracedUtils.get_argidx(a)
if idx == 1 && fnwrap
act = act_from_type(f, reverse, true)
Expand All @@ -322,6 +322,12 @@ function overload_autodiff(
end
TracedUtils.push_val!(ad_inputs, args[idx].dval, path[3:end])
end
else
act = act_from_type(Enzyme.Const, reverse, true)
push!(ret_activity, act)
if act != enzyme_out && act != enzyme_outnoneed
continue
end
end
end

Expand Down Expand Up @@ -385,7 +391,7 @@ function overload_autodiff(
end
residx += 1
end
else
elseif TracedUtils.has_argidx(a)
idx, path = TracedUtils.get_argidx(a)
if idx == 1 && fnwrap
TracedUtils.set!(
Expand All @@ -405,6 +411,9 @@ function overload_autodiff(
)
residx += 1
end
else
TracedUtils.set!(a, (), TracedUtils.transpose_val(MLIR.IR.result(res, residx)))
residx += 1
end
end

Expand Down
12 changes: 12 additions & 0 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,18 @@ function get_argidx(x)
throw(AssertionError("No path found for $x"))
end

function has_argidx(x)
for path in x.paths
if length(path) == 0
continue
end
if path[1] == :args
return true
end
end
return false
end

function set!(x, path, tostore; emptypath=false)
for p in path
x = Reactant.Compiler.traced_getfield(x, p)
Expand Down
Loading

0 comments on commit 057e6b8

Please sign in to comment.