Skip to content

Commit

Permalink
Reimplement getfield to allow customization on external types
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jul 23, 2024
1 parent 9eabd69 commit dffa473
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ macro code_hlo(options, maybe_call=nothing)
end
end

traced_getfield(obj, field) = Base.getfield(obj, field)

function compile(f, args; pipeline_options="", client=nothing)
N = length(args)
ctx = MLIR.IR.Context()
Expand Down Expand Up @@ -333,7 +335,7 @@ function compile(f, args; pipeline_options="", client=nothing)
end
res = :(args[$(path[2])])
for p in path[3:end]
res = :(Base.getfield($res, $(Meta.quot(p))))
res = :(traced_getfield($res, $(Meta.quot(p))))
end
sym = Symbol("sbuf_$i")
sbuf = :($sym = XLA.synced_buffer($res.data))
Expand Down Expand Up @@ -375,7 +377,7 @@ function compile(f, args; pipeline_options="", client=nothing)
path = path[3:end]
end
for p in path
res = :(Base.getfield($res, $(Meta.quot(p))))
res = :(traced_getfield($res, $(Meta.quot(p))))
end
res = :($res.data = $(Symbol("concrete_res_$(idx)")))
push!(delinearized_results, res)
Expand All @@ -401,12 +403,12 @@ function compile(f, args; pipeline_options="", client=nothing)
path = path[3:end]
end
for p in path
res = :(Base.getfield($res, $(Meta.quot(p))))
res = :(traced_getfield($res, $(Meta.quot(p))))
end

argres = :(args[argpath[2]])
for p in argpath[3:end]
argres = :(Base.getfield($argres, $(Meta.quot(p))))
argres = :(traced_getfield($argres, $(Meta.quot(p))))
end

res = :($res.data = $argres.data)
Expand Down

0 comments on commit dffa473

Please sign in to comment.