From b0b2652402cc363ece19746f6dd71805250cb0e0 Mon Sep 17 00:00:00 2001 From: dogukan uraz tuna <156364766+simudt@users.noreply.github.com> Date: Tue, 26 Mar 2024 00:22:45 +0300 Subject: [PATCH] append einsum and 2dsolver version of EinFFT --- simba/main.py | 69 ++++++++++++++++++++------------------------------- 1 file changed, 27 insertions(+), 42 deletions(-) diff --git a/simba/main.py b/simba/main.py index 213868e..6071292 100644 --- a/simba/main.py +++ b/simba/main.py @@ -40,9 +40,12 @@ def __init__( # silu self.act = nn.SiLU() - # Weights for Wr and Wi - self.Wr = nn.Parameter(torch.randn(in_channels, out_channels)) - self.Wi = nn.Parameter(torch.randn(in_channels, out_channels)) + # complex weights for channel-wise transformation + self.complex_weight = nn.Parameter( + torch.randn( + in_channels, out_channels, dtype=torch.complex64 + ) + ) def forward(self, x: Tensor) -> Tensor: """ @@ -57,52 +60,34 @@ def forward(self, x: Tensor) -> Tensor: """ b, c, h, w = x.shape - # Get Xr and X1 - fast_fouried = torch.fft.fft(x) - print(fast_fouried.shape) - - # Get Wr Wi use pytorch split instead - xr = fast_fouried.real - xi = fast_fouried.imag - - # Einstein Matrix Multiplication with XR, Xi, Wr, Wi use torch split instead - # matmul = torch.matmul(xr, self.Wr) + torch.matmul(xi, self.Wi) - matmul = torch.matmul(xr, xi) - # matmul = torch.matmul(self.Wr, self.Wi) - print(matmul.shape) + # apply 2D FFTSolver, transform input tensor to frequency domain + fast_fouried = torch.fft.fft2(x, dim=(-2, -1)) - # Xr, Xi hat, use torch split instead - xr_hat = matmul # .real - xi_hat = matmul # .imag - - # Silu - acted_xr_hat = self.act(xr_hat) - acted_xi_hat = self.act(xi_hat) + # complex-valued multiplication + einsum_mul = torch.einsum( + "bchw,cf->bhwf", fast_fouried, self.complex_weight + ) - # Emm with the weights use torch split instead - # emmed = torch.matmul( - # acted_xr_hat, - # self.Wr - # ) + torch.matmul( - # acted_xi_hat, - # self.Wi - # ) - emmed = torch.matmul(acted_xr_hat, acted_xi_hat) + # get xr xi splitted parts + xr = einsum_mul.real + xi = einsum_mul.imag - # Split up into Xr and Xi again for the ifft use torch split instead - xr_hat = emmed # .real - xi_hat = emmed # .imag + # apply silu + real_act = self.act(xr) + imag_act = self.act(xi) - # IFFT - iffted = torch.fft.ifft(xr_hat + xi_hat) + # activated complex + activated_complex = torch.complex(real_act, imag_act) - return iffted + # apply ifft2d solver as notated + iffted = torch.fft.ifft2(activated_complex, dim=(-2, -1)) + return iffted.real -x = torch.randn(1, 3, 64, 64) -einfft = EinFFT(3, 64, 64) -out = einfft(x) -print(out) +x = torch.randn(1, 3, 64, 64) # Random input tensor +einfft = EinFFT(3, 64, 64) # Instantiate EinFFT module +output = einfft(x) # Apply EinFFT +print(output) # Print module architecture class Simba(nn.Module):