Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement interface for data transfer across GPU devices. #2308

Merged
merged 29 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3d5d849
Adding new `get_device` method to return a CUDA device with particular
codetalker7 Aug 7, 2023
f720d5f
Adding an `adapt` function for `AbstractArray` to handle movement across
codetalker7 Aug 9, 2023
40e085c
Making the `get_device` interface simpler, and some minor changes.
codetalker7 Aug 12, 2023
652bf95
Adding CPU option to `get_device`.
codetalker7 Aug 12, 2023
b3cd292
Removing `KernelAbstractions` from deps.
codetalker7 Aug 12, 2023
9925c5b
Adding new `get_device` method to return a particular AMD device.
codetalker7 Aug 14, 2023
df577f3
Adding new `adapt_storage` function for moving arrays. Also passing
codetalker7 Aug 14, 2023
8aa7eed
Moving relevant function definitions to extensions.
codetalker7 Aug 14, 2023
f137080
Making `_metal` accept an ordinal.
codetalker7 Aug 14, 2023
ef265eb
Adding new `get_device` method to return particular Metal device.
codetalker7 Aug 14, 2023
3fbb4f5
Adding new `adapt_storage` method for metal arrays.
codetalker7 Aug 14, 2023
2c6bc55
Fixing minor error.
codetalker7 Aug 14, 2023
829dcfa
Fixing minor error and spelling mistake.
codetalker7 Aug 15, 2023
930d29c
Fixing package name: `AMDGPU` instead of `AMD`.
codetalker7 Aug 16, 2023
fedee3b
Reverting back to old metal functionality.
codetalker7 Aug 18, 2023
a4449f8
Adding tests for moving models between CPU and NVIDIA devices.
codetalker7 Aug 18, 2023
2a14650
Adding tests for data movement on AMD devices.
codetalker7 Aug 18, 2023
a9fb328
Fixing index error while choosing AMD gpu device.
codetalker7 Aug 18, 2023
91a756f
Fixing AMD ordinal starting index.
codetalker7 Aug 20, 2023
a9f6615
Adding docstring for new `get_device` method.
codetalker7 Aug 20, 2023
b47a6f4
Removing global name conflicts in tests.
codetalker7 Aug 20, 2023
f1ab569
Minor fix to AMDs device id tests.
codetalker7 Aug 20, 2023
129a0b5
Disambiguating test variables.
codetalker7 Aug 20, 2023
7a5b847
Adding more info in docstring of `get_device`, and writing some
codetalker7 Aug 20, 2023
1b770bc
Fixing minor error in AMD code.
codetalker7 Aug 20, 2023
8c09485
Fixing yet another ordinal index error in AMD code.
codetalker7 Aug 20, 2023
e367598
Fixing another ordinal index error in AMD code.
codetalker7 Aug 20, 2023
08b1962
Fixing spelling mistake.
codetalker7 Aug 20, 2023
0063bc0
Replacing type checks for `nothing` but equality checks.
codetalker7 Aug 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extensions]
FluxAMDGPUExt = "AMDGPU"
Expand Down Expand Up @@ -56,8 +56,8 @@ julia = "1.9"
[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Expand All @@ -68,6 +68,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra",
"FillArrays", "ComponentArrays", "BSON", "Pkg",
"CUDA", "cuDNN", "Metal", "AMDGPU"]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU"]
65 changes: 65 additions & 0 deletions docs/src/gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,71 @@ julia> device = Flux.get_device(; verbose=true) # this will resort to auto
```
For detailed information about how the backend is selected, check the documentation for [`Flux.get_device`](@ref).

## Data movement across GPU devices

Flux also supports getting handles to specific GPU devices, and transferring models from one GPU device to another GPU
device from the same backend. Let's try it out for NVIDIA GPUs. First, we list all the available devices:

```julia-repl
julia> using Flux, CUDA;

julia> CUDA.devices()
CUDA.DeviceIterator() for 3 devices:
0. GeForce RTX 2080 Ti
1. GeForce RTX 2080 Ti
2. TITAN X (Pascal)

