Skip to content

Commit

Permalink
Performance Robustness in Reverse Pass (#442)
Browse files Browse the repository at this point in the history
* Fix up zero_rdata_from_type

* Stop generating hundreds of methods

* Call getfield directly instead of getindex

* Manually inline ad stmts for rvs-pass for call

* Refactor PhiNode

* Remove increment_if_ref

* Remove commented-out code

* Improve docstring

* Improve special_functions testset display

* Extend increment

* Remove increment_ref usage

* PiNode

* Remove increment_ref and __pi_rvs

* Add regression test

* Remove __deref_and_zero

* Docstring

* Reformat

* Bump patch version
  • Loading branch information
willtebbutt authored Jan 9, 2025
1 parent b242b3b commit 35b432c
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 96 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.75"
version = "0.4.76"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
11 changes: 4 additions & 7 deletions src/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -662,14 +662,11 @@ with.
if P isa DataType
names = fieldnames(P)
types = fieldtypes(P)
wrapped_field_zeros = map(enumerate(tangent_field_types(P))) do (n, tt)
wrapped_field_zeros = map(enumerate(always_initialised(P))) do (n, init)
fzero = :(zero_rdata_from_type($(types[n])))
if tt <: PossiblyUninitTangent
Q = :(rdata_type(tangent_type($(fieldtype(P, n)))))
return :(PossiblyUninitTangent{$Q}($fzero))
else
return fzero
end
init && return fzero
Q = :(rdata_type(tangent_type($(fieldtype(P, n)))))
return :(PossiblyUninitTangent{$Q}($fzero))
end
wrapped_field_zeros_tuple = Expr(:call, :tuple, wrapped_field_zeros...)
wrapped_expr = :(R(NamedTuple{$names}($wrapped_field_zeros_tuple)))
Expand Down
167 changes: 109 additions & 58 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo)
end
if is_active(stmt.val)
rdata_id = get_rev_data_id(info, stmt.val)
rvs = new_inst(Expr(:call, increment_ref!, rdata_id, Argument(2)))
rvs = increment_ref_stmts(rdata_id, Argument(2))
assert_id = ID()
val = __inc(stmt.val)
fwds = [
Expand Down Expand Up @@ -479,7 +479,13 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo)
val_rdata_ref_id = get_rev_data_id(info, stmt.val)
output_rdata_ref_id = get_rev_data_id(info, line)
fwds = PiNode(__inc(stmt.val), fcodual_type(CC.widenconst(stmt.typ)))
rvs = Expr(:call, __pi_rvs!, P, val_rdata_ref_id, output_rdata_ref_id)

# Get the rdata from the output_rdata_ref, and set its new value to zero, and
# increment the output ref.
output_rdata_id = ID()
deref_stmts = deref_and_zero_stmts(P, output_rdata_ref_id, output_rdata_id)
inc_exprs = increment_ref_stmts(val_rdata_ref_id, output_rdata_id)
rvs = vcat(deref_stmts, inc_exprs)
else
# If the value of the PiNode is a constant / QuoteNode etc, then there is nothing to
# do on the reverse-pass.
Expand All @@ -494,11 +500,6 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo)
return ad_stmt_info(line, nothing, fwds, rvs)
end

@inline function __pi_rvs!(::Type{P}, val_rdata_ref::Ref, output_rdata_ref::Ref) where {P}
increment_ref!(val_rdata_ref, __deref_and_zero(P, output_rdata_ref))
return nothing
end

