Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API] Redesign towards pytorch-forecasting 2.0 #1736

Open
fkiraly opened this issue Dec 20, 2024 · 11 comments
Open

[API] Redesign towards pytorch-forecasting 2.0 #1736

fkiraly opened this issue Dec 20, 2024 · 11 comments
Labels
API design API design & software architecture enhancement New feature or request

Comments

@fkiraly
Copy link
Collaborator

fkiraly commented Dec 20, 2024

Discussion thread for API re-design for pytorch.forecasting next 1.X and towards 2.0. Comments appreciated from everyone!

Summary of discussion on Dec 20, 2024 and prior, about re-design of pytorch-forecasting.

FYI @agobbifbk, @thawn, @sktime/core-developers.

High-level directions:

High-level features for 2.0 with MoSCoW analysis:

  • M: unified model API which is easily extensible and composable, similar to sktime and DSIPTS, but as closely to the pytorch level as possible. The API need not cover forecasters in general, only torch based forecasters.
    • M: unified monitoring and logging API
    • M: extension templates need to be created
    • S: skbase can be used to curate the forecasters as records, with tags, etc
    • S: model persistence
    • C: third party extension patterns, so new models can "live" in other repositories or packages, for instance thuml
  • M: reworked and unified data input API
    • M: support static variables and categoricals
    • S: support for multiple data input locations and formats - pandas, polars, distributed solutions etc
  • M: MLops and benchmarking features as in DSIPTS
  • S: support for pre-training, model hubs, foundation models, but this could be post-2.0

Todos:
0. update documentation on dsipts to signpost the above. README etc.

  1. highest priority - consolidated API design for model and data layer.
    • Depending on distance to current ptf and dsipts, use one or the other location for improvements (separate 2.0 -> dsipts, 1.X -> ptf as current).
    • ptf = stable and downwards compatible; dsipts = "playground"
    • first step for that: side-by-side comparisons of code, defined core workflows
  2. planning sessions & sprints from Jan 2025
@fkiraly fkiraly added enhancement New feature or request API design API design & software architecture labels Dec 20, 2024
@fkiraly fkiraly pinned this issue Dec 20, 2024
@fkiraly
Copy link
Collaborator Author

fkiraly commented Jan 1, 2025

Having reviewed multiple code bases - pytorch-forecasting, DSIPTS, neuralforecast, thuml, I have come to understand that the DataLoader and DataSet conventions are key, in particular the input convention for forward. Interestingly, all the above-mentioned packages have different conventions here, and none seems satisfactory. What is probably most promising is a "merge" of pytorch-forecasting and DSIPTS.

The model layer will mostly follow the data layer, given that torch has an idiosyncratic forward interface.

My suggestions for high-level requirements on data loaders:

  • easy to use in pure torch, detour via pandas can be avoided (this is currently possible but not easy)
  • support for future-known and unknown, endo- and exogenous, group ID and static variables
  • if possible, downwards compatibility to pytorch-forecasting

Observations of the current API:

  • neither package spells out the forward API properly, or has checking utilities for the containers.
  • pytorch-forecasting seems to do a resampling for a decoder/encoder structure already in the data loader - this may not be necessary for all models
  • DSIPTS is closer to the abstract data type, but lacks support for static variables

The explicit specifications can be reconstructed from usage and docstrings, for convenience listed below:

pytorch-forecasting

From the docstring of TimeSeriesDataset.to_dataloader:

