Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement interface for data transfer across GPU devices. #2308

Merged
merged 29 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3d5d849
Adding new `get_device` method to return a CUDA device with particular
codetalker7 Aug 7, 2023
f720d5f
Adding an `adapt` function for `AbstractArray` to handle movement across
codetalker7 Aug 9, 2023
40e085c
Making the `get_device` interface simpler, and some minor changes.
codetalker7 Aug 12, 2023
652bf95
Adding CPU option to `get_device`.
codetalker7 Aug 12, 2023
b3cd292
Removing `KernelAbstractions` from deps.
codetalker7 Aug 12, 2023
9925c5b
Adding new `get_device` method to return a particular AMD device.
codetalker7 Aug 14, 2023
df577f3
Adding new `adapt_storage` function for moving arrays. Also passing
codetalker7 Aug 14, 2023
8aa7eed
Moving relevant function definitions to extensions.
codetalker7 Aug 14, 2023
f137080
Making `_metal` accept an ordinal.
codetalker7 Aug 14, 2023
ef265eb
Adding new `get_device` method to return particular Metal device.
codetalker7 Aug 14, 2023
3fbb4f5
Adding new `adapt_storage` method for metal arrays.
codetalker7 Aug 14, 2023
2c6bc55
Fixing minor error.
codetalker7 Aug 14, 2023
829dcfa
Fixing minor error and spelling mistake.
codetalker7 Aug 15, 2023
930d29c
Fixing package name: `AMDGPU` instead of `AMD`.
codetalker7 Aug 16, 2023
fedee3b
Reverting back to old metal functionality.
codetalker7 Aug 18, 2023
a4449f8
Adding tests for moving models between CPU and NVIDIA devices.
codetalker7 Aug 18, 2023
2a14650
Adding tests for data movement on AMD devices.
codetalker7 Aug 18, 2023
a9fb328
Fixing index error while choosing AMD gpu device.
codetalker7 Aug 18, 2023
91a756f
Fixing AMD ordinal starting index.
codetalker7 Aug 20, 2023
a9f6615
Adding docstring for new `get_device` method.
codetalker7 Aug 20, 2023
b47a6f4
Removing global name conflicts in tests.
codetalker7 Aug 20, 2023
f1ab569
Minor fix to AMDs device id tests.
codetalker7 Aug 20, 2023
129a0b5
Disambiguating test variables.
codetalker7 Aug 20, 2023
7a5b847
Adding more info in docstring of `get_device`, and writing some
codetalker7 Aug 20, 2023
1b770bc
Fixing minor error in AMD code.
codetalker7 Aug 20, 2023
8c09485
Fixing yet another ordinal index error in AMD code.
codetalker7 Aug 20, 2023
e367598
Fixing another ordinal index error in AMD code.
codetalker7 Aug 20, 2023
08b1962
Fixing spelling mistake.
codetalker7 Aug 20, 2023
0063bc0
Replacing type checks for `nothing` but equality checks.
codetalker7 Aug 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extensions]
FluxAMDGPUExt = "AMDGPU"
Expand Down Expand Up @@ -56,8 +56,8 @@ julia = "1.9"
[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Expand All @@ -68,6 +68,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra",
"FillArrays", "ComponentArrays", "BSON", "Pkg",
"CUDA", "cuDNN", "Metal", "AMDGPU"]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU"]
8 changes: 8 additions & 0 deletions ext/FluxAMDGPUExt/FluxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat
# Set to boolean on the first call to check_use_amdgpu
const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing)

function (device::Flux.FluxAMDDevice)(x)
if typeof(device.deviceID) <: Nothing
Flux.gpu(Flux.FluxAMDAdaptor(), x)
else
return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID)) - 1, x) # subtracting 1, because device_id returns a positive integer
end
end
Flux._get_device_name(::Flux.FluxAMDDevice) = "AMD"
Flux._isavailable(::Flux.FluxAMDDevice) = true
Flux._isfunctional(::Flux.FluxAMDDevice) = AMDGPU.functional()

