diff --git a/Project.toml b/Project.toml index 8498a4428..da173b6b5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.68" +version = "0.4.69" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/MooncakeCUDAExt.jl b/ext/MooncakeCUDAExt.jl index b08d97807..56ab065bc 100644 --- a/ext/MooncakeCUDAExt.jl +++ b/ext/MooncakeCUDAExt.jl @@ -26,7 +26,7 @@ import Mooncake.TestUtils: populate_address_map!, AddressMap, __increment_should # Tell Mooncake.jl how to handle CuArrays. -tangent_type(::Type{P}) where {P<:CuArray{<:IEEEFloat}} = P +Mooncake.@tt_effects tangent_type(::Type{P}) where {P<:CuArray{<:IEEEFloat}} = P zero_tangent(x::CuArray{<:IEEEFloat}) = zero(x) function randn_tangent(rng::AbstractRNG, x::CuArray{Float32}) return cu(randn(rng, Float32, size(x)...)) diff --git a/src/fwds_rvs_data.jl b/src/fwds_rvs_data.jl index 554ee2d82..02b69ca6c 100644 --- a/src/fwds_rvs_data.jl +++ b/src/fwds_rvs_data.jl @@ -166,11 +166,13 @@ end T == NoTangent && return NoFData # This method can only handle struct types. Tell user to implement their own method. - isprimitivetype(T) && - throw(error("$T is a primitive type. Implement a method of `fdata_type` for it.")) + if isprimitivetype(T) + msg = "$T is a primitive type. Implement a method of `fdata_type` for it." + return :(error($msg)) + end # If the type is a Union, then take the union type of its arguments. - T isa Union && return Union{fdata_type(T.a),fdata_type(T.b)} + T isa Union && return :(Union{fdata_type($(T.a)),fdata_type($(T.b))}) # If `P` is a mutable type, then its forwards data is its tangent. ismutabletype(T) && return T @@ -179,33 +181,37 @@ end # The same goes for if the type has any undetermined type parameters. (isabstracttype(T) || !isconcretetype(T)) && return Any + # We should now have a `Tangent`. If not, we do not know what to do, so error. + T <: Tangent || return :(error("Unhandled type $T")) + # If `P` is an immutable type, then some of its fields may not need to be propagated # on the forwards-pass. - if T <: Tangent - Tfields = fields_type(T) - fwds_data_field_types = map(1:fieldcount(Tfields)) do n - return fdata_type(fieldtype(Tfields, n)) - end - all(==(NoFData), fwds_data_field_types) && return NoFData - return FData{NamedTuple{fieldnames(Tfields),Tuple{fwds_data_field_types...}}} + field_names = fieldnames(fields_type(T)) + Tfields = fieldtypes(fields_type(T)) + fdata_type_exprs = map(n -> :(fdata_type($(Tfields[n]))), 1:length(Tfields)) + return quote + fwds_data_field_types = $(Expr(:call, :tuple, fdata_type_exprs...)) + stable_all(tuple_map(==(NoFData), fwds_data_field_types)) && return NoFData + return FData{NamedTuple{$field_names,Tuple{fwds_data_field_types...}}} end - - return :(error("Unhandled type $T")) end fdata_type(::Type{T}) where {T<:Ptr} = T @generated function fdata_type(::Type{P}) where {P<:Tuple} - isa(P, Union) && return Union{fdata_type(P.a),fdata_type(P.b)} + isa(P, Union) && return :(Union{fdata_type($(P.a)),fdata_type($(P.b))}) isempty(P.parameters) && return NoFData isa(last(P.parameters), Core.TypeofVararg) && return Any nofdata_tt = Tuple{Vararg{NoFData,length(P.parameters)}} - fdata_tt = Tuple{map(fdata_type, fieldtypes(P))...} - fdata_tt <: nofdata_tt && return NoFData - return nofdata_tt <: fdata_tt ? Union{NoFData,fdata_tt} : fdata_tt + fdata_type_exprs = map(_P -> Expr(:call, :fdata_type, _P), P.parameters) + return quote + fdata_tt = $(Expr(:curly, Tuple, fdata_type_exprs...)) + fdata_tt <: $nofdata_tt && return NoFData + return $nofdata_tt <: fdata_tt ? Union{NoFData,fdata_tt} : fdata_tt + end end -@generated function fdata_type(::Type{NamedTuple{names,T}}) where {names,T<:Tuple} +function fdata_type(::Type{NamedTuple{names,T}}) where {names,T<:Tuple} if fdata_type(T) == NoFData return NoFData elseif isconcretetype(fdata_type(T)) @@ -224,7 +230,7 @@ Returns the type of to the nth field of the fdata type associated to `P`. Will b function fdata_field_type(::Type{P}, n::Int) where {P} Tf = tangent_type(fieldtype(P, n)) f = ismutabletype(P) ? Tf : fdata_type(Tf) - return is_always_initialised(P, n) ? f : _wrap_type(f) + return is_always_initialised(P, n) ? f : PossiblyUninitTangent{f} end """ @@ -232,20 +238,20 @@ end Extract the forwards data from tangent `t`. """ -@generated function fdata(t::T) where {T} +function fdata(t::T) where {T} # Ask for the forwards-data type. Useful catch-all error checking for unexpected types. F = fdata_type(T) # Catch-all for anything with no forwards-data. - F == NoFData && return :(NoFData()) + F == NoFData && return NoFData() # Catch-all for anything where we return the whole object (mutable structs, arrays...). - F == T && return :(t) + F == T && return t # T must be a `Tangent` by now. If it's not, something has gone wrong. - !(T <: Tangent) && return :(error("Unhandled type $T")) - return :($F(fdata(t.fields))) + T <: Tangent || error("Unhandled type $T") + return F(fdata(t.fields)) end function fdata(t::T) where {T<:PossiblyUninitTangent} @@ -415,11 +421,13 @@ end T == NoTangent && return NoRData # This method can only handle struct types. Tell user to implement their own method. - isprimitivetype(T) && - throw(error("$T is a primitive type. Implement a method of `rdata_type` for it.")) + if isprimitivetype(T) + msg = "$T is a primitive type. Implement a method of `rdata_type` for it." + return :(error(msg)) + end # If the type is a Union, then take the union type of its arguments. - T isa Union && return Union{rdata_type(T.a),rdata_type(T.b)} + T isa Union && return :(Union{rdata_type($(T.a)),rdata_type($(T.b))}) # If `P` is a mutable type, then all tangent info is propagated on the forwards-pass. ismutabletype(T) && return NoRData @@ -428,26 +436,31 @@ end # The same goes for if the type has any undetermined type parameters. (isabstracttype(T) || !isconcretetype(T)) && return Any - # If `T` is an immutable type, then some of its fields may not have been propagated on - # the forwards-pass. - if T <: Tangent - Tfs = fields_type(T) - rvs_types = map(n -> rdata_type(fieldtype(Tfs, n)), 1:fieldcount(Tfs)) - all(==(NoRData), rvs_types) && return NoRData - return RData{NamedTuple{fieldnames(Tfs),Tuple{rvs_types...}}} + # If `T` is an immutable type, then some of its fields may not need to be propagated + # on the forwards-pass. + field_names = fieldnames(fields_type(T)) + Tfields = fieldtypes(fields_type(T)) + rdata_type_exprs = map(n -> :(rdata_type($(Tfields[n]))), 1:length(Tfields)) + return quote + rvs_data_field_types = $(Expr(:call, :tuple, rdata_type_exprs...)) + stable_all(tuple_map(==(NoRData), rvs_data_field_types)) && return NoRData + return RData{NamedTuple{$field_names,Tuple{rvs_data_field_types...}}} end end rdata_type(::Type{<:Ptr}) = NoRData @generated function rdata_type(::Type{P}) where {P<:Tuple} - isa(P, Union) && return Union{rdata_type(P.a),rdata_type(P.b)} + isa(P, Union) && return :(Union{rdata_type($(P.a)),rdata_type($(P.b))}) isempty(P.parameters) && return NoRData isa(last(P.parameters), Core.TypeofVararg) && return Any nordata_tt = Tuple{Vararg{NoRData,length(P.parameters)}} - rdata_tt = Tuple{map(rdata_type, fieldtypes(P))...} - rdata_tt <: nordata_tt && return NoRData - return nordata_tt <: rdata_tt ? Union{NoRData,rdata_tt} : rdata_tt + rdata_type_exprs = map(_P -> Expr(:call, :rdata_type, _P), P.parameters) + return quote + rdata_tt = $(Expr(:curly, Tuple, rdata_type_exprs...)) + rdata_tt <: $nordata_tt && return NoRData + return $nordata_tt <: rdata_tt ? Union{NoRData,rdata_tt} : rdata_tt + end end function rdata_type(::Type{NamedTuple{names,T}}) where {names,T<:Tuple} @@ -468,7 +481,7 @@ Returns the type of to the nth field of the rdata type associated to `P`. Will b """ function rdata_field_type(::Type{P}, n::Int) where {P} r = rdata_type(tangent_type(fieldtype(P, n))) - return is_always_initialised(P, n) ? r : _wrap_type(r) + return is_always_initialised(P, n) ? r : PossiblyUninitTangent{r} end """ @@ -480,20 +493,20 @@ Extract the reverse data from tangent `t`. See extended help section of [fdata_type](@ref). """ -@generated function rdata(t::T) where {T} +function rdata(t::T) where {T} # Ask for the reverse-data type. Useful catch-all error checking for unexpected types. R = rdata_type(T) # Catch-all for anything with no reverse-data. - R == NoRData && return :(NoRData()) + R == NoRData && return NoRData() # Catch-all for anything where we return the whole object (Float64, isbits structs, ...) - R == T && return :(t) + R == T && return t # T must be a `Tangent` by now. If it's not, something has gone wrong. - !(T <: Tangent) && return :(error("Unhandled type $T")) - return :($(rdata_type(T))(rdata(t.fields))) + T <: Tangent || error("Unhandled type $T") + return R(rdata(t.fields)) end function rdata(t::T) where {T<:PossiblyUninitTangent} @@ -604,41 +617,48 @@ constitute a correctness problem, but can be detrimental to performance, so shou with. """ @generated function zero_rdata_from_type(::Type{P}) where {P} - R = rdata_type(tangent_type(P)) - - # If we know we can't produce a tangent, say so. - can_produce_zero_rdata_from_type(P) || return CannotProduceZeroRDataFromType() - - # Simple case. - R == NoRData && return NoRData() - # If `P` is a struct type, attempt to derive the zero rdata for it. We cannot derive - # the zero rdata if it is not possible to derive the zero rdata for any of its fields. - if isstructtype(P) + # Prepare expressions for manually-unrolled loop to construct zero rdata elements. + if P isa DataType names = fieldnames(P) types = fieldtypes(P) - wrapped_field_zeros = tuple_map(ntuple(identity, length(names))) do n + wrapped_field_zeros = map(enumerate(tangent_field_types(P))) do (n, tt) fzero = :(zero_rdata_from_type($(types[n]))) - if tangent_field_type(P, n) <: PossiblyUninitTangent - Q = rdata_type(tangent_type(fieldtype(P, n))) - return :(_wrap_field($Q, $fzero)) + if tt <: PossiblyUninitTangent + Q = :(rdata_type(tangent_type($(fieldtype(P, n))))) + return :(PossiblyUninitTangent{$Q}($fzero)) else return fzero end end wrapped_field_zeros_tuple = Expr(:call, :tuple, wrapped_field_zeros...) - return :($R(NamedTuple{$names}($wrapped_field_zeros_tuple))) + wrapped_expr = :(R(NamedTuple{$names}($wrapped_field_zeros_tuple))) + else + wrapped_expr = nothing end - # Fallback -- we've not been able to figure out how to produce an instance of zero rdata - # so report that it cannot be done. - return throw(error("Unhandled type $P")) + return quote + + # If we know we can't produce a tangent, say so. + can_produce_zero_rdata_from_type($P) || return CannotProduceZeroRDataFromType() + + # Simple case. + R = rdata_type(tangent_type($P)) + R == NoRData && return NoRData() + + $(isstructtype(P)) || error("Unhandled type $P") + return $wrapped_expr + end end @generated function zero_rdata_from_type(::Type{P}) where {P<:Tuple} - can_produce_zero_rdata_from_type(P) || return CannotProduceZeroRDataFromType() - rdata_type(tangent_type(P)) == NoRData && return NoRData() - return tuple_map(zero_rdata_from_type, fieldtypes(P)) + has_fields = P isa DataType && Base.datatype_fieldcount(P) !== nothing + zero_exprs = has_fields ? map(_P -> :(zero_rdata_from_type($_P)), fieldtypes(P)) : [] + return quote + can_produce_zero_rdata_from_type($P) || return CannotProduceZeroRDataFromType() + rdata_type(tangent_type($P)) == NoRData && return NoRData() + return $(Expr(:call, :tuple, zero_exprs...)) + end end function zero_rdata_from_type(::Type{P}) where {P<:NamedTuple} @@ -785,15 +805,14 @@ tangent_type(::Type{F}, ::Type{NoRData}) where {F<:Array} = F # Tuples @generated function tangent_type(::Type{F}, ::Type{R}) where {F<:Tuple,R<:Tuple} - return Tuple{tuple_map(tangent_type, Tuple(F.parameters), Tuple(R.parameters))...} + tt_exprs = map((f, r) -> :(tangent_type($f, $r)), fieldtypes(F), fieldtypes(R)) + return Expr(:curly, :Tuple, tt_exprs...) end function tangent_type(::Type{NoFData}, ::Type{R}) where {R<:Tuple} - F_tuple = Tuple{tuple_fill(NoFData, Val(length(R.parameters)))...} - return tangent_type(F_tuple, R) + return tangent_type(Tuple{tuple_fill(NoFData, Val(length(R.parameters)))...}, R) end function tangent_type(::Type{F}, ::Type{NoRData}) where {F<:Tuple} - R_tuple = Tuple{tuple_fill(NoRData, Val(length(F.parameters)))...} - return tangent_type(F, R_tuple) + return tangent_type(F, Tuple{tuple_fill(NoRData, Val(length(F.parameters)))...}) end # NamedTuples @@ -904,10 +923,7 @@ Equivalent to `tangent(fdata, rdata(zero_tangent(primal)))`. zero_tangent(p, ::NoFData) = zero_tangent(p) function zero_tangent(p::P, f::F) where {P,F} - T = tangent_type(P) - T == F && return f - r = rdata(zero_tangent(p)) - return tangent(f, r) + return tangent_type(P) == F ? f : tangent(f, rdata(zero_tangent(p))) end zero_tangent(p::Tuple, f::Union{Tuple,NamedTuple}) = tuple_map(zero_tangent, p, f) diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index 8a1c32ba7..bb5fedf2b 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -32,6 +32,7 @@ struct MooncakeInterpreter{C} <: CC.AbstractInterpreter inf_cache::Vector{CC.InferenceResult} code_cache::MooncakeCache oc_cache::Dict{ClosureCacheKey,Any} + inline_primitives::Bool function MooncakeInterpreter( ::Type{C}; meta=nothing, @@ -41,8 +42,18 @@ struct MooncakeInterpreter{C} <: CC.AbstractInterpreter inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[], code_cache::MooncakeCache=MooncakeCache(), oc_cache::Dict{ClosureCacheKey,Any}=Dict{ClosureCacheKey,Any}(), + inline_primitives::Bool=false, ) where {C} - return new{C}(meta, world, inf_params, opt_params, inf_cache, code_cache, oc_cache) + return new{C}( + meta, + world, + inf_params, + opt_params, + inf_cache, + code_cache, + oc_cache, + inline_primitives, + ) end end @@ -61,28 +72,6 @@ MooncakeInterpreter() = MooncakeInterpreter(DefaultCtx) context_type(::MooncakeInterpreter{C}) where {C} = C -""" - const GLOBAL_INTERPRETER - -Globally cached interpreter. Should only be accessed via `get_interpreter`. -""" -const GLOBAL_INTERPRETER = Ref(MooncakeInterpreter()) - -""" - get_interpreter() - -Returns a `MooncakeInterpreter` appropriate for the current world age. Will use a cached -interpreter if one already exists for the current world age, otherwise creates a new one. - -This should be prefered over constructing a `MooncakeInterpreter` directly. -""" -function get_interpreter() - if GLOBAL_INTERPRETER[].world != Base.get_world_counter() - GLOBAL_INTERPRETER[] = MooncakeInterpreter() - end - return GLOBAL_INTERPRETER[] -end - CC.InferenceParams(interp::MooncakeInterpreter) = interp.inf_params CC.OptimizationParams(interp::MooncakeInterpreter) = interp.opt_params CC.get_inference_cache(interp::MooncakeInterpreter) = interp.inf_cache @@ -134,7 +123,9 @@ function Core.Compiler.abstract_call_gf_by_type( sv::CC.AbsIntState, max_methods::Int, ) where {C} - ret = @invoke CC.abstract_call_gf_by_type( + + # invoke the default abstract call to get the default CC.CallMeta. + cm = @invoke CC.abstract_call_gf_by_type( interp::CC.AbstractInterpreter, f::Any, arginfo::CC.ArgInfo, @@ -143,14 +134,19 @@ function Core.Compiler.abstract_call_gf_by_type( sv::CC.AbsIntState, max_methods::Int, ) - callinfo = ret.info - if Mooncake.is_primitive(C, atype) + + # Check to see whether the call in question is a Mooncake primitive. If it is, set its + # call info such that in the `CC.inlining_policy` it is not inlined away. + callinfo = cm.info + if !interp.inline_primitives && Mooncake.is_primitive(C, atype) callinfo = NoInlineCallInfo(callinfo, atype) end + + # Construct a CallMeta correctly depending on the version of Julia. @static if VERSION ≥ v"1.11-" - return CC.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) + return CC.CallMeta(cm.rt, cm.exct, cm.effects, callinfo) else - return CC.CallMeta(ret.rt, ret.effects, callinfo) + return CC.CallMeta(cm.rt, cm.effects, callinfo) end end @@ -194,3 +190,42 @@ else # 1.11 and up. ) end end + +""" + const GLOBAL_INTERPRETER + +Globally cached interpreter. Should only be accessed via `get_interpreter`. +""" +const GLOBAL_INTERPRETER = Ref(MooncakeInterpreter()) + +""" + const GLOBAL_INLINING_INTERPRETER + +Globally cached interpreter which inline away AD primitives. +""" +const GLOBAL_INLINING_INTERPRETER = Ref( + MooncakeInterpreter(DefaultCtx; inline_primitives=true) +) + +""" + get_interpreter() + +Returns a `MooncakeInterpreter` appropriate for the current world age. Will use a cached +interpreter if one already exists for the current world age, otherwise creates a new one. + +This should be prefered over constructing a `MooncakeInterpreter` directly. +""" +function get_interpreter(; inline_primitives=false) + if inline_primitives + if GLOBAL_INLINING_INTERPRETER[].world != Base.get_world_counter() + interp = MooncakeInterpreter(DefaultCtx; inline_primitives) + GLOBAL_INLINING_INTERPRETER[] = interp + end + return GLOBAL_INLINING_INTERPRETER[] + else + if GLOBAL_INTERPRETER[].world != Base.get_world_counter() + GLOBAL_INTERPRETER[] = MooncakeInterpreter() + end + return GLOBAL_INTERPRETER[] + end +end diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index 48334fbdd..96102931c 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -173,7 +173,7 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true) CC.verify_ir(ir) ir = __strip_coverage!(ir) ir = CC.compact!(ir) - local_interp = CC.NativeInterpreter() + local_interp = get_interpreter(; inline_primitives=true) mi = __get_toplevel_mi_from_ir(ir, @__MODULE__) ir = __infer_ir!(ir, local_interp, mi) if show_ir diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 4e13ea7f2..1ab4e71b9 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -272,10 +272,9 @@ end Helper for [`reverse_data_ref_stmts`](@ref). Constructs a `Ref` whose element type is the [`zero_like_rdata_type`](@ref) for `P`, and whose element is the zero-like rdata for `P`. """ -@inline @generated function __make_ref(p::Type{P}) where {P} +@inline function __make_ref(p::Type{P}) where {P} _P = @isdefined(P) ? P : _typeof(p) - R = zero_like_rdata_type(_P) - return :(Ref{$R}(Mooncake.zero_like_rdata_from_type($_P))) + return Ref{zero_like_rdata_type(_P)}(Mooncake.zero_like_rdata_from_type(_P)) end # This specialised method is necessary to ensure that `__make_ref` works properly for @@ -1341,11 +1340,12 @@ straightforward to figure out much time is spent pushing to the block stack when """ @inline __push_blk_stack!(block_stack::BlockStack, id::Int32) = push!(block_stack, id) -@inline function __assemble_lazy_zero_rdata( +__lazy_zero_rdata_primal(T, x) = lazy_zero_rdata(T, primal(x)) + +@inline @generated function __assemble_lazy_zero_rdata( r::Ref{T}, args::Vararg{CoDual,N} ) where {T<:Tuple,N} - r[] = map((T, x) -> lazy_zero_rdata(T, primal(x)), fieldtypes(T), args) - return nothing + return :(r[] = tuple_map(__lazy_zero_rdata_primal, $(fieldtypes(T)), args)) end """ diff --git a/src/rrules/fastmath.jl b/src/rrules/fastmath.jl index 6cfa99751..7acbb59ab 100644 --- a/src/rrules/fastmath.jl +++ b/src/rrules/fastmath.jl @@ -40,6 +40,8 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:fastmath}) vcat, map([Float64, Float32]) do P return Any[ + (false, :stability_and_allocs, nothing, cosh, P(0.3)), + (false, :stability_and_allocs, nothing, sinh, P(0.3)), (false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, P(0.5)), (false, :stability_and_allocs, nothing, Base.FastMath.exp2_fast, P(0.5)), (false, :stability_and_allocs, nothing, Base.FastMath.exp_fast, P(5.0)), diff --git a/src/rrules/iddict.jl b/src/rrules/iddict.jl index 40eaaeb64..d8a947462 100644 --- a/src/rrules/iddict.jl +++ b/src/rrules/iddict.jl @@ -1,6 +1,6 @@ # We're going to use `IdDict`s to represent tangents for `IdDict`s. -tangent_type(::Type{<:IdDict{K,V}}) where {K,V} = IdDict{K,tangent_type(V)} +@tt_effects tangent_type(::Type{<:IdDict{K,V}}) where {K,V} = IdDict{K,tangent_type(V)} function randn_tangent(rng::AbstractRNG, d::IdDict{K,V}) where {K,V} return IdDict{K,tangent_type(V)}([k => randn_tangent(rng, v) for (k, v) in d]) end diff --git a/src/rrules/memory.jl b/src/rrules/memory.jl index 80ec3f1fc..dc439bd44 100644 --- a/src/rrules/memory.jl +++ b/src/rrules/memory.jl @@ -14,7 +14,7 @@ const Maybe{T} = Union{Nothing,T} -tangent_type(::Type{<:Memory{P}}) where {P} = Memory{tangent_type(P)} +@tt_effects tangent_type(::Type{<:Memory{P}}) where {P} = Memory{tangent_type(P)} function zero_tangent_internal(x::Memory{P}, stackdict::Maybe{IdDict}) where {P} T = tangent_type(typeof(x)) @@ -241,7 +241,7 @@ end # Tangent Interface Implementation -tangent_type(::Type{<:MemoryRef{P}}) where {P} = MemoryRef{tangent_type(P)} +@tt_effects tangent_type(::Type{<:MemoryRef{P}}) where {P} = MemoryRef{tangent_type(P)} #= Given a new chunk of memory `m`, construct a `MemoryRef` which points to the same relative diff --git a/src/rrules/twice_precision.jl b/src/rrules/twice_precision.jl index 521ccaeeb..199caa94b 100644 --- a/src/rrules/twice_precision.jl +++ b/src/rrules/twice_precision.jl @@ -12,7 +12,7 @@ const TwicePrecisionFloat{P<:IEEEFloat} = TwicePrecision{P} const TWP{P} = TwicePrecisionFloat{P} -tangent_type(P::Type{<:TWP}) = P +@tt_effects tangent_type(P::Type{<:TWP}) = P zero_tangent_internal(::TWP{F}, ::StackDict) where {F} = TWP{F}(zero(F), zero(F)) diff --git a/src/tangents.jl b/src/tangents.jl index f4e3b4a89..de37106b9 100644 --- a/src/tangents.jl +++ b/src/tangents.jl @@ -28,27 +28,9 @@ _copy(x::P) where {P<:PossiblyUninitTangent} = is_init(x) ? P(_copy(x.tangent)) @inline is_init(t::PossiblyUninitTangent) = isdefined(t, :tangent) is_init(t) = true -function val(x::PossiblyUninitTangent{T}) where {T} - if is_init(x) - return x.tangent - else - throw(error("Uninitialised")) - end -end +val(x::PossiblyUninitTangent) = is_init(x) ? x.tangent : error("Uninitialised") val(x) = x -function Base.:(==)(t::PossiblyUninitTangent{T}, s::PossiblyUninitTangent{T}) where {T} - is_init(t) && is_init(s) && return val(t) == val(s) - is_init(t) && !is_init(s) && return false - !is_init(t) && is_init(s) && return false - return true -end - -_wrap_type(::Type{T}) where {T} = PossiblyUninitTangent{T} - -_wrap_field(::Type{Q}, x::T) where {Q,T} = PossiblyUninitTangent{Q}(x) -_wrap_field(x::T) where {T} = _wrap_field(T, x) - struct Tangent{Tfields<:NamedTuple} fields::Tfields end @@ -83,10 +65,7 @@ Has the same semantics that `getfield!` would have if the data in the `fields` f were actually fields of `t`. This is the moral equivalent of `getfield` for `MutableTangent`. """ -@inline function get_tangent_field(t::PossiblyMutableTangent{Tfs}, i::Int) where {Tfs} - v = getfield(t.fields, i) - return fieldtype(Tfs, i) <: PossiblyUninitTangent ? val(v) : v -end +@inline get_tangent_field(t::PossiblyMutableTangent, i::Int) = val(getfield(t.fields, i)) @inline function get_tangent_field(t::PossiblyMutableTangent{F}, s::Symbol) where {F} return get_tangent_field(t, _sym_to_int(F, Val(s))) @@ -110,52 +89,44 @@ were actually fields of `t`. This is the moral equivalent of `setfield!` for return x end -@inline function set_tangent_field!( - t::MutableTangent{Tfields}, s::Symbol, x -) where {Tfields} - return set_tangent_field!(t, _sym_to_int(Tfields, Val(s)), x) +@inline function set_tangent_field!(t::MutableTangent{T}, s::Symbol, x) where {T} + return set_tangent_field!(t, _sym_to_int(T, Val(s)), x) end @generated function _sym_to_int(::Type{Tfields}, ::Val{s}) where {Tfields,s} return findfirst(==(s), fieldnames(Tfields)) end -@generated function build_tangent(::Type{P}, fields::Vararg{Any,N}) where {P,N} - tangent_values_exprs = map(enumerate(fieldtypes(P))) do (n, field_type) - if tangent_field_type(P, n) <: PossiblyUninitTangent - tt = PossiblyUninitTangent{tangent_type(field_type)} - if n <= N - return Expr(:call, tt, :(fields[$n])) - else - return Expr(:call, tt) - end - else - return :(fields[$n]) - end +function tangent_field_types_exprs(P::Type) + tangent_type_exprs = map(fieldtypes(P), always_initialised(P)) do _P, init + T_expr = Expr(:call, :tangent_type, _P) + return init ? T_expr : Expr(:curly, PossiblyUninitTangent, T_expr) end - return Expr( - :call, - tangent_type(P), - Expr(:call, NamedTuple{fieldnames(P)}, Expr(:tuple, tangent_values_exprs...)), - ) + return tangent_type_exprs end -function build_tangent( - ::Type{P}, fields::Vararg{Any,N} -) where {P<:Union{Tuple,NamedTuple},N} - T = tangent_type(P) - if T == NoTangent - return NoTangent() - elseif isconcretetype(P) - return T(fields) - else - return __tangent_from_non_concrete(P, fields) +# It is essential that this gets inlined. If it does not, then we run into performance +# issues with the recursion to compute tangent types for nested types. +@generated function tangent_field_types(::Type{P}) where {P} + return Expr(:call, :tuple, tangent_field_types_exprs(P)...) +end + +@generated function build_tangent(::Type{P}, fields::Vararg{Any,N}) where {P,N} + tangent_values_exprs = map(enumerate(tangent_field_types(P))) do (n, tt) + tt <: PossiblyUninitTangent && return n <= N ? :($tt(fields[$n])) : :($tt()) + return :(fields[$n]) end + tuple_expr = Expr(:tuple, tangent_values_exprs...) + return Expr(:call, tangent_type(P), Expr(:call, NamedTuple{fieldnames(P)}, tuple_expr)) end -__tangent_from_non_concrete(::Type{P}, fields) where {P<:Tuple} = Tuple(fields) -function __tangent_from_non_concrete(::Type{P}, fields) where {names,P<:NamedTuple{names}} - return NamedTuple{names}(fields) +""" + @tt_effects tangent_type(...) + +Effects which ought to be applied to `tangent_type`. +""" +macro tt_effects(expr) + return esc(:(Base.@assume_effects :consistent :removable $expr)) end """ @@ -164,6 +135,11 @@ end There must be a single type used to represents tangents of primals of type `P`, and it must be given by `tangent_type(P)`. +Warning: this function assumes the effects `:removable` and `:consistent`. This is necessary +to ensure good performance, but imposes precise constraints on your implementation. If +adding new methods to `tangent_type`, you should consult the extended help of +`Base.@assume_effects` to see what this imposes upon your implementation. + # Extended help The tangent types which Mooncake.jl uses are quite similar in spirit to ChainRules.jl. @@ -294,7 +270,7 @@ tangent_type(::Type{<:Type}) = NoTangent tangent_type(::Type{<:TypeVar}) = NoTangent -tangent_type(::Type{Ptr{P}}) where {P} = Ptr{tangent_type(P)} +@tt_effects tangent_type(::Type{Ptr{P}}) where {P} = Ptr{tangent_type(P)} tangent_type(::Type{<:Ptr}) = NoTangent @@ -320,13 +296,13 @@ tangent_type(::Type{P}) where {P<:Union{Int8,Int16,Int32,Int64,Int128}} = NoTang tangent_type(::Type{<:Core.Builtin}) = NoTangent -tangent_type(::Type{P}) where {P<:IEEEFloat} = P +@tt_effects tangent_type(::Type{P}) where {P<:IEEEFloat} = P tangent_type(::Type{<:Core.LLVMPtr}) = NoTangent tangent_type(::Type{String}) = NoTangent -tangent_type(::Type{<:Array{P,N}}) where {P,N} = Array{tangent_type(P),N} +@tt_effects tangent_type(::Type{<:Array{P,N}}) where {P,N} = Array{tangent_type(P),N} tangent_type(::Type{<:Array{P,N} where {P}}) where {N} = Array @@ -369,7 +345,7 @@ end # Generated functions cannot emit closures, so this is defined here for use below. isconcrete_or_union(p) = p isa Union || isconcretetype(p) -@generated function tangent_type(::Type{P}) where {N,P<:Tuple{Vararg{Any,N}}} +@tt_effects @generated function tangent_type(::Type{P}) where {N,P<:Tuple{Vararg{Any,N}}} # As with other types, tangent type of Union is Union of tangent types. P isa Union && return :(Union{tangent_type($(P.a)),tangent_type($(P.b))}) @@ -383,14 +359,18 @@ isconcrete_or_union(p) = p isa Union || isconcretetype(p) # a UnionAll before running to ensure that datatype_fieldcount will run. isa(P, DataType) && N == 0 && return NoTangent - # Get tangent types for all fields. If they're all `NoTangent`, return `NoTangent`. - # i.e. if `P = Tuple{Int, Int}`, do not return `Tuple{NoTangent, NoTangent}`. Simplify - # and return `NoTangent`. + # Expression to construct `Tuple` type containing tangent type for all fields. tangent_type_exprs = map(n -> :(tangent_type(fieldtype(P, $n))), 1:N) tangent_types = Expr(:call, tuple, tangent_type_exprs...) + + # Construct a Tuple type of the same length as `P`, containing all `NoTangent`s. T_all_notangent = Tuple{Vararg{NoTangent,N}} return quote + + # Get tangent types for all fields. If they're all `NoTangent`, return `NoTangent`. + # i.e. if `P = Tuple{Int, Int}`, do not return `Tuple{NoTangent, NoTangent}`. + # Simplify and return `NoTangent`. tangent_types = $tangent_types T = Tuple{tangent_types...} T <: $T_all_notangent && return NoTangent @@ -411,7 +391,7 @@ isconcrete_or_union(p) = p isa Union || isconcretetype(p) end end -function tangent_type(::Type{P}) where {N,P<:NamedTuple{N}} +@tt_effects function tangent_type(::Type{P}) where {N,P<:NamedTuple{N}} P isa Union && return Union{tangent_type(P.a),tangent_type(P.b)} !isconcretetype(P) && return Union{NoTangent,NamedTuple{N}} TT = tangent_type(Tuple{fieldtypes(P)...}) @@ -419,49 +399,38 @@ function tangent_type(::Type{P}) where {N,P<:NamedTuple{N}} return isconcretetype(TT) ? NamedTuple{N,TT} : Any end -@generated function tangent_type(::Type{P}) where {P} - # This method can only handle struct types. Tell user to implement tangent type - # directly for primitive types. - isprimitivetype(P) && - throw(error("$P is a primitive type. Implement a method of `tangent_type` for it.")) +@tt_effects @generated function tangent_type(::Type{P}) where {P} + + # This method can only handle struct types. Something has gone wrong if P is primitive. + if isprimitivetype(P) + return error("$P is a primitive type. Implement a method of `tangent_type` for it.") + end # If the type is a Union, then take the union type of its arguments. - P isa Union && return Union{tangent_type(P.a),tangent_type(P.b)} + P isa Union && return :(Union{tangent_type($(P.a)),tangent_type($(P.b))}) # If the type is itself abstract, it's tangent could be anything. # The same goes for if the type has any undetermined type parameters. (isabstracttype(P) || !isconcretetype(P)) && return Any - # If all fields are definitely NoTangents, then the overall tangent type is NoTangent. + tangent_fields_types_expr = Expr(:curly, Tuple, tangent_field_types_exprs(P)...) T_all_notangent = Tuple{Vararg{NoTangent,fieldcount(P)}} - Tuple{tangent_field_types(P)...} <: T_all_notangent && return NoTangent + return quote - # Derive tangent type. - bt = backing_type(P) - return bt == NoTangent ? bt : (ismutabletype(P) ? MutableTangent : Tangent){bt} -end + # Construct a `Tuple{...}` whose fields are the tangent types of the fields of `P`. + tangent_field_types_tuple = $tangent_fields_types_expr -@inline function tangent_field_types(P) - return tuple_map(Base.Fix1(tangent_field_type, P), (1:fieldcount(P)...,)) -end + # If all fields are definitely `NoTangent`s, then return `NoTangent`. + tangent_field_types_tuple <: $T_all_notangent && return NoTangent -backing_type(P::Type{<:Tuple}) = Tuple{tangent_field_types(P)...} + # Derive tangent type. + bt = NamedTuple{$(fieldnames(P)),tangent_field_types_tuple} + return $(ismutabletype(P) ? MutableTangent : Tangent){bt} + end +end backing_type(P::Type) = NamedTuple{fieldnames(P),Tuple{tangent_field_types(P)...}} -""" - tangent_field_type(::Type{P}, n::Int) where {P} - -Returns the type that lives in the nth elements of `fields` in a `Tangent` / -`MutableTangent`. Will either be the `tangent_type` of the nth fieldtype of `P`, or the -`tangent_type` wrapped in a `PossiblyUninitTangent`. The latter case only occurs if it is -possible for the field to be undefined. -""" -function tangent_field_type(::Type{P}, n::Int) where {P} - t = tangent_type(fieldtype(P, n)) - return is_always_initialised(P, n) ? t : _wrap_type(t) -end - """ zero_tangent(x) @@ -476,9 +445,7 @@ Internally, `zero_tangent` calls `zero_tangent_internal`, which handles differen handles both circular references and aliasing correctly. """ zero_tangent(x) -function zero_tangent(x::P) where {P} - return zero_tangent_internal(x, isbitstype(P) ? nothing : IdDict()) -end +zero_tangent(x::P) where {P} = zero_tangent_internal(x, isbitstype(P) ? nothing : IdDict()) const StackDict = Union{Nothing,IdDict} @@ -532,23 +499,15 @@ function zero_tangent_internal(x::P, stackdict) where {P} end end -@generated function zero_tangent_struct_field(x::P, stackdict) where {P} - tangent_field_zeros_exprs = ntuple(fieldcount(P)) do n - if tangent_field_type(P, n) <: PossiblyUninitTangent - V = PossiblyUninitTangent{tangent_type(fieldtype(P, n))} - return :( - if isdefined(x, $n) - $V(zero_tangent_internal(getfield(x, $n), stackdict)) - else - $V() - end - ) - else - return :(zero_tangent_internal(getfield(x, $n), stackdict)) - end +function zero_tangent_struct_field(x::P, d) where {P} + Tfs = tangent_field_types(P) + inits = always_initialised(P) + tangent_field_zeros = ntuple(Val(fieldcount(P))) do n + T = Tfs[n] + inits[n] && return zero_tangent_internal(getfield(x, n), d) + return isdefined(x, n) ? T(zero_tangent_internal(getfield(x, n), d)) : T() end - tangent_fields_expr = Expr(:call, :tuple, tangent_field_zeros_exprs...) - return :($(backing_type(P))($tangent_fields_expr)) + return backing_type(P)(tangent_field_zeros) end """ @@ -613,23 +572,15 @@ function randn_tangent_internal(rng::AbstractRNG, x::P, stackdict) where {P} end end -@generated function randn_tangent_struct_field(rng::AbstractRNG, x::P, stackdict) where {P} - tangent_field_exprs = map(1:fieldcount(P)) do n - if tangent_field_type(P, n) <: PossiblyUninitTangent - V = PossiblyUninitTangent{tangent_type(fieldtype(P, n))} - return :( - if isdefined(x, $n) - $V(randn_tangent_internal(rng, getfield(x, $n), stackdict)) - else - $V() - end - ) - else - return :(randn_tangent_internal(rng, getfield(x, $n), stackdict)) - end +function randn_tangent_struct_field(rng::AbstractRNG, x::P, d) where {P} + Tfs = tangent_field_types(P) + inits = always_initialised(P) + tangent_field_zeros = ntuple(Val(fieldcount(P))) do n + T = Tfs[n] + inits[n] && return randn_tangent_internal(rng, getfield(x, n), d) + return isdefined(x, n) ? T(randn_tangent_internal(rng, getfield(x, n), d)) : T() end - tangent_fields_expr = Expr(:call, :tuple, tangent_field_exprs...) - return :($(backing_type(P))($tangent_fields_expr)) + return backing_type(P)(tangent_field_zeros) end """ @@ -1057,6 +1008,7 @@ function tangent_test_cases() # Regression tests to catch type inference failures, see https://github.com/compintell/Mooncake.jl/pull/422 (((((randn(33)...,),),),),), (((((((((randn(33)...,),),),),), randn(5)...),),),), + Base.OneTo{Int}, ] VERSION >= v"1.11" && push!(rel_test_cases, fill!(Memory{Float64}(undef, 3), 3.0)) return vcat( diff --git a/src/test_utils.jl b/src/test_utils.jl index 412753a82..579795f5c 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -123,7 +123,8 @@ using Mooncake: InvalidRDataException, uninit_codual, lgetfield, - lsetfield! + lsetfield!, + CC struct Shim end @@ -807,10 +808,15 @@ end test_tangent_type(primal_type, expected_tangent_type) Checks that `tangent_type(primal_type)` yields `expected_tangent_type`, and that everything -infers / optimises away. +infers / optimises away, and that the effects are as expected. """ function test_tangent_type(primal_type::Type, expected_tangent_type::Type) @test tangent_type(primal_type) == expected_tangent_type + effects = Base.infer_effects(tangent_type, (Type{expected_tangent_type},)) + @test effects.consistent == CC.ALWAYS_TRUE + @test effects.effect_free == CC.ALWAYS_TRUE + @test effects.nothrow + @test effects.terminates return test_opt(Shim(), tangent_type, Tuple{_typeof(primal_type)}) end @@ -1061,9 +1067,7 @@ end __tangent_generation_should_allocate(::Type{P}) where {P<:Array} = true function __increment_should_allocate(::Type{P}) where {P} - return any(eachindex(fieldtypes(P))) do n - Mooncake.tangent_field_type(P, n) <: PossiblyUninitTangent - end + return any(tt -> tt <: PossiblyUninitTangent, Mooncake.tangent_field_types(P)) end __increment_should_allocate(::Type{Core.SimpleVector}) = true @@ -1135,8 +1139,10 @@ function test_fwds_rvs_data(rng::AbstractRNG, p::P) where {P} T = tangent_type(P) F = Mooncake.fdata_type(T) @test F isa Type + check_allocs(Shim(), Mooncake.fdata_type, T) R = Mooncake.rdata_type(T) @test R isa Type + check_allocs(Shim(), Mooncake.rdata_type, T) # Check that fdata and rdata produce the correct types. t = randn_tangent(rng, p) diff --git a/src/utils.jl b/src/utils.jl index ed0d9cc5c..6914ab2ad 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -81,6 +81,16 @@ end return Expr(:call, :tuple, map(_ -> :val, 1:N)...) end +""" + stable_all(x::NTuple{N, Bool}) where {N} + +`all(x::NTuple{N, Bool})` does not constant-fold nicely on 1.10 if the values of `x` are +known statically. This implementation constant-folds nicely on both 1.10 and 1.11, so can +be used in its place in situations where this is important. +""" +stable_all(x::NTuple{1,Bool}) = x[1] +stable_all(x::NTuple{N,Bool}) where {N} = x[1] & stable_all(x[2:end]) + """ _map_if_assigned!(f, y::DenseArray, x::DenseArray{P}) where {P} @@ -181,6 +191,18 @@ function sparam_names(m::Core.Method)::Vector{Symbol} end end +""" + always_initialised(::Type{P}) where {P} + +Returns a tuple with number of fields equal to the number of fields in `P`. The nth field +is set to `true` if the nth field of `P` is initialised, and `false` otherwise. +""" +@generated function always_initialised(::Type{P}) where {P} + P isa DataType || return :(error("$P is not a DataType.")) + num_init = CC.datatype_min_ninitialized(P) + return (map(n -> n <= num_init, 1:fieldcount(P))...,) +end + """ is_always_initialised(P::DataType, n::Int)::Bool diff --git a/test/front_matter.jl b/test/front_matter.jl index 636684a10..1627f11f1 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -25,12 +25,10 @@ using Mooncake: TestUtils, TestResources, CoDual, - _wrap_field, DefaultCtx, rrule!!, lgetfield, lsetfield!, - build_tangent, Stack, _typeof, BBCode, @@ -63,7 +61,10 @@ using Mooncake: verify_rdata_value, is_primitive, MinimalCtx, - stmt + stmt, + can_produce_zero_rdata_from_type, + zero_rdata_from_type, + CannotProduceZeroRDataFromType using .TestUtils: test_rule, diff --git a/test/fwds_rvs_data.jl b/test/fwds_rvs_data.jl index 660748c06..d96edc66f 100644 --- a/test/fwds_rvs_data.jl +++ b/test/fwds_rvs_data.jl @@ -18,18 +18,26 @@ end TestUtils.test_fwds_rvs_data(Xoshiro(123456), p) end @testset "zero_rdata_from_type checks" begin - @test Mooncake.can_produce_zero_rdata_from_type(Vector) == true - @test Mooncake.zero_rdata_from_type(Vector) == NoRData() - @test !Mooncake.can_produce_zero_rdata_from_type(FwdsRvsDataTestResources.Foo) - @test Mooncake.can_produce_zero_rdata_from_type(Tuple{Float64,Type{Float64}}) + @test can_produce_zero_rdata_from_type(Vector) == true + @test zero_rdata_from_type(Vector) == NoRData() + @test !can_produce_zero_rdata_from_type(FwdsRvsDataTestResources.Foo) + @test can_produce_zero_rdata_from_type(Tuple{Float64,Type{Float64}}) @test ==( - Mooncake.zero_rdata_from_type(FwdsRvsDataTestResources.Foo), - Mooncake.CannotProduceZeroRDataFromType(), + zero_rdata_from_type(FwdsRvsDataTestResources.Foo), + CannotProduceZeroRDataFromType(), ) - @test !Mooncake.can_produce_zero_rdata_from_type(Tuple) - @test !Mooncake.can_produce_zero_rdata_from_type(Union{Tuple{Float64},Tuple{Int}}) - @test !Mooncake.can_produce_zero_rdata_from_type(Tuple{T,T} where {T<:Integer}) - @test Mooncake.can_produce_zero_rdata_from_type(Type{Float64}) + @test !can_produce_zero_rdata_from_type(Tuple) + @test zero_rdata_from_type(Tuple) == CannotProduceZeroRDataFromType() + @test !can_produce_zero_rdata_from_type(Union{Tuple{Float64},Tuple{Int}}) + @test ==( + zero_rdata_from_type(Union{Tuple{Float64},Tuple{Int}}), + CannotProduceZeroRDataFromType(), + ) + @test !can_produce_zero_rdata_from_type(Tuple{T,T} where {T<:Integer}) + @test can_produce_zero_rdata_from_type(Type{Float64}) + @test can_produce_zero_rdata_from_type(Union{Tuple{Int},Tuple{Int,Int}}) + @test zero_rdata_from_type(Union{Tuple{Int},Tuple{Int,Int}}) == NoRData() + @test zero_rdata_from_type(Union{Float64,Int}) == CannotProduceZeroRDataFromType() # Edge case: Types with unbound type parameters. P = (Type{T} where {T}).body diff --git a/test/integration_testing/battery_tests/Project.toml b/test/integration_testing/battery_tests/Project.toml index 625c65412..2299cb44c 100644 --- a/test/integration_testing/battery_tests/Project.toml +++ b/test/integration_testing/battery_tests/Project.toml @@ -1,4 +1,5 @@ [deps] +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" diff --git a/test/integration_testing/battery_tests/battery_tests.jl b/test/integration_testing/battery_tests/battery_tests.jl index 266d164f9..6b5521be5 100644 --- a/test/integration_testing/battery_tests/battery_tests.jl +++ b/test/integration_testing/battery_tests/battery_tests.jl @@ -2,7 +2,7 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) -using JET, LinearAlgebra, Mooncake, Random, StableRNGs, Test +using AllocCheck, JET, LinearAlgebra, Mooncake, Random, StableRNGs, Test using Mooncake: TestResources @testset "battery_tests" begin diff --git a/test/interpreter/s2s_reverse_mode_ad.jl b/test/interpreter/s2s_reverse_mode_ad.jl index 810d41a40..344f8f76a 100644 --- a/test/interpreter/s2s_reverse_mode_ad.jl +++ b/test/interpreter/s2s_reverse_mode_ad.jl @@ -277,15 +277,15 @@ end TestUtils.test_rule( Xoshiro(123456), f, x...; perf_flag, interface_only, is_primitive=false ) - TestUtils.test_rule( - Xoshiro(123456), - f, - x...; - perf_flag=:none, - interface_only, - is_primitive=false, - debug_mode=true, - ) + # TestUtils.test_rule( + # Xoshiro(123456), + # f, + # x...; + # perf_flag=:none, + # interface_only, + # is_primitive=false, + # debug_mode=true, + # ) # interp = Mooncake.get_interpreter() # codual_args = map(zero_codual, (f, x...)) diff --git a/test/utils.jl b/test/utils.jl index 03d630a7b..4daf0028d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -25,6 +25,14 @@ @test_throws ArgumentError Mooncake.tuple_map(*, (5.0, 4.0), (4.0,)) @test_throws ArgumentError Mooncake.tuple_map(*, (4.0,), (5.0, 4.0)) end + @testset "stable_all" begin + @test Mooncake.stable_all((false,)) == false + @test Mooncake.stable_all((true,)) == true + @test Mooncake.stable_all((false, true)) == false + @test Mooncake.stable_all((false, false)) == false + @test Mooncake.stable_all((true, false)) == false + @test Mooncake.stable_all((true, true)) == true + end @testset "_map_if_assigned!" begin @testset "unary bits type" begin x = Vector{Float64}(undef, 10)