Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UPDATE] einsum and 2DFFT-2DIFFT version of EinFFT #2

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 29 additions & 41 deletions simba/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@ def __init__(
# silu
self.act = nn.SiLU()

# Weights for Wr and Wi
self.Wr = nn.Parameter(torch.randn(in_channels, out_channels))
self.Wi = nn.Parameter(torch.randn(in_channels, out_channels))
# complex weights for channel-wise transformation
self.complex_weight = nn.Parameter(
torch.randn(
in_channels, out_channels, dtype=torch.complex64
)
)

def forward(self, x: Tensor) -> Tensor:
"""
Expand All @@ -57,52 +60,37 @@ def forward(self, x: Tensor) -> Tensor:
"""
b, c, h, w = x.shape

# Get Xr and X1
fast_fouried = torch.fft.fft(x)
print(fast_fouried.shape)

# Get Wr Wi use pytorch split instead
xr = fast_fouried.real
xi = fast_fouried.imag

# Einstein Matrix Multiplication with XR, Xi, Wr, Wi use torch split instead
# matmul = torch.matmul(xr, self.Wr) + torch.matmul(xi, self.Wi)
matmul = torch.matmul(xr, xi)
# matmul = torch.matmul(self.Wr, self.Wi)
print(matmul.shape)
# apply 2D FFTSolver, transform input tensor to frequency domain
fast_fouried = torch.fft.fft2(x, dim=(-2, -1))

# Xr, Xi hat, use torch split instead
xr_hat = matmul # .real
xi_hat = matmul # .imag

# Silu
acted_xr_hat = self.act(xr_hat)
acted_xi_hat = self.act(xi_hat)

# Emm with the weights use torch split instead
# emmed = torch.matmul(
# acted_xr_hat,
# self.Wr
# ) + torch.matmul(
# acted_xi_hat,
# self.Wi
# )
emmed = torch.matmul(acted_xr_hat, acted_xi_hat)
# complex-valued multiplication
einsum_mul = torch.einsum(
"bchw,cf->bhwf", fast_fouried, self.complex_weight
)

# Split up into Xr and Xi again for the ifft use torch split instead
xr_hat = emmed # .real
xi_hat = emmed # .imag
# get xr xi splitted parts
xr = einsum_mul.real
xi = einsum_mul.imag

# IFFT
iffted = torch.fft.ifft(xr_hat + xi_hat)
# apply silu
real_act = self.act(xr)
imag_act = self.act(xi)

return iffted
# activated complex
activated_complex = torch.complex(real_act, imag_act)

# apply ifft2d solver as notated
iffted = torch.fft.ifft2(activated_complex, dim=(-2, -1))
return iffted.real

# Random input tensor
x = torch.randn(1, 3, 64, 64)
# Instantiate EinFFT module
einfft = EinFFT(3, 64, 64)
out = einfft(x)
print(out)
# Apply EinFFT to get an output
output = einfft(x)
# Print output tensor
print(output)


class Simba(nn.Module):
Expand Down
Loading