Skip to content

Commit

Permalink
Replacing type checks for nothing but equality checks.
Browse files Browse the repository at this point in the history
  • Loading branch information
codetalker7 committed Aug 20, 2023
1 parent 08b1962 commit 0063bc0
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ext/FluxAMDGPUExt/FluxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat
const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing)

function (device::Flux.FluxAMDDevice)(x)
if typeof(device.deviceID) <: Nothing
if device.deviceID === nothing
Flux.gpu(Flux.FluxAMDAdaptor(), x)
else
return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID) - 1), x) # subtracting 1, because device_id returns a positive integer
Expand Down
2 changes: 1 addition & 1 deletion ext/FluxAMDGPUExt/functor.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Convert Float64 to Float32, but preserve Float16.
function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray)
if typeof(to.ordinal) <: Nothing
if to.ordinal === nothing
if (typeof(x) <: AbstractArray{Float16, N} where N)
N = length(size(x))
return isbits(x) ? x : ROCArray{Float16, N}(x)
Expand Down
2 changes: 1 addition & 1 deletion ext/FluxCUDAExt/FluxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import Adapt: adapt_storage
const USE_CUDA = Ref{Union{Nothing, Bool}}(nothing)

function (device::Flux.FluxCUDADevice)(x)
if typeof(device.deviceID) <: Nothing
if device.deviceID === nothing
return Flux.gpu(Flux.FluxCUDAAdaptor(), x)
else
return Flux.gpu(Flux.FluxCUDAAdaptor(device.deviceID.handle), x)
Expand Down
2 changes: 1 addition & 1 deletion ext/FluxCUDAExt/functor.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
function adapt_storage(to::FluxCUDAAdaptor, x::AbstractArray)
typeof(to.ordinal) <: Nothing && return CUDA.cu(x)
to.ordinal === nothing && return CUDA.cu(x)

# remember current device
old_ordinal = CUDA.device().handle
Expand Down

0 comments on commit 0063bc0

Please sign in to comment.