-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error when using optax.Multisteps
with optax.contrib.schedule_free
#1038
Comments
Are your parameters in bfloat16 or in float32? (I know you probably tried but just in case, before digging into that more). Also pinging the author of this code @nullstring |
My parameters are in bfloat16, |
I think it's the same issue as: #377 (comment) |
I don't don't if that's the exact same bug, but it also happens when using learning_rate_fn = optax.warmup_constant_schedule(peak_value=retuned_lr)
optimizer = optax.adam(learning_rate_fn, b1=0.)
optimizer = optax.contrib.schedule_free(optimizer, learning_rate_fn, b1=b1)
optimizer = optax.apply_if_finite(optimizer, 5) I got an error telling me one can't cast NoneType to float32. |
Hello,
I am getting the following exception when I try to wrap a schedule free optimizer with multisteps. Can you help me?
Exception message:
`
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ozan/cloud/train/trainer.py", line 116, in train_step
new_state = state.apply_gradients(grads=grads, train_rngs=new_train_rngs)
File "/home/ozan/.local/lib/python3.10/site-packages/flax/training/train_state.py", line 101, in apply_gradients
updates, new_opt_state = self.tx.update(
File "/home/ozan/.local/lib/python3.10/site-packages/optax/transforms/_accumulation.py", line 380, in update
new_updates, new_state = lax.cond(
TypeError: true_fun and false_fun output must have identical types, got
({'BERT_0': {'embedding_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'encoders_0': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_1': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_2': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_3': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_4': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'out': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[8,64,512]) vs. ShapedArray(float32[8,64,512])'}, 'query': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}, 'value': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,8,64]) vs. ShapedArray(float32[512,8,64])'}}, 'encoder_layer_norm': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'scale': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])'}, 'ffn': {'layers_0': {'bias': 'DIFFERENT ShapedArray(bfloat16[2048]) vs. ShapedArray(float32[2048])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])'}, 'layers_3': {'bias': 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])', 'kernel': 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])'}}}, 'encoders_5': {'attention': {'key': {'bias': 'DIFFERENT ShapedArray(bfloat16[8,64]) vs. ShapedArray(float32[8,64])'....`
The text was updated successfully, but these errors were encountered: