From d940c489a0474ef2475d0f202bfd2b234c9bd399 Mon Sep 17 00:00:00 2001 From: junhyouk lee Date: Thu, 20 Apr 2023 13:25:40 +0900 Subject: [PATCH] update for pytorch2.0 --- README.md | 10 ++-- phaseaug.py | 111 ++++++++++++++++++++++--------------------- phaseaug/phaseaug.py | 111 ++++++++++++++++++++++--------------------- setup.py | 2 +- 4 files changed, 124 insertions(+), 110 deletions(-) diff --git a/README.md b/README.md index e39a8c0..473e1bc 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Accepted to ICASSP 2023 ## TODO -- [ ] PyTorch 2.0 is released, need to modify STFT and iSTFT for complex support +- [x] PyTorch 2.0 is released, need to modify STFT and iSTFT for complex support (solved at `1.0.0`) - [x] Arxiv updated - [x] Errata in paper will be fixed. Section 2.5 in paper, transition band half-width 0.06-> 0.012. - [x] Section 2.5, mention about multiplyinng rotation matrix to "the left side of F(x)" will be added. -> transpose m,k to reduce ambiguity @@ -60,7 +60,8 @@ with autocast(enabled=True) ``` ## Requirements -- [Pytorch>=1.7.0](https://pytorch.org/) for [alias-free-torch](https://github.com/junjun3518/alias-free-torch) +- [PyTorch>=1.7.0](https://pytorch.org/) for [alias-free-torch](https://github.com/junjun3518/alias-free-torch) +- Support PyTorch>=2.0.0 - The requirements are highlighted in [requirements.txt](./requirements.txt). - We also provide docker setup [Dockerfile](./Dockerfile). ``` @@ -128,5 +129,8 @@ If this repostory useful for yout research, please consider citing! year=2022, } ``` -If you have a question or any kind of inquiries, please contact Junhyeok Lee at [jun3518@mindslab.ai](mailto:jun3518@mindslab.ai) + +Bibtex will be updated after ICASSP 2023. + +If you have a question or any kind of inquiries, please contact Junhyeok Lee at [jun3518@icloud.com](mailto:jun3518@icloud.com) diff --git a/phaseaug.py b/phaseaug.py index ebc2608..bd7a42d 100644 --- a/phaseaug.py +++ b/phaseaug.py @@ -14,7 +14,8 @@ def __init__( cutoff=0.05, half_width=0.012, kernel_size=128, - filter_padding='constant' + filter_padding='constant', + complex_calc=None ): super().__init__() self.nfft = nfft @@ -22,6 +23,7 @@ def __init__( self.var = var self.delta_max = delta_max self.use_filter = use_filter + self.complex_calc = complex_calc self.register_buffer('window', torch.hann_window(nfft)) self.register_buffer('phi_ref', torch.arange(nfft // 2 + 1).unsqueeze(0) * 2 * pi / nfft) @@ -46,70 +48,73 @@ def sample_phi(self, batch_size): return phi #[B,nfft//2+1] self.sample_phi = sample_phi + if complex_calc or int(torch.__version__[0])>=2: + def stft_rot_istft(self, x, phi): + X = torch.stft( + x, + self.nfft, + self.hop, + window=self.window, + return_complex=True + ) #[B,F,T] + + rot = torch.exp(torch.tensor([(0.+1.j)], dtype = torch.complex64, device = x.device) * phi) #[B,F,1] + X_aug = X * rot + x_aug = torch.istft( + X_aug, + self.nfft, + self.hop, + window=self.window, + return_complex=False + ).unsqueeze(1) + return x_aug + else: + def stft_rot_istft(self, x, phi): + X = torch.stft( + x, + self.nfft, + self.hop, + window=self.window, + return_complex=False + ) #[B,F,T,2] + + phi_cos = phi.cos() + phi_sin = phi.sin() + rot_mat = torch.cat( + [phi_cos, -phi_sin, phi_sin, phi_cos], #[B,F,2,2] + dim=-1).view(-1, self.nfft // 2 + 1, 2, 2) + # We did not mention that we multiplied rot_mat to "the left side of X" + # Paper will be modified at rebuttal phase for clarity. + X_aug = torch.einsum('bfij ,bftj->bfti', rot_mat, X) + x_aug = torch.istft( + X_aug, + self.nfft, + self.hop, + window=self.window, + return_complex=False + ).unsqueeze(1) + return x_aug + self.stft_rot_istft = stft_rot_istft + # x: audio [B,1,T] -> [B,1,T] # phi: [B,nfft//2+1] # also possible for x :[B,C,T] but we did not generalize it. def forward(self, x, phi=None): x = x.squeeze(1) #[B,t] - X = torch.stft( - x, - self.nfft, - self.hop, - window=self.window, - return_complex=False - ) #[B,F,T,2] if phi is None: phi = self.sample_phi(self, X.shape[0]) - phi[:, 0] = 0. # we are multiplying phi_ref to mu, so it is always zero in our scheme phi = phi.unsqueeze(-1) #[B,F,1] - phi_cos = phi.cos() - phi_sin = phi.sin() - rot_mat = torch.cat( - [phi_cos, -phi_sin, phi_sin, phi_cos], #[B,F,2,2] - dim=-1).view(-1, self.nfft // 2 + 1, 2, 2) - # We did not mention that we multiplied rot_mat to "the left side of X" - # Paper will be modified at rebuttal phase for clarity. - X_aug = torch.einsum('bfij ,bftj->bfti', rot_mat, X) - x_aug = torch.istft( - X_aug, - self.nfft, - self.hop, - window=self.window, - return_complex=False - ) - return x_aug.unsqueeze(1) #[B,1,t] + x_aug = self.stft_rot_istft(self, x, phi) + return x_aug #[B,1,t] # x: audio [B,1,T] -> [B,1,T] # phi: [B,nfft//2+1] def forward_sync(self, x, x_hat, phi=None): - x = torch.cat([x, x_hat], dim=0).squeeze(1) #[2B,t] - X = torch.stft( - x, - self.nfft, - self.hop, - window=self.window, - return_complex=False - ) #[2B,F,T,2] + B = x.shape[0] + x = torch.cat([x, x_hat], dim=0) #[2B,1,t] if phi is None: - phi = self.sample_phi(self, X.shape[0] // 2) + phi = self.sample_phi(self, X.shape[0] // 2) #[2B, nfft//2+1] phi = torch.cat([phi, phi], dim=0) - - phi[:, 0] = 0. # we are multiplying phi_ref to mu, so it is always zero in our scheme - phi = phi.unsqueeze(-1) #[2B,F,1] - phi_cos = phi.cos() - phi_sin = phi.sin() - rot_mat = torch.cat( - [phi_cos, -phi_sin, phi_sin, phi_cos], #[2B,F,2,2] - dim=-1).view(-1, self.nfft // 2 + 1, 2, 2) - # We did not mention that we multiplied rot_mat to "the left side of X" - # Paper will be modified at rebuttal phase for clarity. - X_aug = torch.einsum('bfij ,bftj->bfti', rot_mat, X) - x_aug = torch.istft( - X_aug, - self.nfft, - self.hop, - window=self.window, - return_complex=False - ) - return x_aug.unsqueeze(1).split(x_aug.shape[0] // 2, dim=0) #[B,1,t],[B,1,t] \ No newline at end of file + x_augs = self.forward(x, phi).split(B, dim=0) + return x_augs diff --git a/phaseaug/phaseaug.py b/phaseaug/phaseaug.py index ebc2608..bd7a42d 100644 --- a/phaseaug/phaseaug.py +++ b/phaseaug/phaseaug.py @@ -14,7 +14,8 @@ def __init__( cutoff=0.05, half_width=0.012, kernel_size=128, - filter_padding='constant' + filter_padding='constant', + complex_calc=None ): super().__init__() self.nfft = nfft @@ -22,6 +23,7 @@ def __init__( self.var = var self.delta_max = delta_max self.use_filter = use_filter + self.complex_calc = complex_calc self.register_buffer('window', torch.hann_window(nfft)) self.register_buffer('phi_ref', torch.arange(nfft // 2 + 1).unsqueeze(0) * 2 * pi / nfft) @@ -46,70 +48,73 @@ def sample_phi(self, batch_size): return phi #[B,nfft//2+1] self.sample_phi = sample_phi + if complex_calc or int(torch.__version__[0])>=2: + def stft_rot_istft(self, x, phi): + X = torch.stft( + x, + self.nfft, + self.hop, + window=self.window, + return_complex=True + ) #[B,F,T] + + rot = torch.exp(torch.tensor([(0.+1.j)], dtype = torch.complex64, device = x.device) * phi) #[B,F,1] + X_aug = X * rot + x_aug = torch.istft( + X_aug, + self.nfft, + self.hop, + window=self.window, + return_complex=False + ).unsqueeze(1) + return x_aug + else: + def stft_rot_istft(self, x, phi): + X = torch.stft( + x, + self.nfft, + self.hop, + window=self.window, + return_complex=False + ) #[B,F,T,2] + + phi_cos = phi.cos() + phi_sin = phi.sin() + rot_mat = torch.cat( + [phi_cos, -phi_sin, phi_sin, phi_cos], #[B,F,2,2] + dim=-1).view(-1, self.nfft // 2 + 1, 2, 2) + # We did not mention that we multiplied rot_mat to "the left side of X" + # Paper will be modified at rebuttal phase for clarity. + X_aug = torch.einsum('bfij ,bftj->bfti', rot_mat, X) + x_aug = torch.istft( + X_aug, + self.nfft, + self.hop, + window=self.window, + return_complex=False + ).unsqueeze(1) + return x_aug + self.stft_rot_istft = stft_rot_istft + # x: audio [B,1,T] -> [B,1,T] # phi: [B,nfft//2+1] # also possible for x :[B,C,T] but we did not generalize it. def forward(self, x, phi=None): x = x.squeeze(1) #[B,t] - X = torch.stft( - x, - self.nfft, - self.hop, - window=self.window, - return_complex=False - ) #[B,F,T,2] if phi is None: phi = self.sample_phi(self, X.shape[0]) - phi[:, 0] = 0. # we are multiplying phi_ref to mu, so it is always zero in our scheme phi = phi.unsqueeze(-1) #[B,F,1] - phi_cos = phi.cos() - phi_sin = phi.sin() - rot_mat = torch.cat( - [phi_cos, -phi_sin, phi_sin, phi_cos], #[B,F,2,2] - dim=-1).view(-1, self.nfft // 2 + 1, 2, 2) - # We did not mention that we multiplied rot_mat to "the left side of X" - # Paper will be modified at rebuttal phase for clarity. - X_aug = torch.einsum('bfij ,bftj->bfti', rot_mat, X) - x_aug = torch.istft( - X_aug, - self.nfft, - self.hop, - window=self.window, - return_complex=False - ) - return x_aug.unsqueeze(1) #[B,1,t] + x_aug = self.stft_rot_istft(self, x, phi) + return x_aug #[B,1,t] # x: audio [B,1,T] -> [B,1,T] # phi: [B,nfft//2+1] def forward_sync(self, x, x_hat, phi=None): - x = torch.cat([x, x_hat], dim=0).squeeze(1) #[2B,t] - X = torch.stft( - x, - self.nfft, - self.hop, - window=self.window, - return_complex=False - ) #[2B,F,T,2] + B = x.shape[0] + x = torch.cat([x, x_hat], dim=0) #[2B,1,t] if phi is None: - phi = self.sample_phi(self, X.shape[0] // 2) + phi = self.sample_phi(self, X.shape[0] // 2) #[2B, nfft//2+1] phi = torch.cat([phi, phi], dim=0) - - phi[:, 0] = 0. # we are multiplying phi_ref to mu, so it is always zero in our scheme - phi = phi.unsqueeze(-1) #[2B,F,1] - phi_cos = phi.cos() - phi_sin = phi.sin() - rot_mat = torch.cat( - [phi_cos, -phi_sin, phi_sin, phi_cos], #[2B,F,2,2] - dim=-1).view(-1, self.nfft // 2 + 1, 2, 2) - # We did not mention that we multiplied rot_mat to "the left side of X" - # Paper will be modified at rebuttal phase for clarity. - X_aug = torch.einsum('bfij ,bftj->bfti', rot_mat, X) - x_aug = torch.istft( - X_aug, - self.nfft, - self.hop, - window=self.window, - return_complex=False - ) - return x_aug.unsqueeze(1).split(x_aug.shape[0] // 2, dim=0) #[B,1,t],[B,1,t] \ No newline at end of file + x_augs = self.forward(x, phi).split(B, dim=0) + return x_augs diff --git a/setup.py b/setup.py index b49f25d..a7e517e 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name = 'phaseaug', - version = '0.0.2', + version = '1.0.0', description = 'PhaseAug: A Differentiable Augmentation for Speech Synthesis to Simulate One-to-Many Mapping', long_description=long_description, long_description_content_type="text/markdown",