From 87ae02504ee30c1ce29ae028dc6e589c82ced1cd Mon Sep 17 00:00:00 2001 From: Beshr Kayali Date: Fri, 31 Jan 2020 12:03:54 +0100 Subject: [PATCH] Fix pip api for pip=>10 and read linked requirement files --- reqlice/__init__.py | 63 +++++++++++++++++++++++++++++------------- reqlice/requirement.py | 24 ++++++++++++---- 2 files changed, 62 insertions(+), 25 deletions(-) diff --git a/reqlice/__init__.py b/reqlice/__init__.py index ccbb8f2..f9fd609 100644 --- a/reqlice/__init__.py +++ b/reqlice/__init__.py @@ -9,12 +9,11 @@ from reqlice.requirement import get_valid_pypi_requirement, parse_requirements loop = asyncio.get_event_loop() -comment_start_tag = '# [license] ' -comment_start_re = re.compile(re.escape(comment_start_tag) + '.+') +comment_start_tag = "# [license] " +comment_start_re = re.compile(re.escape(comment_start_tag) + ".+") class Reqlice: - def __init__(self, target, out=sys.stdout, err=sys.stderr): """ :param target: 'path/to/requirements.txt' @@ -31,10 +30,32 @@ def write(self, *what): def error(self, *what): print(*what, file=self.err) + def _parse(self, reqs): + _parsed = [] + for req in reqs.splitlines(): + _req = req.strip() + if _req: + if _req.startswith("-r"): + _parsed += self.read_file( + os.path.join( + os.path.dirname(self.target), _req.partition(" ")[-1] + ) + ) + continue + # Skip comments + elif _req.startswith("#"): + continue + _parsed.append(_req) + + return _parsed + + def read_file(self, filepath): + with open(filepath) as f: + return self._parse(f.read()) + def read_target(self): if os.path.isfile(self.target): - with open(self.target) as f: - return f.read() + return self.read_file(self.target) # Assume string return self.target @@ -48,36 +69,39 @@ def output_requirements(self, package_license_dict): lines = [] # First pass, fetch requirement, strip line - for line in content.splitlines(): - line = comment_start_re.sub('', line).strip() + for line in content: + line = comment_start_re.sub("", line).strip() requirement = get_valid_pypi_requirement(line) if requirement is None: license = None else: - package_name = requirement.req.project_name + try: + package_name = requirement.req.project_name + except AttributeError: + package_name = requirement.req.name + license = package_license_dict.get(package_name) lines.append((line, license)) # Specify default=0 so as no to break when we have no licenses, # comment_startpoint will not be used in that case - comment_startpoint = max( - (len(line) for line, license in lines if license), - default=0 - ) + 2 + comment_startpoint = ( + max((len(line) for line, license in lines if license), default=0) + 2 + ) for line, license in lines: if license is None: self.write(line) continue - padding = ' ' * (comment_startpoint - len(line)) + padding = " " * (comment_startpoint - len(line)) - output = '{line}{padding}{comment_start}{license}'.format( + output = "{line}{padding}{comment_start}{license}".format( line=line, padding=padding, comment_start=comment_start_tag, - license=license + license=license, ) self.write(output) @@ -92,10 +116,10 @@ def fetch_licenses(self, packages): json_data = task.result() if json_data is None: - self.error('Could not fetch info for package:', package_name) + self.error("Could not fetch info for package:", package_name) continue - info = json_data['info'] + info = json_data["info"] license = parse_license(info) if license: @@ -114,11 +138,12 @@ def run(self): def cli(): if len(sys.argv) != 2: - print('Usage: reqlice path/to/requirements.txt') + print("Usage: reqlice path/to/requirements.txt") sys.exit(1) path_to_requirements = sys.argv[1] Reqlice(path_to_requirements).run() -if __name__ == '__main__': + +if __name__ == "__main__": cli() diff --git a/reqlice/requirement.py b/reqlice/requirement.py index 9b659d3..baea0ae 100644 --- a/reqlice/requirement.py +++ b/reqlice/requirement.py @@ -1,5 +1,13 @@ -from pip.req import parse_requirements as pip_parse_requirements -from pip.req import InstallRequirement +try: # for pip >= 10 + from pip._internal.req import parse_requirements as pip_parse_requirements + from pip._internal.req import InstallRequirement + from pip._internal.req.constructors import install_req_from_line as _from_line + +except ImportError: + from pip.req import parse_requirements as pip_parse_requirements + from pip.req import InstallRequirement + + _from_line = InstallRequirement.from_line def is_pypi_requirement(requirement): @@ -13,19 +21,23 @@ def parse_requirements(path_to_requirements): :return: ['package name', ..] """ parsed_reqs = [] - for requirement in pip_parse_requirements(path_to_requirements, - session=False): + for requirement in pip_parse_requirements(path_to_requirements, session=False): if not is_pypi_requirement(requirement): continue - parsed_reqs.append(requirement.req.project_name) + try: + _name = requirement.req.project_name + except AttributeError: + _name = requirement.req.name + + parsed_reqs.append(_name) return parsed_reqs def get_valid_pypi_requirement(line): try: - requirement = InstallRequirement.from_line(line) + requirement = _from_line(line) if not is_pypi_requirement(requirement): raise ValueError except ValueError: