Skip to content

Commit

Permalink
Support compilation of closures (#36)
Browse files Browse the repository at this point in the history
* Add `make_tracer` case for closures

* Add closure case in `traced_type`

Also remove `make_tracer` case for closures because it can be treated by the struct case

* Test on mock closure

* Fix Julia codegen for traced closures

* Replace vector comprehension for tuple

* Skip condition in CI

---------

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
  • Loading branch information
mofeing and Sergio Sánchez Ramírez authored Jul 3, 2024
1 parent d2278e9 commit 989053c
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 11 deletions.
63 changes: 52 additions & 11 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,19 @@ end
end

if T <: Function
return T
# functions are directly returned
if sizeof(T) == 0
return T
end

# in closures, enclosured variables need to be traced
N = fieldcount(T)
traced_fieldtypes = ntuple(Val(N)) do i
return traced_type(fieldtype(T, i), seen, Val(mode))
end

# closure are struct types with the types of enclosured vars as type parameters
return Core.apply_type(T.name.wrapper, traced_fieldtypes...)
end

if T <: DataType
Expand Down Expand Up @@ -725,8 +737,15 @@ function create_result(::Type{MakeStruct{AT,tocopy}}, path, result_stores) where
return Expr(:new, AT, elems...)
end

struct Thunk{linear_results_paths,linear_args_paths,preserved_args_paths,concrete_result_ty}
struct Thunk{
linear_results_paths,
linear_args_paths,
preserved_args_paths,
concrete_result_ty,
closure_ty,
}
exec::XLA.LoadedExecutable
fnwrap::closure_ty
end

@generated function (
Expand All @@ -735,10 +754,18 @@ end
Val{linear_args_paths},
Val{preserved_args_paths},
concrete_result_ty,
closure_ty,
}
)(
args::Vararg{Any,N}
) where {linear_results_paths,linear_args_paths,preserved_args_paths,N,concrete_result_ty}
) where {
linear_results_paths,
linear_args_paths,
preserved_args_paths,
N,
concrete_result_ty,
closure_ty,
}
arg_syncs = Expr[]
topres = Symbol[]
linearized_args = Union{Symbol,Expr}[]
Expand Down Expand Up @@ -859,6 +886,13 @@ end
resexpr = create_result(concrete_result_ty, (), result_stores)
expr = quote
Base.@_inline_meta
$(
# if `f` is a closure, then prepend the closure into `args`
# the closure fields will be correctly extracted from it as the tracer has already passed through it
if !(closure_ty <: Nothing)
:(args = (thunk.fnwrap, args...))
end
)
$exec_call
$(concretize...)
# Needs to store into result
Expand All @@ -870,17 +904,27 @@ end
end

function generate_jlfunc(
concrete_result, client, mod, Nargs, linear_args, linear_results, preserved_args
)
concrete_result,
client,
mod,
linear_args,
linear_results,
preserved_args,
fnwrap::closure_ty,
) where {closure_ty}
linear_results_paths = (map(x -> x.paths, linear_results)...,)
linear_args_paths = (map(x -> x.paths, linear_args)...,)
preserved_args_paths = (map(x -> (x[1].paths, x[2]), preserved_args)...,)
exec = XLA.Compile(client, mod)
v = make_valable(concrete_result)
return Thunk{
Val{linear_results_paths},Val{linear_args_paths},Val{preserved_args_paths},v
Val{linear_results_paths},
Val{linear_args_paths},
Val{preserved_args_paths},
v,
closure_ty,
}(
exec
exec, fnwrap
)
end

Expand Down Expand Up @@ -1068,7 +1112,6 @@ function compile(
fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn(
mod, f, args, (), "main", true
)
@assert !fnwrapped

concrete_seen = IdDict()

Expand Down Expand Up @@ -1130,16 +1173,14 @@ function compile(
MLIR.API.mlirOperationDestroy(func2.operation)
func2.operation = MLIR.API.MlirOperation(C_NULL)

# println(string(mod))

return generate_jlfunc(
concrete_result,
client,
mod,
N,
linear_args,
linear_results2,
preserved_args,
fnwrapped ? f : nothing,
)
end
end
Expand Down
13 changes: 13 additions & 0 deletions test/closure.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using Reactant

muler(x) = y -> x * y

@testset "closure" begin
x = Reactant.ConcreteRArray(ones(2, 2))
y = Reactant.ConcreteRArray(ones(2, 2))

f = muler(x)
g = Reactant.compile(f, (y,))

@test g(y) x * y
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,6 @@ include("basic.jl")
include("bcast.jl")
include("nn.jl")
include("struct.jl")
include("closure.jl")
include("compile.jl")
include("nn_lux.jl")

0 comments on commit 989053c

Please sign in to comment.