diff --git a/Project.toml b/Project.toml index 6b58703e1c..5322cdd9b9 100644 --- a/Project.toml +++ b/Project.toml @@ -24,8 +24,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] FluxAMDGPUExt = "AMDGPU" @@ -56,8 +56,8 @@ julia = "1.9" [extras] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" @@ -68,6 +68,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", - "FillArrays", "ComponentArrays", "BSON", "Pkg", - "CUDA", "cuDNN", "Metal", "AMDGPU"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU"] diff --git a/docs/src/gpu.md b/docs/src/gpu.md index 70708c444a..1d951a89d9 100644 --- a/docs/src/gpu.md +++ b/docs/src/gpu.md @@ -311,6 +311,71 @@ julia> device = Flux.get_device(; verbose=true) # this will resort to auto ``` For detailed information about how the backend is selected, check the documentation for [`Flux.get_device`](@ref). +## Data movement across GPU devices + +Flux also supports getting handles to specific GPU devices, and transferring models from one GPU device to another GPU +device from the same backend. Let's try it out for NVIDIA GPUs. First, we list all the available devices: + +```julia-repl +julia> using Flux, CUDA; + +julia> CUDA.devices() +CUDA.DeviceIterator() for 3 devices: +0. GeForce RTX 2080 Ti +1. GeForce RTX 2080 Ti +2. TITAN X (Pascal) + +``` + +Then, let's select the device with ordinal `0`: + +```julia-repl +julia> device0 = Flux.get_device("CUDA", 0) # the currently supported values for backend are "CUDA" and "AMD" +(::Flux.FluxCUDADevice) (generic function with 1 method) + +``` + +Then, let's move a simple dense layer to the GPU represented by `device0`: + +```julia-repl +julia> dense_model = Dense(2 => 3) +Dense(2 => 3) # 9 parameters + +julia> dense_model = dense_model |> device0; + +julia> dense_model.weight +3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}: + 0.695662 0.816299 + -0.204763 -0.10232 + -0.955829 0.538412 + +julia> CUDA.device(dense_model.weight) # check the GPU to which dense_model is attached +CuDevice(0): GeForce RTX 2080 Ti + +``` + +Next, we'll get a handle to the device with ordinal `1`, and move `dense_model` to that device: + +```julia-repl +julia> device1 = Flux.get_device("CUDA", 1) +(::Flux.FluxCUDADevice) (generic function with 1 method) + +julia> dense_model = dense_model |> device1; # don't directly print the model; see warning below + +julia> CUDA.device(dense_model.weight) +CuDevice(1): GeForce RTX 2080 Ti + +``` + +Due to a limitation in `Metal.jl`, currently this kind of data movement across devices is only supported for `CUDA` and `AMD` backends. + +!!! warning "Printing models after moving to a different device" + + Due to a limitation in how GPU packages currently work, printing + models on the REPL after moving them to a GPU device which is different + from the current device will lead to an error. + + ```@docs Flux.AbstractDevice Flux.FluxCPUDevice diff --git a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl index f41984ec38..a199cd270e 100644 --- a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl +++ b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl @@ -17,6 +17,14 @@ const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat # Set to boolean on the first call to check_use_amdgpu const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing) +function (device::Flux.FluxAMDDevice)(x) + if device.deviceID === nothing + Flux.gpu(Flux.FluxAMDAdaptor(), x) + else + return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID) - 1), x) # subtracting 1, because device_id returns a positive integer + end +end +Flux._get_device_name(::Flux.FluxAMDDevice) = "AMD" Flux._isavailable(::Flux.FluxAMDDevice) = true Flux._isfunctional(::Flux.FluxAMDDevice) = AMDGPU.functional() diff --git a/ext/FluxAMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl index dc3d3cbcce..62507122db 100644 --- a/ext/FluxAMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -1,10 +1,41 @@ # Convert Float64 to Float32, but preserve Float16. -adapt_storage(::FluxAMDAdaptor, x::T) where T <: AbstractArray = - isbits(x) ? x : ROCArray(x) -adapt_storage(::FluxAMDAdaptor, x::AbstractArray{T, N}) where {T <: AbstractFloat, N} = - isbits(x) ? x : ROCArray{Float32, N}(x) -adapt_storage(::FluxAMDAdaptor, x::AbstractArray{Float16, N}) where N = - isbits(x) ? x : ROCArray{Float16, N}(x) +function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray) + if to.ordinal === nothing + if (typeof(x) <: AbstractArray{Float16, N} where N) + N = length(size(x)) + return isbits(x) ? x : ROCArray{Float16, N}(x) + elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N}) + N = length(size(x)) + return isbits(x) ? x : ROCArray{Float32, N}(x) + else + return isbits(x) ? x : ROCArray(x) + end + end + + old_ordinal = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ordinals start from 0 + + if !(x isa ROCArray) + AMDGPU.device!(AMDGPU.devices()[to.ordinal + 1]) # adding 1 because ordinals start from 0 + if (typeof(x) <: AbstractArray{Float16, N} where N) + N = length(size(x)) + x_new = isbits(x) ? x : ROCArray{Float16, N}(x) + elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N}) + N = length(size(x)) + x_new = isbits(x) ? x : ROCArray{Float32, N}(x) + else + x_new = isbits(x) ? x : ROCArray(x) + end + AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1]) + return x_new + elseif AMDGPU.device_id(AMDGPU.device(x)) == to.ordinal + return x + else + AMDGPU.device!(AMDGPU.devices()[to.ordinal + 1]) + x_new = copy(x) + AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1]) + return x_new + end +end adapt_storage(::FluxAMDAdaptor, x::Zygote.FillArrays.AbstractFill) = ROCArray(collect(x)) @@ -45,10 +76,10 @@ Flux._isleaf(::AMD_CONV) = true _exclude(x) = Flux._isleaf(x) _exclude(::CPU_CONV) = true -function _amd(x) +function _amd(ordinal::Union{Nothing, Int}, x) check_use_amdgpu() USE_AMDGPU[] || return x - fmap(x -> Adapt.adapt(FluxAMDAdaptor(), x), x; exclude=_exclude) + fmap(x -> Adapt.adapt(FluxAMDAdaptor(ordinal), x), x; exclude=_exclude) end # CPU -> GPU @@ -74,3 +105,11 @@ function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMD_CONV) Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims), Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups) end + +function Flux.get_device(::Val{:AMD}, ordinal::Int) # ordinal should start from 0 + old_ordinal = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ordinals start from 0 + AMDGPU.device!(AMDGPU.devices()[ordinal + 1]) # adding 1 because ordinals start from 0 + device = Flux.FluxAMDDevice(AMDGPU.device()) + AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1]) + return device +end diff --git a/ext/FluxCUDAExt/FluxCUDAExt.jl b/ext/FluxCUDAExt/FluxCUDAExt.jl index ad80cf8a58..3fcdc5c263 100644 --- a/ext/FluxCUDAExt/FluxCUDAExt.jl +++ b/ext/FluxCUDAExt/FluxCUDAExt.jl @@ -14,6 +14,14 @@ import Adapt: adapt_storage const USE_CUDA = Ref{Union{Nothing, Bool}}(nothing) +function (device::Flux.FluxCUDADevice)(x) + if device.deviceID === nothing + return Flux.gpu(Flux.FluxCUDAAdaptor(), x) + else + return Flux.gpu(Flux.FluxCUDAAdaptor(device.deviceID.handle), x) + end +end +Flux._get_device_name(::Flux.FluxCUDADevice) = "CUDA" Flux._isavailable(::Flux.FluxCUDADevice) = true Flux._isfunctional(::Flux.FluxCUDADevice) = CUDA.functional() diff --git a/ext/FluxCUDAExt/functor.jl b/ext/FluxCUDAExt/functor.jl index 347cfce372..5a4c1d1152 100644 --- a/ext/FluxCUDAExt/functor.jl +++ b/ext/FluxCUDAExt/functor.jl @@ -1,5 +1,24 @@ - adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) +function adapt_storage(to::FluxCUDAAdaptor, x::AbstractArray) + to.ordinal === nothing && return CUDA.cu(x) + + # remember current device + old_ordinal = CUDA.device().handle + + if !(x isa CuArray) + CUDA.device!(to.ordinal) + x_new = CUDA.cu(x) + CUDA.device!(old_ordinal) + return x_new + elseif CUDA.device(x).handle == to.ordinal + return x + else + CUDA.device!(to.ordinal) + x_new = copy(x) + CUDA.device!(old_ordinal) + return x_new + end +end adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x)) adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng() adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x @@ -25,8 +44,16 @@ ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AnyCuArray) = ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AbstractArray) = adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCPUAdaptor(), unthunk(Δ))) -function _cuda(x) +function _cuda(ordinal::Union{Nothing, Int}, x) check_use_cuda() USE_CUDA[] || return x - fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude=Flux._isleaf) + fmap(x -> Adapt.adapt(FluxCUDAAdaptor(ordinal), x), x; exclude=Flux._isleaf) +end + +function Flux.get_device(::Val{:CUDA}, ordinal::Int) + old_ordinal = CUDA.device().handle + CUDA.device!(ordinal) + device = Flux.FluxCUDADevice(CUDA.device()) + CUDA.device!(old_ordinal) + return device end diff --git a/ext/FluxMetalExt/FluxMetalExt.jl b/ext/FluxMetalExt/FluxMetalExt.jl index bca48fe279..a11046d244 100644 --- a/ext/FluxMetalExt/FluxMetalExt.jl +++ b/ext/FluxMetalExt/FluxMetalExt.jl @@ -12,6 +12,8 @@ using Zygote const USE_METAL = Ref{Union{Nothing, Bool}}(nothing) +(::Flux.FluxMetalDevice)(x) = Flux.gpu(Flux.FluxMetalAdaptor(), x) +Flux._get_device_name(::Flux.FluxMetalDevice) = "Metal" Flux._isavailable(::Flux.FluxMetalDevice) = true Flux._isfunctional(::Flux.FluxMetalDevice) = Metal.functional() diff --git a/src/functor.jl b/src/functor.jl index 24dc41d3ed..a39e542847 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -332,13 +332,15 @@ trainable(c::Cholesky) = () # CUDA extension. ######## -struct FluxCUDAAdaptor end +Base.@kwdef struct FluxCUDAAdaptor + ordinal::Union{Nothing, Int} = nothing +end const CUDA_LOADED = Ref{Bool}(false) -function gpu(::FluxCUDAAdaptor, x) +function gpu(to::FluxCUDAAdaptor, x) if CUDA_LOADED[] - return _cuda(x) + return _cuda(to.ordinal, x) else @info """ The CUDA functionality is being called but @@ -353,13 +355,15 @@ function _cuda end # AMDGPU extension. ######## -struct FluxAMDAdaptor end +Base.@kwdef struct FluxAMDAdaptor + ordinal::Union{Nothing, Int} = nothing +end const AMDGPU_LOADED = Ref{Bool}(false) -function gpu(::FluxAMDAdaptor, x) +function gpu(to::FluxAMDAdaptor, x) if AMDGPU_LOADED[] - return _amd(x) + return _amd(to.ordinal, x) else @info """ The AMDGPU functionality is being called but @@ -500,9 +504,6 @@ Base.@kwdef struct FluxCUDADevice <: AbstractDevice deviceID end -(::FluxCUDADevice)(x) = gpu(FluxCUDAAdaptor(), x) -_get_device_name(::FluxCUDADevice) = "CUDA" - """ FluxAMDDevice <: AbstractDevice @@ -512,9 +513,6 @@ Base.@kwdef struct FluxAMDDevice <: AbstractDevice deviceID end -(::FluxAMDDevice)(x) = gpu(FluxAMDAdaptor(), x) -_get_device_name(::FluxAMDDevice) = "AMD" - """ FluxMetalDevice <: AbstractDevice @@ -524,9 +522,6 @@ Base.@kwdef struct FluxMetalDevice <: AbstractDevice deviceID end -(::FluxMetalDevice)(x) = gpu(FluxMetalAdaptor(), x) -_get_device_name(::FluxMetalDevice) = "Metal" - ## device list. order is important const DEVICES = Ref{Vector{Union{Nothing, AbstractDevice}}}(Vector{Union{Nothing, AbstractDevice}}(nothing, length(GPU_BACKENDS))) DEVICES[][GPU_BACKEND_ORDER["CPU"]] = FluxCPUDevice() @@ -550,7 +545,7 @@ julia> Flux.supported_devices() supported_devices() = GPU_BACKENDS """ - Flux.get_device(; verbose=false)::AbstractDevice + Flux.get_device(; verbose=false)::Flux.AbstractDevice Returns a `device` object for the most appropriate backend for the current Julia session. @@ -653,3 +648,45 @@ function get_device(; verbose=false)::AbstractDevice end end end + +""" + Flux.get_device(backend::String, ordinal::Int = 0)::Flux.AbstractDevice + +Get a device object for a backend specified by the string `backend` and `ordinal`. The currently supported values +of `backend` are `"CUDA"`, `"AMD"` and `"CPU"`. `ordinal` must be an integer value between `0` and the number of available devices. + +# Examples + +```julia-repl +julia> using Flux, CUDA; + +julia> CUDA.devices() +CUDA.DeviceIterator() for 3 devices: +0. GeForce RTX 2080 Ti +1. GeForce RTX 2080 Ti +2. TITAN X (Pascal) + +julia> device0 = Flux.get_device("CUDA", 0) +(::Flux.FluxCUDADevice) (generic function with 1 method) + +julia> device0.deviceID +CuDevice(0): GeForce RTX 2080 Ti + +julia> device1 = Flux.get_device("CUDA", 1) +(::Flux.FluxCUDADevice) (generic function with 1 method) + +julia> device1.deviceID +CuDevice(1): GeForce RTX 2080 Ti + +julia> cpu_device = Flux.get_device("CPU") +(::Flux.FluxCPUDevice) (generic function with 1 method) + +``` +""" +function get_device(backend::String, ordinal::Int = 0) + if backend == "CPU" + return FluxCPUDevice() + else + return get_device(Val(Symbol(backend)), ordinal) + end +end diff --git a/test/ext_amdgpu/get_devices.jl b/test/ext_amdgpu/get_devices.jl index 7691241f38..9023fa743c 100644 --- a/test/ext_amdgpu/get_devices.jl +++ b/test/ext_amdgpu/get_devices.jl @@ -9,16 +9,41 @@ else @test typeof(amd_device.deviceID) <: Nothing end +# testing get_device +dense_model = Dense(2 => 3) # initially lives on CPU +weight = copy(dense_model.weight) # store the weight +bias = copy(dense_model.bias) # store the bias if AMDGPU.functional() && AMDGPU.functional(:MIOpen) - device = Flux.get_device() + amd_device = Flux.get_device() - @test typeof(device) <: Flux.FluxAMDDevice - @test typeof(device.deviceID) <: AMDGPU.HIPDevice - @test Flux._get_device_name(device) in Flux.supported_devices() + @test typeof(amd_device) <: Flux.FluxAMDDevice + @test typeof(amd_device.deviceID) <: AMDGPU.HIPDevice + @test Flux._get_device_name(amd_device) in Flux.supported_devices() # correctness of data transfer x = randn(5, 5) - cx = x |> device + cx = x |> amd_device @test cx isa AMDGPU.ROCArray - @test AMDGPU.device_id(AMDGPU.device(cx)) == AMDGPU.device_id(device.deviceID) + @test AMDGPU.device_id(AMDGPU.device(cx)) == AMDGPU.device_id(amd_device.deviceID) + + # moving models to specific NVIDIA devices + for ordinal in 0:(length(AMDGPU.devices()) - 1) + current_amd_device = Flux.get_device("AMD", ordinal) + @test typeof(current_amd_device.deviceID) <: AMDGPU.HIPDevice + @test AMDGPU.device_id(current_amd_device.deviceID) == ordinal + 1 + + global dense_model = dense_model |> current_amd_device + @test dense_model.weight isa AMDGPU.ROCArray + @test dense_model.bias isa AMDGPU.ROCArray + @test AMDGPU.device_id(AMDGPU.device(dense_model.weight)) == ordinal + 1 + @test AMDGPU.device_id(AMDGPU.device(dense_model.bias)) == ordinal + 1 + @test isequal(Flux.cpu(dense_model.weight), weight) + @test isequal(Flux.cpu(dense_model.bias), bias) + end + # finally move to CPU, and see if things work + cpu_device = Flux.get_device("CPU") + dense_model = cpu_device(dense_model) + @test dense_model.weight isa Matrix + @test dense_model.bias isa Vector + end diff --git a/test/ext_cuda/get_devices.jl b/test/ext_cuda/get_devices.jl index e1f4c7d8a8..f3adf4233e 100644 --- a/test/ext_cuda/get_devices.jl +++ b/test/ext_cuda/get_devices.jl @@ -10,16 +10,40 @@ else end # testing get_device +dense_model = Dense(2 => 3) # initially lives on CPU +weight = copy(dense_model.weight) # store the weight +bias = copy(dense_model.bias) # store the bias if CUDA.functional() - device = Flux.get_device() + cuda_device = Flux.get_device() - @test typeof(device) <: Flux.FluxCUDADevice - @test typeof(device.deviceID) <: CUDA.CuDevice - @test Flux._get_device_name(device) in Flux.supported_devices() + @test typeof(cuda_device) <: Flux.FluxCUDADevice + @test typeof(cuda_device.deviceID) <: CUDA.CuDevice + @test Flux._get_device_name(cuda_device) in Flux.supported_devices() # correctness of data transfer x = randn(5, 5) - cx = x |> device + cx = x |> cuda_device @test cx isa CUDA.CuArray - @test CUDA.device(cx).handle == device.deviceID.handle + @test CUDA.device(cx).handle == cuda_device.deviceID.handle + + # moving models to specific NVIDIA devices + for ordinal in 0:(length(CUDA.devices()) - 1) + current_cuda_device = Flux.get_device("CUDA", ordinal) + @test typeof(current_cuda_device.deviceID) <: CUDA.CuDevice + @test current_cuda_device.deviceID.handle == ordinal + + global dense_model = dense_model |> current_cuda_device + @test dense_model.weight isa CUDA.CuArray + @test dense_model.bias isa CUDA.CuArray + @test CUDA.device(dense_model.weight).handle == ordinal + @test CUDA.device(dense_model.bias).handle == ordinal + @test isequal(Flux.cpu(dense_model.weight), weight) + @test isequal(Flux.cpu(dense_model.bias), bias) + end + # finally move to CPU, and see if things work + cpu_device = Flux.get_device("CPU") + dense_model = cpu_device(dense_model) + @test dense_model.weight isa Matrix + @test dense_model.bias isa Vector + end diff --git a/test/ext_metal/get_devices.jl b/test/ext_metal/get_devices.jl index 83786e9834..12f7b87210 100644 --- a/test/ext_metal/get_devices.jl +++ b/test/ext_metal/get_devices.jl @@ -11,15 +11,15 @@ end # testing get_device if Metal.functional() - device = Flux.get_device() + metal_device = Flux.get_device() - @test typeof(device) <: Flux.FluxMetalDevice - @test typeof(device.deviceID) <: Metal.MTLDevice - @test Flux._get_device_name(device) in Flux.supported_devices() + @test typeof(metal_device) <: Flux.FluxMetalDevice + @test typeof(metal_device.deviceID) <: Metal.MTLDevice + @test Flux._get_device_name(metal_device) in Flux.supported_devices() # correctness of data transfer x = randn(5, 5) - cx = x |> device + cx = x |> metal_device @test cx isa Metal.MtlArray - @test Metal.device(cx).registryID == device.deviceID.registryID + @test Metal.device(cx).registryID == metal_device.deviceID.registryID end diff --git a/test/functors.jl b/test/functors.jl index b2c6c37f8d..a9f26194ea 100644 --- a/test/functors.jl +++ b/test/functors.jl @@ -8,7 +8,11 @@ end @test typeof(Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]]) <: Nothing @test typeof(Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["CPU"]]) <: Flux.FluxCPUDevice -device = Flux.get_device() -@test typeof(device) <: Flux.FluxCPUDevice -@test device(x) == x -@test Flux._get_device_name(device) in Flux.supported_devices() +dev = Flux.get_device() +@test typeof(dev) <: Flux.FluxCPUDevice +@test dev(x) == x +@test Flux._get_device_name(dev) in Flux.supported_devices() + +# specifically getting CPU device +dev = Flux.get_device("CPU") +@test typeof(dev) <: Flux.FluxCPUDevice