From 0063bc046aecc3f915e3d09d240112fcc23a9b2f Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 20 Aug 2023 23:40:44 +0530 Subject: [PATCH] Replacing type checks for `nothing` but equality checks. --- ext/FluxAMDGPUExt/FluxAMDGPUExt.jl | 2 +- ext/FluxAMDGPUExt/functor.jl | 2 +- ext/FluxCUDAExt/FluxCUDAExt.jl | 2 +- ext/FluxCUDAExt/functor.jl | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl index 5606f2db01..a199cd270e 100644 --- a/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl +++ b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl @@ -18,7 +18,7 @@ const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing) function (device::Flux.FluxAMDDevice)(x) - if typeof(device.deviceID) <: Nothing + if device.deviceID === nothing Flux.gpu(Flux.FluxAMDAdaptor(), x) else return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID) - 1), x) # subtracting 1, because device_id returns a positive integer diff --git a/ext/FluxAMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl index 4e79d0b01e..62507122db 100644 --- a/ext/FluxAMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -1,6 +1,6 @@ # Convert Float64 to Float32, but preserve Float16. function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray) - if typeof(to.ordinal) <: Nothing + if to.ordinal === nothing if (typeof(x) <: AbstractArray{Float16, N} where N) N = length(size(x)) return isbits(x) ? x : ROCArray{Float16, N}(x) diff --git a/ext/FluxCUDAExt/FluxCUDAExt.jl b/ext/FluxCUDAExt/FluxCUDAExt.jl index adb2c7af66..3fcdc5c263 100644 --- a/ext/FluxCUDAExt/FluxCUDAExt.jl +++ b/ext/FluxCUDAExt/FluxCUDAExt.jl @@ -15,7 +15,7 @@ import Adapt: adapt_storage const USE_CUDA = Ref{Union{Nothing, Bool}}(nothing) function (device::Flux.FluxCUDADevice)(x) - if typeof(device.deviceID) <: Nothing + if device.deviceID === nothing return Flux.gpu(Flux.FluxCUDAAdaptor(), x) else return Flux.gpu(Flux.FluxCUDAAdaptor(device.deviceID.handle), x) diff --git a/ext/FluxCUDAExt/functor.jl b/ext/FluxCUDAExt/functor.jl index b98e3401e3..5a4c1d1152 100644 --- a/ext/FluxCUDAExt/functor.jl +++ b/ext/FluxCUDAExt/functor.jl @@ -1,6 +1,6 @@ adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) function adapt_storage(to::FluxCUDAAdaptor, x::AbstractArray) - typeof(to.ordinal) <: Nothing && return CUDA.cu(x) + to.ordinal === nothing && return CUDA.cu(x) # remember current device old_ordinal = CUDA.device().handle