Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING] Force use of Tangents in pushforward, pullback and hvp #455

Merged
merged 12 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions DifferentiationInterface/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ DifferentiationInterface

## First order

```@docs
Tangents
```

### Pushforward

```@docs
Expand Down
35 changes: 18 additions & 17 deletions DifferentiationInterface/docs/src/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,17 @@ These operators are computed using only the input `x`.

### Low-level operators

These operators are computed using the input `x` and a "seed" `v`, which lives either
These operators are computed using the input `x` and a tangent `t` of type [`Tangents`](@ref).
This tangent is essentially an `NTuple`, whose elements live either

- in the same space as `x` (we call it `dx`)
- or in the same space as `y` (we call it `dy`)
- in the same space as `x` (we call it `tx`)
- or in the same space as `y` (we call it `ty`)

| operator | order | input `x` | output `y` | seed `v` | operator result type | operator result shape |
| :-------------------------- | :---- | :-------------- | :----------- | :------- | :------------------- | :-------------------- |
| [`pushforward`](@ref) (JVP) | 1 | `Any` | `Any` | `dx` | same as `y` | `size(y)` |
| [`pullback`](@ref) (VJP) | 1 | `Any` | `Any` | `dy` | same as `x` | `size(x)` |
| [`hvp`](@ref) | 2 | `AbstractArray` | `Number` | `dx` | same as `x` | `size(x)` |
| operator | order | input `x` | output `y` | tangent `t` | operator result type | operator result shape |
| :-------------------------- | :---- | :-------------- | :----------- | :---------- | :------------------- | :-------------------- |
| [`pushforward`](@ref) (JVP) | 1 | `Any` | `Any` | `tx` | same as `y` | `size(y)` |
| [`pullback`](@ref) (VJP) | 1 | `Any` | `Any` | `ty` | same as `x` | `size(x)` |
| [`hvp`](@ref) | 2 | `AbstractArray` | `Number` | `tx` | same as `x` | `size(x)` |

## Variants

Expand Down Expand Up @@ -73,8 +74,8 @@ This results in various operator signatures (the necessary arguments and their o

| function signature | out-of-place operator | in-place operator |
| :-------------------- | :--------------------------- | :------------------------------------ |
| out-of-place function | `op(f, backend, x, [v])` | `op!(f, result, backend, x, [v])` |
| in-place function | `op(f!, y, backend, x, [v])` | `op!(f!, y, result, backend, x, [v])` |
| out-of-place function | `op(f, backend, x, [t])` | `op!(f, result, backend, x, [t])` |
| in-place function | `op(f!, y, backend, x, [t])` | `op!(f!, y, result, backend, x, [t])` |

!!! warning
The positional arguments between `f`/`f!` and `backend` are always mutated.
Expand Down Expand Up @@ -103,15 +104,15 @@ In addition, the preparation syntax depends on the number of arguments accepted

| function signature | preparation signature |
| :-------------------- | :----------------------------------- |
| out-of-place function | `prepare_op(f, backend, x, [v])` |
| in-place function | `prepare_op(f!, y, backend, x, [v])` |
| out-of-place function | `prepare_op(f, backend, x, [t])` |
| in-place function | `prepare_op(f!, y, backend, x, [t])` |

Preparation creates an object called `extras` which contains the the necessary information to speed up an operator and its variants.
The idea is that you prepare only once, which can be costly, but then call the operator several times while reusing the same `extras`.

```julia
op(f, backend, x, [v]) # slow because it includes preparation
op(f, extras, backend, x, [v]) # fast because it skips preparation
op(f, backend, x, [t]) # slow because it includes preparation
op(f, extras, backend, x, [t]) # fast because it skips preparation
```

!!! warning
Expand All @@ -124,9 +125,9 @@ Here are the general rules that we strive to implement:

| | different point | same point |
| :------------------------ | :--------------------------------------- | :--------------------------------------- |
| the output `extras` of... | `prepare_op(f, b, x)` | `prepare_op_same_point(f, b, x, v)` |
| can be used in... | `op(f, extras, b, other_x)` | `op(f, extras, b, x, other_v)` |
| provided that... | `other_x` has same type and shape as `x` | `other_v` has same type and shape as `v` |
| the output `extras` of... | `prepare_op(f, b, x)` | `prepare_op_same_point(f, b, x, t)` |
| can be used in... | `op(f, extras, b, other_x)` | `op(f, extras, b, x, other_t)` |
| provided that... | `other_x` has same type and shape as `x` | `other_t` has same type and shape as `t` |

These rules hold for the majority of backends, but there are some exceptions: see [this page](@ref "Preparation") to know more.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
# forward mode unused for lack of implementations
#=
function ChainRulesCore.frule((_, dx), dw::DifferentiateWith, x)
@compat (; f, backend) = dw
y, dy = DI.value_and_pushforward(f, backend, x, dx)
return y, dy
end
=#

function ChainRulesCore.rrule(dw::DifferentiateWith, x)
@compat (; f, backend) = dw
y = f(x)
extras_same = DI.prepare_pullback_same_point(f, backend, x, y)
pullbackfunc(dy) = (NoTangent(), DI.pullback(f, extras_same, backend, x, dy))
extras_same = DI.prepare_pullback_same_point(f, backend, x, DI.Tangents(y))
function pullbackfunc(dy)
tx = DI.pullback(f, extras_same, backend, x, DI.Tangents(dy))
return (NoTangent(), only(tx))
end
return y, pullbackfunc
end
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,28 @@ function DI.value_and_pullback(
)
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x)
return y, Tangents(last.(pb.(ty.d)))
tx = map(ty) do dy
last(pb(dy))
end
return y, tx
end

function DI.value_and_pullback(
f, extras::ChainRulesPullbackExtrasSamePoint, ::AutoReverseChainRules, x, ty::Tangents
)
@compat (; y, pb) = extras
return copy(y), Tangents(last.(pb.(ty.d)))
tx = map(ty) do dy
last(pb(dy))
end
return copy(y), tx
end

function DI.pullback(
f, extras::ChainRulesPullbackExtrasSamePoint, ::AutoReverseChainRules, x, ty::Tangents
)
@compat (; pb) = extras
return Tangents(last.(pb.(ty.d)))
tx = map(ty) do dy
last(pb(dy))
end
return tx
end
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()
DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::Tangents) = NoPushforwardExtras()

function DI.pushforward(f, ::NoPushforwardExtras, ::AutoDiffractor, x, tx::Tangents)
dys = map(tx.d) do dx
ty = map(tx) do dx
# code copied from Diffractor.jl
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
dy = z[TaylorTangentIndex(1)]
end
return Tangents(dys)
return ty
end

function DI.value_and_pushforward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ using DifferentiationInterface:
NoPullbackExtras,
NoPushforwardExtras,
Tangents,
SingleTangent,
pick_batchsize
using Enzyme:
Active,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ function DI.value_and_pushforward(
x,
tx::Tangents,
)
dys = map(tx.d) do dx
DI.pushforward(f, extras, backend, x, dx)
ty = map(tx) do dx
only(DI.pushforward(f, extras, backend, x, Tangents(dx)))
end
y = f(x)
return y, Tangents(dys)
return y, ty
end

function DI.value_and_pushforward(
Expand All @@ -36,7 +36,7 @@ function DI.value_and_pushforward(
else
autodiff(forward_mode(backend), f_and_df, Duplicated, x_and_dx)
end
return y, SingleTangent(new_dy)
return y, Tangents(new_dy)
end

function DI.pushforward(
Expand All @@ -55,7 +55,7 @@ function DI.pushforward(
else
only(autodiff(forward_mode(backend), f_and_df, DuplicatedNoNeed, x_and_dx))
end
return SingleTangent(new_dy)
return Tangents(new_dy)
end

function DI.value_and_pushforward!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ function DI.value_and_pushforward(
x,
tx::Tangents,
)
dys = map(tx.d) do dx
DI.pushforward(f!, y, extras, backend, x, dx)
ty = map(tx) do dx
only(DI.pushforward(f!, y, extras, backend, x, Tangents(dx)))
end
f!(y, x)
return y, Tangents(dys)
return y, ty
end

function DI.value_and_pushforward(
Expand All @@ -40,5 +40,5 @@ function DI.value_and_pushforward(
else
autodiff(forward_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx)
end
return y, SingleTangent(dy_sametype)
return y, Tangents(dy_sametype)
end
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ function DI.value_and_pullback(
x,
ty::Tangents,
)
dxs = map(ty.d) do dy
only(DI.pullback(f, extras, backend, x, SingleTangent(dy)))
tx = map(ty) do dy
only(DI.pullback(f, extras, backend, x, Tangents(dy)))
end
y = f(x)
return y, Tangents(dxs)
return y, tx
end

### Out-of-place
Expand All @@ -38,7 +38,7 @@ function DI.value_and_pullback(
autodiff(ReverseWithPrimal, f_and_df, Active, Active(x))
end
new_dx = dy * only(der)
return y, SingleTangent(new_dx)
return y, Tangents(new_dx)
else
dy = only(ty)
f_and_df = force_annotation(get_f_and_df(f, backend))
Expand All @@ -51,7 +51,7 @@ function DI.value_and_pullback(
tape, y, new_dy = forw(f_and_df, Active(x))
copyto!(new_dy, dy)
new_dx = only(only(rev(f_and_df, Active(x), tape)))
return y, SingleTangent(new_dx)
return y, Tangents(new_dx)
end
end

Expand All @@ -76,10 +76,10 @@ function DI.value_and_pullback(
# TODO: generalize beyond Arrays?
dx_sametype .*= dy
end
return y, SingleTangent(dx_sametype)
return y, Tangents(dx_sametype)
else
dx = make_zero(x)
return DI.value_and_pullback!(f, SingleTangent(dx), extras, backend, x, ty)
return DI.value_and_pullback!(f, Tangents(dx), extras, backend, x, ty)
end
end

Expand Down Expand Up @@ -201,7 +201,8 @@ function DI.value_and_gradient(
backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
)
return DI.value_and_pullback(f, NoPullbackExtras(), backend, x, true)
y, tx = DI.value_and_pullback(f, NoPullbackExtras(), backend, x, Tangents(true))
return y, only(tx)
end

function DI.value_and_gradient!(
Expand All @@ -211,7 +212,10 @@ function DI.value_and_gradient!(
backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
)
return DI.value_and_pullback!(f, grad, NoPullbackExtras(), backend, x, true)
y, _ = DI.value_and_pullback!(
f, Tangents(grad), NoPullbackExtras(), backend, x, Tangents(true)
)
return y, grad
end

## Jacobian
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ function DI.value_and_pullback(
x,
ty::Tangents,
)
dxs = map(ty.d) do dy
only(DI.pullback(f!, y, extras, backend, x, SingleTangent(dy)))
tx = map(ty) do dy
only(DI.pullback(f!, y, extras, backend, x, Tangents(dy)))
end
f!(y, x)
return y, Tangents(dxs)
return y, tx
end

function DI.value_and_pullback(
Expand All @@ -38,7 +38,7 @@ function DI.value_and_pullback(
else
only(autodiff(reverse_mode(backend), f!_and_df!, Const, y_and_dy, Active(x)))
end
return y, SingleTangent(new_dx)
return y, Tangents(new_dx)
end

function DI.value_and_pullback(
Expand All @@ -60,5 +60,5 @@ function DI.value_and_pullback(
else
autodiff(reverse_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx)
end
return y, SingleTangent(dx_sametype)
return y, Tangents(dx_sametype)
end
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function DI.prepare_hvp(f, ::AnyAutoEnzyme{Nothing,Nothing}, x, tx::Tangents{1})
end

function DI.hvp(f, ::NoHVPExtras, ::AnyAutoEnzyme{Nothing,Nothing}, x, tx::Tangents{1})
return SingleTangent(hvp(f, x, only(tx)))
return Tangents(hvp(f, x, only(tx)))
end

function DI.hvp!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ function DI.pushforward(
x,
tx::Tangents,
)
dys = map(tx.d) do dx
ty = map(tx) do dx
v_vec = vcat(myvec(x), myvec(dx))
if extras.y_prototype isa Number
return only(extras.jvp_exe(v_vec))
else
return reshape(extras.jvp_exe(v_vec), size(extras.y_prototype))
end
end
return Tangents(dys)
return ty
end

function DI.pushforward!(
Expand Down Expand Up @@ -108,15 +108,15 @@ function DI.pullback(
x,
ty::Tangents,
)
dxs = map(ty.d) do dy
tx = map(ty) do dy
v_vec = vcat(myvec(x), myvec(dy))
if x isa Number
return only(extras.vjp_exe(v_vec))
else
return reshape(extras.vjp_exe(v_vec), size(x))
end
end
return Tangents(dxs)
return tx
end

function DI.pullback!(
Expand Down Expand Up @@ -426,12 +426,12 @@ end
function DI.hvp(
f, extras::FastDifferentiationHVPExtras, ::AutoFastDifferentiation, x, tx::Tangents
)
dgs = map(tx.d) do dx
tg = map(tx) do dx
v_vec = vcat(vec(x), vec(dx))
dg_vec = extras.hvp_exe(v_vec)
return reshape(dg_vec, size(x))
end
return Tangents(dgs)
return tg
end

function DI.hvp!(
Expand Down
Loading
Loading