From 8aa3d3d1851ed14fed5b4419d7926a3267dd9deb Mon Sep 17 00:00:00 2001 From: Jurgis Pods <10235147+jurgispods@users.noreply.github.com> Date: Tue, 22 Aug 2023 12:56:12 +0200 Subject: [PATCH] Fix error when handing over trainer_kwargs to predict_depency() --- pytorch_forecasting/models/base_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 58d1b90e..eec6cfc7 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -1473,13 +1473,14 @@ def predict_dependency( # set values data.set_overwrite_values(variable=variable, values=value, target=target) # predict - kwargs.setdefault("mode", "prediction") + pred_kwargs = deepcopy(kwargs) + pred_kwargs.setdefault("mode", "prediction") if idx == 0 and mode == "dataframe": # need index for returning as dataframe - res = self.predict(data, return_index=True, **kwargs) + res = self.predict(data, return_index=True, **pred_kwargs) results.append(res.output) else: - results.append(self.predict(data, **kwargs)) + results.append(self.predict(data, **pred_kwargs)) # increment progress progress_bar.update()