Skip to content

Commit

Permalink
Merge pull request #173 from scipy/strict-dtypes
Browse files Browse the repository at this point in the history
ENH: add a strict_check flag to (optionally) check dtypes
  • Loading branch information
ev-br authored Aug 23, 2024
2 parents 3365240 + c0e7adc commit 9ffe962
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 4 deletions.
26 changes: 22 additions & 4 deletions scipy_doctest/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ class DTConfig:
rtol : float
Absolute and relative tolerances to check doctest examples with.
Specifically, the check is ``np.allclose(want, got, atol=atol, rtol=rtol)``
strict_check : bool
Whether to check that dtypes match or rely on the lax definition of
equality of numpy objects. For instance, `3 == np.float64(3)`, but
dtypes do not match.
Default is False.
optionflags : int
doctest optionflags
Default is ``NORMALIZE_WHITESPACE | ELLIPSIS | IGNORE_EXCEPTION_DETAIL``
Expand Down Expand Up @@ -107,6 +112,7 @@ def __init__(self, *, # DTChecker configuration
rndm_markers=None,
atol=1e-8,
rtol=1e-2,
strict_check=False,
# DTRunner configuration
optionflags=None,
# DTFinder/DTParser configuration
Expand Down Expand Up @@ -161,8 +167,8 @@ def __init__(self, *, # DTChecker configuration
'#random', '#Random',
"# may vary"}
self.rndm_markers = rndm_markers

self.atol, self.rtol = atol, rtol
self.strict_check = strict_check

### DTRunner configuration ###

Expand Down Expand Up @@ -363,23 +369,35 @@ def check_output(self, want, got, optionflags):
return False

# ... and defer to numpy
strict = self.config.strict_check
try:
return self._do_check(a_want, a_got)
return self._do_check(a_want, a_got, strict)
except Exception:
# heterog tuple, eg (1, np.array([1., 2.]))
try:
return all(self._do_check(w, g) for w, g in zip_longest(a_want, a_got))
return all(
self._do_check(w, g, strict) for w, g in zip_longest(a_want, a_got)
)
except (TypeError, ValueError):
return False

def _do_check(self, want, got):
def _do_check(self, want, got, strict_check):
# This should be done exactly as written to correctly handle all of
# numpy-comparable objects, strings, and heterogeneous tuples

# NB: 3 == np.float64(3.0) but dtypes differ
if strict_check:
want_dtype = np.asarray(want).dtype
got_dtype = np.asarray(got).dtype
if want_dtype != got_dtype:
return False

try:
if want == got:
return True
except Exception:
pass

with warnings.catch_warnings():
# NumPy's ragged array deprecation of np.array([1, (2, 3)])
warnings.simplefilter('ignore', VisibleDeprecationWarning)
Expand Down
8 changes: 8 additions & 0 deletions scipy_doctest/tests/failure_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,11 @@ def tuple_and_list_2():
>>> (0, 1, 2)
[0, 1, 2]
"""


def dtype_mismatch():
"""
>>> import numpy as np
>>> 3.0
3
"""
10 changes: 10 additions & 0 deletions scipy_doctest/tests/test_testmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ def test_tuple_and_list():
assert res.failed == 2


@pytest.mark.parametrize('strict, num_fails', [(True, 1), (False, 0)])
class TestStrictDType:
def test_np_fix(self, strict, num_fails):
config = DTConfig(strict_check=strict)
res, _ = _testmod(failure_cases,
strategy=[failure_cases.dtype_mismatch],
config=config)
assert res.failed == num_fails


class TestLocalFiles:
def test_local_files(self):
# A doctest tries to open a local file. Test that it works
Expand Down

0 comments on commit 9ffe962

Please sign in to comment.