Expand Down
55 changes: 47 additions & 8 deletions ext/FluxAMDGPUExt/functor.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,41 @@
# Convert Float64 to Float32, but preserve Float16.
adapt_storage(::FluxAMDAdaptor, x::T) where T <: AbstractArray =
isbits(x) ? x : ROCArray(x)
adapt_storage(::FluxAMDAdaptor, x::AbstractArray{T, N}) where {T <: AbstractFloat, N} =
isbits(x) ? x : ROCArray{Float32, N}(x)
adapt_storage(::FluxAMDAdaptor, x::AbstractArray{Float16, N}) where N =
isbits(x) ? x : ROCArray{Float16, N}(x)
function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray)
if typeof(to.ordinal) <: Nothing
if (typeof(x) <: AbstractArray{Float16, N} where N)
N = length(size(x))
return isbits(x) ? x : ROCArray{Float16, N}(x)
elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N})
N = length(size(x))
return isbits(x) ? x : ROCArray{Float32, N}(x)
else
return isbits(x) ? x : ROCArray(x)
end
end

old_ordinal = AMDGPU.device_id(AMDGPU.device())

if !(x isa ROCArray)
AMDGPU.device!(AMDGPU.devices()[to.ordinal + 1]) # adding 1 because ordinals start from 0
if (typeof(x) <: AbstractArray{Float16, N} where N)
N = length(size(x))
x_new = isbits(x) ? x : ROCArray{Float16, N}(x)
elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N})
N = length(size(x))
x_new = isbits(x) ? x : ROCArray{Float32, N}(x)
else
x_new = isbits(x) ? x : ROCArray(x)
end
AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1])
return x_new
elseif AMDGPU.device_id(AMDGPU.device(x)) == to.ordinal
return x
else
AMDGPU.device!(AMDGPU.devices()[to.ordinal + 1])
x_new = copy(x)
AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1])
return x_new
end
end

adapt_storage(::FluxAMDAdaptor, x::Zygote.FillArrays.AbstractFill) =
ROCArray(collect(x))
Expand Down Expand Up @@ -45,10 +76,10 @@ Flux._isleaf(::AMD_CONV) = true
_exclude(x) = Flux._isleaf(x)
_exclude(::CPU_CONV) = true

function _amd(x)
function _amd(ordinal::Union{Nothing, Int}, x)
check_use_amdgpu()
USE_AMDGPU[] || return x
fmap(x -> Adapt.adapt(FluxAMDAdaptor(), x), x; exclude=_exclude)
fmap(x -> Adapt.adapt(FluxAMDAdaptor(ordinal), x), x; exclude=_exclude)
end

# CPU -> GPU
Expand All @@ -74,3 +105,11 @@ function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMD_CONV)
Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims),
Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups)
end

function Flux.get_device(::Val{:AMD}, ordinal::Int) # ordinal should start from 0
old_ordinal = AMDGPU.device_id(AMDGPU.device())
AMDGPU.device!(AMDGPU.devices()[ordinal + 1]) # adding 1 because ordinals start from 0
device = Flux.FluxAMDDevice(AMDGPU.device())
AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1])
return device
end
8 changes: 8 additions & 0 deletions ext/FluxCUDAExt/FluxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ import Adapt: adapt_storage

const USE_CUDA = Ref{Union{Nothing, Bool}}(nothing)

function (device::Flux.FluxCUDADevice)(x)
if typeof(device.deviceID) <: Nothing
return Flux.gpu(Flux.FluxCUDAAdaptor(), x)
else
return Flux.gpu(Flux.FluxCUDAAdaptor(device.deviceID.handle), x)
end
end
Flux._get_device_name(::Flux.FluxCUDADevice) = "CUDA"
Flux._isavailable(::Flux.FluxCUDADevice) = true
Flux._isfunctional(::Flux.FluxCUDADevice) = CUDA.functional()

Expand Down
33 changes: 30 additions & 3 deletions ext/FluxCUDAExt/functor.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@

adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
function adapt_storage(to::FluxCUDAAdaptor, x::AbstractArray)
typeof(to.ordinal) <: Nothing && return CUDA.cu(x)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
typeof(to.ordinal) <: Nothing && return CUDA.cu(x)
to.ordinal === nothing && return CUDA.cu(x)

