Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] train TFT model on mac M1 mps device. element 0 of tensors does not require grad and does not have a grad_fn #1721

Open
zy636 opened this issue Nov 23, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@zy636
Copy link

zy636 commented Nov 23, 2024

Describe the bug

While following the tutorial at Demand forecasting with the Temporal Fusion Transformer, I encountered the following issue when :

  1. changing the accelerator from 'cpu' to 'mps' of pl.Trainer call
  2. and set PYTORCH_ENABLE_MPS_FALLBACK=1 (because of I encountered an error The operator 'aten::_embedding_bag' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.):
File ~/miniconda3/envs/tft/lib/python3.12/site-packages/torch/autograd/graph.py:825, in _engine_run_backward(t_outputs, *args, **kwargs)
    823     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    824 try:
--> 825     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    826         t_outputs, *args, **kwargs
    827     )  # Calls into the C++ engine to run the backward pass
    828 finally:
    829     if attach_logging_hooks:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

To Reproduce

trainer2 = pl.Trainer(
    max_epochs=50,
    accelerator="mps",  # change from "cpu" to "mps"
    enable_model_summary=True,
    gradient_clip_val=0.1,
    limit_train_batches=50,  # coment in for training, running valiation every 30 batches
    # fast_dev_run=True,  # comment in to check that networkor dataset has no serious bugs
    callbacks=[lr_logger, early_stop_callback],
    logger=logger,
)

tft2 = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=2,
    dropout=0.1,
    hidden_continuous_size=8,
    loss=QuantileLoss(),
    log_interval=10,  # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
    optimizer="ranger",
    reduce_on_plateau_patience=4,
)
print(f"Number of parameters in network: {tft2.size() / 1e3:.1f}k")

