From 027f7a11e4032eb26a05659e5fe1c9371c52f46d Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Thu, 10 Feb 2022 15:41:26 -0600 Subject: [PATCH] Centralize 'from_module' checking and check the whole mro. (#557) --- src/uproot/_util.py | 21 +++++++++++++++ src/uproot/interpretation/objects.py | 4 +-- src/uproot/language/python.py | 6 ++--- src/uproot/source/chunk.py | 1 + src/uproot/writing/_cascadetree.py | 22 ++++++--------- src/uproot/writing/identify.py | 12 +++------ tests/test_0001-source-class.py | 3 +++ tests/test_0006-notify-when-downloaded.py | 3 +++ tests/test_0007-single-chunk-interface.py | 33 ++++++++++++++--------- 9 files changed, 64 insertions(+), 41 deletions(-) diff --git a/src/uproot/_util.py b/src/uproot/_util.py index 6bb81a170..35dc8b8f6 100644 --- a/src/uproot/_util.py +++ b/src/uproot/_util.py @@ -90,6 +90,27 @@ def ensure_numpy(array, types=(numpy.bool_, numpy.integer, numpy.floating)): _regularize_filter_regex = re.compile("^/(.*)/([iLmsux]*)$") +def from_module(obj, module_name): + """ + Returns True if ``obj`` is an instance of a class from a module + given by name. + + This is like ``isinstance`` (in that it searches the whole ``mro``), + except that the module providing the type to check against doesn't + have to be imported and doesn't get imported (as a side effect) by + this function. + """ + try: + mro = type(obj).mro() + except TypeError: + return False + + for t in mro: + if t.__module__ == module_name or t.__module__.startswith(module_name + "."): + return True + return False + + def _regularize_filter_regex_flags(flags): flagsbyte = 0 for flag in flags: diff --git a/src/uproot/interpretation/objects.py b/src/uproot/interpretation/objects.py index 6da9b3512..ceacf6f6b 100644 --- a/src/uproot/interpretation/objects.py +++ b/src/uproot/interpretation/objects.py @@ -205,9 +205,7 @@ def final_array( start = stop - if all( - type(x).__module__.startswith("awkward") for x in basket_arrays.values() - ): + if all(uproot._util.from_module(x, "awkward") for x in basket_arrays.values()): assert isinstance(library, uproot.interpretation.library.Awkward) awkward = library.imported output = awkward.concatenate(trimmed, mergebool=False, highlevel=False) diff --git a/src/uproot/language/python.py b/src/uproot/language/python.py index 315dbc56c..69b4b52e6 100644 --- a/src/uproot/language/python.py +++ b/src/uproot/language/python.py @@ -455,8 +455,7 @@ def getter(name): file_path, object_path, )() - module_name = type(output[expression]).__module__ - if module_name == "pandas" or module_name.startswith("pandas."): + if uproot._util.from_module(output[expression], "pandas"): is_pandas = True cut = None @@ -472,8 +471,7 @@ def getter(name): file_path, object_path, )() - module_name = type(cut).__module__ - if module_name == "pandas" or module_name.startswith("pandas."): + if uproot._util.from_module(cut, "pandas"): is_pandas = True break diff --git a/src/uproot/source/chunk.py b/src/uproot/source/chunk.py index 1b85c75e4..3dad916de 100644 --- a/src/uproot/source/chunk.py +++ b/src/uproot/source/chunk.py @@ -28,6 +28,7 @@ class Resource: :doc:`uproot.source.futures.ResourceFuture`. """ + @property def file_path(self): """ A path to the file (or URL). diff --git a/src/uproot/writing/_cascadetree.py b/src/uproot/writing/_cascadetree.py index 0914cf5e0..8bce03ff2 100644 --- a/src/uproot/writing/_cascadetree.py +++ b/src/uproot/writing/_cascadetree.py @@ -116,7 +116,7 @@ def __init__( else: try: - if type(branch_type).__module__.startswith("awkward."): + if uproot._util.from_module(branch_type, "awkward"): raise TypeError if ( uproot._util.isstr(branch_type) @@ -509,15 +509,14 @@ def extend(self, file, sink, data): sink.flush() provided = None - module_name = type(data).__module__ - if module_name == "pandas" or module_name.startswith("pandas."): + if uproot._util.from_module(data, "pandas"): import pandas if isinstance(data, pandas.DataFrame) and data.index.is_numeric(): provided = dataframe_to_dict(data) - if module_name == "awkward" or module_name.startswith("awkward."): + if uproot._util.from_module(data, "awkward"): try: awkward = uproot.extras.awkward() except ModuleNotFoundError as err: @@ -550,12 +549,9 @@ def extend(self, file, sink, data): provided = {} for k, v in data.items(): - module_name = type(v).__module__ - if not ( - module_name == "pandas" or module_name.startswith("pandas.") - ) and not ( - module_name == "awkward" or module_name.startswith("awkward.") - ): + if not uproot._util.from_module( + v, "pandas" + ) and not uproot._util.from_module(v, "awkward"): if not hasattr(v, "dtype") and not isinstance(v, Mapping): try: with warnings.catch_warnings(): @@ -583,8 +579,7 @@ def extend(self, file, sink, data): ) from err v = awkward.from_iter(v) - module_name = type(v).__module__ - if module_name == "awkward" or module_name.startswith("awkward."): + if uproot._util.from_module(v, "awkward"): try: awkward = uproot.extras.awkward() except ModuleNotFoundError as err: @@ -622,8 +617,7 @@ def extend(self, file, sink, data): if datum["name"] in provided: recordarray = provided.pop(datum["name"]) - module_name = type(recordarray).__module__ - if module_name == "pandas" or module_name.startswith("pandas."): + if uproot._util.from_module(recordarray, "pandas"): import pandas if isinstance(recordarray, pandas.DataFrame): diff --git a/src/uproot/writing/identify.py b/src/uproot/writing/identify.py index cb9b27d12..f9c7f02f4 100644 --- a/src/uproot/writing/identify.py +++ b/src/uproot/writing/identify.py @@ -48,15 +48,14 @@ def add_to_directory(obj, name, directory, streamers): Raises ``TypeError`` if ``obj`` is not recognized as writable data. """ is_ttree = False - module_name = type(obj).__module__ - if module_name == "pandas" or module_name.startswith("pandas."): + if uproot._util.from_module(obj, "pandas"): import pandas if isinstance(obj, pandas.DataFrame) and obj.index.is_numeric(): obj = uproot.writing._cascadetree.dataframe_to_dict(obj) - if module_name == "awkward" or module_name.startswith("awkward."): + if uproot._util.from_module(obj, "awkward"): import awkward if isinstance(obj, awkward.Array): @@ -70,9 +69,7 @@ def add_to_directory(obj, name, directory, streamers): metadata = {} for branch_name, branch_array in obj.items(): - module_name = type(branch_array).__module__ - - if module_name == "pandas" or module_name.startswith("pandas."): + if uproot._util.from_module(branch_array, "pandas"): import pandas if isinstance(branch_array, pandas.DataFrame): @@ -113,8 +110,7 @@ def add_to_directory(obj, name, directory, streamers): metadata[branch_name] = metadatum else: - module_name = type(branch_array).__module__ - if module_name == "awkward" or module_name.startswith("awkward."): + if uproot._util.from_module(branch_array, "awkward"): data[branch_name] = branch_array metadata[branch_name] = branch_array.type diff --git a/tests/test_0001-source-class.py b/tests/test_0001-source-class.py index e5f8850ea..dc2f83088 100644 --- a/tests/test_0001-source-class.py +++ b/tests/test_0001-source-class.py @@ -94,6 +94,7 @@ def test_memmap_fail(tmpdir): uproot.source.file.MultithreadedFileSource(filename + "-does-not-exist") +@pytest.mark.skip(reason="RECHECK: example.com is flaky, too") @pytest.mark.network def test_http(): source = uproot.source.http.HTTPSource( @@ -117,6 +118,7 @@ def test_http(): assert [tobytes(x.raw_data) for x in chunks] == [one, two, three] +@pytest.mark.skip(reason="RECHECK: example.com is flaky, too") def colons_and_ports(): assert uproot._util.file_object_path_split("https://example.com:443") == ( "https://example.com:443", @@ -131,6 +133,7 @@ def colons_and_ports(): ) == ("https://example.com:443/something", "else") +@pytest.mark.skip(reason="RECHECK: example.com is flaky, too") @pytest.mark.network def test_http_port(): source = uproot.source.http.HTTPSource( diff --git a/tests/test_0006-notify-when-downloaded.py b/tests/test_0006-notify-when-downloaded.py index 0eb6f027f..7d4e4edd4 100644 --- a/tests/test_0006-notify-when-downloaded.py +++ b/tests/test_0006-notify-when-downloaded.py @@ -63,6 +63,7 @@ def test_memmap(tmpdir): expected.pop((chunk.start, chunk.stop)) +@pytest.mark.skip(reason="RECHECK: example.com is flaky, too") @pytest.mark.network def test_http_multipart(): notifications = queue.Queue() @@ -78,6 +79,7 @@ def test_http_multipart(): expected.pop((chunk.start, chunk.stop)) +@pytest.mark.skip(reason="RECHECK: example.com is flaky, too") @pytest.mark.network def test_http(): notifications = queue.Queue() @@ -93,6 +95,7 @@ def test_http(): expected.pop((chunk.start, chunk.stop)) +@pytest.mark.skip(reason="RECHECK: example.com is flaky, too") @pytest.mark.network def test_http_workers(): notifications = queue.Queue() diff --git a/tests/test_0007-single-chunk-interface.py b/tests/test_0007-single-chunk-interface.py index bcc00a5b5..0ded2c4ff 100644 --- a/tests/test_0007-single-chunk-interface.py +++ b/tests/test_0007-single-chunk-interface.py @@ -74,6 +74,7 @@ def test_memmap(tmpdir): ) +@pytest.mark.skip(reason="RECHECK: example.com is flaky, too") @pytest.mark.network def test_http(): for num_workers in [1, 2]: @@ -84,15 +85,20 @@ def test_http(): chunk = source.chunk(start, stop) assert len(tobytes(chunk.raw_data)) == stop - start - with pytest.raises(Exception): - with uproot.source.http.MultithreadedHTTPSource( - "https://wonky.cern/does-not-exist", - num_workers=num_workers, - timeout=0.1, - ) as source: - source.chunk(0, 100) + +@pytest.mark.network +def test_http_fail(): + for num_workers in [1, 2]: + with pytest.raises(Exception): + with uproot.source.http.MultithreadedHTTPSource( + "https://wonky.cern/does-not-exist", + num_workers=num_workers, + timeout=0.1, + ) as source: + source.chunk(0, 100) +@pytest.mark.skip(reason="RECHECK: example.com is flaky, too") @pytest.mark.network def test_http_multipart(): with uproot.source.http.HTTPSource( @@ -102,11 +108,14 @@ def test_http_multipart(): chunk = source.chunk(start, stop) assert len(tobytes(chunk.raw_data)) == stop - start - with pytest.raises(Exception): - with uproot.source.http.HTTPSource( - "https://wonky.cern/does-not-exist", timeout=0.1, num_fallback_workers=1 - ) as source: - tobytes(source.chunk(0, 100).raw_data) + +@pytest.mark.network +def test_http_multipart_fail(): + with pytest.raises(Exception): + with uproot.source.http.HTTPSource( + "https://wonky.cern/does-not-exist", timeout=0.1, num_fallback_workers=1 + ) as source: + tobytes(source.chunk(0, 100).raw_data) @pytest.mark.skip(