Skip to content

Commit

Permalink
Centralize 'from_module' checking and check the whole mro. (#557)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski authored Feb 10, 2022
1 parent 53c5a99 commit 027f7a1
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 41 deletions.
21 changes: 21 additions & 0 deletions src/uproot/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,27 @@ def ensure_numpy(array, types=(numpy.bool_, numpy.integer, numpy.floating)):
_regularize_filter_regex = re.compile("^/(.*)/([iLmsux]*)$")


def from_module(obj, module_name):
"""
Returns True if ``obj`` is an instance of a class from a module
given by name.
This is like ``isinstance`` (in that it searches the whole ``mro``),
except that the module providing the type to check against doesn't
have to be imported and doesn't get imported (as a side effect) by
this function.
"""
try:
mro = type(obj).mro()
except TypeError:
return False

for t in mro:
if t.__module__ == module_name or t.__module__.startswith(module_name + "."):
return True
return False


def _regularize_filter_regex_flags(flags):
flagsbyte = 0
for flag in flags:
Expand Down
4 changes: 1 addition & 3 deletions src/uproot/interpretation/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,7 @@ def final_array(

start = stop

if all(
type(x).__module__.startswith("awkward") for x in basket_arrays.values()
):
if all(uproot._util.from_module(x, "awkward") for x in basket_arrays.values()):
assert isinstance(library, uproot.interpretation.library.Awkward)
awkward = library.imported
output = awkward.concatenate(trimmed, mergebool=False, highlevel=False)
Expand Down
6 changes: 2 additions & 4 deletions src/uproot/language/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,7 @@ def getter(name):
file_path,
object_path,
)()
module_name = type(output[expression]).__module__
if module_name == "pandas" or module_name.startswith("pandas."):
if uproot._util.from_module(output[expression], "pandas"):
is_pandas = True

cut = None
Expand All @@ -472,8 +471,7 @@ def getter(name):
file_path,
object_path,
)()
module_name = type(cut).__module__
if module_name == "pandas" or module_name.startswith("pandas."):
if uproot._util.from_module(cut, "pandas"):
is_pandas = True
break

