diff --git a/Manifest.toml b/Manifest.toml index c244e94bc..3162da3d8 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,10 +1,31 @@ # This file is machine-generated - editing it directly is not advised +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + [[BinaryProvider]] -deps = ["Libdl", "SHA"] -git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c" +deps = ["Libdl", "Logging", "SHA"] +git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058" uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" -version = "0.5.8" +version = "0.5.10" + +[[CompilerSupportLibraries_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "7c4f882c41faa72118841185afc58a2eb00ef612" +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "0.3.3+0" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[LibGit2]] +deps = ["Printf"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" [[Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -13,6 +34,31 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" deps = ["Libdl"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[OpenSpecFun_jll]] +deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"] +git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.3+3" + +[[Pkg]] +deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + [[Random]] deps = ["Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -23,20 +69,50 @@ git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b" uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "1.0.1" +[[Rmath]] +deps = ["Random", "Rmath_jll"] +git-tree-sha1 = "86c5647b565873641538d8f812c04e4c9dbeb370" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.6.1" + +[[Rmath_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "1660f8fefbf5ab9c67560513131d4e933012fc4b" +uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" +version = "0.2.2+0" + [[SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + [[SparseArrays]] deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +[[SpecialFunctions]] +deps = ["OpenSpecFun_jll"] +git-tree-sha1 = "d8d8b8a9f4119829410ecd706da4cc8594a1e020" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "0.10.3" + [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[[StatsFuns]] +deps = ["Rmath", "SpecialFunctions"] +git-tree-sha1 = "04a5a8e6ab87966b43f247920eab053fd5fdc925" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "0.9.5" + [[UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" diff --git a/Project.toml b/Project.toml index f13f49fef..6ed4dd8d8 100644 --- a/Project.toml +++ b/Project.toml @@ -8,10 +8,12 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] BinaryProvider = "0.5" Requires = "0.5, 1.0" +StatsFuns = "0.9" julia = "1" [extras] diff --git a/src/NNlib.jl b/src/NNlib.jl index 597817f66..9f364b8bd 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -1,5 +1,6 @@ module NNlib using Requires +using StatsFuns: logistic, softplus # Include APIs include("dim_helpers.jl") diff --git a/src/activation.jl b/src/activation.jl index 947dc496e..2d4105007 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -12,10 +12,7 @@ export σ, sigmoid, hardσ, hardsigmoid, hardtanh, relu, leakyrelu, relu6, rrelu Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation function. """ -function σ(x::Real) - t = exp(-abs(x)) - ifelse(x ≥ 0, inv(one(t) + t), t / (one(t) + t)) -end +const σ = logistic const sigmoid = σ """ @@ -181,15 +178,6 @@ See [Quadratic Polynomials Learn Better Image Features](http://www.iro.umontreal """ softsign(x::Real) = x / (one(x) + abs(x)) - -""" - softplus(x) = log(exp(x) + 1) - -See [Deep Sparse Rectifier Neural Networks](http://proceedings.mlr.press/v15/glorot11a/glorot11a.pdf). -""" -softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x))) - - """ logcosh(x) @@ -222,7 +210,7 @@ See [Softshrink Activation Function](https://www.gabormelli.com/RKB/Softshrink_A softshrink(x::Real, λ = oftype(x/1, 0.5)) = min(max(zero(x), x - λ), x + λ) # Provide an informative error message if activation functions are called with an array -for f in (:σ, :hardσ, :logσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :softplus, :logcosh, :mish, :tanhshrink, :softshrink) +for f in (:hardσ, :logσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :logcosh, :mish, :tanhshrink, :softshrink) @eval $(f)(x::AbstractArray, args...) = error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.") end diff --git a/test/activation.jl b/test/activation.jl index 70558fc62..c8297836c 100644 --- a/test/activation.jl +++ b/test/activation.jl @@ -108,7 +108,9 @@ end @testset "Array input" begin x = rand(5) for a in ACTIVATION_FUNCTIONS - @test_throws ErrorException a(x) + if a != σ && a != softplus + @test_throws ErrorException a(x) + end end end