Skip to content

Commit

Permalink
feat: support reading of custom-length RNTuple floats and suppressed …
Browse files Browse the repository at this point in the history
…columns (#1347)

* Started implementing reading of quantized and truncated floats

* Added support for suppressed columns

* Added tests

* Cleaner reading of floats with 1, 2, or 3 bytes

* Only support little-endian systems

* Fixed bug with Numpy 1

* Improved tests

* Fixed tests for Numpy 2
  • Loading branch information
ariostas authored Dec 19, 2024
1 parent 2ba58f2 commit 4ee1a2d
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 3 deletions.
4 changes: 4 additions & 0 deletions src/uproot/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@
rntuple_col_type_to_num_dict["splitindex32"],
rntuple_col_type_to_num_dict["splitindex64"],
)
rntuple_custom_float_types = (
rntuple_col_type_to_num_dict["real32trunc"],
rntuple_col_type_to_num_dict["real32quant"],
)


class RNTupleLocatorType(IntEnum):
Expand Down
57 changes: 54 additions & 3 deletions src/uproot/models/RNTuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import struct
import sys
from collections import defaultdict
from itertools import accumulate

Expand Down Expand Up @@ -241,6 +242,11 @@ def read_members(self, chunk, cursor, context, file):
f"""memberwise serialization of {type(self).__name__}
in file {self.file.file_path}"""
)
# Probably no one will encounter this, but just in case something doesn't work correctly
if sys.byteorder != "little":
raise NotImplementedError(
"RNTuple reading is only supported on little-endian systems"
)

(
self._members["fVersionEpoch"],
Expand Down Expand Up @@ -524,6 +530,8 @@ def base_col_form(self, cr, col_id, parameters=None, cardinality=False):
dt_str = uproot.const.rntuple_col_num_to_dtype_dict[dtype_byte]
if dt_str == "bit":
dt_str = "bool"
elif dtype_byte in uproot.const.rntuple_custom_float_types:
dt_str = "float32"
return ak.forms.NumpyForm(
dt_str,
form_key=form_key,
Expand All @@ -546,6 +554,8 @@ def col_form(self, field_id):
)

rel_crs = self._column_records_dict[cfid]
# for this part we can use the default (zeroth) representation
rel_crs = [c for c in rel_crs if c.repr_idx == 0]

if len(rel_crs) == 1: # base case
cardinality = "RNTupleCardinality" in self.field_records[field_id].type_name
Expand Down Expand Up @@ -673,9 +683,14 @@ def read_pagedesc(self, destination, desc, dtype_str, dtype, nbits, split):
context = {}
# bool in RNTuple is always stored as bits
isbit = dtype_str == "bit"
len_divider = 8 if isbit else 1
num_elements = len(destination)
num_elements_toread = int(numpy.ceil(num_elements / len_divider))
if isbit:
num_elements_toread = int(numpy.ceil(num_elements / 8))
elif dtype_str in ("real32trunc", "real32quant"):
num_elements_toread = int(numpy.ceil((num_elements * 4 * nbits) / 32))
dtype = numpy.dtype("uint8")
else:
num_elements_toread = num_elements
uncomp_size = num_elements_toread * dtype.itemsize
decomp_chunk, cursor = self.read_locator(loc, uncomp_size, context)
content = cursor.array(
Expand Down Expand Up @@ -722,6 +737,23 @@ def read_pagedesc(self, destination, desc, dtype_str, dtype, nbits, split):
.reshape(-1, 8)[:, ::-1]
.reshape(-1)
)
elif dtype_str in ("real32trunc", "real32quant"):
if nbits == 32:
content = content.view(numpy.uint32)
elif nbits % 8 == 0:
new_content = numpy.zeros((num_elements, 4), numpy.uint8)
nbytes = nbits // 8
new_content[:, :nbytes] = content.reshape(-1, nbytes)
content = new_content.view(numpy.uint32).reshape(-1)
else:
ak = uproot.extras.awkward()
vm = ak.forth.ForthMachine32(
f"""input x output y uint32 {num_elements} x #{nbits}bit-> y"""
)
vm.run({"x": content})
content = vm["y"]
if dtype_str == "real32trunc":
content <<= 32 - nbits

# needed to chop off extra bits incase we used `unpackbits`
destination[:] = content[:num_elements]
Expand Down Expand Up @@ -754,6 +786,10 @@ def read_col_pages(

def read_col_page(self, ncol, cluster_i):
linklist = self.page_list_envelopes.pagelinklist[cluster_i]
# Check if the column is suppressed and pick the non-suppressed one if so
if ncol < len(linklist) and linklist[ncol].suppressed:
rel_crs = self._column_records_dict[self.column_records[ncol].field_id]
ncol = next(cr.idx for cr in rel_crs if not linklist[cr.idx].suppressed)
pagelist = linklist[ncol].pages if ncol < len(linklist) else []
dtype_byte = self.column_records[ncol].type
dtype_str = uproot.const.rntuple_col_num_to_dtype_dict[dtype_byte]
Expand All @@ -762,14 +798,20 @@ def read_col_page(self, ncol, cluster_i):
dtype = numpy.dtype([("index", "int64"), ("tag", "int32")])
elif dtype_str == "bit":
dtype = numpy.dtype("bool")
elif dtype_byte in uproot.const.rntuple_custom_float_types:
dtype = numpy.dtype("uint32") # for easier bit manipulation
else:
dtype = numpy.dtype(dtype_str)
res = numpy.empty(total_len, dtype)
split = dtype_byte in uproot.const.rntuple_split_types
zigzag = dtype_byte in uproot.const.rntuple_zigzag_types
delta = dtype_byte in uproot.const.rntuple_delta_types
index = dtype_byte in uproot.const.rntuple_index_types
nbits = uproot.const.rntuple_col_num_to_size_dict[dtype_byte]
nbits = (
self.column_records[ncol].nbits
if ncol < len(self.column_records)
else uproot.const.rntuple_col_num_to_size_dict[dtype_byte]
)
tracker = 0
cumsum = 0
for page_desc in pagelist:
Expand All @@ -789,6 +831,15 @@ def read_col_page(self, ncol, cluster_i):
res = _from_zigzag(res)
elif delta:
res = numpy.cumsum(res)
elif dtype_str == "real32trunc":
res = res.view(numpy.float32)
elif dtype_str == "real32quant" and ncol < len(self.column_records):
min_value = self.column_records[ncol].min_value
max_value = self.column_records[ncol].max_value
res = min_value + res.astype(numpy.float32) * (max_value - min_value) / (
(1 << nbits) - 1
)
res = res.astype(numpy.float32)
return res

def arrays(
Expand Down
167 changes: 167 additions & 0 deletions tests/test_1347_rntuple_floats_suppressed_cols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/uproot5/blob/main/LICENSE

import skhep_testdata
import numpy as np

import uproot


def truncate_float(value, bits):
a = np.float32(value).view(np.uint32)
a &= np.uint32(0xFFFFFFFF) << (32 - bits)
return a.astype(np.uint32).view(np.float32)


def quantize_float(value, bits, min, max):
min = np.float32(min)
max = np.float32(max)
if value < min or value > max:
raise ValueError(f"Value {value} is out of range [{min}, {max}]")
scaled_value = (value - min) * (2**bits - 1) / (max - min)
int_value = np.round(scaled_value)
quantized_float = min + int_value * (max - min) / ((1 << bits) - 1)
return quantized_float.astype(np.float32)


def test_custom_floats():
filename = skhep_testdata.data_path("test_float_types_rntuple_v1-0-0-0.root")
with uproot.open(filename) as f:
obj = f["ntuple"]

arrays = obj.arrays()

min_value = -2.0
max_value = 3.0

entry = arrays[0]
true_value = 1.23456789
assert entry.trunc10 == truncate_float(true_value, 10)
assert entry.trunc16 == truncate_float(true_value, 16)
assert entry.trunc24 == truncate_float(true_value, 24)
assert entry.trunc31 == truncate_float(true_value, 31)
assert np.isclose(
entry.quant1, quantize_float(true_value, 1, min_value, max_value)
)
assert np.isclose(
entry.quant8, quantize_float(true_value, 8, min_value, max_value)
)
assert np.isclose(
entry.quant16, quantize_float(true_value, 16, min_value, max_value)
)
assert np.isclose(
entry.quant20, quantize_float(true_value, 20, min_value, max_value)
)
assert np.isclose(
entry.quant24, quantize_float(true_value, 24, min_value, max_value)
)
assert np.isclose(
entry.quant25, quantize_float(true_value, 25, min_value, max_value)
)
assert np.isclose(
entry.quant32, quantize_float(true_value, 32, min_value, max_value)
)

entry = arrays[1]
true_value = 1.4660155e13
assert entry.trunc10 == truncate_float(true_value, 10)
assert entry.trunc16 == truncate_float(true_value, 16)
assert entry.trunc24 == truncate_float(true_value, 24)
assert entry.trunc31 == truncate_float(true_value, 31)
true_value = 1.6666666
assert np.isclose(
entry.quant1, quantize_float(true_value, 1, min_value, max_value)
)
assert np.isclose(
entry.quant8, quantize_float(true_value, 8, min_value, max_value)
)
assert np.isclose(
entry.quant16, quantize_float(true_value, 16, min_value, max_value)
)
assert np.isclose(
entry.quant20, quantize_float(true_value, 20, min_value, max_value)
)
assert np.isclose(
entry.quant24, quantize_float(true_value, 24, min_value, max_value)
)
assert np.isclose(
entry.quant25, quantize_float(true_value, 25, min_value, max_value)
)
assert np.isclose(
entry.quant32, quantize_float(true_value, 32, min_value, max_value)
)

entry = arrays[2]
true_value = -6.2875986e-22
assert entry.trunc10 == truncate_float(true_value, 10)
assert entry.trunc16 == truncate_float(true_value, 16)
assert entry.trunc24 == truncate_float(true_value, 24)
assert entry.trunc31 == truncate_float(true_value, 31)
assert np.isclose(
entry.quant1, quantize_float(true_value, 1, min_value, max_value)
)
assert np.isclose(
entry.quant8, quantize_float(true_value, 8, min_value, max_value)
)
assert np.isclose(
entry.quant16, quantize_float(true_value, 16, min_value, max_value)
)
assert np.isclose(
entry.quant20, quantize_float(true_value, 20, min_value, max_value)
)
assert np.isclose(
entry.quant24, quantize_float(true_value, 24, min_value, max_value)
)
assert np.isclose(
entry.quant25,
quantize_float(true_value, 25, min_value, max_value),
atol=2e-07,
)
assert np.isclose(
entry.quant32, quantize_float(true_value, 32, min_value, max_value)
)

entry = arrays[3]
true_value = -1.9060668
assert entry.trunc10 == truncate_float(true_value, 10)
assert entry.trunc16 == truncate_float(true_value, 16)
assert entry.trunc24 == truncate_float(true_value, 24)
assert entry.trunc31 == truncate_float(true_value, 31)
assert np.isclose(
entry.quant1, quantize_float(true_value, 1, min_value, max_value)
)
assert np.isclose(
entry.quant8, quantize_float(true_value, 8, min_value, max_value)
)
assert np.isclose(
entry.quant16, quantize_float(true_value, 16, min_value, max_value)
)
assert np.isclose(
entry.quant20, quantize_float(true_value, 20, min_value, max_value)
)
assert np.isclose(
entry.quant24, quantize_float(true_value, 24, min_value, max_value)
)
assert np.isclose(
entry.quant25, quantize_float(true_value, 25, min_value, max_value)
)
assert np.isclose(
entry.quant32, quantize_float(true_value, 32, min_value, max_value)
)


def test_multiple_representations():
filename = skhep_testdata.data_path(
"test_multiple_representations_rntuple_v1-0-0-0.root"
)
with uproot.open(filename) as f:
obj = f["ntuple"]

assert len(obj.page_list_envelopes.pagelinklist) == 3
# The zeroth representation is active in clusters 0 and 2, but not in cluster 1
assert not obj.page_list_envelopes.pagelinklist[0][0].suppressed
assert obj.page_list_envelopes.pagelinklist[1][0].suppressed
assert not obj.page_list_envelopes.pagelinklist[2][0].suppressed

arrays = obj.arrays()

assert np.allclose(arrays.real, [1, 2, 3])

0 comments on commit 4ee1a2d

Please sign in to comment.