# Constant GlobalRefs are handled. See const_codual. Non-constant GlobalRefs are handled by
# assuming that they are constant, and creating a CoDual with the value. We then check at
# run-time that the value has not changed.
Expand Down Expand Up @@ -723,17 +724,53 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)
rvs_pass = if T_pb!! <: NoPullback
nothing
else
Expr(
:call,
__run_rvs_pass!,
get_primal_type(info, line),
sig,
pb,
get_rev_data_id(info, line),
map(Base.Fix1(get_rev_data_id, info), args)...,
# Get the rdata which we pass into the pullback from its rdata ref.
rdata_ref_id = get_rev_data_id(info, line)
rdata_output_id = ID()
rdata_output_expr = Expr(:call, getfield, rdata_ref_id, QuoteNode(:x))
rdata_output = (rdata_output_id, new_inst(rdata_output_expr))

# Zero out the value stored in this rdata ref now that we have its current
# value. The new value is rdata, so must be an instance of a bits type, so is
# safe to interpolate straight into instruction.
zero_val = zero_like_rdata_from_type(get_primal_type(info, line))
zero_rdata_expr = Expr(:call, setfield!, rdata_ref_id, QuoteNode(:x), zero_val)
zero_rdata_ref = (ID(), new_inst(zero_rdata_expr))

# Run the pullback. The result is a tuple comprising `length(args)` elements.
call_pullback_id = ID()
call_pullback = (call_pullback_id, new_inst(Expr(:call, pb, rdata_output_id)))

# For each element of the tuple returned by call_pullback, if the corresponding
# value in the primal IR is an Argument / SSA (if `get_rev_data_id` does not
# return nothing), increment the value in its rdata ref. This is equivalent to
# rdata_ref[] = increment!!(rdata_ref[], rdata_inc_resulting_from_pullback),
# but written out manually to ensure nothing fails to inline.
# If the corresponding value in the primal IR is not an Argument / SSA (e.g. it
# is a literal, a `QuoteNode`, or a `GlobalRef`), do nothing as we do not track
# gradients w.r.t. it.
tmp = map(enumerate(args)) do (n, arg)
rev_data_id = get_rev_data_id(info, arg)

# If arg is not an SSA / Argument, then no rdata ref to inc.
rev_data_id === nothing && return nothing

# Extract rdata from result of calling pullback.
rdata_inc_id = ID()
rdata_inc_expr = Expr(:call, getfield, call_pullback_id, n)
rdata_inc = (rdata_inc_id, new_inst(rdata_inc_expr))

# Construct statments to increment ref.
return vcat(rdata_inc, increment_ref_stmts(rev_data_id, rdata_inc_id))
end

# Concatenate all statements, and return them.
vcat(
IDInstPair[rdata_output, zero_rdata_ref, call_pullback],
reduce(vcat, filter(x -> !(x === nothing), tmp); init=IDInstPair[]),
)
end
return ad_stmt_info(line, comms_id, fwds, new_inst(rvs_pass))
return ad_stmt_info(line, comms_id, fwds, rvs_pass)

elseif Meta.isexpr(stmt, :boundscheck)
# For some reason the compiler cannot handle boundscheck statements when we run it
Expand Down Expand Up @@ -782,6 +819,29 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)
end
end

"""
increment_ref_stmts(ref_id::ID, inc_data)::Vector{IDInstPair}
Equivalent to `ref[] = increment!!(ref[], inc_data)`, where `ref` and `inc_data` are the
values associated to `ref_id` and `inc_data` respectively.
"""
function increment_ref_stmts(ref_id::ID, inc_data)::Vector{IDInstPair}

# Get the value stored in the `Base.RefValue`.
ref_val_id = ID()
ref_val = (ref_val_id, new_inst(Expr(:call, getfield, ref_id, QuoteNode(:x))))

# Increment the value by inc_data.
new_val_id = ID()
new_val = (new_val_id, new_inst(Expr(:call, increment!!, ref_val_id, inc_data)))

# Update the value stored in the rdata reference.
set_ref_expr = Expr(:call, setfield!, ref_id, QuoteNode(:x), new_val_id)
set_ref = (ID(), new_inst(set_ref_expr))

return IDInstPair[ref_val, new_val, set_ref]
end

is_active(::Union{Argument,ID}) = true
is_active(::Any) = false

Expand All @@ -807,33 +867,6 @@ end
__get_primal(x::CoDual) = primal(x)
__get_primal(x) = x

"""
__run_rvs_pass!(
P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...
) where {sig}
Used in `make_ad_stmts!` method for `Expr(:call, ...)` and `Expr(:invoke, ...)`.
"""
@inline function __run_rvs_pass!(
P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...
) where {sig}
tuple_map(increment_if_ref!, arg_rev_data_refs, pb!!(ret_rev_data_ref[]))
set_ret_ref_to_zero!!(P, ret_rev_data_ref)
return nothing
end

@inline increment_if_ref!(ref::Ref, rvs_data) = increment_ref!(ref, rvs_data)
@inline increment_if_ref!(::Ref, ::ZeroRData) = nothing
@inline increment_if_ref!(::Nothing, ::Any) = nothing

@inline increment_ref!(x::Ref, t) = setindex!(x, increment!!(x[], t))
@inline increment_ref!(::Base.RefValue{NoRData}, t) = nothing

@inline function set_ret_ref_to_zero!!(::Type{P}, r::Ref{R}) where {P,R}
return r[] = zero_like_rdata_from_type(P)
end
@inline set_ret_ref_to_zero!!(::Type{P}, r::Base.RefValue{NoRData}) where {P} = nothing

const RuleMC{A,R} = MistyClosure{OpaqueClosure{A,R}}

