Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

feat: use LoopVectorization for faster operations #111

Merged
merged 25 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
362a916
test: bug fixes and use correct threads
avik-pal Jul 31, 2024
6509016
feat: use LoopVectorization for faster operations
avik-pal Jul 30, 2024
5e54f29
fix: rework matmul to use operation modes
avik-pal Jul 30, 2024
6173637
feat: add rrules for `matmul` and `matmuladd`
avik-pal Jul 30, 2024
46feb93
feat: replace mean and var with VectorizedStatistics
avik-pal Jul 30, 2024
bbb8787
feat: add EnzymeRules for `matmul!` and `matmuladd!`
avik-pal Jul 30, 2024
d5445ba
feat: add EnzymeRules for `_alpha_dropout_kernel!`
avik-pal Jul 30, 2024
4724704
feat: add EnzymeRules for `_fast_activation!`
avik-pal Jul 30, 2024
1afd91b
refactor: remove unwanted reshapes in BN impl
avik-pal Jul 30, 2024
267b81b
docs: add perf note on LV to dense
avik-pal Jul 31, 2024
73b8961
feat: add a public version of OOP activation
avik-pal Jul 31, 2024
59145df
fix: instance norm gradients with enzyme
avik-pal Jul 31, 2024
585db59
feat: bias activation enzyme rules
avik-pal Jul 31, 2024
6aa1dfe
perf: tune the impls a bit
avik-pal Jul 31, 2024
e83277e
refactor: restructure normalization functions
avik-pal Jul 31, 2024
6193da5
fix: support batchnorm and groupnorm for enzyme bypassing turbo
avik-pal Jul 31, 2024
10fa8bf
fix: dimension checks for matmul
avik-pal Jul 31, 2024
482df28
fix: error in enzyme gradient for matmul
avik-pal Aug 1, 2024
63cb023
refactor: use macro to bypass loopvectorization
avik-pal Aug 1, 2024
8c0334c
fix: run LV matmul only if check_args is true
avik-pal Aug 1, 2024
565b8f1
chore: run formatter
avik-pal Aug 1, 2024
ded928f
fix: dispatch to loopvec for groupnorm
avik-pal Aug 1, 2024
953b710
perf: upperbound LV usage
avik-pal Aug 1, 2024
eef1dc0
fix: wrong function in macro
avik-pal Aug 1, 2024
bedde7b
perf: revert upperbound LV usage
avik-pal Aug 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.37"
version = "0.3.38"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand All @@ -12,13 +12,15 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"
Expand Down Expand Up @@ -55,6 +57,7 @@ InteractiveUtils = "<0.0.1, 1"
JLArrays = "0.1.5"
KernelAbstractions = "0.9.22"
LinearAlgebra = "1.10"
LoopVectorization = "0.12.171"
LuxCore = "0.1.13"
LuxTestUtils = "1.1"
MLDataDevices = "1.0.0"
Expand All @@ -67,6 +70,7 @@ ReTestItems = "1.23.1"
Reexport = "1"
ReverseDiff = "1.15"
SLEEFPirates = "0.6.43"
Setfield = "1.1.1"
StableRNGs = "1"
StaticArrays = "1.9"
StaticArraysCore = "1.4.3"
Expand Down
7 changes: 5 additions & 2 deletions src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ using FastClosures: @closure
using ForwardDiff: ForwardDiff
using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index
using LinearAlgebra: LinearAlgebra, BLAS, mul!
using LoopVectorization: indices, @tturbo
using LuxCore: LuxCore
using Markdown: @doc_str
using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice,
AbstractGPUDevice, AbstractDevice
using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv_filter
using Random: Random, AbstractRNG, rand!
using Reexport: @reexport
using StaticArraysCore: StaticArraysCore, StaticVector
using Setfield: @set!
using StaticArraysCore: StaticArraysCore, StaticArray, StaticVector
using Statistics: Statistics, mean, var
using SLEEFPirates: SLEEFPirates
using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter, unrolled_mapreduce
Expand Down Expand Up @@ -48,13 +50,14 @@ include("impl/fast_ops.jl")
include("impl/fused_dense.jl")
include("impl/fused_conv.jl")
include("impl/forward_diff.jl")
include("impl/matmul.jl")
include("impl/normalization.jl")

include("deprecations.jl")

export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout
export fused_dense_bias_activation, fused_conv_bias_activation
export fast_activation!!
export fast_activation, fast_activation!!
export bias_activation, bias_activation!!

end
23 changes: 23 additions & 0 deletions src/api/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,26 @@ function _fast_activation!!(::Val{false}, σ::F, x::AbstractArray) where {F}
_fast_activation!(σ, x)
return x
end

"""
fast_activation(σ::F, x::AbstractArray) where {F}

Compute `σ.(x)` with the best possible implementation available. On CPUs we unroll the
loop and use LoopVectorization.jl to vectorize the computation. On GPUs we use simply use
broadcasting.

!!! note

This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be
done by the user if needed.

## Arguments

- `σ`: Activation function
- `x`: Input array

## Returns

- Output Array with the same size as `x`
"""
fast_activation(σ::F, x::AbstractArray) where {F} = _fast_activation(σ, x)
4 changes: 1 addition & 3 deletions src/api/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@ multiple operations.

## Notes on implementation

- Despite the naming, currently only the activation (σ) is fused with the bias addition.
Currently this is equivalent to using matrix multiply followed by `NNlib.bias_act!`,
though this function doesn't call those operations.
- If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to
the generic non-mutating implementation.
- Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD
backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff`
fallback to the generic implementation.
- For CUDA Arrays, this uses a special fused implementation via cuBLASLt.
- For small CPU Arrays, we use LoopVectorization.jl.
"""
function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix,
b::Optional{<:AbstractVector}) where {F}
Expand Down
2 changes: 1 addition & 1 deletion src/api/instancenorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
end

function _test_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N}
N > 2 || throw(ArgumentError("`ndims(x) = $(N)` must be at least 2."))
N > 2 || throw(ArgumentError("`ndims(x) = $(N)` must be at least > 2."))
return nothing
end

Expand Down
58 changes: 51 additions & 7 deletions src/impl/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ function __activation_gradient(Δ, out, act::F, x) where {F}
@inbounds y[i] = only_derivative(out[i], act, x) * Δ[i]
end
else
@simd ivdep for i in eachindex(Δ, out, x)
@inbounds y[i] = only_derivative(out[i], act, x[i]) * Δ[i]
@simd ivdep for I in eachindex(Δ, out, x)
@inbounds y[I] = only_derivative(out[I], act, x[I]) * Δ[I]
end
end
return y
Expand All @@ -19,15 +19,53 @@ function __activation_gradient(Δ, out, act::F, x) where {F}
return broadcast(only_deriv, Δ, out, x)
end

function _fast_activation!(opmode, y::AbstractArray, σ::F, x::AbstractArray) where {F}
broadcast!(σ, y, x)
return
end
function _fast_activation!(
::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F}
@tturbo for I in indices((y, x))
y[I] = σ(x[I])
end
end

function _fast_activation_no_turbo!(
::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F}
@simd ivdep for I in eachindex(y, x)
@inbounds y[I] = σ(x[I])
y[I] = σ(x[I])
end
end
function _fast_activation!(opmode, y::AbstractArray, σ::F, x::AbstractArray) where {F}
broadcast!(σ, y, x)
return

function EnzymeRules.augmented_primal(
cfg::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(_fast_activation!)},
::Type{RT}, opmode::EnzymeCore.Const{LoopedArrayOp},
y::EnzymeCore.Duplicated{<:AbstractArray}, σ::EnzymeCore.Const{F},
x::EnzymeCore.Duplicated{<:AbstractArray}) where {F, RT}
dx = one.(x.val)
dy = zero.(y.val)
EnzymeCore.autodiff(EnzymeCore.Forward, _fast_activation_no_turbo!,
opmode, EnzymeCore.Duplicated(y.val, dy),
EnzymeCore.Const(σ.val), EnzymeCore.Duplicated(x.val, dx))

primal = EnzymeRules.needs_primal(cfg) ? y.val : nothing
shadow = EnzymeRules.needs_shadow(cfg) ? y.dval : nothing

return EnzymeRules.AugmentedReturn(primal, shadow, (dy,))
end

function EnzymeRules.reverse(
::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(_fast_activation!)},
::Type{RT}, (dy,), opmode::EnzymeCore.Const{LoopedArrayOp},
y::EnzymeCore.Duplicated{<:AbstractArray}, σ::EnzymeCore.Const{F},
x::EnzymeCore.Duplicated{<:AbstractArray}) where {F, RT}
@tturbo for I in indices((y.dval, x.dval, dy))
x.dval[I] = y.dval[I] * dy[I]
end

x.dval !== y.dval && fill!(y.dval, false)

return nothing, nothing, nothing, nothing
end

# Entry Points to the implementation
Expand Down Expand Up @@ -155,11 +193,17 @@ function EnzymeRules.augmented_primal(
end

function EnzymeRules.reverse(
cfg::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu_sleefpirates)},
::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu_sleefpirates)},
dret::EnzymeCore.Active, ::Nothing, x::EnzymeCore.Active{<:Number})
return (dret.val * ∂gelu_sleefpirates(x.val),)
end

function EnzymeRules.forward(::EnzymeCore.Const{typeof(gelu_sleefpirates)},
::Type{<:EnzymeCore.Duplicated}, x::EnzymeCore.Duplicated{<:Number})
return EnzymeCore.Duplicated(
gelu_sleefpirates(x.val), x.dval * ∂gelu_sleefpirates(x.val))
end

# Convert to SLEEFPirates.jl
function select_fastest_activation(f::F, xs...) where {F}
return select_fastest_activation(
Expand Down
Loading
Loading