From be9c1c8b932110c39955b4dd41785df8ce2cd5db Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sun, 22 Dec 2024 15:04:23 +0200 Subject: [PATCH] Fix `NaN` gradient in spectrogram (#617) --- src/audio/spectrogram.jl | 8 ++++---- test/testsuite/spectral.jl | 12 +++++++++++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/audio/spectrogram.jl b/src/audio/spectrogram.jl index 68198902..efee1d11 100644 --- a/src/audio/spectrogram.jl +++ b/src/audio/spectrogram.jl @@ -25,11 +25,11 @@ See [`stft`](@ref) for other arguments. Spectrogram in the shape `(T, F, B)`, where `T` is the number of window hops and `F = n_fft ÷ 2 + 1`. """ -function spectrogram(waveform; +function spectrogram(waveform::AbstractArray{T}; pad::Int = 0, n_fft::Int, hop_length::Int, window, center::Bool = true, power::Real = 2.0, normalized::Bool = false, window_normalized::Bool = false, -) +) where T pad > 0 && (waveform = pad_zeros(waveform, pad; dims=1);) # Pack batch dimensions. @@ -41,8 +41,8 @@ function spectrogram(waveform; window_normalized && (spec = spec .* inv(norm(window));) if power > 0 - p = eltype(waveform)(power) - spec = abs.(spec).^p + p = T(power) + spec = abs.(spec .+ eps(T)).^p end return spec end diff --git a/test/testsuite/spectral.jl b/test/testsuite/spectral.jl index 12e38cc4..a78a7915 100644 --- a/test/testsuite/spectral.jl +++ b/test/testsuite/spectral.jl @@ -110,9 +110,19 @@ function spectral_testsuite(Backend) spec = spectrogram(x; n_fft=1024, hop_length=128, window, center=true, normalized=false) - @test abs.(y).^2 ≈ spec + # Gradient with `0`s in spectrogram. + # We add small ϵ to spectrogram before computing power + # to prevent `NaN` in gradient due to `abs(0)`. + x = device(ones(Float32, 1024)) + g = Zygote.gradient(x) do x + sum(spectrogram(x; + n_fft=1024, hop_length=128, window, + center=true, normalized=false)) + end + @test !any(isnan.(g[1])) + # Batched. x = device(rand(Float32, 1024, 3)) spec = spectrogram(x;