Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error when handing over trainer_kwargs to predict_depency() #1366

Merged
merged 1 commit into from
Sep 10, 2023

Conversation

jurgispods
Copy link
Contributor

Description

This PR fixes an error when handing over trainer_kwargs to predict_dependency(), as in the following example:

model.predict_dependency(ds, "feature", np.linspace(0, 30, 30), mode="dataframe",  trainer_kwargs={ "accelerator": "cpu"})

Overwriting trainer_kwargs might be necessary when using a model trained on GPU. The issue comes from the kwargs dict handed over to predict_dependency() being overwritten in predict(), which will cause it to contain a PredictCallback(return_index=True, ...) in the kwargs["trainer_kwargs"]["callbacks"] list, even in the second iteration where return_index should be False, causing a subsequent error

File ~/Library/Caches/pypoetry/virtualenvs/discovery-lab-rTrwb4a5-py3.10/lib/python3.10/site-packages/pytorch_forecasting/models/base_model.py:1491, in BaseModel.predict_dependency(self, data, variable, values, mode, target, show_progress_bar, **kwargs)
   1488 data.reset_overwrite_values()  # reset overwrite values to avoid side-effect
   1490 # results to one tensor
-> 1491 results = torch.stack(results, dim=0)
   1493 # convert results to requested output format
   1494 if mode == "series":

TypeError: expected Tensor as element 1 in argument 0, but got list

This PR fixes this by copying the user-specified kwargs in every iteration. No tests have been added, as this is a very minor change.

Checklist

  • Linked issues (if existing)
  • Amended changelog for large changes (and added myself there as contributor)
  • Added/modified tests
  • Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with pre-commit install.
    To run hooks independent of commit, execute pre-commit run --all-files

Make sure to have fun coding!

@codecov-commenter
Copy link

codecov-commenter commented Sep 10, 2023

Codecov Report

Patch coverage: 100.00% and project coverage change: +0.08% 🎉

Comparison is base (f0ead9e) 90.12% compared to head (8aa3d3d) 90.21%.

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1366      +/-   ##
==========================================
+ Coverage   90.12%   90.21%   +0.08%     
==========================================
  Files          30       30              
  Lines        4711     4712       +1     
==========================================
+ Hits         4246     4251       +5     
+ Misses        465      461       -4     
Flag Coverage Δ
cpu 90.21% <100.00%> (+0.08%) ⬆️
pytest 90.21% <100.00%> (+0.08%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Changed Coverage Δ
pytorch_forecasting/models/base_model.py 88.19% <100.00%> (+0.01%) ⬆️

... and 1 file with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@jdb78 jdb78 merged commit c4b1349 into sktime:master Sep 10, 2023
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants