Skip to content

Commit

Permalink
Update softmax.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Mar 23, 2024
1 parent 050b835 commit e470cee
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/softmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
if all(isfinite, max_)
@fastmath out .= exp.(x .- max_)
else
@fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_))
_zero, _one, _inf = T(0), T(1), T(Inf)
@fastmath @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _one, _zero), exp(x - max_))
end
tmp = dims isa Colon ? sum(out) : sum!(max_, out)
out ./= tmp
Expand Down

0 comments on commit e470cee

Please sign in to comment.