Skip to content

Commit

Permalink
Add possibility to load pretrained embedding models (#461)
Browse files Browse the repository at this point in the history
* update load method

* add tests, some fixes

* update notebook

* add filterwarning

* fixes

* add freezed parameter

* fix

* check model_name in list_models

* fix

* reformat freeze parameter

* fix tests for tstcc

* add warning ignore

* fix test

* fix notebook

* update changelog

* update changelog

---------

Co-authored-by: Egor Baturin <[email protected]>
  • Loading branch information
egoriyaa and Egor Baturin authored Aug 30, 2024
1 parent b19507d commit b3a48c3
Show file tree
Hide file tree
Showing 7 changed files with 824 additions and 71 deletions.
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
- Add `**kwargs` argument description for models based on `LinearRegression`, `ElasticNet` and `CatBoostRegressor` ([#454](https://github.com/etna-team/etna/pull/454https://github.com/etna-team/etna/pull/454))
-
-
- Add `**kwargs` argument description for models based on `LinearRegression`, `ElasticNet` and `CatBoostRegressor` ([#454](https://github.com/etna-team/etna/pull/454))
- Add possibility to load pretrained embedding models ([#461](https://github.com/etna-team/etna/pull/461))
- Add `is_freezed` parameter to `TS2VecEmbeddingModel` and `TSTCCEmbeddingModel` ([#461](https://github.com/etna-team/etna/pull/461))
-
-
-
Expand Down
75 changes: 73 additions & 2 deletions etna/transforms/embeddings/models/ts2vec.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import os
import pathlib
import tempfile
import warnings
import zipfile
from pathlib import Path
from typing import List
from typing import Literal
from typing import Optional
from urllib import request

import numpy as np

Expand All @@ -12,6 +17,8 @@
if SETTINGS.torch_required:
from etna.libs.ts2vec import TS2Vec

_DOWNLOAD_PATH = Path.home() / ".etna" / "embeddings" / "ts2vec"


class TS2VecEmbeddingModel(BaseEmbeddingModel):
"""TS2Vec embedding model.
Expand Down Expand Up @@ -39,6 +46,7 @@ def __init__(
num_workers: int = 0,
max_train_length: Optional[int] = None,
temporal_unit: int = 0,
is_freezed: bool = False,
):
"""Init TS2VecEmbeddingModel.
Expand All @@ -64,6 +72,8 @@ def __init__(
temporal_unit:
The minimum unit to perform temporal contrast. When training on a very long sequence,
this param helps to reduce the cost of time and memory.
is_freezed:
Whether to ``freeze`` model in constructor or not. For more details see ``freeze`` method.
Notes
-----
In case of long series to reduce memory consumption it is recommended to use max_train_length parameter or manually break the series into smaller subseries.
Expand All @@ -88,8 +98,10 @@ def __init__(
max_train_length=self.max_train_length,
temporal_unit=self.temporal_unit,
)
self._is_freezed = is_freezed

self._is_freezed: bool = False
if self._is_freezed:
self.freeze()

@property
def is_freezed(self):
Expand Down Expand Up @@ -257,7 +269,7 @@ def save(self, path: pathlib.Path):
archive.write(model_save_path, "model.zip")

@classmethod
def load(cls, path: pathlib.Path) -> "TS2VecEmbeddingModel":
def load(cls, path: Optional[pathlib.Path] = None, model_name: Optional[str] = None) -> "TS2VecEmbeddingModel":
"""Load an object.
Model's weights are transferred to cpu during loading.
Expand All @@ -267,11 +279,51 @@ def load(cls, path: pathlib.Path) -> "TS2VecEmbeddingModel":
path:
Path to load object from.
- if ``path`` is not None and ``model_name`` is None, load the local model from ``path``.
- if ``path`` is None and ``model_name`` is not None, save the external ``model_name`` model to the etna folder in the home directory and load it. If ``path`` exists, external model will not be downloaded.
- if ``path`` is not None and ``model_name`` is not None, save the external ``model_name`` model to ``path`` and load it. If ``path`` exists, external model will not be downloaded.
model_name:
Name of external model to load. To get list of available models use ``list_models`` method.
Returns
-------
:
Loaded object.
Raises
------
ValueError:
If none of parameters ``path`` and ``model_name`` are set.
NotImplementedError:
If ``model_name`` isn't from list of available model names.
"""
warnings.filterwarnings(
"ignore",
message="The object was saved under etna version 2.7.1 but running version is",
category=UserWarning,
)

if model_name is not None:
if path is None:
path = _DOWNLOAD_PATH / f"{model_name}.zip"
if os.path.exists(path):
warnings.warn(
f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model."
)
else:
Path(path).parent.mkdir(exist_ok=True, parents=True)

if model_name in cls.list_models():
url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/ts2vec/{model_name}.zip"
request.urlretrieve(url=url, filename=path)
else:
raise NotImplementedError(
f"Model {model_name} is not available. To get list of available models use `list_models` method."
)
elif path is None and model_name is None:
raise ValueError("Both path and model_name are not specified. At least one parameter should be specified.")

obj: TS2VecEmbeddingModel = super().load(path=path)
obj.embedding_model = TS2Vec(
input_dims=obj.input_dims,
Expand All @@ -292,3 +344,22 @@ def load(cls, path: pathlib.Path) -> "TS2VecEmbeddingModel":
obj.embedding_model.load(fn=str(model_path))

return obj

@staticmethod
def list_models() -> List[str]:
"""
Return a list of available pretrained models.
Main information about available models:
- ts2vec_tiny:
- Number of parameters - 40k
- Dimension of output embeddings - 16
Returns
-------
:
List of available pretrained models.
"""
return ["ts2vec_tiny"]
76 changes: 74 additions & 2 deletions etna/transforms/embeddings/models/tstcc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import os
import pathlib
import tempfile
import warnings
import zipfile
from pathlib import Path
from typing import List
from typing import Literal
from typing import Optional
from urllib import request

import numpy as np

Expand All @@ -11,6 +17,8 @@
if SETTINGS.torch_required:
from etna.libs.tstcc import TSTCC

_DOWNLOAD_PATH = Path.home() / ".etna" / "embeddings" / "tstcc"


class TSTCCEmbeddingModel(BaseEmbeddingModel):
"""TSTCC embedding model.
Expand Down Expand Up @@ -49,6 +57,7 @@ def __init__(
device: Literal["cpu", "cuda"] = "cpu",
batch_size: int = 16,
num_workers: int = 0,
is_freezed: bool = False,
):
"""Init TSTCCEmbeddingModel.
Expand Down Expand Up @@ -87,6 +96,8 @@ def __init__(
The batch size (number of segments in a batch). To swap batch_size, change this attribute.
num_workers:
How many subprocesses to use for data loading. See (api reference :py:class:`torch.utils.data.DataLoader`). To swap num_workers, change this attribute.
is_freezed:
Whether to ``freeze`` model in constructor or not. For more details see ``freeze`` method.
"""
super().__init__(output_dims=output_dims)
self.input_dims = input_dims
Expand Down Expand Up @@ -125,8 +136,10 @@ def __init__(
jitter_ratio=self.jitter_ratio,
use_cosine_similarity=self.use_cosine_similarity,
)
self._is_freezed = is_freezed

self._is_freezed: bool = False
if self._is_freezed:
self.freeze()

@property
def is_freezed(self):
Expand Down Expand Up @@ -252,7 +265,7 @@ def save(self, path: pathlib.Path):
archive.write(model_save_path, "model.zip")

@classmethod
def load(cls, path: pathlib.Path) -> "TSTCCEmbeddingModel":
def load(cls, path: Optional[pathlib.Path] = None, model_name: Optional[str] = None) -> "TSTCCEmbeddingModel":
"""Load an object.
Model's weights are transferred to cpu during loading.
Expand All @@ -262,11 +275,51 @@ def load(cls, path: pathlib.Path) -> "TSTCCEmbeddingModel":
path:
Path to load object from.
- if ``path`` is not None and ``model_name`` is None, load the local model from ``path``.
- if ``path`` is None and ``model_name`` is not None, save the external ``model_name`` model to the etna folder in the home directory and load it. If ``path`` exists, external model will not be downloaded.
- if ``path`` is not None and ``model_name`` is not None, save the external ``model_name`` model to ``path`` and load it. If ``path`` exists, external model will not be downloaded.
model_name:
name of external model to load. To get list of available models use ``list_models`` method.
Returns
-------
:
Loaded object.
Raises
------
ValueError:
If none of parameters ``path`` and ``model_name`` are set.
NotImplementedError:
If ``model_name`` isn't from list of available model names.
"""
warnings.filterwarnings(
"ignore",
message="The object was saved under etna version 2.7.1 but running version is",
category=UserWarning,
)

if model_name is not None:
if path is None:
path = _DOWNLOAD_PATH / f"{model_name}.zip"
if os.path.exists(path):
warnings.warn(
f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model."
)
else:
Path(path).parent.mkdir(exist_ok=True, parents=True)

if model_name in cls.list_models():
url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/tstcc/{model_name}.zip"
request.urlretrieve(url=url, filename=path)
else:
raise NotImplementedError(
f"Model {model_name} is not available. To get list of available models use `list_models` method."
)
elif path is None and model_name is None:
raise ValueError("Both path and model_name are not specified. At least one parameter should be specified.")

obj: TSTCCEmbeddingModel = super().load(path=path)
obj.embedding_model = TSTCC(
input_dims=obj.input_dims,
Expand All @@ -293,3 +346,22 @@ def load(cls, path: pathlib.Path) -> "TSTCCEmbeddingModel":
obj.embedding_model.load(fn=str(model_path))

return obj

@staticmethod
def list_models() -> List[str]:
"""
Return a list of available pretrained models.
Main information about available models:
- tstcc_medium:
- Number of parameters - 234k
- Dimension of output embeddings - 16
Returns
-------
:
List of available pretrained models.
"""
return ["tstcc_medium"]
Loading

0 comments on commit b3a48c3

Please sign in to comment.