Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow get_df on all data_types #887

Merged
merged 9 commits into from
Sep 23, 2024
11 changes: 2 additions & 9 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,15 +1908,8 @@ def get_df(

"""
df = self.get_array(run_id, targets, save=save, max_workers=max_workers, **kwargs)
try:
return pd.DataFrame.from_records(df)
except Exception as e:
if "Data must be 1-dimensional" in str(e):
raise ValueError(
f"Cannot load '{targets}' as a dataframe because it has "
"array fields. Please use get_array."
)
raise

return strax.convert_structured_array_to_df(df, log=self.log)

def get_zarr(
self,
Expand Down
37 changes: 37 additions & 0 deletions strax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,3 +804,40 @@ def convert_tuple_to_list(init_func_input):
else:
# if not a container, return. i.e. int, float, bytes, str etc.
return func_input


@export
def convert_structured_array_to_df(structured_array, log=None):
"""Convert a structured numpy array to a pandas DataFrame.

Parameters:
structured_array (numpy.ndarray): The structured array to be converted.
Returns:
pandas.DataFrame: The converted DataFrame.

"""

if log is None:
import logging

log = logging.getLogger("strax_array_to_df")

data_dict = {}
converted_cols = []
for name in structured_array.dtype.names:
col = structured_array[name]
if col.ndim > 1:
# Convert n-dimensional columns to lists of ndarrays
data_dict[name] = [np.array(row) for row in col]
converted_cols.append(name)
else:
data_dict[name] = col

if converted_cols:
log.warning(
f"Columns {converted_cols} contain non-scalar entries. "
"Some pandas functions (e.g., groupby, apply) might "
"not perform as expected on these columns."
)

return pd.DataFrame(data_dict)
22 changes: 22 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@ def test_core(allow_multiprocess, max_workers, processor):
assert bla.dtype == strax.peak_dtype()


@processing_conditions
def test_core_df(allow_multiprocess, max_workers, processor, caplog):
"""Test that get_df works with N-dimensional data."""
"""Test that get_df works with N-dimensional data."""
mystrax = strax.Context(
storage=[],
register=[Records, Peaks],
processors=[processor],
allow_multiprocess=allow_multiprocess,
use_per_run_defaults=True,
)

df = mystrax.get_df(run_id=run_id, targets="peaks", max_workers=max_workers)
p = mystrax.get_single_plugin(run_id, "records")
assert len(df.loc[0, "data"]) == 200
assert len(df) == p.config["recs_per_chunk"] * p.config["n_chunks"]
assert (
"contain non-scalar entries. Some pandas functions (e.g., groupby, apply)"
" might not perform as expected on these columns." in caplog.text
)


def test_post_office_state():
mystrax = strax.Context(
storage=[],
Expand Down