Skip to content

Commit

Permalink
Use logistic function in StatsFuns
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Jun 9, 2020
1 parent 0d16973 commit 882c873
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 23 deletions.
45 changes: 37 additions & 8 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[CompilerSupportLibraries_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "7c4f882c41faa72118841185afc58a2eb00ef612"
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "0.3.3+0"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[LibGit2]]
deps = ["Printf"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[Libdl]]
Expand All @@ -38,8 +41,14 @@ git-tree-sha1 = "c3d1a616362645754b18e12dbba96ec311b0867f"
uuid = "a6bfbf70-4841-5cb9-aa18-3a8ad3c413ee"
version = "2018.6.22+0"

[[OpenSpecFun_jll]]
deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"]
git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.3+3"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Printf]]
Expand All @@ -60,6 +69,18 @@ git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.0.1"

[[Rmath]]
deps = ["Random", "Rmath_jll"]
git-tree-sha1 = "86c5647b565873641538d8f812c04e4c9dbeb370"
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
version = "0.6.1"

[[Rmath_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "1660f8fefbf5ab9c67560513131d4e933012fc4b"
uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
version = "0.2.2+0"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

Expand All @@ -73,13 +94,21 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["OpenSpecFun_jll"]
git-tree-sha1 = "d8d8b8a9f4119829410ecd706da4cc8594a1e020"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.10.3"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[StatsFuns]]
deps = ["Rmath", "SpecialFunctions"]
git-tree-sha1 = "04a5a8e6ab87966b43f247920eab053fd5fdc925"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.5"

[[UUIDs]]
deps = ["Random", "SHA"]
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ NNPACK_jll = "a6bfbf70-4841-5cb9-aa18-3a8ad3c413ee"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
Requires = "0.5, 1.0"
StatsFuns = "0.9"
julia = "1.3"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module NNlib
using Pkg
using Requires
using NNPACK_jll
using StatsFuns: logistic, softplus

# Include APIs
include("dim_helpers.jl")
Expand Down
16 changes: 2 additions & 14 deletions src/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@ export σ, sigmoid, hardσ, hardsigmoid, hardtanh, relu, leakyrelu, relu6, rrelu
Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation
function.
"""
function σ(x::Real)
t = exp(-abs(x))
ifelse(x 0, inv(one(t) + t), t / (one(t) + t))
end
const σ = logistic
const sigmoid = σ

"""
Expand Down Expand Up @@ -181,15 +178,6 @@ See [Quadratic Polynomials Learn Better Image Features](http://www.iro.umontreal
"""
softsign(x::Real) = x / (one(x) + abs(x))


"""
softplus(x) = log(exp(x) + 1)
See [Deep Sparse Rectifier Neural Networks](http://proceedings.mlr.press/v15/glorot11a/glorot11a.pdf).
"""
softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))


"""
logcosh(x)
Expand Down Expand Up @@ -222,7 +210,7 @@ See [Softshrink Activation Function](https://www.gabormelli.com/RKB/Softshrink_A
softshrink(x::Real, λ = oftype(x/1, 0.5)) = min(max(zero(x), x - λ), x + λ)

# Provide an informative error message if activation functions are called with an array
for f in (:σ, :hardσ, :logσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :softplus, :logcosh, :mish, :tanhshrink, :softshrink)
for f in (:hardσ, :logσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :logcosh, :mish, :tanhshrink, :softshrink)
@eval $(f)(x::AbstractArray, args...) =
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
end
4 changes: 3 additions & 1 deletion test/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ end
@testset "Array input" begin
x = rand(5)
for a in ACTIVATION_FUNCTIONS
@test_throws ErrorException a(x)
if a != σ && a != softplus
@test_throws ErrorException a(x)
end
end
end

Expand Down

0 comments on commit 882c873

Please sign in to comment.