Skip to content

Commit

Permalink
add Cached object with custom equality and hash for use a dict key
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx committed Dec 15, 2024
1 parent 5922ef0 commit 68559c5
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 19 deletions.
56 changes: 50 additions & 6 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,20 +132,23 @@ end

function ReactantCore.traced_call(f, args...)
seen_cache = Reactant.OrderedIdDict()
cache_key = (f, make_tracer(
make_tracer(
seen_cache,
args,
(),
CallCache;
(), # we have to insert something here, but we remove it immediately below.
TracedTrack;
toscalar=false,
track_numbers=(), # TODO: track_numbers?
))
)
linear_args = Reactant.MLIR.IR.Value[]
for (k, v) in seen_cache
v isa TracedType || continue
push!(linear_args, v.mlir_data)
# make tracer inserted `()` into the path, here we remove it:
v.paths = v.paths[1:end-1]
end

cache_key = Cached((f, args...))
if haskey(Reactant.Compiler.callcache[], cache_key)
# cache lookup:
(; f_name, mlir_result_types, traced_result) = Reactant.Compiler.callcache[][cache_key]
Expand Down Expand Up @@ -175,7 +178,7 @@ function ReactantCore.traced_call(f, args...)
traced_result = make_tracer(
seen_results,
traced_result,
nothing, # we have to insert something here, but we remove it immediately below.
(), # we have to insert something here, but we remove it immediately below.
TracedSetPath;
toscalar=false,
track_numbers=(),
Expand All @@ -185,7 +188,7 @@ function ReactantCore.traced_call(f, args...)
v isa TracedType || continue
# this mutates `traced_result`, which is what we want:
v.mlir_data = MLIR.IR.result(call_op, i)
# make tracer inserted `nothing` into the path, here we remove it:
# make tracer inserted `()` into the path, here we remove it:
v.paths = v.paths[1:end-1]
i += 1
end
Expand Down Expand Up @@ -220,3 +223,44 @@ function get_region_removing_missing_values(compiled_fn, insertions)
end
return region
end

struct Cached
obj
end
Base.:(==)(a::Cached, b::Cached) = recursive_equal(a.obj, b.obj)
Base.hash(a::Cached, h::UInt) = recursive_hash(a.obj, h)

recursive_equal(a, b) = false
function recursive_equal(a::T, b::T) where {T}
fn = fieldnames(T)
isempty(fn) && return a == b
for name in fn
!recursive_equal(getfield(a, name), getfield(b, name)) && return false
end
return true
end
function recursive_equal(a::T, b::T) where {T<:AbstractArray}
for (el_a, el_b) in zip(a, b)
!recursive_equal(el_a, el_b) && return false
end
return true
end
recursive_equal(a::T, b::T) where {T<:TracedRArray} = MLIR.IR.type(a.mlir_data) == MLIR.IR.type(b.mlir_data)


function recursive_hash(a::T, h::UInt) where T
fn = fieldnames(T)
isempty(fn) && return hash(a, h)
h = hash(T, h) # include type in the hash
for name in fn
h = recursive_hash(getfield(a, name), h)
end
return h
end
function recursive_hash(a::AbstractArray, h::UInt)
for el in a
h = recursive_hash(el, h)
end
return h
end
recursive_hash(a::TracedRArray, h::UInt) = hash(MLIR.IR.type(a.mlir_data), h)
13 changes: 0 additions & 13 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
TracedToConcrete = 3
ArrayToConcrete = 4
TracedSetPath = 5
CallCache = 6
end

for T in (
Expand Down Expand Up @@ -383,12 +382,6 @@ function make_tracer(
if mode == ConcreteToTraced
throw("Cannot trace existing trace type")
end
if mode == CallCache
if !haskey(seen, prev)
seen[prev] = prev
end
return MLIR.IR.type(prev.mlir_data)
end
if mode == TracedTrack
prev.paths = (prev.paths..., path)
if !haskey(seen, prev)
Expand Down Expand Up @@ -435,12 +428,6 @@ function make_tracer(
if mode == ConcreteToTraced
throw("Cannot trace existing trace type")
end
if mode == CallCache
if !haskey(seen, prev)
seen[prev] = prev
end
return MLIR.IR.type(prev.mlir_data)
end
if mode == TracedTrack
prev.paths = (prev.paths..., path)
if !haskey(seen, prev)
Expand Down

0 comments on commit 68559c5

Please sign in to comment.