From b0b2652402cc363ece19746f6dd71805250cb0e0 Mon Sep 17 00:00:00 2001 From: dogukan uraz tuna <156364766+simudt@users.noreply.github.com> Date: Tue, 26 Mar 2024 00:22:45 +0300 Subject: [PATCH 1/2] append einsum and 2dsolver version of EinFFT --- simba/main.py | 69 ++++++++++++++++++++------------------------------- 1 file changed, 27 insertions(+), 42 deletions(-) diff --git a/simba/main.py b/simba/main.py index 213868e..6071292 100644 --- a/simba/main.py +++ b/simba/main.py @@ -40,9 +40,12 @@ def __init__( # silu self.act = nn.SiLU() - # Weights for Wr and Wi - self.Wr = nn.Parameter(torch.randn(in_channels, out_channels)) - self.Wi = nn.Parameter(torch.randn(in_channels, out_channels)) + # complex weights for channel-wise transformation + self.complex_weight = nn.Parameter( + torch.randn( + in_channels, out_channels, dtype=torch.complex64 + ) + ) def forward(self, x: Tensor) -> Tensor: """ @@ -57,52 +60,34 @@ def forward(self, x: Tensor) -> Tensor: """ b, c, h, w = x.shape - # Get Xr and X1 - fast_fouried = torch.fft.fft(x) - print(fast_fouried.shape) - - # Get Wr Wi use pytorch split instead - xr = fast_fouried.real - xi = fast_fouried.imag - - # Einstein Matrix Multiplication with XR, Xi, Wr, Wi use torch split instead - # matmul = torch.matmul(xr, self.Wr) + torch.matmul(xi, self.Wi) - matmul = torch.matmul(xr, xi) - # matmul = torch.matmul(self.Wr, self.Wi) - print(matmul.shape) + # apply 2D FFTSolver, transform input tensor to frequency domain + fast_fouried = torch.fft.fft2(x, dim=(-2, -1)) - # Xr, Xi hat, use torch split instead - xr_hat = matmul # .real - xi_hat = matmul # .imag - - # Silu - acted_xr_hat = self.act(xr_hat) - acted_xi_hat = self.act(xi_hat) + # complex-valued multiplication + einsum_mul = torch.einsum( + "bchw,cf->bhwf", fast_fouried, self.complex_weight + ) - # Emm with the weights use torch split instead - # emmed = torch.matmul( - # acted_xr_hat, - # self.Wr - # ) + torch.matmul( - # acted_xi_hat, - # self.Wi - # ) - emmed = torch.matmul(acted_xr_hat, acted_xi_hat) + # get xr xi splitted parts + xr = einsum_mul.real + xi = einsum_mul.imag - # Split up into Xr and Xi again for the ifft use torch split instead - xr_hat = emmed # .real - xi_hat = emmed # .imag + # apply silu + real_act = self.act(xr) + imag_act = self.act(xi) - # IFFT - iffted = torch.fft.ifft(xr_hat + xi_hat) + # activated complex + activated_complex = torch.complex(real_act, imag_act) - return iffted + # apply ifft2d solver as notated + iffted = torch.fft.ifft2(activated_complex, dim=(-2, -1)) + return iffted.real -x = torch.randn(1, 3, 64, 64) -einfft = EinFFT(3, 64, 64) -out = einfft(x) -print(out) +x = torch.randn(1, 3, 64, 64) # Random input tensor +einfft = EinFFT(3, 64, 64) # Instantiate EinFFT module +output = einfft(x) # Apply EinFFT +print(output) # Print module architecture class Simba(nn.Module): From 44a4c21d41e5a05d6dedeace399cc7bbf49ac622 Mon Sep 17 00:00:00 2001 From: dogukan uraz tuna <156364766+simudt@users.noreply.github.com> Date: Tue, 26 Mar 2024 00:32:11 +0300 Subject: [PATCH 2/2] append einsum and 2dsolver version of EinFFT --- simba/main.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/simba/main.py b/simba/main.py index 6071292..2815f34 100644 --- a/simba/main.py +++ b/simba/main.py @@ -83,11 +83,14 @@ def forward(self, x: Tensor) -> Tensor: iffted = torch.fft.ifft2(activated_complex, dim=(-2, -1)) return iffted.real - -x = torch.randn(1, 3, 64, 64) # Random input tensor -einfft = EinFFT(3, 64, 64) # Instantiate EinFFT module -output = einfft(x) # Apply EinFFT -print(output) # Print module architecture +# Random input tensor +x = torch.randn(1, 3, 64, 64) +# Instantiate EinFFT module +einfft = EinFFT(3, 64, 64) +# Apply EinFFT to get an output +output = einfft(x) +# Print output tensor +print(output) class Simba(nn.Module):