diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d34f1475..5b8d971c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,7 +21,7 @@ concurrency: jobs: ci: - name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - ${{ matrix.blas_backend }} + name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - ${{ matrix.blas_backend }} - ${{ matrix.loopvec }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} strategy: @@ -43,27 +43,49 @@ jobs: - "others" blas_backend: - "default" + loopvec: + - "true" include: - os: ubuntu-latest test_group: "dense" blas_backend: "blis" version: "1.10" + loopvec: "true" - os: ubuntu-latest test_group: "dense" blas_backend: "mkl" version: "1.10" + loopvec: "true" + - os: ubuntu-latest + test_group: "dense" + blas_backend: "default" + version: "1.10" + loopvec: "false" + - os: ubuntu-latest + test_group: "batched_ops" + blas_backend: "default" + version: "1.10" + loopvec: "false" + - os: ubuntu-latest + test_group: "other_ops" + blas_backend: "default" + version: "1.10" + loopvec: "false" - os: macos-latest test_group: "dense" blas_backend: "appleaccelerate" version: "1.10" + loopvec: "true" - os: macos-latest test_group: "all" blas_backend: "default" version: "1.10" + loopvec: "true" - os: windows-latest test_group: "all" blas_backend: "default" version: "1.10" + loopvec: "true" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -84,6 +106,7 @@ jobs: env: LUXLIB_TEST_GROUP: ${{ matrix.test_group }} LUXLIB_BLAS_BACKEND: ${{ matrix.blas_backend }} + LUXLIB_LOAD_LOOPVEC: ${{ matrix.loopvec }} - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext diff --git a/Project.toml b/Project.toml index 5598564a..7225334c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.3" +version = "1.3.4" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -15,16 +15,14 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -36,7 +34,10 @@ BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" @@ -46,7 +47,10 @@ LuxLibBLISBLASExt = "BLISBLAS" LuxLibCUDAExt = "CUDA" LuxLibMKLExt = "MKL" LuxLibEnzymeExt = "Enzyme" +LuxLibLoopVectorizationExt = "LoopVectorization" +LuxLibOctavianExt = ["Octavian", "LoopVectorization"] LuxLibReverseDiffExt = "ReverseDiff" +LuxLibSLEEFPiratesExt = "SLEEFPirates" LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] LuxLibTrackerExt = "Tracker" LuxLibcuDNNExt = ["CUDA", "cuDNN"] @@ -75,6 +79,7 @@ MLDataDevices = "1.2" Markdown = "1.10" NNlib = "0.9.24" Octavian = "0.3.28" +Preferences = "1.4.3" Polyester = "0.7.15" Random = "1.10" Reexport = "1" diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 7fe762e6..b9a9db67 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -1,9 +1,11 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/benchmarks/runbenchmarks.jl b/benchmarks/runbenchmarks.jl index 7313b7c2..6035c8b2 100644 --- a/benchmarks/runbenchmarks.jl +++ b/benchmarks/runbenchmarks.jl @@ -3,6 +3,7 @@ using Pkg using BenchmarkTools using InteractiveUtils using LinearAlgebra +using Octavian, LoopVectorization const SUITE = BenchmarkGroup() BenchmarkTools.DEFAULT_PARAMETERS.seconds = 5 diff --git a/ext/LuxLibLoopVectorizationExt.jl b/ext/LuxLibLoopVectorizationExt.jl new file mode 100644 index 00000000..87a912be --- /dev/null +++ b/ext/LuxLibLoopVectorizationExt.jl @@ -0,0 +1,72 @@ +module LuxLibLoopVectorizationExt + +using LoopVectorization: LoopVectorization, @tturbo, @turbo, indices +using Polyester: @batch +using Static: True + +using LuxLib: LuxLib, Utils + +Utils.is_extension_loaded(::Val{:LoopVectorization}) = True() + +Utils.can_loopvec_args_check(::True, args...) = LoopVectorization.check_args(args...) + +# matmul +for serial in (true, false) + opname = serial ? :serial_matmul_loopvec! : :matmul_loopvec! + @eval @inline function LuxLib.Impl.$(opname)( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) + if !iszero(β) # Secial case this because Base.FastMath.mul_fast(NaN, false) = NaN + @turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[J, I] * B[I, K] + end + C[J, K] = α * Cⱼₖ + β * C[J, K] + end + else + @turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[J, I] * B[I, K] + end + C[J, K] = α * Cⱼₖ + end + end + end +end + +@inline function LuxLib.Impl.matmuladd_loopvec!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + @tturbo for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[J, I] * B[I, K] + end + C[J, K] = bias[J] + Cⱼₖ + end + return +end + +# batched matmul +function LuxLib.Impl.batched_matmul_loopvec_impl!( + z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}, α::Number=true, β::Number=false) where {zT, xT, yT} + if size(x, 3) == size(y, 3) + @batch for L in axes(z, 3) + LuxLib.Impl.serial_matmul_loopvec!( + Utils.batchview(z, L), Utils.batchview(x, L), Utils.batchview(y, L), α, β) + end + elseif size(x, 3) == 1 + @batch for L in axes(z, 3) + LuxLib.Impl.serial_matmul_loopvec!( + Utils.batchview(z, L), Utils.batchview(x, 1), Utils.batchview(y, L), α, β) + end + else # has to be size(y, 3) == 1 + @batch for L in axes(z, 3) + LuxLib.Impl.serial_matmul_loopvec!( + Utils.batchview(z, L), Utils.batchview(x, L), Utils.batchview(y, 1), α, β) + end + end +end + +end diff --git a/ext/LuxLibOctavianExt.jl b/ext/LuxLibOctavianExt.jl new file mode 100644 index 00000000..a112fa94 --- /dev/null +++ b/ext/LuxLibOctavianExt.jl @@ -0,0 +1,16 @@ +module LuxLibOctavianExt + +using Octavian: Octavian +using Static: True + +using LuxLib: LuxLib, Utils + +Utils.is_extension_loaded(::Val{:Octavian}) = True() + +@inline function LuxLib.Impl.matmul_octavian!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) + Octavian.matmul!(C, A, B, α, β) + return +end + +end diff --git a/ext/LuxLibSLEEFPiratesExt.jl b/ext/LuxLibSLEEFPiratesExt.jl new file mode 100644 index 00000000..6c522b2b --- /dev/null +++ b/ext/LuxLibSLEEFPiratesExt.jl @@ -0,0 +1,58 @@ +module LuxLibSLEEFPiratesExt + +using ChainRulesCore: ChainRulesCore +using NNlib: NNlib +using SLEEFPirates: SLEEFPirates + +using LuxLib: Numeric, Impl + +const CRC = ChainRulesCore + +sigmoid_fast(x::Number) = SLEEFPirates.sigmoid_fast(x) +softplus(x::Number) = SLEEFPirates.softplus(x) +logsigmoid(x::Number) = -softplus(-x) +swish(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast(x)) +lisht(x::Number) = Base.FastMath.mul_fast(x, tanh_fast(x)) +tanh(x::Number) = SLEEFPirates.tanh(x) +tanh_fast(x::Number) = SLEEFPirates.tanh_fast(x) + +for (f, dfdx) in [ + #! format: off + (:sigmoid_fast, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))), + (:softplus, :(sigmoid_fast(x))), + (:logsigmoid, :(sigmoid_fast(-x))), + (:swish, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast(x), Base.FastMath.sub_fast(1, Ω))))), + (:lisht, :(Base.FastMath.add_fast(x, Base.FastMath.mul_fast(tanh_fast(x), Base.FastMath.sub_fast(1, Ω))))), + (:tanh, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))), + (:tanh_fast, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))) + #! format: on +] + @eval CRC.@scalar_rule($f(x), $(dfdx)) + + ∇f = Symbol(:∇broadcasted_, f) + @eval function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof($f), + x::Union{Numeric, Broadcast.Broadcasted}) + Ω = $(f).(x) + function $(∇f)(dΩ) + ∂x = CRC.InplaceableThunk(dx -> @.(dx+=dΩ * $(dfdx)), CRC.@thunk @.(dΩ*$(dfdx))) + return CRC.NoTangent(), CRC.NoTangent(), ∂x + end + return Ω, $(∇f) + end +end + +for (fbase, ffast) in [ + #! format: off + (NNlib.sigmoid_fast, sigmoid_fast), + (NNlib.softplus, softplus), + (NNlib.logsigmoid, logsigmoid), + (NNlib.swish, swish), + (NNlib.lisht, lisht), + (Base.tanh, tanh), + (NNlib.tanh_fast, tanh_fast) + #! format: on +] + @eval Impl.sleefpirates_fast_act(::typeof($fbase)) = $ffast +end + +end diff --git a/src/LuxLib.jl b/src/LuxLib.jl index 05c77f60..f0e5ca70 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -1,6 +1,7 @@ module LuxLib using Compat: @compat +using Preferences: @load_preference using Reexport: @reexport using Static: Static, known @@ -15,6 +16,8 @@ const Numeric = Union{AbstractArray{<:T}, T} where {T <: Number} const ∂∅ = NoTangent() const CRC = ChainRulesCore +const DISABLE_LOOP_VECTORIZATION = @load_preference("disable_loop_vectorization", false) + include("utils.jl") include("traits.jl") include("impl/Impl.jl") diff --git a/src/api/activation.jl b/src/api/activation.jl index 9ef1c544..df44aa0c 100644 --- a/src/api/activation.jl +++ b/src/api/activation.jl @@ -10,7 +10,7 @@ generic implementation. This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be done by the user if needed. -!!! tip +!!! tip "Load `SLEEFPirates.jl` to get faster activations" Certain activation functions are replaced with specialized implementations from [SLEEFPirates.jl](https://github.com/JuliaSIMD/SLEEFPirates.jl) for FP32. This might diff --git a/src/api/batched_mul.jl b/src/api/batched_mul.jl index a5d7b132..c6cb379a 100644 --- a/src/api/batched_mul.jl +++ b/src/api/batched_mul.jl @@ -4,6 +4,11 @@ Computes the batched matrix multiplication of `x` and `y`. For more details see the NNlib documentation on `NNlib.batched_mul`. This function is mostly a wrapper around `batched_mul` but attempts to be faster on CPUs. + +!!! tip "Load `LoopVectorization.jl` to get faster batched matrix multiplication" + + On CPUs loading LoopVectorization adds faster implementations of batched matrix + multiplication. """ function batched_matmul(x::AbstractMatrix, y::AbstractArray{yT, 3}) where {yT} return batched_matmul(expand_batchdim(x), y) diff --git a/src/api/dense.jl b/src/api/dense.jl index 0e83dac7..f51b2518 100644 --- a/src/api/dense.jl +++ b/src/api/dense.jl @@ -24,6 +24,11 @@ multiple operations. - For small CPU Arrays, we use LoopVectorization.jl. On `x86_64` we use Octavian for medium sized matrices. This is overridden if special BLAS implementations are loaded (currently `MKL`, `AppleAccelerate`, and `BLISBLAS`). + +!!! tip "Load `Octavian.jl` + + Loading `Octavian.jl` enables a polyalgorithm that uses different backends based on the + input sizes. """ function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} diff --git a/src/impl/Impl.jl b/src/impl/Impl.jl index 8956a639..b6a6a0d9 100644 --- a/src/impl/Impl.jl +++ b/src/impl/Impl.jl @@ -12,8 +12,6 @@ using ForwardDiff: ForwardDiff using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index -using LoopVectorization: LoopVectorization, @turbo, @tturbo, indices -using Octavian: Octavian using Polyester: @batch using LinearAlgebra: LinearAlgebra, mul! @@ -31,7 +29,7 @@ using ..Utils: Utils, NotaNumber, batchview, concrete_bias_act_output_eltype, co copy_drop_gradients, eltype_mismatch, expand_batchdim, maybe_reduce_BLAS_threads, ofeltype_array, only_derivative, remove_tracking, reset_BLAS_threads, run_ka_kernel, safe_eltype, safe_vec, safe_warning, - unsafe_known, unrolled_mapreduce, @enzyme_alternative + unsafe_known, unrolled_mapreduce, can_loopvec_args, @enzyme_alternative using ..Traits: activation_intermediate_not_needed, activation_has_rrule, is_mutable_array, fuse_cpu_activation using ..System: explicit_blas_loaded, use_octavian, fits_in_l1cache, fits_in_l2cache, @@ -39,7 +37,6 @@ using ..System: explicit_blas_loaded, use_octavian, fits_in_l1cache, fits_in_l2c const CRC = ChainRulesCore const KA = KernelAbstractions -const LV = LoopVectorization include("activation.jl") include("batched_mul.jl") diff --git a/src/impl/activation.jl b/src/impl/activation.jl index dfd1d0c9..0b015e3b 100644 --- a/src/impl/activation.jl +++ b/src/impl/activation.jl @@ -91,16 +91,6 @@ function activation!( return end function activation!(y::AbstractArray, ::LoopedArrayOp, σ::F, x::AbstractArray) where {F} - activation_loop!(y, σ, x) - return -end - -function activation_loop!(y::AbstractArray, σ::F, x::AbstractArray) where {F} - # We use fuse activation as a proxy check for "simple functions" - if LV.check_args(y, x) && unsafe_known(!fuse_cpu_activation(σ)) - LV.vmap!(σ, y, x) - return - end activation_simd_loop!(y, σ, x) return end @@ -111,8 +101,6 @@ function activation_simd_loop!(y::AbstractArray, σ::F, x::AbstractArray) where end end -@enzyme_alternative activation_loop! activation_simd_loop! - # Gradient for activations ∇activation(Δ, _, ::typeof(identity), x) = Δ function ∇activation(Δ, out, act::F, x) where {F} @@ -124,11 +112,11 @@ end @inbounds function ∇activation(::LoopedArrayOp, Δ, out, act::F, x) where {F} y = similar(out) if x isa NotaNumber - @simd ivdep for i in indices((Δ, out)) + @simd ivdep for i in eachindex(Δ, out) @inbounds y[i] = only_derivative(out[i], act, x) * Δ[i] end else - @simd ivdep for i in indices((Δ, out, x)) + @simd ivdep for i in eachindex(Δ, out, x) @inbounds y[i] = only_derivative(out[i], act, x[i]) * Δ[i] end end @@ -144,73 +132,13 @@ end select_fastest_activation(f::F, ::AbstractInternalArrayOpMode, ::Type{T}) where {F, T} = f function select_fastest_activation(f::F, ::LoopedArrayOp, ::Type{T}) where {F, T} - return SLEEFActivations.fast_act(f, T) + return sleefpirates_fast_act(f, T) end CRC.@non_differentiable select_fastest_activation(::Any...) -# Fast activations via SLEEFPirates.jl -module SLEEFActivations - -using ChainRulesCore: ChainRulesCore -using NNlib: NNlib -using SLEEFPirates: SLEEFPirates - -using ....LuxLib: Numeric - -const CRC = ChainRulesCore - -sigmoid_fast(x::Number) = SLEEFPirates.sigmoid_fast(x) -softplus(x::Number) = SLEEFPirates.softplus(x) -logsigmoid(x::Number) = -softplus(-x) -swish(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast(x)) -lisht(x::Number) = Base.FastMath.mul_fast(x, tanh_fast(x)) -tanh(x::Number) = SLEEFPirates.tanh(x) -tanh_fast(x::Number) = SLEEFPirates.tanh_fast(x) - -for (f, dfdx) in [ - #! format: off - (:sigmoid_fast, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))), - (:softplus, :(sigmoid_fast(x))), - (:logsigmoid, :(sigmoid_fast(-x))), - (:swish, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast(x), Base.FastMath.sub_fast(1, Ω))))), - (:lisht, :(Base.FastMath.add_fast(x, Base.FastMath.mul_fast(tanh_fast(x), Base.FastMath.sub_fast(1, Ω))))), - (:tanh, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))), - (:tanh_fast, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))) - #! format: on -] - @eval CRC.@scalar_rule($f(x), $(dfdx)) - - ∇f = Symbol(:∇broadcasted_, f) - @eval function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof($f), - x::Union{Numeric, Broadcast.Broadcasted}) - Ω = $(f).(x) - function $(∇f)(dΩ) - ∂x = CRC.InplaceableThunk(dx -> @.(dx+=dΩ * $(dfdx)), CRC.@thunk @.(dΩ*$(dfdx))) - return CRC.NoTangent(), CRC.NoTangent(), ∂x - end - return Ω, $(∇f) - end -end - -fast_act(f::F, ::Type{T}) where {F, T} = f -fast_act(f::F, ::Type{Float32}) where {F} = fast_act(f) - -for (fbase, ffast) in [ - #! format: off - (NNlib.sigmoid_fast, sigmoid_fast), - (NNlib.softplus, softplus), - (NNlib.logsigmoid, logsigmoid), - (NNlib.swish, swish), - (NNlib.lisht, lisht), - (Base.tanh, tanh), - (NNlib.tanh_fast, tanh_fast) - #! format: on -] - @eval fast_act(::typeof($fbase)) = $ffast -end -fast_act(f::F) where {F} = f - -CRC.@non_differentiable fast_act(::Any...) +sleefpirates_fast_act(f::F, ::Type{T}) where {F, T} = f +sleefpirates_fast_act(f::F, ::Type{Float32}) where {F} = sleefpirates_fast_act(f) +sleefpirates_fast_act(f::F) where {F} = f -end +CRC.@non_differentiable sleefpirates_fast_act(::Any...) diff --git a/src/impl/batched_mul.jl b/src/impl/batched_mul.jl index af10d57e..257b4e0f 100644 --- a/src/impl/batched_mul.jl +++ b/src/impl/batched_mul.jl @@ -50,33 +50,25 @@ end function batched_matmul!(z::AbstractArray{zT, 3}, ::LoopedArrayOp, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} - if !LV.check_args(batchview(z, 1), batchview(x, 1), batchview(y, 1)) || - unsafe_known(explicit_blas_loaded()) - NNlib.batched_mul!(z, x, y) - return - end - batched_matmul_loopvec_impl!(z, x, y) + batched_matmul_cpu!(z, x, y) return end -function batched_matmul_loopvec_impl!( - z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, - y::AbstractArray{yT, 3}, α::Number=true, β::Number=false) where {zT, xT, yT} - if size(x, 3) == size(y, 3) - @batch for L in indices((z, x, y), 3) - serial_matmul_loopvec!(batchview(z, L), batchview(x, L), batchview(y, L), α, β) - end - elseif size(x, 3) == 1 - @batch for L in indices((z, y), 3) - serial_matmul_loopvec!(batchview(z, L), batchview(x, 1), batchview(y, L), α, β) - end - else # has to be size(y, 3) == 1 - @batch for L in indices((z, x), 3) - serial_matmul_loopvec!(batchview(z, L), batchview(x, L), batchview(y, 1), α, β) - end +function batched_matmul_cpu!(z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {zT, xT, yT} + if can_loopvec_args(batchview(z, 1), batchview(x, 1), batchview(y, 1)) && + !unsafe_known(explicit_blas_loaded()) + batched_matmul_loopvec_impl!(z, x, y) + return end + # Avoid an Enzyme segfault https://github.com/EnzymeAD/Enzyme.jl/issues/1983 + fallback_batched_matmul!(z, LoopedArrayOp(), x, y) + # NNlib.batched_mul!(z, x, y) # XXX: restore once the enzyme segfault is fixed + return end +function batched_matmul_loopvec_impl! end + function fallback_batched_matmul( dev, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} z = similar(x, promote_type(eltype(x), eltype(y)), size(x, 1), @@ -88,26 +80,35 @@ end function fallback_batched_matmul!( z::AbstractArray{zT, 3}, dev, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} - @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \ - $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \ - slow." maxlog=1 + # XXX: bring back once the enzyme segfault is fixed + # @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \ + # $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \ + # slow." maxlog=1 + if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) || (size(x, 2) != size(y, 1)) throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) end + + old_threads = maybe_reduce_BLAS_threads(z) + if size(x, 3) == size(y, 3) - Threads.@threads for L in indices((x, y), 3) + Threads.@threads for L in axes(z, 3) mul!(batchview(z, L), batchview(x, L), batchview(y, L)) end elseif size(x, 3) == 1 - Threads.@threads for L in indices((x, y), 3) + Threads.@threads for L in axes(z, 3) mul!(batchview(z, L), batchview(x, 1), batchview(y, L)) end else # has to be size(y, 3) == 1 - Threads.@threads for L in indices((x, y), 3) + Threads.@threads for L in axes(z, 3) mul!(batchview(z, L), batchview(x, L), batchview(y, 1)) end end + + reset_BLAS_threads(old_threads) + + return end function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{xT, 3}, @@ -192,7 +193,7 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) if size(dA, 3) == 1 && size(B.val, 3) != 1 B′ = NNlib.batched_adjoint(B.val) dA′ = batchview(dA, 1) - for L in indices(B′, 3) + for L in axes(B′, 3) mul!(dA′, batchview(dC, L), batchview(B′, L), true, true) end @@ -205,7 +206,7 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) if size(dB, 3) == 1 && size(A.val, 3) != 1 A′ = NNlib.batched_adjoint(A.val) dB′ = batchview(dB, 1) - for L in indices(A′, 3) + for L in axes(A′, 3) mul!(dB′, batchview(A′, L), batchview(dC, L), true, true) end diff --git a/src/impl/batchnorm.jl b/src/impl/batchnorm.jl index c1e377fb..b15490f1 100644 --- a/src/impl/batchnorm.jl +++ b/src/impl/batchnorm.jl @@ -97,12 +97,12 @@ end function compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ) if γ === nothing && β === nothing - @simd ivdep for J in indices((γ′, β′, μ, σ²)) + @simd ivdep for J in eachindex(γ′, β′, μ, σ²) @fastmath @inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ)) @fastmath @inbounds β′[J] = -μ[J] * γ′[J] end else - @simd ivdep for J in indices((γ′, β′, γ, β, μ, σ²)) + @simd ivdep for J in eachindex(γ′, β′, γ, β, μ, σ²) @fastmath @inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ) @fastmath @inbounds β′[J] = β[J] - μ[J] * γ′[J] end @@ -122,8 +122,8 @@ end @inline function apply_batchnorm_scale_bias_act_2d_serial_cpu!( y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{xT, 3}, σ::F) where {F, xT, yT} - for K in indices((x, y), 3) - @simd ivdep for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) + for K in axes(x, 3) + @simd ivdep for J in axes(x, 2) @fastmath @inbounds y[1, J, K] = σ(x[1, J, K] * γ′[J] + β′[J]) end end @@ -132,9 +132,9 @@ end @inline function apply_batchnorm_scale_bias_act_3d_threaded_cpu!( y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{xT, 3}, σ::F) where {F, xT, yT} - @batch for K in indices((x, y), 3) - for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) - @simd ivdep for I in indices((x, y), 1) + @batch for K in axes(x, 3) + for J in axes(x, 2) + @simd ivdep for I in axes(x, 1) @fastmath @inbounds y[I, J, K] = σ(x[I, J, K] * γ′[J] + β′[J]) end end @@ -144,9 +144,9 @@ end @inline function apply_batchnorm_scale_bias_act_3d_serial_cpu!( y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{xT, 3}, σ::F) where {F, xT, yT} - for K in indices((x, y), 3) - for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) - @simd ivdep for I in indices((x, y), 1) + for K in axes(x, 3) + for J in axes(x, 2) + @simd ivdep for I in axes(x, 1) @fastmath @inbounds y[I, J, K] = σ(x[I, J, K] * γ′[J] + β′[J]) end end @@ -167,8 +167,8 @@ end @inline function apply_batchnorm_scale_bias_2d_serial_cpu!( y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{xT, 3}) where {xT, yT} - for K in indices((x, y), 3) - @simd ivdep for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) + for K in axes(x, 3) + @simd ivdep for J in axes(x, 2) @fastmath @inbounds y[1, J, K] = x[1, J, K] * γ′[J] + β′[J] end end @@ -177,9 +177,9 @@ end @inline function apply_batchnorm_scale_bias_3d_threaded_cpu!( y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{xT, 3}) where {xT, yT} - @batch for K in indices((x, y), 3) - for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) - @simd ivdep for I in indices((x, y), 1) + @batch for K in axes(x, 3) + for J in axes(x, 2) + @simd ivdep for I in axes(x, 1) @fastmath @inbounds y[I, J, K] = x[I, J, K] * γ′[J] + β′[J] end end @@ -189,9 +189,9 @@ end @inline function apply_batchnorm_scale_bias_3d_serial_cpu!( y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, x::AbstractArray{xT, 3}) where {xT, yT} - for K in indices((x, y), 3) - for J in indices((x, y, γ′, β′), (2, 2, 1, 1)) - @simd ivdep for I in indices((x, y), 1) + for K in axes(x, 3) + for J in axes(x, 2) + @simd ivdep for I in axes(x, 1) @fastmath @inbounds y[I, J, K] = x[I, J, K] * γ′[J] + β′[J] end end @@ -307,8 +307,8 @@ function ∇batchnorm_affine_normalize_cpu!( fill!(∂σ², 0) if size(∂y, 1) == 1 - @fastmath @inbounds for K in indices(∂y, 3) - @simd for J in indices(∂y, 2) + @fastmath @inbounds for K in axes(∂y, 3) + @simd for J in axes(∂y, 2) idenom = γ′[J] idenom² = idenom^2 @@ -320,11 +320,11 @@ function ∇batchnorm_affine_normalize_cpu!( end end else - @fastmath @inbounds for K in indices(∂y, 3), J in indices(∂y, 2) + @fastmath @inbounds for K in axes(∂y, 3), J in axes(∂y, 2) idenom = γ′[J] idenom² = idenom^2 - @simd for I in indices(∂y, 1) + @simd for I in axes(∂y, 1) xμ = x[I, J, K] - μ[J] ∂x[I, J, K] = ∂y[I, J, K] * idenom @@ -349,8 +349,8 @@ function ∇batchnorm_affine_normalize_cpu!( fill!(∂β, 0) if size(∂y, 1) == 1 - @fastmath @inbounds for K in indices(∂y, 3) - @simd for J in indices(∂y, 2) + @fastmath @inbounds for K in axes(∂y, 3) + @simd for J in axes(∂y, 2) idenom = inv(sqrt(σ²[J] + ϵ)) idenom² = idenom^2 @@ -364,11 +364,11 @@ function ∇batchnorm_affine_normalize_cpu!( end end else - @fastmath @inbounds for K in indices(∂y, 3), J in indices(∂y, 2) + @fastmath @inbounds for K in axes(∂y, 3), J in axes(∂y, 2) idenom = inv(sqrt(σ²[J] + ϵ)) idenom² = idenom^2 - @simd for I in indices(∂y, 1) + @simd for I in axes(∂y, 1) xμ = x[I, J, K] - μ[J] ∂x[I, J, K] = ∂y[I, J, K] * γ′[J] diff --git a/src/impl/bias_activation.jl b/src/impl/bias_activation.jl index a84fd152..f96531a7 100644 --- a/src/impl/bias_activation.jl +++ b/src/impl/bias_activation.jl @@ -194,38 +194,21 @@ end function bias_activation_cpu!(y::AbstractArray{yT, 3}, ::False, σ::F, x::AbstractArray{xT, 3}, bias::AbstractVector) where {F, xT, yT} - if !LV.check_args(y, x, bias) - bias_activation_simd_loop!(y, σ, x, bias) - return - end - bias_activation_loop!(y, σ, x, bias) + bias_activation_simd_loop!(y, σ, x, bias) return end -function bias_activation_loop!(y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, - bias::AbstractVector) where {F, xT, yT} - if size(y, 1) == 1 - @tturbo for K in indices(x, 3), J in indices((x, bias), (2, 1)) - y[1, J, K] = σ(x[1, J, K] + bias[J]) - end - else - @tturbo for K in indices(x, 3), J in indices((x, bias), (2, 1)), I in indices(y, 1) - y[I, J, K] = σ(x[I, J, K] + bias[J]) - end - end -end - function bias_activation_simd_loop!(y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, bias::AbstractVector) where {F, xT, yT} if size(y, 1) == 1 - for K in indices(x, 3) - @simd ivdep for J in indices((x, bias), (2, 1)) + for K in axes(x, 3) + @simd ivdep for J in axes(x, 2) @inbounds y[1, J, K] = σ(x[1, J, K] + bias[J]) end end else - for K in indices(x, 3), J in indices((x, bias), (2, 1)) - @simd ivdep for I in indices(y, 1) + for K in axes(x, 3), J in axes(x, 2) + @simd ivdep for I in axes(x, 1) @inbounds y[I, J, K] = σ(x[I, J, K] + bias[J]) end end @@ -233,8 +216,6 @@ function bias_activation_simd_loop!(y::AbstractArray{yT, 3}, σ::F, x::AbstractA return end -@enzyme_alternative bias_activation_loop! bias_activation_simd_loop! - function bias_add!(y::AbstractArray{yT, N}, ::AbstractInternalArrayOpMode, x::AbstractArray{xT, N}, bias::AbstractVector) where {N, xT, yT} broadcast!(+, y, x, reshape_bias(x, bias)) @@ -251,14 +232,14 @@ end function bias_add_loop!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, 3}, bias::AbstractVector) where {xT, yT} if size(y, 1) == 1 - for K in indices(x, 3) - @simd ivdep for J in indices((x, bias), (2, 1)) + for K in axes(x, 3) + @simd ivdep for J in axes(x, 2) @inbounds y[1, J, K] = x[1, J, K] + bias[J] end end else - for K in indices(x, 3), J in indices((x, bias), (2, 1)) - @simd ivdep for I in indices(y, 1) + for K in axes(x, 3), J in axes(x, 2) + @simd ivdep for I in axes(y, 1) @inbounds y[I, J, K] = x[I, J, K] + bias[J] end end diff --git a/src/impl/dropout.jl b/src/impl/dropout.jl index 64d28fa5..5b424829 100644 --- a/src/impl/dropout.jl +++ b/src/impl/dropout.jl @@ -80,29 +80,16 @@ function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArra p::Real, x::AbstractArray, α::Real, A::Real, B::Real) cond = similar(noise, Bool) y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) - if LV.check_args(noise, x, y, cond) - @tturbo for I in indices((noise, x, y, cond)) - cond[I] = noise[I] > p - y[I] = ifelse(cond[I], x[I], α) * A + B - end - else - @batch for I in indices((noise, x, y, cond)) - cond[I] = noise[I] > p - y[I] = ifelse(cond[I], x[I], α) * A + B - end + @simd ivdep for I in eachindex(noise, x, y, cond) + @inbounds cond[I] = noise[I] > p + @inbounds y[I] = ifelse(cond[I], x[I], α) * A + B end ∇alpha_dropout = let cond = cond, 𝒫x = CRC.ProjectTo(x), x = x Δ -> begin ∂x = similar(x) - if LV.check_args(∂x, cond, Δ) - @tturbo for I in indices((∂x, cond, Δ)) - ∂x[I] = cond[I] * Δ[I] * A - end - else - @batch for I in indices((∂x, cond, Δ)) - ∂x[I] = cond[I] * Δ[I] * A - end + @simd ivdep for I in eachindex(cond, Δ, ∂x) + @inbounds ∂x[I] = cond[I] * Δ[I] * A end return (ntuple(Returns(∂∅), 4)..., 𝒫x(∂x), ntuple(Returns(∂∅), 3)...) end @@ -125,29 +112,14 @@ function CRC.rrule(::typeof(alpha_dropout), ::AbstractInternalArrayOpMode, return y, ∇alpha_dropout end -function alpha_dropout!(res::AbstractArray, ::LoopedArrayOp, noise::AbstractArray, - p::Real, x::AbstractArray, α::Real, A::Real, B::Real) - if LV.check_args(noise, x, res) - @tturbo for I in indices((noise, x, res)) - res[I] = ifelse(noise[I] > p, x[I], α) * A + B - end - else - @batch for I in indices((noise, x, res)) - res[I] = ifelse(noise[I] > p, x[I], α) * A + B - end - end -end - -function alpha_dropout_simd_loop!( +function alpha_dropout!( res::AbstractArray{T}, ::LoopedArrayOp, noise::AbstractArray{T}, p::Real, x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T} - @simd ivdep for I in indices((noise, x, res)) + @simd ivdep for I in eachindex(noise, x, res) res[I] = ifelse(noise[I] > p, x[I], α) * A + B end end -@enzyme_alternative alpha_dropout! alpha_dropout_simd_loop! - dropout_fptype(x) = float(real(remove_tracking(eltype(x)))) CRC.@non_differentiable dropout_fptype(::Any...) @@ -177,27 +149,13 @@ function generate_dropout_mask!(y::AbstractArray, ::LoopedArrayOp, p, invp) return end -function generate_dropout_mask_loop!(y::AbstractArray, p, invp) - if LV.check_args(y) - @tturbo for I in indices(y) - y[I] = (y[I] > p) * invp - end - else - @batch for I in indices(y) - y[I] = (y[I] > p) * invp - end - end -end - -function generate_dropout_mask_simd_loop!(y::AbstractArray{T}, p, invp) where {T} +function generate_dropout_mask_loop!(y::AbstractArray{T}, p, invp) where {T} p, invp = T(p), T(invp) - @simd ivdep for I in indices(y) + @simd ivdep for I in eachindex(y) y[I] = (y[I] > p) * invp end end -@enzyme_alternative generate_dropout_mask_loop! generate_dropout_mask_simd_loop! - function generate_dropout_mask!( y::AbstractArray{T}, ::AbstractInternalArrayOpMode, p, invp) where {T} p, invp = T(p), T(invp) diff --git a/src/impl/groupnorm.jl b/src/impl/groupnorm.jl index 4ebc70c3..9a64fd73 100644 --- a/src/impl/groupnorm.jl +++ b/src/impl/groupnorm.jl @@ -95,17 +95,17 @@ function groupnorm_affine_normalize_act_3d_serial_cpu!( σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, σ::F) where {F, xT, yT, μT, σ²T} if γ === nothing && β === nothing - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) β′ = -μ[1, 1, K, L] * γ′ - @simd ivdep for J in indices(y, 2) + @simd ivdep for J in axes(y, 2) y[1, J, K, L] = σ(x[1, J, K, L] * γ′ + β′) end end else - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - @simd for J in indices(y, 2) + @simd for J in axes(y, 2) γ′ = γ[1, J, K, 1] * idenom β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ y[1, J, K, L] = σ(x[1, J, K, L] * γ′ + β′) @@ -119,22 +119,22 @@ function groupnorm_affine_normalize_act_4d_serial_cpu!( σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, σ::F) where {F, xT, yT, μT, σ²T} if γ === nothing && β === nothing - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) β′ = -μ[1, 1, K, L] * γ′ - for J in indices(y, 2) - @simd ivdep for I in indices(y, 1) + for J in axes(y, 2) + @simd ivdep for I in axes(y, 1) y[I, J, K, L] = σ(x[I, J, K, L] * γ′ + β′) end end end else - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - for J in indices(y, 2) + for J in axes(y, 2) γ′ = γ[1, J, K, 1] * idenom β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ - @simd ivdep for I in indices(y, 1) + @simd ivdep for I in axes(y, 1) y[I, J, K, L] = σ(x[I, J, K, L] * γ′ + β′) end end @@ -158,17 +158,17 @@ end σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} if γ === nothing && β === nothing - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) β′ = -μ[1, 1, K, L] * γ′ - @simd ivdep for J in indices(y, 2) + @simd ivdep for J in axes(y, 2) y[1, J, K, L] = x[1, J, K, L] * γ′ + β′ end end else - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - @simd for J in indices(y, 2) + @simd for J in axes(y, 2) γ′ = γ[1, J, K, 1] * idenom β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ y[1, J, K, L] = x[1, J, K, L] * γ′ + β′ @@ -182,22 +182,22 @@ end σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} if γ === nothing && β === nothing - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) β′ = -μ[1, 1, K, L] * γ′ - for J in indices(y, 2) - @simd ivdep for I in indices(y, 1) + for J in axes(y, 2) + @simd ivdep for I in axes(y, 1) y[I, J, K, L] = x[I, J, K, L] * γ′ + β′ end end end else - @fastmath @inbounds for L in indices(y, 4), K in indices(y, 3) + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) - for J in indices(y, 2) + for J in axes(y, 2) γ′ = γ[1, J, K, 1] * idenom β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ - @simd ivdep for I in indices(y, 1) + @simd ivdep for I in axes(y, 1) y[I, J, K, L] = x[I, J, K, L] * γ′ + β′ end end @@ -305,11 +305,11 @@ function ∇groupnorm_affine_normalize_cpu!( fill!(∂σ², 0) if size(∂y, 1) == 1 - @fastmath @inbounds for L in indices(∂y, 4), K in indices(∂y, 3) + @fastmath @inbounds for L in axes(∂y, 4), K in axes(∂y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - @simd for J in indices(∂y, 2) + @simd for J in axes(∂y, 2) xμ = x[1, J, K, L] - μ[1, 1, K, L] ∂x[1, J, K, L] = ∂y[1, J, K, L] * idenom @@ -318,12 +318,12 @@ function ∇groupnorm_affine_normalize_cpu!( end end else - @fastmath @inbounds for L in indices(∂y, 4), K in indices(∂y, 3) + @fastmath @inbounds for L in axes(∂y, 4), K in axes(∂y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - for J in indices(∂y, 2) - @simd for I in indices(∂y, 1) + for J in axes(∂y, 2) + @simd for I in axes(∂y, 1) xμ = x[I, J, K, L] - μ[1, 1, K, L] ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom @@ -349,11 +349,11 @@ function ∇groupnorm_affine_normalize_cpu!( fill!(∂β, 0) if size(∂y, 1) == 1 - @fastmath @inbounds for L in indices(∂y, 4), K in indices(∂y, 3) + @fastmath @inbounds for L in axes(∂y, 4), K in axes(∂y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - @simd for J in indices(∂y, 2) + @simd for J in axes(∂y, 2) γ′ = γ[1, J, K, 1] * idenom xμ = x[1, J, K, L] - μ[1, 1, K, L] @@ -366,13 +366,13 @@ function ∇groupnorm_affine_normalize_cpu!( end end else - @fastmath @inbounds for L in indices(∂y, 4), K in indices(∂y, 3) + @fastmath @inbounds for L in axes(∂y, 4), K in axes(∂y, 3) idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) idenom² = idenom^2 - for J in indices(∂y, 2) + for J in axes(∂y, 2) γ′ = γ[1, J, K, 1] * idenom - @simd for I in indices(∂y, 1) + @simd for I in axes(∂y, 1) xμ = x[I, J, K, L] - μ[1, 1, K, L] ∂x[I, J, K, L] = ∂y[I, J, K, L] * γ′ diff --git a/src/impl/matmul.jl b/src/impl/matmul.jl index 13f643bf..e202df32 100644 --- a/src/impl/matmul.jl +++ b/src/impl/matmul.jl @@ -67,7 +67,7 @@ end function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - if LV.check_args(C, A, B, bias) && fits_in_l2cache(C, A, B, bias) + if can_loopvec_args(C, A, B, bias) && fits_in_l2cache(C, A, B, bias) matmuladd_loopvec!(C, A, B, bias) return end @@ -95,7 +95,7 @@ for spl_blas in (True, False) function matmul_cpu!( # Octavian can be used C::AbstractMatrix, ::True, ::$(spl_blas), A::AbstractMatrix, B::AbstractMatrix) - if LV.check_args(C, A, B) + if can_loopvec_args(C, A, B) if fits_in_l1cache(C, A, B) matmul_loopvec!(C, A, B, true, false) return @@ -112,7 +112,7 @@ for spl_blas in (True, False) function matmul_cpu!( # Octavian cannot be used C::AbstractMatrix, ::False, ::$(spl_blas), A::AbstractMatrix, B::AbstractMatrix) - if LV.check_args(C, A, B) + if can_loopvec_args(C, A, B) if $(unsafe_known(spl_blas()) ? fits_in_l1cache : fits_in_l2cache)(C, A, B) matmul_loopvec!(C, A, B, true, false) return @@ -126,11 +126,6 @@ end # Low-Level Matmul implementations -- Either call libraries or implement our own # We force inlining here to avoid allocations in the inner loops -@inline function matmul_octavian!( - C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) - Octavian.matmul!(C, A, B, α, β) - return -end # Best case fallback, we are likely going to hit BLAS @inline function matmul_cpu_fallback!(C::AbstractMatrix{T}, A::AbstractMatrix{T}, @@ -141,7 +136,7 @@ end @inline function matmul_cpu_fallback!(C::AbstractMatrix{T}, A::AbstractMatrix{AT}, B::AbstractMatrix{BT}, α::Number, β::Number) where {T, AT, BT} - if LV.check_args(C, A, B) # Use Octavian if possible. Don't check via `use_octavian()` + if can_loopvec_args(C, A, B) && unsafe_known(is_extension_loaded(Val(:Octavian))) matmul_octavian!(C, A, B, α, β) return end @@ -163,41 +158,11 @@ end return end -for serial in (true, false) - opname = serial ? :serial_matmul_loopvec! : :matmul_loopvec! - @eval @inline function $opname( - C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) - if !iszero(β) # Secial case this because Base.FastMath.mul_fast(NaN, false) = NaN - @turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1) - Cⱼₖ = zero(eltype(C)) - for I in indices((A, B), (2, 1)) - Cⱼₖ += A[J, I] * B[I, K] - end - C[J, K] = α * Cⱼₖ + β * C[J, K] - end - else - @turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1) - Cⱼₖ = zero(eltype(C)) - for I in indices((A, B), (2, 1)) - Cⱼₖ += A[J, I] * B[I, K] - end - C[J, K] = α * Cⱼₖ - end - end - end -end +function serial_matmul_loopvec! end +function matmul_loopvec! end +function matmuladd_loopvec! end -@inline function matmuladd_loopvec!( - C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) - @tturbo for K in indices((C, B), 2), J in indices((C, A), 1) - Cⱼₖ = zero(eltype(C)) - for I in indices((A, B), (2, 1)) - Cⱼₖ += A[J, I] * B[I, K] - end - C[J, K] = bias[J] + Cⱼₖ - end - return -end +function matmul_octavian! end @inline function matmuladd_cpu_fallback!( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) diff --git a/src/impl/normalization.jl b/src/impl/normalization.jl index 9afc4cde..f9dafcdf 100644 --- a/src/impl/normalization.jl +++ b/src/impl/normalization.jl @@ -43,7 +43,7 @@ end function update_running_statistics_simd_loop!( rμₙ, rσ²ₙ, ::LoopedArrayOp, rμ, rσ², μ, σ², m₁, m₂, m₃) - @simd ivdep for I in indices((rμₙ, rσ²ₙ)) + @simd ivdep for I in eachindex(rμₙ, rσ²ₙ) rμₙ[I] = m₃ * rμ[I] + m₁ * μ[I] rσ²ₙ[I] = m₃ * rσ²[I] + m₂ * σ²[I] end diff --git a/src/traits.jl b/src/traits.jl index 7f660da5..29d3dc1e 100644 --- a/src/traits.jl +++ b/src/traits.jl @@ -80,6 +80,7 @@ using ChainRulesCore: ChainRulesCore using Hwloc: Hwloc using Static: static, False, True +using ..LuxLib: DISABLE_LOOP_VECTORIZATION using ..Utils: is_extension_loaded, safe_minimum const CRC = ChainRulesCore @@ -130,7 +131,14 @@ end CRC.@non_differentiable explicit_blas_loaded() -use_octavian() = is_x86_64() & (INTEL_HARDWARE | AMD_RYZEN_HARDWARE) +@static if DISABLE_LOOP_VECTORIZATION + use_octavian() = False() +else + function use_octavian() + return is_extension_loaded(Val(:Octavian)) & is_x86_64() & + (INTEL_HARDWARE | AMD_RYZEN_HARDWARE) + end +end CRC.@non_differentiable use_octavian() diff --git a/src/utils.jl b/src/utils.jl index 0639b5d5..0104457c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -11,13 +11,16 @@ using NNlib: NNlib using Static: Static, StaticBool, False, True, static using StaticArraysCore: SVector, SMatrix -using ..LuxLib: Optional, ∂∅ +using ..LuxLib: Optional, ∂∅, DISABLE_LOOP_VECTORIZATION const CRC = ChainRulesCore const KA = KernelAbstractions is_extension_loaded(::Val) = False() +CRC.@non_differentiable is_extension_loaded(::Any...) +EnzymeRules.inactive_noinl(::typeof(is_extension_loaded), ::Any...) = nothing + # Simple Operations -- no rrules needed ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x function ofeltype_array( @@ -322,4 +325,18 @@ end CRC.@non_differentiable static_training_mode_check(::Any...) +@static if DISABLE_LOOP_VECTORIZATION + @inline can_loopvec_args(args...) = false +else + @inline function can_loopvec_args(args...) + return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...) + end +end + +@inline can_loopvec_args_check(::False, args...) = false + +CRC.@non_differentiable can_loopvec_args_check(::Any...) + +EnzymeRules.inactive_noinl(::typeof(can_loopvec_args_check), ::Any...) = nothing + end diff --git a/test/Project.toml b/test/Project.toml index 3b238301..1005c488 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,10 +12,12 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -44,10 +46,12 @@ ForwardDiff = "0.10.36" Hwloc = "3.2" InteractiveUtils = "<0.0.1, 1" JLArrays = "0.1.5" +LoopVectorization = "0.12.171" LuxTestUtils = "1.2.1" MKL = "0.7" MLDataDevices = "1.0.0" NNlib = "0.9.21" +Octavian = "0.3.28" Pkg = "1.10" Preferences = "1.4.3" Random = "1.10" diff --git a/test/common_ops/activation_tests.jl b/test/common_ops/activation_tests.jl index 2045f20f..e2b80e71 100644 --- a/test/common_ops/activation_tests.jl +++ b/test/common_ops/activation_tests.jl @@ -36,7 +36,7 @@ @jet apply_act_fast2(f, x) @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any - if f !== lisht || (f === lisht && T == Float32 && !ongpu) + if f !== lisht @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any end @test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any diff --git a/test/common_ops/bias_act_tests.jl b/test/common_ops/bias_act_tests.jl index 1429c9b2..3b2f22d0 100644 --- a/test/common_ops/bias_act_tests.jl +++ b/test/common_ops/bias_act_tests.jl @@ -44,12 +44,9 @@ @jet bias_act_loss2(act, x, b) @jet bias_act_loss3(act, x, b) - if (act !== lisht || (act === lisht && T == Float32 && !ongpu)) && T != Float16 + if act !== lisht && T != Float16 @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any - elseif T != Float16 - @test_broken @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any - @test_broken @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any end @test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol, diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index 487a50d5..2ba51d0a 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -8,6 +8,10 @@ LuxTestUtils.jet_target_modules!(["LuxLib"]) const LUXLIB_BLAS_BACKEND = lowercase(get(ENV, "LUXLIB_BLAS_BACKEND", "default")) +if parse(Bool, get(ENV, "LUXLIB_LOAD_LOOPVEC", "true")) + import LoopVectorization, Octavian +end + if LUXLIB_BLAS_BACKEND == "default" @info "Using default BLAS backend: OpenBLAS" elseif LUXLIB_BLAS_BACKEND == "appleaccelerate"