From 90997a029a7bcd1f7eb3460a8787bcd6006b1c51 Mon Sep 17 00:00:00 2001 From: Simeon David Schaub Date: Fri, 27 Dec 2024 22:05:34 +0100 Subject: [PATCH] fix: use `return_type` instead of `_return_type` (#1148) --- lib/LuxLib/src/impl/activation.jl | 2 +- lib/LuxLib/src/impl/batched_mul.jl | 2 +- lib/LuxLib/src/traits.jl | 4 ++-- lib/LuxLib/src/utils.jl | 4 ++-- src/helpers/losses.jl | 4 ++-- src/utils.jl | 4 ++-- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl index 0b015e3b1..2dd7df43f 100644 --- a/lib/LuxLib/src/impl/activation.jl +++ b/lib/LuxLib/src/impl/activation.jl @@ -59,7 +59,7 @@ function activation(::AbstractInternalArrayOpMode, σ::F, x::AbstractArray) wher end @stable default_mode="disable" function activation( opmode::LoopedArrayOp, σ::F, x::AbstractArray{T}) where {F, T} - RT = Core.Compiler._return_type(σ, Tuple{T}) + RT = Core.Compiler.return_type(σ, Tuple{T}) y = similar(x, ifelse(isconcretetype(RT), RT, T)) activation!(y, opmode, σ, x) return y diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl index 37e62de67..0464ca240 100644 --- a/lib/LuxLib/src/impl/batched_mul.jl +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -17,7 +17,7 @@ end function batched_matmul(opmode::GPUBroadcastOp{<:AbstractGPUDevice}, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} - if isconcretetype(Core.Compiler._return_type( + if isconcretetype(Core.Compiler.return_type( NNlib.batched_mul, Tuple{typeof(x), typeof(y)})) return NNlib.batched_mul(x, y) # GPU versions are well optimized end diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 0a768de86..6df0fc8f7 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -61,12 +61,12 @@ end activation_intermediate_not_needed(::typeof(identity), ::Type) = True() function activation_intermediate_not_needed(::F, ::Type{T}) where {F, T} - return static(isconcretetype(Core.Compiler._return_type( + return static(isconcretetype(Core.Compiler.return_type( only_derivative, Tuple{T, F, NotaNumber}))) end function activation_has_rrule(::F, ::Type{T}) where {F, T} - return static(isconcretetype(Core.Compiler._return_type( + return static(isconcretetype(Core.Compiler.return_type( only_derivative, Tuple{T, F, T}))) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 1ef926b93..717df4e78 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -54,7 +54,7 @@ safe_vec(::Nothing) = nothing ## This part is taken from NNlib.jl # This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` -# is independent of `x`, as `_return_type` says `Union{}` when calling is an error. +# is independent of `x`, as `return_type` says `Union{}` when calling is an error. struct NotaNumber <: Real end # This just saves typing `only.(only.(` many times: @@ -118,7 +118,7 @@ CRC.@non_differentiable default_epsilon(::Any...) function concrete_bias_act_output_eltype(act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx}, b::Optional{<:AbstractVector}) where {F, Tw, Tx} Ty = promote_type(Tw, Tx, safe_eltype(b)) - Tact = Core.Compiler._return_type(act, Tuple{Ty}) + Tact = Core.Compiler.return_type(act, Tuple{Ty}) return ifelse(isconcretetype(Tact), Tact, Ty) end diff --git a/src/helpers/losses.jl b/src/helpers/losses.jl index 8cf17a997..9a8f575b6 100644 --- a/src/helpers/losses.jl +++ b/src/helpers/losses.jl @@ -42,7 +42,7 @@ end fused_agg(::typeof(sum), op::OP, x::Number, y::Number) where {OP} = op(x, y) function fused_agg(::typeof(sum), op::OP, x::AbstractArray, y::AbstractArray) where {OP} if fast_scalar_indexing(x) && fast_scalar_indexing(y) - res = Core.Compiler._return_type(op, Tuple{eltype(x), eltype(y)})(0) + res = Core.Compiler.return_type(op, Tuple{eltype(x), eltype(y)})(0) @simd ivdep for i in eachindex(x, y) @inbounds res += op(x[i], y[i]) end @@ -73,7 +73,7 @@ function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, Nothing, eltype(x), 1}.(x, (Partials{1, eltype(x)}((one(eltype(x)),)),)) x_partials = similar(x) T = eltype(x) - res = Core.Compiler._return_type(op, Tuple{T, eltype(y)})(0) + res = Core.Compiler.return_type(op, Tuple{T, eltype(y)})(0) @inbounds @simd for i in eachindex(x_partials, x, y) x_dual = Dual{Nothing, T, 1}(x[i], Partials{1, T}((one(T),))) tmp = op(x_dual, y[i]) diff --git a/src/utils.jl b/src/utils.jl index bd73087ae..2a6930a2a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -42,7 +42,7 @@ unbatched_structure(x) = fmapstructure(size_unbatched, x) can_named_tuple(::NamedTuple) = true can_named_tuple(::T) where {T} = can_named_tuple(T) function can_named_tuple(::Type{T}) where {T} - return Core.Compiler._return_type(named_tuple, Tuple{T}) !== Union{} + return Core.Compiler.return_type(named_tuple, Tuple{T}) !== Union{} end @non_differentiable can_named_tuple(::Any) @@ -50,7 +50,7 @@ end # Convert to a NamedTuple named_tuple(nt::NamedTuple) = nt function named_tuple(x::T) where {T} - NT = Core.Compiler._return_type(NamedTuple, Tuple{T}) + NT = Core.Compiler.return_type(NamedTuple, Tuple{T}) if NT === Union{} || NT === NamedTuple error("`NamedTuple` is not defined for type `$(T)`. Please define \ `Lux.Utils.named_tuple(::$(T))` method (or preferably \