Skip to content

Commit

Permalink
Merge pull request #297 from proektlab/master
Browse files Browse the repository at this point in the history
Allow non-hashable types and nested params in get_params_diffs
  • Loading branch information
kushalkolar authored Jun 6, 2024
2 parents 6a6a184 + 0168d74 commit 38f6ebe
Showing 1 changed file with 45 additions and 18 deletions.
63 changes: 45 additions & 18 deletions mesmerize_core/caiman_extensions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import Counter
from datetime import datetime
import time
from copy import deepcopy

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -265,35 +266,61 @@ def get_params_diffs(self, algo: str, item_name: str) -> pd.DataFrame:
`item_name`. The returned index corresponds to the
index of the original DataFrame
"""
"""
def flatten_params(params_dict: dict):
"""
Produce a flat dict with one entry for each parameter in the passed dict.
If params_dict['main'] is nested one level (e.g., {'init': {'K': 5}, 'merging': {'merge_thr': 0.85}}...),
each key in the output is <outerKey>.<innerKey>, e.g., [(init.K, 5), (merging.merge_thr, 0.85)]
"""
params = {}
for key1, val1 in params_dict.items():
if isinstance(val1, dict): # nested
for key2, val2 in val1.items():
params[f"{key1}.{key2}"] = val2
else:
params[key1] = val1
return params

sub_df = self._df[self._df["item_name"] == item_name]
sub_df = sub_df[sub_df["algo"] == algo]

if sub_df.index.size == 0:
raise NameError(f"The given `item_name`: {item_name}, does not exist in the DataFrame")

all_variants = set(
tuple(
chain.from_iterable(
[
tuple(p["main"].items()) for p in sub_df.params.values
]
)
)
)

counts = Counter([av[0] for av in all_variants])
variants_exist = [param[0] for param in counts.items() if param[1] > 1]

# gives a series where each item is a dict that has the unique params that correspond to a row
# get flattened parameters for each of the filtered items
params_flat = sub_df.params.map(lambda p: flatten_params(p["main"]))

# build list of params that differ between different parameter sets
common_params = deepcopy(params_flat.iat[0]) # holds the common value for parameters found in all sets (so far)
varying_params = set() # set of parameter keys that appear in not all sets or with varying values

for this_params in params_flat.iloc[1:]:
# first, anything that's not in both this dict and the common set is considered varying
common_paramset = set(common_params.keys())
for not_common_key in common_paramset.symmetric_difference(this_params.keys()):
varying_params.add(not_common_key)
if not_common_key in common_paramset:
del common_params[not_common_key]
common_paramset.remove(not_common_key)

# second, look at params in the common set and remove any that differ for this set
for key in common_paramset: # iterate over this set rather than dict itself to avoid issues when deleting entries
if not np.array_equal(common_params[key], this_params[key]): # (should also work for scalars/arbitrary objects)
varying_params.add(key)
del common_params[key]

# gives a list where each item is a dict that has the unique params that correspond to a row
# the indices of this series correspond to the index of the row in the parent dataframe
diffs: pd.Series = sub_df["params"].apply(
lambda p: {k: p["main"][k] for k in variants_exist if k in p["main"].keys()}
)
diffs = params_flat.map(lambda p: {key: p[key] for key in varying_params if key in p})

# return as a nicely formatted dataframe
diffs_df = pd.DataFrame.from_dict(diffs.tolist(), dtype=object).set_index(diffs.index)

# replace any missing parameters with a string for clarity
with pd.option_context('future.no_silent_downcasting', True): # avoids warning about downcasting
diffs_df.fillna("<default>", inplace=True)

return diffs_df

@warning_experimental("This feature will change in the future and directly return the "
Expand Down

0 comments on commit 38f6ebe

Please sign in to comment.