Expand Down
1 change: 1 addition & 0 deletions src/uproot/source/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Resource:
:doc:`uproot.source.futures.ResourceFuture`.
"""

@property
def file_path(self):
"""
A path to the file (or URL).
Expand Down
22 changes: 8 additions & 14 deletions src/uproot/writing/_cascadetree.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(

else:
try:
if type(branch_type).__module__.startswith("awkward."):
if uproot._util.from_module(branch_type, "awkward"):
raise TypeError
if (
uproot._util.isstr(branch_type)
Expand Down Expand Up @@ -509,15 +509,14 @@ def extend(self, file, sink, data):
sink.flush()

provided = None
module_name = type(data).__module__

if module_name == "pandas" or module_name.startswith("pandas."):
if uproot._util.from_module(data, "pandas"):
import pandas

if isinstance(data, pandas.DataFrame) and data.index.is_numeric():
provided = dataframe_to_dict(data)

if module_name == "awkward" or module_name.startswith("awkward."):
if uproot._util.from_module(data, "awkward"):
try:
awkward = uproot.extras.awkward()
except ModuleNotFoundError as err:
Expand Down Expand Up @@ -550,12 +549,9 @@ def extend(self, file, sink, data):

provided = {}
for k, v in data.items():
module_name = type(v).__module__
if not (
module_name == "pandas" or module_name.startswith("pandas.")
) and not (
module_name == "awkward" or module_name.startswith("awkward.")
):
if not uproot._util.from_module(
v, "pandas"
) and not uproot._util.from_module(v, "awkward"):
if not hasattr(v, "dtype") and not isinstance(v, Mapping):
try:
with warnings.catch_warnings():
Expand Down Expand Up @@ -583,8 +579,7 @@ def extend(self, file, sink, data):
) from err
v = awkward.from_iter(v)

module_name = type(v).__module__
if module_name == "awkward" or module_name.startswith("awkward."):
if uproot._util.from_module(v, "awkward"):
try:
awkward = uproot.extras.awkward()
except ModuleNotFoundError as err:
Expand Down Expand Up @@ -622,8 +617,7 @@ def extend(self, file, sink, data):
if datum["name"] in provided:
recordarray = provided.pop(datum["name"])

module_name = type(recordarray).__module__
if module_name == "pandas" or module_name.startswith("pandas."):
if uproot._util.from_module(recordarray, "pandas"):
import pandas

if isinstance(recordarray, pandas.DataFrame):
Expand Down
12 changes: 4 additions & 8 deletions src/uproot/writing/identify.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,14 @@ def add_to_directory(obj, name, directory, streamers):
Raises ``TypeError`` if ``obj`` is not recognized as writable data.
"""
is_ttree = False
module_name = type(obj).__module__

if module_name == "pandas" or module_name.startswith("pandas."):
if uproot._util.from_module(obj, "pandas"):
import pandas

if isinstance(obj, pandas.DataFrame) and obj.index.is_numeric():
obj = uproot.writing._cascadetree.dataframe_to_dict(obj)

if module_name == "awkward" or module_name.startswith("awkward."):
if uproot._util.from_module(obj, "awkward"):
import awkward

if isinstance(obj, awkward.Array):
Expand All @@ -70,9 +69,7 @@ def add_to_directory(obj, name, directory, streamers):
metadata = {}

for branch_name, branch_array in obj.items():
module_name = type(branch_array).__module__

if module_name == "pandas" or module_name.startswith("pandas."):
if uproot._util.from_module(branch_array, "pandas"):
import pandas

if isinstance(branch_array, pandas.DataFrame):
Expand Down Expand Up @@ -113,8 +110,7 @@ def add_to_directory(obj, name, directory, streamers):
metadata[branch_name] = metadatum

else:
module_name = type(branch_array).__module__
if module_name == "awkward" or module_name.startswith("awkward."):
if uproot._util.from_module(branch_array, "awkward"):
data[branch_name] = branch_array
metadata[branch_name] = branch_array.type

Expand Down
3 changes: 3 additions & 0 deletions tests/test_0001-source-class.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def test_memmap_fail(tmpdir):
uproot.source.file.MultithreadedFileSource(filename + "-does-not-exist")


@pytest.mark.skip(reason="RECHECK: example.com is flaky, too")
@pytest.mark.network
def test_http():
source = uproot.source.http.HTTPSource(
Expand All @@ -117,6 +118,7 @@ def test_http():
assert [tobytes(x.raw_data) for x in chunks] == [one, two, three]


@pytest.mark.skip(reason="RECHECK: example.com is flaky, too")
def colons_and_ports():
assert uproot._util.file_object_path_split("https://example.com:443") == (
"https://example.com:443",
Expand All @@ -131,6 +133,7 @@ def colons_and_ports():
) == ("https://example.com:443/something", "else")


@pytest.mark.skip(reason="RECHECK: example.com is flaky, too")
@pytest.mark.network
def test_http_port():
source = uproot.source.http.HTTPSource(
Expand Down
3 changes: 3 additions & 0 deletions tests/test_0006-notify-when-downloaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_memmap(tmpdir):
expected.pop((chunk.start, chunk.stop))


@pytest.mark.skip(reason="RECHECK: example.com is flaky, too")
@pytest.mark.network
def test_http_multipart():
notifications = queue.Queue()
Expand All @@ -78,6 +79,7 @@ def test_http_multipart():
expected.pop((chunk.start, chunk.stop))


@pytest.mark.skip(reason="RECHECK: example.com is flaky, too")
@pytest.mark.network
def test_http():
notifications = queue.Queue()
Expand All @@ -93,6 +95,7 @@ def test_http():
expected.pop((chunk.start, chunk.stop))


@pytest.mark.skip(reason="RECHECK: example.com is flaky, too")
@pytest.mark.network
def test_http_workers():
notifications = queue.Queue()
Expand Down
33 changes: 21 additions & 12 deletions tests/test_0007-single-chunk-interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_memmap(tmpdir):
)


@pytest.mark.skip(reason="RECHECK: example.com is flaky, too")
@pytest.mark.network
def test_http():
for num_workers in [1, 2]:
Expand All @@ -84,15 +85,20 @@ def test_http():
chunk = source.chunk(start, stop)
assert len(tobytes(chunk.raw_data)) == stop - start

with pytest.raises(Exception):
with uproot.source.http.MultithreadedHTTPSource(
"https://wonky.cern/does-not-exist",
num_workers=num_workers,
timeout=0.1,
) as source:
source.chunk(0, 100)

@pytest.mark.network
def test_http_fail():
for num_workers in [1, 2]:
with pytest.raises(Exception):
with uproot.source.http.MultithreadedHTTPSource(
"https://wonky.cern/does-not-exist",
num_workers=num_workers,
timeout=0.1,
) as source:
source.chunk(0, 100)


@pytest.mark.skip(reason="RECHECK: example.com is flaky, too")
@pytest.mark.network
def test_http_multipart():
with uproot.source.http.HTTPSource(
Expand All @@ -102,11 +108,14 @@ def test_http_multipart():
chunk = source.chunk(start, stop)
assert len(tobytes(chunk.raw_data)) == stop - start

with pytest.raises(Exception):
with uproot.source.http.HTTPSource(
"https://wonky.cern/does-not-exist", timeout=0.1, num_fallback_workers=1
) as source:
tobytes(source.chunk(0, 100).raw_data)

@pytest.mark.network
def test_http_multipart_fail():
with pytest.raises(Exception):
with uproot.source.http.HTTPSource(
"https://wonky.cern/does-not-exist", timeout=0.1, num_fallback_workers=1
) as source:
tobytes(source.chunk(0, 100).raw_data)


@pytest.mark.skip(
Expand Down

0 comments on commit 027f7a1

Please sign in to comment.