We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
In principle we can add a flag to reset optimizer state, then change the following in trainer.py
trainer.py
if self.config.reset_optimizer_state: saveable_train_state = dataclasses.replace(saveable_train_state, optimizer=False)
However this is causing some cryptic configuration issue. Stack trace is below:
Traceback (most recent call last): File "/home/ahmed/levanter/src/levanter/main/train_lm.py", line 222, in <module> levanter.config.main(main)() File "/home/ahmed/levanter/src/levanter/config.py", line 84, in wrapper_inner response = fn(cfg, *args, **kwargs) File "/home/ahmed/levanter/src/levanter/main/train_lm.py", line 131, in main state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) File "/home/ahmed/levanter/src/levanter/trainer.py", line 345, in initial_state state = load_checkpoint_or_initialize( File "/home/ahmed/levanter/src/levanter/checkpoint.py", line 464, in load_or_init filtered_state_shape = equinox.filter(state_shape, is_checkpointed) File "/home/ahmed/venv310/lib/python3.10/site-packages/equinox/_filters.py", line 129, in filter filter_tree = jtu.tree_map(_make_filter_tree(is_leaf), filter_spec, pytree) File "/home/ahmed/venv310/lib/python3.10/site-packages/jax/_src/tree_util.py", line 342, in tree_map all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] File "/home/ahmed/venv310/lib/python3.10/site-packages/jax/_src/tree_util.py", line 342, in <listcomp> all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] ValueError: Mismatch custom node data: ('step', 'model', 'opt_state', 'training_key'), ('optimizer', 'is_trainable', 'mp'), (False, True, Policy(param_dtype=<class 'jax.numpy.float32'>, compute_dtype=<class 'jax.numpy.bfloat16'>, output_dtype=<class 'jax.numpy.bfloat16'>)) != ('step', 'model', 'opt_state', 'training_key'), ('optimizer', 'is_trainable', 'mp'), (GradientTransformationExtraArgs(init=<function inject_hyperparams.<locals>.wrapped_transform.<locals>.init_fn at 0x7ee1a0719bd0>, update=<function inject_hyperparams.<locals>.wrapped_transform.<locals>.update_fn at 0x7ee1a0719c60>), True, Policy(param_dtype=<class 'jax.numpy.float32'>, compute_dtype=<class 'jax.numpy.bfloat16'>, output_dtype=<class 'jax.numpy.bfloat16'>)); value: TrainerState( step=ShapeDtypeStruct(shape=(), dtype=int32), model=LlamaLMHeadModel( transformer=LlamaTransformer( config=LlamaConfig( seq_len=4096, hidden_dim=4096, intermediate_dim=11008, num_layers=32, num_heads=32, num_kv_heads=32, activation_function='silu', initializer_range=0.02, layer_norm_epsilon=0, upcast_attn=False, use_flash_attention=True, attn_backend=None, flash_attention_block_size=None, gradient_checkpointing=True, gradient_checkpointing_block_size=5, scan_layers=True, use_bias=False, use_layer_norm_weight=False, rope_scaling=None, rope_theta=10000.0, reference_checkpoint='meta-llama/Llama-2-7b-hf', tokenizer=None ), layers=Stacked( stacked=LlamaDecoderLayer( config=LlamaConfig( seq_len=4096, hidden_dim=4096, intermediate_dim=11008, num_layers=32, num_heads=32, num_kv_heads=32, activation_function='silu', initializer_range=0.02, layer_norm_epsilon=0, upcast_attn=False, use_flash_attention=True, attn_backend=None, flash_attention_block_size=None, gradient_checkpointing=True, gradient_checkpointing_block_size=5, scan_layers=True, use_bias=False, use_layer_norm_weight=False, rope_scaling=None, rope_theta=10000.0, reference_checkpoint='meta-llama/Llama-2-7b-hf', tokenizer=None ), self_attn=LlamaAttention( config=LlamaConfig( seq_len=4096, hidden_dim=4096, intermediate_dim=11008, num_layers=32, num_heads=32, num_kv_heads=32, activation_function='silu', initializer_range=0.02, layer_norm_epsilon=0, upcast_attn=False, use_flash_attention=True, attn_backend=None, flash_attention_block_size=None, gradient_checkpointing=True, gradient_checkpointing_block_size=5, scan_layers=True, use_bias=False, use_layer_norm_weight=False, rope_scaling=None, rope_theta=10000.0, reference_checkpoint='meta-llama/Llama-2-7b-hf', tokenizer=None ), q_proj=Linear( weight=Named(float32{'layers': 32, 'kv_heads': 32, 'q_heads_per_group': 1, 'head_size': 128, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=( Axis(name='kv_heads', size=32), Axis(name='q_heads_per_group', size=1), Axis(name='head_size', size=128) ), dot_general=DefaultDotGeneralOp() ), k_proj=Linear( weight=Named(float32{'layers': 32, 'kv_heads': 32, 'head_size': 128, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=( Axis(name='kv_heads', size=32), Axis(name='head_size', size=128) ), dot_general=DefaultDotGeneralOp() ), v_proj=Linear( weight=Named(float32{'layers': 32, 'kv_heads': 32, 'head_size': 128, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=( Axis(name='kv_heads', size=32), Axis(name='head_size', size=128) ), dot_general=DefaultDotGeneralOp() ), o_proj=Linear( weight=Named(float32{'layers': 32, 'embed': 4096, 'heads': 32, 'head_size': 128}), bias=None, In=(Axis(name='heads', size=32), Axis(name='head_size', size=128)), Out=Axis(name='embed', size=4096), dot_general=DefaultDotGeneralOp() ) ), mlp=LlamaMlp( gate_proj=Linear( weight=Named(float32{'layers': 32, 'mlp': 11008, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=Axis(name='mlp', size=11008), dot_general=DefaultDotGeneralOp() ), up_proj=Linear( weight=Named(float32{'layers': 32, 'mlp': 11008, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=Axis(name='mlp', size=11008), dot_general=DefaultDotGeneralOp() ), down_proj=Linear( weight=Named(float32{'layers': 32, 'embed': 4096, 'mlp': 11008}), bias=None, In=Axis(name='mlp', size=11008), Out=Axis(name='embed', size=4096), dot_general=DefaultDotGeneralOp() ), act=<function silu> ), input_layernorm=LlamaRMSNorm( axis=Axis(name='embed', size=4096), weight=None, bias=None, eps=0, dtype=<class 'jax.numpy.float32'> ), post_attention_layernorm=LlamaRMSNorm( axis=Axis(name='embed', size=4096), weight=None, bias=None, eps=0, dtype=<class 'jax.numpy.float32'> ) ), Block=Axis(name='layers', size=32), gradient_checkpointing=True, prevent_cse=False ), norm=LlamaRMSNorm( axis=Axis(name='embed', size=4096), weight=None, bias=None, eps=0, dtype=<class 'jax.numpy.float32'> ) ), embeddings=LlamaEmbedding( Vocab=Axis(name='vocab', size=50280), token_embeddings=Embedding( weight=Named(float32{'vocab': 50280, 'embed': 4096}), Vocab=Axis(name='vocab', size=50280), Embed=Axis(name='embed', size=4096) ) ), lm_head=Linear( weight=Named(float32{'vocab': 50280, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=Axis(name='vocab', size=50280), dot_general=DefaultDotGeneralOp() ) ), optimizer=GradientTransformationExtraArgs( init=<function init_fn>, update=<function update_fn> ), opt_state=InjectStatefulHyperparamsState( count=ShapeDtypeStruct(shape=(), dtype=int32), hyperparams={'learning_rate': ShapeDtypeStruct(shape=(), dtype=float32)}, hyperparams_states={ 'learning_rate': WrappedScheduleState(count=ShapeDtypeStruct(shape=(), dtype=int32)) }, inner_state=( EmptyState(), ScaleByAdamState( count=ShapeDtypeStruct(shape=(), dtype=int32), mu=LlamaLMHeadModel( transformer=LlamaTransformer( config=LlamaConfig( seq_len=4096, hidden_dim=4096, intermediate_dim=11008, num_layers=32, num_heads=32, num_kv_heads=32, activation_function='silu', initializer_range=0.02, layer_norm_epsilon=0, upcast_attn=False, use_flash_attention=True, attn_backend=None, flash_attention_block_size=None, gradient_checkpointing=True, gradient_checkpointing_block_size=5, scan_layers=True, use_bias=False, use_layer_norm_weight=False, rope_scaling=None, rope_theta=10000.0, reference_checkpoint='meta-llama/Llama-2-7b-hf', tokenizer=None ), layers=Stacked( stacked=LlamaDecoderLayer( config=LlamaConfig( seq_len=4096, hidden_dim=4096, intermediate_dim=11008, num_layers=32, num_heads=32, num_kv_heads=32, activation_function='silu', initializer_range=0.02, layer_norm_epsilon=0, upcast_attn=False, use_flash_attention=True, attn_backend=None, flash_attention_block_size=None, gradient_checkpointing=True, gradient_checkpointing_block_size=5, scan_layers=True, use_bias=False, use_layer_norm_weight=False, rope_scaling=None, rope_theta=10000.0, reference_checkpoint='meta-llama/Llama-2-7b-hf', tokenizer=None ), self_attn=LlamaAttention( config=LlamaConfig( seq_len=4096, hidden_dim=4096, intermediate_dim=11008, num_layers=32, num_heads=32, num_kv_heads=32, activation_function='silu', initializer_range=0.02, layer_norm_epsilon=0, upcast_attn=False, use_flash_attention=True, attn_backend=None, flash_attention_block_size=None, gradient_checkpointing=True, gradient_checkpointing_block_size=5, scan_layers=True, use_bias=False, use_layer_norm_weight=False, rope_scaling=None, rope_theta=10000.0, reference_checkpoint='meta-llama/Llama-2-7b-hf', tokenizer=None ), q_proj=Linear( weight=Named(float32{'layers': 32, 'kv_heads': 32, 'q_heads_per_group': 1, 'head_size': 128, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=( Axis(name='kv_heads', size=32), Axis(name='q_heads_per_group', size=1), Axis(name='head_size', size=128) ), dot_general=DefaultDotGeneralOp() ), k_proj=Linear( weight=Named(float32{'layers': 32, 'kv_heads': 32, 'head_size': 128, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=( Axis(name='kv_heads', size=32), Axis(name='head_size', size=128) ), dot_general=DefaultDotGeneralOp() ), v_proj=Linear( weight=Named(float32{'layers': 32, 'kv_heads': 32, 'head_size': 128, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=( Axis(name='kv_heads', size=32), Axis(name='head_size', size=128) ), dot_general=DefaultDotGeneralOp() ), o_proj=Linear( weight=Named(float32{'layers': 32, 'embed': 4096, 'heads': 32, 'head_size': 128}), bias=None, In=( Axis(name='heads', size=32), Axis(name='head_size', size=128) ), Out=Axis(name='embed', size=4096), dot_general=DefaultDotGeneralOp() ) ), mlp=LlamaMlp( gate_proj=Linear( weight=Named(float32{'layers': 32, 'mlp': 11008, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=Axis(name='mlp', size=11008), dot_general=DefaultDotGeneralOp() ), up_proj=Linear( weight=Named(float32{'layers': 32, 'mlp': 11008, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=Axis(name='mlp', size=11008), dot_general=DefaultDotGeneralOp() ), down_proj=Linear( weight=Named(float32{'layers': 32, 'embed': 4096, 'mlp': 11008}), bias=None, In=Axis(name='mlp', size=11008), Out=Axis(name='embed', size=4096), dot_general=DefaultDotGeneralOp() ), act=<function silu> ), input_layernorm=LlamaRMSNorm( axis=Axis(name='embed', size=4096), weight=None, bias=None, eps=0, dtype=<class 'jax.numpy.float32'> ), post_attention_layernorm=LlamaRMSNorm( axis=Axis(name='embed', size=4096), weight=None, bias=None, eps=0, dtype=<class 'jax.numpy.float32'> ) ), Block=Axis(name='layers', size=32), gradient_checkpointing=True, prevent_cse=False ), norm=LlamaRMSNorm( axis=Axis(name='embed', size=4096), weight=None, bias=None, eps=0, dtype=<class 'jax.numpy.float32'> ) ), embeddings=LlamaEmbedding( Vocab=Axis(name='vocab', size=50280), token_embeddings=Embedding( weight=Named(float32{'vocab': 50280, 'embed': 4096}), Vocab=Axis(name='vocab', size=50280), Embed=Axis(name='embed', size=4096) ) ), lm_head=Linear( weight=Named(float32{'vocab': 50280, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=Axis(name='vocab', size=50280), dot_general=DefaultDotGeneralOp() ) ), nu=LlamaLMHeadModel( transformer=LlamaTransformer( config=LlamaConfig( seq_len=4096, hidden_dim=4096, intermediate_dim=11008, num_layers=32, num_heads=32, num_kv_heads=32, activation_function='silu', initializer_range=0.02, layer_norm_epsilon=0, upcast_attn=False, use_flash_attention=True, attn_backend=None, flash_attention_block_size=None, gradient_checkpointing=True, gradient_checkpointing_block_size=5, scan_layers=True, use_bias=False, use_layer_norm_weight=False, rope_scaling=None, rope_theta=10000.0, reference_checkpoint='meta-llama/Llama-2-7b-hf', tokenizer=None ), layers=Stacked( stacked=LlamaDecoderLayer( config=LlamaConfig( seq_len=4096, hidden_dim=4096, intermediate_dim=11008, num_layers=32, num_heads=32, num_kv_heads=32, activation_function='silu', initializer_range=0.02, layer_norm_epsilon=0, upcast_attn=False, use_flash_attention=True, attn_backend=None, flash_attention_block_size=None, gradient_checkpointing=True, gradient_checkpointing_block_size=5, scan_layers=True, use_bias=False, use_layer_norm_weight=False, rope_scaling=None, rope_theta=10000.0, reference_checkpoint='meta-llama/Llama-2-7b-hf', tokenizer=None ), self_attn=LlamaAttention( config=LlamaConfig( seq_len=4096, hidden_dim=4096, intermediate_dim=11008, num_layers=32, num_heads=32, num_kv_heads=32, activation_function='silu', initializer_range=0.02, layer_norm_epsilon=0, upcast_attn=False, use_flash_attention=True, attn_backend=None, flash_attention_block_size=None, gradient_checkpointing=True, gradient_checkpointing_block_size=5, scan_layers=True, use_bias=False, use_layer_norm_weight=False, rope_scaling=None, rope_theta=10000.0, reference_checkpoint='meta-llama/Llama-2-7b-hf', tokenizer=None ), q_proj=Linear( weight=Named(float32{'layers': 32, 'kv_heads': 32, 'q_heads_per_group': 1, 'head_size': 128, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=( Axis(name='kv_heads', size=32), Axis(name='q_heads_per_group', size=1), Axis(name='head_size', size=128) ), dot_general=DefaultDotGeneralOp() ), k_proj=Linear( weight=Named(float32{'layers': 32, 'kv_heads': 32, 'head_size': 128, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=( Axis(name='kv_heads', size=32), Axis(name='head_size', size=128) ), dot_general=DefaultDotGeneralOp() ), v_proj=Linear( weight=Named(float32{'layers': 32, 'kv_heads': 32, 'head_size': 128, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=( Axis(name='kv_heads', size=32), Axis(name='head_size', size=128) ), dot_general=DefaultDotGeneralOp() ), o_proj=Linear( weight=Named(float32{'layers': 32, 'embed': 4096, 'heads': 32, 'head_size': 128}), bias=None, In=( Axis(name='heads', size=32), Axis(name='head_size', size=128) ), Out=Axis(name='embed', size=4096), dot_general=DefaultDotGeneralOp() ) ), mlp=LlamaMlp( gate_proj=Linear( weight=Named(float32{'layers': 32, 'mlp': 11008, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=Axis(name='mlp', size=11008), dot_general=DefaultDotGeneralOp() ), up_proj=Linear( weight=Named(float32{'layers': 32, 'mlp': 11008, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=Axis(name='mlp', size=11008), dot_general=DefaultDotGeneralOp() ), down_proj=Linear( weight=Named(float32{'layers': 32, 'embed': 4096, 'mlp': 11008}), bias=None, In=Axis(name='mlp', size=11008), Out=Axis(name='embed', size=4096), dot_general=DefaultDotGeneralOp() ), act=<function silu> ), input_layernorm=LlamaRMSNorm( axis=Axis(name='embed', size=4096), weight=None, bias=None, eps=0, dtype=<class 'jax.numpy.float32'> ), post_attention_layernorm=LlamaRMSNorm( axis=Axis(name='embed', size=4096), weight=None, bias=None, eps=0, dtype=<class 'jax.numpy.float32'> ) ), Block=Axis(name='layers', size=32), gradient_checkpointing=True, prevent_cse=False ), norm=LlamaRMSNorm( axis=Axis(name='embed', size=4096), weight=None, bias=None, eps=0, dtype=<class 'jax.numpy.float32'> ) ), embeddings=LlamaEmbedding( Vocab=Axis(name='vocab', size=50280), token_embeddings=Embedding( weight=Named(float32{'vocab': 50280, 'embed': 4096}), Vocab=Axis(name='vocab', size=50280), Embed=Axis(name='embed', size=4096) ) ), lm_head=Linear( weight=Named(float32{'vocab': 50280, 'embed': 4096}), bias=None, In=Axis(name='embed', size=4096), Out=Axis(name='vocab', size=50280), dot_general=DefaultDotGeneralOp() ) ) ), EmptyState() ) ), training_key=ShapeDtypeStruct(shape=(2,), dtype=uint32), is_trainable=True, mp=Policy( param_dtype=<class 'jax.numpy.float32'>, compute_dtype=<class 'jax.numpy.bfloat16'>, output_dtype=<class 'jax.numpy.bfloat16'> ) ).
The text was updated successfully, but these errors were encountered:
No branches or pull requests
In principle we can add a flag to reset optimizer state, then change the following in
trainer.py
However this is causing some cryptic configuration issue. Stack trace is below:
The text was updated successfully, but these errors were encountered: