From 031cd895fc4b83dc306732927785c03985966632 Mon Sep 17 00:00:00 2001 From: Safoine El khabich Date: Mon, 29 Jul 2024 11:55:51 +0100 Subject: [PATCH] chore: Remove redundant clean_requirements function and update package_utils --- .../orchestrators/databricks_orchestrator.py | 2 +- .../databricks/utils/databricks_utils.py | 23 -------- src/zenml/utils/package_utils.py | 25 +++++++++ tests/unit/utils/test_package_utils.py | 53 +++++++++++++++++++ 4 files changed, 79 insertions(+), 24 deletions(-) create mode 100644 tests/unit/utils/test_package_utils.py diff --git a/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator.py b/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator.py index 7a3c8060599..046edb6469d 100644 --- a/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator.py +++ b/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator.py @@ -42,7 +42,6 @@ DatabricksEntrypointConfiguration, ) from zenml.integrations.databricks.utils.databricks_utils import ( - clean_requirements, convert_step_to_task, ) from zenml.io import fileio @@ -53,6 +52,7 @@ from zenml.orchestrators.wheeled_orchestrator import WheeledOrchestrator from zenml.stack import StackValidator from zenml.utils import io_utils +from zenml.utils.package_utils import clean_requirements from zenml.utils.pipeline_docker_image_builder import ( PipelineDockerImageBuilder, ) diff --git a/src/zenml/integrations/databricks/utils/databricks_utils.py b/src/zenml/integrations/databricks/utils/databricks_utils.py index 4c7d77e0c07..58b371c909a 100644 --- a/src/zenml/integrations/databricks/utils/databricks_utils.py +++ b/src/zenml/integrations/databricks/utils/databricks_utils.py @@ -85,26 +85,3 @@ def sanitize_labels(labels: Dict[str, str]) -> None: labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip( "-_." ) - - -def clean_requirements(requirements: List[str]) -> List[str]: - """Clean requirements list. - - Args: - requirements: List of requirements. - - Returns: - Cleaned list of requirements - """ - cleaned = {} - for req in requirements: - package = ( - req.split(">=")[0] - .split("==")[0] - .split("<")[0] - .split("[")[0] - .strip() - ) - if package not in cleaned or ("=" in req or ">" in req or "<" in req): - cleaned[package] = req - return sorted(cleaned.values()) diff --git a/src/zenml/utils/package_utils.py b/src/zenml/utils/package_utils.py index ec17d03cca6..5b5115721c5 100644 --- a/src/zenml/utils/package_utils.py +++ b/src/zenml/utils/package_utils.py @@ -13,6 +13,8 @@ # permissions and limitations under the License. """Utility functions for the package.""" +from typing import List + import requests from packaging import version @@ -48,3 +50,26 @@ def is_latest_zenml_version() -> bool: return False else: return True + + +def clean_requirements(requirements: List[str]) -> List[str]: + """Clean requirements list from redundant requirements. + + Args: + requirements: List of requirements. + + Returns: + Cleaned list of requirements + """ + cleaned = {} + for req in requirements: + package = ( + req.split(">=")[0] + .split("==")[0] + .split("<")[0] + .split("[")[0] + .strip() + ) + if package not in cleaned or ("=" in req or ">" in req or "<" in req): + cleaned[package] = req + return sorted(cleaned.values()) diff --git a/tests/unit/utils/test_package_utils.py b/tests/unit/utils/test_package_utils.py new file mode 100644 index 00000000000..3249d4096fb --- /dev/null +++ b/tests/unit/utils/test_package_utils.py @@ -0,0 +1,53 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +import pytest + +from zenml.utils.package_utils import clean_requirements + + +@pytest.mark.parametrize("input_reqs, expected_output", [ + ( + ["package1==1.0.0", "package2>=2.0.0", "package3<3.0.0"], + ["package1==1.0.0", "package2>=2.0.0", "package3<3.0.0"] + ), + ( + ["package1==1.0.0", "package1==2.0.0", "package2>=2.0.0"], + ["package1==2.0.0", "package2>=2.0.0"] + ), + ( + ["package1[extra]==1.0.0", "package2[test,dev]>=2.0.0"], + ["package1[extra]==1.0.0", "package2[test,dev]>=2.0.0"] + ), + ( + ["package1", "package2==2.0.0", "package1>=1.5.0", "package3<3.0.0"], + ["package1>=1.5.0", "package2==2.0.0", "package3<3.0.0"] + ), + ( + [], + [] + ), +]) +def test_clean_requirements(input_reqs, expected_output): + """Test clean_requirements function.""" + assert clean_requirements(input_reqs) == expected_output + +def test_clean_requirements_type_error(): + """Test clean_requirements function with wrong input type.""" + with pytest.raises(TypeError): + clean_requirements("not a list") + +def test_clean_requirements_value_error(): + """Test clean_requirements function with wrong input value.""" + with pytest.raises(ValueError): + clean_requirements([1, 2, 3]) # List of non-string elements \ No newline at end of file