Skip to content

Commit

Permalink
fix clean req
Browse files Browse the repository at this point in the history
  • Loading branch information
safoinme committed Jul 29, 2024
1 parent e629eef commit 05dd774
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
16 changes: 15 additions & 1 deletion src/zenml/utils/package_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,30 @@ def clean_requirements(requirements: List[str]) -> List[str]:
Returns:
Cleaned list of requirements
Raises:
TypeError: If input is not a list
ValueError: If any element in the list is not a string
"""
if not isinstance(requirements, list):
raise TypeError("Input must be a list")

if not all(isinstance(req, str) for req in requirements):
raise ValueError("All elements in the list must be strings")

cleaned = {}
for req in requirements:
package = (
req.split(">=")[0]
.split("==")[0]
.split("<")[0]
.split("~=")[0]
.split("^=")[0]
.split("[")[0]
.strip()
)
if package not in cleaned or ("=" in req or ">" in req or "<" in req):
if package not in cleaned or any(
op in req for op in ["=", ">", "<", "~", "^"]
):
cleaned[package] = req
return sorted(cleaned.values())
23 changes: 22 additions & 1 deletion tests/unit/utils/test_package_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,35 @@
["package1>=1.5.0", "package2==2.0.0", "package3<3.0.0"],
),
([], []),
(
["package1~=1.0.0", "package2^=2.0.0", "package3==3.0.0"],
["package1~=1.0.0", "package2^=2.0.0", "package3==3.0.0"],
),
(
["package1~=1.0.0", "package1^=1.1.0", "package1==1.2.0"],
["package1==1.2.0"],
),
(["package1", "package1~=1.0.0"], ["package1~=1.0.0"]),
],
)
def test_clean_requirements(input_reqs, expected_output):
"""Test clean_requirements function."""
assert clean_requirements(input_reqs) == expected_output


def test_clean_requirements_type_error():
"""Test clean_requirements function with wrong input type."""
with pytest.raises(TypeError):
clean_requirements("not a list")


def test_clean_requirements_value_error():
"""Test clean_requirements function with wrong input value."""
with pytest.raises(AttributeError):
with pytest.raises(ValueError):
clean_requirements([1, 2, 3]) # List of non-string elements


def test_clean_requirements_mixed_types():
"""Test clean_requirements function with mixed types in list."""
with pytest.raises(ValueError):
clean_requirements(["package1==1.0.0", 2, "package3<3.0.0"])

0 comments on commit 05dd774

Please sign in to comment.