From 4df92152e67c6d2db2eda91b061268dd68c4983a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20Akko=C3=A7?= Date: Tue, 10 Dec 2024 19:47:35 +0100 Subject: [PATCH] [DOC] Updated documentation on `TimeSeriesDataSet.predict_mode` (#1720) ### Description Documentation PR. Clarifies use of `predict_mode` for [TimeSeriesDataSet](https://pytorch-forecasting.readthedocs.io/en/stable/api/pytorch_forecasting.data.timeseries.TimeSeriesDataSet.html#). ### Checklist - [x] Linked issues (if existing) - [x] Amended changelog for large changes (and added myself there as contributor) - [x] Added/modified tests - [x] Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with `pre-commit install`. To run hooks independent of commit, execute `pre-commit run --all-files` Make sure to have fun coding! --- pytorch_forecasting/data/timeseries.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 0ae2182f..566c16a6 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -323,10 +323,14 @@ def __init__( distribution. If True, defaults to (0.2, 0.05), i.e. ~1/4 of samples around minimum encoder length. Defaults to False otherwise. - predict_mode (bool): if to only iterate over each timeseries once (only the last provided samples). - Effectively, this will take choose for each time series identified by ``group_ids`` + predict_mode (bool): If True, the TimeSeriesDataSet will only create one sequence + per time series (i.e. only from the latest provided samples). + Effectively, this will select each time series identified by ``group_ids`` the last ``max_prediction_length`` samples of each time series as prediction samples and everthing previous up to ``max_encoder_length`` samples as encoder samples. + If False, the TimeSeriesDataSet will create subsequences by sliding a window over the data samples. + For training use cases, it's preferable to set predict_mode=False to get all subseries. + On the other hand, predict_mode = True is ideal for validation cases. """ super().__init__() self.max_encoder_length = max_encoder_length