From 61538d9a8450848f09db2d70465f152c4af7c357 Mon Sep 17 00:00:00 2001 From: dachengx Date: Fri, 26 Apr 2024 01:01:01 -0500 Subject: [PATCH] Add method `replication_tree` to `strax.Context` --- axidence/context.py | 80 +++++++++++++++++++++++++++++++++++++++++++ tests/test_context.py | 34 ++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 tests/test_context.py diff --git a/axidence/context.py b/axidence/context.py index e3c6155..e88b2da 100644 --- a/axidence/context.py +++ b/axidence/context.py @@ -1,3 +1,5 @@ +from immutabledict import immutabledict +from tqdm import tqdm import strax import straxen @@ -40,3 +42,81 @@ def salt_to_context(self): EventBuilding, ) ) + + +@strax.Context.add_method +def plugin_factory(st, data_type, suffixes): + plugin = st._plugin_class_registry[data_type] + + new_plugins = [] + p = st._get_plugins((data_type,), run_id="0")[data_type] + + for suffix in suffixes: + snake = "_" + strax.camel_to_snake(suffix) + + class new_plugin(plugin): + suffix = snake + + def infer_dtype(self): + # some plugins like PulseProcessing uses self.deps in infer_dtype, + # which will cause error because the dependency tree changes + # https://github.com/XENONnT/straxen/blob/b4910e560a6a7f11288a4368816e692c26f8bc73/straxen/plugins/records/records.py#L142 + # so we assign the dtype manually and raise error in infer_dtype method + raise RuntimeError + + # need to be compatible with strax.camel_to_snake + # https://github.com/AxFoundation/strax/blob/7da9a2a6375e7614181830484b322389986cf064/strax/context.py#L324 + new_plugin.__name__ = plugin.__name__ + suffix + + # assign the same attributes as the original plugin + if hasattr(p, "depends_on"): + new_plugin.depends_on = tuple(d + snake for d in p.depends_on) + else: + raise RuntimeError + + if hasattr(p, "provides"): + new_plugin.provides = tuple(p + snake for p in p.provides) + else: + snake_name = strax.camel_to_snake(new_plugin.__name__) + new_plugin.provides = (snake_name,) + + if hasattr(p, "data_kind"): + if isinstance(p.data_kind, (dict, immutabledict)): + keys = [k + snake for k in p.data_kind.keys()] + values = [v + snake for v in p.data_kind.values()] + new_plugin.data_kind = immutabledict(zip(keys, values)) + else: + new_plugin.data_kind = p.data_kind + snake + else: + raise RuntimeError(f"data_kind is not defined for instance of {plugin.__name__}") + + if hasattr(p, "save_when"): + if isinstance(p.save_when, (dict, immutabledict)): + keys = [k + snake for k in p.save_when] + new_plugin.save_when = immutabledict(zip(keys, p.save_when.values())) + else: + new_plugin.save_when = p.save_when + snake + else: + raise RuntimeError(f"save_when is not defined for instance of {plugin.__name__}") + + if isinstance(p.dtype, (dict, immutabledict)): + new_plugin.dtype = dict(zip([k + snake for k in p.dtype.keys()], p.dtype.values())) + else: + new_plugin.dtype = p.dtype + + new_plugins.append(new_plugin) + return new_plugins + + +@strax.Context.add_method +def replication_tree(st, suffixes=["Paired", "Salted"], tqdm_disable=True): + snakes = ["_" + strax.camel_to_snake(suffix) for suffix in suffixes] + for k in st._plugin_class_registry.keys(): + for s in snakes: + if s in k: + raise ValueError(f"{k} with suffix {s} is already registered!") + plugins_collection = [] + for k in tqdm(st._plugin_class_registry.keys(), disable=tqdm_disable): + plugins_collection += st.plugin_factory(k, suffixes) + + st.register(plugins_collection) diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 0000000..4ee5295 --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,34 @@ +import os +import uuid +import shutil +import tempfile +from unittest import TestCase + +import axidence + + +class TestContext(TestCase): + @classmethod + def setUpClass(cls) -> None: + # Maybe keeping one temp dir is a bit overkill + temp_folder = uuid.uuid4().hex + cls.tempdir = os.path.join(tempfile.gettempdir(), temp_folder) + assert not os.path.exists(cls.tempdir) + + cls.run_id = "0" * 6 + cls.st = axidence.unsalted_context(output_folder=cls.tempdir) + cls.st.salt_to_context() + + @classmethod + def tearDownClass(cls): + # Make sure to only cleanup this dir after we have done all the tests + if os.path.exists(cls.tempdir): + shutil.rmtree(cls.tempdir) + + def test_replication_tree(self): + """Test the replication_tree method.""" + self.st.replication_tree() + with self.assertRaises( + ValueError, msg="Should raise error calling replication_tree twice!" + ): + self.st.replication_tree()