Skip to content

Commit

Permalink
fix comparer expected on is null
Browse files Browse the repository at this point in the history
  • Loading branch information
geo-martino committed Dec 3, 2024
1 parent 4f00aaf commit 9756000
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/info/release-history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Changed
Fixed
-----
* Paths are now sanitised when assigning ``filename`` to :py:class:`.LocalTrack`
* :py:class:`.Comparer` no longer needs an expected value set for methods which do not use it


1.2.0
Expand Down
4 changes: 3 additions & 1 deletion musify/processors/compare.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Processor making comparisons between objects and data types.
"""
import inspect
import re
from collections.abc import Sequence, Hashable
from datetime import datetime, date
Expand Down Expand Up @@ -89,7 +90,8 @@ def compare[T: Any](self, item: T, reference: T | None = None) -> bool:

if reference is None and self.reference_required:
raise ComparerError("A reference is required for this instance of Comparer")
if reference is None and not self.expected:
signature = inspect.getfullargspec(self._processor_method)
if reference is None and "expected" in signature.args and not self.expected:
raise ComparerError("No comparative item given and no expected values set")

tag_name = None
Expand Down
10 changes: 10 additions & 0 deletions tests/processors/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ def test_equality(self, obj: Comparer):
new_filter.field = choice(obj.field.all())
assert obj != new_filter

def test_compare_on_no_expected_value(self):
comparer = Comparer(condition="is null", field=LocalTrackField.DISC_TOTAL)
track = random_track()
track.disc_total = None

assert comparer.compare(track)

comparer = Comparer(condition="is not null", field=LocalTrackField.DISC_TOTAL)
assert not comparer.compare(track)

def test_compare_with_reference(self):
track_1 = random_track()
track_2 = random_track()
Expand Down

0 comments on commit 9756000

Please sign in to comment.