trainer2.fit(
    tft2,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

Expected behavior

Additional context

trainer2 = pl.Trainer(
   ...:     max_epochs=50,
   ...:     accelerator="mps",
   ...:     enable_model_summary=True,
   ...:     gradient_clip_val=0.1,
   ...:     limit_train_batches=50,  # coment in for training, running valiation every 30 batches
   ...:     # fast_dev_run=True,  # comment in to check that networkor dataset has no serious bugs
   ...:     callbacks=[lr_logger, early_stop_callback],
   ...:     logger=logger,
   ...: )
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

In [5]: 

In [5]: trainer2.fit(
   ...:     tft2,
   ...:     train_dataloaders=train_dataloader,
   ...:     val_dataloaders=val_dataloader,
   ...: )

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                  | 1.3 K  | train
3  | prescalers                         | ModuleDict                      | 256    | train
4  | static_variable_selection          | VariableSelectionNetwork        | 3.4 K  | train
5  | encoder_variable_selection         | VariableSelectionNetwork        | 8.0 K  | train
6  | decoder_variable_selection         | VariableSelectionNetwork        | 2.7 K  | train
7  | static_context_variable_selection  | GatedResidualNetwork            | 1.1 K  | train
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 1.1 K  | train
9  | static_context_initial_cell_lstm   | GatedResidualNetwork            | 1.1 K  | train
10 | static_context_enrichment          | GatedResidualNetwork            | 1.1 K  | train
11 | lstm_encoder                       | LSTM                            | 2.2 K  | train
12 | lstm_decoder                       | LSTM                            | 2.2 K  | train
13 | post_lstm_gate_encoder             | GatedLinearUnit                 | 544    | train
14 | post_lstm_add_norm_encoder         | AddNorm                         | 32     | train
15 | static_enrichment                  | GatedResidualNetwork            | 1.4 K  | train
16 | multihead_attn                     | InterpretableMultiHeadAttention | 808    | train
17 | post_attn_gate_norm                | GateAddNorm                     | 576    | train
18 | pos_wise_ff                        | GatedResidualNetwork            | 1.1 K  | train
19 | pre_output_gate_norm               | GateAddNorm                     | 576    | train
20 | output_layer                       | Linear                          | 119    | train
------------------------------------------------------------------------------------------------
29.4 K    Trainable params
0         Non-trainable params
29.4 K    Total params
0.118     Total estimated model params size (MB)
480       Modules in train mode
0         Modules in eval mode
Epoch 0:   0%|                                                                                                                                    | 0/50 [00:00<?, ?it/s]---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 1
----> 1 trainer2.fit(
      2     tft2,
      3     train_dataloaders=train_dataloader,
      4     val_dataloaders=val_dataloader,
      5 )

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    536 self.state.status = TrainerStatus.RUNNING
    537 self.training = True
--> 538 call._call_and_handle_interrupt(
    539     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    540 )

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     45     if trainer.strategy.launcher is not None:
     46         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47     return trainer_fn(*args, **kwargs)
     49 except _TunerExitException:
     50     _call_teardown_hook(trainer)

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    567 assert self.state.fn is not None
    568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    569     self.state.fn,
    570     ckpt_path,
    571     model_provided=True,
    572     model_connected=self.lightning_module is not None,
    573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
    576 assert self.state.stopped
    577 self.training = False

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:981, in Trainer._run(self, model, ckpt_path)
    976 self._signal_connector.register_signal_handlers()
    978 # ----------------------------
    979 # RUN THE TRAINER
    980 # ----------------------------
--> 981 results = self._run_stage()
    983 # ----------------------------
    984 # POST-Training CLEAN UP
    985 # ----------------------------
    986 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:1025, in Trainer._run_stage(self)
   1023         self._run_sanity_check()
   1024     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1025         self.fit_loop.run()
   1026     return None
   1027 raise RuntimeError(f"Unexpected state {self.state}")

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:205, in _FitLoop.run(self)
    203 try:
    204     self.on_advance_start()
--> 205     self.advance()
    206     self.on_advance_end()
    207     self._restarting = False

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:363, in _FitLoop.advance(self)
    361 with self.trainer.profiler.profile("run_training_epoch"):
    362     assert self._data_fetcher is not None
--> 363     self.epoch_loop.run(self._data_fetcher)

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py:140, in _TrainingEpochLoop.run(self, data_fetcher)
    138 while not self.done:
    139     try:
--> 140         self.advance(data_fetcher)
    141         self.on_advance_end(data_fetcher)
    142         self._restarting = False

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py:250, in _TrainingEpochLoop.advance(self, data_fetcher)
    247 with trainer.profiler.profile("run_training_batch"):
    248     if trainer.lightning_module.automatic_optimization:
    249         # in automatic optimization, there can only be one optimizer
--> 250         batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
    251     else:
    252         batch_output = self.manual_optimization.run(kwargs)

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py:190, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
    183         closure()
    185 # ------------------------------
    186 # BACKWARD PASS
    187 # ------------------------------
    188 # gradient update with accumulated gradients
    189 else:
--> 190     self._optimizer_step(batch_idx, closure)
    192 result = closure.consume_result()
    193 if result.loss is None:

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py:268, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
    265     self.optim_progress.optimizer.step.increment_ready()
    267 # model hook
--> 268 call._call_lightning_module_hook(
    269     trainer,
    270     "optimizer_step",
    271     trainer.current_epoch,
    272     batch_idx,
    273     optimizer,
    274     train_step_and_backward_closure,
    275 )
    277 if not should_accumulate:
    278     self.optim_progress.optimizer.step.increment_completed()

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:167, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    164 pl_module._current_fx_name = hook_name
    166 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 167     output = fn(*args, **kwargs)
    169 # restore current_fx when nested context
    170 pl_module._current_fx_name = prev_fx_name

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/core/module.py:1306, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
   1275 def optimizer_step(
   1276     self,
   1277     epoch: int,
   (...)
   1280     optimizer_closure: Optional[Callable[[], Any]] = None,
   1281 ) -> None:
   1282     r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
   1283     the optimizer.
   1284 
   (...)
   1304 
   1305     """
-> 1306     optimizer.step(closure=optimizer_closure)

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py:153, in LightningOptimizer.step(self, closure, **kwargs)
    150     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    152 assert self._strategy is not None
--> 153 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
    155 self._on_after_step()
    157 return step_output

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py:238, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
    236 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
    237 assert isinstance(model, pl.LightningModule)
--> 238 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py:122, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
    120 """Hook to run the optimizer step."""
    121 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 122 return optimizer.step(closure=closure, **kwargs)

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/torch/optim/optimizer.py:487, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    482         else:
    483             raise RuntimeError(
    484                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    485             )
--> 487 out = func(*args, **kwargs)
    488 self._optimizer_step_code()
    490 # call optimizer step post hooks

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/pytorch_optimizer/optimizer/ranger21.py:190, in Ranger21.step(self, closure)
    188 if closure is not None:
    189     with torch.enable_grad():
--> 190         loss = closure()
    192 param_size: int = 0
    193 variance_ma_sum: float = 1.0

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py:108, in Precision._wrap_closure(self, model, optimizer, closure)
     95 def _wrap_closure(
     96     self,
     97     model: "pl.LightningModule",
     98     optimizer: Steppable,
     99     closure: Callable[[], Any],
    100 ) -> Any:
    101     """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
    102     hook is called.
    103 
   (...)
    106 
    107     """
--> 108     closure_result = closure()
    109     self._after_closure(model, optimizer)
    110     return closure_result

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py:144, in Closure.__call__(self, *args, **kwargs)
    142 @override
    143 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 144     self._result = self.closure(*args, **kwargs)
    145     return self._result.loss

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py:138, in Closure.closure(self, *args, **kwargs)
    135     self._zero_grad_fn()
    137 if self._backward_fn is not None and step_output.closure_loss is not None:
--> 138     self._backward_fn(step_output.closure_loss)
    140 return step_output

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py:239, in _AutomaticOptimization._make_backward_fn.<locals>.backward_fn(loss)
    238 def backward_fn(loss: Tensor) -> None:
--> 239     call._call_strategy_hook(self.trainer, "backward", loss, optimizer)

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:319, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    316     return None
    318 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 319     output = fn(*args, **kwargs)
    321 # restore current_fx when nested context
    322 pl_module._current_fx_name = prev_fx_name

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py:212, in Strategy.backward(self, closure_loss, optimizer, *args, **kwargs)
    209 assert self.lightning_module is not None
    210 closure_loss = self.precision_plugin.pre_backward(closure_loss, self.lightning_module)
--> 212 self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
    214 closure_loss = self.precision_plugin.post_backward(closure_loss, self.lightning_module)
    215 self.post_backward(closure_loss)

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py:72, in Precision.backward(self, tensor, model, optimizer, *args, **kwargs)
     52 @override
     53 def backward(  # type: ignore[override]
     54     self,
   (...)
     59     **kwargs: Any,
     60 ) -> None:
     61     r"""Performs the actual backpropagation.
     62 
     63     Args:
   (...)
     70 
     71     """
---> 72     model.backward(tensor, *args, **kwargs)

File ~/miniconda3//lib/python3.12/site-packages/lightning/pytorch/core/module.py:1101, in LightningModule.backward(self, loss, *args, **kwargs)
   1099     self._fabric.backward(loss, *args, **kwargs)
   1100 else:
-> 1101     loss.backward(*args, **kwargs)

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/torch/_tensor.py:581, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    571 if has_torch_function_unary(self):
    572     return handle_torch_function(
    573         Tensor.backward,
    574         (self,),
   (...)
    579         inputs=inputs,
    580     )
--> 581 torch.autograd.backward(
    582     self, gradient, retain_graph, create_graph, inputs=inputs
    583 )

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/torch/autograd/__init__.py:347, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    342     retain_graph = create_graph
    344 # The reason we repeat the same comment below is that
    345 # some Python versions print out the first line of a multi-line function
    346 # calls in the traceback and some print out the last line
--> 347 _engine_run_backward(
    348     tensors,
    349     grad_tensors_,
    350     retain_graph,
    351     create_graph,
    352     inputs,
    353     allow_unreachable=True,
    354     accumulate_grad=True,
    355 )

File ~/miniconda3/envs/tft/lib/python3.12/site-packages/torch/autograd/graph.py:825, in _engine_run_backward(t_outputs, *args, **kwargs)
    823     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    824 try:
--> 825     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    826         t_outputs, *args, **kwargs
    827     )  # Calls into the C++ engine to run the backward pass
    828 finally:
    829     if attach_logging_hooks:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Versions
pytorch 2.5.1 py3.12_0 pytorch
pytorch-forecasting 1.1.1 pypi_0 pypi
pytorch-lightning 2.4.0 pypi_0 pypi
pytorch-optimizer 3.2.0 pypi_0 pypi
torchvision 0.20.1 py312_cpu pytorch

@zy636 zy636 added the bug Something isn't working label Nov 23, 2024
@github-project-automation github-project-automation bot moved this to Needs triage & validation in Bugfixing - pytorch-forecasting Nov 23, 2024
@zy636 zy636 changed the title [BUG] train TFT model on MAC M1 mps device. element 0 of tensors does not require grad and does not have a grad_fn [BUG] train TFT model on mac M1 mps device. element 0 of tensors does not require grad and does not have a grad_fn Nov 23, 2024
@this-josh
Copy link

I'm also experiencing this issue

@fnhirwa
Copy link
Member

fnhirwa commented Dec 9, 2024

I was able to reproduce this bug.

This seems to be an issue with the loss Tensor passed to the backward pass being detached from the computation graph, which means it does not require gradients and doesn't have grad_fn. This is probably the downside of the falling back of MPs. I am going to open a PR to fix this issue for the MPS accelerator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Reproduced/confirmed
Development

No branches or pull requests

3 participants