Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement interface for data transfer across GPU devices. #2308

Merged
merged 29 commits into from
Aug 21, 2023

Conversation

codetalker7
Copy link
Contributor

@codetalker7 codetalker7 commented Aug 9, 2023

This PR addresses issue #2302. Here, the goal is to pass device information in adaptors (FluxCUDAAdaptor, FluxAMDAdaptor and FluxMetalAdaptor respectively), which will then be used to do data transfer at the parameter level only.

Also, new Flux.get_device methods will be added in the extensions which can return a Flux.AbstractDevice with a specific backend and ordinal.

Example for data movement with CUDA

julia> using Flux, CUDA;

julia> CUDA.devices()
CUDA.DeviceIterator() for 3 devices:
0. GeForce RTX 2080 Ti
1. GeForce RTX 2080 Ti
2. TITAN X (Pascal)

julia> CUDA.device()        # default device is 0
CuDevice(0): GeForce RTX 2080 Ti

julia> m = Dense(2 => 3);

julia> m.weight       # m lives on CPU initially
3×2 Matrix{Float32}: 
 -0.705636  -0.369204
 -0.265636  -0.730012
 -0.461941   0.0260784

julia> device1 = Flux.get_device(CUDABackend, UInt(1))    # get device with id 1
(::Flux.FluxCUDADevice) (generic function with 1 method)

julia> device1.deviceID
CuDevice(1): GeForce RTX 2080 Ti

julia> m = m |> device1;    # transfer m from CPU to device1

julia> m.weight
3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 -0.705636  -0.369204
 -0.265636  -0.730012
 -0.461941   0.0260784

julia> device(m.weight)    # verify that m is attached to device1
CuDevice(1): GeForce RTX 2080 Ti

julia> device2 = Flux.get_device(CUDABackend, UInt(2))    # get device with id 2
(::Flux.FluxCUDADevice) (generic function with 1 method)

julia> m = m |> device2;         # now transfer m to device2

julia> m.weight
3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 -0.370832  -1.05765
  0.926787  -0.192985
  0.686085  -0.00252177

julia> device(m.weight)
CuDevice(2): TITAN X (Pascal)

cc @CarloLucibello @ToucheSir @darsnack.

PR Checklist

  • Do we need to garbage collect manually?
  • What about adapt_storage for data types other than AbstractArray?
  • Add relevant tests.
  • Adding relevant documentation.

@codetalker7
Copy link
Contributor Author

Also, things seem to be working fine here, except when I try to directly print the new model:

julia> using Flux, CUDA

julia> m = Dense(2 => 3);

julia> device1 = Flux.get_device(CUDABackend, UInt(1));