# remember current device
old_ordinal = CUDA.device().handle

if !(x isa CuArray)
CUDA.device!(to.ordinal)
x_new = CUDA.cu(x)
CUDA.device!(old_ordinal)
return x_new
elseif CUDA.device(x).handle == to.ordinal
return x
else
CUDA.device!(to.ordinal)
x_new = copy(x)
CUDA.device!(old_ordinal)
return x_new
end
end
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x
Expand All @@ -25,8 +44,16 @@ ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AnyCuArray) =
ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AbstractArray) =
adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCPUAdaptor(), unthunk(Δ)))

function _cuda(x)
function _cuda(ordinal::Union{Nothing, Int}, x)
check_use_cuda()
USE_CUDA[] || return x
fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude=Flux._isleaf)
fmap(x -> Adapt.adapt(FluxCUDAAdaptor(ordinal), x), x; exclude=Flux._isleaf)
end

function Flux.get_device(::Val{:CUDA}, ordinal::Int)
old_ordinal = CUDA.device().handle
CUDA.device!(ordinal)
device = Flux.FluxCUDADevice(CUDA.device())
CUDA.device!(old_ordinal)
return device
end
2 changes: 2 additions & 0 deletions ext/FluxMetalExt/FluxMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ using Zygote

const USE_METAL = Ref{Union{Nothing, Bool}}(nothing)

(::Flux.FluxMetalDevice)(x) = Flux.gpu(Flux.FluxMetalAdaptor(), x)
Flux._get_device_name(::Flux.FluxMetalDevice) = "Metal"
Flux._isavailable(::Flux.FluxMetalDevice) = true
Flux._isfunctional(::Flux.FluxMetalDevice) = Metal.functional()

Expand Down
69 changes: 53 additions & 16 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,15 @@ trainable(c::Cholesky) = ()

# CUDA extension. ########

struct FluxCUDAAdaptor end
Base.@kwdef struct FluxCUDAAdaptor
ordinal::Union{Nothing, Int} = nothing
end

const CUDA_LOADED = Ref{Bool}(false)

function gpu(::FluxCUDAAdaptor, x)
function gpu(to::FluxCUDAAdaptor, x)
if CUDA_LOADED[]
return _cuda(x)
return _cuda(to.ordinal, x)
else
@info """
The CUDA functionality is being called but
Expand All @@ -353,13 +355,15 @@ function _cuda end

# AMDGPU extension. ########

struct FluxAMDAdaptor end
Base.@kwdef struct FluxAMDAdaptor
ordinal::Union{Nothing, Int} = nothing
end

const AMDGPU_LOADED = Ref{Bool}(false)

function gpu(::FluxAMDAdaptor, x)
function gpu(to::FluxAMDAdaptor, x)
if AMDGPU_LOADED[]
return _amd(x)
return _amd(to.ordinal, x)
else
@info """
The AMDGPU functionality is being called but
Expand Down Expand Up @@ -500,9 +504,6 @@ Base.@kwdef struct FluxCUDADevice <: AbstractDevice
deviceID
end

(::FluxCUDADevice)(x) = gpu(FluxCUDAAdaptor(), x)
_get_device_name(::FluxCUDADevice) = "CUDA"

"""
FluxAMDDevice <: AbstractDevice

Expand All @@ -512,9 +513,6 @@ Base.@kwdef struct FluxAMDDevice <: AbstractDevice
deviceID
end

(::FluxAMDDevice)(x) = gpu(FluxAMDAdaptor(), x)
_get_device_name(::FluxAMDDevice) = "AMD"

"""
FluxMetalDevice <: AbstractDevice

Expand All @@ -524,9 +522,6 @@ Base.@kwdef struct FluxMetalDevice <: AbstractDevice
deviceID
end

(::FluxMetalDevice)(x) = gpu(FluxMetalAdaptor(), x)
_get_device_name(::FluxMetalDevice) = "Metal"

## device list. order is important
const DEVICES = Ref{Vector{Union{Nothing, AbstractDevice}}}(Vector{Union{Nothing, AbstractDevice}}(nothing, length(GPU_BACKENDS)))
DEVICES[][GPU_BACKEND_ORDER["CPU"]] = FluxCPUDevice()
Expand All @@ -550,7 +545,7 @@ julia> Flux.supported_devices()
supported_devices() = GPU_BACKENDS

