Skip to content

Commit

Permalink
Allow get_df on all data_types (#887)
Browse files Browse the repository at this point in the history
* Allow get_df on all data_types

* Add basic test for get_df on peaks data

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add warning message for non-scalar DataFrame entries

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Reduce line length

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Dacheng Xu <[email protected]>
  • Loading branch information
3 people authored Sep 23, 2024
1 parent 0cfe543 commit 0696ef5
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 9 deletions.
11 changes: 2 additions & 9 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,15 +1908,8 @@ def get_df(
"""
df = self.get_array(run_id, targets, save=save, max_workers=max_workers, **kwargs)
try:
return pd.DataFrame.from_records(df)
except Exception as e:
if "Data must be 1-dimensional" in str(e):
raise ValueError(
f"Cannot load '{targets}' as a dataframe because it has "
"array fields. Please use get_array."
)
raise

return strax.convert_structured_array_to_df(df, log=self.log)

def get_zarr(
self,
Expand Down
37 changes: 37 additions & 0 deletions strax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,3 +804,40 @@ def convert_tuple_to_list(init_func_input):
else:
# if not a container, return. i.e. int, float, bytes, str etc.
return func_input


@export
def convert_structured_array_to_df(structured_array, log=None):
"""Convert a structured numpy array to a pandas DataFrame.
Parameters:
structured_array (numpy.ndarray): The structured array to be converted.
Returns:
pandas.DataFrame: The converted DataFrame.
"""

if log is None:
import logging

log = logging.getLogger("strax_array_to_df")

data_dict = {}
converted_cols = []
for name in structured_array.dtype.names:
col = structured_array[name]
if col.ndim > 1:
# Convert n-dimensional columns to lists of ndarrays
data_dict[name] = [np.array(row) for row in col]
converted_cols.append(name)
else:
data_dict[name] = col

if converted_cols:
log.warning(
f"Columns {converted_cols} contain non-scalar entries. "
"Some pandas functions (e.g., groupby, apply) might "
"not perform as expected on these columns."
)

return pd.DataFrame(data_dict)
22 changes: 22 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@ def test_core(allow_multiprocess, max_workers, processor):
assert bla.dtype == strax.peak_dtype()


@processing_conditions
def test_core_df(allow_multiprocess, max_workers, processor, caplog):
"""Test that get_df works with N-dimensional data."""
"""Test that get_df works with N-dimensional data."""
mystrax = strax.Context(
storage=[],
register=[Records, Peaks],
processors=[processor],
allow_multiprocess=allow_multiprocess,
use_per_run_defaults=True,
)

df = mystrax.get_df(run_id=run_id, targets="peaks", max_workers=max_workers)
p = mystrax.get_single_plugin(run_id, "records")
assert len(df.loc[0, "data"]) == 200
assert len(df) == p.config["recs_per_chunk"] * p.config["n_chunks"]
assert (
"contain non-scalar entries. Some pandas functions (e.g., groupby, apply)"
" might not perform as expected on these columns." in caplog.text
)


def test_post_office_state():
mystrax = strax.Context(
storage=[],
Expand Down

0 comments on commit 0696ef5

Please sign in to comment.