diff --git a/Project.toml b/Project.toml index c5619e9b4..86c5d1c78 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.72" +version = "0.4.73" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 7398a1dbd..d8a55b615 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -253,38 +253,25 @@ get_rev_data_id(::ADInfo, ::Any) = nothing """ reverse_data_ref_stmts(info::ADInfo) -Create the statements which initialise the reverse-data `Ref`s. +Create the `:new` statements which initialise the reverse-data `Ref`s. Interpolates the +initial rdata directly into the statement, which is safe because it is always a bits type. """ function reverse_data_ref_stmts(info::ADInfo) + function make_ref_stmt(id, P) + ref_type = Base.RefValue{P <: Type ? NoRData : zero_like_rdata_type(P)} + init_ref_val = P <: Type ? NoRData() : Mooncake.zero_like_rdata_from_type(P) + return (id, new_inst(Expr(:new, ref_type, QuoteNode(init_ref_val)))) + end return vcat( map(collect(info.arg_rdata_ref_ids)) do (k, id) - (id, new_inst(Expr(:call, __make_ref, CC.widenconst(info.arg_types[k])))) + return make_ref_stmt(id, CC.widenconst(info.arg_types[k])) end, map(collect(info.ssa_rdata_ref_ids)) do (k, id) - (id, new_inst(Expr(:call, __make_ref, CC.widenconst(info.ssa_insts[k].type)))) + return make_ref_stmt(id, CC.widenconst(info.ssa_insts[k].type)) end, ) end -""" - __make_ref(p::Type{P}) where {P} - -Helper for [`reverse_data_ref_stmts`](@ref). Constructs a `Ref` whose element type is the -[`zero_like_rdata_type`](@ref) for `P`, and whose element is the zero-like rdata for `P`. -""" -@inline function __make_ref(p::Type{P}) where {P} - _P = @isdefined(P) ? P : _typeof(p) - return Ref{zero_like_rdata_type(_P)}(Mooncake.zero_like_rdata_from_type(_P)) -end - -# This specialised method is necessary to ensure that `__make_ref` works properly for -# `DataType`s with unbound type parameters. See `TestResources.typevar_tester` for an -# example. The above method requires that `P` be a type in which all parameters are fully- -# bound. Strange errors occur if this property does not hold. -@inline __make_ref(::Type{<:Type}) = Ref{NoRData}(NoRData()) - -@inline __make_ref(::Type{Union{}}) = nothing - # Returns the number of arguments that the primal function has. num_args(info::ADInfo) = length(info.arg_types)