"""
Flux.get_device(; verbose=false)::AbstractDevice
Flux.get_device(; verbose=false)::Flux.AbstractDevice

Returns a `device` object for the most appropriate backend for the current Julia session.

Expand Down Expand Up @@ -653,3 +648,45 @@ function get_device(; verbose=false)::AbstractDevice
end
end
end

"""
Flux.get_device(backend::String, ordinal::Int = 0)::Flux.AbstractDevice

Get a device object for a backend specified by the string `backend` and `ordinal`. The currently supported values
of `backend` are `"CUDA"`, `"AMD"` and `"CPU"`.

# Examples

```julia-repl
julia> using Flux, CUDA;

julia> CUDA.devices()
CUDA.DeviceIterator() for 3 devices:
0. GeForce RTX 2080 Ti
1. GeForce RTX 2080 Ti
2. TITAN X (Pascal)

julia> device0 = Flux.get_device("CUDA", 0)
(::Flux.FluxCUDADevice) (generic function with 1 method)

julia> device0.deviceID
CuDevice(0): GeForce RTX 2080 Ti

julia> device1 = Flux.get_device("CUDA", 1)
(::Flux.FluxCUDADevice) (generic function with 1 method)

julia> device1.deviceID
CuDevice(1): GeForce RTX 2080 Ti

julia> cpu_device = Flux.get_device("CPU")
(::Flux.FluxCPUDevice) (generic function with 1 method)

```
"""
function get_device(backend::String, ordinal::Int = 0)
if backend == "CPU"
return FluxCPUDevice()
else
return get_device(Val(Symbol(backend)), ordinal)
end
end
37 changes: 31 additions & 6 deletions test/ext_amdgpu/get_devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,41 @@ else
@test typeof(amd_device.deviceID) <: Nothing
end

# testing get_device
dense_model = Dense(2 => 3) # initially lives on CPU
weight = copy(dense_model.weight) # store the weight
bias = copy(dense_model.bias) # store the bias
if AMDGPU.functional() && AMDGPU.functional(:MIOpen)
device = Flux.get_device()
amd_device = Flux.get_device()

@test typeof(device) <: Flux.FluxAMDDevice
@test typeof(device.deviceID) <: AMDGPU.HIPDevice
@test Flux._get_device_name(device) in Flux.supported_devices()
@test typeof(amd_device) <: Flux.FluxAMDDevice
@test typeof(amd_device.deviceID) <: AMDGPU.HIPDevice
@test Flux._get_device_name(amd_device) in Flux.supported_devices()

# correctness of data transfer
x = randn(5, 5)
cx = x |> device
cx = x |> amd_device
@test cx isa AMDGPU.ROCArray
@test AMDGPU.device_id(AMDGPU.device(cx)) == AMDGPU.device_id(device.deviceID)
@test AMDGPU.device_id(AMDGPU.device(cx)) == AMDGPU.device_id(amd_device.deviceID)

# moving models to specific NVIDIA devices
for ordinal in 0:(length(AMDGPU.devices()) - 1)
current_amd_device = Flux.get_device("AMD", ordinal)
@test typeof(current_amd_device.deviceID) <: AMDGPU.HIPDevice
@test AMDGPU.device_id(current_amd_device.deviceID) == ordinal + 1

global dense_model = dense_model |> current_amd_device
@test dense_model.weight isa AMDGPU.ROCArray
@test dense_model.bias isa AMDGPU.ROCArray
@test ADMGPU.device_id(AMDGPU.device(dense_model.weight)) == ordinal + 1
@test ADMGPU.device_id(AMDGPU.device(dense_model.bias)) == ordinal + 1
@test isequal(Flux.cpu(dense_model.weight), weight)
@test isequal(Flux.cpu(dense_model.bias), bias)
end
# finally move to CPU, and see if things work
cpu_device = Flux.get_device("CPU")
dense_model = cpu_device(dense_model)
@test dense_model.weight isa Matrix
@test dense_model.bias isa Vector

end
Loading
Loading