```

Then, let's select the device with ordinal `0`:

```julia-repl
julia> device0 = Flux.get_device("CUDA", 0) # the currently supported values for backend are "CUDA" and "AMD"
(::Flux.FluxCUDADevice) (generic function with 1 method)

```

Then, let's move a simple dense layer to the GPU represented by `device0`:

```julia-repl
julia> dense_model = Dense(2 => 3)
Dense(2 => 3) # 9 parameters

julia> dense_model = dense_model |> device0;

julia> dense_model.weight
3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
0.695662 0.816299
-0.204763 -0.10232
-0.955829 0.538412

julia> CUDA.device(dense_model.weight) # check the GPU to which dense_model is attached
CuDevice(0): GeForce RTX 2080 Ti

```

Next, we'll get a handle to the device with ordinal `1`, and move `dense_model` to that device:

```julia-repl
julia> device1 = Flux.get_device("CUDA", 1)
(::Flux.FluxCUDADevice) (generic function with 1 method)

julia> dense_model = dense_model |> device1; # don't directly print the model; see warning below

julia> CUDA.device(dense_model.weight)
CuDevice(1): GeForce RTX 2080 Ti

```

Due to a limitation in `Metal.jl`, currently this kind of data movement across devices is only supported for `CUDA` and `AMD` backends.

!!! warning "Printing models after moving to a different device"

Due to a limitation in how GPU packages currently work, printing
models on the REPL after moving them to a GPU device which is different
from the current device will lead to an error.


```@docs
Flux.AbstractDevice
Flux.FluxCPUDevice
Expand Down
8 changes: 8 additions & 0 deletions ext/FluxAMDGPUExt/FluxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat
# Set to boolean on the first call to check_use_amdgpu
const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing)

function (device::Flux.FluxAMDDevice)(x)
if device.deviceID === nothing
Flux.gpu(Flux.FluxAMDAdaptor(), x)
else
return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID) - 1), x) # subtracting 1, because device_id returns a positive integer
end
end
Flux._get_device_name(::Flux.FluxAMDDevice) = "AMD"
Flux._isavailable(::Flux.FluxAMDDevice) = true
Flux._isfunctional(::Flux.FluxAMDDevice) = AMDGPU.functional()

Expand Down
55 changes: 47 additions & 8 deletions ext/FluxAMDGPUExt/functor.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,41 @@
# Convert Float64 to Float32, but preserve Float16.
adapt_storage(::FluxAMDAdaptor, x::T) where T <: AbstractArray =
isbits(x) ? x : ROCArray(x)
adapt_storage(::FluxAMDAdaptor, x::AbstractArray{T, N}) where {T <: AbstractFloat, N} =
isbits(x) ? x : ROCArray{Float32, N}(x)
adapt_storage(::FluxAMDAdaptor, x::AbstractArray{Float16, N}) where N =
isbits(x) ? x : ROCArray{Float16, N}(x)
function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray)
if to.ordinal === nothing
if (typeof(x) <: AbstractArray{Float16, N} where N)
N = length(size(x))
return isbits(x) ? x : ROCArray{Float16, N}(x)
elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N})
N = length(size(x))
return isbits(x) ? x : ROCArray{Float32, N}(x)
else
return isbits(x) ? x : ROCArray(x)
end
end

old_ordinal = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ordinals start from 0

if !(x isa ROCArray)
AMDGPU.device!(AMDGPU.devices()[to.ordinal + 1]) # adding 1 because ordinals start from 0
if (typeof(x) <: AbstractArray{Float16, N} where N)
N = length(size(x))
x_new = isbits(x) ? x : ROCArray{Float16, N}(x)
elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N})
N = length(size(x))
x_new = isbits(x) ? x : ROCArray{Float32, N}(x)
else
x_new = isbits(x) ? x : ROCArray(x)
end
AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1])
return x_new
elseif AMDGPU.device_id(AMDGPU.device(x)) == to.ordinal
return x
else
AMDGPU.device!(AMDGPU.devices()[to.ordinal + 1])
x_new = copy(x)
AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1])
return x_new
end
end

adapt_storage(::FluxAMDAdaptor, x::Zygote.FillArrays.AbstractFill) =
ROCArray(collect(x))
Expand Down Expand Up @@ -45,10 +76,10 @@ Flux._isleaf(::AMD_CONV) = true
_exclude(x) = Flux._isleaf(x)
_exclude(::CPU_CONV) = true

function _amd(x)
function _amd(ordinal::Union{Nothing, Int}, x)
check_use_amdgpu()
USE_AMDGPU[] || return x
fmap(x -> Adapt.adapt(FluxAMDAdaptor(), x), x; exclude=_exclude)
fmap(x -> Adapt.adapt(FluxAMDAdaptor(ordinal), x), x; exclude=_exclude)
end

# CPU -> GPU
Expand All @@ -74,3 +105,11 @@ function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMD_CONV)
Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims),
Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups)
end

function Flux.get_device(::Val{:AMD}, ordinal::Int) # ordinal should start from 0
old_ordinal = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ordinals start from 0
AMDGPU.device!(AMDGPU.devices()[ordinal + 1]) # adding 1 because ordinals start from 0
device = Flux.FluxAMDDevice(AMDGPU.device())
AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1])
return device
end
8 changes: 8 additions & 0 deletions ext/FluxCUDAExt/FluxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ import Adapt: adapt_storage

const USE_CUDA = Ref{Union{Nothing, Bool}}(nothing)

function (device::Flux.FluxCUDADevice)(x)
if device.deviceID === nothing
return Flux.gpu(Flux.FluxCUDAAdaptor(), x)
else
return Flux.gpu(Flux.FluxCUDAAdaptor(device.deviceID.handle), x)
end
end
Flux._get_device_name(::Flux.FluxCUDADevice) = "CUDA"
Flux._isavailable(::Flux.FluxCUDADevice) = true
Flux._isfunctional(::Flux.FluxCUDADevice) = CUDA.functional()

Expand Down
33 changes: 30 additions & 3 deletions ext/FluxCUDAExt/functor.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@

adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
function adapt_storage(to::FluxCUDAAdaptor, x::AbstractArray)
to.ordinal === nothing && return CUDA.cu(x)

# remember current device
old_ordinal = CUDA.device().handle

if !(x isa CuArray)
CUDA.device!(to.ordinal)
x_new = CUDA.cu(x)
CUDA.device!(old_ordinal)
return x_new
elseif CUDA.device(x).handle == to.ordinal
return x
else
CUDA.device!(to.ordinal)
x_new = copy(x)
CUDA.device!(old_ordinal)
return x_new
end
end
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x
Expand All @@ -25,8 +44,16 @@ ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AnyCuArray) =
ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AbstractArray) =
adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCPUAdaptor(), unthunk(Δ)))

function _cuda(x)
function _cuda(ordinal::Union{Nothing, Int}, x)
check_use_cuda()
USE_CUDA[] || return x
fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude=Flux._isleaf)
fmap(x -> Adapt.adapt(FluxCUDAAdaptor(ordinal), x), x; exclude=Flux._isleaf)
end

function Flux.get_device(::Val{:CUDA}, ordinal::Int)
old_ordinal = CUDA.device().handle
CUDA.device!(ordinal)
device = Flux.FluxCUDADevice(CUDA.device())
CUDA.device!(old_ordinal)
return device
end
2 changes: 2 additions & 0 deletions ext/FluxMetalExt/FluxMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ using Zygote

const USE_METAL = Ref{Union{Nothing, Bool}}(nothing)

(::Flux.FluxMetalDevice)(x) = Flux.gpu(Flux.FluxMetalAdaptor(), x)
Flux._get_device_name(::Flux.FluxMetalDevice) = "Metal"
Flux._isavailable(::Flux.FluxMetalDevice) = true
Flux._isfunctional(::Flux.FluxMetalDevice) = Metal.functional()

Expand Down
69 changes: 53 additions & 16 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,15 @@ trainable(c::Cholesky) = ()

# CUDA extension. ########

struct FluxCUDAAdaptor end
Base.@kwdef struct FluxCUDAAdaptor
ordinal::Union{Nothing, Int} = nothing
end

