Skip to content

Commit

Permalink
norm by sqrt(nfft) or window l2 energy
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Nov 20, 2024
1 parent 2deea7c commit 004ab31
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 15 deletions.
26 changes: 23 additions & 3 deletions src/layer/inversespectrogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ int InverseSpectrogram::load_param(const ParamDict& pd)

// assert winlen <= n_fft
// generate window
window_data.create(n_fft);
window_data.create(normalized == 2 ? n_fft + 1 : n_fft);
{
float* p = window_data;
for (int i = 0; i < (n_fft - winlen) / 2; i++)
Expand Down Expand Up @@ -69,6 +69,17 @@ int InverseSpectrogram::load_param(const ParamDict& pd)
{
*p++ = 0.f;
}

// pre-calculated window norm factor
if (normalized == 2)
{
float sqsum = 0.f;
for (int i = 0; i < n_fft; i++)
{
sqsum += window_data[i] * window_data[i];
}
window_data[n_fft] = sqrt(sqsum);
}
}

return 0;
Expand Down Expand Up @@ -134,11 +145,20 @@ int InverseSpectrogram::forward(const Mat& bottom_blob, Mat& top_blob, const Opt
}
}

if (normalized)
if (normalized == 1)
{
float norm = sqrt(n_fft);
for (int i = 0; i < 2 * n_fft; i++)
{
sp[i] *= norm;
}
}
if (normalized == 2)
{
float norm = window_data[n_fft];
for (int i = 0; i < 2 * n_fft; i++)
{
sp[i] *= sqrt(winlen);
sp[i] *= norm;
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/layer/inversespectrogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class InverseSpectrogram : public Layer
int winlen;
int window_type; // 0=ones 1=hann 2=hamming
int center;
int normalized;
int normalized; // 0=disabled 1=sqrt(n_fft) 2=window-l2-energy

Mat window_data;
};
Expand Down
28 changes: 22 additions & 6 deletions src/layer/spectrogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ int Spectrogram::load_param(const ParamDict& pd)

// assert winlen <= n_fft
// generate window
window_data.create(n_fft);
window_data.create(normalized == 2 ? n_fft + 1 : n_fft);
{
float* p = window_data;
for (int i = 0; i < (n_fft - winlen) / 2; i++)
Expand Down Expand Up @@ -71,6 +71,17 @@ int Spectrogram::load_param(const ParamDict& pd)
{
*p++ = 0.f;
}

// pre-calculated window norm factor
if (normalized == 2)
{
float sqsum = 0.f;
for (int i = 0; i < n_fft; i++)
{
sqsum += window_data[i] * window_data[i];
}
window_data[n_fft] = 1.f / sqrt(sqsum);
}
}

return 0;
Expand Down Expand Up @@ -139,12 +150,17 @@ int Spectrogram::forward(const Mat& bottom_blob, Mat& top_blob, const Option& op
im -= v * sin(angle); // + imag * cos(angle);
}

// fprintf(stderr, "%.2f %.2f %.2f %.2f\n", re, im, magnitude, power);

if (normalized)
if (normalized == 1)
{
float norm = 1.f / sqrt(n_fft);
re *= norm;
im *= norm;
}
if (normalized == 2)
{
re /= sqrt(winlen);
im /= sqrt(winlen);
float norm = window_data[n_fft];
re *= norm;
im *= norm;
}

if (power == 0)
Expand Down
2 changes: 1 addition & 1 deletion src/layer/spectrogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Spectrogram : public Layer
int window_type; // 0=ones 1=hann 2=hamming
int center;
int pad_type; // 0=CONSTANT 1=REPLICATE 2=REFLECT
int normalized;
int normalized; // 0=disabled 1=sqrt(n_fft) 2=window-l2-energy
int onesided;

Mat window_data;
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/tests/ncnn/test_torch_istft.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def forward(self, x, y, z, w):
y = torch.view_as_complex(y)
z = torch.view_as_complex(z)
w = torch.view_as_complex(w)
out0 = torch.istft(x, n_fft=64, window=torch.hann_window(64), center=True, normalized=True, return_complex=False)
out0 = torch.istft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=False)
out1 = torch.istft(y, n_fft=128, center=False, onesided=True, return_complex=False)
out2 = torch.istft(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, onesided=True, return_complex=False)
out3 = torch.istft(w, n_fft=512, center=False, onesided=False, return_complex=True)
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/tests/ncnn/test_torch_stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self):
super(Model, self).__init__()

def forward(self, x, y):
out0 = torch.stft(x, n_fft=64, window=torch.hann_window(64), center=True, normalized=True, return_complex=True)
out0 = torch.stft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=True)
out1 = torch.stft(x, n_fft=128, center=False, onesided=True, return_complex=True)
out2 = torch.stft(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, pad_mode='constant', onesided=True, return_complex=True)
out3 = torch.stft(y, n_fft=512, center=True, onesided=False, return_complex=True)
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/tests/test_torch_istft.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
out0 = torch.istft(x, n_fft=64, window=torch.hann_window(64), center=True, normalized=True, return_complex=False)
out0 = torch.istft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=False)
out1 = torch.istft(y, n_fft=128, center=False, onesided=True, return_complex=False)
out2 = torch.istft(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, onesided=True, return_complex=False)
out3 = torch.istft(w, n_fft=512, center=False, onesided=False, return_complex=True)
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/tests/test_torch_stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self):
super(Model, self).__init__()

def forward(self, x, y):
out0 = torch.stft(x, n_fft=64, window=torch.hann_window(64), center=True, normalized=True, return_complex=True)
out0 = torch.stft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=True)
out1 = torch.stft(x, n_fft=128, center=False, onesided=True, return_complex=True)
out2 = torch.stft(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, pad_mode='constant', onesided=True, return_complex=True)
out3 = torch.stft(y, n_fft=512, center=True, onesided=False, return_complex=True)
Expand Down

0 comments on commit 004ab31

Please sign in to comment.