diff --git a/src/softmax.jl b/src/softmax.jl index 182f2fb93..6c522dac7 100644 --- a/src/softmax.jl +++ b/src/softmax.jl @@ -62,7 +62,8 @@ function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T} if all(isfinite, max_) @fastmath out .= exp.(x .- max_) else - @fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_)) + _zero, _one, _inf = T(0), T(1), T(Inf) + @fastmath @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _one, _zero), exp(x - max_)) end tmp = dims isa Colon ? sum(out) : sum!(max_, out) out ./= tmp