DataLoader: dataloader that returns Tuple.
    First entry is ``x``, a dictionary of tensors with the entries (and shapes in brackets)

    * encoder_cat (batch_size x n_encoder_time_steps x n_features): long tensor of encoded
        categoricals for encoder
    * encoder_cont (batch_size x n_encoder_time_steps x n_features): float tensor of scaled continuous
        variables for encoder
    * encoder_target (batch_size x n_encoder_time_steps or list thereof with each entry for a different
        target):
        float tensor with unscaled continous target or encoded categorical target,
        list of tensors for multiple targets
    * encoder_lengths (batch_size): long tensor with lengths of the encoder time series. No entry will
        be greater than n_encoder_time_steps
    * decoder_cat (batch_size x n_decoder_time_steps x n_features): long tensor of encoded
        categoricals for decoder
    * decoder_cont (batch_size x n_decoder_time_steps x n_features): float tensor of scaled continuous
        variables for decoder
    * decoder_target (batch_size x n_decoder_time_steps or list thereof with each entry for a different
        target):
        float tensor with unscaled continous target or encoded categorical target for decoder
        - this corresponds to first entry of ``y``, list of tensors for multiple targets
    * decoder_lengths (batch_size): long tensor with lengths of the decoder time series. No entry will
        be greater than n_decoder_time_steps
    * group_ids (batch_size x number_of_ids): encoded group ids that identify a time series in the dataset
    * target_scale (batch_size x scale_size or list thereof with each entry for a different target):
        parameters used to normalize the target.
        Typically these are mean and standard deviation. Is list of tensors for multiple targets.


    Second entry is ``y``, a tuple of the form (``target``, `weight`)

    * target (batch_size x n_decoder_time_steps or list thereof with each entry for a different target):
        unscaled (continuous) or encoded (categories) targets, list of tensors for multiple targets
    * weight (None or batch_size x n_decoder_time_steps): weight

There is a custom collate_fn, it is (oddly?) stored in the TimeSeriesDataSet._collate_fm as a static method, which is then passed to the data loader.

DSIPTS

Specifies a simpler structure that is closer ot the abstract data type of the time series data - and therefore imo better.

The data loader needs to return batches as follows, from the docstring of Base.forward:

            batch (dict): the batch structure. The keys are:
                y : the target variable(s). This is always present
                x_num_past: the numerical past variables. This is always present
                x_num_future: the numerical future variables
                x_cat_past: the categorical past variables
                x_cat_future: the categorical future variables
                idx_target: index of target features in the past array

This is missing group ID or static variables, but imo is closer to the end state where we want to go.

Ensuring downwards compatibility

Downwards compatibility can be ensured by:

  • providing converter functions between the two types of batches. This can be achieved with aadditional decoder/encoder layers, or a DataLoader depending on another DataLoader.
  • neural networks being tagged with input assumptions on forward. This is probably a good idea in general as well.

Also, currently none of the libraries seems to have stringent tests for the API - we should probably introduce these. scikit-base can be used to draw these up quickly.

@fkiraly
Copy link
Collaborator Author

fkiraly commented Jan 1, 2025

question for @jdb78 - why did you design the forward API with encoder/decoder specific variables? Personally, I consider this a modelling choice, since not every deep learning forecaster is encoder/decoder based.

Side note: one possible design is to have data loaders that are specific to neural networks, facing a more general API

@geetu040
Copy link

geetu040 commented Jan 1, 2025

question for @jdb78 - why did you design the forward API with encoder/decoder specific variables? Personally, I consider this a modelling choice, since not every deep learning forecaster is encoder/decoder based.

forward method is in the API design of torch.nn.Module, which is the base class for every layer and model in pytorch. So I would say its not encoder/decoder based or a modelling choice in just this specific context.

@fkiraly
Copy link
Collaborator Author

fkiraly commented Jan 1, 2025

@geetu040, what I mean is the format of the x in forward, not the choice of forward itself (which indeed is fixed by torch). This can be an arbitrarily nested structure of dict and tuple, with leaf entries being tensors. The convention on the exact structure of x is up to the user, and this is where a core part of the API definition is "hidden" - all listed packages differ in their choices for the type of x that needs to be passed.

So, for pytorch-forecasting, the choice of having decoder/encoder related fields is indeed a choice.

@Sohaib-Ahmed21
Copy link
Contributor

@fkiraly what are your reviews on model initialization in pytorch_forecasting from from_dataset class method in model as other packages initialize model from the init method.

@jdb78
Copy link
Collaborator

jdb78 commented Jan 2, 2025

The idea is that basically all models can be represented as encoder/decoder. In some cases they are the same.

@fkiraly
Copy link
Collaborator Author

fkiraly commented Jan 2, 2025

The idea is that basically all models can be represented as encoder/decoder. In some cases they are the same.

Is that really true though for all models out there? And do we need this as forward args at the top level - as opposed to inside a layer?

See for instance Amazon Chronos:
https://github.com/amazon-science/chronos-forecasting/blob/main/src/chronos/chronos.py

or Google TimesFM:
https://github.com/google-research/timesfm/blob/master/src/timesfm/pytorch_patched_decoder.py

What I think we need for 2.0 is an API design that can cover all torch-based forecasting models.

@fkiraly
Copy link
Collaborator Author

fkiraly commented Jan 2, 2025

@fkiraly what are your reviews on model initialization in pytorch_forecasting from from_dataset class method in model as other packages initialize model from the init method.

I think there is no serious problem with that as ultimately it calls __init__ which in turn calls the hooks. I have three main feelings here:

  • positive: I think it is a smart idea, since many parameters will be the same for multiple models, given one dataset
  • negative: it complicates the interface, since we are passing information about the model and the dataset in many places.
  • question: I would really like @jdb78's thoughts on why/how you decided to put which parameters where - in TimeSeriesDataSet args, the model __init__, or the forward args, e.g., allowed_encoder_known_variable_names

@Sohaib-Ahmed21
Copy link
Contributor

Sohaib-Ahmed21 commented Jan 2, 2025

I think there is no serious problem with that as ultimately it calls __init__ which in turn calls the hooks. I have three main feelings here:

  • positive: I think it is a smart idea, since many parameters will be the same for multiple models, given one dataset
  • negative: it complicates the interface, since we are passing information about the model and the dataset in many places.

Yup, the interface complication is the main concerning thing considering use-cases like those involving single model. But yes, the positive and negative sides need cost-benefit analysis to decide final.

@fkiraly
Copy link
Collaborator Author

fkiraly commented Jan 3, 2025

Some further thoughts about the design:

  • I think there should be a DataSet that provides the time series raw, without any transformation or resampling.
    • optimally, this will be decoupled from pandas, using pandas only as one possible source.
    • this should be more similar to DSIPTS
  • I think we should the idiomatic follow DataSet vs DataLoader separation clearly - the usual separation is DataSet = sample-level operations, loading; DataLoader = shuffling, batching, etc.
    • the current pytorch-forecasting does not follow this separation! The data loader just copies mostly what the DataSet does, introducing high coupling between the layers that should be separated.

I would hence suggest, on the pytorch-forecasting side, a refactor that introduces a clear layer separation, but leaves the current interfaces intact until 2.0:

  • introduction of a DataSet subclass C similar to DSIPTS, close to the data. This can be subclassed for non-memory data sources
    • optionally, there can be subclasses that take sktime data types as lazy arguments. This would greatly facilitate interfacing.
  • introduction of a SlidingDataLoader that unifies the current logic in TimeSeriesDataSet.__getitem__, TimeSeriesDataSet._construct_index and the dataloader returned by TimeSeriesDataSet.to_dataloader. This DataLoader would take C as argument, and the parameters used in the above, and return the same batches as the current TimeSeriesDataSet.
    • for downwards compatibility, it can also take a TimeSeriesDataSet - this is polymorphism, just to ensure downwards compatibility.
  • on the model side, we design the API that each model comes with its own loader - or loaders. There is a default loader for each model - current pytorch-forecasting moels all point to the loader implied by TimeSeriesDataSet.
    • optionally, we could introduce a composite class closer to sktime, which consists of a loader and a model - one for each model.

@fkiraly
Copy link
Collaborator Author

fkiraly commented Jan 3, 2025

@geetu040, @benHeid, @jdb78, I would appreciate your thoughts and opinions!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API design API design & software architecture enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants