Skip to content

Commit

Permalink
refactor: remove promote_to handling
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 24, 2024
1 parent adcf30a commit 2832a6c
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 23 deletions.
4 changes: 0 additions & 4 deletions ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ function Utils.promote_to(::Type{T}, x::Number) where {T <: Number}
return Reactant.ConcreteRNumber{T}(x)
end

function Utils.promote_to_inside_interpreter(::Type{T}, x::Number) where {T <: Number}
return Reactant.TracedUtils.promote_to(TracedRNumber{T}, x)
end

include("patches.jl")
include("training.jl")

Expand Down
22 changes: 4 additions & 18 deletions src/helpers/optimizers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ end
Optimisers.init(::ReactantDescent, ::AbstractArray) = nothing

function Optimisers.apply!(opt::ReactantDescent, state, x::AbstractArray{T}, dx) where {T}
η = Utils.promote_to_inside_interpreter(T, opt.eta)
η = T(opt.eta)
return state, @. dx * η
end

Expand All @@ -74,8 +74,7 @@ function Optimisers.init(::ReactantMomentum, x::AbstractArray)
end

function Optimisers.apply!(opt::ReactantMomentum, mvel, ::AbstractArray{T}, dx) where {T}
η = Utils.promote_to_inside_interpreter(T, opt.eta)
ρ = Utils.promote_to_inside_interpreter(T, opt.rho)
η, ρ = T(opt.eta), T(opt.rho)
@. mvel = ρ * mvel + η * dx
return mvel, mvel
end
Expand Down Expand Up @@ -110,13 +109,7 @@ function Optimisers.init(opt::ReactantAdam, x::AbstractArray{T}) where {T}
end

function Optimisers.apply!(o::ReactantAdam, state, ::AbstractArray{T}, dx) where {T}
η = Utils.promote_to_inside_interpreter(T, o.eta)
β = (
Utils.promote_to_inside_interpreter(T, o.beta[1]),
Utils.promote_to_inside_interpreter(T, o.beta[2])
)
ϵ = Utils.promote_to_inside_interpreter(T, o.epsilon) # XXX: See Optimisers._eps

η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon) # XXX: See Optimisers._eps
mt, vt, βt = state

@. mt = β[1] * mt + (1 - β[1]) * dx
Expand Down Expand Up @@ -161,14 +154,7 @@ function Optimisers.init(opt::ReactantAdamW, x::AbstractArray{T}) where {T}
end

function Optimisers.apply!(o::ReactantAdamW, state, x::AbstractArray{T}, dx) where {T}
η = Utils.promote_to_inside_interpreter(T, o.eta)
β = (
Utils.promote_to_inside_interpreter(T, o.beta[1]),
Utils.promote_to_inside_interpreter(T, o.beta[2])
)
ϵ = Utils.promote_to_inside_interpreter(T, o.epsilon) # XXX: See Optimisers._eps
λ = Utils.promote_to_inside_interpreter(T, o.lambda)

η, β, ϵ, λ = T(o.eta), T.(o.beta), T(o.epsilon), T(o.lambda) # XXX: See Optimisers._eps
mt, vt, βt = state

# standard Adam update with learning rate eta=1
Expand Down
1 change: 0 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ matrix_to_array(x::AbstractMatrix, y::AbstractArray) = reshape(x, :, size(y)[2:e

function to_rarray end
function promote_to end
function promote_to_inside_interpreter end

# This should probably be in WeightInitializers.jl
calculate_gain(_, __) = 1.0f0
Expand Down

0 comments on commit 2832a6c

Please sign in to comment.