-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explicit Enzyme rules #349
base: master
Are you sure you want to change the base?
Conversation
bc56823
to
b1049ed
Compare
@torfjelde can you help review this PR since you worked on the rules for ReverseDiff? |
@wsmoses, you might want to look since the added rules are for Enzyme. |
_muladd_partial(x::Real, y::BatchDuplicated, ::Nothing) = map(Base.Fix1(*, x), y.dval) | ||
|
||
function EnzymeRules.forward( | ||
::Const{typeof(find_alpha)}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is using the rule system for Enzyme <=0.12, which was changed in Enzyme 0.13. [or equivalently EnzymeCore 0.7 and 0.8 iirc]. We should probably use the latter.
The breaking change of relevance is that forward mode rules now have a config as the first arg [similar to ReverseMode], with utility functions that specify whether the original return and/or derivative returns are requested.
For example: https://github.com/EnzymeAD/Enzyme.jl/blob/31df08b376634e2926e1d147c67284b20c493c22/src/internal_rules.jl#L858
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is using the rule system for Enzyme <=0.12, which was changed in Enzyme 0.13.
Yes, I know, this is intentional since the master branch uses Enzyme 0.12 and I wanted to keep the update to Enzyme 0.13 and the switch from the imported to the explicit rules separate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The breaking change of relevance is that forward mode rules now have a config as the first arg [similar to ReverseMode], with utility functions that specify whether the original return and/or derivative returns are requested.
I saw this, and one thing that confused me was: What's the purpose of the return type/activity (Const/Duplicated/DuplicatedNoNeed) if it's solely based on needs_primal(config)
and needs_shadow(config)
what's returned? Why does the argument exist at all if it is ignored?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, it was previously impossible to represent a function with a constant return, whose primal result was not required by enzyme
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simultaneously marking a return DuplicatedNoNeed is no longer how to specify that the primal is not needed (eg outside of rules you would do autodiff(ForwardWithPrimal, Duplicated, …)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I think I still don't fully understand. Based on the sort!
example it seems the only use case would be to dispatch on batched/non-batched mode? But that could be done solely by dispatching on argument xs
, it seems? Both versions (batched/non-batched) completely ignore whether the return type is Const/Duplicated/...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So that’s one use of it.
But basically they represent different (though related) pieces of information.
const:
Result is not differentiated wrt.
It can either be required to be returned (needs primal) or not
Duplicated:
Result is differentiated wrt
The primal can be required to be returned or not (previously this was represented in dupnoneed vs duplicated, now duplicated is always passed in?
The shadow can be required to be returned or not (needs shadow). Prior to 0.13 it was always required to be returned, even if not needed.
return Ω | ||
end | ||
|
||
Ω̇ = if wt_y isa Const && wt_u_hat isa Const && b isa Const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with the bump to 0.13 you can probably do !needs_shadow || (the const conditions)
::Const{typeof(find_alpha)}, | ||
::Type{<:Const}, | ||
::Nothing, | ||
wt_y::Union{Const,Active}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at this point why bother having an annotation at all here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to ensure when and with which types of arguments the function is called, to be sure that it is actually doing the correct thing and does not give silently wrong results. Additionally, IMO it makes it easier to read and understand the code (non-annotated arguments and very generic types were quite confusing to me when I tried to read and learn from rules in Enzyme).
config::EnzymeRules.ConfigWidth{1}, | ||
::Const{typeof(find_alpha)}, | ||
::Type{<:Const}, | ||
::Nothing, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this relies on some type stability guarantees, I'd probalby just mark it as tape,
without annotation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is a fair assumption to make, but also might as well leave it
wt_u_hat::Union{Const,Active}, | ||
b::Union{Const,Active}, | ||
) | ||
# Trivial case: Nothing to be differentiated (return activity is `Const`) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should also likely be extended for batch size != 1.
Basically in these cases active returns are just an ntuple of the relevant derivatives (in this case zero)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added support for batches in reverse mode (or at least I tried to 😄). How would I go about testing it? AFAICT EnzymeTestUtils.test_reverse
won't test batched mode since all activities are either Const
or Active
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The easiest way is probably to make a wrapper function that takes and returns variables by reference then calls this function. Then you can use batchduplicated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to add such a test in 22471da but now julia is aborted in CI due to an LLVM assertion error: https://github.com/TuringLang/Bijectors.jl/actions/runs/11788862392/job/32836666863?pr=349#step:5:488 Do you have any idea what is going and what the problem is?
ext/BijectorsEnzymeCoreExt.jl
Outdated
nothing | ||
else | ||
# Store the partial derivatives | ||
∂find_alpha(Ω, wt_y, wt_u_hat, b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know it's legal to run ∂find_alpha here, but why not move that to the reverse pass, and potentially cache Ω (or even recompute it if desirable)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this then lets you simplify the reverse rule too since you can make dfind_alpha take ΔΩ_val as an argument and just return it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would be the advantage of moving it? In ChainRules, usually such computations are intentionally performed outside of the pullback such that they do not have to be recomputed when the pullback is reevaluated (eg this is also done by the @scalar_rule
macro).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Less data needs to be cached from fwd to rev. We also know the reverse pass is executed only once
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also know the reverse pass is executed only once
That means the motivation for the common ChainRules design (precomputing as much as possible outside of the pullback) does not exist in Enzyme?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this mentioned in the docs somewhere? I think it would be very useful to explain this, in particular for defining custom rules, I guess people familiar with ChainRules might be biased towards precomputing and caching stuff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t think so, since I at least was unaware of the chainrules convention.
feel free to open a PR
I have no clue about Enzyme.jl, unfortunately 😕 Seems @wsmoses's feedback is much more useful 👍 |
@devmotion perhaps we could merge this PR into #341, which supports Enzyme to v0.13? |
I started to implement the Enzyme rule for
find_alpha
directly, without importing the ChainRules. This should be more performant and less brittle given that the@import_frule
and@import_rrule
are not very polished (I think there are a few bugs actually but I haven't checked it with a MWE yet) and even advised against in the docstrings.Tests with EnzymeTestUtils pass successfully on Julia 1.10. I assume Julia 1.6 might be too old for Enzyme and 1.11 only properly supported by Enzyme 0.13.
Benchmark
A simple benchmark:
master
This PR