#
Expand Down Expand Up @@ -1437,7 +1470,7 @@ function pullback_ir(

# De-reference the nth rdata.
rdata_id = ID()
rdata = new_inst(Expr(:call, getindex, arg_rdata_ref_ids[n]))
rdata = new_inst(Expr(:call, getfield, arg_rdata_ref_ids[n], QuoteNode(:x)))

# Get the nth lazy zero rdata.
lazy_zero_rdata_id = ID()
Expand Down Expand Up @@ -1511,11 +1544,12 @@ function conclude_rvs_block(

# Create statements which extract + zero the rdata refs associated to them.
rdata_ids = map(_ -> ID(), phi_ids)
deref_stmts = map(phi_ids, rdata_ids) do phi_id, deref_id
tmp = map(phi_ids, rdata_ids) do phi_id, deref_id
P = get_primal_type(info, phi_id)
r = get_rev_data_id(info, phi_id)
return (deref_id, new_inst(Expr(:call, __deref_and_zero, P, r)))
return deref_and_zero_stmts(P, r, deref_id)
end
deref_stmts = reduce(vcat, tmp; init=IDInstPair[])

# For each predecessor, create a `BBlock` which processes its corresponding edge in
# each of the `PhiNode`s.
Expand All @@ -1540,14 +1574,19 @@ function __get_value(edge::ID, x::IDPhiNode)
end

"""
__deref_and_zero(::Type{P}, x::Ref) where {P}
deref_and_zero_stmts(P, ref_id, val_id)
Helper, used in conclude_rvs_block.
Equivalent to something like
```julia
val = ref[]
ref[] = zero_rdata_from_type(P)
```
"""
@inline function __deref_and_zero(::Type{P}, x::Ref) where {P}
t = x[]
x[] = Mooncake.zero_like_rdata_from_type(P)
return t
function deref_and_zero_stmts(P, ref_id, val_id)
val = (val_id, new_inst(Expr(:call, getfield, ref_id, QuoteNode(:x))))
r = Mooncake.zero_like_rdata_from_type(P)
set_ref = (ID(), new_inst(Expr(:call, setfield!, ref_id, QuoteNode(:x), r)))
return IDInstPair[val, set_ref]
end

"""
Expand All @@ -1562,10 +1601,14 @@ of some block:
%6 = φ (#2 => _1, #3 => %5)
%7 = φ (#2 => 5., #3 => _2)
```
Let the tangent refs associated to `%6`, `%7`, and `_1`` be denoted `t%6`, `t%7`, and `t_1`
resp., and let `pred_id` be `#2`, then this function will produce a basic block of the form
Let the rdata refs associated to `%6`, `%7`, and `_1`` be denoted `r%6`, `r%7`, and `r_1`
resp., and let `pred_id` be `#2`, and `increment_ref!` be the following function,
```julia
increment_ref!(t_1, t%6)
increment_ref!(ref, x) = ref[] = increment!!(ref[], x)
```
then this `rvs_phi_block` will produce a basic block of the form
```julia
increment_ref!(r_1, r%6)
nothing
goto #2
```
Expand All @@ -1577,15 +1620,23 @@ on.
The same ideas apply if `pred_id` were `#3`. The block would end with `#3`, and there would
be two `increment_ref!` calls because both `%5` and `_2` are not constants.
In practice, code which is equivalent to `increment_ref!` is created directly, rather than
inserting a call to a generic Julia function. This is because we need to be certain that
the getfield and setfield! calls applied to any references are visible to the SROA
optimisation pass. If we insert a call to a function like `increment_ref!`, it might not be
inlined away, making such references opaque.
"""
function rvs_phi_block(
pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo
)
@assert length(rdata_ids) == length(values)
inc_stmts = map(rdata_ids, values) do id, val
stmt = Expr(:call, increment_if_ref!, get_rev_data_id(info, val), id)
return (ID(), new_inst(stmt))
tmp = map(rdata_ids, values) do id, val
rev_data_id = get_rev_data_id(info, val)
rev_data_id === nothing && return nothing
return increment_ref_stmts(rev_data_id, id)
end
inc_stmts = reduce(vcat, filter(x -> !(x === nothing), tmp); init=IDInstPair[])
goto_stmt = (ID(), new_inst(IDGotoNode(pred_id)))
return BBlock(ID(), vcat(inc_stmts, goto_stmt))
end
Expand Down
2 changes: 2 additions & 0 deletions src/interpreter/zero_like_rdata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ error -- please open an issue in such a situation.
struct ZeroRData end

@inline increment!!(::ZeroRData, r::R) where {R} = r
@inline increment!!(r::R, ::ZeroRData) where {R} = r
@inline increment!!(::ZeroRData, ::ZeroRData) = ZeroRData()

"""
zero_like_rdata_type(::Type{P}) where {P}
Expand Down
38 changes: 12 additions & 26 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,43 +24,29 @@ the same length, while `map` will just produce a new tuple whose length is equal
shorter of `x` and `y`.
"""
@inline @generated function tuple_map(f::F, x::Tuple) where {F}
return Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), eachindex(x.parameters))...)
return Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), 1:fieldcount(x))...)
end

