Skip to content

Commit

Permalink
Implement a working VJP function
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 29, 2024
1 parent 50f64a1 commit fb7ea0a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 21 deletions.
44 changes: 26 additions & 18 deletions ext/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,28 +67,38 @@ end
function Lux.__to_reactant_adaptor(

Check warning on line 67 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L67

Added line #L67 was not covered by tests
to::Lux.ToReactantAdaptor{FST}, model::AbstractExplicitLayer,
input_prototype, ps, st, eltype_adaptor) where {FST}
output = first(model(input_prototype, ps, st))
concrete_output = __make_concrete_array(output)

Check warning on line 71 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L70-L71

Added lines #L70 - L71 were not covered by tests

concrete_input = __make_concrete_array(input_prototype)
cmodel = __make_concrete_array(model)
cps = __make_concrete_array(ps)
cst = __make_concrete_array(st)

Check warning on line 76 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L73-L76

Added lines #L73 - L76 were not covered by tests

csmodel = Lux.StatefulLuxLayer{FST}(cmodel, cps, cst)

Check warning on line 78 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L78

Added line #L78 was not covered by tests

fwd = Reactant.compile((m, x) -> m(x), (csmodel, concrete_input))

bwd = try
enzyme_grad_fn = (m, x) -> begin
dx = Enzyme.make_zero(x)
dps = Enzyme.make_zero(m.ps)
st = ifelse(FST, m.st, m.st_any)
Enzyme.autodiff(
Enzyme.Reverse, (m, x, ps, st) -> first(LuxCore.apply(m, x, ps, st)),
Enzyme.Duplicated, Enzyme.Const(m.model), Enzyme.Duplicated(x, dx),
Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st))
return (; ps=dps), dx
fwd_fn = Reactant.compile((m, x) -> m(x), (csmodel, concrete_input))

Check warning on line 80 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L80

Added line #L80 was not covered by tests

function enzyme_vjp_fn(m, x, y, dy)
dx = Enzyme.make_zero(x)
dps = Enzyme.make_zero(m.ps)
st_m = ifelse(FST, m.st, m.st_any)

Check warning on line 85 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L82-L85

Added lines #L82 - L85 were not covered by tests

function wrapper_fn!(y, model, x, ps, st)
copyto!(y, first(LuxCore.apply(model, x, ps, st)))
return nothing

Check warning on line 89 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L87-L89

Added lines #L87 - L89 were not covered by tests
end

Reactant.compile(enzyme_grad_fn, (csmodel, concrete_input))
Enzyme.autodiff(Enzyme.Reverse, wrapper_fn!, Enzyme.Const, Enzyme.Duplicated(y, dy),

Check warning on line 92 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L92

Added line #L92 was not covered by tests
Enzyme.Const(m.model), Enzyme.Duplicated(x, dx),
Enzyme.Duplicated(m.ps, dps), Enzyme.Const(st_m))
return dx, dps

Check warning on line 95 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L95

Added line #L95 was not covered by tests
end

vjp_fn = try
concrete_output2 = __make_concrete_array(deepcopy(output))
Reactant.compile(

Check warning on line 100 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L98-L100

Added lines #L98 - L100 were not covered by tests
enzyme_vjp_fn, (csmodel, concrete_input, concrete_output, concrete_output2))
catch err
to.force_compile_backward && rethrow(err)
@error """

Check warning on line 104 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L103-L104

Added lines #L103 - L104 were not covered by tests
Expand All @@ -101,11 +111,9 @@ function Lux.__to_reactant_adaptor(
nothing

Check warning on line 111 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L111

Added line #L111 was not covered by tests
end

# TODO: Add compiled types to the layer type information. That way we can check
# if the model is being executed with the correct types.
return Lux.ReactantLayer{FST, Lux.__recursive_eltype(input_prototype)}(

Check warning on line 114 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L114

Added line #L114 was not covered by tests
to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd,
bwd, eltype_adaptor, fmapstructure(Lux.__size, input_prototype))
to, input_prototype, concrete_input, cps, cst, model, cmodel, fwd_fn,
vjp_fn, eltype_adaptor, fmapstructure(Lux.__size, input_prototype))
end

# TODO: Currently we are maintaining 2 copies of the parameters, this is not ideal.
Expand Down Expand Up @@ -183,7 +191,7 @@ end

Lux.__apply_reactant(l, x, ps, st) = __graceful_type_mismatch_error(l, x, ps, st)

Check warning on line 192 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L192

Added line #L192 was not covered by tests

@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd(csmodel, x)
@inline Lux.__apply_reactant(l::Lux.ReactantLayer, csmodel, x) = l.fwd_fn(csmodel, x)

Check warning on line 194 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L194

Added line #L194 was not covered by tests

# Don't inline, else types don't get displayed in the stack trace
function __graceful_type_mismatch_error(

Check warning on line 197 in ext/LuxReactantExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxReactantExt.jl#L197

Added line #L197 was not covered by tests
Expand Down
7 changes: 4 additions & 3 deletions src/layers/extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ end

# TODO: Add a ChainRules rrule that calls the `bwd` function, i.e. uses Enzyme for the
# gradient computation
@concrete struct ReactantLayer{FST, T, inType, inCType, psType, stType, F, B,
# TODO: Inference won't work OOTB, we will have to compile that separately
@concrete struct ReactantLayer{FST, T, inType, inCType, psType, stType,
L <: AbstractExplicitLayer, AD <: ToReactantAdaptor} <: AbstractExplicitLayer
adaptor::AD
input_prototype::inType
Expand All @@ -254,8 +255,8 @@ end
concrete_st::stType
layer::L
clayer
fwd::F
bwd::B
fwd_fn
vjp_fn
eltype_adaptor
input_structure
end
Expand Down

0 comments on commit fb7ea0a

Please sign in to comment.