Skip to content

Commit

Permalink
[BUFG][Output head]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Mar 4, 2024
1 parent ad478a9 commit 32352dd
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
14 changes: 7 additions & 7 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
from griffin_torch.main import Griffin

# Forward pass
x = torch.randint(0, 100, (1, 10))
x = torch.randint(0, 10000, (1, 1000)) # Increase the number of tokens

# Model
model = Griffin(
dim=512, # Dimension of the model
num_tokens=100, # Number of tokens in the input
seq_len=10, # Length of the input sequence
depth=8, # Number of transformer blocks
mlp_mult=4, # Multiplier for the hidden dimension in the MLPs
dim=2048, # Increase the dimension of the model
num_tokens=10000, # Increase the number of tokens in the input
seq_len=1000, # Increase the length of the input sequence
depth=32, # Increase the number of transformer blocks
mlp_mult=16, # Increase the multiplier for the hidden dimension in the MLPs
dropout=0.1, # Dropout rate
)

# Forward pass
y = model(x)

print(y)
print(y.shape)
6 changes: 3 additions & 3 deletions griffin_torch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def forward(self, x):
return F.normalize(x, dim=-1) * self.scale * self.g


def output_head(x: Tensor, dim: int):
def output_head(x: Tensor, num_tokens: int, dim: int):
"""
Applies a linear transformation followed by softmax activation to the input tensor.
Expand All @@ -114,7 +114,7 @@ def output_head(x: Tensor, dim: int):
x = RMSNorm(dim)(x)

# Linear transformation
x = nn.Linear(dim, dim)(x)
x = nn.Linear(dim, num_tokens)(x)

# Softmax
return F.softmax(x, dim=-1)
Expand Down Expand Up @@ -329,4 +329,4 @@ def forward(self, x: Tensor) -> Tensor:
for layer in self.layers:
x = layer(x) + x

return output_head(x, self.dim)
return output_head(x, self.num_tokens, self.dim)

0 comments on commit 32352dd

Please sign in to comment.