const CUDA_LOADED = Ref{Bool}(false)

function gpu(::FluxCUDAAdaptor, x)
function gpu(to::FluxCUDAAdaptor, x)
if CUDA_LOADED[]
return _cuda(x)
return _cuda(to.ordinal, x)
else
@info """
The CUDA functionality is being called but
Expand All @@ -353,13 +355,15 @@ function _cuda end

# AMDGPU extension. ########

struct FluxAMDAdaptor end
Base.@kwdef struct FluxAMDAdaptor
ordinal::Union{Nothing, Int} = nothing
end

const AMDGPU_LOADED = Ref{Bool}(false)

function gpu(::FluxAMDAdaptor, x)
function gpu(to::FluxAMDAdaptor, x)
if AMDGPU_LOADED[]
return _amd(x)
return _amd(to.ordinal, x)
else
@info """
The AMDGPU functionality is being called but
Expand Down Expand Up @@ -500,9 +504,6 @@ Base.@kwdef struct FluxCUDADevice <: AbstractDevice
deviceID
end

(::FluxCUDADevice)(x) = gpu(FluxCUDAAdaptor(), x)
_get_device_name(::FluxCUDADevice) = "CUDA"

"""
FluxAMDDevice <: AbstractDevice

Expand All @@ -512,9 +513,6 @@ Base.@kwdef struct FluxAMDDevice <: AbstractDevice
deviceID
end

(::FluxAMDDevice)(x) = gpu(FluxAMDAdaptor(), x)
_get_device_name(::FluxAMDDevice) = "AMD"

"""
FluxMetalDevice <: AbstractDevice

Expand All @@ -524,9 +522,6 @@ Base.@kwdef struct FluxMetalDevice <: AbstractDevice
deviceID
end

(::FluxMetalDevice)(x) = gpu(FluxMetalAdaptor(), x)
_get_device_name(::FluxMetalDevice) = "Metal"

## device list. order is important
const DEVICES = Ref{Vector{Union{Nothing, AbstractDevice}}}(Vector{Union{Nothing, AbstractDevice}}(nothing, length(GPU_BACKENDS)))
DEVICES[][GPU_BACKEND_ORDER["CPU"]] = FluxCPUDevice()
Expand All @@ -550,7 +545,7 @@ julia> Flux.supported_devices()
supported_devices() = GPU_BACKENDS

"""
Flux.get_device(; verbose=false)::AbstractDevice
Flux.get_device(; verbose=false)::Flux.AbstractDevice

Returns a `device` object for the most appropriate backend for the current Julia session.

Expand Down Expand Up @@ -653,3 +648,45 @@ function get_device(; verbose=false)::AbstractDevice
end
end
end

"""
Flux.get_device(backend::String, ordinal::Int = 0)::Flux.AbstractDevice

Get a device object for a backend specified by the string `backend` and `ordinal`. The currently supported values
of `backend` are `"CUDA"`, `"AMD"` and `"CPU"`. `ordinal` must be an integer value between `0` and the number of available devices.

# Examples

```julia-repl
julia> using Flux, CUDA;

julia> CUDA.devices()
CUDA.DeviceIterator() for 3 devices:
0. GeForce RTX 2080 Ti
1. GeForce RTX 2080 Ti
2. TITAN X (Pascal)

julia> device0 = Flux.get_device("CUDA", 0)
(::Flux.FluxCUDADevice) (generic function with 1 method)

julia> device0.deviceID
CuDevice(0): GeForce RTX 2080 Ti

julia> device1 = Flux.get_device("CUDA", 1)
(::Flux.FluxCUDADevice) (generic function with 1 method)

julia> device1.deviceID
CuDevice(1): GeForce RTX 2080 Ti

julia> cpu_device = Flux.get_device("CPU")
(::Flux.FluxCPUDevice) (generic function with 1 method)

```
"""
function get_device(backend::String, ordinal::Int = 0)
if backend == "CPU"
return FluxCPUDevice()
else
return get_device(Val(Symbol(backend)), ordinal)
end
end
Loading
Loading