Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: add a public version of OOP activation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 31, 2024
1 parent 875f419 commit 54904a3
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ include("deprecations.jl")

export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout
export fused_dense_bias_activation, fused_conv_bias_activation
export fast_activation!!
export fast_activation, fast_activation!!
export bias_activation, bias_activation!!

end
23 changes: 23 additions & 0 deletions src/api/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,26 @@ function _fast_activation!!(::Val{false}, σ::F, x::AbstractArray) where {F}
_fast_activation!(σ, x)
return x
end

"""
fast_activation(σ::F, x::AbstractArray) where {F}
Compute `σ.(x)` with the best possible implementation available. On CPUs we unroll the
loop and use LoopVectorization.jl to vectorize the computation. On GPUs we use simply use
broadcasting.
!!! note
This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be
done by the user if needed.
## Arguments
- `σ`: Activation function
- `x`: Input array
## Returns
- Output Array with the same size as `x`
"""
fast_activation::F, x::AbstractArray) where {F} = _fast_activation(σ, x)
14 changes: 14 additions & 0 deletions test/common_ops/activation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

apply_act(f::F, x) where {F} = sum(abs2, f.(x))
apply_act_fast(f::F, x) where {F} = sum(abs2, fast_activation!!(f, copy(x)))
apply_act_fast2(f::F, x) where {F} = sum(abs2, fast_activation(f, x))

@testset "$mode" for (mode, aType, ongpu) in MODES
@testset "$f: $T" for f in [identity, relu, sigmoid, sigmoid_fast, softplus,
Expand All @@ -13,26 +14,39 @@

y1 = apply_act(f, x)
y2 = apply_act_fast(f, x)
y3 = apply_act_fast2(f, x)

fp16 = T == Float16
atol = fp16 ? 1.0f-1 : 1.0f-3
rtol = fp16 ? 1.0f-1 : 1.0f-3

@test y1y2 atol=atol rtol=rtol
@test y1y3 atol=atol rtol=rtol
@test eltype(y1) == T
@test eltype(y2) == T
@test eltype(y3) == T

@test @inferred(apply_act(f, x)) isa Any
@test @inferred(apply_act_fast(f, x)) isa Any
@test @inferred(apply_act_fast2(f, x)) isa Any

@jet apply_act_fast(f, x)
@jet apply_act_fast2(f, x)

@test @inferred(Zygote.gradient(apply_act, f, x)) isa Any
@test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any
@test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any

test_gradients(Base.Fix1(apply_act, f), x; atol, rtol)
test_gradients(Base.Fix1(apply_act_fast, f), x; atol, rtol)
test_gradients(Base.Fix1(apply_act_fast2, f), x; atol, rtol)

∂x1 = Zygote.gradient(apply_act, f, x)[2]
∂x2 = Zygote.gradient(apply_act_fast, f, x)[2]
∂x3 = Zygote.gradient(apply_act_fast2, f, x)[2]

@test ∂x1∂x2 atol=atol rtol=rtol
@test ∂x1∂x3 atol=atol rtol=rtol
end
end
end

0 comments on commit 54904a3

Please sign in to comment.