Skip to content

Commit

Permalink
tweaks, but more is needed
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Dec 3, 2024
1 parent aca91d0 commit fb630c0
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions docs/src/tutorials/gradient_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ Reverse-mode source-to-source automatic differentiation, written by hooking into

* By far the best-tested option for Flux models.

* Long compilation times, on the first call.
* Medium compilation times, on the first call.

* Allows mutation of structs, but not of arrays. This leads to the most common error... sometimes this happens because you mutate an array, often because you call some function which, internally, creates the array it wants to return & then fills it in.

Expand All @@ -175,13 +175,16 @@ Zygote.jacobian(x -> mysum2(x).^2, Float32[1 2 3; 4 5 6])[1] # ERROR: Mutating
```

* Custom rules via `ZygoteRules.@adjpoint` or (equivalently) `ChainRulesCore.rrule`.
Among other things, this lets you wrap functions which internally mutate an array, so that Zygote need not look inside.

* Returns nested NamedTuples and Tuples, and uses `nothing` to mean zero.

* Does not track shared arrays, hence may return different contributions.

```julia

shared = [2.0]
nt = (a = shared, b = shared, c = [2.0])
Zygote.gradient(x -> sum(abs2, x.a + 2*x.b + 3*x.c), nt)[1] # (a = [24.0], b = [48.0], c = [72.0])
```

!!! compat "Deprecated: Zygote's implicit mode"
Expand Down Expand Up @@ -255,8 +258,16 @@ New package which works on the LLVM code which Julia compiles down to.

* Returns another struct of the same type as the model, such as `Chain` above. Non-differentiable objects are left alone, not replaced by a zero.

* Shared arrays are shared in the gradient:

```julia
shared = [2.0]
nt = (a = shared, b = shared, c = [2.0])
Enzyme.gradient(Reverse, x -> sum(abs2, x.a + 2*x.b + 3*x.c), nt)[1] # (a = [72.0], b = [72.0], c = [72.0])
```

Enzyme likes to work in-place, with objects and their gradients stored togeter in a `Duplicated(x, dx)`.
Flux has an interface which uses this:
Flux now has an interface which uses this:
```julia
julia> Flux.train!((m,x) -> sum(abs2, m(1)), model, 1:1, opt_state) # train! with Zygote

Expand All @@ -274,7 +285,15 @@ julia> Flux.withgradient(loss, Duplicated(model))
### [Mooncake.jl](https://github.com/compintell/Mooncake.jl)

Another new AD to watch. Many similariries in its approach to Enzyme.jl, but operates all in Julia.
[Fluxperimental.jl](https://github.com/FluxML/Fluxperimental.jl) has an interface to try this out:

```julia
julia> grads_m2 = Flux.gradient(loss, Moonduo(model))
((layers = ((weight = [-0.18171549534589682 0.0 0.0; 0.18171549534589682 0.0 0.0],), nothing),),)

julia> Flux.withgradient(loss, Moonduo(model))
(val = 0.5665111155481435, grad = ((layers = ((weight = [-0.15810298866515066 0.0 0.0; 0.1581029886651505 0.0 0.0],), nothing),),))
```

### [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl)

Expand Down Expand Up @@ -303,7 +322,8 @@ Another Julia source-to-source reverse-mode AD.

### [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl)

Forward mode is a different algorithm...
Forward mode AD is a different algorithm, which is easier to implement. This is a reliable old package,
but is of limited interest for use with Flux:

* Needs a simple array of parameters, i.e. supports only `gradient(f, x::AbstractArray{<:Real})`.

Expand All @@ -316,9 +336,9 @@ Forward mode is a different algorithm...

* Like Tracker this passes a special TrackedArray type through your function. Allows you to record & compile the tape, and pre-allocate things.

* Needs a flat vector
* Like ForwardDiff it needs a flat vector, only `gradient(f, x::AbstractArray{<:Real})`.

* No support for GPU
* No support for GPU operations.


<hr/>
Expand All @@ -343,15 +363,15 @@ I haven't tried really, but I think it ought to work.

## Meta-packages

Besides AD packages, several packages have been written aiming to provide a unified interface to many options. These may offer useful ways to quickly switch between things you are trying.
Besides AD packages, several packages have been written aiming to provide a unified interface to many options. These may offer useful ways to quickly switch between things you are trying. However, Flux does not directly interface with any of them.

### [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl)

The original meta-package for calling any of several engines.

### [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl)

This year's new attempt to build a simpler one?
This year's new attempt to build a simpler such meta-package. However, from Flux's point of view

### [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)

Expand Down

0 comments on commit fb630c0

Please sign in to comment.