From cda5278fdacdc54ebc48ae1430a30c7bad78247f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 24 Dec 2024 21:26:44 +0530 Subject: [PATCH] refactor: remove promote_to handling --- ext/LuxReactantExt/LuxReactantExt.jl | 4 ---- src/helpers/optimizers.jl | 22 ++++------------------ src/utils.jl | 1 - 3 files changed, 4 insertions(+), 23 deletions(-) diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index 373ff38681..6f49b076d0 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -18,10 +18,6 @@ function Utils.promote_to(::Type{T}, x::Number) where {T <: Number} return Reactant.ConcreteRNumber{T}(x) end -function Utils.promote_to_inside_interpreter(::Type{T}, x::Number) where {T <: Number} - return Reactant.TracedUtils.promote_to(TracedRNumber{T}, x) -end - include("patches.jl") include("training.jl") diff --git a/src/helpers/optimizers.jl b/src/helpers/optimizers.jl index f83e96a833..fe0116bb4e 100644 --- a/src/helpers/optimizers.jl +++ b/src/helpers/optimizers.jl @@ -48,7 +48,7 @@ end Optimisers.init(::ReactantDescent, ::AbstractArray) = nothing function Optimisers.apply!(opt::ReactantDescent, state, x::AbstractArray{T}, dx) where {T} - η = Utils.promote_to_inside_interpreter(T, opt.eta) + η = T(opt.eta) return state, @. dx * η end @@ -74,8 +74,7 @@ function Optimisers.init(::ReactantMomentum, x::AbstractArray) end function Optimisers.apply!(opt::ReactantMomentum, mvel, ::AbstractArray{T}, dx) where {T} - η = Utils.promote_to_inside_interpreter(T, opt.eta) - ρ = Utils.promote_to_inside_interpreter(T, opt.rho) + η, ρ = T(opt.eta), T(opt.rho) @. mvel = ρ * mvel + η * dx return mvel, mvel end @@ -110,13 +109,7 @@ function Optimisers.init(opt::ReactantAdam, x::AbstractArray{T}) where {T} end function Optimisers.apply!(o::ReactantAdam, state, ::AbstractArray{T}, dx) where {T} - η = Utils.promote_to_inside_interpreter(T, o.eta) - β = ( - Utils.promote_to_inside_interpreter(T, o.beta[1]), - Utils.promote_to_inside_interpreter(T, o.beta[2]) - ) - ϵ = Utils.promote_to_inside_interpreter(T, o.epsilon) # XXX: See Optimisers._eps - + η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon) # XXX: See Optimisers._eps mt, vt, βt = state @. mt = β[1] * mt + (1 - β[1]) * dx @@ -161,14 +154,7 @@ function Optimisers.init(opt::ReactantAdamW, x::AbstractArray{T}) where {T} end function Optimisers.apply!(o::ReactantAdamW, state, x::AbstractArray{T}, dx) where {T} - η = Utils.promote_to_inside_interpreter(T, o.eta) - β = ( - Utils.promote_to_inside_interpreter(T, o.beta[1]), - Utils.promote_to_inside_interpreter(T, o.beta[2]) - ) - ϵ = Utils.promote_to_inside_interpreter(T, o.epsilon) # XXX: See Optimisers._eps - λ = Utils.promote_to_inside_interpreter(T, o.lambda) - + η, β, ϵ, λ = T(o.eta), T.(o.beta), T(o.epsilon), T(o.lambda) # XXX: See Optimisers._eps mt, vt, βt = state # standard Adam update with learning rate eta=1 diff --git a/src/utils.jl b/src/utils.jl index 4dbe13c53e..eb24d2e25d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -213,7 +213,6 @@ matrix_to_array(x::AbstractMatrix, y::AbstractArray) = reshape(x, :, size(y)[2:e function to_rarray end function promote_to end -function promote_to_inside_interpreter end # This should probably be in WeightInitializers.jl calculate_gain(_, __) = 1.0f0