diff --git a/CHANGELOG.md b/CHANGELOG.md index 94bb83678..7f0392df3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,9 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added -- Add `**kwargs` argument description for models based on `LinearRegression`, `ElasticNet` and `CatBoostRegressor` ([#454](https://github.com/etna-team/etna/pull/454https://github.com/etna-team/etna/pull/454)) -- -- +- Add `**kwargs` argument description for models based on `LinearRegression`, `ElasticNet` and `CatBoostRegressor` ([#454](https://github.com/etna-team/etna/pull/454)) +- Add possibility to load pretrained embedding models ([#461](https://github.com/etna-team/etna/pull/461)) +- Add `is_freezed` parameter to `TS2VecEmbeddingModel` and `TSTCCEmbeddingModel` ([#461](https://github.com/etna-team/etna/pull/461)) - - - diff --git a/etna/transforms/embeddings/models/ts2vec.py b/etna/transforms/embeddings/models/ts2vec.py index fdd4c5279..ea38c0301 100644 --- a/etna/transforms/embeddings/models/ts2vec.py +++ b/etna/transforms/embeddings/models/ts2vec.py @@ -1,8 +1,13 @@ +import os import pathlib import tempfile +import warnings import zipfile +from pathlib import Path +from typing import List from typing import Literal from typing import Optional +from urllib import request import numpy as np @@ -12,6 +17,8 @@ if SETTINGS.torch_required: from etna.libs.ts2vec import TS2Vec +_DOWNLOAD_PATH = Path.home() / ".etna" / "embeddings" / "ts2vec" + class TS2VecEmbeddingModel(BaseEmbeddingModel): """TS2Vec embedding model. @@ -39,6 +46,7 @@ def __init__( num_workers: int = 0, max_train_length: Optional[int] = None, temporal_unit: int = 0, + is_freezed: bool = False, ): """Init TS2VecEmbeddingModel. @@ -64,6 +72,8 @@ def __init__( temporal_unit: The minimum unit to perform temporal contrast. When training on a very long sequence, this param helps to reduce the cost of time and memory. + is_freezed: + Whether to ``freeze`` model in constructor or not. For more details see ``freeze`` method. Notes ----- In case of long series to reduce memory consumption it is recommended to use max_train_length parameter or manually break the series into smaller subseries. @@ -88,8 +98,10 @@ def __init__( max_train_length=self.max_train_length, temporal_unit=self.temporal_unit, ) + self._is_freezed = is_freezed - self._is_freezed: bool = False + if self._is_freezed: + self.freeze() @property def is_freezed(self): @@ -257,7 +269,7 @@ def save(self, path: pathlib.Path): archive.write(model_save_path, "model.zip") @classmethod - def load(cls, path: pathlib.Path) -> "TS2VecEmbeddingModel": + def load(cls, path: Optional[pathlib.Path] = None, model_name: Optional[str] = None) -> "TS2VecEmbeddingModel": """Load an object. Model's weights are transferred to cpu during loading. @@ -267,11 +279,51 @@ def load(cls, path: pathlib.Path) -> "TS2VecEmbeddingModel": path: Path to load object from. + - if ``path`` is not None and ``model_name`` is None, load the local model from ``path``. + - if ``path`` is None and ``model_name`` is not None, save the external ``model_name`` model to the etna folder in the home directory and load it. If ``path`` exists, external model will not be downloaded. + - if ``path`` is not None and ``model_name`` is not None, save the external ``model_name`` model to ``path`` and load it. If ``path`` exists, external model will not be downloaded. + + model_name: + Name of external model to load. To get list of available models use ``list_models`` method. + Returns ------- : Loaded object. + + Raises + ------ + ValueError: + If none of parameters ``path`` and ``model_name`` are set. + NotImplementedError: + If ``model_name`` isn't from list of available model names. """ + warnings.filterwarnings( + "ignore", + message="The object was saved under etna version 2.7.1 but running version is", + category=UserWarning, + ) + + if model_name is not None: + if path is None: + path = _DOWNLOAD_PATH / f"{model_name}.zip" + if os.path.exists(path): + warnings.warn( + f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model." + ) + else: + Path(path).parent.mkdir(exist_ok=True, parents=True) + + if model_name in cls.list_models(): + url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/ts2vec/{model_name}.zip" + request.urlretrieve(url=url, filename=path) + else: + raise NotImplementedError( + f"Model {model_name} is not available. To get list of available models use `list_models` method." + ) + elif path is None and model_name is None: + raise ValueError("Both path and model_name are not specified. At least one parameter should be specified.") + obj: TS2VecEmbeddingModel = super().load(path=path) obj.embedding_model = TS2Vec( input_dims=obj.input_dims, @@ -292,3 +344,22 @@ def load(cls, path: pathlib.Path) -> "TS2VecEmbeddingModel": obj.embedding_model.load(fn=str(model_path)) return obj + + @staticmethod + def list_models() -> List[str]: + """ + Return a list of available pretrained models. + + Main information about available models: + + - ts2vec_tiny: + + - Number of parameters - 40k + - Dimension of output embeddings - 16 + + Returns + ------- + : + List of available pretrained models. + """ + return ["ts2vec_tiny"] diff --git a/etna/transforms/embeddings/models/tstcc.py b/etna/transforms/embeddings/models/tstcc.py index f808d393f..f5b0e4665 100644 --- a/etna/transforms/embeddings/models/tstcc.py +++ b/etna/transforms/embeddings/models/tstcc.py @@ -1,7 +1,13 @@ +import os import pathlib import tempfile +import warnings import zipfile +from pathlib import Path +from typing import List from typing import Literal +from typing import Optional +from urllib import request import numpy as np @@ -11,6 +17,8 @@ if SETTINGS.torch_required: from etna.libs.tstcc import TSTCC +_DOWNLOAD_PATH = Path.home() / ".etna" / "embeddings" / "tstcc" + class TSTCCEmbeddingModel(BaseEmbeddingModel): """TSTCC embedding model. @@ -49,6 +57,7 @@ def __init__( device: Literal["cpu", "cuda"] = "cpu", batch_size: int = 16, num_workers: int = 0, + is_freezed: bool = False, ): """Init TSTCCEmbeddingModel. @@ -87,6 +96,8 @@ def __init__( The batch size (number of segments in a batch). To swap batch_size, change this attribute. num_workers: How many subprocesses to use for data loading. See (api reference :py:class:`torch.utils.data.DataLoader`). To swap num_workers, change this attribute. + is_freezed: + Whether to ``freeze`` model in constructor or not. For more details see ``freeze`` method. """ super().__init__(output_dims=output_dims) self.input_dims = input_dims @@ -125,8 +136,10 @@ def __init__( jitter_ratio=self.jitter_ratio, use_cosine_similarity=self.use_cosine_similarity, ) + self._is_freezed = is_freezed - self._is_freezed: bool = False + if self._is_freezed: + self.freeze() @property def is_freezed(self): @@ -252,7 +265,7 @@ def save(self, path: pathlib.Path): archive.write(model_save_path, "model.zip") @classmethod - def load(cls, path: pathlib.Path) -> "TSTCCEmbeddingModel": + def load(cls, path: Optional[pathlib.Path] = None, model_name: Optional[str] = None) -> "TSTCCEmbeddingModel": """Load an object. Model's weights are transferred to cpu during loading. @@ -262,11 +275,51 @@ def load(cls, path: pathlib.Path) -> "TSTCCEmbeddingModel": path: Path to load object from. + - if ``path`` is not None and ``model_name`` is None, load the local model from ``path``. + - if ``path`` is None and ``model_name`` is not None, save the external ``model_name`` model to the etna folder in the home directory and load it. If ``path`` exists, external model will not be downloaded. + - if ``path`` is not None and ``model_name`` is not None, save the external ``model_name`` model to ``path`` and load it. If ``path`` exists, external model will not be downloaded. + + model_name: + name of external model to load. To get list of available models use ``list_models`` method. + Returns ------- : Loaded object. + + Raises + ------ + ValueError: + If none of parameters ``path`` and ``model_name`` are set. + NotImplementedError: + If ``model_name`` isn't from list of available model names. """ + warnings.filterwarnings( + "ignore", + message="The object was saved under etna version 2.7.1 but running version is", + category=UserWarning, + ) + + if model_name is not None: + if path is None: + path = _DOWNLOAD_PATH / f"{model_name}.zip" + if os.path.exists(path): + warnings.warn( + f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model." + ) + else: + Path(path).parent.mkdir(exist_ok=True, parents=True) + + if model_name in cls.list_models(): + url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/tstcc/{model_name}.zip" + request.urlretrieve(url=url, filename=path) + else: + raise NotImplementedError( + f"Model {model_name} is not available. To get list of available models use `list_models` method." + ) + elif path is None and model_name is None: + raise ValueError("Both path and model_name are not specified. At least one parameter should be specified.") + obj: TSTCCEmbeddingModel = super().load(path=path) obj.embedding_model = TSTCC( input_dims=obj.input_dims, @@ -293,3 +346,22 @@ def load(cls, path: pathlib.Path) -> "TSTCCEmbeddingModel": obj.embedding_model.load(fn=str(model_path)) return obj + + @staticmethod + def list_models() -> List[str]: + """ + Return a list of available pretrained models. + + Main information about available models: + + - tstcc_medium: + + - Number of parameters - 234k + - Dimension of output embeddings - 16 + + Returns + ------- + : + List of available pretrained models. + """ + return ["tstcc_medium"] diff --git a/examples/210-embedding_models.ipynb b/examples/210-embedding_models.ipynb index 4cf141d41..7792d09de 100644 --- a/examples/210-embedding_models.ipynb +++ b/examples/210-embedding_models.ipynb @@ -24,7 +24,8 @@ " * [Baseline](#section_2_1)\n", " * [EmbeddingSegmentTransform](#section_2_2)\n", " * [EmbeddingWindowTransform](#section_2_3)\n", - "* [Saving and loading models](#chapter3)" + "* [Saving and loading models](#chapter3)\n", + "* [Loading external pretrained models](#chapter4)" ] }, { @@ -657,8 +658,7 @@ "from etna.datasets import load_dataset\n", "\n", "ts = load_dataset(\"m3_monthly\")\n", - "ts.drop_features(features=[\"origin_timestamp\"])\n", - "ts.df_exog = None\n", + "ts = TSDataset(ts.to_pandas(features=[\"target\"]), freq=None)\n", "ts.head()" ] }, @@ -676,19 +676,19 @@ "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 4.3s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 8.4s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 13.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 13.1s finished\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 9.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 14.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 14.1s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.4s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.8s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.1s finished\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.2s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" ] } ], @@ -755,14 +755,16 @@ "from etna.transforms.embeddings.models import BaseEmbeddingModel\n", "\n", "\n", - "def forecast_with_segment_embeddings(emb_model: BaseEmbeddingModel, training_params: dict) -> float:\n", + "def forecast_with_segment_embeddings(\n", + " emb_model: BaseEmbeddingModel, training_params: dict = {}, n_folds: int = 3\n", + ") -> float:\n", " model = CatBoostMultiSegmentModel()\n", "\n", " emb_transform = EmbeddingSegmentTransform(\n", " in_columns=[\"target\"], embedding_model=emb_model, training_params=training_params, out_column=\"emb\"\n", " )\n", " pipeline = Pipeline(model=model, transforms=[lag_transform, emb_transform], horizon=HORIZON)\n", - " metrics_df, _, _ = pipeline.backtest(ts, metrics=[SMAPE()], n_folds=3)\n", + " metrics_df, _, _ = pipeline.backtest(ts, metrics=[SMAPE()], n_folds=n_folds)\n", " smape_score = metrics_df[\"SMAPE\"].mean()\n", " return smape_score" ] @@ -802,20 +804,20 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 34.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.1min remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.7min remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.7min finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 35.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.2min remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.8min remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.8min finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 1.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 2.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 3.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 3.1s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 1.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 2.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 3.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 3.3s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s finished\n" ] } ], @@ -870,15 +872,15 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 27.7s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 58.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.7min remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.7min finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 26.7s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 53.9s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.4min remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.4min finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 1.7s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 3.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 4.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 4.1s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.9s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.8s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.8s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.8s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", @@ -987,20 +989,20 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 53.9s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.8min remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.7min remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.7min finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 45.8s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.9min remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.8min remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.8min finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 10.0s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 20.9s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 31.6s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 31.6s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 10.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 20.7s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 30.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 30.6s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s finished\n" ] } ], @@ -1051,20 +1053,20 @@ "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 34.5s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.2min remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.8min remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.8min finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 44.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.5min remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.3min remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.3min finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 8.6s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 17.4s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 26.3s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 26.3s finished\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 10.8s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 20.8s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 30.6s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 30.6s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", - "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", - "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s finished\n" ] } ], @@ -1213,6 +1215,505 @@ "source": [ "model_loaded.is_freezed" ] + }, + { + "cell_type": "markdown", + "id": "5d5a6f56", + "metadata": {}, + "source": [ + "## 4. Loading external pretrained models \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "565729ea", + "metadata": {}, + "source": [ + "In this section we introduce our pretrained embedding models." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "8d38bf52", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
segment | \n", + "m1 | \n", + "m10 | \n", + "m100 | \n", + "m101 | \n", + "m102 | \n", + "m103 | \n", + "m104 | \n", + "m105 | \n", + "m106 | \n", + "m107 | \n", + "... | \n", + "m90 | \n", + "m91 | \n", + "m92 | \n", + "m93 | \n", + "m94 | \n", + "m95 | \n", + "m96 | \n", + "m97 | \n", + "m98 | \n", + "m99 | \n", + "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature | \n", + "target | \n", + "target | \n", + "target | \n", + "target | \n", + "target | \n", + "target | \n", + "target | \n", + "target | \n", + "target | \n", + "target | \n", + "... | \n", + "target | \n", + "target | \n", + "target | \n", + "target | \n", + "target | \n", + "target | \n", + "target | \n", + "target | \n", + "target | \n", + "target | \n", + "
timestamp | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " |
0 | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "... | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "
1 | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "... | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "
2 | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "... | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "
3 | \n", + "NaN | \n", + "NaN | \n", + "4.0 | \n", + "329.0 | \n", + "1341.0 | \n", + "319.0 | \n", + "1419.0 | \n", + "462.0 | \n", + "921.0 | \n", + "3118.0 | \n", + "... | \n", + "7301.0 | \n", + "4374.0 | \n", + "803.0 | \n", + "191.0 | \n", + "124.0 | \n", + "319.0 | \n", + "270.0 | \n", + "36.0 | \n", + "109.0 | \n", + "38.0 | \n", + "
4 | \n", + "NaN | \n", + "NaN | \n", + "40.0 | \n", + "439.0 | \n", + "1258.0 | \n", + "315.0 | \n", + "1400.0 | \n", + "550.0 | \n", + "1060.0 | \n", + "2775.0 | \n", + "... | \n", + "13980.0 | \n", + "3470.0 | \n", + "963.0 | \n", + "265.0 | \n", + "283.0 | \n", + "690.0 | \n", + "365.0 | \n", + "31.0 | \n", + "158.0 | \n", + "74.0 | \n", + "
5 rows × 366 columns
\n", + "