diff --git a/CHANGELOG.md b/CHANGELOG.md index 94bb83678..7f0392df3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,9 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added -- Add `**kwargs` argument description for models based on `LinearRegression`, `ElasticNet` and `CatBoostRegressor` ([#454](https://github.com/etna-team/etna/pull/454https://github.com/etna-team/etna/pull/454)) -- -- +- Add `**kwargs` argument description for models based on `LinearRegression`, `ElasticNet` and `CatBoostRegressor` ([#454](https://github.com/etna-team/etna/pull/454)) +- Add possibility to load pretrained embedding models ([#461](https://github.com/etna-team/etna/pull/461)) +- Add `is_freezed` parameter to `TS2VecEmbeddingModel` and `TSTCCEmbeddingModel` ([#461](https://github.com/etna-team/etna/pull/461)) - - - diff --git a/etna/transforms/embeddings/models/ts2vec.py b/etna/transforms/embeddings/models/ts2vec.py index fdd4c5279..ea38c0301 100644 --- a/etna/transforms/embeddings/models/ts2vec.py +++ b/etna/transforms/embeddings/models/ts2vec.py @@ -1,8 +1,13 @@ +import os import pathlib import tempfile +import warnings import zipfile +from pathlib import Path +from typing import List from typing import Literal from typing import Optional +from urllib import request import numpy as np @@ -12,6 +17,8 @@ if SETTINGS.torch_required: from etna.libs.ts2vec import TS2Vec +_DOWNLOAD_PATH = Path.home() / ".etna" / "embeddings" / "ts2vec" + class TS2VecEmbeddingModel(BaseEmbeddingModel): """TS2Vec embedding model. @@ -39,6 +46,7 @@ def __init__( num_workers: int = 0, max_train_length: Optional[int] = None, temporal_unit: int = 0, + is_freezed: bool = False, ): """Init TS2VecEmbeddingModel. @@ -64,6 +72,8 @@ def __init__( temporal_unit: The minimum unit to perform temporal contrast. When training on a very long sequence, this param helps to reduce the cost of time and memory. + is_freezed: + Whether to ``freeze`` model in constructor or not. For more details see ``freeze`` method. Notes ----- In case of long series to reduce memory consumption it is recommended to use max_train_length parameter or manually break the series into smaller subseries. @@ -88,8 +98,10 @@ def __init__( max_train_length=self.max_train_length, temporal_unit=self.temporal_unit, ) + self._is_freezed = is_freezed - self._is_freezed: bool = False + if self._is_freezed: + self.freeze() @property def is_freezed(self): @@ -257,7 +269,7 @@ def save(self, path: pathlib.Path): archive.write(model_save_path, "model.zip") @classmethod - def load(cls, path: pathlib.Path) -> "TS2VecEmbeddingModel": + def load(cls, path: Optional[pathlib.Path] = None, model_name: Optional[str] = None) -> "TS2VecEmbeddingModel": """Load an object. Model's weights are transferred to cpu during loading. @@ -267,11 +279,51 @@ def load(cls, path: pathlib.Path) -> "TS2VecEmbeddingModel": path: Path to load object from. + - if ``path`` is not None and ``model_name`` is None, load the local model from ``path``. + - if ``path`` is None and ``model_name`` is not None, save the external ``model_name`` model to the etna folder in the home directory and load it. If ``path`` exists, external model will not be downloaded. + - if ``path`` is not None and ``model_name`` is not None, save the external ``model_name`` model to ``path`` and load it. If ``path`` exists, external model will not be downloaded. + + model_name: + Name of external model to load. To get list of available models use ``list_models`` method. + Returns ------- : Loaded object. + + Raises + ------ + ValueError: + If none of parameters ``path`` and ``model_name`` are set. + NotImplementedError: + If ``model_name`` isn't from list of available model names. """ + warnings.filterwarnings( + "ignore", + message="The object was saved under etna version 2.7.1 but running version is", + category=UserWarning, + ) + + if model_name is not None: + if path is None: + path = _DOWNLOAD_PATH / f"{model_name}.zip" + if os.path.exists(path): + warnings.warn( + f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model." + ) + else: + Path(path).parent.mkdir(exist_ok=True, parents=True) + + if model_name in cls.list_models(): + url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/ts2vec/{model_name}.zip" + request.urlretrieve(url=url, filename=path) + else: + raise NotImplementedError( + f"Model {model_name} is not available. To get list of available models use `list_models` method." + ) + elif path is None and model_name is None: + raise ValueError("Both path and model_name are not specified. At least one parameter should be specified.") + obj: TS2VecEmbeddingModel = super().load(path=path) obj.embedding_model = TS2Vec( input_dims=obj.input_dims, @@ -292,3 +344,22 @@ def load(cls, path: pathlib.Path) -> "TS2VecEmbeddingModel": obj.embedding_model.load(fn=str(model_path)) return obj + + @staticmethod + def list_models() -> List[str]: + """ + Return a list of available pretrained models. + + Main information about available models: + + - ts2vec_tiny: + + - Number of parameters - 40k + - Dimension of output embeddings - 16 + + Returns + ------- + : + List of available pretrained models. + """ + return ["ts2vec_tiny"] diff --git a/etna/transforms/embeddings/models/tstcc.py b/etna/transforms/embeddings/models/tstcc.py index f808d393f..f5b0e4665 100644 --- a/etna/transforms/embeddings/models/tstcc.py +++ b/etna/transforms/embeddings/models/tstcc.py @@ -1,7 +1,13 @@ +import os import pathlib import tempfile +import warnings import zipfile +from pathlib import Path +from typing import List from typing import Literal +from typing import Optional +from urllib import request import numpy as np @@ -11,6 +17,8 @@ if SETTINGS.torch_required: from etna.libs.tstcc import TSTCC +_DOWNLOAD_PATH = Path.home() / ".etna" / "embeddings" / "tstcc" + class TSTCCEmbeddingModel(BaseEmbeddingModel): """TSTCC embedding model. @@ -49,6 +57,7 @@ def __init__( device: Literal["cpu", "cuda"] = "cpu", batch_size: int = 16, num_workers: int = 0, + is_freezed: bool = False, ): """Init TSTCCEmbeddingModel. @@ -87,6 +96,8 @@ def __init__( The batch size (number of segments in a batch). To swap batch_size, change this attribute. num_workers: How many subprocesses to use for data loading. See (api reference :py:class:`torch.utils.data.DataLoader`). To swap num_workers, change this attribute. + is_freezed: + Whether to ``freeze`` model in constructor or not. For more details see ``freeze`` method. """ super().__init__(output_dims=output_dims) self.input_dims = input_dims @@ -125,8 +136,10 @@ def __init__( jitter_ratio=self.jitter_ratio, use_cosine_similarity=self.use_cosine_similarity, ) + self._is_freezed = is_freezed - self._is_freezed: bool = False + if self._is_freezed: + self.freeze() @property def is_freezed(self): @@ -252,7 +265,7 @@ def save(self, path: pathlib.Path): archive.write(model_save_path, "model.zip") @classmethod - def load(cls, path: pathlib.Path) -> "TSTCCEmbeddingModel": + def load(cls, path: Optional[pathlib.Path] = None, model_name: Optional[str] = None) -> "TSTCCEmbeddingModel": """Load an object. Model's weights are transferred to cpu during loading. @@ -262,11 +275,51 @@ def load(cls, path: pathlib.Path) -> "TSTCCEmbeddingModel": path: Path to load object from. + - if ``path`` is not None and ``model_name`` is None, load the local model from ``path``. + - if ``path`` is None and ``model_name`` is not None, save the external ``model_name`` model to the etna folder in the home directory and load it. If ``path`` exists, external model will not be downloaded. + - if ``path`` is not None and ``model_name`` is not None, save the external ``model_name`` model to ``path`` and load it. If ``path`` exists, external model will not be downloaded. + + model_name: + name of external model to load. To get list of available models use ``list_models`` method. + Returns ------- : Loaded object. + + Raises + ------ + ValueError: + If none of parameters ``path`` and ``model_name`` are set. + NotImplementedError: + If ``model_name`` isn't from list of available model names. """ + warnings.filterwarnings( + "ignore", + message="The object was saved under etna version 2.7.1 but running version is", + category=UserWarning, + ) + + if model_name is not None: + if path is None: + path = _DOWNLOAD_PATH / f"{model_name}.zip" + if os.path.exists(path): + warnings.warn( + f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model." + ) + else: + Path(path).parent.mkdir(exist_ok=True, parents=True) + + if model_name in cls.list_models(): + url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/tstcc/{model_name}.zip" + request.urlretrieve(url=url, filename=path) + else: + raise NotImplementedError( + f"Model {model_name} is not available. To get list of available models use `list_models` method." + ) + elif path is None and model_name is None: + raise ValueError("Both path and model_name are not specified. At least one parameter should be specified.") + obj: TSTCCEmbeddingModel = super().load(path=path) obj.embedding_model = TSTCC( input_dims=obj.input_dims, @@ -293,3 +346,22 @@ def load(cls, path: pathlib.Path) -> "TSTCCEmbeddingModel": obj.embedding_model.load(fn=str(model_path)) return obj + + @staticmethod + def list_models() -> List[str]: + """ + Return a list of available pretrained models. + + Main information about available models: + + - tstcc_medium: + + - Number of parameters - 234k + - Dimension of output embeddings - 16 + + Returns + ------- + : + List of available pretrained models. + """ + return ["tstcc_medium"] diff --git a/examples/210-embedding_models.ipynb b/examples/210-embedding_models.ipynb index 4cf141d41..7792d09de 100644 --- a/examples/210-embedding_models.ipynb +++ b/examples/210-embedding_models.ipynb @@ -24,7 +24,8 @@ " * [Baseline](#section_2_1)\n", " * [EmbeddingSegmentTransform](#section_2_2)\n", " * [EmbeddingWindowTransform](#section_2_3)\n", - "* [Saving and loading models](#chapter3)" + "* [Saving and loading models](#chapter3)\n", + "* [Loading external pretrained models](#chapter4)" ] }, { @@ -657,8 +658,7 @@ "from etna.datasets import load_dataset\n", "\n", "ts = load_dataset(\"m3_monthly\")\n", - "ts.drop_features(features=[\"origin_timestamp\"])\n", - "ts.df_exog = None\n", + "ts = TSDataset(ts.to_pandas(features=[\"target\"]), freq=None)\n", "ts.head()" ] }, @@ -676,19 +676,19 @@ "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 4.3s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 8.4s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 13.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 13.1s finished\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 9.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 14.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 14.1s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.4s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.8s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.1s finished\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.2s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" ] } ], @@ -755,14 +755,16 @@ "from etna.transforms.embeddings.models import BaseEmbeddingModel\n", "\n", "\n", - "def forecast_with_segment_embeddings(emb_model: BaseEmbeddingModel, training_params: dict) -> float:\n", + "def forecast_with_segment_embeddings(\n", + " emb_model: BaseEmbeddingModel, training_params: dict = {}, n_folds: int = 3\n", + ") -> float:\n", " model = CatBoostMultiSegmentModel()\n", "\n", " emb_transform = EmbeddingSegmentTransform(\n", " in_columns=[\"target\"], embedding_model=emb_model, training_params=training_params, out_column=\"emb\"\n", " )\n", " pipeline = Pipeline(model=model, transforms=[lag_transform, emb_transform], horizon=HORIZON)\n", - " metrics_df, _, _ = pipeline.backtest(ts, metrics=[SMAPE()], n_folds=3)\n", + " metrics_df, _, _ = pipeline.backtest(ts, metrics=[SMAPE()], n_folds=n_folds)\n", " smape_score = metrics_df[\"SMAPE\"].mean()\n", " return smape_score" ] @@ -802,20 +804,20 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 34.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.1min remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.7min remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.7min finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 35.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.2min remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.8min remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.8min finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 1.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 2.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 3.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 3.1s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 1.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 2.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 3.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 3.3s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s finished\n" ] } ], @@ -870,15 +872,15 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 27.7s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 58.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.7min remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.7min finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 26.7s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 53.9s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.4min remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.4min finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 1.7s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 3.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 4.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 4.1s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.9s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.8s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.8s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.8s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", @@ -987,20 +989,20 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 53.9s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.8min remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.7min remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.7min finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 45.8s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.9min remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.8min remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.8min finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 10.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 20.9s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 31.6s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 31.6s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 10.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 20.7s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 30.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 30.6s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s finished\n" ] } ], @@ -1051,20 +1053,20 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 34.5s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.2min remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.8min remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.8min finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 44.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.5min remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.3min remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.3min finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 8.6s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 17.4s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 26.3s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 26.3s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 10.8s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 20.8s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 30.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 30.6s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s finished\n" ] } ], @@ -1213,6 +1215,505 @@ "source": [ "model_loaded.is_freezed" ] + }, + { + "cell_type": "markdown", + "id": "5d5a6f56", + "metadata": {}, + "source": [ + "## 4. Loading external pretrained models \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "565729ea", + "metadata": {}, + "source": [ + "In this section we introduce our pretrained embedding models." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "8d38bf52", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
segmentm1m10m100m101m102m103m104m105m106m107...m90m91m92m93m94m95m96m97m98m99
featuretargettargettargettargettargettargettargettargettargettarget...targettargettargettargettargettargettargettargettargettarget
timestamp
0NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
1NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
2NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3NaNNaN4.0329.01341.0319.01419.0462.0921.03118.0...7301.04374.0803.0191.0124.0319.0270.036.0109.038.0
4NaNNaN40.0439.01258.0315.01400.0550.01060.02775.0...13980.03470.0963.0265.0283.0690.0365.031.0158.074.0
\n", + "

5 rows × 366 columns

\n", + "
" + ], + "text/plain": [ + "segment m1 m10 m100 m101 m102 m103 m104 m105 m106 \\\n", + "feature target target target target target target target target target \n", + "timestamp \n", + "0 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", + "1 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", + "2 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", + "3 NaN NaN 4.0 329.0 1341.0 319.0 1419.0 462.0 921.0 \n", + "4 NaN NaN 40.0 439.0 1258.0 315.0 1400.0 550.0 1060.0 \n", + "\n", + "segment m107 ... m90 m91 m92 m93 m94 m95 m96 \\\n", + "feature target ... target target target target target target target \n", + "timestamp ... \n", + "0 NaN ... NaN NaN NaN NaN NaN NaN NaN \n", + "1 NaN ... NaN NaN NaN NaN NaN NaN NaN \n", + "2 NaN ... NaN NaN NaN NaN NaN NaN NaN \n", + "3 3118.0 ... 7301.0 4374.0 803.0 191.0 124.0 319.0 270.0 \n", + "4 2775.0 ... 13980.0 3470.0 963.0 265.0 283.0 690.0 365.0 \n", + "\n", + "segment m97 m98 m99 \n", + "feature target target target \n", + "timestamp \n", + "0 NaN NaN NaN \n", + "1 NaN NaN NaN \n", + "2 NaN NaN NaN \n", + "3 36.0 109.0 38.0 \n", + "4 31.0 158.0 74.0 \n", + "\n", + "[5 rows x 366 columns]" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "HORIZON = 12\n", + "\n", + "ts = load_dataset(\"tourism_monthly\")\n", + "ts = TSDataset(ts.to_pandas(features=[\"target\"]), freq=None)\n", + "ts.head()" + ] + }, + { + "cell_type": "markdown", + "id": "70588951", + "metadata": {}, + "source": [ + "Our base pipeline with lags. " + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "3ed0d8d8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 4.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 4.1s finished\n", + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s finished\n", + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s finished\n" + ] + } + ], + "source": [ + "model = CatBoostMultiSegmentModel()\n", + "\n", + "lag_transform = LagTransform(in_column=\"target\", lags=list(range(HORIZON, HORIZON + 6)), out_column=\"lag\")\n", + "\n", + "pipeline = Pipeline(model=model, transforms=[lag_transform], horizon=HORIZON)\n", + "metrics_df, _, _ = pipeline.backtest(ts, metrics=[SMAPE()], n_folds=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "73c6f34c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SMAPE: 18.80136468764402\n" + ] + } + ], + "source": [ + "print(\"SMAPE: \", metrics_df[\"SMAPE\"].mean())" + ] + }, + { + "cell_type": "markdown", + "id": "64b52daa", + "metadata": {}, + "source": [ + "It is often useful to encode segment by `SegmentEncoderTransform` when using multi-segment models like now." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "56b8f36a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 12.5s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 12.5s finished\n", + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.4s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.4s finished\n", + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s finished\n" + ] + } + ], + "source": [ + "from etna.transforms import SegmentEncoderTransform\n", + "\n", + "model = CatBoostMultiSegmentModel()\n", + "\n", + "lag_transform = LagTransform(in_column=\"target\", lags=list(range(HORIZON, HORIZON + 6)), out_column=\"lag\")\n", + "segment_transform = SegmentEncoderTransform()\n", + "\n", + "pipeline = Pipeline(model=model, transforms=[lag_transform, segment_transform], horizon=HORIZON)\n", + "metrics_df, _, _ = pipeline.backtest(ts, metrics=[SMAPE()], n_folds=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "05d839c2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SMAPE: 18.719919206298737\n" + ] + } + ], + "source": [ + "print(\"SMAPE: \", metrics_df[\"SMAPE\"].mean())" + ] + }, + { + "cell_type": "markdown", + "id": "ccda7e72", + "metadata": {}, + "source": [ + "Segment embeddings from `EmbeddingSegmentTransform` can replace `SegmentEncoderTransform`'s feature. The main advantage of using segment embeddings is that you can forecast new segments by your trained pipeline. `SegmentEncoderTransform` can't work with segments that weren't present during training.\n", + "\n", + "To see available embedding models use `list_models` method of `TS2VecEmbeddingModel` or `TSTCCEmbeddingModel`" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "fa270a0a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['ts2vec_tiny']" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "TS2VecEmbeddingModel.list_models()" + ] + }, + { + "cell_type": "markdown", + "id": "04e8575b", + "metadata": {}, + "source": [ + "Let's load `ts2vec_tiny` pretrained model." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "2834d4cb", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 7.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 7.0s finished\n", + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 1.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 1.0s finished\n", + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s finished\n" + ] + } + ], + "source": [ + "emb_model = TS2VecEmbeddingModel.load(path=\"ts2vec_model.zip\", model_name=\"ts2vec_tiny\")\n", + "\n", + "smape_score = forecast_with_segment_embeddings(emb_model, n_folds=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "837728f8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SMAPE: 18.436162523492154\n" + ] + } + ], + "source": [ + "print(\"SMAPE: \", smape_score)" + ] + }, + { + "cell_type": "markdown", + "id": "ab62fb69", + "metadata": {}, + "source": [ + "We get better result compared to `SegmentEncoderTransform` and opportunity to use pipeline for new segments." + ] } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml index 2977f6de1..3681d4620 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -209,6 +209,7 @@ filterwarnings = [ "ignore: Given top_k=.* is less than n_segments=.*. Algo will filter data without Gale-Shapley run.", "ignore: This model doesn't work with exogenous features", "ignore: Some of external objects in input parameters could be not", + "ignore: You haven't set all parameters inside class __init__ method.* 'is_freezed'", # external warnings "ignore: Attribute 'logging_metrics' is an instance of `nn.Module` and is already", "ignore: Attribute 'loss' is an instance of `nn.Module` and is already", diff --git a/tests/test_transforms/test_embeddings/test_models/test_ts2vec.py b/tests/test_transforms/test_embeddings/test_models/test_ts2vec.py index f04f81090..c8c66e321 100644 --- a/tests/test_transforms/test_embeddings/test_models/test_ts2vec.py +++ b/tests/test_transforms/test_embeddings/test_models/test_ts2vec.py @@ -1,4 +1,5 @@ -import pathlib +import os +from pathlib import Path from tempfile import NamedTemporaryFile import numpy as np @@ -33,7 +34,7 @@ def test_encode_window(ts_with_exog_nan_begin_numpy): def test_save(tmp_path): model = TS2VecEmbeddingModel(input_dims=3) - path = pathlib.Path(tmp_path) / "tmp.zip" + path = Path(tmp_path) / "tmp.zip" model.save(path=path) @@ -41,11 +42,16 @@ def test_save(tmp_path): def test_load(tmp_path): model = TS2VecEmbeddingModel(input_dims=3) - path = pathlib.Path(tmp_path) / "tmp.zip" + path = Path(tmp_path) / "tmp.zip" model.save(path=path) TS2VecEmbeddingModel.load(path=path) +@pytest.mark.smoke +def test_list_models(): + TS2VecEmbeddingModel.list_models() + + @pytest.mark.parametrize( "output_dims, segment_shape_expected, window_shape_expected", [(2, (5, 2), (5, 10, 2)), (3, (5, 3), (5, 10, 3))] ) @@ -60,7 +66,7 @@ def test_encode_format(ts_with_exog_nan_begin_numpy, output_dims, segment_shape_ def test_encode_pre_fitted(ts_with_exog_nan_begin_numpy, tmp_path): model = TS2VecEmbeddingModel(input_dims=3) model.fit(ts_with_exog_nan_begin_numpy, n_epochs=1) - path = pathlib.Path(tmp_path) / "tmp.zip" + path = Path(tmp_path) / "tmp.zip" model.save(path=path) model_loaded = TS2VecEmbeddingModel.load(path=path) @@ -77,7 +83,7 @@ def test_not_freeze_fit(ts_with_exog_nan_begin_numpy, tmp_path): model = TS2VecEmbeddingModel(input_dims=3) model.fit(ts_with_exog_nan_begin_numpy, n_epochs=1) model.freeze(is_freezed=False) - path = pathlib.Path(tmp_path) / "tmp.zip" + path = Path(tmp_path) / "tmp.zip" model.save(path=path) model_loaded = TS2VecEmbeddingModel.load(path=path) @@ -98,7 +104,7 @@ def test_freeze_fit(ts_with_exog_nan_begin_numpy, tmp_path): model = TS2VecEmbeddingModel(input_dims=3) model.fit(ts_with_exog_nan_begin_numpy, n_epochs=1) model.freeze(is_freezed=True) - path = pathlib.Path(tmp_path) / "tmp.zip" + path = Path(tmp_path) / "tmp.zip" model.save(path=path) model_loaded = TS2VecEmbeddingModel.load(path=path) @@ -138,3 +144,51 @@ def test_logged_loss(ts_with_exog_nan_begin_numpy, verbose, n_epochs, n_lines_ex model.fit(ts_with_exog_nan_begin_numpy, n_epochs=n_epochs, verbose=verbose) check_logged_loss(log_file=file.name, n_lines_expected=n_lines_expected) tslogger.remove(idx) + + +def test_correct_list_models(): + assert TS2VecEmbeddingModel.list_models() == ["ts2vec_tiny"] + + +@pytest.mark.parametrize("model_name", ["ts2vec_tiny"]) +def test_load_pretrained_model_default_path(model_name): + path = Path.home() / ".etna" / "embeddings" / "ts2vec" / f"{model_name}.zip" + path.unlink(missing_ok=True) + _ = TS2VecEmbeddingModel.load(model_name=model_name) + assert os.path.isfile(path) + + +@pytest.mark.parametrize("model_name", ["ts2vec_tiny"]) +def test_load_pretrained_model_exact_path(model_name, tmp_path): + path = Path(tmp_path) / "tmp.zip" + _ = TS2VecEmbeddingModel.load(path=path, model_name=model_name) + assert os.path.isfile(path) + + +def test_load_unknown_pretrained_model(): + model_name = "unknown_model" + with pytest.raises( + NotImplementedError, + match=f"Model {model_name} is not available. To get list of available models use `list_models` method.", + ): + TS2VecEmbeddingModel.load(model_name=model_name) + + +def test_load_set_none_parameters(): + with pytest.raises( + ValueError, match="Both path and model_name are not specified. At least one parameter should be specified." + ): + TS2VecEmbeddingModel.load() + + +def test_warning_existing_path(tmp_path): + model = TS2VecEmbeddingModel(input_dims=1) + path = Path(tmp_path) / "tmp.zip" + model.save(path) + + model_name = "ts2vec_tiny" + with pytest.warns( + UserWarning, + match=f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model.", + ): + TS2VecEmbeddingModel.load(path=path, model_name=model_name) diff --git a/tests/test_transforms/test_embeddings/test_models/test_tstcc.py b/tests/test_transforms/test_embeddings/test_models/test_tstcc.py index cb41439db..e2f52f7a9 100644 --- a/tests/test_transforms/test_embeddings/test_models/test_tstcc.py +++ b/tests/test_transforms/test_embeddings/test_models/test_tstcc.py @@ -1,4 +1,5 @@ -import pathlib +import os +from pathlib import Path from tempfile import NamedTemporaryFile import numpy as np @@ -33,7 +34,7 @@ def test_encode_window(ts_with_exog_nan_begin_numpy): def test_save(tmp_path): model = TSTCCEmbeddingModel(input_dims=3) - path = pathlib.Path(tmp_path) / "tmp.zip" + path = Path(tmp_path) / "tmp.zip" model.save(path=path) @@ -41,11 +42,16 @@ def test_save(tmp_path): def test_load(tmp_path): model = TSTCCEmbeddingModel(input_dims=3) - path = pathlib.Path(tmp_path) / "tmp.zip" + path = Path(tmp_path) / "tmp.zip" model.save(path=path) TSTCCEmbeddingModel.load(path=path) +@pytest.mark.smoke +def test_list_models(): + TSTCCEmbeddingModel.list_models() + + @pytest.mark.parametrize( "output_dims, segment_shape_expected, window_shape_expected", [(2, (5, 2), (5, 10, 2)), (3, (5, 3), (5, 10, 3))] ) @@ -60,7 +66,7 @@ def test_encode_format(ts_with_exog_nan_begin_numpy, output_dims, segment_shape_ def test_encode_pre_fitted(ts_with_exog_nan_begin_numpy, tmp_path): model = TSTCCEmbeddingModel(input_dims=3, batch_size=3) model.fit(ts_with_exog_nan_begin_numpy, n_epochs=1) - path = pathlib.Path(tmp_path) / "tmp.zip" + path = Path(tmp_path) / "tmp.zip" model.save(path=path) model_loaded = TSTCCEmbeddingModel.load(path=path) @@ -77,7 +83,7 @@ def test_not_freeze_fit(ts_with_exog_nan_begin_numpy, tmp_path): model = TSTCCEmbeddingModel(input_dims=3, batch_size=3) model.fit(ts_with_exog_nan_begin_numpy, n_epochs=1) model.freeze(is_freezed=False) - path = pathlib.Path(tmp_path) / "tmp.zip" + path = Path(tmp_path) / "tmp.zip" model.save(path=path) model_loaded = TSTCCEmbeddingModel.load(path=path) @@ -98,7 +104,7 @@ def test_freeze_fit(ts_with_exog_nan_begin_numpy, tmp_path): model = TSTCCEmbeddingModel(input_dims=3, batch_size=3) model.fit(ts_with_exog_nan_begin_numpy, n_epochs=1) model.freeze(is_freezed=True) - path = pathlib.Path(tmp_path) / "tmp.zip" + path = Path(tmp_path) / "tmp.zip" model.save(path=path) model_loaded = TSTCCEmbeddingModel.load(path=path) @@ -144,3 +150,51 @@ def test_logged_loss(ts_with_exog_nan_begin_numpy, verbose, n_epochs, n_lines_ex model.fit(ts_with_exog_nan_begin_numpy, n_epochs=n_epochs, verbose=verbose) check_logged_loss(log_file=file.name, n_lines_expected=n_lines_expected) tslogger.remove(idx) + + +def test_correct_list_models(): + assert TSTCCEmbeddingModel.list_models() == ["tstcc_medium"] + + +@pytest.mark.parametrize("model_name", ["tstcc_medium"]) +def test_load_pretrained_model_default_path(model_name): + path = Path.home() / ".etna" / "embeddings" / "tstcc" / f"{model_name}.zip" + path.unlink(missing_ok=True) + _ = TSTCCEmbeddingModel.load(path=path, model_name=model_name) + assert os.path.isfile(path) + + +@pytest.mark.parametrize("model_name", ["tstcc_medium"]) +def test_load_pretrained_model_exact_path(model_name, tmp_path): + path = Path(tmp_path) / "tmp.zip" + _ = TSTCCEmbeddingModel.load(path=path, model_name=model_name) + assert os.path.isfile(path) + + +def test_load_unknown_pretrained_model(): + model_name = "unknown_model" + with pytest.raises( + NotImplementedError, + match=f"Model {model_name} is not available. To get list of available models use `list_models` method.", + ): + TSTCCEmbeddingModel.load(model_name=model_name) + + +def test_load_set_none_parameters(): + with pytest.raises( + ValueError, match="Both path and model_name are not specified. At least one parameter should be specified." + ): + TSTCCEmbeddingModel.load() + + +def test_warning_existing_path(tmp_path): + model = TSTCCEmbeddingModel(input_dims=1) + path = Path(tmp_path) / "tmp.zip" + model.save(path) + + model_name = "tstcc_medium" + with pytest.warns( + UserWarning, + match=f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model.", + ): + TSTCCEmbeddingModel.load(path=path, model_name=model_name)