Skip to content

Commit

Permalink
Fix NaN gradient in spectrogram (#617)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Dec 22, 2024
1 parent 81e6cd1 commit be9c1c8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/audio/spectrogram.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ See [`stft`](@ref) for other arguments.
Spectrogram in the shape `(T, F, B)`, where
`T` is the number of window hops and `F = n_fft ÷ 2 + 1`.
"""
function spectrogram(waveform;
function spectrogram(waveform::AbstractArray{T};
pad::Int = 0, n_fft::Int, hop_length::Int, window,
center::Bool = true, power::Real = 2.0,
normalized::Bool = false, window_normalized::Bool = false,
)
) where T
pad > 0 && (waveform = pad_zeros(waveform, pad; dims=1);)

# Pack batch dimensions.
Expand All @@ -41,8 +41,8 @@ function spectrogram(waveform;
window_normalized && (spec = spec .* inv(norm(window));)

if power > 0
p = eltype(waveform)(power)
spec = abs.(spec).^p
p = T(power)
spec = abs.(spec .+ eps(T)).^p
end
return spec
end
Expand Down
12 changes: 11 additions & 1 deletion test/testsuite/spectral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,19 @@ function spectral_testsuite(Backend)
spec = spectrogram(x;
n_fft=1024, hop_length=128, window,
center=true, normalized=false)

@test abs.(y).^2 spec

# Gradient with `0`s in spectrogram.
# We add small ϵ to spectrogram before computing power
# to prevent `NaN` in gradient due to `abs(0)`.
x = device(ones(Float32, 1024))
g = Zygote.gradient(x) do x
sum(spectrogram(x;
n_fft=1024, hop_length=128, window,
center=true, normalized=false))
end
@test !any(isnan.(g[1]))

# Batched.
x = device(rand(Float32, 1024, 3))
spec = spectrogram(x;
Expand Down

0 comments on commit be9c1c8

Please sign in to comment.