Skip to content

Commit

Permalink
append einsum and 2dsolver version of EinFFT
Browse files Browse the repository at this point in the history
  • Loading branch information
dtreai committed Mar 25, 2024
1 parent e142190 commit b0b2652
Showing 1 changed file with 27 additions and 42 deletions.
69 changes: 27 additions & 42 deletions simba/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@ def __init__(
# silu
self.act = nn.SiLU()

# Weights for Wr and Wi
self.Wr = nn.Parameter(torch.randn(in_channels, out_channels))
self.Wi = nn.Parameter(torch.randn(in_channels, out_channels))
# complex weights for channel-wise transformation
self.complex_weight = nn.Parameter(
torch.randn(
in_channels, out_channels, dtype=torch.complex64
)
)

def forward(self, x: Tensor) -> Tensor:
"""
Expand All @@ -57,52 +60,34 @@ def forward(self, x: Tensor) -> Tensor:
"""
b, c, h, w = x.shape

# Get Xr and X1
fast_fouried = torch.fft.fft(x)
print(fast_fouried.shape)

# Get Wr Wi use pytorch split instead
xr = fast_fouried.real
xi = fast_fouried.imag

# Einstein Matrix Multiplication with XR, Xi, Wr, Wi use torch split instead
# matmul = torch.matmul(xr, self.Wr) + torch.matmul(xi, self.Wi)
matmul = torch.matmul(xr, xi)
# matmul = torch.matmul(self.Wr, self.Wi)
print(matmul.shape)
# apply 2D FFTSolver, transform input tensor to frequency domain
fast_fouried = torch.fft.fft2(x, dim=(-2, -1))

# Xr, Xi hat, use torch split instead
xr_hat = matmul # .real
xi_hat = matmul # .imag

# Silu
acted_xr_hat = self.act(xr_hat)
acted_xi_hat = self.act(xi_hat)
# complex-valued multiplication
einsum_mul = torch.einsum(
"bchw,cf->bhwf", fast_fouried, self.complex_weight
)

# Emm with the weights use torch split instead
# emmed = torch.matmul(
# acted_xr_hat,
# self.Wr
# ) + torch.matmul(
# acted_xi_hat,
# self.Wi
# )
emmed = torch.matmul(acted_xr_hat, acted_xi_hat)
# get xr xi splitted parts
xr = einsum_mul.real
xi = einsum_mul.imag

# Split up into Xr and Xi again for the ifft use torch split instead
xr_hat = emmed # .real
xi_hat = emmed # .imag
# apply silu
real_act = self.act(xr)
imag_act = self.act(xi)

# IFFT
iffted = torch.fft.ifft(xr_hat + xi_hat)
# activated complex
activated_complex = torch.complex(real_act, imag_act)

return iffted
# apply ifft2d solver as notated
iffted = torch.fft.ifft2(activated_complex, dim=(-2, -1))
return iffted.real


x = torch.randn(1, 3, 64, 64)
einfft = EinFFT(3, 64, 64)
out = einfft(x)
print(out)
x = torch.randn(1, 3, 64, 64) # Random input tensor
einfft = EinFFT(3, 64, 64) # Instantiate EinFFT module
output = einfft(x) # Apply EinFFT
print(output) # Print module architecture


class Simba(nn.Module):
Expand Down

0 comments on commit b0b2652

Please sign in to comment.