Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reset optimizer state is not working #726

Open
ahmeda14960 opened this issue Sep 12, 2024 · 0 comments
Open

Reset optimizer state is not working #726

ahmeda14960 opened this issue Sep 12, 2024 · 0 comments

Comments

@ahmeda14960
Copy link
Contributor

In principle we can add a flag to reset optimizer state, then change the following in trainer.py

if self.config.reset_optimizer_state:
            saveable_train_state = dataclasses.replace(saveable_train_state, optimizer=False)

However this is causing some cryptic configuration issue. Stack trace is below:

Traceback (most recent call last):
  File "/home/ahmed/levanter/src/levanter/main/train_lm.py", line 222, in <module>
    levanter.config.main(main)()
  File "/home/ahmed/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/home/ahmed/levanter/src/levanter/main/train_lm.py", line 131, in main
    state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))
  File "/home/ahmed/levanter/src/levanter/trainer.py", line 345, in initial_state
    state = load_checkpoint_or_initialize(
  File "/home/ahmed/levanter/src/levanter/checkpoint.py", line 464, in load_or_init
    filtered_state_shape = equinox.filter(state_shape, is_checkpointed)
  File "/home/ahmed/venv310/lib/python3.10/site-packages/equinox/_filters.py", line 129, in filter
    filter_tree = jtu.tree_map(_make_filter_tree(is_leaf), filter_spec, pytree)
  File "/home/ahmed/venv310/lib/python3.10/site-packages/jax/_src/tree_util.py", line 342, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  File "/home/ahmed/venv310/lib/python3.10/site-packages/jax/_src/tree_util.py", line 342, in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Mismatch custom node data: ('step', 'model', 'opt_state', 'training_key'), ('optimizer', 'is_trainable', 'mp'), (False, True, Policy(param_dtype=<class 'jax.numpy.float32'>, compute_dtype=<class 'jax.numpy.bfloat16'>, output_dtype=<class 'jax.numpy.bfloat16'>)) != ('step', 'model', 'opt_state', 'training_key'), ('optimizer', 'is_trainable', 'mp'), (GradientTransformationExtraArgs(init=<function inject_hyperparams.<locals>.wrapped_transform.<locals>.init_fn at 0x7ee1a0719bd0>, update=<function inject_hyperparams.<locals>.wrapped_transform.<locals>.update_fn at 0x7ee1a0719c60>), True, Policy(param_dtype=<class 'jax.numpy.float32'>, compute_dtype=<class 'jax.numpy.bfloat16'>, output_dtype=<class 'jax.numpy.bfloat16'>)); value: TrainerState(
  step=ShapeDtypeStruct(shape=(), dtype=int32),
  model=LlamaLMHeadModel(
    transformer=LlamaTransformer(
      config=LlamaConfig(
        seq_len=4096,
        hidden_dim=4096,
        intermediate_dim=11008,
        num_layers=32,
        num_heads=32,
        num_kv_heads=32,
        activation_function='silu',
        initializer_range=0.02,
        layer_norm_epsilon=0,
        upcast_attn=False,
        use_flash_attention=True,
        attn_backend=None,
        flash_attention_block_size=None,
        gradient_checkpointing=True,
        gradient_checkpointing_block_size=5,
        scan_layers=True,
        use_bias=False,
        use_layer_norm_weight=False,
        rope_scaling=None,
        rope_theta=10000.0,
        reference_checkpoint='meta-llama/Llama-2-7b-hf',
        tokenizer=None
      ),
      layers=Stacked(
        stacked=LlamaDecoderLayer(
          config=LlamaConfig(
            seq_len=4096,
            hidden_dim=4096,
            intermediate_dim=11008,
            num_layers=32,
            num_heads=32,
            num_kv_heads=32,
            activation_function='silu',
            initializer_range=0.02,
            layer_norm_epsilon=0,
            upcast_attn=False,
            use_flash_attention=True,
            attn_backend=None,
            flash_attention_block_size=None,
            gradient_checkpointing=True,
            gradient_checkpointing_block_size=5,
            scan_layers=True,
            use_bias=False,
            use_layer_norm_weight=False,
            rope_scaling=None,
            rope_theta=10000.0,
            reference_checkpoint='meta-llama/Llama-2-7b-hf',
            tokenizer=None
          ),
          self_attn=LlamaAttention(
            config=LlamaConfig(
              seq_len=4096,
              hidden_dim=4096,
              intermediate_dim=11008,
              num_layers=32,
              num_heads=32,
              num_kv_heads=32,
              activation_function='silu',
              initializer_range=0.02,
              layer_norm_epsilon=0,
              upcast_attn=False,
              use_flash_attention=True,
              attn_backend=None,
              flash_attention_block_size=None,
              gradient_checkpointing=True,
              gradient_checkpointing_block_size=5,
              scan_layers=True,
              use_bias=False,
              use_layer_norm_weight=False,
              rope_scaling=None,
              rope_theta=10000.0,
              reference_checkpoint='meta-llama/Llama-2-7b-hf',
              tokenizer=None
            ),
            q_proj=Linear(
              weight=Named(float32{'layers': 32, 'kv_heads': 32, 'q_heads_per_group': 1, 'head_size': 128, 'embed': 4096}),
              bias=None,
              In=Axis(name='embed', size=4096),
              Out=(
                Axis(name='kv_heads', size=32),
                Axis(name='q_heads_per_group', size=1),
                Axis(name='head_size', size=128)
              ),
              dot_general=DefaultDotGeneralOp()
            ),
            k_proj=Linear(
              weight=Named(float32{'layers': 32, 'kv_heads': 32, 'head_size': 128, 'embed': 4096}),
              bias=None,
              In=Axis(name='embed', size=4096),
              Out=(
                Axis(name='kv_heads', size=32),
                Axis(name='head_size', size=128)
              ),
              dot_general=DefaultDotGeneralOp()
            ),
            v_proj=Linear(
              weight=Named(float32{'layers': 32, 'kv_heads': 32, 'head_size': 128, 'embed': 4096}),
              bias=None,
              In=Axis(name='embed', size=4096),
              Out=(
                Axis(name='kv_heads', size=32),
                Axis(name='head_size', size=128)
              ),
              dot_general=DefaultDotGeneralOp()
            ),
            o_proj=Linear(
              weight=Named(float32{'layers': 32, 'embed': 4096, 'heads': 32, 'head_size': 128}),
              bias=None,
              In=(Axis(name='heads', size=32), Axis(name='head_size', size=128)),
              Out=Axis(name='embed', size=4096),
              dot_general=DefaultDotGeneralOp()
            )
          ),
          mlp=LlamaMlp(
            gate_proj=Linear(
              weight=Named(float32{'layers': 32, 'mlp': 11008, 'embed': 4096}),
              bias=None,
              In=Axis(name='embed', size=4096),
              Out=Axis(name='mlp', size=11008),
              dot_general=DefaultDotGeneralOp()
            ),
            up_proj=Linear(
              weight=Named(float32{'layers': 32, 'mlp': 11008, 'embed': 4096}),
              bias=None,
              In=Axis(name='embed', size=4096),
              Out=Axis(name='mlp', size=11008),
              dot_general=DefaultDotGeneralOp()
            ),
            down_proj=Linear(
              weight=Named(float32{'layers': 32, 'embed': 4096, 'mlp': 11008}),
              bias=None,
              In=Axis(name='mlp', size=11008),
              Out=Axis(name='embed', size=4096),
              dot_general=DefaultDotGeneralOp()
            ),
            act=<function silu>
          ),
          input_layernorm=LlamaRMSNorm(
            axis=Axis(name='embed', size=4096),
            weight=None,
            bias=None,
            eps=0,
            dtype=<class 'jax.numpy.float32'>
          ),
          post_attention_layernorm=LlamaRMSNorm(
            axis=Axis(name='embed', size=4096),
            weight=None,
            bias=None,
            eps=0,
            dtype=<class 'jax.numpy.float32'>
          )
        ),
        Block=Axis(name='layers', size=32),
        gradient_checkpointing=True,
        prevent_cse=False
      ),
      norm=LlamaRMSNorm(
        axis=Axis(name='embed', size=4096),
        weight=None,
        bias=None,
        eps=0,
        dtype=<class 'jax.numpy.float32'>
      )
    ),
    embeddings=LlamaEmbedding(
      Vocab=Axis(name='vocab', size=50280),
      token_embeddings=Embedding(
        weight=Named(float32{'vocab': 50280, 'embed': 4096}),
        Vocab=Axis(name='vocab', size=50280),
        Embed=Axis(name='embed', size=4096)
      )
    ),
    lm_head=Linear(
      weight=Named(float32{'vocab': 50280, 'embed': 4096}),
      bias=None,
      In=Axis(name='embed', size=4096),
      Out=Axis(name='vocab', size=50280),
      dot_general=DefaultDotGeneralOp()
    )
  ),
  optimizer=GradientTransformationExtraArgs(
    init=<function init_fn>,
    update=<function update_fn>
  ),
  opt_state=InjectStatefulHyperparamsState(
    count=ShapeDtypeStruct(shape=(), dtype=int32),
    hyperparams={'learning_rate': ShapeDtypeStruct(shape=(), dtype=float32)},
    hyperparams_states={
      'learning_rate':
      WrappedScheduleState(count=ShapeDtypeStruct(shape=(), dtype=int32))
    },
    inner_state=(
      EmptyState(),
      ScaleByAdamState(
        count=ShapeDtypeStruct(shape=(), dtype=int32),
        mu=LlamaLMHeadModel(
          transformer=LlamaTransformer(
            config=LlamaConfig(
              seq_len=4096,
              hidden_dim=4096,
              intermediate_dim=11008,
              num_layers=32,
              num_heads=32,
              num_kv_heads=32,
              activation_function='silu',
              initializer_range=0.02,
              layer_norm_epsilon=0,
              upcast_attn=False,
              use_flash_attention=True,
              attn_backend=None,
              flash_attention_block_size=None,
              gradient_checkpointing=True,
              gradient_checkpointing_block_size=5,
              scan_layers=True,
              use_bias=False,
              use_layer_norm_weight=False,
              rope_scaling=None,
              rope_theta=10000.0,
              reference_checkpoint='meta-llama/Llama-2-7b-hf',
              tokenizer=None
            ),
            layers=Stacked(
              stacked=LlamaDecoderLayer(
                config=LlamaConfig(
                  seq_len=4096,
                  hidden_dim=4096,
                  intermediate_dim=11008,
                  num_layers=32,
                  num_heads=32,
                  num_kv_heads=32,
                  activation_function='silu',
                  initializer_range=0.02,
                  layer_norm_epsilon=0,
                  upcast_attn=False,
                  use_flash_attention=True,
                  attn_backend=None,
                  flash_attention_block_size=None,
                  gradient_checkpointing=True,
                  gradient_checkpointing_block_size=5,
                  scan_layers=True,
                  use_bias=False,
                  use_layer_norm_weight=False,
                  rope_scaling=None,
                  rope_theta=10000.0,
                  reference_checkpoint='meta-llama/Llama-2-7b-hf',
                  tokenizer=None
                ),
                self_attn=LlamaAttention(
                  config=LlamaConfig(
                    seq_len=4096,
                    hidden_dim=4096,
                    intermediate_dim=11008,
                    num_layers=32,
                    num_heads=32,
                    num_kv_heads=32,
                    activation_function='silu',
                    initializer_range=0.02,
                    layer_norm_epsilon=0,
                    upcast_attn=False,
                    use_flash_attention=True,
                    attn_backend=None,
                    flash_attention_block_size=None,
                    gradient_checkpointing=True,
                    gradient_checkpointing_block_size=5,
                    scan_layers=True,
                    use_bias=False,
                    use_layer_norm_weight=False,
                    rope_scaling=None,
                    rope_theta=10000.0,
                    reference_checkpoint='meta-llama/Llama-2-7b-hf',
                    tokenizer=None
                  ),
                  q_proj=Linear(
                    weight=Named(float32{'layers': 32, 'kv_heads': 32, 'q_heads_per_group': 1, 'head_size': 128, 'embed': 4096}),
                    bias=None,
                    In=Axis(name='embed', size=4096),
                    Out=(
                      Axis(name='kv_heads', size=32),
                      Axis(name='q_heads_per_group', size=1),
                      Axis(name='head_size', size=128)
                    ),
                    dot_general=DefaultDotGeneralOp()
                  ),
                  k_proj=Linear(
                    weight=Named(float32{'layers': 32, 'kv_heads': 32, 'head_size': 128, 'embed': 4096}),
                    bias=None,
                    In=Axis(name='embed', size=4096),
                    Out=(
                      Axis(name='kv_heads', size=32),
                      Axis(name='head_size', size=128)
                    ),
                    dot_general=DefaultDotGeneralOp()
                  ),
                  v_proj=Linear(
                    weight=Named(float32{'layers': 32, 'kv_heads': 32, 'head_size': 128, 'embed': 4096}),
                    bias=None,
                    In=Axis(name='embed', size=4096),
                    Out=(
                      Axis(name='kv_heads', size=32),
                      Axis(name='head_size', size=128)
                    ),
                    dot_general=DefaultDotGeneralOp()
                  ),
                  o_proj=Linear(
                    weight=Named(float32{'layers': 32, 'embed': 4096, 'heads': 32, 'head_size': 128}),
                    bias=None,
                    In=(
                      Axis(name='heads', size=32),
                      Axis(name='head_size', size=128)
                    ),
                    Out=Axis(name='embed', size=4096),
                    dot_general=DefaultDotGeneralOp()
                  )
                ),
                mlp=LlamaMlp(
                  gate_proj=Linear(
                    weight=Named(float32{'layers': 32, 'mlp': 11008, 'embed': 4096}),
                    bias=None,
                    In=Axis(name='embed', size=4096),
                    Out=Axis(name='mlp', size=11008),
                    dot_general=DefaultDotGeneralOp()
                  ),
                  up_proj=Linear(
                    weight=Named(float32{'layers': 32, 'mlp': 11008, 'embed': 4096}),
                    bias=None,
                    In=Axis(name='embed', size=4096),
                    Out=Axis(name='mlp', size=11008),
                    dot_general=DefaultDotGeneralOp()
                  ),
                  down_proj=Linear(
                    weight=Named(float32{'layers': 32, 'embed': 4096, 'mlp': 11008}),
                    bias=None,
                    In=Axis(name='mlp', size=11008),
                    Out=Axis(name='embed', size=4096),
                    dot_general=DefaultDotGeneralOp()
                  ),
                  act=<function silu>
                ),
                input_layernorm=LlamaRMSNorm(
                  axis=Axis(name='embed', size=4096),
                  weight=None,
                  bias=None,
                  eps=0,
                  dtype=<class 'jax.numpy.float32'>
                ),
                post_attention_layernorm=LlamaRMSNorm(
                  axis=Axis(name='embed', size=4096),
                  weight=None,
                  bias=None,
                  eps=0,
                  dtype=<class 'jax.numpy.float32'>
                )
              ),
              Block=Axis(name='layers', size=32),
              gradient_checkpointing=True,
              prevent_cse=False
            ),
            norm=LlamaRMSNorm(
              axis=Axis(name='embed', size=4096),
              weight=None,
              bias=None,
              eps=0,
              dtype=<class 'jax.numpy.float32'>
            )
          ),
          embeddings=LlamaEmbedding(
            Vocab=Axis(name='vocab', size=50280),
            token_embeddings=Embedding(
              weight=Named(float32{'vocab': 50280, 'embed': 4096}),
              Vocab=Axis(name='vocab', size=50280),
              Embed=Axis(name='embed', size=4096)
            )
          ),
          lm_head=Linear(
            weight=Named(float32{'vocab': 50280, 'embed': 4096}),
            bias=None,
            In=Axis(name='embed', size=4096),
            Out=Axis(name='vocab', size=50280),
            dot_general=DefaultDotGeneralOp()
          )
        ),
        nu=LlamaLMHeadModel(
          transformer=LlamaTransformer(
            config=LlamaConfig(
              seq_len=4096,
              hidden_dim=4096,
              intermediate_dim=11008,
              num_layers=32,
              num_heads=32,
              num_kv_heads=32,
              activation_function='silu',
              initializer_range=0.02,
              layer_norm_epsilon=0,
              upcast_attn=False,
              use_flash_attention=True,
              attn_backend=None,
              flash_attention_block_size=None,
              gradient_checkpointing=True,
              gradient_checkpointing_block_size=5,
              scan_layers=True,
              use_bias=False,
              use_layer_norm_weight=False,
              rope_scaling=None,
              rope_theta=10000.0,
              reference_checkpoint='meta-llama/Llama-2-7b-hf',
              tokenizer=None
            ),
            layers=Stacked(
              stacked=LlamaDecoderLayer(
                config=LlamaConfig(
                  seq_len=4096,
                  hidden_dim=4096,
                  intermediate_dim=11008,
                  num_layers=32,
                  num_heads=32,
                  num_kv_heads=32,
                  activation_function='silu',
                  initializer_range=0.02,
                  layer_norm_epsilon=0,
                  upcast_attn=False,
                  use_flash_attention=True,
                  attn_backend=None,
                  flash_attention_block_size=None,
                  gradient_checkpointing=True,
                  gradient_checkpointing_block_size=5,
                  scan_layers=True,
                  use_bias=False,
                  use_layer_norm_weight=False,
                  rope_scaling=None,
                  rope_theta=10000.0,
                  reference_checkpoint='meta-llama/Llama-2-7b-hf',
                  tokenizer=None
                ),
                self_attn=LlamaAttention(
                  config=LlamaConfig(
                    seq_len=4096,
                    hidden_dim=4096,
                    intermediate_dim=11008,
                    num_layers=32,
                    num_heads=32,
                    num_kv_heads=32,
                    activation_function='silu',
                    initializer_range=0.02,
                    layer_norm_epsilon=0,
                    upcast_attn=False,
                    use_flash_attention=True,
                    attn_backend=None,
                    flash_attention_block_size=None,
                    gradient_checkpointing=True,
                    gradient_checkpointing_block_size=5,
                    scan_layers=True,
                    use_bias=False,
                    use_layer_norm_weight=False,
                    rope_scaling=None,
                    rope_theta=10000.0,
                    reference_checkpoint='meta-llama/Llama-2-7b-hf',
                    tokenizer=None
                  ),
                  q_proj=Linear(
                    weight=Named(float32{'layers': 32, 'kv_heads': 32, 'q_heads_per_group': 1, 'head_size': 128, 'embed': 4096}),
                    bias=None,
                    In=Axis(name='embed', size=4096),
                    Out=(
                      Axis(name='kv_heads', size=32),
                      Axis(name='q_heads_per_group', size=1),
                      Axis(name='head_size', size=128)
                    ),
                    dot_general=DefaultDotGeneralOp()
                  ),
                  k_proj=Linear(
                    weight=Named(float32{'layers': 32, 'kv_heads': 32, 'head_size': 128, 'embed': 4096}),
                    bias=None,
                    In=Axis(name='embed', size=4096),
                    Out=(
                      Axis(name='kv_heads', size=32),
                      Axis(name='head_size', size=128)
                    ),
                    dot_general=DefaultDotGeneralOp()
                  ),
                  v_proj=Linear(
                    weight=Named(float32{'layers': 32, 'kv_heads': 32, 'head_size': 128, 'embed': 4096}),
                    bias=None,
                    In=Axis(name='embed', size=4096),
                    Out=(
                      Axis(name='kv_heads', size=32),
                      Axis(name='head_size', size=128)
                    ),
                    dot_general=DefaultDotGeneralOp()
                  ),
                  o_proj=Linear(
                    weight=Named(float32{'layers': 32, 'embed': 4096, 'heads': 32, 'head_size': 128}),
                    bias=None,
                    In=(
                      Axis(name='heads', size=32),
                      Axis(name='head_size', size=128)
                    ),
                    Out=Axis(name='embed', size=4096),
                    dot_general=DefaultDotGeneralOp()
                  )
                ),
                mlp=LlamaMlp(
                  gate_proj=Linear(
                    weight=Named(float32{'layers': 32, 'mlp': 11008, 'embed': 4096}),
                    bias=None,
                    In=Axis(name='embed', size=4096),
                    Out=Axis(name='mlp', size=11008),
                    dot_general=DefaultDotGeneralOp()
                  ),
                  up_proj=Linear(
                    weight=Named(float32{'layers': 32, 'mlp': 11008, 'embed': 4096}),
                    bias=None,
                    In=Axis(name='embed', size=4096),
                    Out=Axis(name='mlp', size=11008),
                    dot_general=DefaultDotGeneralOp()
                  ),
                  down_proj=Linear(
                    weight=Named(float32{'layers': 32, 'embed': 4096, 'mlp': 11008}),
                    bias=None,
                    In=Axis(name='mlp', size=11008),
                    Out=Axis(name='embed', size=4096),
                    dot_general=DefaultDotGeneralOp()
                  ),
                  act=<function silu>
                ),
                input_layernorm=LlamaRMSNorm(
                  axis=Axis(name='embed', size=4096),
                  weight=None,
                  bias=None,
                  eps=0,
                  dtype=<class 'jax.numpy.float32'>
                ),
                post_attention_layernorm=LlamaRMSNorm(
                  axis=Axis(name='embed', size=4096),
                  weight=None,
                  bias=None,
                  eps=0,
                  dtype=<class 'jax.numpy.float32'>
                )
              ),
              Block=Axis(name='layers', size=32),
              gradient_checkpointing=True,
              prevent_cse=False
            ),
            norm=LlamaRMSNorm(
              axis=Axis(name='embed', size=4096),
              weight=None,
              bias=None,
              eps=0,
              dtype=<class 'jax.numpy.float32'>
            )
          ),
          embeddings=LlamaEmbedding(
            Vocab=Axis(name='vocab', size=50280),
            token_embeddings=Embedding(
              weight=Named(float32{'vocab': 50280, 'embed': 4096}),
              Vocab=Axis(name='vocab', size=50280),
              Embed=Axis(name='embed', size=4096)
            )
          ),
          lm_head=Linear(
            weight=Named(float32{'vocab': 50280, 'embed': 4096}),
            bias=None,
            In=Axis(name='embed', size=4096),
            Out=Axis(name='vocab', size=50280),
            dot_general=DefaultDotGeneralOp()
          )
        )
      ),
      EmptyState()
    )
  ),
  training_key=ShapeDtypeStruct(shape=(2,), dtype=uint32),
  is_trainable=True,
  mp=Policy(
    param_dtype=<class 'jax.numpy.float32'>,
    compute_dtype=<class 'jax.numpy.bfloat16'>,
    output_dtype=<class 'jax.numpy.bfloat16'>
  )
).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant