Skip to content

Commit

Permalink
[PIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Apr 4, 2024
1 parent 75e9345 commit f2a7946
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 29 deletions.
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)

# MambaFormer
Implementation of MambaFormer in Pytorch ++ Zeta from the paper: "Can Mamba Learn How to Learn? A Comparative Study on In-Context Learning Tasks"

## install
`pip3 install mamba-former`

## usage
```python
import torch
from mamba_former.main import MambaFormer

# Forward pass example
x = torch.randint(1, 1000, (1, 100)) # Token
# Tokens are integrers

# Model
model = MambaFormer(
dim = 512,
num_tokens = 1000,
depth = 6,
d_state = 512,
d_conv = 128,
heads = 8,
dim_head = 64,
return_tokens = True
)

# Forward
out = model(x)
print(out)
print(out.shape)
```


# License
Expand Down
23 changes: 23 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from mamba_former.main import MambaFormer

# Forward pass example
x = torch.randint(1, 1000, (1, 100)) # Token
# Tokens are integrers

# Model
model = MambaFormer(
dim = 512,
num_tokens = 1000,
depth = 6,
d_state = 512,
d_conv = 128,
heads = 8,
dim_head = 64,
return_tokens = True
)

# Forward
out = model(x)
print(out)
print(out.shape)
5 changes: 5 additions & 0 deletions mamba_former/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from mamba_former.main import MambaFormer

__all__ = [
"MambaFormer"
]
21 changes: 0 additions & 21 deletions mamba_former/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,24 +135,3 @@ def forward(self, x):
else:
return x


# Forward pass example
x = torch.randint(1, 1000, (1, 100)) # Token
# Tokens are integrers

# Model
model = MambaFormer(
dim = 512,
num_tokens = 1000,
depth = 6,
d_state = 512,
d_conv = 128,
heads = 8,
dim_head = 64,
return_tokens = True
)

# Forward
out = model(x)
print(out)
print(out.shape)
14 changes: 6 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "paper"
name = "mamba-former"
version = "0.0.1"
description = "Paper - Pytorch"
license = "MIT"
authors = ["Kye Gomez <[email protected]>"]
homepage = "https://github.com/kyegomez/paper"
documentation = "https://github.com/kyegomez/paper" # Add this if you have documentation.
homepage = "https://github.com/kyegomez/MambaFormer"
documentation = "https://github.com/kyegomez/MambaFormer" # Add this if you have documentation.
readme = "README.md" # Assuming you have a README.md
repository = "https://github.com/kyegomez/paper"
repository = "https://github.com/kyegomez/MambaFormer"
keywords = ["artificial intelligence", "deep learning", "optimizers", "Prompt Engineering"]
classifiers = [
"Development Status :: 4 - Beta",
Expand All @@ -25,10 +25,8 @@ classifiers = [
python = "^3.6"
swarms = "*"
zetascale = "*"

[tool.poetry.dev-dependencies]
# Add development dependencies here

einops = "*"
torch = "*"

[tool.poetry.group.lint.dependencies]
ruff = "^0.1.6"
Expand Down

0 comments on commit f2a7946

Please sign in to comment.