Skip to content

Commit

Permalink
update for pytorch2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
junjun3518 committed Apr 20, 2023
1 parent 0a7ab71 commit d940c48
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 110 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Accepted to ICASSP 2023


## TODO
- [ ] PyTorch 2.0 is released, need to modify STFT and iSTFT for complex support
- [x] PyTorch 2.0 is released, need to modify STFT and iSTFT for complex support (solved at `1.0.0`)
- [x] Arxiv updated
- [x] Errata in paper will be fixed. Section 2.5 in paper, transition band half-width 0.06-> 0.012.
- [x] Section 2.5, mention about multiplyinng rotation matrix to "the left side of F(x)" will be added. -> transpose m,k to reduce ambiguity
Expand Down Expand Up @@ -60,7 +60,8 @@ with autocast(enabled=True)
```

## Requirements
- [Pytorch>=1.7.0](https://pytorch.org/) for [alias-free-torch](https://github.com/junjun3518/alias-free-torch)
- [PyTorch>=1.7.0](https://pytorch.org/) for [alias-free-torch](https://github.com/junjun3518/alias-free-torch)
- Support PyTorch>=2.0.0
- The requirements are highlighted in [requirements.txt](./requirements.txt).
- We also provide docker setup [Dockerfile](./Dockerfile).
```
Expand Down Expand Up @@ -128,5 +129,8 @@ If this repostory useful for yout research, please consider citing!
year=2022,
}
```
If you have a question or any kind of inquiries, please contact Junhyeok Lee at [[email protected]](mailto:[email protected])

Bibtex will be updated after ICASSP 2023.

If you have a question or any kind of inquiries, please contact Junhyeok Lee at [[email protected]](mailto:[email protected])

111 changes: 58 additions & 53 deletions phaseaug.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ def __init__(
cutoff=0.05,
half_width=0.012,
kernel_size=128,
filter_padding='constant'
filter_padding='constant',
complex_calc=None
):
super().__init__()
self.nfft = nfft
self.hop = hop
self.var = var
self.delta_max = delta_max
self.use_filter = use_filter
self.complex_calc = complex_calc
self.register_buffer('window', torch.hann_window(nfft))
self.register_buffer('phi_ref', torch.arange(nfft // 2 + 1).unsqueeze(0) * 2 * pi / nfft)

Expand All @@ -46,70 +48,73 @@ def sample_phi(self, batch_size):
return phi #[B,nfft//2+1]
self.sample_phi = sample_phi

if complex_calc or int(torch.__version__[0])>=2:
def stft_rot_istft(self, x, phi):
X = torch.stft(
x,
self.nfft,
self.hop,
window=self.window,
return_complex=True
) #[B,F,T]

rot = torch.exp(torch.tensor([(0.+1.j)], dtype = torch.complex64, device = x.device) * phi) #[B,F,1]
X_aug = X * rot
x_aug = torch.istft(
X_aug,
self.nfft,
self.hop,
window=self.window,
return_complex=False
).unsqueeze(1)
return x_aug
else:
def stft_rot_istft(self, x, phi):
X = torch.stft(
x,
self.nfft,
self.hop,
window=self.window,
return_complex=False
) #[B,F,T,2]

phi_cos = phi.cos()
phi_sin = phi.sin()
rot_mat = torch.cat(
[phi_cos, -phi_sin, phi_sin, phi_cos], #[B,F,2,2]
dim=-1).view(-1, self.nfft // 2 + 1, 2, 2)
# We did not mention that we multiplied rot_mat to "the left side of X"
# Paper will be modified at rebuttal phase for clarity.
X_aug = torch.einsum('bfij ,bftj->bfti', rot_mat, X)
x_aug = torch.istft(
X_aug,
self.nfft,
self.hop,
window=self.window,
return_complex=False
).unsqueeze(1)
return x_aug
self.stft_rot_istft = stft_rot_istft

# x: audio [B,1,T] -> [B,1,T]
# phi: [B,nfft//2+1]
# also possible for x :[B,C,T] but we did not generalize it.
def forward(self, x, phi=None):
x = x.squeeze(1) #[B,t]
X = torch.stft(
x,
self.nfft,
self.hop,
window=self.window,
return_complex=False
) #[B,F,T,2]
if phi is None:
phi = self.sample_phi(self, X.shape[0])

phi[:, 0] = 0. # we are multiplying phi_ref to mu, so it is always zero in our scheme
phi = phi.unsqueeze(-1) #[B,F,1]
phi_cos = phi.cos()
phi_sin = phi.sin()
rot_mat = torch.cat(
[phi_cos, -phi_sin, phi_sin, phi_cos], #[B,F,2,2]
dim=-1).view(-1, self.nfft // 2 + 1, 2, 2)
# We did not mention that we multiplied rot_mat to "the left side of X"
# Paper will be modified at rebuttal phase for clarity.
X_aug = torch.einsum('bfij ,bftj->bfti', rot_mat, X)
x_aug = torch.istft(
X_aug,
self.nfft,
self.hop,
window=self.window,
return_complex=False
)
return x_aug.unsqueeze(1) #[B,1,t]
x_aug = self.stft_rot_istft(self, x, phi)
return x_aug #[B,1,t]

# x: audio [B,1,T] -> [B,1,T]
# phi: [B,nfft//2+1]
def forward_sync(self, x, x_hat, phi=None):
x = torch.cat([x, x_hat], dim=0).squeeze(1) #[2B,t]
X = torch.stft(
x,
self.nfft,
self.hop,
window=self.window,
return_complex=False
) #[2B,F,T,2]
B = x.shape[0]
x = torch.cat([x, x_hat], dim=0) #[2B,1,t]
if phi is None:
phi = self.sample_phi(self, X.shape[0] // 2)
phi = self.sample_phi(self, X.shape[0] // 2) #[2B, nfft//2+1]
phi = torch.cat([phi, phi], dim=0)

phi[:, 0] = 0. # we are multiplying phi_ref to mu, so it is always zero in our scheme
phi = phi.unsqueeze(-1) #[2B,F,1]
phi_cos = phi.cos()
phi_sin = phi.sin()
rot_mat = torch.cat(
[phi_cos, -phi_sin, phi_sin, phi_cos], #[2B,F,2,2]
dim=-1).view(-1, self.nfft // 2 + 1, 2, 2)
# We did not mention that we multiplied rot_mat to "the left side of X"
# Paper will be modified at rebuttal phase for clarity.
X_aug = torch.einsum('bfij ,bftj->bfti', rot_mat, X)
x_aug = torch.istft(
X_aug,
self.nfft,
self.hop,
window=self.window,
return_complex=False
)
return x_aug.unsqueeze(1).split(x_aug.shape[0] // 2, dim=0) #[B,1,t],[B,1,t]
x_augs = self.forward(x, phi).split(B, dim=0)
return x_augs
111 changes: 58 additions & 53 deletions phaseaug/phaseaug.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ def __init__(
cutoff=0.05,
half_width=0.012,
kernel_size=128,
filter_padding='constant'
filter_padding='constant',
complex_calc=None
):
super().__init__()
self.nfft = nfft
self.hop = hop
self.var = var
self.delta_max = delta_max
self.use_filter = use_filter
self.complex_calc = complex_calc
self.register_buffer('window', torch.hann_window(nfft))
self.register_buffer('phi_ref', torch.arange(nfft // 2 + 1).unsqueeze(0) * 2 * pi / nfft)

Expand All @@ -46,70 +48,73 @@ def sample_phi(self, batch_size):
return phi #[B,nfft//2+1]
self.sample_phi = sample_phi

if complex_calc or int(torch.__version__[0])>=2:
def stft_rot_istft(self, x, phi):
X = torch.stft(
x,
self.nfft,
self.hop,
window=self.window,
return_complex=True
) #[B,F,T]

rot = torch.exp(torch.tensor([(0.+1.j)], dtype = torch.complex64, device = x.device) * phi) #[B,F,1]
X_aug = X * rot
x_aug = torch.istft(
X_aug,
self.nfft,
self.hop,
window=self.window,
return_complex=False
).unsqueeze(1)
return x_aug
else:
def stft_rot_istft(self, x, phi):
X = torch.stft(
x,
self.nfft,
self.hop,
window=self.window,
return_complex=False
) #[B,F,T,2]

phi_cos = phi.cos()
phi_sin = phi.sin()
rot_mat = torch.cat(
[phi_cos, -phi_sin, phi_sin, phi_cos], #[B,F,2,2]
dim=-1).view(-1, self.nfft // 2 + 1, 2, 2)
# We did not mention that we multiplied rot_mat to "the left side of X"
# Paper will be modified at rebuttal phase for clarity.
X_aug = torch.einsum('bfij ,bftj->bfti', rot_mat, X)
x_aug = torch.istft(
X_aug,
self.nfft,
self.hop,
window=self.window,
return_complex=False
).unsqueeze(1)
return x_aug
self.stft_rot_istft = stft_rot_istft

# x: audio [B,1,T] -> [B,1,T]
# phi: [B,nfft//2+1]
# also possible for x :[B,C,T] but we did not generalize it.
def forward(self, x, phi=None):
x = x.squeeze(1) #[B,t]
X = torch.stft(
x,
self.nfft,
self.hop,
window=self.window,
return_complex=False
) #[B,F,T,2]
if phi is None:
phi = self.sample_phi(self, X.shape[0])

phi[:, 0] = 0. # we are multiplying phi_ref to mu, so it is always zero in our scheme
phi = phi.unsqueeze(-1) #[B,F,1]
phi_cos = phi.cos()
phi_sin = phi.sin()
rot_mat = torch.cat(
[phi_cos, -phi_sin, phi_sin, phi_cos], #[B,F,2,2]
dim=-1).view(-1, self.nfft // 2 + 1, 2, 2)
# We did not mention that we multiplied rot_mat to "the left side of X"
# Paper will be modified at rebuttal phase for clarity.
X_aug = torch.einsum('bfij ,bftj->bfti', rot_mat, X)
x_aug = torch.istft(
X_aug,
self.nfft,
self.hop,
window=self.window,
return_complex=False
)
return x_aug.unsqueeze(1) #[B,1,t]
x_aug = self.stft_rot_istft(self, x, phi)
return x_aug #[B,1,t]

# x: audio [B,1,T] -> [B,1,T]
# phi: [B,nfft//2+1]
def forward_sync(self, x, x_hat, phi=None):
x = torch.cat([x, x_hat], dim=0).squeeze(1) #[2B,t]
X = torch.stft(
x,
self.nfft,
self.hop,
window=self.window,
return_complex=False
) #[2B,F,T,2]
B = x.shape[0]
x = torch.cat([x, x_hat], dim=0) #[2B,1,t]
if phi is None:
phi = self.sample_phi(self, X.shape[0] // 2)
phi = self.sample_phi(self, X.shape[0] // 2) #[2B, nfft//2+1]
phi = torch.cat([phi, phi], dim=0)

phi[:, 0] = 0. # we are multiplying phi_ref to mu, so it is always zero in our scheme
phi = phi.unsqueeze(-1) #[2B,F,1]
phi_cos = phi.cos()
phi_sin = phi.sin()
rot_mat = torch.cat(
[phi_cos, -phi_sin, phi_sin, phi_cos], #[2B,F,2,2]
dim=-1).view(-1, self.nfft // 2 + 1, 2, 2)
# We did not mention that we multiplied rot_mat to "the left side of X"
# Paper will be modified at rebuttal phase for clarity.
X_aug = torch.einsum('bfij ,bftj->bfti', rot_mat, X)
x_aug = torch.istft(
X_aug,
self.nfft,
self.hop,
window=self.window,
return_complex=False
)
return x_aug.unsqueeze(1).split(x_aug.shape[0] // 2, dim=0) #[B,1,t],[B,1,t]
x_augs = self.forward(x, phi).split(B, dim=0)
return x_augs
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name = 'phaseaug',
version = '0.0.2',
version = '1.0.0',
description = 'PhaseAug: A Differentiable Augmentation for Speech Synthesis to Simulate One-to-Many Mapping',
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit d940c48

Please sign in to comment.