diff --git a/src/zenml/utils/package_utils.py b/src/zenml/utils/package_utils.py index 5b5115721c5..0b92c1fd532 100644 --- a/src/zenml/utils/package_utils.py +++ b/src/zenml/utils/package_utils.py @@ -60,16 +60,30 @@ def clean_requirements(requirements: List[str]) -> List[str]: Returns: Cleaned list of requirements + + Raises: + TypeError: If input is not a list + ValueError: If any element in the list is not a string """ + if not isinstance(requirements, list): + raise TypeError("Input must be a list") + + if not all(isinstance(req, str) for req in requirements): + raise ValueError("All elements in the list must be strings") + cleaned = {} for req in requirements: package = ( req.split(">=")[0] .split("==")[0] .split("<")[0] + .split("~=")[0] + .split("^=")[0] .split("[")[0] .strip() ) - if package not in cleaned or ("=" in req or ">" in req or "<" in req): + if package not in cleaned or any( + op in req for op in ["=", ">", "<", "~", "^"] + ): cleaned[package] = req return sorted(cleaned.values()) diff --git a/tests/unit/utils/test_package_utils.py b/tests/unit/utils/test_package_utils.py index c69e9d4641c..5358babb9ef 100644 --- a/tests/unit/utils/test_package_utils.py +++ b/tests/unit/utils/test_package_utils.py @@ -41,6 +41,15 @@ ["package1>=1.5.0", "package2==2.0.0", "package3<3.0.0"], ), ([], []), + ( + ["package1~=1.0.0", "package2^=2.0.0", "package3==3.0.0"], + ["package1~=1.0.0", "package2^=2.0.0", "package3==3.0.0"], + ), + ( + ["package1~=1.0.0", "package1^=1.1.0", "package1==1.2.0"], + ["package1==1.2.0"], + ), + (["package1", "package1~=1.0.0"], ["package1~=1.0.0"]), ], ) def test_clean_requirements(input_reqs, expected_output): @@ -48,7 +57,19 @@ def test_clean_requirements(input_reqs, expected_output): assert clean_requirements(input_reqs) == expected_output +def test_clean_requirements_type_error(): + """Test clean_requirements function with wrong input type.""" + with pytest.raises(TypeError): + clean_requirements("not a list") + + def test_clean_requirements_value_error(): """Test clean_requirements function with wrong input value.""" - with pytest.raises(AttributeError): + with pytest.raises(ValueError): clean_requirements([1, 2, 3]) # List of non-string elements + + +def test_clean_requirements_mixed_types(): + """Test clean_requirements function with mixed types in list.""" + with pytest.raises(ValueError): + clean_requirements(["package1==1.0.0", 2, "package3<3.0.0"])