From 0db533e12d82395fcd6dd7f0f3f68645215237ea Mon Sep 17 00:00:00 2001 From: t-reents Date: Wed, 3 Jul 2024 17:31:08 +0200 Subject: [PATCH 1/2] Simplify `get_builder_from_protocol` in `ProjwfcBandsWorkChain` This commit mainly simplifies the current version of the `get_builder_from_protocol` method in `ProjwfcBandsWorkChain`. Moreover, it adds support for overrides containing standard Python datatypes, e.g. `kpoints_distance` specified as a float`. --- .../workflows/projwfcbands.py | 22 +++---------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/src/aiida_wannier90_workflows/workflows/projwfcbands.py b/src/aiida_wannier90_workflows/workflows/projwfcbands.py index 3117321..e81966c 100644 --- a/src/aiida_wannier90_workflows/workflows/projwfcbands.py +++ b/src/aiida_wannier90_workflows/workflows/projwfcbands.py @@ -107,7 +107,6 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ """ from aiida_wannier90_workflows.utils.workflows.builder.submit import ( recursive_merge_builder, - recursive_merge_container, ) type_check(pw_code, (str, int, orm.Code)) @@ -119,10 +118,6 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ # Prepare workchain builder builder = cls.get_builder() - protocol_inputs = cls.get_protocol_inputs( - protocol=protocol, overrides=overrides - ) - projwfc_overrides = None if overrides: projwfc_overrides = overrides.pop("projwfc", None) @@ -137,25 +132,14 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ # By default do not run relax pwbands_builder.pop("relax", None) - inputs = pwbands_builder._inputs(prune=True) # pylint: disable=protected-access projwfc_builder = ProjwfcBaseWorkChain.get_builder_from_protocol( projwfc_code, protocol=protocol, overrides=projwfc_overrides ) + projwfc_builder.pop("clean_workdir", None) - inputs["projwfc"] = projwfc_builder._inputs( # pylint: disable=protected-access - prune=True - ) - inputs["projwfc"].pop("clean_workdir", None) - - # Need to convert `clean_workdir` to `orm.Bool` - if "clean_workdir" in protocol_inputs: - protocol_inputs["clean_workdir"] = orm.Bool( - protocol_inputs["clean_workdir"] - ) - - inputs = recursive_merge_container(inputs, protocol_inputs) - builder = recursive_merge_builder(builder, inputs) + builder.projwfc = projwfc_builder + builder = recursive_merge_builder(builder, pwbands_builder) return builder From 60bfe9da4feea2131740cc8f53e6f29890417828 Mon Sep 17 00:00:00 2001 From: t-reents Date: Mon, 23 Sep 2024 16:21:56 +0200 Subject: [PATCH 2/2] Fix support for parent protocol --- .../workflows/projwfcbands.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/aiida_wannier90_workflows/workflows/projwfcbands.py b/src/aiida_wannier90_workflows/workflows/projwfcbands.py index e81966c..6ffa548 100644 --- a/src/aiida_wannier90_workflows/workflows/projwfcbands.py +++ b/src/aiida_wannier90_workflows/workflows/projwfcbands.py @@ -118,15 +118,17 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ # Prepare workchain builder builder = cls.get_builder() - projwfc_overrides = None - if overrides: - projwfc_overrides = overrides.pop("projwfc", None) + protocol_inputs = cls.get_protocol_inputs( + protocol=protocol, overrides=overrides + ) + + projwfc_overrides = protocol_inputs.pop("projwfc", None) pwbands_builder = PwBandsWorkChain.get_builder_from_protocol( code=pw_code, structure=structure, protocol=protocol, - overrides=overrides, + overrides=protocol_inputs, **kwargs, )