diff --git a/axidence/context.py b/axidence/context.py index 23abf04..99c6e47 100644 --- a/axidence/context.py +++ b/axidence/context.py @@ -45,6 +45,83 @@ def ordinary_context(**kwargs): return straxen.contexts.xenonnt_online(_database_init=False, **kwargs) +@export +def assign_plugin_attributes( + new_plugin, old_plugin, old_instance, suffix, snake, assign_attributes=None +): + # need to be compatible with strax.camel_to_snake + # https://github.com/AxFoundation/strax/blob/7da9a2a6375e7614181830484b322389986cf064/strax/context.py#L324 + new_plugin.__name__ = old_plugin.__name__ + suffix + + # assign the attributes from the original plugin + if assign_attributes and old_plugin.__name__ in assign_attributes: + for attr in assign_attributes[old_plugin.__name__]: + setattr(new_plugin, attr, getattr(old_instance, attr)) + + # assign the same attributes as the original plugin + if hasattr(old_instance, "depends_on"): + new_plugin.depends_on = tuple(d + snake for d in old_instance.depends_on) + else: + raise RuntimeError(f"depends_on is not defined for instance of {old_plugin.__name__}") + + if hasattr(old_instance, "provides"): + new_plugin.provides = tuple(p + snake for p in old_instance.provides) + else: + snake_name = strax.camel_to_snake(new_plugin.__name__) + new_plugin.provides = (snake_name,) + + if hasattr(old_instance, "data_kind"): + if isinstance(old_instance.data_kind, (dict, immutabledict)): + keys = [k + snake for k in old_instance.data_kind.keys()] + values = [v + snake for v in old_instance.data_kind.values()] + new_plugin.data_kind = immutabledict(zip(keys, values)) + else: + new_plugin.data_kind = old_instance.data_kind + snake + else: + raise RuntimeError(f"data_kind is not defined for instance of {old_plugin.__name__}") + + if hasattr(old_instance, "save_when"): + if isinstance(old_instance.save_when, (dict, immutabledict)): + keys = [k + snake for k in old_instance.save_when] + new_plugin.save_when = immutabledict(zip(keys, old_instance.save_when.values())) + else: + new_plugin.save_when = old_instance.save_when + snake + else: + raise RuntimeError(f"save_when is not defined for instance of {old_plugin.__name__}") + + if isinstance(old_instance.dtype, (dict, immutabledict)): + new_plugin.dtype = dict( + zip([k + snake for k in old_instance.dtype.keys()], old_instance.dtype.values()) + ) + else: + new_plugin.dtype = old_instance.dtype + + if isinstance(old_instance, CutPlugin): + if hasattr(old_instance, "cut_name"): + new_plugin.cut_name = old_instance.cut_name + snake + else: + raise RuntimeError(f"cut_name is not defined for instance of {old_plugin.__name__}") + + if isinstance(old_instance, CutList): + if hasattr(old_instance, "accumulated_cuts_string"): + new_plugin.accumulated_cuts_string = old_instance.accumulated_cuts_string + snake + else: + raise RuntimeError( + f"accumulated_cuts_string is not defined for instance of {old_plugin.__name__}" + ) + + if isinstance(old_instance, CutPlugin) or isinstance(old_instance, CutList): + # this will make CutList.cuts to be invalid + new_plugin.dtype = np.dtype( + [ + ((d[0][0], d[0][1] + snake), d[1]) if d[0][1] not in ["time", "endtime"] else d + for d in new_plugin.dtype.descr + ] + ) + + return new_plugin + + @strax.Context.add_method def plugin_factory(st, data_type, suffixes, assign_attributes=None): """Create new plugins inheriting from the plugin which provides @@ -73,73 +150,9 @@ def do_compute(self, chunk_i=None, **kwargs): new_kwargs = dict(zip(new_keys, kwargs.values())) return super().do_compute(chunk_i=chunk_i, **new_kwargs) - # need to be compatible with strax.camel_to_snake - # https://github.com/AxFoundation/strax/blob/7da9a2a6375e7614181830484b322389986cf064/strax/context.py#L324 - new_plugin.__name__ = plugin.__name__ + suffix - - # assign the attributes from the original plugin - if assign_attributes and plugin.__name__ in assign_attributes: - for attr in assign_attributes[plugin.__name__]: - setattr(new_plugin, attr, getattr(p, attr)) - - # assign the same attributes as the original plugin - if hasattr(p, "depends_on"): - new_plugin.depends_on = tuple(d + snake for d in p.depends_on) - else: - raise RuntimeError(f"depends_on is not defined for instance of {plugin.__name__}") - - if hasattr(p, "provides"): - new_plugin.provides = tuple(p + snake for p in p.provides) - else: - snake_name = strax.camel_to_snake(new_plugin.__name__) - new_plugin.provides = (snake_name,) - - if hasattr(p, "data_kind"): - if isinstance(p.data_kind, (dict, immutabledict)): - keys = [k + snake for k in p.data_kind.keys()] - values = [v + snake for v in p.data_kind.values()] - new_plugin.data_kind = immutabledict(zip(keys, values)) - else: - new_plugin.data_kind = p.data_kind + snake - else: - raise RuntimeError(f"data_kind is not defined for instance of {plugin.__name__}") - - if hasattr(p, "save_when"): - if isinstance(p.save_when, (dict, immutabledict)): - keys = [k + snake for k in p.save_when] - new_plugin.save_when = immutabledict(zip(keys, p.save_when.values())) - else: - new_plugin.save_when = p.save_when + snake - else: - raise RuntimeError(f"save_when is not defined for instance of {plugin.__name__}") - - if isinstance(p.dtype, (dict, immutabledict)): - new_plugin.dtype = dict(zip([k + snake for k in p.dtype.keys()], p.dtype.values())) - else: - new_plugin.dtype = p.dtype - - if isinstance(p, CutPlugin): - if hasattr(p, "cut_name"): - new_plugin.cut_name = p.cut_name + snake - else: - raise RuntimeError(f"cut_name is not defined for instance of {plugin.__name__}") - - if isinstance(p, CutList): - if hasattr(p, "accumulated_cuts_string"): - new_plugin.accumulated_cuts_string = p.accumulated_cuts_string + snake - else: - raise RuntimeError( - f"accumulated_cuts_string is not defined for instance of {plugin.__name__}" - ) - - if isinstance(p, CutPlugin) or isinstance(p, CutList): - # this will make CutList.cuts to be invalid - new_plugin.dtype = np.dtype( - [ - ((d[0][0], d[0][1] + snake), d[1]) if d[0][1] not in ["time", "endtime"] else d - for d in new_plugin.dtype.descr - ] - ) + new_plugin = assign_plugin_attributes( + new_plugin, plugin, p, suffix, snake, assign_attributes=assign_attributes + ) new_plugins.append(new_plugin) return new_plugins @@ -160,7 +173,7 @@ def replication_tree(st, suffixes=["Paired", "Salted"], assign_attributes=None, snakes = ["_" + strax.camel_to_snake(suffix) for suffix in suffixes] for k in st._plugin_class_registry.keys(): for s in snakes: - if s in k: + if k.endswith(s): raise ValueError(f"{k} with suffix {s} is already registered!") plugins_collection = [] for k in tqdm(st._plugin_class_registry.keys(), disable=tqdm_disable):