Skip to content

Commit

Permalink
fix: don't transfer via CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 23, 2024
1 parent 6167fd6 commit adcf30a
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/helpers/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Static: StaticBool, Static, False, True

using ..Lux: Lux, Utils, ReactantCompatibleOptimisers
using LuxCore: LuxCore, AbstractLuxLayer
using MLDataDevices: MLDataDevices, ReactantDevice, get_device_type, get_device, cpu_device
using MLDataDevices: MLDataDevices, ReactantDevice, get_device_type, cpu_device

"""
TrainState
Expand Down Expand Up @@ -63,12 +63,10 @@ Constructor for [`TrainState`](@ref).
[`TrainState`](@ref) object.
"""
function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule)
dev = get_device(ps)
st_opt = if dev isa ReactantDevice
ps_cpu = ps |> cpu_device()
st_opt = if get_device_type(ps) <: ReactantDevice
Optimisers.setup(
ReactantCompatibleOptimisers.make_reactant_compatible(optimizer), ps_cpu
) |> dev
ReactantCompatibleOptimisers.make_reactant_compatible(optimizer), ps
)
else
Optimisers.setup(optimizer, ps)
end
Expand Down

0 comments on commit adcf30a

Please sign in to comment.