Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to map logical axes for target/decoder/logits... #2

Open
ibulu opened this issue Apr 29, 2022 · 0 comments
Open

Failed to map logical axes for target/decoder/logits... #2

ibulu opened this issue Apr 29, 2022 · 0 comments

Comments

@ibulu
Copy link

ibulu commented Apr 29, 2022

I am getting the following error when fine-tuning longT5 model:

`
ValueError Traceback (most recent call last)
Input In [16], in <cell line: 21>()
14 gin_utils.parse_gin_flags(
15 # User-provided gin paths take precedence if relative paths conflict.
16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
17 FLAGS.gin_file,
18 FLAGS.gin_bindings)
19 train_using_gin()
---> 21 gin_utils.run(main_train)

File ~/Downloads/t5x/t5x/gin_utils.py:105, in run(main)
103 def run(main):
104 """Wrapper for app.run that rewrites gin args before parsing."""
--> 105 app.run(
106 main,
107 flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a)))

File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:312, in run(main, argv, flags_parser)
310 callback()
311 try:
--> 312 _run_main(main, args)
313 except UsageError as error:
314 usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)

File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:258, in _run_main(main, argv)
256 sys.exit(retval)
257 else:
--> 258 sys.exit(main(argv))

Input In [15], in main_train(argv)
1 def main_train(argv: Sequence[str]):
2 """Wrapper for pdb post mortems."""
----> 3 _main(argv)

Input In [16], in _main(argv)
12 train_using_gin = gin.configurable(train)
14 gin_utils.parse_gin_flags(
15 # User-provided gin paths take precedence if relative paths conflict.
16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
17 FLAGS.gin_file,
18 FLAGS.gin_bindings)
---> 19 train_using_gin()

File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1605, in _make_gin_wrapper..gin_wrapper(*args, **kwargs)
1603 scope_info = " in scope '{}'".format(scope_str) if scope_str else ''
1604 err_str = err_str.format(name, fn_or_cls, scope_info)
-> 1605 utils.augment_exception_message_and_reraise(e, err_str)

File ~/opt/miniconda3/lib/python3.9/site-packages/gin/utils.py:41, in augment_exception_message_and_reraise(exception, message)
39 proxy = ExceptionProxy()
40 ExceptionProxy.qualname = type(exception).qualname
---> 41 raise proxy.with_traceback(exception.traceback) from None

File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1582, in _make_gin_wrapper..gin_wrapper(*args, **kwargs)
1579 new_kwargs.update(kwargs)
1581 try:
-> 1582 return fn(*new_args, **new_kwargs)
1583 except Exception as e: # pylint: disable=broad-except
1584 err_str = ''

Input In [7], in train(model, train_dataset_cfg, train_eval_dataset_cfg, infer_eval_dataset_cfg, checkpoint_cfg, partitioner, trainer_cls, model_dir, total_steps, eval_steps, eval_period, stats_period, random_seed, use_hardware_rng, summarize_config_fn, inference_evaluator_cls, get_dataset_fn, concurrent_metrics, actions, train_eval_get_dataset_fn, run_eval_before_training, use_gda)
224 input_types = {
225 k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items()
226 }
227 init_or_restore_tick = time.time()
--> 228 train_state_initializer = utils.TrainStateInitializer(
229 optimizer_def=model.optimizer_def,
230 init_fn=model.get_initial_variables,
231 input_shapes=input_shapes,
232 input_types=input_types,
233 partitioner=partitioner)
234 # 3. From scratch using init_fn.
235 train_state = train_state_initializer.from_checkpoint_or_scratch(
236 restore_cfgs, init_rng=init_rng, ds_iter=checkpointable_train_iter)

File ~/Downloads/t5x/t5x/utils.py:368, in TrainStateInitializer.init(self, optimizer_def, init_fn, input_shapes, partitioner, input_types)
365 self._partitioner = partitioner
366 self.global_train_state_shape = jax.eval_shape(
367 initialize_train_state, rng=jax.random.PRNGKey(0))
--> 368 self.train_state_axes = partitioner.get_mesh_axes(
369 self.global_train_state_shape)
370 self._initialize_train_state = initialize_train_state
372 # Currently scanned layers require passing annotations through to the
373 # point of the scan transformation to resolve an XLA SPMD issue.
374
375 # init_fn is always(?) equal to model.get_initial_variables, fetch the model
376 # instance from the bound method.

File ~/Downloads/t5x/t5x/partitioning.py:892, in PjitPartitioner.get_mesh_axes(self, train_state)
888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
890 flat_logical_axes = traverse_util.flatten_dict(
891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/')
--> 892 flat_mesh_axes = {
893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()
894 }
896 return logical_axes.restore_state(
897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))

File ~/Downloads/t5x/t5x/partitioning.py:893, in (.0)
888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
890 flat_logical_axes = traverse_util.flatten_dict(
891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/')
892 flat_mesh_axes = {
--> 893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()
894 }
896 return logical_axes.restore_state(
897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))

File ~/Downloads/t5x/t5x/partitioning.py:888, in PjitPartitioner.get_mesh_axes.._logical_to_mesh_axes(param_name, logical_axes)
885 return flax_partitioning.logical_to_mesh_axes(logical_axes,
886 self._logical_axis_rules)
887 except ValueError as e:
--> 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e

ValueError: Failed to map logical axes for target/decoder/logits_dense/kernel
In call to configurable 'train' (<function train at 0x2b751e160>)

`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant