diff --git a/src/helpers/training.jl b/src/helpers/training.jl index ffe16ffce..e5bbee395 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -11,7 +11,7 @@ using Static: StaticBool, Static, False, True using ..Lux: Lux, Utils, ReactantCompatibleOptimisers using LuxCore: LuxCore, AbstractLuxLayer -using MLDataDevices: MLDataDevices, ReactantDevice, get_device_type, get_device, cpu_device +using MLDataDevices: MLDataDevices, ReactantDevice, get_device_type, cpu_device """ TrainState @@ -63,12 +63,10 @@ Constructor for [`TrainState`](@ref). [`TrainState`](@ref) object. """ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule) - dev = get_device(ps) - st_opt = if dev isa ReactantDevice - ps_cpu = ps |> cpu_device() + st_opt = if get_device_type(ps) <: ReactantDevice Optimisers.setup( - ReactantCompatibleOptimisers.make_reactant_compatible(optimizer), ps_cpu - ) |> dev + ReactantCompatibleOptimisers.make_reactant_compatible(optimizer), ps + ) else Optimisers.setup(optimizer, ps) end