From 18153aa87dd229e4d354a078310a7a4b69466615 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 28 Dec 2024 16:55:36 +0100 Subject: [PATCH] [ENH] move tide model - part 2 (#1744) Follows https://github.com/sktime/pytorch-forecasting/pull/1743, adds exports and linting. --- pytorch_forecasting/models/tide/__init__.py | 9 +++++++++ pytorch_forecasting/models/tide/_tide.py | 12 ++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) create mode 100644 pytorch_forecasting/models/tide/__init__.py diff --git a/pytorch_forecasting/models/tide/__init__.py b/pytorch_forecasting/models/tide/__init__.py new file mode 100644 index 00000000..0f265a15 --- /dev/null +++ b/pytorch_forecasting/models/tide/__init__.py @@ -0,0 +1,9 @@ +"""Tide model.""" + +from pytorch_forecasting.models.tide._tide import TiDEModel +from pytorch_forecasting.models.tide.sub_modules import _TideModule + +__all__ = [ + "_TideModule", + "TiDEModel", +] diff --git a/pytorch_forecasting/models/tide/_tide.py b/pytorch_forecasting/models/tide/_tide.py index b8556fbe..101d4a4b 100644 --- a/pytorch_forecasting/models/tide/_tide.py +++ b/pytorch_forecasting/models/tide/_tide.py @@ -1,15 +1,15 @@ -from typing import Dict, List, Optional, Tuple, Union from copy import copy +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.data.encoders import NaNLabelEncoder from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE from pytorch_forecasting.models.base_model import BaseModelWithCovariates from pytorch_forecasting.models.nn.embeddings import MultiEmbedding from pytorch_forecasting.models.tide.sub_modules import _TideModule -from pytorch_forecasting.data import TimeSeriesDataSet -from pytorch_forecasting.data.encoders import NaNLabelEncoder - -from torch import nn -import torch class TiDEModel(BaseModelWithCovariates):