Skip to content

Commit

Permalink
Allow provider to filter unsatisfied names
Browse files Browse the repository at this point in the history
  • Loading branch information
notatallshaw committed Dec 1, 2023
1 parent 044ab9f commit 9084a15
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 3 deletions.
3 changes: 3 additions & 0 deletions examples/extras_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ def get_dependencies(self, candidate):
req = self.get_base_requirement(candidate)
deps.append(req)
return deps

def filter_unsatisfied_names(self, unsatisfied_names, causes):
return unsatisfied_names
3 changes: 3 additions & 0 deletions examples/reporter_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def is_satisfied_by(self, requirement, candidate):
def get_dependencies(self, candidate):
return self.candidates[candidate]

def filter_unsatisfied_names(self, unsatisfied_names, causes):
return unsatisfied_names


class Reporter(resolvelib.BaseReporter):
def starting(self):
Expand Down
16 changes: 16 additions & 0 deletions src/resolvelib/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,19 @@ def get_dependencies(self, candidate: CT) -> Iterable[RT]:
specifies as its dependencies.
"""
raise NotImplementedError

def filter_unsatisfied_names(
self,
unsatisfied_names: Iterable[KT],
causes: Sequence[RequirementInformation[RT, CT]],
) -> Iterable[KT]:
"""Filter unsatisfied names before choosing a preference.
Used as an optimizion to reduce the number of calls to
get_preference. It is recomended to return unsatisfied names
that relate to the causes, or the conflict between causes.
Must return a subset of unsatisfied_names, must return at least
one unsatisfied name.
"""
raise NotImplementedError
22 changes: 19 additions & 3 deletions src/resolvelib/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,12 +423,28 @@ def resolve(
return self.state

# keep track of satisfied names to calculate diff after pinning
satisfied_names = set(self.state.criteria.keys()) - set(
unsatisfied_names
unsatisfied_names_set = set(unsatisfied_names)
satisfied_names = (
set(self.state.criteria.keys()) - unsatisfied_names_set
)

if len(unsatisfied_names) > 1:
filtered_unstatisfied_names = list(
self._p.filter_unsatisfied_names(
unsatisfied_names_set, self.state.backtrack_causes
)
)
else:
filtered_unstatisfied_names = unsatisfied_names

# Choose the most preferred unpinned criterion to try.
name = min(unsatisfied_names, key=self._get_preference)
if len(filtered_unstatisfied_names) > 1:
name = min(
filtered_unstatisfied_names, key=self._get_preference
)
else:
name = filtered_unstatisfied_names[0]

failure_causes = self._attempt_to_pin_criterion(name)

if failure_causes:
Expand Down
3 changes: 3 additions & 0 deletions tests/functional/cocoapods/test_resolvers_cocoapods.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ def is_satisfied_by(self, requirement, candidate):
def get_dependencies(self, candidate):
return candidate.deps

def filter_unsatisfied_names(self, unsatisfied_names, causes):
return unsatisfied_names


XFAIL_CASES = {
# ResolveLib does not complain about cycles, so these will be different.
Expand Down
3 changes: 3 additions & 0 deletions tests/functional/python/test_resolvers_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def _iter_dependencies(self, candidate):
def get_dependencies(self, candidate):
return list(self._iter_dependencies(candidate))

def filter_unsatisfied_names(self, unsatisfied_names, causes):
return unsatisfied_names


INPUTS_DIR = os.path.abspath(os.path.join(__file__, "..", "inputs"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import json
import operator
import os
from typing import Iterable, Sequence

import pytest

from resolvelib.providers import AbstractProvider
from resolvelib.resolvers import Resolver
from resolvelib.structs import RequirementInformation

Requirement = collections.namedtuple("Requirement", "container constraint")
Candidate = collections.namedtuple("Candidate", "container version")
Expand Down Expand Up @@ -132,6 +134,9 @@ def _iter_dependencies(self, candidate):
def get_dependencies(self, candidate):
return list(self._iter_dependencies(candidate))

def filter_unsatisfied_names(self, unsatisfied_names, causes):
return unsatisfied_names


@pytest.fixture(
params=[os.path.join(INPUTS_DIR, n) for n in INPUT_NAMES],
Expand Down
12 changes: 12 additions & 0 deletions tests/test_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def is_satisfied_by(self, requirement, candidate):
assert candidate is self.candidate
return False

def filter_unsatisfied_names(self, unsatisfied_names, causes):
return unsatisfied_names

resolver = Resolver(Provider(requirement, candidate), BaseReporter())

with pytest.raises(InconsistentCandidate) as ctx:
Expand Down Expand Up @@ -104,6 +107,9 @@ def find_matches(self, identifier, requirements, incompatibilities):
def is_satisfied_by(self, requirement, candidate):
return candidate[1] in requirement[1]

def filter_unsatisfied_names(self, unsatisfied_names, causes):
return unsatisfied_names

# Now when resolved, both requirements to child specified by parent should
# be pulled, and the resolver should choose v1, not v2 (happens if the
# v1-only requirement is dropped).
Expand Down Expand Up @@ -164,6 +170,9 @@ def find_matches(self, identifier, requirements, incompatibilities):
def is_satisfied_by(self, requirement, candidate):
return candidate.version in requirement.versions

def filter_unsatisfied_names(self, unsatisfied_names, causes):
return unsatisfied_names

def run_resolver(*args):
reporter = Reporter()
resolver = Resolver(Provider(), reporter)
Expand Down Expand Up @@ -243,6 +252,9 @@ def is_satisfied_by(
) -> bool:
return candidate[1] in Requirement(requirement).specifier

def filter_unsatisfied_names(self, unsatisfied_names, causes):
return unsatisfied_names

# patch Resolution._get_updated_criteria to collect rejected states
rejected_criteria: list[Criterion] = []
get_updated_criteria_orig = (
Expand Down

0 comments on commit 9084a15

Please sign in to comment.