Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit Enzyme rules #349

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

Explicit Enzyme rules #349

wants to merge 5 commits into from

Conversation

devmotion
Copy link
Member

@devmotion devmotion commented Nov 8, 2024

I started to implement the Enzyme rule for find_alpha directly, without importing the ChainRules. This should be more performant and less brittle given that the @import_frule and @import_rrule are not very polished (I think there are a few bugs actually but I haven't checked it with a MWE yet) and even advised against in the docstrings.

Tests with EnzymeTestUtils pass successfully on Julia 1.10. I assume Julia 1.6 might be too old for Enzyme and 1.11 only properly supported by Enzyme 0.13.

Benchmark

A simple benchmark:

master

julia> using Bijectors, Enzyme, Chairmarks

julia> @be (randn(), expm1(randn()), randn()) autodiff(Forward, Bijectors.find_alpha, Duplicated, Duplicated(_[1], 1.0), Duplicated(_[2], 1.0), Duplicated(_[3], 1.0))
Benchmark: 3970 samples with 22 evaluations
 min    787.909 ns (11 allocs: 256 bytes)
 median 1.034 μs (11 allocs: 256 bytes)
 mean   1.059 μs (11 allocs: 256 bytes)
 max    5.146 μs (11 allocs: 256 bytes)

julia> @be (randn(), expm1(randn()), randn()) autodiff(Reverse, Bijectors.find_alpha, Active, Active(_[1]), Active(_[2]), Active(_[3]))
Benchmark: 3628 samples with 33 evaluations
 min    494.970 ns (12 allocs: 288 bytes)
 median 698.227 ns (12 allocs: 288 bytes)
 mean   767.860 ns (12 allocs: 288 bytes)
 max    38.293 μs (12 allocs: 288 bytes)

julia> @be [randn(), expm1(randn()), randn()] Enzyme.gradient(Forward, x -> Bijectors.find_alpha(x[1], x[2], x[3]), _)
Benchmark: 3270 samples with 8 evaluations
 min    2.771 μs (37 allocs: 1.074 KiB)
 median 3.464 μs (37 allocs: 1.074 KiB)
 mean   3.554 μs (37 allocs: 1.074 KiB)
 max    5.828 μs (37 allocs: 1.074 KiB)

julia> @be [randn(), expm1(randn()), randn()] Enzyme.gradient(Reverse, x -> Bijectors.find_alpha(x[1], x[2], x[3]), _)
Benchmark: 3007 samples with 47 evaluations
 min    469.851 ns (13 allocs: 368 bytes)
 median 649.809 ns (13 allocs: 368 bytes)
 mean   661.067 ns (13 allocs: 368 bytes)
 max    1.335 μs (13 allocs: 368 bytes)

This PR

julia> using Bijectors, Enzyme, Chairmarks

julia> @be (randn(), expm1(randn()), randn()) autodiff(Forward, Bijectors.find_alpha, Duplicated, Duplicated(_[1], 1.0), Duplicated(_[2], 1.0), Duplicated(_[3], 1.0))
Benchmark: 3733 samples with 42 evaluations
 min    366.071 ns (2 allocs: 64 bytes)
 median 572.405 ns (2 allocs: 64 bytes)
 mean   586.341 ns (2 allocs: 64 bytes)
 max    1.141 μs (2 allocs: 64 bytes)

julia> @be (randn(), expm1(randn()), randn()) autodiff(Reverse, Bijectors.find_alpha, Active, Active(_[1]), Active(_[2]), Active(_[3]))
Benchmark: 3924 samples with 39 evaluations
 min    364.333 ns (2 allocs: 32 bytes)
 median 587.615 ns (2 allocs: 32 bytes)
 mean   603.323 ns (2 allocs: 32 bytes)
 max    1.327 μs (2 allocs: 32 bytes)

julia> @be [randn(), expm1(randn()), randn()] Enzyme.gradient(Forward, x -> Bijectors.find_alpha(x[1], x[2], x[3]), _)
Benchmark: 3223 samples with 27 evaluations
 min    839.519 ns (11.30 allocs: 552.296 bytes)
 median 1.026 μs (11.30 allocs: 552.296 bytes)
 mean   1.056 μs (11.30 allocs: 552.296 bytes)
 max    1.755 μs (11.30 allocs: 552.296 bytes)

julia> @be [randn(), expm1(randn()), randn()] Enzyme.gradient(Reverse, x -> Bijectors.find_alpha(x[1], x[2], x[3]), _)
Benchmark: 2907 samples with 56 evaluations
 min    371.286 ns (3 allocs: 112 bytes)
 median 562.500 ns (3 allocs: 112 bytes)
 mean   574.438 ns (3 allocs: 112 bytes)
 max    2.377 μs (3 allocs: 112 bytes)

@TuringLang TuringLang deleted a comment from github-actions bot Nov 8, 2024
@TuringLang TuringLang deleted a comment from github-actions bot Nov 8, 2024
@TuringLang TuringLang deleted a comment from github-actions bot Nov 8, 2024
@devmotion devmotion force-pushed the dw/enzyme branch 2 times, most recently from bc56823 to b1049ed Compare November 8, 2024 16:12
@devmotion devmotion marked this pull request as ready for review November 8, 2024 17:11
@yebai
Copy link
Member

yebai commented Nov 9, 2024

@torfjelde can you help review this PR since you worked on the rules for ReverseDiff?

@yebai
Copy link
Member

yebai commented Nov 9, 2024

@wsmoses, you might want to look since the added rules are for Enzyme.

_muladd_partial(x::Real, y::BatchDuplicated, ::Nothing) = map(Base.Fix1(*, x), y.dval)

function EnzymeRules.forward(
::Const{typeof(find_alpha)},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is using the rule system for Enzyme <=0.12, which was changed in Enzyme 0.13. [or equivalently EnzymeCore 0.7 and 0.8 iirc]. We should probably use the latter.

The breaking change of relevance is that forward mode rules now have a config as the first arg [similar to ReverseMode], with utility functions that specify whether the original return and/or derivative returns are requested.

For example: https://github.com/EnzymeAD/Enzyme.jl/blob/31df08b376634e2926e1d147c67284b20c493c22/src/internal_rules.jl#L858

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is using the rule system for Enzyme <=0.12, which was changed in Enzyme 0.13.

Yes, I know, this is intentional since the master branch uses Enzyme 0.12 and I wanted to keep the update to Enzyme 0.13 and the switch from the imported to the explicit rules separate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The breaking change of relevance is that forward mode rules now have a config as the first arg [similar to ReverseMode], with utility functions that specify whether the original return and/or derivative returns are requested.

I saw this, and one thing that confused me was: What's the purpose of the return type/activity (Const/Duplicated/DuplicatedNoNeed) if it's solely based on needs_primal(config) and needs_shadow(config) what's returned? Why does the argument exist at all if it is ignored?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, it was previously impossible to represent a function with a constant return, whose primal result was not required by enzyme

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simultaneously marking a return DuplicatedNoNeed is no longer how to specify that the primal is not needed (eg outside of rules you would do autodiff(ForwardWithPrimal, Duplicated, …)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think I still don't fully understand. Based on the sort! example it seems the only use case would be to dispatch on batched/non-batched mode? But that could be done solely by dispatching on argument xs, it seems? Both versions (batched/non-batched) completely ignore whether the return type is Const/Duplicated/...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that’s one use of it.

But basically they represent different (though related) pieces of information.

const:
Result is not differentiated wrt.
It can either be required to be returned (needs primal) or not
Duplicated:
Result is differentiated wrt
The primal can be required to be returned or not (previously this was represented in dupnoneed vs duplicated, now duplicated is always passed in?
The shadow can be required to be returned or not (needs shadow). Prior to 0.13 it was always required to be returned, even if not needed.

return Ω
end

Ω̇ = if wt_y isa Const && wt_u_hat isa Const && b isa Const
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the bump to 0.13 you can probably do !needs_shadow || (the const conditions)

::Const{typeof(find_alpha)},
::Type{<:Const},
::Nothing,
wt_y::Union{Const,Active},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point why bother having an annotation at all here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to ensure when and with which types of arguments the function is called, to be sure that it is actually doing the correct thing and does not give silently wrong results. Additionally, IMO it makes it easier to read and understand the code (non-annotated arguments and very generic types were quite confusing to me when I tried to read and learn from rules in Enzyme).

config::EnzymeRules.ConfigWidth{1},
::Const{typeof(find_alpha)},
::Type{<:Const},
::Nothing,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this relies on some type stability guarantees, I'd probalby just mark it as tape, without annotation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a fair assumption to make, but also might as well leave it

wt_u_hat::Union{Const,Active},
b::Union{Const,Active},
)
# Trivial case: Nothing to be differentiated (return activity is `Const`)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should also likely be extended for batch size != 1.

Basically in these cases active returns are just an ntuple of the relevant derivatives (in this case zero)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added support for batches in reverse mode (or at least I tried to 😄). How would I go about testing it? AFAICT EnzymeTestUtils.test_reverse won't test batched mode since all activities are either Const or Active.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The easiest way is probably to make a wrapper function that takes and returns variables by reference then calls this function. Then you can use batchduplicated

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to add such a test in 22471da but now julia is aborted in CI due to an LLVM assertion error: https://github.com/TuringLang/Bijectors.jl/actions/runs/11788862392/job/32836666863?pr=349#step:5:488 Do you have any idea what is going and what the problem is?

nothing
else
# Store the partial derivatives
∂find_alpha(Ω, wt_y, wt_u_hat, b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's legal to run ∂find_alpha here, but why not move that to the reverse pass, and potentially cache Ω (or even recompute it if desirable)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this then lets you simplify the reverse rule too since you can make dfind_alpha take ΔΩ_val as an argument and just return it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the advantage of moving it? In ChainRules, usually such computations are intentionally performed outside of the pullback such that they do not have to be recomputed when the pullback is reevaluated (eg this is also done by the @scalar_rule macro).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Less data needs to be cached from fwd to rev. We also know the reverse pass is executed only once

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also know the reverse pass is executed only once

That means the motivation for the common ChainRules design (precomputing as much as possible outside of the pullback) does not exist in Enzyme?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this mentioned in the docs somewhere? I think it would be very useful to explain this, in particular for defining custom rules, I guess people familiar with ChainRules might be biased towards precomputing and caching stuff.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think so, since I at least was unaware of the chainrules convention.

feel free to open a PR

@torfjelde
Copy link
Member

can you help review this PR since you worked on the rules for ReverseDiff?

I have no clue about Enzyme.jl, unfortunately 😕 Seems @wsmoses's feedback is much more useful 👍

@yebai
Copy link
Member

yebai commented Nov 11, 2024

@devmotion perhaps we could merge this PR into #341, which supports Enzyme to v0.13?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants