From 54904a37026d0cefecb7f1a789d48079eadf3284 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Jul 2024 18:11:33 -0700 Subject: [PATCH] feat: add a public version of OOP activation --- src/LuxLib.jl | 2 +- src/api/activation.jl | 23 +++++++++++++++++++++++ test/common_ops/activation_tests.jl | 14 ++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/LuxLib.jl b/src/LuxLib.jl index 2d55589d..8f41e597 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -57,7 +57,7 @@ include("deprecations.jl") export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout export fused_dense_bias_activation, fused_conv_bias_activation -export fast_activation!! +export fast_activation, fast_activation!! export bias_activation, bias_activation!! end diff --git a/src/api/activation.jl b/src/api/activation.jl index 14815593..2599f1ac 100644 --- a/src/api/activation.jl +++ b/src/api/activation.jl @@ -39,3 +39,26 @@ function _fast_activation!!(::Val{false}, σ::F, x::AbstractArray) where {F} _fast_activation!(σ, x) return x end + +""" + fast_activation(σ::F, x::AbstractArray) where {F} + +Compute `σ.(x)` with the best possible implementation available. On CPUs we unroll the +loop and use LoopVectorization.jl to vectorize the computation. On GPUs we use simply use +broadcasting. + +!!! note + + This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be + done by the user if needed. + +## Arguments + + - `σ`: Activation function + - `x`: Input array + +## Returns + + - Output Array with the same size as `x` +""" +fast_activation(σ::F, x::AbstractArray) where {F} = _fast_activation(σ, x) diff --git a/test/common_ops/activation_tests.jl b/test/common_ops/activation_tests.jl index d4af9f0f..2c99bf72 100644 --- a/test/common_ops/activation_tests.jl +++ b/test/common_ops/activation_tests.jl @@ -3,6 +3,7 @@ apply_act(f::F, x) where {F} = sum(abs2, f.(x)) apply_act_fast(f::F, x) where {F} = sum(abs2, fast_activation!!(f, copy(x))) + apply_act_fast2(f::F, x) where {F} = sum(abs2, fast_activation(f, x)) @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$f: $T" for f in [identity, relu, sigmoid, sigmoid_fast, softplus, @@ -13,26 +14,39 @@ y1 = apply_act(f, x) y2 = apply_act_fast(f, x) + y3 = apply_act_fast2(f, x) fp16 = T == Float16 atol = fp16 ? 1.0f-1 : 1.0f-3 rtol = fp16 ? 1.0f-1 : 1.0f-3 @test y1≈y2 atol=atol rtol=rtol + @test y1≈y3 atol=atol rtol=rtol @test eltype(y1) == T + @test eltype(y2) == T + @test eltype(y3) == T @test @inferred(apply_act(f, x)) isa Any @test @inferred(apply_act_fast(f, x)) isa Any + @test @inferred(apply_act_fast2(f, x)) isa Any + @jet apply_act_fast(f, x) + @jet apply_act_fast2(f, x) + @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any + @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any + @test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any test_gradients(Base.Fix1(apply_act, f), x; atol, rtol) test_gradients(Base.Fix1(apply_act_fast, f), x; atol, rtol) + test_gradients(Base.Fix1(apply_act_fast2, f), x; atol, rtol) ∂x1 = Zygote.gradient(apply_act, f, x)[2] ∂x2 = Zygote.gradient(apply_act_fast, f, x)[2] + ∂x3 = Zygote.gradient(apply_act_fast2, f, x)[2] @test ∂x1≈∂x2 atol=atol rtol=rtol + @test ∂x1≈∂x3 atol=atol rtol=rtol end end end