Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility to load pretrained embedding models #461

Merged
merged 18 commits into from
Aug 30, 2024
75 changes: 73 additions & 2 deletions etna/transforms/embeddings/models/ts2vec.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import os
import pathlib
import tempfile
import warnings
import zipfile
from pathlib import Path
from typing import List
from typing import Literal
from typing import Optional
from urllib import request

import numpy as np

Expand All @@ -12,6 +17,8 @@
if SETTINGS.torch_required:
from etna.libs.ts2vec import TS2Vec

_DOWNLOAD_PATH = Path.home() / ".etna" / "embeddings" / "ts2vec"


class TS2VecEmbeddingModel(BaseEmbeddingModel):
"""TS2Vec embedding model.
Expand Down Expand Up @@ -39,6 +46,7 @@ def __init__(
num_workers: int = 0,
max_train_length: Optional[int] = None,
temporal_unit: int = 0,
is_freezed: bool = False,
):
"""Init TS2VecEmbeddingModel.

Expand All @@ -64,6 +72,8 @@ def __init__(
temporal_unit:
The minimum unit to perform temporal contrast. When training on a very long sequence,
this param helps to reduce the cost of time and memory.
is_freezed:
Whether to ``freeze`` model in constructor or not. For more details see ``freeze`` method.
Notes
-----
In case of long series to reduce memory consumption it is recommended to use max_train_length parameter or manually break the series into smaller subseries.
Expand All @@ -88,8 +98,10 @@ def __init__(
max_train_length=self.max_train_length,
temporal_unit=self.temporal_unit,
)
self._is_freezed = is_freezed

self._is_freezed: bool = False
if self._is_freezed:
self.freeze()

@property
def is_freezed(self):
Expand Down Expand Up @@ -257,7 +269,7 @@ def save(self, path: pathlib.Path):
archive.write(model_save_path, "model.zip")

@classmethod
def load(cls, path: pathlib.Path) -> "TS2VecEmbeddingModel":
def load(cls, path: Optional[pathlib.Path] = None, model_name: Optional[str] = None) -> "TS2VecEmbeddingModel":
"""Load an object.

Model's weights are transferred to cpu during loading.
Expand All @@ -267,11 +279,51 @@ def load(cls, path: pathlib.Path) -> "TS2VecEmbeddingModel":
path:
Path to load object from.

- if ``path`` is not None and ``model_name`` is None, load the local model from ``path``.
- if ``path`` is None and ``model_name`` is not None, save the external ``model_name`` model to the etna folder in the home directory and load it. If ``path`` exists, external model will not be downloaded.
- if ``path`` is not None and ``model_name`` is not None, save the external ``model_name`` model to ``path`` and load it. If ``path`` exists, external model will not be downloaded.

model_name:
Name of external model to load. To get list of available models use ``list_models`` method.

Returns
-------
:
Loaded object.

Raises
------
ValueError:
If none of parameters ``path`` and ``model_name`` are set.
NotImplementedError:
If ``model_name`` isn't from list of available model names.
"""
warnings.filterwarnings(
"ignore",
message="The object was saved under etna version 2.7.1 but running version is",
category=UserWarning,
)

if model_name is not None:
if path is None:
path = _DOWNLOAD_PATH / f"{model_name}.zip"
if os.path.exists(path):
warnings.warn(
f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model."
)
else:
Path(path).parent.mkdir(exist_ok=True, parents=True)

if model_name in cls.list_models():
url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/ts2vec/{model_name}.zip"
request.urlretrieve(url=url, filename=path)
else:
raise NotImplementedError(
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
f"Model {model_name} is not available. To get list of available models use `list_models` method."
)
elif path is None and model_name is None:
raise ValueError("Both path and model_name are not specified. At least one parameter should be specified.")

obj: TS2VecEmbeddingModel = super().load(path=path)
obj.embedding_model = TS2Vec(
input_dims=obj.input_dims,
Expand All @@ -292,3 +344,22 @@ def load(cls, path: pathlib.Path) -> "TS2VecEmbeddingModel":
obj.embedding_model.load(fn=str(model_path))

return obj

@staticmethod
def list_models() -> List[str]:
"""
Return a list of available pretrained models.

Main information about available models:
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved

- ts2vec_tiny:

- Number of parameters - 40k
- Dimension of output embeddings - 16

Returns
-------
:
List of available pretrained models.
"""
return ["ts2vec_tiny"]
76 changes: 74 additions & 2 deletions etna/transforms/embeddings/models/tstcc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import os
import pathlib
import tempfile
import warnings
import zipfile
from pathlib import Path
from typing import List
from typing import Literal
from typing import Optional
from urllib import request

import numpy as np

Expand All @@ -11,6 +17,8 @@
if SETTINGS.torch_required:
from etna.libs.tstcc import TSTCC

_DOWNLOAD_PATH = Path.home() / ".etna" / "embeddings" / "tstcc"


class TSTCCEmbeddingModel(BaseEmbeddingModel):
"""TSTCC embedding model.
Expand Down Expand Up @@ -49,6 +57,7 @@ def __init__(
device: Literal["cpu", "cuda"] = "cpu",
batch_size: int = 16,
num_workers: int = 0,
is_freezed: bool = False,
):
"""Init TSTCCEmbeddingModel.

Expand Down Expand Up @@ -87,6 +96,8 @@ def __init__(
The batch size (number of segments in a batch). To swap batch_size, change this attribute.
num_workers:
How many subprocesses to use for data loading. See (api reference :py:class:`torch.utils.data.DataLoader`). To swap num_workers, change this attribute.
is_freezed:
Whether to ``freeze`` model in constructor or not. For more details see ``freeze`` method.
"""
super().__init__(output_dims=output_dims)
self.input_dims = input_dims
Expand Down Expand Up @@ -125,8 +136,10 @@ def __init__(
jitter_ratio=self.jitter_ratio,
use_cosine_similarity=self.use_cosine_similarity,
)
self._is_freezed = is_freezed

self._is_freezed: bool = False
if self._is_freezed:
self.freeze()

@property
def is_freezed(self):
Expand Down Expand Up @@ -252,7 +265,7 @@ def save(self, path: pathlib.Path):
archive.write(model_save_path, "model.zip")

@classmethod
def load(cls, path: pathlib.Path) -> "TSTCCEmbeddingModel":
def load(cls, path: Optional[pathlib.Path] = None, model_name: Optional[str] = None) -> "TSTCCEmbeddingModel":
"""Load an object.

Model's weights are transferred to cpu during loading.
Expand All @@ -262,11 +275,51 @@ def load(cls, path: pathlib.Path) -> "TSTCCEmbeddingModel":
path:
Path to load object from.

- if ``path`` is not None and ``model_name`` is None, load the local model from ``path``.
- if ``path`` is None and ``model_name`` is not None, save the external ``model_name`` model to the etna folder in the home directory and load it. If ``path`` exists, external model will not be downloaded.
- if ``path`` is not None and ``model_name`` is not None, save the external ``model_name`` model to ``path`` and load it. If ``path`` exists, external model will not be downloaded.

model_name:
name of external model to load. To get list of available models use ``list_models`` method.

Returns
-------
:
Loaded object.

Raises
------
ValueError:
If none of parameters ``path`` and ``model_name`` are set.
NotImplementedError:
If ``model_name`` isn't from list of available model names.
"""
warnings.filterwarnings(
"ignore",
message="The object was saved under etna version 2.7.1 but running version is",
category=UserWarning,
)

if model_name is not None:
if path is None:
path = _DOWNLOAD_PATH / f"{model_name}.zip"
if os.path.exists(path):
warnings.warn(
f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model."
)
else:
Path(path).parent.mkdir(exist_ok=True, parents=True)

if model_name in cls.list_models():
url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/tstcc/{model_name}.zip"
request.urlretrieve(url=url, filename=path)
else:
raise NotImplementedError(
f"Model {model_name} is not available. To get list of available models use `list_models` method."
)
elif path is None and model_name is None:
raise ValueError("Both path and model_name are not specified. At least one parameter should be specified.")

obj: TSTCCEmbeddingModel = super().load(path=path)
obj.embedding_model = TSTCC(
input_dims=obj.input_dims,
Expand All @@ -293,3 +346,22 @@ def load(cls, path: pathlib.Path) -> "TSTCCEmbeddingModel":
obj.embedding_model.load(fn=str(model_path))

return obj

@staticmethod
def list_models() -> List[str]:
"""
Return a list of available pretrained models.

Main information about available models:

- tstcc_medium:

- Number of parameters - 234k
- Dimension of output embeddings - 16

Returns
-------
:
List of available pretrained models.
"""
return ["tstcc_medium"]
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading