Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility to load pretrained embedding models #461

Merged
merged 18 commits into from
Aug 30, 2024
54 changes: 53 additions & 1 deletion etna/transforms/embeddings/models/ts2vec.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import os
import pathlib
import tempfile
import warnings
import zipfile
from pathlib import Path
from typing import List
from typing import Literal
from typing import Optional
from urllib import request
from urllib.error import HTTPError

import numpy as np

Expand All @@ -12,6 +18,8 @@
if SETTINGS.torch_required:
from etna.libs.ts2vec import TS2Vec

_DOWNLOAD_PATH = Path.home() / ".etna" / "embeddings" / "ts2vec"


class TS2VecEmbeddingModel(BaseEmbeddingModel):
"""TS2Vec embedding model.
Expand Down Expand Up @@ -257,7 +265,7 @@
archive.write(model_save_path, "model.zip")

@classmethod
def load(cls, path: pathlib.Path) -> "TS2VecEmbeddingModel":
def load(cls, path: Optional[pathlib.Path] = None, model_name: Optional[str] = None) -> "TS2VecEmbeddingModel":
"""Load an object.

Model's weights are transferred to cpu during loading.
Expand All @@ -267,11 +275,50 @@
path:
Path to load object from.

- if `path` is not None and `model_name` is None, load the local model from `path`.
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
- if `path` is None and `model_name` is not None, save the external `model_name` model to the etna folder in the home directory and load it.
If `path` exists, external model will not be loaded.
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
- if `path` is not None and `model_name` is not None, save the external `model_name` model to `path` and load it.
If `path` exists, external model will not be loaded.

model_name:
name of external model to load. To get list of available models use `list_models` method.
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
:
Loaded object.

Raises
------
ValueError:
If non of parameters path and model_name are set.
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
NotImplementedError:
If model_name is not from list of available model names
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
"""
if model_name is not None:
if path is None:
path = _DOWNLOAD_PATH / f"{model_name}.zip"
if os.path.exists(path):
warnings.warn(

Check warning on line 303 in etna/transforms/embeddings/models/ts2vec.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/embeddings/models/ts2vec.py#L300-L303

Added lines #L300 - L303 were not covered by tests
f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model."
)
else:
directory = os.path.dirname(path)

Check warning on line 307 in etna/transforms/embeddings/models/ts2vec.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/embeddings/models/ts2vec.py#L307

Added line #L307 was not covered by tests
# If path not in current directory and it doesn't exist
if directory and not os.path.exists(directory):
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
os.makedirs(directory)

Check warning on line 310 in etna/transforms/embeddings/models/ts2vec.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/embeddings/models/ts2vec.py#L309-L310

Added lines #L309 - L310 were not covered by tests

try:
url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/ts2vec/{model_name}.zip"
request.urlretrieve(url=url, filename=path)
except HTTPError:
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(

Check warning on line 316 in etna/transforms/embeddings/models/ts2vec.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/embeddings/models/ts2vec.py#L312-L316

Added lines #L312 - L316 were not covered by tests
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
f"Model {model_name} is not available. To get list of available models use `list_models` method."
)
elif path is None and model_name is None:
raise ValueError("Both path and model_name are not specified. Specify one parameter at least.")

Check warning on line 320 in etna/transforms/embeddings/models/ts2vec.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/embeddings/models/ts2vec.py#L320

Added line #L320 was not covered by tests
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved

obj: TS2VecEmbeddingModel = super().load(path=path)
obj.embedding_model = TS2Vec(
input_dims=obj.input_dims,
Expand All @@ -292,3 +339,8 @@
obj.embedding_model.load(fn=str(model_path))

return obj

@staticmethod
def list_models() -> List[str]:
"""Return a list of available pretrained models."""
return ["ts2vec_tiny"]

Check warning on line 346 in etna/transforms/embeddings/models/ts2vec.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/embeddings/models/ts2vec.py#L346

Added line #L346 was not covered by tests
55 changes: 54 additions & 1 deletion etna/transforms/embeddings/models/tstcc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import os
import pathlib
import tempfile
import warnings
import zipfile
from pathlib import Path
from typing import List
from typing import Literal
from typing import Optional
from urllib import request
from urllib.error import HTTPError

import numpy as np

Expand All @@ -11,6 +18,8 @@
if SETTINGS.torch_required:
from etna.libs.tstcc import TSTCC

_DOWNLOAD_PATH = Path.home() / ".etna" / "embeddings" / "tstcc"


class TSTCCEmbeddingModel(BaseEmbeddingModel):
"""TSTCC embedding model.
Expand Down Expand Up @@ -252,7 +261,7 @@
archive.write(model_save_path, "model.zip")

@classmethod
def load(cls, path: pathlib.Path) -> "TSTCCEmbeddingModel":
def load(cls, path: Optional[pathlib.Path] = None, model_name: Optional[str] = None) -> "TSTCCEmbeddingModel":
"""Load an object.

Model's weights are transferred to cpu during loading.
Expand All @@ -262,11 +271,50 @@
path:
Path to load object from.

- if `path` is not None and `model_name` is None, load the local model from `path`.
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
- if `path` is None and `model_name` is not None, save the external `model_name` model to the etna folder in the home directory and load it.
If `path` exists, external model will not be loaded.
- if `path` is not None and `model_name` is not None, save the external `model_name` model to `path` and load it.
If `path` exists, external model will not be loaded.

model_name:
name of external model to load. To get list of available models use `list_models` method.

Returns
-------
:
Loaded object.

Raises
------
ValueError:
If non of parameters path and model_name are set.
NotImplementedError:
If model_name is not from list of available model names
"""
if model_name is not None:
if path is None:
path = _DOWNLOAD_PATH / f"{model_name}.zip"
if os.path.exists(path):
warnings.warn(

Check warning on line 299 in etna/transforms/embeddings/models/tstcc.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/embeddings/models/tstcc.py#L296-L299

Added lines #L296 - L299 were not covered by tests
f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model."
)
else:
directory = os.path.dirname(path)

Check warning on line 303 in etna/transforms/embeddings/models/tstcc.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/embeddings/models/tstcc.py#L303

Added line #L303 was not covered by tests
# If path not in current directory and it doesn't exist
if directory and not os.path.exists(directory):
os.makedirs(directory)

Check warning on line 306 in etna/transforms/embeddings/models/tstcc.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/embeddings/models/tstcc.py#L305-L306

Added lines #L305 - L306 were not covered by tests

try:
url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/tstcc/{model_name}.zip"
request.urlretrieve(url=url, filename=path)
except HTTPError:
raise NotImplementedError(

Check warning on line 312 in etna/transforms/embeddings/models/tstcc.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/embeddings/models/tstcc.py#L308-L312

Added lines #L308 - L312 were not covered by tests
f"Model {model_name} is not available. To get list of available models use `list_models` method."
)
elif path is None and model_name is None:
raise ValueError("Both path and model_name are not specified. Specify one parameter at least.")

Check warning on line 316 in etna/transforms/embeddings/models/tstcc.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/embeddings/models/tstcc.py#L316

Added line #L316 was not covered by tests

obj: TSTCCEmbeddingModel = super().load(path=path)
obj.embedding_model = TSTCC(
input_dims=obj.input_dims,
Expand All @@ -293,3 +341,8 @@
obj.embedding_model.load(fn=str(model_path))

return obj

@staticmethod
def list_models() -> List[str]:
"""Return a list of available pretrained models."""
return ["tstcc_medium"]

Check warning on line 348 in etna/transforms/embeddings/models/tstcc.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/embeddings/models/tstcc.py#L348

Added line #L348 was not covered by tests
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading