Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use better softmax #7

Open
mossr opened this issue Nov 8, 2023 · 3 comments
Open

Use better softmax #7

mossr opened this issue Nov 8, 2023 · 3 comments

Comments

@mossr
Copy link
Member

mossr commented Nov 8, 2023

Use softmax function instead of hand writing it, change weight to be 1/w (catch w=0 to be min/max instead of smoothmin/smoothmax)

@mossr
Copy link
Member Author

mossr commented Nov 22, 2023

Implemented logsumexp for smoothmin and smoothmax, added unit tests, and changed the weight so w=0 returns the min/max and w → ∞ returns the mean (so the weight in the smoothmin/smoothmax uses 1/w).

Currently there's a strange issue with Zygote that prevents the stable version from taking the gradient using an interval:

x = [-0.25, 0, 0.1, 0.6, 0.75, 1.0]
ϕ_eventually = @formula ([3,5], xₜ -> μ(xₜ) > 0.5)
∇ρ̃(x, ϕ_eventually)

Error:

MethodError: no method matching +(::Base.RefValue{Any}, ::NamedTuple{(, :I), Tuple{NamedTuple{(, :c), Tuple{Nothing, Float64}}, Nothing}})

Closest candidates are:
+(::Any, ::Any, !Matched::Any, !Matched::Any...)
@ Base operators.jl:578
+(!Matched::ChainRulesCore.Tangent{P}, ::P) where P
@ ChainRulesCore E:\.julia\packages\ChainRulesCore\7MWx2\src\tangent_arithmetic.jl:146
+(!Matched::ChainRulesCore.AbstractThunk, ::Any)
@ ChainRulesCore E:\.julia\packages\ChainRulesCore\7MWx2\src\tangent_arithmetic.jl:122
...
accum(::Base.RefValue{Any}, ::NamedTuple{(, :I), Tuple{NamedTuple{(, :c), Tuple{Nothing, Float64}}, Nothing}})@lib.jl:17
Pullback@stl.jl:360[inlined]
(::Zygote.Pullback{Tuple{SignalTemporalLogic.var"##ρ̃#127", Float64, typeof(SignalTemporalLogic.ρ̃), Vector{Float64}, SignalTemporalLogic.Eventually}, Tuple{Zygote.Pullback{Tuple{typeof(SignalTemporalLogic.get_interval), SignalTemporalLogic.Eventually, Vector{Float64}}, Any}, Zygote.Pullback{Tuple{typeof(Core.kwcall), NamedTuple{(:w,), Tuple{Float64}}, typeof(SignalTemporalLogic.smoothmax), Base.Generator{UnitRange{Int64}, SignalTemporalLogic.var"#128#129"{Float64, Vector{Float64}, SignalTemporalLogic.Eventually}}}, Any}, Zygote.Pullback{Tuple{Type{NamedTuple{(:w,)}}, Tuple{Float64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:w,), Tuple{Float64}}}, Tuple{Float64}}, Tuple{Zygote.var"#2224#back#315"{Zygote.Jnew{NamedTuple{(:w,), Tuple{Float64}}, Nothing, true}}}}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{Type{Base.Generator}, SignalTemporalLogic.var"#128#129"{Float64, Vector{Float64}, SignalTemporalLogic.Eventually}, UnitRange{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{Base.Generator{UnitRange{Int64}, SignalTemporalLogic.var"#128#129"{Float64, Vector{Float64}, SignalTemporalLogic.Eventually}}}, SignalTemporalLogic.var"#128#129"{Float64, Vector{Float64}, SignalTemporalLogic.Eventually}, UnitRange{Int64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(convert), Type{UnitRange{Int64}}, UnitRange{Int64}}, Any}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(convert), Type{SignalTemporalLogic.var"#128#129"{Float64, Vector{Float64}, SignalTemporalLogic.Eventually}}, SignalTemporalLogic.var"#128#129"{Float64, Vector{Float64}, SignalTemporalLogic.Eventually}}, Tuple{}}, Zygote.var"#2214#back#313"{Zygote.Jnew{Base.Generator{UnitRange{Int64}, SignalTemporalLogic.var"#128#129"{Float64, Vector{Float64}, SignalTemporalLogic.Eventually}}, Nothing, false}}}}}}, Zygote.var"#2214#back#313"{Zygote.Jnew{SignalTemporalLogic.var"#128#129"{Float64, Vector{Float64}, SignalTemporalLogic.Eventually}, Nothing, false}}}})(::Float64)@interface2.jl:0

Therefore, it's turned off by default: _smoothmin(x, w; stable=false) etc.

@mossr
Copy link
Member Author

mossr commented Nov 22, 2023

Reference commit 7823805 (incorrectly label issue 6)

@mykelk
Copy link
Member

mykelk commented Nov 22, 2023

Would you want to use functionality from https://github.com/JuliaStats/LogExpFunctions.jl/?

Looking at their source code, it looks like they have some tricks to make it work with auto differentiation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants