diff --git a/src/pipupgrade/_pip.py b/src/pipupgrade/_pip.py index d8970a7..212a4d4 100644 --- a/src/pipupgrade/_pip.py +++ b/src/pipupgrade/_pip.py @@ -5,6 +5,7 @@ import pip import json import os.path as osp +from typing import List # imports - module imports from pipupgrade.util.system import which, popen @@ -15,16 +16,34 @@ logger = get_logger() -PIP9 = int(pip.__version__.split(".")[0]) < 10 +MAJOR_VERSION = int(pip.__version__.split(".")[0]) -if PIP9: - # from pip import get_installed_distributions - from pip.req import parse_requirements - from pip.req.req_install import InstallRequirement -else: +if MAJOR_VERSION >= 20: + from pip._internal.req.constructors import install_req_from_parsed_requirement + from pip._internal.req.req_file import ( + parse_requirements as _real_parse_requirements, + ) + from pip._internal.req.req_install import InstallRequirement + + def parse_requirements( + filename, session + ): # type: (str, str) -> List[InstallRequirement] + """Wrap pip internal `parse_requirements`, which now returns + `ParsedRequirement` instances, to instead return `InstallRequirement`- + as with the previous implementation + Based on https://github.com/pypa/pip/blob/a48ad5385b234097d51283b08c3d933fd81ef534/tests/unit/test_req_file.py#L50""" + for parsed_req in _real_parse_requirements(filename, session): + yield install_req_from_parsed_requirement(parsed_req) + + +elif MAJOR_VERSION >= 10: # from pip._internal.utils.misc import get_installed_distributions - from pip._internal.req import parse_requirements + from pip._internal.req import parse_requirements from pip._internal.req.req_install import InstallRequirement +else: + # from pip import get_installed_distributions + from pip.req import parse_requirements + from pip.req.req_install import InstallRequirement from pip._vendor.pkg_resources import ( Distribution, diff --git a/tests/pipupgrade/test__pip.py b/tests/pipupgrade/test__pip.py index 8c82e75..0688079 100644 --- a/tests/pipupgrade/test__pip.py +++ b/tests/pipupgrade/test__pip.py @@ -2,10 +2,10 @@ from __future__ import absolute_import # imports - standard imports -import subprocess +import os # imports - test imports -import pytest +from testutils import PATH # imports - module imports from pipupgrade import _pip @@ -43,8 +43,17 @@ def _assert_outerr(routerr, outerr): _pip.call("install", "pipupgrade") assert_pip_call(_pip.call("install", "pipupgrade", quiet = True)) - + _pip.call("install", "pipupgrade", log = path) assert tempfile.read() - # assert_pip_call(_pip.call("list", output = True)) \ No newline at end of file + # assert_pip_call(_pip.call("list", output = True)) + + +def test_parse_requirements(): + """`parse_requirements` returns an iterable of `InstallRequirement`""" + filepath = os.path.join(PATH["DATA"], "project", "requirements.txt") + + requirements = list(_pip.parse_requirements(filepath, session="hack")) + + assert all([isinstance(req, _pip.InstallRequirement) for req in requirements])