diff --git a/example.py b/example.py index 1dd23e5..0931eb4 100644 --- a/example.py +++ b/example.py @@ -2,19 +2,19 @@ from griffin_torch.main import Griffin # Forward pass -x = torch.randint(0, 100, (1, 10)) +x = torch.randint(0, 10000, (1, 1000)) # Increase the number of tokens # Model model = Griffin( - dim=512, # Dimension of the model - num_tokens=100, # Number of tokens in the input - seq_len=10, # Length of the input sequence - depth=8, # Number of transformer blocks - mlp_mult=4, # Multiplier for the hidden dimension in the MLPs + dim=2048, # Increase the dimension of the model + num_tokens=10000, # Increase the number of tokens in the input + seq_len=1000, # Increase the length of the input sequence + depth=32, # Increase the number of transformer blocks + mlp_mult=16, # Increase the multiplier for the hidden dimension in the MLPs dropout=0.1, # Dropout rate ) # Forward pass y = model(x) -print(y) +print(y.shape) \ No newline at end of file diff --git a/griffin_torch/main.py b/griffin_torch/main.py index f624d46..5053582 100644 --- a/griffin_torch/main.py +++ b/griffin_torch/main.py @@ -100,7 +100,7 @@ def forward(self, x): return F.normalize(x, dim=-1) * self.scale * self.g -def output_head(x: Tensor, dim: int): +def output_head(x: Tensor, num_tokens: int, dim: int): """ Applies a linear transformation followed by softmax activation to the input tensor. @@ -114,7 +114,7 @@ def output_head(x: Tensor, dim: int): x = RMSNorm(dim)(x) # Linear transformation - x = nn.Linear(dim, dim)(x) + x = nn.Linear(dim, num_tokens)(x) # Softmax return F.softmax(x, dim=-1) @@ -329,4 +329,4 @@ def forward(self, x: Tensor) -> Tensor: for layer in self.layers: x = layer(x) + x - return output_head(x, self.dim) + return output_head(x, self.num_tokens, self.dim)