Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to reproduce paper's model sizes. #48

Open
AI-Guru opened this issue Aug 22, 2024 · 2 comments
Open

Failed to reproduce paper's model sizes. #48

AI-Guru opened this issue Aug 22, 2024 · 2 comments

Comments

@AI-Guru
Copy link

AI-Guru commented Aug 22, 2024

Hi everyone,

I am most excited about xLSTM. Great and promising work!

Today, I am having trouble reproducing the model sizes from the paper. For example xLSTM[7:1] with 125M trainable parameters.

From the paper, I constructed the following config:

from omegaconf import OmegaConf
from dacite import from_dict
from xlstm.xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig

# Load the config.
config_string = """ 
model:
  vocab_size: 50257
  num_blocks: 24
  embedding_dim: 384
  mlstm_block:
    mlstm:
      num_heads: 4
  slstm_block:
    slstm:
      num_heads: 4
  slstm_at: [3, 20]
  context_length: 2048
"""
config = OmegaConf.create(config_string)

# Create the model.
model_config = from_dict(xLSTMLMModelConfig, OmegaConf.to_container(config.model))
model = xLSTMLMModel(model_config)
print(model_config)
print(model)

# Get the number of parameters.
number_of_parameters = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {number_of_parameters:_}")

It yields:

Number of parameters: 60_575_792

This is roughly half of the expected parameters. What did I miss?

Cheers,
Tristan

@PRamoneda
Copy link

We have the same problem

@kpoeppel
Copy link
Collaborator

It should be an embedding dimension of 768. Where did you find the 384?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants