diff --git a/src/functor.jl b/src/functor.jl index 8296e6bd98..24dc41d3ed 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -187,16 +187,9 @@ _isbitsarray(x) = false _isleaf(::AbstractRNG) = true _isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) -const GPU_BACKEND_ORDER = sort( - Dict( - "CUDA" => 1, - "AMD" => 2, - "Metal" => 3, - "CPU" => 4, - ), - byvalue = true -) -const GPU_BACKENDS = tuple(collect(keys(GPU_BACKEND_ORDER))...) +# the order below is important +const GPU_BACKENDS = ["CUDA", "AMD", "Metal", "CPU"] +const GPU_BACKEND_ORDER = Dict(collect(zip(GPU_BACKENDS, 1:length(GPU_BACKENDS)))) const GPU_BACKEND = @load_preference("gpu_backend", "CUDA") function gpu_backend!(backend::String)