You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
File ~/Downloads/t5x/t5x/utils.py:368, in TrainStateInitializer.init(self, optimizer_def, init_fn, input_shapes, partitioner, input_types)
365 self._partitioner = partitioner
366 self.global_train_state_shape = jax.eval_shape(
367 initialize_train_state, rng=jax.random.PRNGKey(0))
--> 368 self.train_state_axes = partitioner.get_mesh_axes(
369 self.global_train_state_shape)
370 self._initialize_train_state = initialize_train_state
372 # Currently scanned layers require passing annotations through to the
373 # point of the scan transformation to resolve an XLA SPMD issue.
374
375 # init_fn is always(?) equal to model.get_initial_variables, fetch the model
376 # instance from the bound method.
File ~/Downloads/t5x/t5x/partitioning.py:892, in PjitPartitioner.get_mesh_axes(self, train_state)
888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
890 flat_logical_axes = traverse_util.flatten_dict(
891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/')
--> 892 flat_mesh_axes = {
893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()
894 }
896 return logical_axes.restore_state(
897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))
File ~/Downloads/t5x/t5x/partitioning.py:893, in (.0)
888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
890 flat_logical_axes = traverse_util.flatten_dict(
891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/')
892 flat_mesh_axes = {
--> 893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()
894 }
896 return logical_axes.restore_state(
897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))
File ~/Downloads/t5x/t5x/partitioning.py:888, in PjitPartitioner.get_mesh_axes.._logical_to_mesh_axes(param_name, logical_axes)
885 return flax_partitioning.logical_to_mesh_axes(logical_axes,
886 self._logical_axis_rules)
887 except ValueError as e:
--> 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
ValueError: Failed to map logical axes for target/decoder/logits_dense/kernel
In call to configurable 'train' (<function train at 0x2b751e160>)
`
The text was updated successfully, but these errors were encountered:
I am getting the following error when fine-tuning longT5 model:
`
ValueError Traceback (most recent call last)
Input In [16], in <cell line: 21>()
14 gin_utils.parse_gin_flags(
15 # User-provided gin paths take precedence if relative paths conflict.
16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
17 FLAGS.gin_file,
18 FLAGS.gin_bindings)
19 train_using_gin()
---> 21 gin_utils.run(main_train)
File ~/Downloads/t5x/t5x/gin_utils.py:105, in run(main)
103 def run(main):
104 """Wrapper for app.run that rewrites gin args before parsing."""
--> 105 app.run(
106 main,
107 flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a)))
File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:312, in run(main, argv, flags_parser)
310 callback()
311 try:
--> 312 _run_main(main, args)
313 except UsageError as error:
314 usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)
File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:258, in _run_main(main, argv)
256 sys.exit(retval)
257 else:
--> 258 sys.exit(main(argv))
Input In [15], in main_train(argv)
1 def main_train(argv: Sequence[str]):
2 """Wrapper for pdb post mortems."""
----> 3 _main(argv)
Input In [16], in _main(argv)
12 train_using_gin = gin.configurable(train)
14 gin_utils.parse_gin_flags(
15 # User-provided gin paths take precedence if relative paths conflict.
16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
17 FLAGS.gin_file,
18 FLAGS.gin_bindings)
---> 19 train_using_gin()
File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1605, in _make_gin_wrapper..gin_wrapper(*args, **kwargs)
1603 scope_info = " in scope '{}'".format(scope_str) if scope_str else ''
1604 err_str = err_str.format(name, fn_or_cls, scope_info)
-> 1605 utils.augment_exception_message_and_reraise(e, err_str)
File ~/opt/miniconda3/lib/python3.9/site-packages/gin/utils.py:41, in augment_exception_message_and_reraise(exception, message)
39 proxy = ExceptionProxy()
40 ExceptionProxy.qualname = type(exception).qualname
---> 41 raise proxy.with_traceback(exception.traceback) from None
File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1582, in _make_gin_wrapper..gin_wrapper(*args, **kwargs)
1579 new_kwargs.update(kwargs)
1581 try:
-> 1582 return fn(*new_args, **new_kwargs)
1583 except Exception as e: # pylint: disable=broad-except
1584 err_str = ''
Input In [7], in train(model, train_dataset_cfg, train_eval_dataset_cfg, infer_eval_dataset_cfg, checkpoint_cfg, partitioner, trainer_cls, model_dir, total_steps, eval_steps, eval_period, stats_period, random_seed, use_hardware_rng, summarize_config_fn, inference_evaluator_cls, get_dataset_fn, concurrent_metrics, actions, train_eval_get_dataset_fn, run_eval_before_training, use_gda)
224 input_types = {
225 k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items()
226 }
227 init_or_restore_tick = time.time()
--> 228 train_state_initializer = utils.TrainStateInitializer(
229 optimizer_def=model.optimizer_def,
230 init_fn=model.get_initial_variables,
231 input_shapes=input_shapes,
232 input_types=input_types,
233 partitioner=partitioner)
234 # 3. From scratch using
init_fn
.235 train_state = train_state_initializer.from_checkpoint_or_scratch(
236 restore_cfgs, init_rng=init_rng, ds_iter=checkpointable_train_iter)
File ~/Downloads/t5x/t5x/utils.py:368, in TrainStateInitializer.init(self, optimizer_def, init_fn, input_shapes, partitioner, input_types)
365 self._partitioner = partitioner
366 self.global_train_state_shape = jax.eval_shape(
367 initialize_train_state, rng=jax.random.PRNGKey(0))
--> 368 self.train_state_axes = partitioner.get_mesh_axes(
369 self.global_train_state_shape)
370 self._initialize_train_state = initialize_train_state
372 # Currently scanned layers require passing annotations through to the
373 # point of the scan transformation to resolve an XLA SPMD issue.
374
375 # init_fn is always(?) equal to model.get_initial_variables, fetch the model
376 # instance from the bound method.
File ~/Downloads/t5x/t5x/partitioning.py:892, in PjitPartitioner.get_mesh_axes(self, train_state)
888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
890 flat_logical_axes = traverse_util.flatten_dict(
891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/')
--> 892 flat_mesh_axes = {
893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()
894 }
896 return logical_axes.restore_state(
897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))
File ~/Downloads/t5x/t5x/partitioning.py:893, in (.0)
888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
890 flat_logical_axes = traverse_util.flatten_dict(
891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/')
892 flat_mesh_axes = {
--> 893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()
894 }
896 return logical_axes.restore_state(
897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))
File ~/Downloads/t5x/t5x/partitioning.py:888, in PjitPartitioner.get_mesh_axes.._logical_to_mesh_axes(param_name, logical_axes)
885 return flax_partitioning.logical_to_mesh_axes(logical_axes,
886 self._logical_axis_rules)
887 except ValueError as e:
--> 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
ValueError: Failed to map logical axes for target/decoder/logits_dense/kernel
In call to configurable 'train' (<function train at 0x2b751e160>)
`
The text was updated successfully, but these errors were encountered: