Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate assign_plugin_attributes for future usage #60

Merged
merged 1 commit into from
May 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 81 additions & 68 deletions axidence/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,83 @@ def ordinary_context(**kwargs):
return straxen.contexts.xenonnt_online(_database_init=False, **kwargs)


@export
def assign_plugin_attributes(
new_plugin, old_plugin, old_instance, suffix, snake, assign_attributes=None
):
# need to be compatible with strax.camel_to_snake
# https://github.com/AxFoundation/strax/blob/7da9a2a6375e7614181830484b322389986cf064/strax/context.py#L324
new_plugin.__name__ = old_plugin.__name__ + suffix

# assign the attributes from the original plugin
if assign_attributes and old_plugin.__name__ in assign_attributes:
for attr in assign_attributes[old_plugin.__name__]:
setattr(new_plugin, attr, getattr(old_instance, attr))

# assign the same attributes as the original plugin
if hasattr(old_instance, "depends_on"):
new_plugin.depends_on = tuple(d + snake for d in old_instance.depends_on)
else:
raise RuntimeError(f"depends_on is not defined for instance of {old_plugin.__name__}")

if hasattr(old_instance, "provides"):
new_plugin.provides = tuple(p + snake for p in old_instance.provides)
else:
snake_name = strax.camel_to_snake(new_plugin.__name__)
new_plugin.provides = (snake_name,)

if hasattr(old_instance, "data_kind"):
if isinstance(old_instance.data_kind, (dict, immutabledict)):
keys = [k + snake for k in old_instance.data_kind.keys()]
values = [v + snake for v in old_instance.data_kind.values()]
new_plugin.data_kind = immutabledict(zip(keys, values))
else:
new_plugin.data_kind = old_instance.data_kind + snake
else:
raise RuntimeError(f"data_kind is not defined for instance of {old_plugin.__name__}")

if hasattr(old_instance, "save_when"):
if isinstance(old_instance.save_when, (dict, immutabledict)):
keys = [k + snake for k in old_instance.save_when]
new_plugin.save_when = immutabledict(zip(keys, old_instance.save_when.values()))
else:
new_plugin.save_when = old_instance.save_when + snake
else:
raise RuntimeError(f"save_when is not defined for instance of {old_plugin.__name__}")

if isinstance(old_instance.dtype, (dict, immutabledict)):
new_plugin.dtype = dict(
zip([k + snake for k in old_instance.dtype.keys()], old_instance.dtype.values())
)
else:
new_plugin.dtype = old_instance.dtype

if isinstance(old_instance, CutPlugin):
if hasattr(old_instance, "cut_name"):
new_plugin.cut_name = old_instance.cut_name + snake
else:
raise RuntimeError(f"cut_name is not defined for instance of {old_plugin.__name__}")

if isinstance(old_instance, CutList):
if hasattr(old_instance, "accumulated_cuts_string"):
new_plugin.accumulated_cuts_string = old_instance.accumulated_cuts_string + snake
else:
raise RuntimeError(
f"accumulated_cuts_string is not defined for instance of {old_plugin.__name__}"
)

if isinstance(old_instance, CutPlugin) or isinstance(old_instance, CutList):
# this will make CutList.cuts to be invalid
new_plugin.dtype = np.dtype(
[
((d[0][0], d[0][1] + snake), d[1]) if d[0][1] not in ["time", "endtime"] else d
for d in new_plugin.dtype.descr
]
)

return new_plugin


@strax.Context.add_method
def plugin_factory(st, data_type, suffixes, assign_attributes=None):
"""Create new plugins inheriting from the plugin which provides
Expand Down Expand Up @@ -73,73 +150,9 @@ def do_compute(self, chunk_i=None, **kwargs):
new_kwargs = dict(zip(new_keys, kwargs.values()))
return super().do_compute(chunk_i=chunk_i, **new_kwargs)

# need to be compatible with strax.camel_to_snake
# https://github.com/AxFoundation/strax/blob/7da9a2a6375e7614181830484b322389986cf064/strax/context.py#L324
new_plugin.__name__ = plugin.__name__ + suffix

# assign the attributes from the original plugin
if assign_attributes and plugin.__name__ in assign_attributes:
for attr in assign_attributes[plugin.__name__]:
setattr(new_plugin, attr, getattr(p, attr))

# assign the same attributes as the original plugin
if hasattr(p, "depends_on"):
new_plugin.depends_on = tuple(d + snake for d in p.depends_on)
else:
raise RuntimeError(f"depends_on is not defined for instance of {plugin.__name__}")

if hasattr(p, "provides"):
new_plugin.provides = tuple(p + snake for p in p.provides)
else:
snake_name = strax.camel_to_snake(new_plugin.__name__)
new_plugin.provides = (snake_name,)

if hasattr(p, "data_kind"):
if isinstance(p.data_kind, (dict, immutabledict)):
keys = [k + snake for k in p.data_kind.keys()]
values = [v + snake for v in p.data_kind.values()]
new_plugin.data_kind = immutabledict(zip(keys, values))
else:
new_plugin.data_kind = p.data_kind + snake
else:
raise RuntimeError(f"data_kind is not defined for instance of {plugin.__name__}")

if hasattr(p, "save_when"):
if isinstance(p.save_when, (dict, immutabledict)):
keys = [k + snake for k in p.save_when]
new_plugin.save_when = immutabledict(zip(keys, p.save_when.values()))
else:
new_plugin.save_when = p.save_when + snake
else:
raise RuntimeError(f"save_when is not defined for instance of {plugin.__name__}")

if isinstance(p.dtype, (dict, immutabledict)):
new_plugin.dtype = dict(zip([k + snake for k in p.dtype.keys()], p.dtype.values()))
else:
new_plugin.dtype = p.dtype

if isinstance(p, CutPlugin):
if hasattr(p, "cut_name"):
new_plugin.cut_name = p.cut_name + snake
else:
raise RuntimeError(f"cut_name is not defined for instance of {plugin.__name__}")

if isinstance(p, CutList):
if hasattr(p, "accumulated_cuts_string"):
new_plugin.accumulated_cuts_string = p.accumulated_cuts_string + snake
else:
raise RuntimeError(
f"accumulated_cuts_string is not defined for instance of {plugin.__name__}"
)

if isinstance(p, CutPlugin) or isinstance(p, CutList):
# this will make CutList.cuts to be invalid
new_plugin.dtype = np.dtype(
[
((d[0][0], d[0][1] + snake), d[1]) if d[0][1] not in ["time", "endtime"] else d
for d in new_plugin.dtype.descr
]
)
new_plugin = assign_plugin_attributes(
new_plugin, plugin, p, suffix, snake, assign_attributes=assign_attributes
)

new_plugins.append(new_plugin)
return new_plugins
Expand All @@ -160,7 +173,7 @@ def replication_tree(st, suffixes=["Paired", "Salted"], assign_attributes=None,
snakes = ["_" + strax.camel_to_snake(suffix) for suffix in suffixes]
for k in st._plugin_class_registry.keys():
for s in snakes:
if s in k:
if k.endswith(s):
raise ValueError(f"{k} with suffix {s} is already registered!")
plugins_collection = []
for k in tqdm(st._plugin_class_registry.keys(), disable=tqdm_disable):
Expand Down
Loading