Skip to content

Commit

Permalink
[BREAKING] Force use of Tangents in pushforward, pullback and `…
Browse files Browse the repository at this point in the history
…hvp` (#455)

* Force use of `Tangents`

* Fix

* Add `map` and other utilities

* Add methods

* Scenarios

* Fixes

* Fixes

* Fix

* Coverage and docs

* Avoid duplicate imports

* Map

* Dix focs
  • Loading branch information
gdalle authored Sep 7, 2024
1 parent 25289d9 commit ccf5247
Show file tree
Hide file tree
Showing 54 changed files with 437 additions and 386 deletions.
4 changes: 4 additions & 0 deletions DifferentiationInterface/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ DifferentiationInterface

## First order

```@docs
Tangents
```

### Pushforward

```@docs
Expand Down
35 changes: 18 additions & 17 deletions DifferentiationInterface/docs/src/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,17 @@ These operators are computed using only the input `x`.

### Low-level operators

These operators are computed using the input `x` and a "seed" `v`, which lives either
These operators are computed using the input `x` and a tangent `t` of type [`Tangents`](@ref).
This tangent is essentially an `NTuple`, whose elements live either

- in the same space as `x` (we call it `dx`)
- or in the same space as `y` (we call it `dy`)
- in the same space as `x` (we call it `tx`)
- or in the same space as `y` (we call it `ty`)

| operator | order | input `x` | output `y` | seed `v` | operator result type | operator result shape |
| :-------------------------- | :---- | :-------------- | :----------- | :------- | :------------------- | :-------------------- |
| [`pushforward`](@ref) (JVP) | 1 | `Any` | `Any` | `dx` | same as `y` | `size(y)` |
| [`pullback`](@ref) (VJP) | 1 | `Any` | `Any` | `dy` | same as `x` | `size(x)` |
| [`hvp`](@ref) | 2 | `AbstractArray` | `Number` | `dx` | same as `x` | `size(x)` |
| operator | order | input `x` | output `y` | tangent `t` | operator result type | operator result shape |
| :-------------------------- | :---- | :-------------- | :----------- | :---------- | :------------------- | :-------------------- |
| [`pushforward`](@ref) (JVP) | 1 | `Any` | `Any` | `tx` | same as `y` | `size(y)` |
| [`pullback`](@ref) (VJP) | 1 | `Any` | `Any` | `ty` | same as `x` | `size(x)` |
| [`hvp`](@ref) | 2 | `AbstractArray` | `Number` | `tx` | same as `x` | `size(x)` |

## Variants

Expand Down Expand Up @@ -73,8 +74,8 @@ This results in various operator signatures (the necessary arguments and their o

| function signature | out-of-place operator | in-place operator |
| :-------------------- | :--------------------------- | :------------------------------------ |
| out-of-place function | `op(f, backend, x, [v])` | `op!(f, result, backend, x, [v])` |
| in-place function | `op(f!, y, backend, x, [v])` | `op!(f!, y, result, backend, x, [v])` |
| out-of-place function | `op(f, backend, x, [t])` | `op!(f, result, backend, x, [t])` |
| in-place function | `op(f!, y, backend, x, [t])` | `op!(f!, y, result, backend, x, [t])` |

!!! warning
The positional arguments between `f`/`f!` and `backend` are always mutated.
Expand Down Expand Up @@ -103,15 +104,15 @@ In addition, the preparation syntax depends on the number of arguments accepted

| function signature | preparation signature |
| :-------------------- | :----------------------------------- |
| out-of-place function | `prepare_op(f, backend, x, [v])` |
| in-place function | `prepare_op(f!, y, backend, x, [v])` |
| out-of-place function | `prepare_op(f, backend, x, [t])` |
| in-place function | `prepare_op(f!, y, backend, x, [t])` |

Preparation creates an object called `extras` which contains the the necessary information to speed up an operator and its variants.
The idea is that you prepare only once, which can be costly, but then call the operator several times while reusing the same `extras`.

```julia
op(f, backend, x, [v]) # slow because it includes preparation
op(f, extras, backend, x, [v]) # fast because it skips preparation
op(f, backend, x, [t]) # slow because it includes preparation
op(f, extras, backend, x, [t]) # fast because it skips preparation
```

!!! warning
Expand All @@ -124,9 +125,9 @@ Here are the general rules that we strive to implement:

| | different point | same point |
| :------------------------ | :--------------------------------------- | :--------------------------------------- |
| the output `extras` of... | `prepare_op(f, b, x)` | `prepare_op_same_point(f, b, x, v)` |
| can be used in... | `op(f, extras, b, other_x)` | `op(f, extras, b, x, other_v)` |
| provided that... | `other_x` has same type and shape as `x` | `other_v` has same type and shape as `v` |
| the output `extras` of... | `prepare_op(f, b, x)` | `prepare_op_same_point(f, b, x, t)` |
| can be used in... | `op(f, extras, b, other_x)` | `op(f, extras, b, x, other_t)` |
| provided that... | `other_x` has same type and shape as `x` | `other_t` has same type and shape as `t` |

These rules hold for the majority of backends, but there are some exceptions: see [this page](@ref "Preparation") to know more.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
# forward mode unused for lack of implementations
#=
function ChainRulesCore.frule((_, dx), dw::DifferentiateWith, x)
@compat (; f, backend) = dw
y, dy = DI.value_and_pushforward(f, backend, x, dx)
return y, dy
end
=#

function ChainRulesCore.rrule(dw::DifferentiateWith, x)
@compat (; f, backend) = dw
y = f(x)
extras_same = DI.prepare_pullback_same_point(f, backend, x, y)
pullbackfunc(dy) = (NoTangent(), DI.pullback(f, extras_same, backend, x, dy))
extras_same = DI.prepare_pullback_same_point(f, backend, x, DI.Tangents(y))
function pullbackfunc(dy)
tx = DI.pullback(f, extras_same, backend, x, DI.Tangents(dy))
return (NoTangent(), only(tx))
end
return y, pullbackfunc
end
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,28 @@ function DI.value_and_pullback(
)
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x)
return y, Tangents(last.(pb.(ty.d)))
tx = map(ty) do dy
last(pb(dy))
end
return y, tx
end

function DI.value_and_pullback(
f, extras::ChainRulesPullbackExtrasSamePoint, ::AutoReverseChainRules, x, ty::Tangents
)
@compat (; y, pb) = extras
return copy(y), Tangents(last.(pb.(ty.d)))
tx = map(ty) do dy
last(pb(dy))
end
return copy(y), tx
end

function DI.pullback(
f, extras::ChainRulesPullbackExtrasSamePoint, ::AutoReverseChainRules, x, ty::Tangents
)
@compat (; pb) = extras
return Tangents(last.(pb.(ty.d)))
tx = map(ty) do dy
last(pb(dy))
end
return tx
end
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()
DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::Tangents) = NoPushforwardExtras()

function DI.pushforward(f, ::NoPushforwardExtras, ::AutoDiffractor, x, tx::Tangents)
dys = map(tx.d) do dx
ty = map(tx) do dx
# code copied from Diffractor.jl
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
dy = z[TaylorTangentIndex(1)]
end
return Tangents(dys)
return ty
end

function DI.value_and_pushforward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ using DifferentiationInterface:
NoPullbackExtras,
NoPushforwardExtras,
Tangents,
SingleTangent,
pick_batchsize
using Enzyme:
Active,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ function DI.value_and_pushforward(
x,
tx::Tangents,
)
dys = map(tx.d) do dx
DI.pushforward(f, extras, backend, x, dx)
ty = map(tx) do dx
only(DI.pushforward(f, extras, backend, x, Tangents(dx)))
end
y = f(x)
return y, Tangents(dys)
return y, ty
end

function DI.value_and_pushforward(
Expand All @@ -36,7 +36,7 @@ function DI.value_and_pushforward(
else
autodiff(forward_mode(backend), f_and_df, Duplicated, x_and_dx)
end
return y, SingleTangent(new_dy)
return y, Tangents(new_dy)
end

function DI.pushforward(
Expand All @@ -55,7 +55,7 @@ function DI.pushforward(
else
only(autodiff(forward_mode(backend), f_and_df, DuplicatedNoNeed, x_and_dx))
end
return SingleTangent(new_dy)
return Tangents(new_dy)
end

function DI.value_and_pushforward!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ function DI.value_and_pushforward(
x,
tx::Tangents,
)
dys = map(tx.d) do dx
DI.pushforward(f!, y, extras, backend, x, dx)
ty = map(tx) do dx
only(DI.pushforward(f!, y, extras, backend, x, Tangents(dx)))
end
f!(y, x)
return y, Tangents(dys)
return y, ty
end

function DI.value_and_pushforward(
Expand All @@ -40,5 +40,5 @@ function DI.value_and_pushforward(
else
autodiff(forward_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx)
end
return y, SingleTangent(dy_sametype)
return y, Tangents(dy_sametype)
end
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ function DI.value_and_pullback(
x,
ty::Tangents,
)
dxs = map(ty.d) do dy
only(DI.pullback(f, extras, backend, x, SingleTangent(dy)))
tx = map(ty) do dy
only(DI.pullback(f, extras, backend, x, Tangents(dy)))
end
y = f(x)
return y, Tangents(dxs)
return y, tx
end

### Out-of-place
Expand All @@ -38,7 +38,7 @@ function DI.value_and_pullback(
autodiff(ReverseWithPrimal, f_and_df, Active, Active(x))
end
new_dx = dy * only(der)
return y, SingleTangent(new_dx)
return y, Tangents(new_dx)
else
dy = only(ty)
f_and_df = force_annotation(get_f_and_df(f, backend))
Expand All @@ -51,7 +51,7 @@ function DI.value_and_pullback(
tape, y, new_dy = forw(f_and_df, Active(x))
copyto!(new_dy, dy)
new_dx = only(only(rev(f_and_df, Active(x), tape)))
return y, SingleTangent(new_dx)
return y, Tangents(new_dx)
end
end

Expand All @@ -76,10 +76,10 @@ function DI.value_and_pullback(
# TODO: generalize beyond Arrays?
dx_sametype .*= dy
end
return y, SingleTangent(dx_sametype)
return y, Tangents(dx_sametype)
else
dx = make_zero(x)
return DI.value_and_pullback!(f, SingleTangent(dx), extras, backend, x, ty)
return DI.value_and_pullback!(f, Tangents(dx), extras, backend, x, ty)
end
end

Expand Down Expand Up @@ -201,7 +201,8 @@ function DI.value_and_gradient(
backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
)
return DI.value_and_pullback(f, NoPullbackExtras(), backend, x, true)
y, tx = DI.value_and_pullback(f, NoPullbackExtras(), backend, x, Tangents(true))
return y, only(tx)
end

function DI.value_and_gradient!(
Expand All @@ -211,7 +212,10 @@ function DI.value_and_gradient!(
backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
)
return DI.value_and_pullback!(f, grad, NoPullbackExtras(), backend, x, true)
y, _ = DI.value_and_pullback!(
f, Tangents(grad), NoPullbackExtras(), backend, x, Tangents(true)
)
return y, grad
end

## Jacobian
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ function DI.value_and_pullback(
x,
ty::Tangents,
)
dxs = map(ty.d) do dy
only(DI.pullback(f!, y, extras, backend, x, SingleTangent(dy)))
tx = map(ty) do dy
only(DI.pullback(f!, y, extras, backend, x, Tangents(dy)))
end
f!(y, x)
return y, Tangents(dxs)
return y, tx
end

function DI.value_and_pullback(
Expand All @@ -38,7 +38,7 @@ function DI.value_and_pullback(
else
only(autodiff(reverse_mode(backend), f!_and_df!, Const, y_and_dy, Active(x)))
end
return y, SingleTangent(new_dx)
return y, Tangents(new_dx)
end

function DI.value_and_pullback(
Expand All @@ -60,5 +60,5 @@ function DI.value_and_pullback(
else
autodiff(reverse_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx)
end
return y, SingleTangent(dx_sametype)
return y, Tangents(dx_sametype)
end
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function DI.prepare_hvp(f, ::AnyAutoEnzyme{Nothing,Nothing}, x, tx::Tangents{1})
end

function DI.hvp(f, ::NoHVPExtras, ::AnyAutoEnzyme{Nothing,Nothing}, x, tx::Tangents{1})
return SingleTangent(hvp(f, x, only(tx)))
return Tangents(hvp(f, x, only(tx)))
end

function DI.hvp!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ function DI.pushforward(
x,
tx::Tangents,
)
dys = map(tx.d) do dx
ty = map(tx) do dx
v_vec = vcat(myvec(x), myvec(dx))
if extras.y_prototype isa Number
return only(extras.jvp_exe(v_vec))
else
return reshape(extras.jvp_exe(v_vec), size(extras.y_prototype))
end
end
return Tangents(dys)
return ty
end

function DI.pushforward!(
Expand Down Expand Up @@ -108,15 +108,15 @@ function DI.pullback(
x,
ty::Tangents,
)
dxs = map(ty.d) do dy
tx = map(ty) do dy
v_vec = vcat(myvec(x), myvec(dy))
if x isa Number
return only(extras.vjp_exe(v_vec))
else
return reshape(extras.vjp_exe(v_vec), size(x))
end
end
return Tangents(dxs)
return tx
end

function DI.pullback!(
Expand Down Expand Up @@ -426,12 +426,12 @@ end
function DI.hvp(
f, extras::FastDifferentiationHVPExtras, ::AutoFastDifferentiation, x, tx::Tangents
)
dgs = map(tx.d) do dx
tg = map(tx) do dx
v_vec = vcat(vec(x), vec(dx))
dg_vec = extras.hvp_exe(v_vec)
return reshape(dg_vec, size(x))
end
return Tangents(dgs)
return tg
end

function DI.hvp!(
Expand Down
Loading

0 comments on commit ccf5247

Please sign in to comment.