From 004ab31702b1f563cc221c66f3bb7d737a2073d4 Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 20 Nov 2024 14:56:00 +0800 Subject: [PATCH] norm by sqrt(nfft) or window l2 energy --- src/layer/inversespectrogram.cpp | 26 ++++++++++++++++++--- src/layer/inversespectrogram.h | 2 +- src/layer/spectrogram.cpp | 28 ++++++++++++++++++----- src/layer/spectrogram.h | 2 +- tools/pnnx/tests/ncnn/test_torch_istft.py | 2 +- tools/pnnx/tests/ncnn/test_torch_stft.py | 2 +- tools/pnnx/tests/test_torch_istft.py | 2 +- tools/pnnx/tests/test_torch_stft.py | 2 +- 8 files changed, 51 insertions(+), 15 deletions(-) diff --git a/src/layer/inversespectrogram.cpp b/src/layer/inversespectrogram.cpp index bc17ab303ef..3c45a3dfe65 100644 --- a/src/layer/inversespectrogram.cpp +++ b/src/layer/inversespectrogram.cpp @@ -34,7 +34,7 @@ int InverseSpectrogram::load_param(const ParamDict& pd) // assert winlen <= n_fft // generate window - window_data.create(n_fft); + window_data.create(normalized == 2 ? n_fft + 1 : n_fft); { float* p = window_data; for (int i = 0; i < (n_fft - winlen) / 2; i++) @@ -69,6 +69,17 @@ int InverseSpectrogram::load_param(const ParamDict& pd) { *p++ = 0.f; } + + // pre-calculated window norm factor + if (normalized == 2) + { + float sqsum = 0.f; + for (int i = 0; i < n_fft; i++) + { + sqsum += window_data[i] * window_data[i]; + } + window_data[n_fft] = sqrt(sqsum); + } } return 0; @@ -134,11 +145,20 @@ int InverseSpectrogram::forward(const Mat& bottom_blob, Mat& top_blob, const Opt } } - if (normalized) + if (normalized == 1) + { + float norm = sqrt(n_fft); + for (int i = 0; i < 2 * n_fft; i++) + { + sp[i] *= norm; + } + } + if (normalized == 2) { + float norm = window_data[n_fft]; for (int i = 0; i < 2 * n_fft; i++) { - sp[i] *= sqrt(winlen); + sp[i] *= norm; } } diff --git a/src/layer/inversespectrogram.h b/src/layer/inversespectrogram.h index 097e2ff28df..969868d1540 100644 --- a/src/layer/inversespectrogram.h +++ b/src/layer/inversespectrogram.h @@ -35,7 +35,7 @@ class InverseSpectrogram : public Layer int winlen; int window_type; // 0=ones 1=hann 2=hamming int center; - int normalized; + int normalized; // 0=disabled 1=sqrt(n_fft) 2=window-l2-energy Mat window_data; }; diff --git a/src/layer/spectrogram.cpp b/src/layer/spectrogram.cpp index b8a3220ca74..9f3c35b779a 100644 --- a/src/layer/spectrogram.cpp +++ b/src/layer/spectrogram.cpp @@ -36,7 +36,7 @@ int Spectrogram::load_param(const ParamDict& pd) // assert winlen <= n_fft // generate window - window_data.create(n_fft); + window_data.create(normalized == 2 ? n_fft + 1 : n_fft); { float* p = window_data; for (int i = 0; i < (n_fft - winlen) / 2; i++) @@ -71,6 +71,17 @@ int Spectrogram::load_param(const ParamDict& pd) { *p++ = 0.f; } + + // pre-calculated window norm factor + if (normalized == 2) + { + float sqsum = 0.f; + for (int i = 0; i < n_fft; i++) + { + sqsum += window_data[i] * window_data[i]; + } + window_data[n_fft] = 1.f / sqrt(sqsum); + } } return 0; @@ -139,12 +150,17 @@ int Spectrogram::forward(const Mat& bottom_blob, Mat& top_blob, const Option& op im -= v * sin(angle); // + imag * cos(angle); } - // fprintf(stderr, "%.2f %.2f %.2f %.2f\n", re, im, magnitude, power); - - if (normalized) + if (normalized == 1) + { + float norm = 1.f / sqrt(n_fft); + re *= norm; + im *= norm; + } + if (normalized == 2) { - re /= sqrt(winlen); - im /= sqrt(winlen); + float norm = window_data[n_fft]; + re *= norm; + im *= norm; } if (power == 0) diff --git a/src/layer/spectrogram.h b/src/layer/spectrogram.h index e7f80b58da9..4b2db08a581 100644 --- a/src/layer/spectrogram.h +++ b/src/layer/spectrogram.h @@ -36,7 +36,7 @@ class Spectrogram : public Layer int window_type; // 0=ones 1=hann 2=hamming int center; int pad_type; // 0=CONSTANT 1=REPLICATE 2=REFLECT - int normalized; + int normalized; // 0=disabled 1=sqrt(n_fft) 2=window-l2-energy int onesided; Mat window_data; diff --git a/tools/pnnx/tests/ncnn/test_torch_istft.py b/tools/pnnx/tests/ncnn/test_torch_istft.py index d3b988d09d9..414b3b21765 100644 --- a/tools/pnnx/tests/ncnn/test_torch_istft.py +++ b/tools/pnnx/tests/ncnn/test_torch_istft.py @@ -25,7 +25,7 @@ def forward(self, x, y, z, w): y = torch.view_as_complex(y) z = torch.view_as_complex(z) w = torch.view_as_complex(w) - out0 = torch.istft(x, n_fft=64, window=torch.hann_window(64), center=True, normalized=True, return_complex=False) + out0 = torch.istft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=False) out1 = torch.istft(y, n_fft=128, center=False, onesided=True, return_complex=False) out2 = torch.istft(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, onesided=True, return_complex=False) out3 = torch.istft(w, n_fft=512, center=False, onesided=False, return_complex=True) diff --git a/tools/pnnx/tests/ncnn/test_torch_stft.py b/tools/pnnx/tests/ncnn/test_torch_stft.py index 550d4ccf854..896cd133053 100644 --- a/tools/pnnx/tests/ncnn/test_torch_stft.py +++ b/tools/pnnx/tests/ncnn/test_torch_stft.py @@ -21,7 +21,7 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y): - out0 = torch.stft(x, n_fft=64, window=torch.hann_window(64), center=True, normalized=True, return_complex=True) + out0 = torch.stft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=True) out1 = torch.stft(x, n_fft=128, center=False, onesided=True, return_complex=True) out2 = torch.stft(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, pad_mode='constant', onesided=True, return_complex=True) out3 = torch.stft(y, n_fft=512, center=True, onesided=False, return_complex=True) diff --git a/tools/pnnx/tests/test_torch_istft.py b/tools/pnnx/tests/test_torch_istft.py index 0e0763c93b4..771217c70d2 100644 --- a/tools/pnnx/tests/test_torch_istft.py +++ b/tools/pnnx/tests/test_torch_istft.py @@ -21,7 +21,7 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): - out0 = torch.istft(x, n_fft=64, window=torch.hann_window(64), center=True, normalized=True, return_complex=False) + out0 = torch.istft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=False) out1 = torch.istft(y, n_fft=128, center=False, onesided=True, return_complex=False) out2 = torch.istft(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, onesided=True, return_complex=False) out3 = torch.istft(w, n_fft=512, center=False, onesided=False, return_complex=True) diff --git a/tools/pnnx/tests/test_torch_stft.py b/tools/pnnx/tests/test_torch_stft.py index b90c3c01be2..3ec307294a5 100644 --- a/tools/pnnx/tests/test_torch_stft.py +++ b/tools/pnnx/tests/test_torch_stft.py @@ -21,7 +21,7 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y): - out0 = torch.stft(x, n_fft=64, window=torch.hann_window(64), center=True, normalized=True, return_complex=True) + out0 = torch.stft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=True) out1 = torch.stft(x, n_fft=128, center=False, onesided=True, return_complex=True) out2 = torch.stft(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, pad_mode='constant', onesided=True, return_complex=True) out3 = torch.stft(y, n_fft=512, center=True, onesided=False, return_complex=True)