@inline @generated function tuple_map(f::F, x::Tuple, y::Tuple) where {F}
if length(x.parameters) != length(y.parameters)
return :(throw(ArgumentError("length(x) != length(y)")))
else
stmts = map(n -> :(f(getfield(x, $n), getfield(y, $n))), eachindex(x.parameters))
stmts = map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:fieldcount(x))
return Expr(:call, :tuple, stmts...)
end
end

for N in 1:128
@eval @inline function tuple_map(f::F, x::Tuple{Vararg{Any,$N}}) where {F}
return $(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), 1:N)...))
end
@eval @inline function tuple_map(
f::F, x::NamedTuple{names,<:Tuple{Vararg{Any,$N}}}
) where {F,names}
return NamedTuple{names}(
$(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), 1:N)...))
)
end
@eval @inline function tuple_map(f, x::Tuple{Vararg{Any,$N}}, y::Tuple{Vararg{Any,$N}})
return $(Expr(
:call, :tuple, map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:N)...
))
end
@eval @inline function tuple_map(
f::F,
x::NamedTuple{names,<:Tuple{Vararg{Any,$N}}},
y::NamedTuple{names,<:Tuple{Vararg{Any,$N}}},
) where {F,names}
return NamedTuple{names}(
$(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:N)...))
)
@generated function tuple_map(f, x::NamedTuple{names}) where {names}
getfield_exprs = map(n -> :(f(getfield(x, $n))), 1:fieldcount(x))
return :(NamedTuple{names}($(Expr(:call, :tuple, getfield_exprs...))))
end

@generated function tuple_map(f, x::NamedTuple{names}, y::NamedTuple{names}) where {names}
if fieldcount(x) != fieldcount(y)
return :(throw(ArgumentError("length(x) != length(y)")))
end
getfield_exprs = map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:fieldcount(x))
return :(NamedTuple{names}($(Expr(:call, :tuple, getfield_exprs...))))
end

for N in 1:256
Expand Down
4 changes: 2 additions & 2 deletions test/ext/special_functions/special_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Mooncake.TestUtils: test_rule

# Rules in this file are only lightly tester, because they are all just @from_rrule rules.
@testset "special_functions" begin
@testset for (perf_flag, f, x...) in vcat(
@testset "$perf_flag, $(typeof((f, x...)))" for (perf_flag, f, x...) in vcat(
map([Float64, Float32]) do P
return Any[
(:stability, airyai, P(0.1)),
Expand Down Expand Up @@ -51,7 +51,7 @@ using Mooncake.TestUtils: test_rule
)
test_rule(StableRNG(123456), f, x...; perf_flag)
end
@testset for (perf_flag, f, x...) in vcat(
@testset "$perf_flag, $(typeof((f, x...)))" for (perf_flag, f, x...) in vcat(
map([Float64, Float32]) do P
return Any[
(:none, logerf, P(0.3), P(0.5)), # first branch
Expand Down
9 changes: 7 additions & 2 deletions test/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct A
end
f(a, x) = dot(a.data, x)

unstable_tester(x::Ref{Any}) = sin(x[])

end

@testset "s2s_reverse_mode_ad" begin
Expand Down Expand Up @@ -106,8 +108,6 @@ end
@test length(stmts.fwds) == 2
@test stmts.fwds[1][2].stmt isa Expr
@test stmts.fwds[2][2].stmt isa ReturnNode
@test Meta.isexpr(only(stmts.rvs)[2].stmt, :call)
@test only(stmts.rvs)[2].stmt.args[1] == Mooncake.increment_ref!
end
@testset "literal" begin
stmt_info = make_ad_stmts!(ReturnNode(5.0), line, info)
Expand Down Expand Up @@ -344,4 +344,9 @@ end
f() = Float64
@test length(build_rrule(Tuple{typeof(f)}).fwds_oc.oc.captures) == 2
end
@testset "all `Ref`s for rdata are eliminated in type unstable code" begin
ir = Mooncake.rvs_ir(Tuple{typeof(S2SGlobals.unstable_tester),Ref{Any}})
stmts = Mooncake.stmt(ir.stmts)
@test !any(x -> Meta.isexpr(x, :new) && x.args[1] <: Base.RefValue, stmts)
end
end

2 comments on commit 35b432c

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/122700

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.76 -m "<description of version>" 35b432c47b454d6af925497337c32bc0f3958df0
git push origin v0.4.76

Please sign in to comment.