From dffa473a371167c59fdada186d6c902b797618c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 23 Jul 2024 19:37:48 +0200 Subject: [PATCH] Reimplement `getfield` to allow customization on external types --- src/Compiler.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 6a386b9bf..8f78e6751 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -269,6 +269,8 @@ macro code_hlo(options, maybe_call=nothing) end end +traced_getfield(obj, field) = Base.getfield(obj, field) + function compile(f, args; pipeline_options="", client=nothing) N = length(args) ctx = MLIR.IR.Context() @@ -333,7 +335,7 @@ function compile(f, args; pipeline_options="", client=nothing) end res = :(args[$(path[2])]) for p in path[3:end] - res = :(Base.getfield($res, $(Meta.quot(p)))) + res = :(traced_getfield($res, $(Meta.quot(p)))) end sym = Symbol("sbuf_$i") sbuf = :($sym = XLA.synced_buffer($res.data)) @@ -375,7 +377,7 @@ function compile(f, args; pipeline_options="", client=nothing) path = path[3:end] end for p in path - res = :(Base.getfield($res, $(Meta.quot(p)))) + res = :(traced_getfield($res, $(Meta.quot(p)))) end res = :($res.data = $(Symbol("concrete_res_$(idx)"))) push!(delinearized_results, res) @@ -401,12 +403,12 @@ function compile(f, args; pipeline_options="", client=nothing) path = path[3:end] end for p in path - res = :(Base.getfield($res, $(Meta.quot(p)))) + res = :(traced_getfield($res, $(Meta.quot(p)))) end argres = :(args[argpath[2]]) for p in argpath[3:end] - argres = :(Base.getfield($argres, $(Meta.quot(p)))) + argres = :(traced_getfield($argres, $(Meta.quot(p)))) end res = :($res.data = $argres.data)