Skip to content

Commit

Permalink
Some Performance Fixes (#402)
Browse files Browse the repository at this point in the history
* Check more small-union Tuple test cases

* Improve small union handling for Tuples

* Change s2s implementation

* Avoid type instability

* Add test

* Bump patch

* Fix up ret type

* Fix test case

* Formatting

* Formatting

* Fix up array tests

* Stabilise test case

* Formatting

* argmin allocates on 1.10

* Fix up testing

* Improve codual implementation

* Remove redundant code

* Remove redundant test

* Fix up formatting
  • Loading branch information
willtebbutt authored Dec 1, 2024
1 parent e70c730 commit 1a0ed04
Show file tree
Hide file tree
Showing 12 changed files with 684 additions and 372 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.52"
version = "0.4.53"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion ext/MooncakeAllocCheckExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ module MooncakeAllocCheckExt
using AllocCheck, Mooncake
import Mooncake.TestUtils: check_allocs, Shim

@check_allocs check_allocs(::Shim, f::F, x::Tuple{Vararg{Any,N}}) where {F,N} = f(x...)
@check_allocs check_allocs(::Shim, f::F, x...) where {F} = f(x...)

end
81 changes: 46 additions & 35 deletions src/codual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,46 +31,48 @@ Equivalent to `CoDual(x, uninit_tangent(x))`.
"""
uninit_codual(x) = CoDual(x, uninit_tangent(x))

function _codual_internal(::Type{P}, f::F, extractor::E) where {P,F,E}
P == Union{} && return CoDual
P == DataType && return CoDual
P isa Union && return Union{f(P.a),f(P.b)}

if P <: Tuple && !all(isconcretetype, (P.parameters...,))
field_types = (P.parameters...,)
union_fields = _findall(Base.Fix2(isa, Union), 1, field_types)
if length(union_fields) == 1 &&
all(p -> p isa Union || isconcretetype(p), field_types)
P_split = split_union_tuple_type(field_types)
return Union{f(P_split.a),f(P_split.b)}
end
end

P <: UnionAll && return CoDual # P is abstract, so we don't know its tangent type.
return isconcretetype(P) ? CoDual{P,extractor(P)} : CoDual
end

"""
codual_type(P::Type)
The type of the `CoDual` which contains instances of `P` and associated tangents.
"""
function codual_type(::Type{P}) where {P}
P == DataType && return CoDual
P isa Union && return Union{codual_type(P.a),codual_type(P.b)}
P <: UnionAll && return CoDual # P is abstract, so we don't know its tangent type.
return isconcretetype(P) ? CoDual{P,tangent_type(P)} : CoDual
end
codual_type(::Type{P}) where {P} = _codual_internal(P, codual_type, tangent_type)

function codual_type(p::Type{Type{P}}) where {P}
return @isdefined(P) ? CoDual{Type{P},NoTangent} : CoDual{_typeof(p),NoTangent}
end

struct NoPullback{R<:Tuple}
r::R
end

_copy(x::P) where {P<:NoPullback} = P(_copy(x.r))

"""
NoPullback(args::CoDual...)
Construct a `NoPullback` from the arguments passed to an `rrule!!`. For each argument,
extracts the primal value, and constructs a `LazyZeroRData`. These are stored in a
`NoPullback` which, in the reverse-pass of AD, instantiates these `LazyZeroRData`s and
returns them in order to perform the reverse-pass of AD.
fcodual_type(P::Type)
The advantage of this approach is that if it is possible to construct the zero rdata element
for each of the arguments lazily, the `NoPullback` generated will be a singleton type. This
means that AD can avoid generating a stack to store this pullback, which can result in
significant performance improvements.
The type of the `CoDual` which contains instances of `P` and its fdata.
"""
function NoPullback(args::Vararg{CoDual,N}) where {N}
return NoPullback(tuple_map(lazy_zero_rdata primal, args))
function fcodual_type(::Type{P}) where {P}
return _codual_internal(P, fcodual_type, P -> fdata_type(tangent_type(P)))
end

@inline (pb::NoPullback)(_) = tuple_map(instantiate, pb.r)
function fcodual_type(p::Type{Type{P}}) where {P}
return @isdefined(P) ? CoDual{Type{P},NoFData} : CoDual{_typeof(p),NoFData}
end

to_fwds(x::CoDual) = CoDual(primal(x), fdata(tangent(x)))

Expand All @@ -86,18 +88,27 @@ See implementation for details, as this function is subject to change.
"""
@inline uninit_fcodual(x::P) where {P} = CoDual(x, uninit_fdata(x))

struct NoPullback{R<:Tuple}
r::R
end

_copy(x::P) where {P<:NoPullback} = P(_copy(x.r))

"""
fcodual_type(P::Type)
NoPullback(args::CoDual...)
The type of the `CoDual` which contains instances of `P` and its fdata.
Construct a `NoPullback` from the arguments passed to an `rrule!!`. For each argument,
extracts the primal value, and constructs a `LazyZeroRData`. These are stored in a
`NoPullback` which, in the reverse-pass of AD, instantiates these `LazyZeroRData`s and
returns them in order to perform the reverse-pass of AD.
The advantage of this approach is that if it is possible to construct the zero rdata element
for each of the arguments lazily, the `NoPullback` generated will be a singleton type. This
means that AD can avoid generating a stack to store this pullback, which can result in
significant performance improvements.
"""
function fcodual_type(::Type{P}) where {P}
P == DataType && return CoDual
P isa Union && return Union{fcodual_type(P.a),fcodual_type(P.b)}
P <: UnionAll && return CoDual
return isconcretetype(P) ? CoDual{P,fdata_type(tangent_type(P))} : CoDual
function NoPullback(args::Vararg{CoDual,N}) where {N}
return NoPullback(tuple_map(lazy_zero_rdata primal, args))
end

function fcodual_type(p::Type{Type{P}}) where {P}
return @isdefined(P) ? CoDual{Type{P},NoFData} : CoDual{_typeof(p),NoFData}
end
@inline (pb::NoPullback)(_) = tuple_map(instantiate, pb.r)
57 changes: 56 additions & 1 deletion src/interpreter/ir_normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ provides a convenient way to do this.
function normalise!(ir::IRCode, spnames::Vector{Symbol})
sp_map = Dict{Symbol,CC.VarState}(zip(spnames, ir.sptypes))
ir = interpolate_boundschecks!(ir)
ir = CC.compact!(ir)
ir = fix_up_invoke_inference!(ir)
for (n, inst) in enumerate(stmt(ir.stmts))
inst = foreigncall_to_call(inst, sp_map)
inst = new_to_call(inst)
Expand Down Expand Up @@ -65,6 +65,61 @@ function _interpolate_boundschecks!(statements::Vector{Any})
return nothing
end

"""
fix_up_invoke_inference!(ir::IRCode)
# The Problem
Consider the following:
```julia
@noinline function bar!(x)
x .*= 2
end
function foo!(x)
bar!(x)
return nothing
end
```
In this case, the IR associated to `Tuple{typeof(foo), Vector{Float64}}` will be something
along the lines of
```julia
julia> Base.code_ircode_by_type(Tuple{typeof(foo), Vector{Float64}})
1-element Vector{Any}:
2 1 ─ invoke Main.bar!(_2::Vector{Float64})::Any
3 └── return Main.nothing
=> Nothing
```
Observe that the type inferred for the first line is `Any`. Inference is at liberty to do
this without any risk of performance problems because the first line is not used anywhere
else in the function. Had this line been used elsewhere in the function, inference would
have inferred its type to be `Vector{Float64}`.
This causes performance problems for Mooncake, because it uses the return type to do
various things, including allocating storage for quantities required on the reverse-pass.
Consequently, inference infering `Any` rather than `Vector{Float64}` causes type
instabilities in the code that Mooncake generates, which can have catastrophic conseqeuences
for performance.
# The Solution
`:invoke` expressions contain the `Core.MethodInstance` associated to them, which contains
a `Core.CodeCache`, which contains the return type of the `:invoke`. This function looks
for `:invoke` statements whose return type is inferred to be `Any` in `ir`, and modifies it
to be the return type given by the code cache.
"""
function fix_up_invoke_inference!(ir::IRCode)::IRCode
stmts = ir.stmts
for n in 1:length(stmts)
if Meta.isexpr(stmt(stmts)[n], :invoke) && _type(stmts.type[n]) == Any
mi = stmt(stmts)[n].args[1]::Core.MethodInstance
R = isdefined(mi, :cache) ? mi.cache.rettype : CC.return_type(mi.specTypes)
stmts.type[n] = R
end
end
return ir
end

"""
foreigncall_to_call(inst, sp_map::Dict{Symbol, CC.VarState})
Expand Down
22 changes: 5 additions & 17 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -944,22 +944,10 @@ function rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where
pb_type = Pullback{sig,Base.RefValue{pb_oc_type},Val{isva},nvargs(isva, sig)}
nargs = Val{length(ir.argtypes)}

if isconcretetype(Treturn)
Tderived_rule = DerivedRule{
sig,RuleMC{arg_fwds_types,fcodual_type(Treturn)},pb_type,Val{isva},nargs
}
return debug_mode ? DebugRRule{Tderived_rule} : Tderived_rule
else
if debug_mode
return DebugRRule{
DerivedRule{sig,RuleMC{arg_fwds_types,P},pb_type,Val{isva},nargs}
} where {P<:fcodual_type(Treturn)}
else
return DerivedRule{
sig,RuleMC{arg_fwds_types,P},pb_type,Val{isva},nargs
} where {P<:fcodual_type(Treturn)}
end
end
Tderived_rule = DerivedRule{
sig,RuleMC{arg_fwds_types,fcodual_type(Treturn)},pb_type,Val{isva},nargs
}
return debug_mode ? DebugRRule{Tderived_rule} : Tderived_rule
end

nvargs(isva, sig) = Val{isva ? length(sig.parameters[end].parameters) : 0}
Expand Down Expand Up @@ -1738,7 +1726,7 @@ end

_rtype(::Type{<:DebugRRule}) = Tuple{CoDual,DebugPullback}
_rtype(T::Type{<:MistyClosure}) = _rtype(fieldtype(T, :oc))
_rtype(::Type{<:OpaqueClosure{<:Any,<:R}}) where {R} = (@isdefined R) ? R : CoDual
_rtype(::Type{<:OpaqueClosure{<:Any,R}}) where {R} = R
_rtype(T::Type{<:DerivedRule}) = Tuple{_rtype(fieldtype(T, :fwds_oc)),fieldtype(T, :pb)}

@noinline function _build_rule!(rule::LazyDerivedRule{sig,Trule}, args) where {sig,Trule}
Expand Down
31 changes: 31 additions & 0 deletions src/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,30 @@ tangent_type(::Type{Method}) = NoTangent

tangent_type(::Type{<:Enum}) = NoTangent

# Inferable version of `findall` for `Tuple`s.
_findall(::Any, ::Int, ::Tuple{}) = ()
function _findall(cond, ind::Int, x::Tuple)
tail = _findall(cond, ind + 1, x[2:end])
return cond(x[1]) ? (ind, tail...) : tail
end

function split_union_tuple_type(tangent_types)

# Create first split.
ta_types = map(tangent_types) do T
return T isa Union ? T.a : T
end
ta = Tuple{ta_types...}

# Create second split.
tb_types = map(tangent_types) do T
return T isa Union ? T.b : T
end
tb = Tuple{tb_types...}

return Union{ta,tb}
end

function tangent_type(::Type{P}) where {N,P<:Tuple{Vararg{Any,N}}}
# As with other types, tangent type of Union is Union of tangent types.
P isa Union && return Union{tangent_type(P.a),tangent_type(P.b)}
Expand All @@ -365,6 +389,13 @@ function tangent_type(::Type{P}) where {N,P<:Tuple{Vararg{Any,N}}}
T_all_notangent = Tuple{Vararg{NoTangent,length(tangent_types)}}
T <: T_all_notangent && return NoTangent

# If exactly one of the field types is a Union, then split.
union_fields = _findall(Base.Fix2(isa, Union), 1, tangent_types)
if length(union_fields) == 1 &&
all(p -> p isa Union || isconcretetype(p), tangent_types)
return split_union_tuple_type(tangent_types)
end

# If it's _possible_ for a subtype of `P` to have tangent type `NoTangent`, then we must
# account for that by returning the union of `NoTangent` and `T`. For example, if
# `P = Tuple{Any, Int}`, then `P2 = Tuple{Int, Int}` is a subtype. Since `P2` has
Expand Down
13 changes: 13 additions & 0 deletions src/test_resources.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,12 +584,24 @@ function typevar_tester()
end

tuple_with_union(x::Bool) = (x ? 5.0 : 5, nothing)
tuple_with_union_2(x::Bool) = (x ? 5.0 : 5, x ? 5 : 5.0)
tuple_with_union_3(x::Bool, y::Bool) = (x ? 5.0 : (y ? 5 : nothing), nothing)

struct NoDefaultCtor{T}
x::T
NoDefaultCtor(x::T) where {T} = new{T}(x)
end

@noinline function __inplace_function!(x::Vector{Float64})
x .*= 2
return nothing
end

function inplace_invoke!(x::Vector{Float64})
__inplace_function!(x)
return nothing
end

function generate_test_functions()
return Any[
(false, :allocs, nothing, const_tester),
Expand Down Expand Up @@ -814,6 +826,7 @@ function generate_test_functions()
(false, :none, nothing, hvcat, (2, 2), 3.0, 2.0, 0.0, 1.0),
(false, :none, nothing, partial_typevar_tester),
(false, :none, nothing, typevar_tester),
(false, :allocs, nothing, inplace_invoke!, randn(1_024)),
]
end

Expand Down
10 changes: 5 additions & 5 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ function test_set_tangent_field!_correctness(t1::T, t2::T) where {T<:MutableTang
end
end

function check_allocs(::Any, f::F, x::Tuple{Vararg{Any,N}}) where {F,N}
function check_allocs(::Any, f::F, x::Vararg{Any,N}) where {F,N}
throw(error("Load AllocCheck.jl to use this functionality."))
end

Expand Down Expand Up @@ -970,10 +970,10 @@ function test_tangent_performance(rng::AbstractRNG, p::P) where {P}
end

function test_allocations(t::T, z::T) where {T}
check_allocs(Shim(), increment!!, (t, t))
check_allocs(Shim(), increment!!, (t, z))
check_allocs(Shim(), increment!!, (z, t))
return check_allocs(Shim(), increment!!, (z, z))
check_allocs(Shim(), increment!!, t, t)
check_allocs(Shim(), increment!!, t, z)
check_allocs(Shim(), increment!!, z, t)
return check_allocs(Shim(), increment!!, z, z)
end

_set_tangent_field!(x, ::Val{i}, v) where {i} = set_tangent_field!(x, i, v)
Expand Down
Loading

2 comments on commit 1a0ed04

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/120473

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.53 -m "<description of version>" 1a0ed0478518654a1820526b294c83755b39e33a
git push origin v0.4.53

Please sign in to comment.