From d10b407bc315f4bbdf62b0f52a93736041d663f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20D=C3=ADaz-Guerra=20Aparicio?= Date: Fri, 29 Mar 2024 18:32:10 +0200 Subject: [PATCH] [bug] Quick fix for XUMX in torch 2.0 (#684) --- asteroid/models/x_umx.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/asteroid/models/x_umx.py b/asteroid/models/x_umx.py index 1c29fd6b1..5b2911af2 100755 --- a/asteroid/models/x_umx.py +++ b/asteroid/models/x_umx.py @@ -7,13 +7,6 @@ class XUMX(BaseModel): - def __init__(self, *args, **kwargs): - raise RuntimeError( - "XUMX is broken in torch 2.0, use torch<2.0 with asteroid<0.7 to use it until it's fixed." - ) - - -class BrokenXUMX(BaseModel): r"""CrossNet-Open-Unmix (X-UMX) for Music Source Separation introduced in [1]. There are two notable contributions with no effect on inference: a) Multi Domain Losses @@ -352,8 +345,9 @@ def forward(self, x): normalized=False, onesided=True, pad_mode="reflect", - return_complex=False, + return_complex=True, ) + stft_f = torch.view_as_real(stft_f) # reshape back to channel dimension stft_f = stft_f.contiguous().view(nb_samples, nb_channels, self.n_fft // 2 + 1, -1, 2) @@ -405,6 +399,7 @@ def forward(self, spec, ang): x_i = spec * torch.sin(ang) x = torch.stack([x_r, x_i], dim=-1) x = x.view(sources * bsize * channels, fbins, frames, 2) + x = torch.view_as_complex(x) wav = torch.istft( x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, center=self.center )