From 7fdd5993054048eb6056de7f8cc46fb0bef71f96 Mon Sep 17 00:00:00 2001 From: Avik Basu <3485425+ab93@users.noreply.github.com> Date: Fri, 20 Jan 2023 13:00:04 -0800 Subject: [PATCH] chore: catch explicit model not found exception in mlflow (#133) Signed-off-by: Avik Basu --- .codecov.yml | 4 +- .github/CODEOWNERS | 2 +- Makefile | 2 +- numalogic/__init__.py | 12 +++++ numalogic/_constants.py | 12 +++++ numalogic/models/autoencoder/__init__.py | 12 +++++ numalogic/models/autoencoder/base.py | 12 +++++ numalogic/models/autoencoder/trainer.py | 12 +++++ numalogic/models/autoencoder/variants/conv.py | 12 +++++ numalogic/models/autoencoder/variants/lstm.py | 12 +++++ .../autoencoder/variants/transformer.py | 12 +++++ .../models/autoencoder/variants/vanilla.py | 12 +++++ numalogic/postprocess.py | 12 +++++ numalogic/preprocess/__init__.py | 14 ++++++ numalogic/preprocess/transformer.py | 12 +++++ numalogic/registry/__init__.py | 12 +++++ numalogic/registry/artifact.py | 14 ++++++ numalogic/registry/mlflow_registry.py | 30 ++++++++++-- numalogic/synthetic/__init__.py | 12 +++++ numalogic/synthetic/anomalies.py | 12 +++++ numalogic/synthetic/sparsity.py | 12 +++++ numalogic/synthetic/timeseries.py | 12 +++++ numalogic/tools/callbacks.py | 12 +++++ numalogic/tools/data.py | 12 +++++ numalogic/tools/exceptions.py | 12 +++++ numalogic/tools/types.py | 12 +++++ pyproject.toml | 2 +- tests/registry/test_mlflow_registry.py | 46 +++++++++++++++---- 28 files changed, 336 insertions(+), 18 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index 58d1399a..443feeb8 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -3,8 +3,8 @@ coverage: project: default: target: auto - threshold: 1% + threshold: 5% patch: default: target: auto - threshold: 1% + threshold: 10% diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index b01da95c..c3f81a9e 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,3 +1,3 @@ # These owners will be the default owners for everything in # the repo. Unless a later match takes precedence -* @ab93 @vigith @whynowy +* @ab93 @vigith @nkoppisetty diff --git a/Makefile b/Makefile index 31b0beed..bfb5c64a 100644 --- a/Makefile +++ b/Makefile @@ -23,7 +23,7 @@ lint: format # install all dependencies setup: - poetry install --with dev --all-extras + poetry install --with dev,torch --all-extras --no-root # test your application (tests in the tests/ directory) test: diff --git a/numalogic/__init__.py b/numalogic/__init__.py index 4dcd58d9..3b960d3d 100644 --- a/numalogic/__init__.py +++ b/numalogic/__init__.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import logging diff --git a/numalogic/_constants.py b/numalogic/_constants.py index dc341ba5..250354e2 100644 --- a/numalogic/_constants.py +++ b/numalogic/_constants.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import os NUMALOGIC_DIR = os.path.dirname(__file__) diff --git a/numalogic/models/autoencoder/__init__.py b/numalogic/models/autoencoder/__init__.py index b56c2b01..296eb72f 100644 --- a/numalogic/models/autoencoder/__init__.py +++ b/numalogic/models/autoencoder/__init__.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + from numalogic.models.autoencoder.trainer import AutoencoderTrainer __all__ = ["AutoencoderTrainer"] diff --git a/numalogic/models/autoencoder/base.py b/numalogic/models/autoencoder/base.py index ca9a1992..d563d3d4 100644 --- a/numalogic/models/autoencoder/base.py +++ b/numalogic/models/autoencoder/base.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + from abc import ABCMeta import pytorch_lightning as pl diff --git a/numalogic/models/autoencoder/trainer.py b/numalogic/models/autoencoder/trainer.py index 1344ac85..961f0dbc 100644 --- a/numalogic/models/autoencoder/trainer.py +++ b/numalogic/models/autoencoder/trainer.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import logging import pytorch_lightning as pl diff --git a/numalogic/models/autoencoder/variants/conv.py b/numalogic/models/autoencoder/variants/conv.py index 76336e4e..c4d5ba03 100644 --- a/numalogic/models/autoencoder/variants/conv.py +++ b/numalogic/models/autoencoder/variants/conv.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import logging from typing import Tuple diff --git a/numalogic/models/autoencoder/variants/lstm.py b/numalogic/models/autoencoder/variants/lstm.py index dce7e13d..f469dfd0 100644 --- a/numalogic/models/autoencoder/variants/lstm.py +++ b/numalogic/models/autoencoder/variants/lstm.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import logging from typing import Tuple diff --git a/numalogic/models/autoencoder/variants/transformer.py b/numalogic/models/autoencoder/variants/transformer.py index 07570dec..f5818535 100644 --- a/numalogic/models/autoencoder/variants/transformer.py +++ b/numalogic/models/autoencoder/variants/transformer.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + from typing import Tuple import torch diff --git a/numalogic/models/autoencoder/variants/vanilla.py b/numalogic/models/autoencoder/variants/vanilla.py index 82487ba6..f6c58c24 100644 --- a/numalogic/models/autoencoder/variants/vanilla.py +++ b/numalogic/models/autoencoder/variants/vanilla.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + from typing import Tuple, Sequence import torch diff --git a/numalogic/postprocess.py b/numalogic/postprocess.py index a9514719..ba7d28b8 100644 --- a/numalogic/postprocess.py +++ b/numalogic/postprocess.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import numpy as np from numpy.typing import ArrayLike diff --git a/numalogic/preprocess/__init__.py b/numalogic/preprocess/__init__.py index e69de29b..8eb1167d 100644 --- a/numalogic/preprocess/__init__.py +++ b/numalogic/preprocess/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from numalogic.preprocess.transformer import LogTransformer, StaticPowerTransformer + +__all__ = ["LogTransformer", "StaticPowerTransformer"] diff --git a/numalogic/preprocess/transformer.py b/numalogic/preprocess/transformer.py index 83854685..4ab270a1 100644 --- a/numalogic/preprocess/transformer.py +++ b/numalogic/preprocess/transformer.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import logging import numpy as np diff --git a/numalogic/registry/__init__.py b/numalogic/registry/__init__.py index fa178b58..92506b84 100644 --- a/numalogic/registry/__init__.py +++ b/numalogic/registry/__init__.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + from numalogic.registry.artifact import ArtifactManager from numalogic.registry.artifact import ArtifactData diff --git a/numalogic/registry/artifact.py b/numalogic/registry/artifact.py index 3be27f99..21060225 100644 --- a/numalogic/registry/artifact.py +++ b/numalogic/registry/artifact.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + from abc import ABCMeta, abstractmethod from dataclasses import dataclass from typing import Sequence, Any, Dict @@ -19,6 +31,8 @@ class ArtifactManager(metaclass=ABCMeta): :param uri: server/connection uri """ + __slots__ = ("uri",) + def __init__(self, uri: str): self.uri = uri diff --git a/numalogic/registry/mlflow_registry.py b/numalogic/registry/mlflow_registry.py index fd25e320..d6276733 100644 --- a/numalogic/registry/mlflow_registry.py +++ b/numalogic/registry/mlflow_registry.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import logging from enum import Enum from typing import Optional, Sequence @@ -6,6 +18,7 @@ import mlflow.pytorch from mlflow.entities.model_registry import ModelVersion from mlflow.exceptions import RestException +from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST from mlflow.tracking import MlflowClient from numalogic.registry import ArtifactManager, ArtifactData @@ -39,9 +52,9 @@ class MLflowRegistry(ArtifactManager): Examples -------- - >>> from numalogic.models.autoencoder.variants.vanilla import VanillaAE - >>> from numalogic.registry.mlflow_registry import MLflowRegistry - >>> from sklearn.pipeline import make_pipeline + >>> from numalogic.models.autoencoder.variants import VanillaAE + >>> from numalogic.registry import MLflowRegistry + >>> from sklearn.preprocessing import StandardScaler >>> >>> data = [[0, 0], [0, 0], [1, 1], [1, 1]] >>> scaler = StandardScaler.fit(data) @@ -50,6 +63,7 @@ class MLflowRegistry(ArtifactManager): >>> artifact_data = registry.load(skeys=["model"], dkeys=["AE"]) """ + __slots__ = ("client", "handler", "models_to_retain") _TRACKING_URI = None def __new__( @@ -132,6 +146,14 @@ def load( _LOGGER.info("Successfully loaded model metadata from Mlflow!") return ArtifactData(artifact=model, metadata=metadata, extras=dict(version_info)) + except RestException as mlflow_err: + if ErrorCode.Value(mlflow_err.error_code) == RESOURCE_DOES_NOT_EXIST: + _LOGGER.info("Model not found with key: %s", model_key) + else: + _LOGGER.exception( + "Mlflow error when loading a model with key: %s: %r", model_key, mlflow_err + ) + return None except Exception as ex: _LOGGER.exception("Error when loading a model with key: %s: %r", model_key, ex) return None @@ -164,7 +186,7 @@ def save( _LOGGER.info("Successfully inserted model %s to Mlflow", model_key) return model_version except Exception as ex: - _LOGGER.exception("Error when saving a model with key: %s: %r", model_key, ex) + _LOGGER.exception("Unhandled error when saving a model with key: %s: %r", model_key, ex) return None finally: mlflow.end_run() diff --git a/numalogic/synthetic/__init__.py b/numalogic/synthetic/__init__.py index eaef2a4c..c510e22b 100644 --- a/numalogic/synthetic/__init__.py +++ b/numalogic/synthetic/__init__.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + from numalogic.synthetic.timeseries import SyntheticTSGenerator from numalogic.synthetic.anomalies import AnomalyGenerator from numalogic.synthetic.sparsity import SparsityGenerator diff --git a/numalogic/synthetic/anomalies.py b/numalogic/synthetic/anomalies.py index 27b1c253..07e61580 100644 --- a/numalogic/synthetic/anomalies.py +++ b/numalogic/synthetic/anomalies.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + from typing import Sequence, List import numpy as np diff --git a/numalogic/synthetic/sparsity.py b/numalogic/synthetic/sparsity.py index a2d99af8..0bea99bc 100644 --- a/numalogic/synthetic/sparsity.py +++ b/numalogic/synthetic/sparsity.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import random diff --git a/numalogic/synthetic/timeseries.py b/numalogic/synthetic/timeseries.py index 898ea189..89bf61cb 100644 --- a/numalogic/synthetic/timeseries.py +++ b/numalogic/synthetic/timeseries.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + from typing import Tuple import numpy as np diff --git a/numalogic/tools/callbacks.py b/numalogic/tools/callbacks.py index d2e8d9b5..4e5a02ff 100644 --- a/numalogic/tools/callbacks.py +++ b/numalogic/tools/callbacks.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import logging import pytorch_lightning as pl diff --git a/numalogic/tools/data.py b/numalogic/tools/data.py index 4f78a2d7..03d6a907 100644 --- a/numalogic/tools/data.py +++ b/numalogic/tools/data.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import numpy as np import torch from numalogic.tools.exceptions import DataModuleError, InvalidDataShapeError diff --git a/numalogic/tools/exceptions.py b/numalogic/tools/exceptions.py index 74598b20..d7f3dcb1 100644 --- a/numalogic/tools/exceptions.py +++ b/numalogic/tools/exceptions.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + class ModelInitializationError(Exception): pass diff --git a/numalogic/tools/types.py b/numalogic/tools/types.py index a29d2855..dca8ec5c 100644 --- a/numalogic/tools/types.py +++ b/numalogic/tools/types.py @@ -1,3 +1,15 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + from typing import Union, Dict, NewType, TypeVar, Sequence, Optional from sklearn.base import BaseEstimator, TransformerMixin diff --git a/pyproject.toml b/pyproject.toml index 27f1e2d2..103db0d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "numalogic" -version = "0.3.1" +version = "0.3.2" description = "Collection of operational Machine Learning models and tools." authors = ["Numalogic Developers"] packages = [{ include = "numalogic" }] diff --git a/tests/registry/test_mlflow_registry.py b/tests/registry/test_mlflow_registry.py index 11c6943e..d24291b8 100644 --- a/tests/registry/test_mlflow_registry.py +++ b/tests/registry/test_mlflow_registry.py @@ -3,6 +3,8 @@ from unittest.mock import patch, Mock from mlflow import ActiveRun +from mlflow.exceptions import RestException +from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST, ErrorCode, RESOURCE_LIMIT_EXCEEDED from sklearn.ensemble import RandomForestRegressor from numalogic.models.autoencoder.variants import VanillaAE @@ -54,7 +56,7 @@ def test_construct_key(self): @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version) - def test_insert_model(self): + def test_save_model(self): ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys dkeys = self.dkeys @@ -73,7 +75,7 @@ def test_insert_model(self): @patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage) @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) - def test_insert_model_sklearn(self): + def test_save_model_sklearn(self): model = self.model_sklearn ml = MLflowRegistry(TRACKING_URI, artifact_type="sklearn") skeys = self.skeys @@ -94,7 +96,7 @@ def test_insert_model_sklearn(self): @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict())) - def test_select_model_when_pytorch_model_exist1(self): + def test_load_model_when_pytorch_model_exist1(self): model = self.model ml = MLflowRegistry(TRACKING_URI, artifact_type="pytorch") skeys = self.skeys @@ -111,7 +113,7 @@ def test_select_model_when_pytorch_model_exist1(self): @patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version) @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata())) - def test_select_model_when_pytorch_model_exist2(self): + def test_load_model_when_pytorch_model_exist2(self): model = self.model ml = MLflowRegistry(TRACKING_URI, artifact_type="pytorch", models_to_retain=2) skeys = self.skeys @@ -133,7 +135,7 @@ def test_select_model_when_pytorch_model_exist2(self): @patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2) @patch("mlflow.sklearn.load_model", Mock(return_value=RandomForestRegressor())) @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata())) - def test_select_model_when_sklearn_model_exist(self): + def test_load_model_when_sklearn_model_exist(self): model = self.model_sklearn ml = MLflowRegistry(TRACKING_URI, artifact_type="sklearn") skeys = self.skeys @@ -155,7 +157,7 @@ def test_select_model_when_sklearn_model_exist(self): @patch("mlflow.tracking.MlflowClient.get_model_version", mock_get_model_version_obj) @patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10))) @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_empty_rundata())) - def test_select_model_with_version(self): + def test_load_model_with_version(self): model = self.model ml = MLflowRegistry(TRACKING_URI) skeys = self.skeys @@ -170,7 +172,7 @@ def test_select_model_with_version(self): self.assertIsNone(data.metadata) @patch("mlflow.pyfunc.load_model", Mock(side_effect=RuntimeError)) - def test_select_model_when_no_model_01(self): + def test_load_model_when_no_model_01(self): fake_skeys = ["Fakemodel_"] fake_dkeys = ["error"] ml = MLflowRegistry(TRACKING_URI, artifact_type="pyfunc") @@ -179,7 +181,7 @@ def test_select_model_when_no_model_01(self): self.assertTrue(log.output) @patch("mlflow.tensorflow.load_model", Mock(side_effect=RuntimeError)) - def test_select_model_when_no_model_02(self): + def test_load_model_when_no_model_02(self): fake_skeys = ["Fakemodel_"] fake_dkeys = ["error"] ml = MLflowRegistry(TRACKING_URI, artifact_type="tensorflow") @@ -237,7 +239,7 @@ def test_delete_model_when_no_model(self): @patch("mlflow.pytorch.log_model", Mock(side_effect=RuntimeError)) @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_empty_rundata()))) @patch("mlflow.active_run", Mock(return_value=return_empty_rundata())) - def test_insertion_failed(self): + def test_save_failed(self): fake_skeys = ["Fakemodel_"] fake_dkeys = ["error"] @@ -246,6 +248,32 @@ def test_insertion_failed(self): ml.save(skeys=fake_skeys, dkeys=fake_dkeys, artifact=self.model) self.assertTrue(log.output) + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) + @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) + @patch( + "mlflow.pytorch.load_model", + Mock(side_effect=RestException({"error_code": ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)})), + ) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict())) + def test_load_no_model_found(self): + ml = MLflowRegistry(TRACKING_URI, artifact_type="pytorch") + skeys = self.skeys + dkeys = self.dkeys + self.assertIsNone(ml.load(skeys=skeys, dkeys=dkeys)) + + @patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict()))) + @patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict())) + @patch( + "mlflow.pytorch.load_model", + Mock(side_effect=RestException({"error_code": ErrorCode.Name(RESOURCE_LIMIT_EXCEEDED)})), + ) + @patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict())) + def test_load_other_mlflow_err(self): + ml = MLflowRegistry(TRACKING_URI, artifact_type="pytorch") + skeys = self.skeys + dkeys = self.dkeys + self.assertIsNone(ml.load(skeys=skeys, dkeys=dkeys)) + if __name__ == "__main__": unittest.main()