Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: strip most types from gradient output #1362

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extensions]
EnzymeSparseArraysExt = "SparseArrays"
EnzymeSpecialFunctionsExt = "SpecialFunctions"
EnzymeStaticArraysExt = "StaticArrays"

[compat]
CEnum = "0.4, 0.5"
Expand Down
9 changes: 9 additions & 0 deletions ext/EnzymeSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module EnzymeSparseArraysExt

using SparseArrays
using Enzyme

Enzyme.strip_types(x::SparseVector{<:Enzyme.GoodNum}) = x
Enzyme.strip_types(x::SparseMatrixCSC{<:Enzyme.GoodNum}) = x

end
9 changes: 9 additions & 0 deletions ext/EnzymeStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module EnzymeStaticArraysExt

using StaticArrays
using Enzyme

Enzyme.strip_types(x::StaticArrays.SArray{<:Any, <:Enzyme.GoodNum}) = x
Enzyme.strip_types(x::StaticArrays.MArray{<:Any, <:Enzyme.GoodNum}) = x

end
103 changes: 99 additions & 4 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,8 @@ This will allocate and return new array `make_zero(x)` with the gradient result.
Besides arrays, for struct `x` it returns another instance of the same type,
whose fields contain the components of the gradient.
In the result, `grad.a` contains `∂f/∂x.a` for any differential `x.a`,
while `grad.c == x.c` for other types.
while `grad.c == nothing` for other types.
The result is post-processed to remove types from `make_zero(x)`.

Examples:

Expand All @@ -925,18 +926,112 @@ grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str"))

# output

(a = 3.0, b = [2.0], c = "str")
(a = 3.0, b = [2.0], c = nothing)
```
"""
@inline function gradient(::ReverseMode, f::F, x::X) where {F, X}
if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState
dx = Ref(make_zero(x))
autodiff(Reverse, f∘only, Active, Duplicated(Ref(x), dx))
return only(dx)
return strip_types(only(dx))
else
dx = make_zero(x)
autodiff(Reverse, f, Active, Duplicated(x, dx))
return dx
return strip_types(dx)
end
end


"""
strip_types(x)

Aims to replace all structs with plain NamedTuples,
except for types which are their own natural cotangent representations,
such as Complex, Diagonal, Transpose.

Maps non-diff elements to nothing (whereas `make_zero` preserves their values).

Ignores mutability (whereas `make_zero` uses a cache to preserve identifications).

!!! warning
This function is used by `gradient` to clean up its return, but it is not public at present.
Extending it in your package is not recommended.

# Examples
```
julia> [1,2,3.]' |> strip_types
1×3 adjoint(::Vector{Float64}) with eltype Float64:
1.0 2.0 3.0

julia> [1,2,3]' |> strip_types
(parent = nothing,)

julia> Symmetric([1 2; 3 4.]) |> strip_types
(data = [1.0 2.0; 3.0 4.0], uplo = nothing)

julia> LinRange(1,2,3) |> strip_types
(start = 1.0, stop = 2.0, len = nothing, lendiv = nothing)

julia> Ref([[1,2,3.]', (4:6)', (7.0, 16//2)]) |> strip_types
(x = Any[[1.0 2.0 3.0], (parent = (start = nothing, stop = nothing),), (7.0, (num = nothing, den = nothing))],)
```
"""
function strip_types end

export strip_types

const GoodNum = Union{AbstractFloat, Complex{<:AbstractFloat}}
strip_types(x::GoodNum) = x
strip_types(x::Array{<:GoodNum}) = x
# strip_types(x::DenseArray{<:GoodNum}) = x
# strip_types(x::Base.TwicePrecision{<:GoodNum}) = x

# Non-differentiable types
strip_types(x::Integer) = nothing # zero(x)
strip_types(x::Union{Symbol, Char, AbstractString, Nothing}) = nothing # x
strip_types(x::Array{<:Integer}) = nothing
strip_types(x::Array{<:Union{Symbol, Char, AbstractString, Nothing}}) = nothing

# Containers to recurse into
strip_types(x::Union{Tuple, NamedTuple, Array}) = map(strip_types, x) # need to worry about undef?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work, for example, for a recursive type (and will instead infinite loop presumably?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I guess the implementation is a sketch, and a real one would need some kind of IdDict cache for this purpose too.


function strip_types(x::T) where T
# Doing this would trim many branches of nothings, alla Zygote, confusing to handle?
# Compiler.guaranteed_const_nongen(T, nothing) && return nothing
if Base.issingletontype(T)
T <: Function || @info "non-Function singleton" T
return nothing # maybe this also trims some branches
elseif !Base.isstructtype(T)
@info "not isstructtype" T
return x
elseif fieldcount(T) == 0
@info "zero fields" T
return x
else
return strip_namedtuple(x)
end
end
function strip_namedtuple(x::T) where T
names = fieldnames(T)
# may need to think about unassigned fields?
tup = map(n -> strip_types(getfield(x, n)), names)
return NamedTuple{names}(tup)
end

for wrap in (:SubArray, :Slices, :PermutedDimsArray, :ReshapedArray)
@eval function strip_types(x::Base.$wrap)
nt = strip_namedtuple(x)
x.parent === nt.parent ? x : nt
end
end

# LinearAlgebra, only some types are kosher. Must recurse to allow e.g. transpose(@SVector [1,2.])
for (wrap, field) in [(:Diagonal, :diag),
(:AdjOrTrans, :parent),
(:UpperOrLowerTriangular, :data)]
@eval function strip_types(x::LinearAlgebra.$wrap)
nt = strip_namedtuple(x)
x.$field === nt.$field ? x : nt
end
end

Expand Down
Loading