From 3d5d84929dab2c2536d4dda44312f0a029ecc16d Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 7 Aug 2023 16:36:34 +0530 Subject: [PATCH 01/29] Adding new `get_device` method to return a CUDA device with particular ordinal. --- Project.toml | 9 ++++----- ext/FluxCUDAExt/functor.jl | 6 ++++++ src/Flux.jl | 1 + src/functor.jl | 3 +++ 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 6b58703e1c..56e5ecd50a 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.14.2" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -24,8 +25,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] FluxAMDGPUExt = "AMDGPU" @@ -56,8 +57,8 @@ julia = "1.9" [extras] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" @@ -68,6 +69,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", - "FillArrays", "ComponentArrays", "BSON", "Pkg", - "CUDA", "cuDNN", "Metal", "AMDGPU"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU"] diff --git a/ext/FluxCUDAExt/functor.jl b/ext/FluxCUDAExt/functor.jl index 347cfce372..0177c9daab 100644 --- a/ext/FluxCUDAExt/functor.jl +++ b/ext/FluxCUDAExt/functor.jl @@ -30,3 +30,9 @@ function _cuda(x) USE_CUDA[] || return x fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude=Flux._isleaf) end + +function Flux.get_device(::Type{CUDA.CUDABackend}, ordinal::UInt) + CUDA.device!(ordinal) do + return Flux.FluxCUDADevice(CUDA.device()) + end +end diff --git a/src/Flux.jl b/src/Flux.jl index d522b91e78..78239588cb 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -14,6 +14,7 @@ using Random: default_rng using Zygote, ChainRulesCore using Zygote: Params, @adjoint, gradient, pullback using Zygote.ForwardDiff: value +import KernelAbstractions export gradient # Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.) diff --git a/src/functor.jl b/src/functor.jl index 24dc41d3ed..290be4ea3a 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -1,5 +1,6 @@ import Adapt: adapt, adapt_storage using LinearAlgebra: Cholesky +using NNlib: KernelAbstractions using Zygote: IdSet import Functors: Functors, @functor, functor, fmap, isleaf using SparseArrays: AbstractSparseArray @@ -653,3 +654,5 @@ function get_device(; verbose=false)::AbstractDevice end end end + +function get_device(::Type{<:KernelAbstractions.GPU}, ::UInt) end From f720d5f97f9fa7ee0f25a5b30e9331411c4b2778 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Thu, 10 Aug 2023 02:41:46 +0530 Subject: [PATCH 02/29] Adding an `adapt` function for `AbstractArray` to handle movement across devices. --- ext/FluxCUDAExt/functor.jl | 36 ++++++++++++++++++++++++++++++------ src/functor.jl | 16 ++++++++++++---- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/ext/FluxCUDAExt/functor.jl b/ext/FluxCUDAExt/functor.jl index 0177c9daab..1160b09806 100644 --- a/ext/FluxCUDAExt/functor.jl +++ b/ext/FluxCUDAExt/functor.jl @@ -1,5 +1,27 @@ - adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) +function adapt_storage(to::FluxCUDAAdaptor, x::AbstractArray) + typeof(to.ordinal) <: Nothing && return CUDA.cu(x) + + # remember current device + old_ordinal = CUDA.device().handle + + if !(x isa CuArray) + CUDA.device!(to.ordinal) + x_new = CUDA.cu(x) + CUDA.device!(old_ordinal) + return x_new + else + if CUDA.device(x).handle == to.ordinal + return x + else + CUDA.device!(to.ordinal) + x_new = CUDA.rand(size(x)...) + copy!(x_new, x) + CUDA.device!(old_ordinal) + return x_new + end + end +end adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x)) adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng() adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x @@ -25,14 +47,16 @@ ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AnyCuArray) = ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AbstractArray) = adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCPUAdaptor(), unthunk(Δ))) -function _cuda(x) +function _cuda(ordinal::Union{Nothing, UInt}, x) check_use_cuda() USE_CUDA[] || return x - fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude=Flux._isleaf) + fmap(x -> Adapt.adapt(FluxCUDAAdaptor(ordinal), x), x; exclude=Flux._isleaf) end function Flux.get_device(::Type{CUDA.CUDABackend}, ordinal::UInt) - CUDA.device!(ordinal) do - return Flux.FluxCUDADevice(CUDA.device()) - end + old_ordinal = CUDA.device().handle + CUDA.device!(ordinal) + device = Flux.FluxCUDADevice(CUDA.device()) + CUDA.device!(old_ordinal) + return device end diff --git a/src/functor.jl b/src/functor.jl index 290be4ea3a..ec53a6b702 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -333,13 +333,15 @@ trainable(c::Cholesky) = () # CUDA extension. ######## -struct FluxCUDAAdaptor end +Base.@kwdef struct FluxCUDAAdaptor + ordinal::Union{Nothing, UInt} = nothing +end const CUDA_LOADED = Ref{Bool}(false) -function gpu(::FluxCUDAAdaptor, x) +function gpu(to::FluxCUDAAdaptor, x) if CUDA_LOADED[] - return _cuda(x) + return _cuda(to.ordinal, x) else @info """ The CUDA functionality is being called but @@ -501,7 +503,13 @@ Base.@kwdef struct FluxCUDADevice <: AbstractDevice deviceID end -(::FluxCUDADevice)(x) = gpu(FluxCUDAAdaptor(), x) +function (device::FluxCUDADevice)(x) + if typeof(device.deviceID) <: Nothing + return gpu(FluxCUDAAdaptor(), x) + else + return gpu(FluxCUDAAdaptor(UInt(device.deviceID.handle)), x) + end +end _get_device_name(::FluxCUDADevice) = "CUDA" """ From 40e085ccfc5d48cb5d8c86a4cc689c534ff798d0 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sat, 12 Aug 2023 14:04:43 +0530 Subject: [PATCH 03/29] Making the `get_device` interface simpler, and some minor changes. --- ext/FluxCUDAExt/functor.jl | 19 ++++++++----------- src/functor.jl | 6 ++++-- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/ext/FluxCUDAExt/functor.jl b/ext/FluxCUDAExt/functor.jl index 1160b09806..b98e3401e3 100644 --- a/ext/FluxCUDAExt/functor.jl +++ b/ext/FluxCUDAExt/functor.jl @@ -10,16 +10,13 @@ function adapt_storage(to::FluxCUDAAdaptor, x::AbstractArray) x_new = CUDA.cu(x) CUDA.device!(old_ordinal) return x_new + elseif CUDA.device(x).handle == to.ordinal + return x else - if CUDA.device(x).handle == to.ordinal - return x - else - CUDA.device!(to.ordinal) - x_new = CUDA.rand(size(x)...) - copy!(x_new, x) - CUDA.device!(old_ordinal) - return x_new - end + CUDA.device!(to.ordinal) + x_new = copy(x) + CUDA.device!(old_ordinal) + return x_new end end adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x)) @@ -47,13 +44,13 @@ ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AnyCuArray) = ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AbstractArray) = adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCPUAdaptor(), unthunk(Δ))) -function _cuda(ordinal::Union{Nothing, UInt}, x) +function _cuda(ordinal::Union{Nothing, Int}, x) check_use_cuda() USE_CUDA[] || return x fmap(x -> Adapt.adapt(FluxCUDAAdaptor(ordinal), x), x; exclude=Flux._isleaf) end -function Flux.get_device(::Type{CUDA.CUDABackend}, ordinal::UInt) +function Flux.get_device(::Val{:CUDA}, ordinal::Int) old_ordinal = CUDA.device().handle CUDA.device!(ordinal) device = Flux.FluxCUDADevice(CUDA.device()) diff --git a/src/functor.jl b/src/functor.jl index ec53a6b702..12ff5ffde4 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -334,7 +334,7 @@ trainable(c::Cholesky) = () # CUDA extension. ######## Base.@kwdef struct FluxCUDAAdaptor - ordinal::Union{Nothing, UInt} = nothing + ordinal::Union{Nothing, Int} = nothing end const CUDA_LOADED = Ref{Bool}(false) @@ -663,4 +663,6 @@ function get_device(; verbose=false)::AbstractDevice end end -function get_device(::Type{<:KernelAbstractions.GPU}, ::UInt) end +function get_device(backend::String, ordinal::Int) + get_device(Val(Symbol(backend)), ordinal) +end From 652bf95668052e8be1602d97181fbf54b120cd60 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sat, 12 Aug 2023 14:37:42 +0530 Subject: [PATCH 04/29] Adding CPU option to `get_device`. --- src/functor.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 12ff5ffde4..45f557edac 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -663,6 +663,10 @@ function get_device(; verbose=false)::AbstractDevice end end -function get_device(backend::String, ordinal::Int) - get_device(Val(Symbol(backend)), ordinal) +function get_device(backend::String, ordinal::Int = 0) + if backend == "CPU" + return FluxCPUDevice() + else + return get_device(Val(Symbol(backend)), ordinal) + end end From b3cd29285ae072a911ae7cd327ad16ff08770c41 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sat, 12 Aug 2023 14:40:17 +0530 Subject: [PATCH 05/29] Removing `KernelAbstractions` from deps. --- Project.toml | 1 - src/Flux.jl | 1 - src/functor.jl | 1 - 3 files changed, 3 deletions(-) diff --git a/Project.toml b/Project.toml index 56e5ecd50a..5322cdd9b9 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.14.2" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" diff --git a/src/Flux.jl b/src/Flux.jl index 78239588cb..d522b91e78 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -14,7 +14,6 @@ using Random: default_rng using Zygote, ChainRulesCore using Zygote: Params, @adjoint, gradient, pullback using Zygote.ForwardDiff: value -import KernelAbstractions export gradient # Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.) diff --git a/src/functor.jl b/src/functor.jl index 45f557edac..c7d483eb5b 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -1,6 +1,5 @@ import Adapt: adapt, adapt_storage using LinearAlgebra: Cholesky -using NNlib: KernelAbstractions using Zygote: IdSet import Functors: Functors, @functor, functor, fmap, isleaf using SparseArrays: AbstractSparseArray From 9925c5b3916b334bdc6bf18300039a7e4340c7ed Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 14 Aug 2023 18:32:35 +0530 Subject: [PATCH 06/29] Adding new `get_device` method to return a particular AMD device. --- ext/FluxAMDGPUExt/functor.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ext/FluxAMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl index dc3d3cbcce..5cc64eb743 100644 --- a/ext/FluxAMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -74,3 +74,11 @@ function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMD_CONV) Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims), Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups) end + +function Flux.get_device(::Val{:AMD}, ordinal::Int) + old_ordinal = AMDGPU.device_id(AMDGPU.device()) + AMDGPU.device_id!(ordinal) + device = Flux.FluxAMDDevice(AMDGPU.device()) + AMDGPU.device_id!(old_ordinal) + return device +end From df577f38bd9fdd2a73c4d5ccbb632b4388c895fa Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 14 Aug 2023 22:21:10 +0530 Subject: [PATCH 07/29] Adding new `adapt_storage` function for moving arrays. Also passing ordinal information through `FluxAMDAdaptor`. --- ext/FluxAMDGPUExt/functor.jl | 51 +++++++++++++++++++++++++++++------- src/functor.jl | 8 +++--- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/ext/FluxAMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl index 5cc64eb743..dd6602fb68 100644 --- a/ext/FluxAMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -1,10 +1,41 @@ # Convert Float64 to Float32, but preserve Float16. -adapt_storage(::FluxAMDAdaptor, x::T) where T <: AbstractArray = - isbits(x) ? x : ROCArray(x) -adapt_storage(::FluxAMDAdaptor, x::AbstractArray{T, N}) where {T <: AbstractFloat, N} = - isbits(x) ? x : ROCArray{Float32, N}(x) -adapt_storage(::FluxAMDAdaptor, x::AbstractArray{Float16, N}) where N = - isbits(x) ? x : ROCArray{Float16, N}(x) +function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray) + if typeof(to.ordinal) <: Nothing + if (typeof(x) <: AbstractArray{Float16, N} where N) + N = length(size(x)) + return isbits(x) ? x : ROCArray{Float16, N}(x) + elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N}) + N = length(size(x)) + return isbits(x) ? x : ROCArray{Float32, N}(x) + else + return isbits(x) ? x : ROCArray(x) + end + end + + old_ordinal = AMDGPU.device_id(AMDGPU.device()) + + if !(x isa ROCArray) + AMDGPU.device!(AMD.devices()[to.ordinal]) + if (typeof(x) <: AbstractArray{Float16, N} where N) + N = length(size(x)) + x_new = isbits(x) ? x : ROCArray{Float16, N}(x) + elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N}) + N = length(size(x)) + x_new = isbits(x) ? x : ROCArray{Float32, N}(x) + else + x_new = isbits(x) ? x : ROCArray(x) + end + AMDGPU.device!(AMD.devices()[old_ordinal]) + return x_new + elseif AMDGPU.device_id(AMDGPU.device(x)) == to.ordinal + return x + else + AMDGPU.device!(AMD.devices()[to.ordinal]) + x_new = copy(x) + AMDGPU.device!(AMD.devices()[old_ordinal]) + return x_new + end +end adapt_storage(::FluxAMDAdaptor, x::Zygote.FillArrays.AbstractFill) = ROCArray(collect(x)) @@ -45,10 +76,10 @@ Flux._isleaf(::AMD_CONV) = true _exclude(x) = Flux._isleaf(x) _exclude(::CPU_CONV) = true -function _amd(x) +function _amd(ordinal::Union{Nothing, Int}, x) check_use_amdgpu() USE_AMDGPU[] || return x - fmap(x -> Adapt.adapt(FluxAMDAdaptor(), x), x; exclude=_exclude) + fmap(x -> Adapt.adapt(FluxAMDAdaptor(ordinal), x), x; exclude=_exclude) end # CPU -> GPU @@ -77,8 +108,8 @@ end function Flux.get_device(::Val{:AMD}, ordinal::Int) old_ordinal = AMDGPU.device_id(AMDGPU.device()) - AMDGPU.device_id!(ordinal) + AMDGPU.device!(AMDGPU.devices()[ordinal]) device = Flux.FluxAMDDevice(AMDGPU.device()) - AMDGPU.device_id!(old_ordinal) + AMDGPU.device!(AMDGPU.devices()[old_ordinal]) return device end diff --git a/src/functor.jl b/src/functor.jl index c7d483eb5b..c2ec8406d7 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -355,13 +355,15 @@ function _cuda end # AMDGPU extension. ######## -struct FluxAMDAdaptor end +Base.@kwdef struct FluxAMDAdaptor + ordinal::Union{Nothing, Int} = nothing +end const AMDGPU_LOADED = Ref{Bool}(false) -function gpu(::FluxAMDAdaptor, x) +function gpu(to::FluxAMDAdaptor, x) if AMDGPU_LOADED[] - return _amd(x) + return _amd(to.ordinal, x) else @info """ The AMDGPU functionality is being called but From 8aa7eed9dd9b9e2f553cf19bb8fe8fdc2646c829 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Tue, 15 Aug 2023 00:12:45 +0530 Subject: [PATCH 08/29] Moving relevant function definitions to extensions. --- ext/FluxAMDGPUExt/FluxAMDGPUExt.jl | 8 ++++++++ ext/FluxCUDAExt/FluxCUDAExt.jl | 8 ++++++++ ext/FluxMetalExt/FluxMetalExt.jl | 8 ++++++++ src/functor.jl | 23 +++++------------------ 4 files changed, 29 insertions(+), 18 deletions(-) diff --git a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl index f41984ec38..b6f19f8609 100644 --- a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl +++ b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl @@ -17,6 +17,14 @@ const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat # Set to boolean on the first call to check_use_amdgpu const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing) +function (device::Flux.FluxAMDDevice)(x) + if typeof(device.deviceId) <: Nothing + Flux.gpu(Flux.FluxAMDAdaptor(), x) + else + return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceId))) + end +end +Flux._get_device_name(::Flux.FluxAMDDevice) = "AMD" Flux._isavailable(::Flux.FluxAMDDevice) = true Flux._isfunctional(::Flux.FluxAMDDevice) = AMDGPU.functional() diff --git a/ext/FluxCUDAExt/FluxCUDAExt.jl b/ext/FluxCUDAExt/FluxCUDAExt.jl index ad80cf8a58..adb2c7af66 100644 --- a/ext/FluxCUDAExt/FluxCUDAExt.jl +++ b/ext/FluxCUDAExt/FluxCUDAExt.jl @@ -14,6 +14,14 @@ import Adapt: adapt_storage const USE_CUDA = Ref{Union{Nothing, Bool}}(nothing) +function (device::Flux.FluxCUDADevice)(x) + if typeof(device.deviceID) <: Nothing + return Flux.gpu(Flux.FluxCUDAAdaptor(), x) + else + return Flux.gpu(Flux.FluxCUDAAdaptor(device.deviceID.handle), x) + end +end +Flux._get_device_name(::Flux.FluxCUDADevice) = "CUDA" Flux._isavailable(::Flux.FluxCUDADevice) = true Flux._isfunctional(::Flux.FluxCUDADevice) = CUDA.functional() diff --git a/ext/FluxMetalExt/FluxMetalExt.jl b/ext/FluxMetalExt/FluxMetalExt.jl index bca48fe279..f7bb700c4c 100644 --- a/ext/FluxMetalExt/FluxMetalExt.jl +++ b/ext/FluxMetalExt/FluxMetalExt.jl @@ -12,6 +12,14 @@ using Zygote const USE_METAL = Ref{Union{Nothing, Bool}}(nothing) +function (device::Flux.FluxMetalDevice)(x) + if typeof(device.deviceId) <: Nothing + Flux.gpu(Flux.FluxMetalAdaptor(), x) + else + return Flux.gpu(Flux.FluxMetalAdaptor(device.deviceId.registryID), x) + end +end +Flux._get_device_name(::Flux.FluxMetalDevice) = "Metal" Flux._isavailable(::Flux.FluxMetalDevice) = true Flux._isfunctional(::Flux.FluxMetalDevice) = Metal.functional() diff --git a/src/functor.jl b/src/functor.jl index c2ec8406d7..3d44c6bccd 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -378,13 +378,15 @@ function _amd end # Metal extension. ###### -struct FluxMetalAdaptor end +Base.@kwdef struct FluxMetalAdaptor + ordinal::Union{Nothing, Int} = nothing +end const METAL_LOADED = Ref{Bool}(false) -function gpu(::FluxMetalAdaptor, x) +function gpu(to::FluxMetalAdaptor, x) if METAL_LOADED[] - return _metal(x) + return _metal(to.ordinal, x) else @info """ The Metal functionality is being called but @@ -504,15 +506,6 @@ Base.@kwdef struct FluxCUDADevice <: AbstractDevice deviceID end -function (device::FluxCUDADevice)(x) - if typeof(device.deviceID) <: Nothing - return gpu(FluxCUDAAdaptor(), x) - else - return gpu(FluxCUDAAdaptor(UInt(device.deviceID.handle)), x) - end -end -_get_device_name(::FluxCUDADevice) = "CUDA" - """ FluxAMDDevice <: AbstractDevice @@ -522,9 +515,6 @@ Base.@kwdef struct FluxAMDDevice <: AbstractDevice deviceID end -(::FluxAMDDevice)(x) = gpu(FluxAMDAdaptor(), x) -_get_device_name(::FluxAMDDevice) = "AMD" - """ FluxMetalDevice <: AbstractDevice @@ -534,9 +524,6 @@ Base.@kwdef struct FluxMetalDevice <: AbstractDevice deviceID end -(::FluxMetalDevice)(x) = gpu(FluxMetalAdaptor(), x) -_get_device_name(::FluxMetalDevice) = "Metal" - ## device list. order is important const DEVICES = Ref{Vector{Union{Nothing, AbstractDevice}}}(Vector{Union{Nothing, AbstractDevice}}(nothing, length(GPU_BACKENDS))) DEVICES[][GPU_BACKEND_ORDER["CPU"]] = FluxCPUDevice() From f137080954b063b698b641ac4b727017ed75477a Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Tue, 15 Aug 2023 00:15:22 +0530 Subject: [PATCH 09/29] Making `_metal` accept an ordinal. --- ext/FluxMetalExt/functor.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/FluxMetalExt/functor.jl b/ext/FluxMetalExt/functor.jl index 27914af7b4..6ff47be04a 100644 --- a/ext/FluxMetalExt/functor.jl +++ b/ext/FluxMetalExt/functor.jl @@ -27,8 +27,8 @@ function ChainRulesCore.rrule( end -function _metal(x) +function _metal(ordinal::Union{Nothing, Int}, x) check_use_metal() USE_METAL[] || return x - fmap(x -> Adapt.adapt(FluxMetalAdaptor(), x), x; exclude=_isleaf) + fmap(x -> Adapt.adapt(FluxMetalAdaptor(ordinal), x), x; exclude=_isleaf) end From ef265eb26f280e013c766b918ecf58cee0fcaaf5 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Tue, 15 Aug 2023 00:25:43 +0530 Subject: [PATCH 10/29] Adding new `get_device` method to return particular Metal device. --- ext/FluxMetalExt/functor.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ext/FluxMetalExt/functor.jl b/ext/FluxMetalExt/functor.jl index 6ff47be04a..6c3e074303 100644 --- a/ext/FluxMetalExt/functor.jl +++ b/ext/FluxMetalExt/functor.jl @@ -32,3 +32,11 @@ function _metal(ordinal::Union{Nothing, Int}, x) USE_METAL[] || return x fmap(x -> Adapt.adapt(FluxMetalAdaptor(ordinal), x), x; exclude=_isleaf) end + +function Flux.get_device(::Val{:Metal}, ordinal::Int) + old_device = Metal.current_device() + Metal.device!(Metal.devices()[ordinal]) + device = Flux.FluxMetalDevice(Metal.device()) + Metal.device!(old_device) + return device +end From 3fbb4f5b52f2dd3c2ff041b14350f7e389e85948 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Tue, 15 Aug 2023 00:38:59 +0530 Subject: [PATCH 11/29] Adding new `adapt_storage` method for metal arrays. --- ext/FluxMetalExt/functor.jl | 43 +++++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/ext/FluxMetalExt/functor.jl b/ext/FluxMetalExt/functor.jl index 6c3e074303..b9fe2fdb6b 100644 --- a/ext/FluxMetalExt/functor.jl +++ b/ext/FluxMetalExt/functor.jl @@ -1,10 +1,41 @@ # Convert Float64 to Float32, but preserve Float16. -adapt_storage(::FluxMetalAdaptor, x::T) where T <: AbstractArray = - isbits(x) ? x : MtlArray(x) -adapt_storage(::FluxMetalAdaptor, x::AbstractArray{T, N}) where {T <: AbstractFloat, N} = - isbits(x) ? x : MtlArray{Float32, N}(x) -adapt_storage(::FluxMetalAdaptor, x::AbstractArray{Float16, N}) where N = - isbits(x) ? x : MtlArray{Float16, N}(x) +function adapt_storage(to::FluxMetalAdaptor, x::AbstractArray) + if typeof(to.ordinal) <: Nothing + if (typeof(x) <: AbstractArray{Float16, N} where N) + N = length(size(x)) + return isbits(x) ? x : MtlArray{Float16, N}(x) + elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N}) + N = length(size(x)) + return isbits(x) ? x : MtlArray{Float32, N}(x) + else + return isbits(x) ? x : MtlArray(x) + end + end + + old_device = Metal.current_device() + + if !(x isa MtlArray) + Metal.device!(Metal.devices()[to.ordinal]) + if (typeof(x) <: AbstractArray{Float16, N} where N) + N = length(size(x)) + x_new = isbits(x) ? x : MtlArray{Float16, N}(x) + elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N}) + N = length(size(x)) + x_new = isbits(x) ? x : MtlArray{Float32, N}(x) + else + x_new = isbits(x) ? x : MtlArray(x) + end + Metal.device!(old_device) + return x_new + elseif Metal.device(x).registryID == Metal.devices()[to.ordinal].registryID + return x + else + Metal.device!(Metal.devices()[to.ordinal]) + x_new = copy(x) + Metal.device!(old_device) + return x_new + end +end adapt_storage(::FluxMetalAdaptor, x::Zygote.FillArrays.AbstractFill) = MtlArray(collect(x)) From 2c6bc556a53d19e5202b39fae5ba06cbdbd412f9 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Tue, 15 Aug 2023 00:42:05 +0530 Subject: [PATCH 12/29] Fixing minor error. --- ext/FluxAMDGPUExt/FluxAMDGPUExt.jl | 2 +- ext/FluxMetalExt/FluxMetalExt.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl index b6f19f8609..7c39bb3c6d 100644 --- a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl +++ b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl @@ -21,7 +21,7 @@ function (device::Flux.FluxAMDDevice)(x) if typeof(device.deviceId) <: Nothing Flux.gpu(Flux.FluxAMDAdaptor(), x) else - return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceId))) + return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID))) end end Flux._get_device_name(::Flux.FluxAMDDevice) = "AMD" diff --git a/ext/FluxMetalExt/FluxMetalExt.jl b/ext/FluxMetalExt/FluxMetalExt.jl index f7bb700c4c..c15fce6b9c 100644 --- a/ext/FluxMetalExt/FluxMetalExt.jl +++ b/ext/FluxMetalExt/FluxMetalExt.jl @@ -16,7 +16,7 @@ function (device::Flux.FluxMetalDevice)(x) if typeof(device.deviceId) <: Nothing Flux.gpu(Flux.FluxMetalAdaptor(), x) else - return Flux.gpu(Flux.FluxMetalAdaptor(device.deviceId.registryID), x) + return Flux.gpu(Flux.FluxMetalAdaptor(device.deviceID.registryID), x) end end Flux._get_device_name(::Flux.FluxMetalDevice) = "Metal" From 829dcfa767085d1ffb661dc00a99a915b54b63bb Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Wed, 16 Aug 2023 01:54:19 +0530 Subject: [PATCH 13/29] Fixing minor error and spelling mistake. --- ext/FluxAMDGPUExt/FluxAMDGPUExt.jl | 4 ++-- ext/FluxMetalExt/FluxMetalExt.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl index 7c39bb3c6d..446a3cbd43 100644 --- a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl +++ b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl @@ -18,10 +18,10 @@ const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing) function (device::Flux.FluxAMDDevice)(x) - if typeof(device.deviceId) <: Nothing + if typeof(device.deviceID) <: Nothing Flux.gpu(Flux.FluxAMDAdaptor(), x) else - return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID))) + return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID)), x) end end Flux._get_device_name(::Flux.FluxAMDDevice) = "AMD" diff --git a/ext/FluxMetalExt/FluxMetalExt.jl b/ext/FluxMetalExt/FluxMetalExt.jl index c15fce6b9c..fba65fac2b 100644 --- a/ext/FluxMetalExt/FluxMetalExt.jl +++ b/ext/FluxMetalExt/FluxMetalExt.jl @@ -13,7 +13,7 @@ using Zygote const USE_METAL = Ref{Union{Nothing, Bool}}(nothing) function (device::Flux.FluxMetalDevice)(x) - if typeof(device.deviceId) <: Nothing + if typeof(device.deviceID) <: Nothing Flux.gpu(Flux.FluxMetalAdaptor(), x) else return Flux.gpu(Flux.FluxMetalAdaptor(device.deviceID.registryID), x) From 930d29ca058cd52a8e89db89dae43e871f4becac Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Wed, 16 Aug 2023 23:00:06 +0530 Subject: [PATCH 14/29] Fixing package name: `AMDGPU` instead of `AMD`. --- ext/FluxAMDGPUExt/functor.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/FluxAMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl index dd6602fb68..6b26431e2b 100644 --- a/ext/FluxAMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -15,7 +15,7 @@ function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray) old_ordinal = AMDGPU.device_id(AMDGPU.device()) if !(x isa ROCArray) - AMDGPU.device!(AMD.devices()[to.ordinal]) + AMDGPU.device!(AMDGPU.devices()[to.ordinal]) if (typeof(x) <: AbstractArray{Float16, N} where N) N = length(size(x)) x_new = isbits(x) ? x : ROCArray{Float16, N}(x) @@ -25,14 +25,14 @@ function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray) else x_new = isbits(x) ? x : ROCArray(x) end - AMDGPU.device!(AMD.devices()[old_ordinal]) + AMDGPU.device!(AMDGPU.devices()[old_ordinal]) return x_new elseif AMDGPU.device_id(AMDGPU.device(x)) == to.ordinal return x else - AMDGPU.device!(AMD.devices()[to.ordinal]) + AMDGPU.device!(AMDGPU.devices()[to.ordinal]) x_new = copy(x) - AMDGPU.device!(AMD.devices()[old_ordinal]) + AMDGPU.device!(AMDGPU.devices()[old_ordinal]) return x_new end end From fedee3b45e30ae56675578686b1d0e02795ad24b Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 18 Aug 2023 15:00:35 +0530 Subject: [PATCH 15/29] Reverting back to old metal functionality. --- ext/FluxMetalExt/FluxMetalExt.jl | 8 +---- ext/FluxMetalExt/functor.jl | 55 +++++--------------------------- src/functor.jl | 8 ++--- 3 files changed, 12 insertions(+), 59 deletions(-) diff --git a/ext/FluxMetalExt/FluxMetalExt.jl b/ext/FluxMetalExt/FluxMetalExt.jl index fba65fac2b..a11046d244 100644 --- a/ext/FluxMetalExt/FluxMetalExt.jl +++ b/ext/FluxMetalExt/FluxMetalExt.jl @@ -12,13 +12,7 @@ using Zygote const USE_METAL = Ref{Union{Nothing, Bool}}(nothing) -function (device::Flux.FluxMetalDevice)(x) - if typeof(device.deviceID) <: Nothing - Flux.gpu(Flux.FluxMetalAdaptor(), x) - else - return Flux.gpu(Flux.FluxMetalAdaptor(device.deviceID.registryID), x) - end -end +(::Flux.FluxMetalDevice)(x) = Flux.gpu(Flux.FluxMetalAdaptor(), x) Flux._get_device_name(::Flux.FluxMetalDevice) = "Metal" Flux._isavailable(::Flux.FluxMetalDevice) = true Flux._isfunctional(::Flux.FluxMetalDevice) = Metal.functional() diff --git a/ext/FluxMetalExt/functor.jl b/ext/FluxMetalExt/functor.jl index b9fe2fdb6b..27914af7b4 100644 --- a/ext/FluxMetalExt/functor.jl +++ b/ext/FluxMetalExt/functor.jl @@ -1,41 +1,10 @@ # Convert Float64 to Float32, but preserve Float16. -function adapt_storage(to::FluxMetalAdaptor, x::AbstractArray) - if typeof(to.ordinal) <: Nothing - if (typeof(x) <: AbstractArray{Float16, N} where N) - N = length(size(x)) - return isbits(x) ? x : MtlArray{Float16, N}(x) - elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N}) - N = length(size(x)) - return isbits(x) ? x : MtlArray{Float32, N}(x) - else - return isbits(x) ? x : MtlArray(x) - end - end - - old_device = Metal.current_device() - - if !(x isa MtlArray) - Metal.device!(Metal.devices()[to.ordinal]) - if (typeof(x) <: AbstractArray{Float16, N} where N) - N = length(size(x)) - x_new = isbits(x) ? x : MtlArray{Float16, N}(x) - elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N}) - N = length(size(x)) - x_new = isbits(x) ? x : MtlArray{Float32, N}(x) - else - x_new = isbits(x) ? x : MtlArray(x) - end - Metal.device!(old_device) - return x_new - elseif Metal.device(x).registryID == Metal.devices()[to.ordinal].registryID - return x - else - Metal.device!(Metal.devices()[to.ordinal]) - x_new = copy(x) - Metal.device!(old_device) - return x_new - end -end +adapt_storage(::FluxMetalAdaptor, x::T) where T <: AbstractArray = + isbits(x) ? x : MtlArray(x) +adapt_storage(::FluxMetalAdaptor, x::AbstractArray{T, N}) where {T <: AbstractFloat, N} = + isbits(x) ? x : MtlArray{Float32, N}(x) +adapt_storage(::FluxMetalAdaptor, x::AbstractArray{Float16, N}) where N = + isbits(x) ? x : MtlArray{Float16, N}(x) adapt_storage(::FluxMetalAdaptor, x::Zygote.FillArrays.AbstractFill) = MtlArray(collect(x)) @@ -58,16 +27,8 @@ function ChainRulesCore.rrule( end -function _metal(ordinal::Union{Nothing, Int}, x) +function _metal(x) check_use_metal() USE_METAL[] || return x - fmap(x -> Adapt.adapt(FluxMetalAdaptor(ordinal), x), x; exclude=_isleaf) -end - -function Flux.get_device(::Val{:Metal}, ordinal::Int) - old_device = Metal.current_device() - Metal.device!(Metal.devices()[ordinal]) - device = Flux.FluxMetalDevice(Metal.device()) - Metal.device!(old_device) - return device + fmap(x -> Adapt.adapt(FluxMetalAdaptor(), x), x; exclude=_isleaf) end diff --git a/src/functor.jl b/src/functor.jl index 3d44c6bccd..5977972563 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -378,15 +378,13 @@ function _amd end # Metal extension. ###### -Base.@kwdef struct FluxMetalAdaptor - ordinal::Union{Nothing, Int} = nothing -end +struct FluxMetalAdaptor end const METAL_LOADED = Ref{Bool}(false) -function gpu(to::FluxMetalAdaptor, x) +function gpu(::FluxMetalAdaptor, x) if METAL_LOADED[] - return _metal(to.ordinal, x) + return _metal(x) else @info """ The Metal functionality is being called but From a4449f8ffac0062570096342878b1432ebc45f66 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 18 Aug 2023 15:33:31 +0530 Subject: [PATCH 16/29] Adding tests for moving models between CPU and NVIDIA devices. --- test/ext_cuda/get_devices.jl | 21 +++++++++++++++++++++ test/functors.jl | 4 ++++ 2 files changed, 25 insertions(+) diff --git a/test/ext_cuda/get_devices.jl b/test/ext_cuda/get_devices.jl index e1f4c7d8a8..73fd9dc0c7 100644 --- a/test/ext_cuda/get_devices.jl +++ b/test/ext_cuda/get_devices.jl @@ -22,4 +22,25 @@ if CUDA.functional() cx = x |> device @test cx isa CUDA.CuArray @test CUDA.device(cx).handle == device.deviceID.handle + + + # moving models to specific NVIDIA devices + m = Dense(2 => 3) # initially lives on CPU + for ordinal in 0:(length(CUDA.devices()) - 1) + device = Flux.get_device("CUDA", ordinal) + @test typeof(device.deviceID) <: CUDA.CuDevice + @test device.deviceID.handle == ordinal + + m = m |> device + @test m.weight isa CUDA.CuArray + @test m.bias isa CUDA.CuArray + @test CUDA.device(m.weight).handle == ordinal + @test CUDA.device(m.bias).handle == ordinal + end + # finally move to CPU, and see if things work + cpu_device = Flux.get_device("CPU") + m = cpu_device(m) + @test m.weight isa Matrix + @test m.bias isa Vector + end diff --git a/test/functors.jl b/test/functors.jl index b2c6c37f8d..9abc477c7f 100644 --- a/test/functors.jl +++ b/test/functors.jl @@ -12,3 +12,7 @@ device = Flux.get_device() @test typeof(device) <: Flux.FluxCPUDevice @test device(x) == x @test Flux._get_device_name(device) in Flux.supported_devices() + +# specifically getting CPU device +device = Flux.get_device("CPU") +@test typeof(device) <: Flux.FluxCPUDevice From 2a14650cc5f123f59cdb4a52401389f59d48fc54 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 18 Aug 2023 18:21:16 +0530 Subject: [PATCH 17/29] Adding tests for data movement on AMD devices. --- test/ext_amdgpu/get_devices.jl | 24 ++++++++++++++++++++++++ test/ext_cuda/get_devices.jl | 7 +++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/test/ext_amdgpu/get_devices.jl b/test/ext_amdgpu/get_devices.jl index 7691241f38..33877c7c16 100644 --- a/test/ext_amdgpu/get_devices.jl +++ b/test/ext_amdgpu/get_devices.jl @@ -21,4 +21,28 @@ if AMDGPU.functional() && AMDGPU.functional(:MIOpen) cx = x |> device @test cx isa AMDGPU.ROCArray @test AMDGPU.device_id(AMDGPU.device(cx)) == AMDGPU.device_id(device.deviceID) + + # moving models to specific NVIDIA devices + m = Dense(2 => 3) # initially lives on CPU + weight = copy(m.weight) # store the weight + bias = copy(m.bias) # store the bias + for ordinal in 0:(length(AMDGPU.devices()) - 1) + device = Flux.get_device("AMD", ordinal) + @test typeof(device.deviceID) <: AMDGPU.HIPDevice + @test AMDGPU.device_id(device.deviceID) == ordinal + + m = m |> device + @test m.weight isa AMDGPU.ROCArray + @test m.bias isa AMDGPU.ROCArray + @test ADMGPU.device_id(AMDGPU.device(m.weight)) == ordinal + @test ADMGPU.device_id(AMDGPU.device(m.bias)) == ordinal + @test isequal(Flux.cpu(m.weight), weight) + @test isequal(Flux.cpu(m.bias), bias) + end + # finally move to CPU, and see if things work + cpu_device = Flux.get_device("CPU") + m = cpu_device(m) + @test m.weight isa Matrix + @test m.bias isa Vector + end diff --git a/test/ext_cuda/get_devices.jl b/test/ext_cuda/get_devices.jl index 73fd9dc0c7..f75c4b3ee1 100644 --- a/test/ext_cuda/get_devices.jl +++ b/test/ext_cuda/get_devices.jl @@ -23,9 +23,10 @@ if CUDA.functional() @test cx isa CUDA.CuArray @test CUDA.device(cx).handle == device.deviceID.handle - # moving models to specific NVIDIA devices - m = Dense(2 => 3) # initially lives on CPU + m = Dense(2 => 3) # initially lives on CPU + weight = copy(m.weight) # store the weight + bias = copy(m.bias) # store the bias for ordinal in 0:(length(CUDA.devices()) - 1) device = Flux.get_device("CUDA", ordinal) @test typeof(device.deviceID) <: CUDA.CuDevice @@ -36,6 +37,8 @@ if CUDA.functional() @test m.bias isa CUDA.CuArray @test CUDA.device(m.weight).handle == ordinal @test CUDA.device(m.bias).handle == ordinal + @test isequal(Flux.cpu(m.weight), weight) + @test isequal(Flux.cpu(m.bias), bias) end # finally move to CPU, and see if things work cpu_device = Flux.get_device("CPU") From a9fb328096f64c74402afc36ce9b004dc1808dac Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 18 Aug 2023 23:27:25 +0530 Subject: [PATCH 18/29] Fixing index error while choosing AMD gpu device. --- ext/FluxAMDGPUExt/functor.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ext/FluxAMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl index 6b26431e2b..3a9402df73 100644 --- a/ext/FluxAMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -15,7 +15,7 @@ function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray) old_ordinal = AMDGPU.device_id(AMDGPU.device()) if !(x isa ROCArray) - AMDGPU.device!(AMDGPU.devices()[to.ordinal]) + AMDGPU.device!(AMDGPU.devices()[to.ordinal + 1]) # adding 1 because ordinals start from 0 if (typeof(x) <: AbstractArray{Float16, N} where N) N = length(size(x)) x_new = isbits(x) ? x : ROCArray{Float16, N}(x) @@ -25,14 +25,14 @@ function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray) else x_new = isbits(x) ? x : ROCArray(x) end - AMDGPU.device!(AMDGPU.devices()[old_ordinal]) + AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1]) return x_new elseif AMDGPU.device_id(AMDGPU.device(x)) == to.ordinal return x else - AMDGPU.device!(AMDGPU.devices()[to.ordinal]) + AMDGPU.device!(AMDGPU.devices()[to.ordinal + 1]) x_new = copy(x) - AMDGPU.device!(AMDGPU.devices()[old_ordinal]) + AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1]) return x_new end end @@ -108,8 +108,8 @@ end function Flux.get_device(::Val{:AMD}, ordinal::Int) old_ordinal = AMDGPU.device_id(AMDGPU.device()) - AMDGPU.device!(AMDGPU.devices()[ordinal]) + AMDGPU.device!(AMDGPU.devices()[ordinal + 1]) # adding 1 because ordinals start from 0 device = Flux.FluxAMDDevice(AMDGPU.device()) - AMDGPU.device!(AMDGPU.devices()[old_ordinal]) + AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1]) return device end From 91a756f2455908d7fcd9111f4aa8d0457a855d31 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 20 Aug 2023 14:32:27 +0530 Subject: [PATCH 19/29] Fixing AMD ordinal starting index. --- ext/FluxAMDGPUExt/FluxAMDGPUExt.jl | 2 +- ext/FluxAMDGPUExt/functor.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl index 446a3cbd43..9420b11340 100644 --- a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl +++ b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl @@ -21,7 +21,7 @@ function (device::Flux.FluxAMDDevice)(x) if typeof(device.deviceID) <: Nothing Flux.gpu(Flux.FluxAMDAdaptor(), x) else - return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID)), x) + return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID)) - 1, x) # subtracting 1, because device_id returns a positive integer end end Flux._get_device_name(::Flux.FluxAMDDevice) = "AMD" diff --git a/ext/FluxAMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl index 3a9402df73..cb665f44cd 100644 --- a/ext/FluxAMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -106,7 +106,7 @@ function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMD_CONV) Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups) end -function Flux.get_device(::Val{:AMD}, ordinal::Int) +function Flux.get_device(::Val{:AMD}, ordinal::Int) # ordinal should start from 0 old_ordinal = AMDGPU.device_id(AMDGPU.device()) AMDGPU.device!(AMDGPU.devices()[ordinal + 1]) # adding 1 because ordinals start from 0 device = Flux.FluxAMDDevice(AMDGPU.device()) From a9f661568f8607b309d2b55ae2948c8dbaaeaad5 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 20 Aug 2023 14:33:09 +0530 Subject: [PATCH 20/29] Adding docstring for new `get_device` method. --- src/functor.jl | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/src/functor.jl b/src/functor.jl index 5977972563..916fabcb44 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -545,7 +545,7 @@ julia> Flux.supported_devices() supported_devices() = GPU_BACKENDS """ - Flux.get_device(; verbose=false)::AbstractDevice + Flux.get_device(; verbose=false)::Flux.AbstractDevice Returns a `device` object for the most appropriate backend for the current Julia session. @@ -649,6 +649,40 @@ function get_device(; verbose=false)::AbstractDevice end end +""" + Flux.get_device(backend::String, ordinal::Int = 0)::Flux.AbstractDevice + +Get a device object for a backend specified by the string `backend` and `ordinal`. The currently supported values +of `backend` are `"CUDA"`, `"AMD"` and `"CPU"`. + +# Examples + +```julia-repl +julia> using Flux, CUDA; + +julia> CUDA.devices() +CUDA.DeviceIterator() for 3 devices: +0. GeForce RTX 2080 Ti +1. GeForce RTX 2080 Ti +2. TITAN X (Pascal) + +julia> device0 = Flux.get_device("CUDA", 0) +(::Flux.FluxCUDADevice) (generic function with 1 method) + +julia> device0.deviceID +CuDevice(0): GeForce RTX 2080 Ti + +julia> device1 = Flux.get_device("CUDA", 1) +(::Flux.FluxCUDADevice) (generic function with 1 method) + +julia> device1.deviceID +CuDevice(1): GeForce RTX 2080 Ti + +julia> cpu_device = Flux.get_device("CPU") +(::Flux.FluxCPUDevice) (generic function with 1 method) + +``` +""" function get_device(backend::String, ordinal::Int = 0) if backend == "CPU" return FluxCPUDevice() From b47a6f433b402f1dc2380e2b01f66992c24ea040 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 20 Aug 2023 15:33:05 +0530 Subject: [PATCH 21/29] Removing global name conflicts in tests. --- test/ext_amdgpu/get_devices.jl | 46 +++++++++++++++++----------------- test/ext_cuda/get_devices.jl | 46 +++++++++++++++++----------------- test/ext_metal/get_devices.jl | 12 ++++----- test/functors.jl | 12 ++++----- 4 files changed, 58 insertions(+), 58 deletions(-) diff --git a/test/ext_amdgpu/get_devices.jl b/test/ext_amdgpu/get_devices.jl index 33877c7c16..69519d9ea4 100644 --- a/test/ext_amdgpu/get_devices.jl +++ b/test/ext_amdgpu/get_devices.jl @@ -10,39 +10,39 @@ else end if AMDGPU.functional() && AMDGPU.functional(:MIOpen) - device = Flux.get_device() + amd_device = Flux.get_device() - @test typeof(device) <: Flux.FluxAMDDevice - @test typeof(device.deviceID) <: AMDGPU.HIPDevice - @test Flux._get_device_name(device) in Flux.supported_devices() + @test typeof(amd_device) <: Flux.FluxAMDDevice + @test typeof(amd_device.deviceID) <: AMDGPU.HIPDevice + @test Flux._get_device_name(amd_device) in Flux.supported_devices() # correctness of data transfer x = randn(5, 5) - cx = x |> device + cx = x |> amd_device @test cx isa AMDGPU.ROCArray - @test AMDGPU.device_id(AMDGPU.device(cx)) == AMDGPU.device_id(device.deviceID) + @test AMDGPU.device_id(AMDGPU.device(cx)) == AMDGPU.device_id(amd_device.deviceID) # moving models to specific NVIDIA devices - m = Dense(2 => 3) # initially lives on CPU - weight = copy(m.weight) # store the weight - bias = copy(m.bias) # store the bias + dense_model = Dense(2 => 3) # initially lives on CPU + weight = copy(dense_model.weight) # store the weight + bias = copy(dense_model.bias) # store the bias for ordinal in 0:(length(AMDGPU.devices()) - 1) - device = Flux.get_device("AMD", ordinal) - @test typeof(device.deviceID) <: AMDGPU.HIPDevice - @test AMDGPU.device_id(device.deviceID) == ordinal - - m = m |> device - @test m.weight isa AMDGPU.ROCArray - @test m.bias isa AMDGPU.ROCArray - @test ADMGPU.device_id(AMDGPU.device(m.weight)) == ordinal - @test ADMGPU.device_id(AMDGPU.device(m.bias)) == ordinal - @test isequal(Flux.cpu(m.weight), weight) - @test isequal(Flux.cpu(m.bias), bias) + amd_device = Flux.get_device("AMD", ordinal) + @test typeof(amd_device.deviceID) <: AMDGPU.HIPDevice + @test AMDGPU.device_id(amd_device.deviceID) == ordinal + + dense_model = dense_model |> amd_device + @test dense_model.weight isa AMDGPU.ROCArray + @test dense_model.bias isa AMDGPU.ROCArray + @test ADMGPU.device_id(AMDGPU.device(dense_model.weight)) == ordinal + @test ADMGPU.device_id(AMDGPU.device(dense_model.bias)) == ordinal + @test isequal(Flux.cpu(dense_model.weight), weight) + @test isequal(Flux.cpu(dense_model.bias), bias) end # finally move to CPU, and see if things work cpu_device = Flux.get_device("CPU") - m = cpu_device(m) - @test m.weight isa Matrix - @test m.bias isa Vector + dense_model = cpu_device(dense_model) + @test dense_model.weight isa Matrix + @test dense_model.bias isa Vector end diff --git a/test/ext_cuda/get_devices.jl b/test/ext_cuda/get_devices.jl index f75c4b3ee1..ba3108120d 100644 --- a/test/ext_cuda/get_devices.jl +++ b/test/ext_cuda/get_devices.jl @@ -11,39 +11,39 @@ end # testing get_device if CUDA.functional() - device = Flux.get_device() + cuda_device = Flux.get_device() - @test typeof(device) <: Flux.FluxCUDADevice - @test typeof(device.deviceID) <: CUDA.CuDevice - @test Flux._get_device_name(device) in Flux.supported_devices() + @test typeof(cuda_device) <: Flux.FluxCUDADevice + @test typeof(cuda_device.deviceID) <: CUDA.CuDevice + @test Flux._get_device_name(cuda_device) in Flux.supported_devices() # correctness of data transfer x = randn(5, 5) - cx = x |> device + cx = x |> cuda_device @test cx isa CUDA.CuArray - @test CUDA.device(cx).handle == device.deviceID.handle + @test CUDA.device(cx).handle == cuda_device.deviceID.handle # moving models to specific NVIDIA devices - m = Dense(2 => 3) # initially lives on CPU - weight = copy(m.weight) # store the weight - bias = copy(m.bias) # store the bias + dense_model = Dense(2 => 3) # initially lives on CPU + weight = copy(dense_model.weight) # store the weight + bias = copy(dense_model.bias) # store the bias for ordinal in 0:(length(CUDA.devices()) - 1) - device = Flux.get_device("CUDA", ordinal) - @test typeof(device.deviceID) <: CUDA.CuDevice - @test device.deviceID.handle == ordinal - - m = m |> device - @test m.weight isa CUDA.CuArray - @test m.bias isa CUDA.CuArray - @test CUDA.device(m.weight).handle == ordinal - @test CUDA.device(m.bias).handle == ordinal - @test isequal(Flux.cpu(m.weight), weight) - @test isequal(Flux.cpu(m.bias), bias) + cuda_device = Flux.get_device("CUDA", ordinal) + @test typeof(cuda_device.deviceID) <: CUDA.CuDevice + @test cuda_device.deviceID.handle == ordinal + + dense_model = dense_model |> cuda_device + @test dense_model.weight isa CUDA.CuArray + @test dense_model.bias isa CUDA.CuArray + @test CUDA.device(dense_model.weight).handle == ordinal + @test CUDA.device(dense_model.bias).handle == ordinal + @test isequal(Flux.cpu(dense_model.weight), weight) + @test isequal(Flux.cpu(dense_model.bias), bias) end # finally move to CPU, and see if things work cpu_device = Flux.get_device("CPU") - m = cpu_device(m) - @test m.weight isa Matrix - @test m.bias isa Vector + dense_model = cpu_device(dense_model) + @test dense_model.weight isa Matrix + @test dense_model.bias isa Vector end diff --git a/test/ext_metal/get_devices.jl b/test/ext_metal/get_devices.jl index 83786e9834..12f7b87210 100644 --- a/test/ext_metal/get_devices.jl +++ b/test/ext_metal/get_devices.jl @@ -11,15 +11,15 @@ end # testing get_device if Metal.functional() - device = Flux.get_device() + metal_device = Flux.get_device() - @test typeof(device) <: Flux.FluxMetalDevice - @test typeof(device.deviceID) <: Metal.MTLDevice - @test Flux._get_device_name(device) in Flux.supported_devices() + @test typeof(metal_device) <: Flux.FluxMetalDevice + @test typeof(metal_device.deviceID) <: Metal.MTLDevice + @test Flux._get_device_name(metal_device) in Flux.supported_devices() # correctness of data transfer x = randn(5, 5) - cx = x |> device + cx = x |> metal_device @test cx isa Metal.MtlArray - @test Metal.device(cx).registryID == device.deviceID.registryID + @test Metal.device(cx).registryID == metal_device.deviceID.registryID end diff --git a/test/functors.jl b/test/functors.jl index 9abc477c7f..a9f26194ea 100644 --- a/test/functors.jl +++ b/test/functors.jl @@ -8,11 +8,11 @@ end @test typeof(Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]]) <: Nothing @test typeof(Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["CPU"]]) <: Flux.FluxCPUDevice -device = Flux.get_device() -@test typeof(device) <: Flux.FluxCPUDevice -@test device(x) == x -@test Flux._get_device_name(device) in Flux.supported_devices() +dev = Flux.get_device() +@test typeof(dev) <: Flux.FluxCPUDevice +@test dev(x) == x +@test Flux._get_device_name(dev) in Flux.supported_devices() # specifically getting CPU device -device = Flux.get_device("CPU") -@test typeof(device) <: Flux.FluxCPUDevice +dev = Flux.get_device("CPU") +@test typeof(dev) <: Flux.FluxCPUDevice From f1ab56954661b4ff38645f85a10b3f1090174c4b Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 20 Aug 2023 15:35:17 +0530 Subject: [PATCH 22/29] Minor fix to AMDs device id tests. --- test/ext_amdgpu/get_devices.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/ext_amdgpu/get_devices.jl b/test/ext_amdgpu/get_devices.jl index 69519d9ea4..a3f4242e5c 100644 --- a/test/ext_amdgpu/get_devices.jl +++ b/test/ext_amdgpu/get_devices.jl @@ -29,13 +29,13 @@ if AMDGPU.functional() && AMDGPU.functional(:MIOpen) for ordinal in 0:(length(AMDGPU.devices()) - 1) amd_device = Flux.get_device("AMD", ordinal) @test typeof(amd_device.deviceID) <: AMDGPU.HIPDevice - @test AMDGPU.device_id(amd_device.deviceID) == ordinal + @test AMDGPU.device_id(amd_device.deviceID) == ordinal + 1 dense_model = dense_model |> amd_device @test dense_model.weight isa AMDGPU.ROCArray @test dense_model.bias isa AMDGPU.ROCArray - @test ADMGPU.device_id(AMDGPU.device(dense_model.weight)) == ordinal - @test ADMGPU.device_id(AMDGPU.device(dense_model.bias)) == ordinal + @test ADMGPU.device_id(AMDGPU.device(dense_model.weight)) == ordinal + 1 + @test ADMGPU.device_id(AMDGPU.device(dense_model.bias)) == ordinal + 1 @test isequal(Flux.cpu(dense_model.weight), weight) @test isequal(Flux.cpu(dense_model.bias), bias) end From 129a0b56c8b8dca054d20214a5c9f566125559e0 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 20 Aug 2023 16:40:31 +0530 Subject: [PATCH 23/29] Disambiguating test variables. --- test/ext_amdgpu/get_devices.jl | 15 ++++++++------- test/ext_cuda/get_devices.jl | 14 +++++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/test/ext_amdgpu/get_devices.jl b/test/ext_amdgpu/get_devices.jl index a3f4242e5c..f89286c4b2 100644 --- a/test/ext_amdgpu/get_devices.jl +++ b/test/ext_amdgpu/get_devices.jl @@ -9,6 +9,10 @@ else @test typeof(amd_device.deviceID) <: Nothing end +# testing get_device +dense_model = Dense(2 => 3) # initially lives on CPU +weight = copy(dense_model.weight) # store the weight +bias = copy(dense_model.bias) # store the bias if AMDGPU.functional() && AMDGPU.functional(:MIOpen) amd_device = Flux.get_device() @@ -23,15 +27,12 @@ if AMDGPU.functional() && AMDGPU.functional(:MIOpen) @test AMDGPU.device_id(AMDGPU.device(cx)) == AMDGPU.device_id(amd_device.deviceID) # moving models to specific NVIDIA devices - dense_model = Dense(2 => 3) # initially lives on CPU - weight = copy(dense_model.weight) # store the weight - bias = copy(dense_model.bias) # store the bias for ordinal in 0:(length(AMDGPU.devices()) - 1) - amd_device = Flux.get_device("AMD", ordinal) - @test typeof(amd_device.deviceID) <: AMDGPU.HIPDevice - @test AMDGPU.device_id(amd_device.deviceID) == ordinal + 1 + current_amd_device = Flux.get_device("AMD", ordinal) + @test typeof(current_amd_device.deviceID) <: AMDGPU.HIPDevice + @test AMDGPU.device_id(current_amd_device.deviceID) == ordinal + 1 - dense_model = dense_model |> amd_device + global dense_model = dense_model |> current_amd_device @test dense_model.weight isa AMDGPU.ROCArray @test dense_model.bias isa AMDGPU.ROCArray @test ADMGPU.device_id(AMDGPU.device(dense_model.weight)) == ordinal + 1 diff --git a/test/ext_cuda/get_devices.jl b/test/ext_cuda/get_devices.jl index ba3108120d..f3adf4233e 100644 --- a/test/ext_cuda/get_devices.jl +++ b/test/ext_cuda/get_devices.jl @@ -10,6 +10,9 @@ else end # testing get_device +dense_model = Dense(2 => 3) # initially lives on CPU +weight = copy(dense_model.weight) # store the weight +bias = copy(dense_model.bias) # store the bias if CUDA.functional() cuda_device = Flux.get_device() @@ -24,15 +27,12 @@ if CUDA.functional() @test CUDA.device(cx).handle == cuda_device.deviceID.handle # moving models to specific NVIDIA devices - dense_model = Dense(2 => 3) # initially lives on CPU - weight = copy(dense_model.weight) # store the weight - bias = copy(dense_model.bias) # store the bias for ordinal in 0:(length(CUDA.devices()) - 1) - cuda_device = Flux.get_device("CUDA", ordinal) - @test typeof(cuda_device.deviceID) <: CUDA.CuDevice - @test cuda_device.deviceID.handle == ordinal + current_cuda_device = Flux.get_device("CUDA", ordinal) + @test typeof(current_cuda_device.deviceID) <: CUDA.CuDevice + @test current_cuda_device.deviceID.handle == ordinal - dense_model = dense_model |> cuda_device + global dense_model = dense_model |> current_cuda_device @test dense_model.weight isa CUDA.CuArray @test dense_model.bias isa CUDA.CuArray @test CUDA.device(dense_model.weight).handle == ordinal From 7a5b847ff72a2e3f1964404255788feeeb2536a0 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 20 Aug 2023 22:16:09 +0530 Subject: [PATCH 24/29] Adding more info in docstring of `get_device`, and writing some documentation in the guide. --- docs/src/gpu.md | 65 +++++++++++++++++++++++++++++++++++++++++++++++++ src/functor.jl | 2 +- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/docs/src/gpu.md b/docs/src/gpu.md index 70708c444a..1d951a89d9 100644 --- a/docs/src/gpu.md +++ b/docs/src/gpu.md @@ -311,6 +311,71 @@ julia> device = Flux.get_device(; verbose=true) # this will resort to auto ``` For detailed information about how the backend is selected, check the documentation for [`Flux.get_device`](@ref). +## Data movement across GPU devices + +Flux also supports getting handles to specific GPU devices, and transferring models from one GPU device to another GPU +device from the same backend. Let's try it out for NVIDIA GPUs. First, we list all the available devices: + +```julia-repl +julia> using Flux, CUDA; + +julia> CUDA.devices() +CUDA.DeviceIterator() for 3 devices: +0. GeForce RTX 2080 Ti +1. GeForce RTX 2080 Ti +2. TITAN X (Pascal) + +``` + +Then, let's select the device with ordinal `0`: + +```julia-repl +julia> device0 = Flux.get_device("CUDA", 0) # the currently supported values for backend are "CUDA" and "AMD" +(::Flux.FluxCUDADevice) (generic function with 1 method) + +``` + +Then, let's move a simple dense layer to the GPU represented by `device0`: + +```julia-repl +julia> dense_model = Dense(2 => 3) +Dense(2 => 3) # 9 parameters + +julia> dense_model = dense_model |> device0; + +julia> dense_model.weight +3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}: + 0.695662 0.816299 + -0.204763 -0.10232 + -0.955829 0.538412 + +julia> CUDA.device(dense_model.weight) # check the GPU to which dense_model is attached +CuDevice(0): GeForce RTX 2080 Ti + +``` + +Next, we'll get a handle to the device with ordinal `1`, and move `dense_model` to that device: + +```julia-repl +julia> device1 = Flux.get_device("CUDA", 1) +(::Flux.FluxCUDADevice) (generic function with 1 method) + +julia> dense_model = dense_model |> device1; # don't directly print the model; see warning below + +julia> CUDA.device(dense_model.weight) +CuDevice(1): GeForce RTX 2080 Ti + +``` + +Due to a limitation in `Metal.jl`, currently this kind of data movement across devices is only supported for `CUDA` and `AMD` backends. + +!!! warning "Printing models after moving to a different device" + + Due to a limitation in how GPU packages currently work, printing + models on the REPL after moving them to a GPU device which is different + from the current device will lead to an error. + + ```@docs Flux.AbstractDevice Flux.FluxCPUDevice diff --git a/src/functor.jl b/src/functor.jl index 916fabcb44..a39e542847 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -653,7 +653,7 @@ end Flux.get_device(backend::String, ordinal::Int = 0)::Flux.AbstractDevice Get a device object for a backend specified by the string `backend` and `ordinal`. The currently supported values -of `backend` are `"CUDA"`, `"AMD"` and `"CPU"`. +of `backend` are `"CUDA"`, `"AMD"` and `"CPU"`. `ordinal` must be an integer value between `0` and the number of available devices. # Examples From 1b770bc4f42f3fa204626a2a8e502e60ade78a5e Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 20 Aug 2023 23:18:18 +0530 Subject: [PATCH 25/29] Fixing minor error in AMD code. --- ext/FluxAMDGPUExt/FluxAMDGPUExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl index 9420b11340..5606f2db01 100644 --- a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl +++ b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl @@ -21,7 +21,7 @@ function (device::Flux.FluxAMDDevice)(x) if typeof(device.deviceID) <: Nothing Flux.gpu(Flux.FluxAMDAdaptor(), x) else - return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID)) - 1, x) # subtracting 1, because device_id returns a positive integer + return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID) - 1), x) # subtracting 1, because device_id returns a positive integer end end Flux._get_device_name(::Flux.FluxAMDDevice) = "AMD" From 8c094852bffb88b5d7582d25523f26eca088005b Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 20 Aug 2023 23:22:09 +0530 Subject: [PATCH 26/29] Fixing yet another ordinal index error in AMD code. --- ext/FluxAMDGPUExt/functor.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/FluxAMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl index cb665f44cd..4799892075 100644 --- a/ext/FluxAMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -12,7 +12,7 @@ function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray) end end - old_ordinal = AMDGPU.device_id(AMDGPU.device()) + old_ordinal = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ordinals start from 0 if !(x isa ROCArray) AMDGPU.device!(AMDGPU.devices()[to.ordinal + 1]) # adding 1 because ordinals start from 0 From e36759896789320c8fa3a14c19bf61e45265ec06 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 20 Aug 2023 23:26:43 +0530 Subject: [PATCH 27/29] Fixing another ordinal index error in AMD code. --- ext/FluxAMDGPUExt/functor.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/FluxAMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl index 4799892075..4e79d0b01e 100644 --- a/ext/FluxAMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -107,8 +107,8 @@ function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMD_CONV) end function Flux.get_device(::Val{:AMD}, ordinal::Int) # ordinal should start from 0 - old_ordinal = AMDGPU.device_id(AMDGPU.device()) - AMDGPU.device!(AMDGPU.devices()[ordinal + 1]) # adding 1 because ordinals start from 0 + old_ordinal = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ordinals start from 0 + AMDGPU.device!(AMDGPU.devices()[ordinal + 1]) # adding 1 because ordinals start from 0 device = Flux.FluxAMDDevice(AMDGPU.device()) AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1]) return device From 08b196295fca0a3e1cb691d6c9c0c919110cbcf3 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 20 Aug 2023 23:31:20 +0530 Subject: [PATCH 28/29] Fixing spelling mistake. --- test/ext_amdgpu/get_devices.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ext_amdgpu/get_devices.jl b/test/ext_amdgpu/get_devices.jl index f89286c4b2..9023fa743c 100644 --- a/test/ext_amdgpu/get_devices.jl +++ b/test/ext_amdgpu/get_devices.jl @@ -35,8 +35,8 @@ if AMDGPU.functional() && AMDGPU.functional(:MIOpen) global dense_model = dense_model |> current_amd_device @test dense_model.weight isa AMDGPU.ROCArray @test dense_model.bias isa AMDGPU.ROCArray - @test ADMGPU.device_id(AMDGPU.device(dense_model.weight)) == ordinal + 1 - @test ADMGPU.device_id(AMDGPU.device(dense_model.bias)) == ordinal + 1 + @test AMDGPU.device_id(AMDGPU.device(dense_model.weight)) == ordinal + 1 + @test AMDGPU.device_id(AMDGPU.device(dense_model.bias)) == ordinal + 1 @test isequal(Flux.cpu(dense_model.weight), weight) @test isequal(Flux.cpu(dense_model.bias), bias) end From 0063bc046aecc3f915e3d09d240112fcc23a9b2f Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 20 Aug 2023 23:40:44 +0530 Subject: [PATCH 29/29] Replacing type checks for `nothing` but equality checks. --- ext/FluxAMDGPUExt/FluxAMDGPUExt.jl | 2 +- ext/FluxAMDGPUExt/functor.jl | 2 +- ext/FluxCUDAExt/FluxCUDAExt.jl | 2 +- ext/FluxCUDAExt/functor.jl | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl index 5606f2db01..a199cd270e 100644 --- a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl +++ b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl @@ -18,7 +18,7 @@ const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing) function (device::Flux.FluxAMDDevice)(x) - if typeof(device.deviceID) <: Nothing + if device.deviceID === nothing Flux.gpu(Flux.FluxAMDAdaptor(), x) else return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID) - 1), x) # subtracting 1, because device_id returns a positive integer diff --git a/ext/FluxAMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl index 4e79d0b01e..62507122db 100644 --- a/ext/FluxAMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -1,6 +1,6 @@ # Convert Float64 to Float32, but preserve Float16. function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray) - if typeof(to.ordinal) <: Nothing + if to.ordinal === nothing if (typeof(x) <: AbstractArray{Float16, N} where N) N = length(size(x)) return isbits(x) ? x : ROCArray{Float16, N}(x) diff --git a/ext/FluxCUDAExt/FluxCUDAExt.jl b/ext/FluxCUDAExt/FluxCUDAExt.jl index adb2c7af66..3fcdc5c263 100644 --- a/ext/FluxCUDAExt/FluxCUDAExt.jl +++ b/ext/FluxCUDAExt/FluxCUDAExt.jl @@ -15,7 +15,7 @@ import Adapt: adapt_storage const USE_CUDA = Ref{Union{Nothing, Bool}}(nothing) function (device::Flux.FluxCUDADevice)(x) - if typeof(device.deviceID) <: Nothing + if device.deviceID === nothing return Flux.gpu(Flux.FluxCUDAAdaptor(), x) else return Flux.gpu(Flux.FluxCUDAAdaptor(device.deviceID.handle), x) diff --git a/ext/FluxCUDAExt/functor.jl b/ext/FluxCUDAExt/functor.jl index b98e3401e3..5a4c1d1152 100644 --- a/ext/FluxCUDAExt/functor.jl +++ b/ext/FluxCUDAExt/functor.jl @@ -1,6 +1,6 @@ adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) function adapt_storage(to::FluxCUDAAdaptor, x::AbstractArray) - typeof(to.ordinal) <: Nothing && return CUDA.cu(x) + to.ordinal === nothing && return CUDA.cu(x) # remember current device old_ordinal = CUDA.device().handle