-
-
Notifications
You must be signed in to change notification settings - Fork 610
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Distributed data parallel training support (#2464)
* first experiment distributed * feat: add DistributedUtils (MPI&NCCL working) * feat: add DistributedUtils (MPI&NCCL working) * fix: no need for amdgpu now * chore: cleanup&propose how to use amdgpu * chore: add preferences for CUDA-awareness * feat: fix devices for CUDA-awareness * chore: add tests * chore: get rid of unnecessary deps * chore: update NEWS.md * chore: cleanup env * chore: update docs * chore: update docs & cleanup * chore: update docs & cleanup * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello <[email protected]> * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello <[email protected]> * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello <[email protected]> * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello <[email protected]> * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello <[email protected]> * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello <[email protected]> * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello <[email protected]> * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello <[email protected]> * Update docs/src/guide/gpu.md Co-authored-by: Carlo Lucibello <[email protected]> * Update docs/src/guide/gpu.md * Update docs/src/guide/gpu.md * chore: add PR review suggestions * chore: fix docs * fix: add runtests.jl * chore: small docs update * chore: remove pkgs from deps --------- Co-authored-by: CarloLucibello <[email protected]> Co-authored-by: Carlo Lucibello <[email protected]>
- Loading branch information
1 parent
033f4b2
commit d1ff714
Showing
14 changed files
with
1,042 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
module FluxMPIExt | ||
|
||
using CUDA | ||
using Flux: MPIBackend, NCCLBackend, DistributedUtils, | ||
AbstractDevice, FluxCUDADevice, FluxAMDGPUDevice, cpu, gpu, | ||
get_device, MPI_CUDA_AWARE, MPI_ROCM_AWARE | ||
using MPI: MPI | ||
|
||
if Base.find_package("AMDGPU") !== nothing | ||
using AMDGPU | ||
end | ||
|
||
|
||
function DistributedUtils.__initialize( | ||
::Type{MPIBackend}; cuda_devices=nothing, amdgpu_devices=nothing, | ||
force_cuda::Bool=false, caller::String="", force_amdgpu::Bool=false) # Undocumented internal kwarg | ||
!MPI.Initialized() && MPI.Init() | ||
DistributedUtils.MPI_Initialized[] = true | ||
|
||
local_rank = MPI.Comm_rank(MPI.COMM_WORLD) | ||
|
||
if cuda_devices !== missing && CUDA.functional() | ||
if cuda_devices === nothing | ||
CUDA.device!((local_rank + 1) % length(CUDA.devices())) | ||
else | ||
CUDA.device!(cuda_devices[local_rank + 1]) | ||
end | ||
elseif force_cuda | ||
error(lazy"CUDA devices are not functional and `force_cuda` is set to `true`. This is caused by backend: $(caller).") | ||
end | ||
|
||
if Base.find_package("AMDGPU") !== nothing | ||
if amdgpu_devices !== missing && AMDGPU.functional() | ||
if amdgpu_devices === nothing | ||
AMDGPU.device!((local_rank + 1) % length(AMDGPU.devices())) | ||
else | ||
AMDGPU.device!(amdgpu_devices[local_rank + 1]) | ||
end | ||
elseif force_amdgpu | ||
error(lazy"AMDGPU devices are not functional (or `LuxAMDGPU.jl` not loaded) and `force_amdgpu` is set to `true`. This is caused by backend: $(caller).") | ||
end | ||
end | ||
|
||
return | ||
end | ||
|
||
DistributedUtils.__get_distributed_backend(::Type{MPIBackend}) = MPIBackend(MPI.COMM_WORLD) | ||
|
||
DistributedUtils.local_rank(backend::MPIBackend) = MPI.Comm_rank(backend.comm) | ||
|
||
DistributedUtils.total_workers(backend::MPIBackend) = MPI.Comm_size(backend.comm) | ||
|
||
# Broadcast | ||
# Union with Function is because of Flux.cpu istypeof Function | ||
# We need CPU in case of non CUDA-aware implementation | ||
function DistributedUtils.__bcast!( | ||
backend::MPIBackend, sendrecvbuf, dev::Union{AbstractDevice, Function}; root=0) | ||
MPI.Bcast!(sendrecvbuf, backend.comm; root) | ||
return sendrecvbuf | ||
end | ||
|
||
function DistributedUtils.__bcast!( | ||
backend::MPIBackend, sendbuf, recvbuf, dev::Union{AbstractDevice, Function}; root=0) | ||
return DistributedUtils.__bcast!( | ||
backend, ifelse(DistributedUtils.local_rank(backend) == root, sendbuf, recvbuf), | ||
dev; root) | ||
end | ||
|
||
# if MPI implementation is not CUDA-aware | ||
# we have to move data to CPU first | ||
for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice)) | ||
if !aware | ||
@eval begin | ||
function DistributedUtils.__bcast!( | ||
backend::MPIBackend, sendrecvbuf, dev::$dType; root=0) | ||
sendrecvbuf_ = sendrecvbuf |> cpu | ||
DistributedUtils.__bcast!(backend, sendrecvbuf_, cpu; root) | ||
sendrecvbuf |> gpu | ||
return sendrecvbuf | ||
end | ||
|
||
function DistributedUtils.__bcast!( | ||
backend::MPIBackend, sendbuf, recvbuf, dev::$dType; root=0) | ||
sendbuf_ = sendbuf |> cpu | ||
recvbuf_ = recvbuf |> cpu | ||
DistributedUtils.__bcast!(backend, sendbuf_, recvbuf_, cpu; root) | ||
recvbuf |> gpu | ||
return recvbuf | ||
end | ||
end | ||
end | ||
end | ||
|
||
|
||
# Allreduce | ||
function DistributedUtils.__allreduce!( | ||
backend::MPIBackend, sendrecvbuf, op::F, dev::Union{AbstractDevice, Function};) where {F} | ||
mpiop = ifelse(op === DistributedUtils.avg, +, op) | ||
MPI.Allreduce!(sendrecvbuf, mpiop, backend.comm) | ||
if op === DistributedUtils.avg | ||
sendrecvbuf ./= DistributedUtils.total_workers(backend) | ||
end | ||
return sendrecvbuf | ||
end | ||
|
||
function DistributedUtils.__allreduce!( | ||
backend::MPIBackend, sendbuf, recvbuf, op::F, dev::Union{AbstractDevice, Function};) where {F} | ||
mpiop = ifelse(op === DistributedUtils.avg, +, op) | ||
MPI.Allreduce!(sendbuf, recvbuf, mpiop, backend.comm) | ||
if op === DistributedUtils.avg | ||
recvbuf ./= DistributedUtils.total_workers(backend) | ||
end | ||
return recvbuf | ||
end | ||
|
||
for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice)) | ||
if !aware | ||
@eval begin | ||
function DistributedUtils.__allreduce!( | ||
backend::MPIBackend, sendrecvbuf, op::F, dev::$dType) where {F} | ||
sendrecvbuf_ = sendrecvbuf |> cpu | ||
DistributedUtils.__allreduce!(backend, sendrecvbuf_, op, cpu) | ||
sendrecvbuf |> gpu | ||
return sendrecvbuf | ||
end | ||
|
||
function DistributedUtils.__allreduce!( | ||
backend::MPIBackend, sendbuf, recvbuf, op::F, dev::$dType) where {F} | ||
sendbuf_ = sendbuf |> cpu | ||
recvbuf_ = recvbuf |> cpu | ||
DistributedUtils.__allreduce!(backend, sendbuf_, recvbuf_, op, cpu) | ||
recvbuf |> gpu | ||
return recvbuf | ||
end | ||
end | ||
end | ||
end | ||
|
||
# Reduce | ||
function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F, | ||
dev::Union{AbstractDevice, Function}; root::Int) where {F} | ||
mpiop = ifelse(op === DistributedUtils.avg, +, op) | ||
MPI.Reduce!(sendrecvbuf, mpiop, backend.comm; root) | ||
if op === DistributedUtils.avg | ||
sendrecvbuf ./= DistributedUtils.total_workers(backend) | ||
end | ||
return sendrecvbuf | ||
end | ||
|
||
function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf, op::F, | ||
dev::Union{AbstractDevice, Function}; root::Int) where {F} | ||
mpiop = ifelse(op === DistributedUtils.avg, +, op) | ||
MPI.Reduce!(sendbuf, recvbuf, mpiop, backend.comm; root) | ||
if op === DistributedUtils.avg | ||
recvbuf ./= DistributedUtils.total_workers(backend) | ||
end | ||
return recvbuf | ||
end | ||
|
||
for (aware, dType) in ((MPI_CUDA_AWARE, FluxCUDADevice), (MPI_ROCM_AWARE, FluxAMDGPUDevice)) | ||
if !aware | ||
@eval begin | ||
function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F, | ||
dev::$dType; root::Int) where {F} | ||
sendrecvbuf_ = sendrecvbuf |> cpu | ||
DistributedUtils.__reduce!(backend, sendrecvbuf_, op, cpu; root) | ||
sendrecvbuf |> gpu | ||
return sendrecvbuf | ||
end | ||
|
||
function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf, | ||
op::F, dev::$dType; root::Int) where {F} | ||
sendbuf_ = sendbuf |> cpu | ||
recvbuf_ = recvbuf |> cpu | ||
DistributedUtils.__reduce!(backend, sendbuf_, recvbuf_, op, cpu; root) | ||
recvbuf |> gpu | ||
return recvbuf | ||
end | ||
end | ||
end | ||
end | ||
|
||
end |
Oops, something went wrong.