Skip to content

Commit

Permalink
Adding docstring for new get_device method.
Browse files Browse the repository at this point in the history
  • Loading branch information
codetalker7 committed Aug 18, 2023
1 parent 581d74c commit fafef86
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ julia> Flux.supported_devices()
supported_devices() = GPU_BACKENDS

"""
Flux.get_device(; verbose=false)::AbstractDevice
Flux.get_device(; verbose=false)::Flux.AbstractDevice
Returns a `device` object for the most appropriate backend for the current Julia session.
Expand Down Expand Up @@ -656,6 +656,40 @@ function get_device(; verbose=false)::AbstractDevice
end
end

"""
Flux.get_device(backend::String, ordinal::Int = 0)::Flux.AbstractDevice
Get a device object for a backend specified by the string `backend` and `ordinal`. The currently supported values
of `backend` are `"CUDA"`, `"AMD"` and `"CPU"`.
# Examples
```julia-repl
julia> using Flux, CUDA;
julia> CUDA.devices()
CUDA.DeviceIterator() for 3 devices:
0. GeForce RTX 2080 Ti
1. GeForce RTX 2080 Ti
2. TITAN X (Pascal)
julia> device0 = Flux.get_device("CUDA", 0)
(::Flux.FluxCUDADevice) (generic function with 1 method)
julia> device0.deviceID
CuDevice(0): GeForce RTX 2080 Ti
julia> device1 = Flux.get_device("CUDA", 1)
(::Flux.FluxCUDADevice) (generic function with 1 method)
julia> device1.deviceID
CuDevice(1): GeForce RTX 2080 Ti
julia> cpu_device = Flux.get_device("CPU")
(::Flux.FluxCPUDevice) (generic function with 1 method)
```
"""
function get_device(backend::String, ordinal::Int = 0)
if backend == "CPU"
return FluxCPUDevice()
Expand Down

0 comments on commit fafef86

Please sign in to comment.