Skip to content

Commit

Permalink
[DONE]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Mar 26, 2024
1 parent 023d3e2 commit 2cb0c06
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 28 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,36 @@
A simpler Pytorch + Zeta Implementation of the paper: "SiMBA: Simplified Mamba-based Architecture for Vision and Multivariate Time series"


## install
`$ pip install simba-torch`

## usage
```python

import torch
from simba_torch.main import Simba

# Forward pass with images
img = torch.randn(1, 3, 224, 224)

# Create model
model = Simba(
dim = 4, # Dimension of the transformer
dropout = 0.1, # Dropout rate for regularization
d_state=64, # Dimension of the transformer state
d_conv=64, # Dimension of the convolutional layers
num_classes=64, # Number of output classes
depth=8, # Number of transformer layers
patch_size=16, # Size of the image patches
image_size=224, # Size of the input image
channels=3, # Number of input channels
)

# Forward pass
out = model(img)
print(out.shape)

```


# License
Expand Down
22 changes: 22 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
from simba_torch.main import Simba

# Forward pass with images
img = torch.randn(1, 3, 224, 224)

# Create model
model = Simba(
dim = 4, # Dimension of the transformer
dropout = 0.1, # Dropout rate for regularization
d_state=64, # Dimension of the transformer state
d_conv=64, # Dimension of the convolutional layers
num_classes=64, # Number of output classes
depth=8, # Number of transformer layers
patch_size=16, # Size of the image patches
image_size=224, # Size of the input image
channels=3, # Number of input channels
)

# Forward pass
out = model(img)
print(out.shape)
95 changes: 95 additions & 0 deletions simba_torch/ein_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,98 @@ def forward(self, x: Tensor) -> Tensor:
# output = einfft(x)
# # Print output tensor
# print(output)


class EinFFTText(nn.Module):
"""
EinFFT module performs the EinFFT operation on the input tensor.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
dim (int): Dimension of the input tensor.
heads (int, optional): Number of attention heads. Defaults to 8.
Attributes:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
dim (int): Dimension of the input tensor.
heads (int): Number of attention heads.
act (nn.SiLU): Activation function (SiLU).
Wr (nn.Parameter): Learnable weight parameter for real part.
Wi (nn.Parameter): Learnable weight parameter for imaginary part.
"""

def __init__(
self,
sequence_length: int,
dim: int,
):
super().__init__()
self.dim = dim

# silu
self.act = nn.SiLU()

# complex weights for channel-wise transformation
self.complex_weight = nn.Parameter(
torch.randn(sequence_length, dim, dtype=torch.complex64)
)

# Real weight
self.real_weight = nn.Parameter(
torch.randn(sequence_length, dim)
)

def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the EinFFT module.
Args:
x (Tensor): Input tensor of shape (batch_size, sequence_length, dimension).
Returns:
Tensor: Output tensor of shape (batch_size, sequence_length, dimension).
"""
b, s, d = x.shape

# apply 1D FFTSolver, transform input tensor to frequency domain
fast_fouried = torch.fft.fft(x, dim=-2)

# get xr xi splitted parts
xr = fast_fouried.real
xi = fast_fouried.imag

# complex-valued multiplication
einsum_mul = torch.einsum(
"bsd,cf->bsd", xr, self.complex_weight
) + torch.einsum("bsd,cf->bsd", xi, self.complex_weight)

xr = einsum_mul.real
xi = einsum_mul.imag

# apply silu
real_act = self.act(xr)
imag_act = self.act(xi)

# EMM with the weights use torch split instead
emmed = torch.einsum(
"bsd,cf->bsd", real_act, self.real_weight
) + torch.einsum("bsd,cf->bsd", imag_act, self.complex_weight)

# apply ifft solver as notated
iffted = torch.fft.ifft(emmed + emmed, dim=-2)
return iffted.real


# Random input tensor
x = torch.randn(1, 3, 64)

# Instantiate EinFFT module
einfft = EinFFTText(3, 64)

# Apply EinFFT to get an output
output = einfft(x)
print(output.shape)
Loading

0 comments on commit 2cb0c06

Please sign in to comment.