Skip to content

Commit

Permalink
Make generated functions safe for extension (#426)
Browse files Browse the repository at this point in the history
* Tidy up comments in code

* Remove redundant method of backing_type

* Formatting

* Make tangent type not generated and tidy up a bit

* Typo

* stable_ntuple

* Remove more generated functions

* make safe __make_ref generated function

* Formatting

* temporarily revert change

* Add extra perf tests to fwds_rvs data

* Fix perf problem

* Fix battery tests

* Fix up perf

* Start making fdata and rdata functions safe

* Remove redundant generated

* Make tuple fdata rdata safe

* Fix all on 1.10

* Formatting

* Fix performance

* Stop using generated function

* Revert "Fix performance"

This reverts commit 824acd0.

* Revert "Revert "Fix performance""

This reverts commit 2997c78.

* Revert "Stop using generated function"

This reverts commit 9a427e7.

* Unrevert change to generated function

* Make function not generated

* Fix up more generated functions

* Make safe zero_rdata_from_type

* Catch edge case in unit tests

* Fix formatting

* Test stable_all properly

* Tell inference what to do

* Remove offending code

* Remove bootstrap calls

* Formatting

* Improve comments

* Fix array test

* Formatting

* Add test to prime inference

* Also add sinh to test suite

* Generated function to enforce specialisation

* Try disabling debug mode on bulk of tests

* Remove entirely redundant stable_ntuple function

* Test overall performance with uninferred tangent type

* Revert "Test overall performance with uninferred tangent type"

This reverts commit 931e27f.

* Remove redundant comment

* Remove compiler-level assertions

* Fix typo

* Typo

* Just assume effects

* Revert "Just assume effects"

This reverts commit 96ff8b9.

* Assume some effects

* Formatting

* Bump patch version

* Remove redundant function

* Tidy up a bit

* Tidy up further

* Remove unused functionality

* Tidy tidy tidy

* Enforce effects in unit tests

* Formatting and effects macro

* Effects for CuArray

* More effects

* Assume more effects

* Tidy up more
  • Loading branch information
willtebbutt authored Dec 24, 2024
1 parent c6d33fa commit 658d566
Show file tree
Hide file tree
Showing 19 changed files with 322 additions and 271 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.68"
version = "0.4.69"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion ext/MooncakeCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import Mooncake.TestUtils: populate_address_map!, AddressMap, __increment_should

# Tell Mooncake.jl how to handle CuArrays.

tangent_type(::Type{P}) where {P<:CuArray{<:IEEEFloat}} = P
Mooncake.@tt_effects tangent_type(::Type{P}) where {P<:CuArray{<:IEEEFloat}} = P
zero_tangent(x::CuArray{<:IEEEFloat}) = zero(x)
function randn_tangent(rng::AbstractRNG, x::CuArray{Float32})
return cu(randn(rng, Float32, size(x)...))
Expand Down
162 changes: 89 additions & 73 deletions src/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,13 @@ end
T == NoTangent && return NoFData

# This method can only handle struct types. Tell user to implement their own method.
isprimitivetype(T) &&
throw(error("$T is a primitive type. Implement a method of `fdata_type` for it."))
if isprimitivetype(T)
msg = "$T is a primitive type. Implement a method of `fdata_type` for it."
return :(error($msg))
end

# If the type is a Union, then take the union type of its arguments.
T isa Union && return Union{fdata_type(T.a),fdata_type(T.b)}
T isa Union && return :(Union{fdata_type($(T.a)),fdata_type($(T.b))})

# If `P` is a mutable type, then its forwards data is its tangent.
ismutabletype(T) && return T
Expand All @@ -179,33 +181,37 @@ end
# The same goes for if the type has any undetermined type parameters.
(isabstracttype(T) || !isconcretetype(T)) && return Any

# We should now have a `Tangent`. If not, we do not know what to do, so error.
T <: Tangent || return :(error("Unhandled type $T"))

# If `P` is an immutable type, then some of its fields may not need to be propagated
# on the forwards-pass.
if T <: Tangent
Tfields = fields_type(T)
fwds_data_field_types = map(1:fieldcount(Tfields)) do n
return fdata_type(fieldtype(Tfields, n))
end
all(==(NoFData), fwds_data_field_types) && return NoFData
return FData{NamedTuple{fieldnames(Tfields),Tuple{fwds_data_field_types...}}}
field_names = fieldnames(fields_type(T))
Tfields = fieldtypes(fields_type(T))
fdata_type_exprs = map(n -> :(fdata_type($(Tfields[n]))), 1:length(Tfields))
return quote
fwds_data_field_types = $(Expr(:call, :tuple, fdata_type_exprs...))
stable_all(tuple_map(==(NoFData), fwds_data_field_types)) && return NoFData
return FData{NamedTuple{$field_names,Tuple{fwds_data_field_types...}}}
end

return :(error("Unhandled type $T"))
end

fdata_type(::Type{T}) where {T<:Ptr} = T

@generated function fdata_type(::Type{P}) where {P<:Tuple}
isa(P, Union) && return Union{fdata_type(P.a),fdata_type(P.b)}
isa(P, Union) && return :(Union{fdata_type($(P.a)),fdata_type($(P.b))})
isempty(P.parameters) && return NoFData
isa(last(P.parameters), Core.TypeofVararg) && return Any
nofdata_tt = Tuple{Vararg{NoFData,length(P.parameters)}}
fdata_tt = Tuple{map(fdata_type, fieldtypes(P))...}
fdata_tt <: nofdata_tt && return NoFData
return nofdata_tt <: fdata_tt ? Union{NoFData,fdata_tt} : fdata_tt
fdata_type_exprs = map(_P -> Expr(:call, :fdata_type, _P), P.parameters)
return quote
fdata_tt = $(Expr(:curly, Tuple, fdata_type_exprs...))
fdata_tt <: $nofdata_tt && return NoFData
return $nofdata_tt <: fdata_tt ? Union{NoFData,fdata_tt} : fdata_tt
end
end

@generated function fdata_type(::Type{NamedTuple{names,T}}) where {names,T<:Tuple}
function fdata_type(::Type{NamedTuple{names,T}}) where {names,T<:Tuple}
if fdata_type(T) == NoFData
return NoFData
elseif isconcretetype(fdata_type(T))
Expand All @@ -224,28 +230,28 @@ Returns the type of to the nth field of the fdata type associated to `P`. Will b
function fdata_field_type(::Type{P}, n::Int) where {P}
Tf = tangent_type(fieldtype(P, n))
f = ismutabletype(P) ? Tf : fdata_type(Tf)
return is_always_initialised(P, n) ? f : _wrap_type(f)
return is_always_initialised(P, n) ? f : PossiblyUninitTangent{f}
end

"""
fdata(t)::fdata_type(typeof(t))
Extract the forwards data from tangent `t`.
"""
@generated function fdata(t::T) where {T}
function fdata(t::T) where {T}

# Ask for the forwards-data type. Useful catch-all error checking for unexpected types.
F = fdata_type(T)

# Catch-all for anything with no forwards-data.
F == NoFData && return :(NoFData())
F == NoFData && return NoFData()

# Catch-all for anything where we return the whole object (mutable structs, arrays...).
F == T && return :(t)
F == T && return t

# T must be a `Tangent` by now. If it's not, something has gone wrong.
!(T <: Tangent) && return :(error("Unhandled type $T"))
return :($F(fdata(t.fields)))
T <: Tangent || error("Unhandled type $T")
return F(fdata(t.fields))
end

function fdata(t::T) where {T<:PossiblyUninitTangent}
Expand Down Expand Up @@ -415,11 +421,13 @@ end
T == NoTangent && return NoRData

# This method can only handle struct types. Tell user to implement their own method.
isprimitivetype(T) &&
throw(error("$T is a primitive type. Implement a method of `rdata_type` for it."))
if isprimitivetype(T)
msg = "$T is a primitive type. Implement a method of `rdata_type` for it."
return :(error(msg))
end

# If the type is a Union, then take the union type of its arguments.
T isa Union && return Union{rdata_type(T.a),rdata_type(T.b)}
T isa Union && return :(Union{rdata_type($(T.a)),rdata_type($(T.b))})

# If `P` is a mutable type, then all tangent info is propagated on the forwards-pass.
ismutabletype(T) && return NoRData
Expand All @@ -428,26 +436,31 @@ end
# The same goes for if the type has any undetermined type parameters.
(isabstracttype(T) || !isconcretetype(T)) && return Any

# If `T` is an immutable type, then some of its fields may not have been propagated on
# the forwards-pass.
if T <: Tangent
Tfs = fields_type(T)
rvs_types = map(n -> rdata_type(fieldtype(Tfs, n)), 1:fieldcount(Tfs))
all(==(NoRData), rvs_types) && return NoRData
return RData{NamedTuple{fieldnames(Tfs),Tuple{rvs_types...}}}
# If `T` is an immutable type, then some of its fields may not need to be propagated
# on the forwards-pass.
field_names = fieldnames(fields_type(T))
Tfields = fieldtypes(fields_type(T))
rdata_type_exprs = map(n -> :(rdata_type($(Tfields[n]))), 1:length(Tfields))
return quote
rvs_data_field_types = $(Expr(:call, :tuple, rdata_type_exprs...))
stable_all(tuple_map(==(NoRData), rvs_data_field_types)) && return NoRData
return RData{NamedTuple{$field_names,Tuple{rvs_data_field_types...}}}
end
end

rdata_type(::Type{<:Ptr}) = NoRData

@generated function rdata_type(::Type{P}) where {P<:Tuple}
isa(P, Union) && return Union{rdata_type(P.a),rdata_type(P.b)}
isa(P, Union) && return :(Union{rdata_type($(P.a)),rdata_type($(P.b))})
isempty(P.parameters) && return NoRData
isa(last(P.parameters), Core.TypeofVararg) && return Any
nordata_tt = Tuple{Vararg{NoRData,length(P.parameters)}}
rdata_tt = Tuple{map(rdata_type, fieldtypes(P))...}
rdata_tt <: nordata_tt && return NoRData
return nordata_tt <: rdata_tt ? Union{NoRData,rdata_tt} : rdata_tt
rdata_type_exprs = map(_P -> Expr(:call, :rdata_type, _P), P.parameters)
return quote
rdata_tt = $(Expr(:curly, Tuple, rdata_type_exprs...))
rdata_tt <: $nordata_tt && return NoRData
return $nordata_tt <: rdata_tt ? Union{NoRData,rdata_tt} : rdata_tt
end
end

function rdata_type(::Type{NamedTuple{names,T}}) where {names,T<:Tuple}
Expand All @@ -468,7 +481,7 @@ Returns the type of to the nth field of the rdata type associated to `P`. Will b
"""
function rdata_field_type(::Type{P}, n::Int) where {P}
r = rdata_type(tangent_type(fieldtype(P, n)))
return is_always_initialised(P, n) ? r : _wrap_type(r)
return is_always_initialised(P, n) ? r : PossiblyUninitTangent{r}
end

"""
Expand All @@ -480,20 +493,20 @@ Extract the reverse data from tangent `t`.
See extended help section of [fdata_type](@ref).
"""
@generated function rdata(t::T) where {T}
function rdata(t::T) where {T}

# Ask for the reverse-data type. Useful catch-all error checking for unexpected types.
R = rdata_type(T)

# Catch-all for anything with no reverse-data.
R == NoRData && return :(NoRData())
R == NoRData && return NoRData()

# Catch-all for anything where we return the whole object (Float64, isbits structs, ...)
R == T && return :(t)
R == T && return t

# T must be a `Tangent` by now. If it's not, something has gone wrong.
!(T <: Tangent) && return :(error("Unhandled type $T"))
return :($(rdata_type(T))(rdata(t.fields)))
T <: Tangent || error("Unhandled type $T")
return R(rdata(t.fields))
end

function rdata(t::T) where {T<:PossiblyUninitTangent}
Expand Down Expand Up @@ -604,41 +617,48 @@ constitute a correctness problem, but can be detrimental to performance, so shou
with.
"""
@generated function zero_rdata_from_type(::Type{P}) where {P}
R = rdata_type(tangent_type(P))

# If we know we can't produce a tangent, say so.
can_produce_zero_rdata_from_type(P) || return CannotProduceZeroRDataFromType()

# Simple case.
R == NoRData && return NoRData()

# If `P` is a struct type, attempt to derive the zero rdata for it. We cannot derive
# the zero rdata if it is not possible to derive the zero rdata for any of its fields.
if isstructtype(P)
# Prepare expressions for manually-unrolled loop to construct zero rdata elements.
if P isa DataType
names = fieldnames(P)
types = fieldtypes(P)
wrapped_field_zeros = tuple_map(ntuple(identity, length(names))) do n
wrapped_field_zeros = map(enumerate(tangent_field_types(P))) do (n, tt)
fzero = :(zero_rdata_from_type($(types[n])))
if tangent_field_type(P, n) <: PossiblyUninitTangent
Q = rdata_type(tangent_type(fieldtype(P, n)))
return :(_wrap_field($Q, $fzero))
if tt <: PossiblyUninitTangent
Q = :(rdata_type(tangent_type($(fieldtype(P, n)))))
return :(PossiblyUninitTangent{$Q}($fzero))
else
return fzero
end
end
wrapped_field_zeros_tuple = Expr(:call, :tuple, wrapped_field_zeros...)
return :($R(NamedTuple{$names}($wrapped_field_zeros_tuple)))
wrapped_expr = :(R(NamedTuple{$names}($wrapped_field_zeros_tuple)))
else
wrapped_expr = nothing
end

# Fallback -- we've not been able to figure out how to produce an instance of zero rdata
# so report that it cannot be done.
return throw(error("Unhandled type $P"))
return quote

# If we know we can't produce a tangent, say so.
can_produce_zero_rdata_from_type($P) || return CannotProduceZeroRDataFromType()

# Simple case.
R = rdata_type(tangent_type($P))
R == NoRData && return NoRData()

$(isstructtype(P)) || error("Unhandled type $P")
return $wrapped_expr
end
end

@generated function zero_rdata_from_type(::Type{P}) where {P<:Tuple}
can_produce_zero_rdata_from_type(P) || return CannotProduceZeroRDataFromType()
rdata_type(tangent_type(P)) == NoRData && return NoRData()
return tuple_map(zero_rdata_from_type, fieldtypes(P))
has_fields = P isa DataType && Base.datatype_fieldcount(P) !== nothing
zero_exprs = has_fields ? map(_P -> :(zero_rdata_from_type($_P)), fieldtypes(P)) : []
return quote
can_produce_zero_rdata_from_type($P) || return CannotProduceZeroRDataFromType()
rdata_type(tangent_type($P)) == NoRData && return NoRData()
return $(Expr(:call, :tuple, zero_exprs...))
end
end

function zero_rdata_from_type(::Type{P}) where {P<:NamedTuple}
Expand Down Expand Up @@ -785,15 +805,14 @@ tangent_type(::Type{F}, ::Type{NoRData}) where {F<:Array} = F

# Tuples
@generated function tangent_type(::Type{F}, ::Type{R}) where {F<:Tuple,R<:Tuple}
return Tuple{tuple_map(tangent_type, Tuple(F.parameters), Tuple(R.parameters))...}
tt_exprs = map((f, r) -> :(tangent_type($f, $r)), fieldtypes(F), fieldtypes(R))
return Expr(:curly, :Tuple, tt_exprs...)
end
function tangent_type(::Type{NoFData}, ::Type{R}) where {R<:Tuple}
F_tuple = Tuple{tuple_fill(NoFData, Val(length(R.parameters)))...}
return tangent_type(F_tuple, R)
return tangent_type(Tuple{tuple_fill(NoFData, Val(length(R.parameters)))...}, R)
end
function tangent_type(::Type{F}, ::Type{NoRData}) where {F<:Tuple}
R_tuple = Tuple{tuple_fill(NoRData, Val(length(F.parameters)))...}
return tangent_type(F, R_tuple)
return tangent_type(F, Tuple{tuple_fill(NoRData, Val(length(F.parameters)))...})
end

# NamedTuples
Expand Down Expand Up @@ -904,10 +923,7 @@ Equivalent to `tangent(fdata, rdata(zero_tangent(primal)))`.
zero_tangent(p, ::NoFData) = zero_tangent(p)

function zero_tangent(p::P, f::F) where {P,F}
T = tangent_type(P)
T == F && return f
r = rdata(zero_tangent(p))
return tangent(f, r)
return tangent_type(P) == F ? f : tangent(f, rdata(zero_tangent(p)))
end

zero_tangent(p::Tuple, f::Union{Tuple,NamedTuple}) = tuple_map(zero_tangent, p, f)
Loading

2 comments on commit 658d566

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/121932

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.69 -m "<description of version>" 658d566dc1525157c65e01df48dc7bef5a53d810
git push origin v0.4.69

Please sign in to comment.