julia> m = m |> device1
Dense(2 => 3)       # 9 parametersError showing value of type Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}:
ERROR: CUDA error: an illegal memory access was encountered (code 700, ERROR_ILLEGAL_ADDRESS)
Stacktrace:
  [1] throw_api_error(res::CUDA.cudaError_enum)
    @ CUDA ~/.julia/packages/CUDA/tVtYo/lib/cudadrv/libcuda.jl:27
  [2] check
    @ ~/.julia/packages/CUDA/tVtYo/lib/cudadrv/libcuda.jl:34 [inlined]
  [3] cuMemcpyDtoHAsync_v2
    @ ~/.julia/packages/CUDA/tVtYo/lib/utils/call.jl:26 [inlined]
  [4] #unsafe_copyto!#8
    @ ~/.julia/packages/CUDA/tVtYo/lib/cudadrv/memory.jl:397 [inlined]
  [5] (::CUDA.var"#1014#1015"{Bool, Vector{Bool}, Int64, CuArray{Bool, 2, CUDA.Mem.DeviceBuffer}, Int64, Int64})()
    @ CUDA ~/.julia/packages/CUDA/tVtYo/src/array.jl:482
  [6] #context!#887
    @ ~/.julia/packages/CUDA/tVtYo/lib/cudadrv/state.jl:170 [inlined]
  [7] context!
    @ ~/.julia/packages/CUDA/tVtYo/lib/cudadrv/state.jl:165 [inlined]
  [8] unsafe_copyto!(dest::Vector{Bool}, doffs::Int64, src::CuArray{Bool, 2, CUDA.Mem.DeviceBuffer}, soffs::Int64, n::Int64)
    @ CUDA ~/.julia/packages/CUDA/tVtYo/src/array.jl:475
  [9] copyto!
    @ ~/.julia/packages/CUDA/tVtYo/src/array.jl:429 [inlined]
 [10] getindex
    @ ~/.julia/packages/GPUArrays/5XhED/src/host/indexing.jl:12 [inlined]
 [11] macro expansion
    @ ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:136 [inlined]
 [12] _mapreduce(f::ComposedFunction{typeof(!), typeof(iszero)}, op::typeof(|), As::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}; dims::Colon, init::Nothing)
    @ GPUArrays ~/.julia/packages/GPUArrays/5XhED/src/host/mapreduce.jl:73
 [13] _mapreduce
    @ ~/.julia/packages/GPUArrays/5XhED/src/host/mapreduce.jl:35 [inlined]
 [14] #mapreduce#29
    @ ~/.julia/packages/GPUArrays/5XhED/src/host/mapreduce.jl:31 [inlined]
 [15] mapreduce
    @ ~/.julia/packages/GPUArrays/5XhED/src/host/mapreduce.jl:31 [inlined]
 [16] any
    @ ~/.julia/packages/GPUArrays/5XhED/src/host/mapreduce.jl:82 [inlined]
 [17] _any
    @ ~/fluxml/Flux.jl/src/layers/show.jl:129 [inlined]
 [18] (::Flux.var"#330#331"{ComposedFunction{typeof(!), typeof(iszero)}})(x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Flux ~/fluxml/Flux.jl/src/layers/show.jl:131
 [19] _any(f::Flux.var"#330#331"{ComposedFunction{typeof(!), typeof(iszero)}}, itr::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, #unused#::Colon)
    @ Base ./reduce.jl:1215
 [20] any
    @ ./reduce.jl:1210 [inlined]
 [21] _any
    @ ~/fluxml/Flux.jl/src/layers/show.jl:131 [inlined]
 [22] _all
    @ ~/fluxml/Flux.jl/src/layers/show.jl:135 [inlined]
 [23] _nan_show(io::IOContext{Base.TTY}, x::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Flux ~/fluxml/Flux.jl/src/layers/show.jl:120
 [24] _layer_show(io::IOContext{Base.TTY}, layer::Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, indent::Int64, name::Nothing)
    @ Flux ~/fluxml/Flux.jl/src/layers/show.jl:86
 [25] _layer_show
    @ ~/fluxml/Flux.jl/src/layers/show.jl:75 [inlined]
 [26] show(io::IOContext{Base.TTY}, m::MIME{Symbol("text/plain")}, x::Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})
    @ Flux ~/fluxml/Flux.jl/src/layers/show.jl:67
 [27] (::REPL.var"#55#56"{REPL.REPLDisplay{REPL.LineEditREPL}, MIME{Symbol("text/plain")}, Base.RefValue{Any}})(io::Any)
    @ REPL ~/julia-1.9.2/share/julia/stdlib/v1.9/REPL/src/REPL.jl:276
 [28] with_repl_linfo(f::Any, repl::REPL.LineEditREPL)
    @ REPL ~/julia-1.9.2/share/julia/stdlib/v1.9/REPL/src/REPL.jl:557
 [29] display(d::REPL.REPLDisplay, mime::MIME{Symbol("text/plain")}, x::Any)
    @ REPL ~/julia-1.9.2/share/julia/stdlib/v1.9/REPL/src/REPL.jl:262
 [30] display
    @ ~/julia-1.9.2/share/julia/stdlib/v1.9/REPL/src/REPL.jl:281 [inlined]
 [31] display(x::Any)
    @ Base.Multimedia ./multimedia.jl:340
 [32] #invokelatest#2
    @ ./essentials.jl:816 [inlined]
 [33] invokelatest
    @ ./essentials.jl:813 [inlined]
 [34] print_response(errio::IO, response::Any, show_value::Bool, have_color::Bool, specialdisplay::Union{Nothing, AbstractDisplay})
    @ REPL ~/julia-1.9.2/share/julia/stdlib/v1.9/REPL/src/REPL.jl:305
 [35] (::REPL.var"#57#58"{REPL.LineEditREPL, Pair{Any, Bool}, Bool, Bool})(io::Any)
    @ REPL ~/julia-1.9.2/share/julia/stdlib/v1.9/REPL/src/REPL.jl:287
 [36] with_repl_linfo(f::Any, repl::REPL.LineEditREPL)
    @ REPL ~/julia-1.9.2/share/julia/stdlib/v1.9/REPL/src/REPL.jl:557
 [37] print_response(repl::REPL.AbstractREPL, response::Any, show_value::Bool, have_color::Bool)
    @ REPL ~/julia-1.9.2/share/julia/stdlib/v1.9/REPL/src/REPL.jl:285
 [38] (::REPL.var"#do_respond#80"{Bool, Bool, REPL.var"#93#103"{REPL.LineEditREPL, REPL.REPLHistoryProvider}, REPL.LineEditREPL, REPL.LineEdit.Prompt})(s::REPL.LineEdit.MIState, buf::Any, ok::Bool)
    @ REPL ~/julia-1.9.2/share/julia/stdlib/v1.9/REPL/src/REPL.jl:899
 [39] #invokelatest#2
    @ ./essentials.jl:816 [inlined]
 [40] invokelatest
    @ ./essentials.jl:813 [inlined]
 [41] run_interface(terminal::REPL.Terminals.TextTerminal, m::REPL.LineEdit.ModalInterface, s::REPL.LineEdit.MIState)
    @ REPL.LineEdit ~/julia-1.9.2/share/julia/stdlib/v1.9/REPL/src/LineEdit.jl:2647
 [42] run_frontend(repl::REPL.LineEditREPL, backend::REPL.REPLBackendRef)
    @ REPL ~/julia-1.9.2/share/julia/stdlib/v1.9/REPL/src/REPL.jl:1300
 [43] (::REPL.var"#62#68"{REPL.LineEditREPL, REPL.REPLBackendRef})()
    @ REPL ./task.jl:514

I'm not exactly sure where the illegal memory access is happening. If I try to manually try to run the code for show for Dense models, it seems to work fine.

@ToucheSir
Copy link
Member

I believe you have to switch to device!(1) before printing the model because of how CUDA.jl works. That's a fine limitation though and one I totally expected. We can just document it with a small note for now.

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Care to try this for the other GPU backends as well?

ext/FluxCUDAExt/functor.jl Outdated Show resolved Hide resolved
@codetalker7
Copy link
Contributor Author

Care to try this for the other GPU backends as well?

Yes, will update the PR. Thanks for the suggestions.

@CarloLucibello
Copy link
Member

CarloLucibello commented Aug 11, 2023

I think the interface should be simpler. I suggest

device = get_device("CUDA", 1)

@codetalker7
Copy link
Contributor Author

Care to try this for the other GPU backends as well?

Hi @ToucheSir. I can surely attempt to implement this for metal and AMD, but I don't have access to these GPUs (so I won't be able to test on them). Just wondering: is there any way for me to test AMD/Metal code?

@codecov-commenter
Copy link

codecov-commenter commented Aug 12, 2023

Codecov Report

Patch coverage: 6.41% and project coverage change: -4.24% ⚠️

Comparison is base (656175d) 79.47% compared to head (0063bc0) 75.23%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2308      +/-   ##
==========================================
- Coverage   79.47%   75.23%   -4.24%     
==========================================
  Files          31       31              
  Lines        1749     1805      +56     
==========================================
- Hits         1390     1358      -32     
- Misses        359      447      +88     
Files Changed Coverage Δ
ext/FluxAMDGPUExt/FluxAMDGPUExt.jl 5.00% <0.00%> (-1.67%) ⬇️
ext/FluxAMDGPUExt/functor.jl 0.00% <0.00%> (ø)
ext/FluxCUDAExt/FluxCUDAExt.jl 28.57% <0.00%> (-8.93%) ⬇️
ext/FluxCUDAExt/functor.jl 0.00% <0.00%> (ø)
ext/FluxMetalExt/FluxMetalExt.jl 7.14% <0.00%> (-1.20%) ⬇️
src/functor.jl 59.66% <55.55%> (-20.18%) ⬇️

... and 3 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ToucheSir
Copy link
Member

Just wondering: is there any way for me to test AMD/Metal code?

Probably not locally, but since you'll be writing tests those should be auto-tested by CI. Also happy to help test the AMDGPU routines locally once those are in a good place.

@codetalker7
Copy link
Contributor Author

I've implemented the same interface for AMDGPU and Metal as well. The tests are already catching an error for Metal. I suspect it is because I'm using the registryID for a metal device. Unlike the CUDA and AMDGPU packages, Metal doesn't provide a way to get the ordinal of a metal GPU device. The closest thing I got from there was a registry ID.

Any nice way to get the ordinal for a metal device? Maybe we can match our device in the Metal.devices() list, and return the first matching index? But how do we match?

@ToucheSir
Copy link
Member

I wasn't aware Metal.jl had this limitation. If there's no clean way to do it now, we can always disable specific device selection for it.

@codetalker7
Copy link
Contributor Author

I wasn't aware Metal.jl had this limitation. If there's no clean way to do it now, we can always disable specific device selection for it.

Okay, I don't see a clean way to do this at the moment. So I'll skip adding this functionality to metal for now.

@codetalker7
Copy link
Contributor Author

Added some tests for CUDA and AMD. How do they look? Next I'll add some documentation.

@codetalker7
Copy link
Contributor Author

codetalker7 commented Aug 18, 2023

Getting a weird error in buildkite, which says: LoadError: cannot assign a value to imported variable CUDA.device from module Main. This is because CUDA is functional in the AMD and Metal tests? This wasn't the case before. Did something change in the CI?

adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
function adapt_storage(to::FluxCUDAAdaptor, x::AbstractArray)
typeof(to.ordinal) <: Nothing && return CUDA.cu(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
typeof(to.ordinal) <: Nothing && return CUDA.cu(x)
to.ordinal === nothing && return CUDA.cu(x)

@CarloLucibello
Copy link
Member

This seems good to go. You could remove the draft status.

@codetalker7 codetalker7 marked this pull request as ready for review August 20, 2023 12:34
@codetalker7
Copy link
Contributor Author

This seems good to go. You could remove the draft status.

Sure. I'll add some documentation.

@codetalker7
Copy link
Contributor Author

Completed everything. Please let me know if something needs to be changed.

@CarloLucibello CarloLucibello merged commit 36fbdf1 into FluxML:master Aug 21, 2023
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants