From 97560006d4e38cc0a362c6453180d30a610b34f5 Mon Sep 17 00:00:00 2001 From: geo-martino Date: Tue, 3 Dec 2024 16:08:44 -0500 Subject: [PATCH] fix comparer expected on is null --- docs/info/release-history.rst | 1 + musify/processors/compare.py | 4 +++- tests/processors/test_compare.py | 10 ++++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/docs/info/release-history.rst b/docs/info/release-history.rst index ffd6caa6..d9ca7ce7 100644 --- a/docs/info/release-history.rst +++ b/docs/info/release-history.rst @@ -42,6 +42,7 @@ Changed Fixed ----- * Paths are now sanitised when assigning ``filename`` to :py:class:`.LocalTrack` +* :py:class:`.Comparer` no longer needs an expected value set for methods which do not use it 1.2.0 diff --git a/musify/processors/compare.py b/musify/processors/compare.py index dfd82bce..844f8f01 100644 --- a/musify/processors/compare.py +++ b/musify/processors/compare.py @@ -1,6 +1,7 @@ """ Processor making comparisons between objects and data types. """ +import inspect import re from collections.abc import Sequence, Hashable from datetime import datetime, date @@ -89,7 +90,8 @@ def compare[T: Any](self, item: T, reference: T | None = None) -> bool: if reference is None and self.reference_required: raise ComparerError("A reference is required for this instance of Comparer") - if reference is None and not self.expected: + signature = inspect.getfullargspec(self._processor_method) + if reference is None and "expected" in signature.args and not self.expected: raise ComparerError("No comparative item given and no expected values set") tag_name = None diff --git a/tests/processors/test_compare.py b/tests/processors/test_compare.py index 5e199b29..be80a323 100644 --- a/tests/processors/test_compare.py +++ b/tests/processors/test_compare.py @@ -67,6 +67,16 @@ def test_equality(self, obj: Comparer): new_filter.field = choice(obj.field.all()) assert obj != new_filter + def test_compare_on_no_expected_value(self): + comparer = Comparer(condition="is null", field=LocalTrackField.DISC_TOTAL) + track = random_track() + track.disc_total = None + + assert comparer.compare(track) + + comparer = Comparer(condition="is not null", field=LocalTrackField.DISC_TOTAL) + assert not comparer.compare(track) + def test_compare_with_reference(self): track_1 = random_track() track_2 = random_track()