diff --git a/src/aiidalab_qe/app/configuration/__init__.py b/src/aiidalab_qe/app/configuration/__init__.py
index 5551da9b1..b9775ea0b 100644
--- a/src/aiidalab_qe/app/configuration/__init__.py
+++ b/src/aiidalab_qe/app/configuration/__init__.py
@@ -56,7 +56,7 @@ def __init__(self, model: ConfigurationStepModel, **kwargs):
lambda structure: ""
if structure
else """
-
+
Please set the input structure first.
""",
diff --git a/src/aiidalab_qe/app/configuration/advanced/subsettings.py b/src/aiidalab_qe/app/configuration/advanced/subsettings.py
index 5a588d6b5..b251d5160 100644
--- a/src/aiidalab_qe/app/configuration/advanced/subsettings.py
+++ b/src/aiidalab_qe/app/configuration/advanced/subsettings.py
@@ -8,6 +8,7 @@
class AdvancedCalculationSubSettingsModel(Model):
+ identifier = "sub"
dependencies = []
loaded_from_process = tl.Bool(False)
@@ -33,12 +34,10 @@ def reset(self):
class AdvancedConfigurationSubSettingsPanel(ipw.VBox, t.Generic[M]):
- identifier = "sub"
-
def __init__(self, model: M, **kwargs):
from aiidalab_qe.common.widgets import LoadingWidget
- self.loading_message = LoadingWidget(f"Loading {self.identifier} settings")
+ self.loading_message = LoadingWidget(f"Loading {model.identifier} settings")
super().__init__(
layout={"justify_content": "space-between", **kwargs.get("layout", {})},
diff --git a/src/aiidalab_qe/app/configuration/model.py b/src/aiidalab_qe/app/configuration/model.py
index 5c3949eeb..abbf9cfbf 100644
--- a/src/aiidalab_qe/app/configuration/model.py
+++ b/src/aiidalab_qe/app/configuration/model.py
@@ -86,7 +86,7 @@ def update(self):
def get_model_state(self):
parameters = {
identifier: model.get_model_state()
- for identifier, model in self._models.items()
+ for identifier, model in self.get_models()
if model.include
}
parameters["workchain"] |= {
@@ -100,7 +100,7 @@ def set_model_state(self, parameters):
workchain_parameters: dict = parameters.get("workchain", {})
self.relax_type = workchain_parameters.get("relax_type")
properties = set(workchain_parameters.get("properties", []))
- for identifier, model in self._models.items():
+ for identifier, model in self.get_models():
model.include = identifier in self._default_models | properties
if parameters.get(identifier):
model.set_model_state(parameters[identifier])
@@ -111,7 +111,7 @@ def reset(self):
self.relax_type_help = self._get_default_relax_type_help()
self.relax_type_options = self._get_default_relax_type_options()
self.relax_type = self._get_default_relax_type()
- for identifier, model in self._models.items():
+ for identifier, model in self.get_models():
if identifier not in self._default_models:
model.include = False
@@ -135,7 +135,7 @@ def _link_model(self, model: ConfigurationSettingsModel):
def _get_properties(self):
properties = []
- for identifier, model in self._models.items():
+ for identifier, model in self.get_models():
if identifier in self._default_models:
continue
if model.include:
diff --git a/src/aiidalab_qe/app/submission/__init__.py b/src/aiidalab_qe/app/submission/__init__.py
index d9176b9ab..344535f3a 100644
--- a/src/aiidalab_qe/app/submission/__init__.py
+++ b/src/aiidalab_qe/app/submission/__init__.py
@@ -10,7 +10,8 @@
from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
from aiidalab_qe.app.utils import get_entry_items
-from aiidalab_qe.common.panel import ResourceSettingsModel, ResourceSettingsPanel
+from aiidalab_qe.common.code import PluginCodes, PwCodeModel
+from aiidalab_qe.common.panel import PluginResourceSettingsModel, ResourceSettingsPanel
from aiidalab_qe.common.setup_codes import QESetupWidget
from aiidalab_qe.common.setup_pseudos import PseudosInstallWidget
from aiidalab_widgets_base import WizardAppWidgetStep
@@ -39,10 +40,6 @@ def __init__(self, model: SubmissionStepModel, qe_auto_setup=True, **kwargs):
self._on_submission,
"confirmed",
)
- self._model.observe(
- self._on_input_structure_change,
- "input_structure",
- )
self._model.observe(
self._on_input_parameters_change,
"input_parameters",
@@ -77,22 +74,28 @@ def __init__(self, model: SubmissionStepModel, qe_auto_setup=True, **kwargs):
self.rendered = False
- global_code_model = GlobalResourceSettingsModel()
- self.global_code_settings = GlobalResourceSettingsPanel(model=global_code_model)
- self._model.add_model("global", global_code_model)
- global_code_model.observe(
+ global_resources_model = GlobalResourceSettingsModel()
+ self.global_resources = GlobalResourceSettingsPanel(
+ model=global_resources_model
+ )
+ self._model.add_model("global", global_resources_model)
+ ipw.dlink(
+ (self._model, "plugin_overrides"),
+ (global_resources_model, "plugin_overrides"),
+ )
+ global_resources_model.observe(
self._on_plugin_submission_blockers_change,
["submission_blockers"],
)
- global_code_model.observe(
+ global_resources_model.observe(
self._on_plugin_submission_warning_messages_change,
["submission_warning_messages"],
)
self.settings = {
- "global": self.global_code_settings,
+ "global": self.global_resources,
}
- self._fetch_plugin_settings()
+ self._fetch_plugin_resource_settings()
self._install_sssp(qe_auto_setup)
self._set_up_qe(qe_auto_setup)
@@ -197,9 +200,7 @@ def submit(self, _=None):
self._model.confirm()
def reset(self):
- with self.hold_trait_notifications():
- self._model.reset()
- self._model.set_selected_codes()
+ self._model.reset()
@tl.observe("previous_step_state")
def _on_previous_step_state_change(self, _):
@@ -211,14 +212,15 @@ def _on_tab_change(self, change):
tab: ResourceSettingsPanel = self.tabs.children[tab_index] # type: ignore
tab.render()
- def _on_input_structure_change(self, _):
- """"""
-
def _on_input_parameters_change(self, _):
- self._model.update_active_models()
- self._update_tabs()
self._model.update_process_label()
+ self._model.update_plugin_inclusion()
+ self._model.update_plugin_overrides()
self._model.update_submission_blockers()
+ self._update_tabs()
+
+ def _on_plugin_overrides_change(self, _):
+ self._model.update_plugin_overrides()
def _on_plugin_submission_blockers_change(self, _):
self._model.update_submission_blockers()
@@ -237,16 +239,13 @@ def _on_submission_blockers_change(self, _):
self._model.update_submission_blocker_message()
self._update_state()
- def _on_submission_warning_change(self, _):
- self._model.update_submission_warning_message()
-
def _on_installation_change(self, _):
self._model.update_submission_blockers()
def _on_qe_installed(self, _):
self._toggle_qe_installation_widget()
if self._model.qe_installed:
- self._model.refresh_codes()
+ self._model.update()
def _on_sssp_installed(self, _):
self._toggle_sssp_installation_widget()
@@ -325,14 +324,24 @@ def _update_state(self, _=None):
else:
self.state = self.state.CONFIGURED
- def _fetch_plugin_settings(self):
- eps = get_entry_items("aiidalab_qe.properties", "code")
- for identifier, data in eps.items():
+ def _fetch_plugin_resource_settings(self):
+ entries = get_entry_items("aiidalab_qe.properties", "resources")
+ codes: PluginCodes = {
+ "dft": {
+ "pw": PwCodeModel(),
+ },
+ }
+ for identifier, resources in entries.items():
for key in ("panel", "model"):
- if key not in data:
+ if key not in resources:
raise ValueError(f"Entry {identifier} is missing the '{key}' key")
- panel = data["panel"]
- model: ResourceSettingsModel = data["model"]()
+
+ panel = resources["panel"]
+ model: PluginResourceSettingsModel = resources["model"]()
+ model.observe(
+ self._on_plugin_overrides_change,
+ "override",
+ )
model.observe(
self._on_plugin_submission_blockers_change,
["submission_blockers"],
@@ -343,16 +352,11 @@ def _fetch_plugin_settings(self):
)
self._model.add_model(identifier, model)
- def toggle_plugin(_, model=model):
- model.update()
- self._update_tabs()
-
- model.observe(
- toggle_plugin,
- "include",
- )
-
self.settings[identifier] = panel(
identifier=identifier,
model=model,
)
+
+ codes[identifier] = dict(model.get_models())
+
+ self.global_resources.set_up_codes(codes)
diff --git a/src/aiidalab_qe/app/submission/global_settings/model.py b/src/aiidalab_qe/app/submission/global_settings/model.py
index 0345cc6ab..8b2710aa8 100644
--- a/src/aiidalab_qe/app/submission/global_settings/model.py
+++ b/src/aiidalab_qe/app/submission/global_settings/model.py
@@ -5,14 +5,11 @@
import traitlets as tl
from aiida import orm
-from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
from aiidalab_qe.common.code import CodeModel, PwCodeModel
from aiidalab_qe.common.mixins import HasInputStructure
from aiidalab_qe.common.panel import ResourceSettingsModel
from aiidalab_qe.common.widgets import QEAppComputationalResourcesWidget
-DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore
-
class GlobalResourceSettingsModel(
ResourceSettingsModel,
@@ -20,40 +17,23 @@ class GlobalResourceSettingsModel(
):
"""Model for the global code setting."""
+ identifier = "global"
+
dependencies = [
- "input_parameters",
"input_structure",
+ "input_parameters",
]
input_parameters = tl.Dict()
- codes = tl.Dict(
- key_trait=tl.Unicode(), # code name
- value_trait=tl.Instance(CodeModel), # code metadata
- )
- # this is a copy of the codes trait, which is used to trigger the update of the plugin
- global_codes = tl.Dict(
- key_trait=tl.Unicode(), # code name
- value_trait=tl.Dict(), # code metadata
- )
-
- plugin_mapping = tl.Dict(
- key_trait=tl.Unicode(), # plugin identifier
- value_trait=tl.List(tl.Unicode()), # list of code names
- )
-
- submission_blockers = tl.List(tl.Unicode())
- submission_warning_messages = tl.Unicode("")
+ plugin_overrides = tl.List(tl.Unicode())
+ plugin_overrides_notification = tl.Unicode("")
include = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- # Used by the code-setup thread to fetch code options
- # This is necessary to avoid passing the User object
- # between session in separate threads.
- self._default_user_email = orm.User.collection.get_default().email
self._RUN_ON_LOCALHOST_NUM_SITES_WARN_THRESHOLD = 10
self._RUN_ON_LOCALHOST_VOLUME_WARN_THRESHOLD = 1000 # \AA^3
@@ -66,120 +46,82 @@ def __init__(self, *args, **kwargs):
"""
- def refresh_codes(self):
- for _, code_model in self.codes.items():
- code_model.update(self._default_user_email) # type: ignore
+ self.plugin_mapping: dict[str, list[str]] = {}
+
+ def update_global_codes(self):
+ self.global_codes = self.get_model_state()["codes"]
def update_active_codes(self):
- for name, code_model in self.codes.items():
- if name != "quantumespresso.pw":
+ for identifier, code_model in self.get_models():
+ if identifier != "quantumespresso.pw":
code_model.deactivate()
properties = self._get_properties()
for identifier, code_names in self.plugin_mapping.items():
if identifier in properties:
for code_name in code_names:
- self.codes[code_name].activate()
+ self.get_model(code_name).activate()
- def get_model_state(self):
- codes = {name: model.get_model_state() for name, model in self.codes.items()}
-
- return {"codes": codes}
-
- def set_model_state(self, code_data: dict):
- for name, code_model in self.codes.items():
- if name in code_data and code_model.is_active:
- code_model.set_model_state(code_data[name])
+ def update_plugin_overrides_notification(self):
+ if self.plugin_overrides:
+ formatted = "\n".join(
+ f"
+
The submission is blocked due to the following reason(s):
- {fmt_list}
+ {formatted}
"""
@@ -160,31 +167,29 @@ def get_model_state(self) -> dict[str, dict[str, dict]]:
parameters: dict = deepcopy(self.input_parameters) # type: ignore
parameters["codes"] = {
identifier: model.get_model_state()
- for identifier, model in self._models.items()
+ for identifier, model in self.get_models()
if model.include
}
return parameters
def set_model_state(self, parameters):
+ codes: dict = parameters.get("codes", {})
+
if "resources" in parameters:
- parameters["codes"] = {
- key: {"code": value} for key, value in parameters["codes"].items()
- }
- parameters["codes"]["pw"]["nodes"] = parameters["resources"]["num_machines"]
- parameters["codes"]["pw"]["cpus"] = parameters["resources"][
- "num_mpiprocs_per_machine"
- ]
- parameters["codes"]["pw"]["parallelization"] = {
- "npool": parameters["resources"]["npools"]
- }
+ resources = parameters["resources"]
+ codes |= {key: {"code": value} for key, value in codes.items()}
+ codes["pw"]["nodes"] = resources["num_machines"]
+ codes["pw"]["cpus"] = resources["num_mpiprocs_per_machine"]
+ codes["pw"]["parallelization"] = {"npool": resources["npools"]}
+
workchain_parameters: dict = parameters.get("workchain", {})
properties = set(workchain_parameters.get("properties", []))
- with self.hold_trait_notifications():
- for identifier, model in self._models.items():
- model.include = identifier in self._default_models | properties
- if parameters["codes"].get(identifier):
- model.set_model_state(parameters["codes"][identifier]["codes"])
- model.loaded_from_process = True
+ included = self._default_models | properties
+ for identifier, model in self.get_models():
+ model.include = identifier in included
+ if codes.get(identifier):
+ model.set_model_state(codes[identifier])
+ model.loaded_from_process = True
if self.process_node:
self.process_label = self.process_node.label
@@ -193,8 +198,8 @@ def set_model_state(self, parameters):
def get_selected_codes(self) -> dict[str, dict]:
return {
- name: code_model.get_model_state()
- for name, code_model in self.get_model("global").codes.items()
+ identifier: code_model.get_model_state()
+ for identifier, code_model in self.get_model("global").get_models()
if code_model.is_ready
}
@@ -203,7 +208,7 @@ def reset(self):
self.input_structure = None
self.input_parameters = {}
self.process_node = None
- for identifier, model in self._models.items():
+ for identifier, model in self.get_models():
if identifier not in self._default_models:
model.include = False
@@ -267,11 +272,9 @@ def _create_builder(self, parameters) -> ProcessBuilderNamespace:
return builder
def _check_submission_blockers(self):
- # Do not submit while any of the background setup processes are running.
if self.installing_qe or self.installing_sssp:
yield "Background setup processes must finish."
- # SSSP library not installed
if not self.sssp_installed:
yield "The SSSP library is not installed."
diff --git a/src/aiidalab_qe/common/code/model.py b/src/aiidalab_qe/common/code/model.py
index 4e40cc871..90b761575 100644
--- a/src/aiidalab_qe/common/code/model.py
+++ b/src/aiidalab_qe/common/code/model.py
@@ -24,6 +24,7 @@ class CodeModel(Model):
max_wallclock_seconds = tl.Int(3600 * 12)
allow_hidden_codes = tl.Bool(False)
allow_disabled_computers = tl.Bool(False)
+ override = tl.Bool(False)
def __init__(
self,
@@ -48,19 +49,24 @@ def __init__(
def is_ready(self):
return self.is_active and bool(self.selected)
+ @property
+ def first_option(self):
+ return self.options[0][1] if self.options else None # type: ignore
+
def activate(self):
self.is_active = True
def deactivate(self):
self.is_active = False
- def update(self, user_email: str):
- if not self.options:
+ def update(self, user_email="", refresh=False):
+ if not self.options or refresh:
self.options = self._get_codes(user_email)
- self.selected = self.options[0][1] if self.options else None
+ self.selected = self.first_option
def get_model_state(self) -> dict:
return {
+ "options": self.options,
"code": self.selected,
"nodes": self.num_nodes,
"cpus": self.num_cpus,
@@ -69,8 +75,12 @@ def get_model_state(self) -> dict:
"max_wallclock_seconds": self.max_wallclock_seconds,
}
- def set_model_state(self, parameters):
- self.selected = self._get_uuid(parameters["code"])
+ def set_model_state(self, parameters: dict):
+ self.selected = (
+ self._get_uuid(identifier)
+ if (identifier := parameters.get("code"))
+ else self.first_option
+ )
self.num_nodes = parameters.get("nodes", 1)
self.num_cpus = parameters.get("cpus", 1)
self.ntasks_per_node = parameters.get("ntasks_per_node", 1)
@@ -78,19 +88,15 @@ def set_model_state(self, parameters):
self.max_wallclock_seconds = parameters.get("max_wallclock_seconds", 3600 * 12)
def _get_uuid(self, identifier):
- if not self.selected:
- try:
- uuid = orm.load_code(identifier).uuid
- except NotExistent:
- uuid = None
- # If the code was imported from another user, it is not usable
- # in the app and thus will not be considered as an option!
- self.selected = uuid if uuid in [opt[1] for opt in self.options] else None
- return self.selected
-
- def _get_codes(self, user_email: str):
- # set default user_email if not provided
- user_email = user_email or orm.User.collection.get_default().email
+ try:
+ uuid = orm.load_code(identifier).uuid
+ except NotExistent:
+ uuid = None
+ # If the code was imported from another user, it is not usable
+ # in the app and thus will not be considered as an option!
+ return uuid if uuid in [opt[1] for opt in self.options] else None
+
+ def _get_codes(self, user_email: str = ""):
user = orm.User.collection.get(email=user_email)
filters = (
@@ -122,7 +128,7 @@ def _full_code_label(code):
class PwCodeModel(CodeModel):
- override = tl.Bool(False)
+ parallelization_override = tl.Bool(False)
npool = tl.Int(1)
def __init__(
@@ -142,14 +148,22 @@ def __init__(
def get_model_state(self) -> dict:
parameters = super().get_model_state()
- parameters["parallelization"] = {"npool": self.npool} if self.override else {}
+ parameters["parallelization"] = (
+ {
+ "npool": self.npool,
+ }
+ if self.parallelization_override
+ else {}
+ )
return parameters
def set_model_state(self, parameters):
super().set_model_state(parameters)
if "parallelization" in parameters and "npool" in parameters["parallelization"]:
- self.override = True
+ self.parallelization_override = True
self.npool = parameters["parallelization"].get("npool", 1)
+ else:
+ self.parallelization_override = False
CodesDict = dict[str, CodeModel]
diff --git a/src/aiidalab_qe/common/mixins.py b/src/aiidalab_qe/common/mixins.py
index 21421c233..c08bc73ad 100644
--- a/src/aiidalab_qe/common/mixins.py
+++ b/src/aiidalab_qe/common/mixins.py
@@ -31,12 +31,19 @@ class HasModels(t.Generic[T]):
def __init__(self):
self._models: dict[str, T] = {}
- def add_model(self, identifier, model):
+ def has_model(self, identifier):
+ return identifier in self._models
+
+ def add_model(self, identifier, model: T):
self._models[identifier] = model
self._link_model(model)
+ def add_models(self, models: dict[str, T]):
+ for identifier, model in models.items():
+ self.add_model(identifier, model)
+
def get_model(self, identifier) -> T:
- if identifier in self._models:
+ if self.has_model(identifier):
return self._models[identifier]
raise ValueError(f"Model with identifier '{identifier}' not found.")
@@ -44,7 +51,7 @@ def get_models(self) -> t.Iterable[tuple[str, T]]:
return self._models.items()
def _link_model(self, model: T):
- raise NotImplementedError()
+ pass
class HasProcess(tl.HasTraits):
diff --git a/src/aiidalab_qe/common/panel.py b/src/aiidalab_qe/common/panel.py
index 3cca22ab0..0620c382d 100644
--- a/src/aiidalab_qe/common/panel.py
+++ b/src/aiidalab_qe/common/panel.py
@@ -15,8 +15,9 @@
from aiida import orm
from aiida.common.extendeddicts import AttributeDict
+from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
from aiidalab_qe.common.code.model import CodeModel
-from aiidalab_qe.common.mixins import Confirmable, HasProcess
+from aiidalab_qe.common.mixins import Confirmable, HasModels, HasProcess
from aiidalab_qe.common.mvc import Model
from aiidalab_qe.common.widgets import (
LoadingWidget,
@@ -24,7 +25,7 @@
QEAppComputationalResourcesWidget,
)
-DEFAULT_PARAMETERS = {}
+DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore
class Panel(ipw.VBox):
@@ -83,6 +84,7 @@ def __init__(self, **kwargs):
class SettingsModel(Model):
title = "Model"
+ identifier = ""
dependencies: list[str] = []
include = tl.Bool(False)
@@ -90,18 +92,12 @@ class SettingsModel(Model):
_defaults = {}
- def update(self, specific=""):
- """Updates the model.
-
- Parameters
- ----------
- `specific` : `str`, optional
- If provided, specifies the level of update.
- """
+ def update(self):
+ """Updates the model."""
pass
def get_model_state(self) -> dict:
- """Retrieves the model current state as a dictionary."""
+ """Retrieves the current state of the model as a dictionary."""
raise NotImplementedError()
def set_model_state(self, parameters: dict):
@@ -118,12 +114,11 @@ def reset(self):
class SettingsPanel(Panel, t.Generic[SM]):
title = "Settings"
- description = ""
def __init__(self, model: SM, **kwargs):
from aiidalab_qe.common.widgets import LoadingWidget
- self.loading_message = LoadingWidget(f"Loading {self.identifier} settings")
+ self.loading_message = LoadingWidget(f"Loading {model.identifier} settings")
super().__init__(
children=[self.loading_message],
@@ -209,11 +204,10 @@ def _reset(self):
self._model.reset()
-class ResourceSettingsModel(SettingsModel):
- """Base model for plugin code setting models."""
+class ResourceSettingsModel(SettingsModel, HasModels[CodeModel]):
+ """Base model for resource setting models."""
- dependencies = ["global.global_codes"]
- codes = {} # To be defined by subclasses
+ dependencies = []
global_codes = tl.Dict(
key_trait=tl.Unicode(),
@@ -222,110 +216,60 @@ class ResourceSettingsModel(SettingsModel):
submission_blockers = tl.List(tl.Unicode())
submission_warning_messages = tl.Unicode("")
- override = tl.Bool(False)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Used by the code-setup thread to fetch code options
- self._default_user_email = orm.User.collection.get_default().email
+ self.DEFAULT_USER_EMAIL = orm.User.collection.get_default().email
- def refresh_codes(self):
- for _, code_model in self.codes.items():
- code_model.update(self._default_user_email)
+ def add_model(self, identifier, model):
+ super().add_model(identifier, model)
+ model.update(self.DEFAULT_USER_EMAIL)
- def update_code_from_global(self):
- # Skip the sync if the user has overridden the settings
- if self.override:
- return
- for _, code_model in self.codes.items():
- default_calc_job_plugin = code_model.default_calc_job_plugin
- if default_calc_job_plugin in self.global_codes:
- code_data = self.global_codes[default_calc_job_plugin]
- code_model.set_model_state(code_data)
+ def update_submission_blockers(self):
+ self.submission_blockers = list(self._check_submission_blockers())
def get_model_state(self):
- codes = {name: model.get_model_state() for name, model in self.codes.items()}
return {
- "codes": codes,
- "override": self.override,
+ "codes": {
+ identifier: code_model.get_model_state()
+ for identifier, code_model in self.get_models()
+ },
}
- def set_model_state(self, code_data: dict):
- for name, code_model in self.codes.items():
- if name in code_data:
- code_model.set_model_state(code_data[name])
+ def set_model_state(self, parameters: dict):
+ self.set_selected_codes(parameters.get("codes", {}))
- def reset(self):
- """Reset the model to its default state."""
- for code_model in self.codes.values():
- code_model.reset()
+ def get_selected_codes(self) -> dict[str, dict]:
+ return {
+ identifier: code_model.get_model_state()
+ for identifier, code_model in self.get_models()
+ if code_model.is_ready
+ }
+
+ def set_selected_codes(self, code_data=DEFAULT["codes"]):
+ for identifier, code_model in self.get_models():
+ if identifier in code_data:
+ code_model.set_model_state(code_data[identifier])
+
+ def _check_submission_blockers(self):
+ return []
RSM = t.TypeVar("RSM", bound=ResourceSettingsModel)
class ResourceSettingsPanel(SettingsPanel[RSM], t.Generic[RSM]):
- """Base class for plugin code setting panels."""
+ """Base class for resource setting panels."""
def __init__(self, model, **kwargs):
super().__init__(model, **kwargs)
- self.code_widgets = {}
- self.rendered = False
- self._model.observe(
- self._on_global_codes_change,
- "global_codes",
- )
- self._model.observe(
- self._on_override_change,
- "override",
- )
- def render(self):
- if self.rendered:
- return
- self.override_help = ipw.HTML(
- "Click to override the resource settings for this plugin."
- )
- self.override = ipw.Checkbox(
- description="",
- indent=False,
- layout=ipw.Layout(max_width="3%"),
- )
- ipw.link(
- (self._model, "override"),
- (self.override, "value"),
- )
- self.code_widgets_container = ipw.VBox()
self.code_widgets = {}
- self.children = [
- ipw.HBox([self.override, self.override_help]),
- self.code_widgets_container,
- ]
-
- self.rendered = True
-
- for code_model in self._model.codes.values():
- self._toggle_code(code_model)
- return self.code_widgets_container
-
- def _on_global_codes_change(self, _):
- self._model.update_code_from_global()
def _on_code_resource_change(self, _):
- """Update the submission blockers and warning messages."""
-
- def _on_override_change(self, change):
- if change["new"]:
- for code_widget in self.code_widgets.values():
- code_widget.num_nodes.disabled = False
- code_widget.num_cpus.disabled = False
- code_widget.code_selection.code_select_dropdown.disabled = False
- else:
- for code_widget in self.code_widgets.values():
- code_widget.num_nodes.disabled = True
- code_widget.num_cpus.disabled = True
- code_widget.code_selection.code_select_dropdown.disabled = True
+ pass
def _toggle_code(self, code_model: CodeModel):
if not self.rendered:
@@ -343,25 +287,23 @@ def _toggle_code(self, code_model: CodeModel):
code_widget = self.code_widgets[code_model.name]
if not code_model.is_rendered:
self._render_code_widget(code_model, code_widget)
+ code_widget.observe(
+ code_widget.update_resources,
+ "value",
+ )
def _render_code_widget(
self,
code_model: CodeModel,
code_widget: QEAppComputationalResourcesWidget,
):
- code_model.update(None)
ipw.dlink(
(code_model, "options"),
(code_widget.code_selection.code_select_dropdown, "options"),
)
ipw.link(
(code_model, "selected"),
- (code_widget.code_selection.code_select_dropdown, "value"),
- )
- ipw.dlink(
- (code_model, "selected"),
- (code_widget.code_selection.code_select_dropdown, "disabled"),
- lambda selected: not selected,
+ (code_widget, "value"),
)
ipw.link(
(code_model, "num_cpus"),
@@ -385,16 +327,25 @@ def _render_code_widget(
)
if isinstance(code_widget, PwCodeResourceSetupWidget):
ipw.link(
- (code_model, "override"),
+ (code_model, "parallelization_override"),
(code_widget.parallelization.override, "value"),
)
ipw.link(
(code_model, "npool"),
(code_widget.parallelization.npool, "value"),
)
+ code_model.observe(
+ self._on_code_resource_change,
+ [
+ "parallelization_override",
+ "npool",
+ ],
+ )
code_model.observe(
self._on_code_resource_change,
[
+ "options",
+ "selected",
"num_cpus",
"num_nodes",
"ntasks_per_node",
@@ -402,19 +353,163 @@ def _render_code_widget(
"max_wallclock_seconds",
],
)
- # disable the code widget if the override is not set
- code_widget.num_nodes.disabled = not self.override.value
- code_widget.num_cpus.disabled = not self.override.value
- code_widget.code_selection.code_select_dropdown.disabled = (
- not self.override.value
- )
-
code_widgets = self.code_widgets_container.children[:-1] # type: ignore
-
self.code_widgets_container.children = [*code_widgets, code_widget]
code_model.is_rendered = True
+class PluginResourceSettingsModel(ResourceSettingsModel):
+ """Base model for plugin resource setting models."""
+
+ dependencies = [
+ "global.global_codes",
+ ]
+
+ override = tl.Bool(False)
+
+ def update(self):
+ """Updates the code models from the global resources.
+
+ Skips synchronization with global resources if the user has chosen to override
+ the resources for the plugin codes.
+ """
+ if self.override:
+ return
+ for _, code_model in self.get_models():
+ default_calc_job_plugin = code_model.default_calc_job_plugin
+ if default_calc_job_plugin in self.global_codes:
+ code_resources: dict = self.global_codes[default_calc_job_plugin] # type: ignore
+ options = code_resources.get("options", [])
+ if options != code_model.options:
+ code_model.update(self.DEFAULT_USER_EMAIL, refresh=True)
+ code_model.set_model_state(code_resources)
+
+ def get_model_state(self):
+ return {
+ "override": self.override,
+ **super().get_model_state(),
+ }
+
+ def set_model_state(self, parameters: dict):
+ self.override = parameters.get("override", False)
+ super().set_model_state(parameters)
+
+ def _link_model(self, model: CodeModel):
+ tl.link(
+ (self, "override"),
+ (model, "override"),
+ )
+
+
+PRSM = t.TypeVar("PRSM", bound=PluginResourceSettingsModel)
+
+
+class PluginResourceSettingsPanel(ResourceSettingsPanel[PRSM], t.Generic[PRSM]):
+ """Base class for plugin resource setting panels."""
+
+ def __init__(self, model, **kwargs):
+ super().__init__(model, **kwargs)
+
+ self._model.observe(
+ self._on_global_codes_change,
+ "global_codes",
+ )
+ self._model.observe(
+ self._on_override_change,
+ "override",
+ )
+
+ def render(self):
+ if self.rendered:
+ return
+
+ self.override_help = ipw.HTML(
+ "Click to override the resource settings for this plugin."
+ )
+ self.override = ipw.Checkbox(
+ description="",
+ indent=False,
+ layout=ipw.Layout(max_width="3%"),
+ )
+ ipw.link(
+ (self._model, "override"),
+ (self.override, "value"),
+ )
+ self.code_widgets_container = ipw.VBox()
+
+ self.children = [
+ ipw.HBox(
+ children=[
+ self.override,
+ self.override_help,
+ ]
+ ),
+ self.code_widgets_container,
+ ]
+
+ self.rendered = True
+
+ # Render any active codes
+ for _, code_model in self._model.get_models():
+ self._toggle_code(code_model)
+
+ return self.code_widgets_container
+
+ def _on_global_codes_change(self, _):
+ self._model.update()
+
+ def _on_override_change(self, _):
+ self._model.update()
+
+ def _render_code_widget(
+ self,
+ code_model: CodeModel,
+ code_widget: QEAppComputationalResourcesWidget,
+ ):
+ super()._render_code_widget(code_model, code_widget)
+ self._link_override_to_widget_disable(code_model, code_widget)
+
+ def _link_override_to_widget_disable(self, code_model, code_widget):
+ """Links the override attribute of the code model to the disable attribute
+ of subwidgets of the code widget."""
+ ipw.dlink(
+ (code_model, "override"),
+ (code_widget.code_selection.code_select_dropdown, "disabled"),
+ lambda override: not override,
+ )
+ ipw.dlink(
+ (code_model, "override"),
+ (code_widget.num_cpus, "disabled"),
+ lambda override: not override,
+ )
+ ipw.dlink(
+ (code_model, "override"),
+ (code_widget.num_nodes, "disabled"),
+ lambda override: not override,
+ )
+ ipw.dlink(
+ (code_model, "override"),
+ (code_widget.code_selection.btn_setup_new_code, "disabled"),
+ lambda override: not override,
+ )
+ ipw.dlink(
+ (code_model, "override"),
+ (code_widget.btn_setup_resource_detail, "disabled"),
+ lambda override: not override,
+ )
+ if isinstance(code_widget, PwCodeResourceSetupWidget):
+ ipw.dlink(
+ (code_model, "override"),
+ (code_widget.parallelization.override, "disabled"),
+ lambda override: not override,
+ )
+ ipw.dlink(
+ (code_model, "override"),
+ (code_widget.parallelization.npool, "disabled"),
+ lambda override: not override,
+ )
+
+
class ResultsModel(Model, HasProcess):
title = "Model"
identifier = "model"
diff --git a/src/aiidalab_qe/common/widgets.py b/src/aiidalab_qe/common/widgets.py
index b53d2b23d..89f8673a4 100644
--- a/src/aiidalab_qe/common/widgets.py
+++ b/src/aiidalab_qe/common/widgets.py
@@ -702,8 +702,7 @@ def __init__(self, **kwargs):
)
traitlets.link((self.code_selection, "value"), (self, "value"))
- @traitlets.observe("value")
- def _update_resources(self, change):
+ def update_resources(self, change):
if change["new"]:
self.set_resource_defaults(load_code(change["new"]).computer)
diff --git a/src/aiidalab_qe/plugins/bands/__init__.py b/src/aiidalab_qe/plugins/bands/__init__.py
index 0c8b86ad6..2cc72a868 100644
--- a/src/aiidalab_qe/plugins/bands/__init__.py
+++ b/src/aiidalab_qe/plugins/bands/__init__.py
@@ -1,8 +1,8 @@
# from aiidalab_qe.bands.result import Result
from aiidalab_qe.common.panel import PluginOutline
-from .code import BandsResourceSettingsModel, BandsResourceSettingsPanel
from .model import BandsConfigurationSettingsModel
+from .resources import BandsResourceSettingsModel, BandsResourceSettingsPanel
from .result import BandsResultsModel, BandsResultsPanel
from .setting import BandsConfigurationSettingsPanel
from .workchain import workchain_and_builder
@@ -18,7 +18,7 @@ class BandsPluginOutline(PluginOutline):
"panel": BandsConfigurationSettingsPanel,
"model": BandsConfigurationSettingsModel,
},
- "code": {
+ "resources": {
"panel": BandsResourceSettingsPanel,
"model": BandsResourceSettingsModel,
},
diff --git a/src/aiidalab_qe/plugins/bands/code.py b/src/aiidalab_qe/plugins/bands/code.py
deleted file mode 100644
index 79571a00f..000000000
--- a/src/aiidalab_qe/plugins/bands/code.py
+++ /dev/null
@@ -1,26 +0,0 @@
-"""Panel for Bands plugin."""
-
-from aiidalab_qe.common.code.model import CodeModel, PwCodeModel
-from aiidalab_qe.common.panel import ResourceSettingsModel, ResourceSettingsPanel
-
-
-class BandsResourceSettingsModel(ResourceSettingsModel):
- """Model for the band structure plugin."""
-
- codes = {
- "pw": PwCodeModel(
- name="pw.x",
- description="pw.x",
- default_calc_job_plugin="quantumespresso.pw",
- ),
- "projwfc_bands": CodeModel(
- name="projwfc.x",
- description="projwfc.x",
- default_calc_job_plugin="quantumespresso.projwfc",
- ),
- }
-
-
-class BandsResourceSettingsPanel(ResourceSettingsPanel[BandsResourceSettingsModel]):
- title = "Band Structure"
- identifier = "bands"
diff --git a/src/aiidalab_qe/plugins/bands/resources.py b/src/aiidalab_qe/plugins/bands/resources.py
new file mode 100644
index 000000000..2905d72d8
--- /dev/null
+++ b/src/aiidalab_qe/plugins/bands/resources.py
@@ -0,0 +1,36 @@
+"""Panel for Bands plugin."""
+
+from aiidalab_qe.common.code.model import CodeModel, PwCodeModel
+from aiidalab_qe.common.panel import (
+ PluginResourceSettingsModel,
+ PluginResourceSettingsPanel,
+)
+
+
+class BandsResourceSettingsModel(PluginResourceSettingsModel):
+ """Model for the band structure plugin."""
+
+ identifier = "bands"
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.add_models(
+ {
+ "pw": PwCodeModel(
+ name="pw.x",
+ description="pw.x",
+ default_calc_job_plugin="quantumespresso.pw",
+ ),
+ "projwfc_bands": CodeModel(
+ name="projwfc.x",
+ description="projwfc.x",
+ default_calc_job_plugin="quantumespresso.projwfc",
+ ),
+ }
+ )
+
+
+class BandsResourceSettingsPanel(
+ PluginResourceSettingsPanel[BandsResourceSettingsModel],
+):
+ title = "Band Structure"
diff --git a/src/aiidalab_qe/plugins/pdos/__init__.py b/src/aiidalab_qe/plugins/pdos/__init__.py
index 280a90c39..c8defb4d9 100644
--- a/src/aiidalab_qe/plugins/pdos/__init__.py
+++ b/src/aiidalab_qe/plugins/pdos/__init__.py
@@ -1,7 +1,7 @@
from aiidalab_qe.common.panel import PluginOutline
-from .code import PdosResourceSettingsModel, PdosResourceSettingsPanel
from .model import PdosConfigurationSettingsModel
+from .resources import PdosResourceSettingsModel, PdosResourceSettingsPanel
from .result import PdosResultsModel, PdosResultsPanel
from .setting import PdosConfigurationSettingPanel
from .workchain import workchain_and_builder
@@ -17,7 +17,7 @@ class PdosPluginOutline(PluginOutline):
"panel": PdosConfigurationSettingPanel,
"model": PdosConfigurationSettingsModel,
},
- "code": {
+ "resources": {
"panel": PdosResourceSettingsPanel,
"model": PdosResourceSettingsModel,
},
diff --git a/src/aiidalab_qe/plugins/pdos/code.py b/src/aiidalab_qe/plugins/pdos/code.py
deleted file mode 100644
index 1e2095a25..000000000
--- a/src/aiidalab_qe/plugins/pdos/code.py
+++ /dev/null
@@ -1,31 +0,0 @@
-"""Panel for PDOS plugin."""
-
-from aiidalab_qe.common.code.model import CodeModel, PwCodeModel
-from aiidalab_qe.common.panel import ResourceSettingsModel, ResourceSettingsPanel
-
-
-class PdosResourceSettingsModel(ResourceSettingsModel):
- """Model for the pdos code setting plugin."""
-
- codes = {
- "pw": PwCodeModel(
- name="pw.x",
- description="pw.x",
- default_calc_job_plugin="quantumespresso.pw",
- ),
- "dos": CodeModel(
- name="dos.x",
- description="dos.x",
- default_calc_job_plugin="quantumespresso.dos",
- ),
- "projwfc": CodeModel(
- name="projwfc.x",
- description="projwfc.x",
- default_calc_job_plugin="quantumespresso.projwfc",
- ),
- }
-
-
-class PdosResourceSettingsPanel(ResourceSettingsPanel[PdosResourceSettingsModel]):
- title = "PDOS"
- identifier = "pdos"
diff --git a/src/aiidalab_qe/plugins/pdos/resources.py b/src/aiidalab_qe/plugins/pdos/resources.py
new file mode 100644
index 000000000..0d17798f7
--- /dev/null
+++ b/src/aiidalab_qe/plugins/pdos/resources.py
@@ -0,0 +1,41 @@
+"""Panel for PDOS plugin."""
+
+from aiidalab_qe.common.code.model import CodeModel, PwCodeModel
+from aiidalab_qe.common.panel import (
+ PluginResourceSettingsModel,
+ PluginResourceSettingsPanel,
+)
+
+
+class PdosResourceSettingsModel(PluginResourceSettingsModel):
+ """Model for the pdos code setting plugin."""
+
+ identifier = "pdos"
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.add_models(
+ {
+ "pw": PwCodeModel(
+ name="pw.x",
+ description="pw.x",
+ default_calc_job_plugin="quantumespresso.pw",
+ ),
+ "dos": CodeModel(
+ name="dos.x",
+ description="dos.x",
+ default_calc_job_plugin="quantumespresso.dos",
+ ),
+ "projwfc": CodeModel(
+ name="projwfc.x",
+ description="projwfc.x",
+ default_calc_job_plugin="quantumespresso.projwfc",
+ ),
+ }
+ )
+
+
+class PdosResourceSettingsPanel(
+ PluginResourceSettingsPanel[PdosResourceSettingsModel],
+):
+ title = "PDOS"
diff --git a/src/aiidalab_qe/plugins/xas/__init__.py b/src/aiidalab_qe/plugins/xas/__init__.py
index 76a4af001..1d2adcac6 100644
--- a/src/aiidalab_qe/plugins/xas/__init__.py
+++ b/src/aiidalab_qe/plugins/xas/__init__.py
@@ -1,17 +1,22 @@
-from importlib import resources
+from importlib import resources as importlib_resources
import yaml
from aiidalab_qe.common.panel import PluginOutline
from aiidalab_qe.plugins import xas as xas_folder
-from .code import XasResourceSettingsModel, XasResourceSettingsPanel
from .model import XasConfigurationSettingsModel
+from .resources import XasResourceSettingsModel, XasResourceSettingsPanel
from .result import XasResultsModel, XasResultsPanel
from .setting import XasConfigurationSettingsPanel
from .workchain import workchain_and_builder
-PSEUDO_TOC = yaml.safe_load(resources.read_text(xas_folder, "pseudo_toc.yaml"))
+PSEUDO_TOC = yaml.safe_load(
+ importlib_resources.read_text(
+ xas_folder,
+ "pseudo_toc.yaml",
+ )
+)
class XasPluginOutline(PluginOutline):
@@ -24,7 +29,7 @@ class XasPluginOutline(PluginOutline):
"panel": XasConfigurationSettingsPanel,
"model": XasConfigurationSettingsModel,
},
- "code": {
+ "resources": {
"panel": XasResourceSettingsPanel,
"model": XasResourceSettingsModel,
},
diff --git a/src/aiidalab_qe/plugins/xas/code.py b/src/aiidalab_qe/plugins/xas/code.py
deleted file mode 100644
index ff07fb8f9..000000000
--- a/src/aiidalab_qe/plugins/xas/code.py
+++ /dev/null
@@ -1,26 +0,0 @@
-"""Panel for XAS plugin."""
-
-from aiidalab_qe.common.code.model import CodeModel, PwCodeModel
-from aiidalab_qe.common.panel import ResourceSettingsModel, ResourceSettingsPanel
-
-
-class XasResourceSettingsModel(ResourceSettingsModel):
- """Model for the XAS plugin."""
-
- codes = {
- "pw": PwCodeModel(
- name="pw.x",
- description="pw.x",
- default_calc_job_plugin="quantumespresso.pw",
- ),
- "xspectra": CodeModel(
- name="xspectra.x",
- description="xspectra.x",
- default_calc_job_plugin="quantumespresso.xspectra",
- ),
- }
-
-
-class XasResourceSettingsPanel(ResourceSettingsPanel[XasResourceSettingsModel]):
- title = "XAS Structure"
- identifier = "xas"
diff --git a/src/aiidalab_qe/plugins/xas/resources.py b/src/aiidalab_qe/plugins/xas/resources.py
new file mode 100644
index 000000000..7e9f77a9d
--- /dev/null
+++ b/src/aiidalab_qe/plugins/xas/resources.py
@@ -0,0 +1,36 @@
+"""Panel for XAS plugin."""
+
+from aiidalab_qe.common.code.model import CodeModel, PwCodeModel
+from aiidalab_qe.common.panel import (
+ PluginResourceSettingsModel,
+ PluginResourceSettingsPanel,
+)
+
+
+class XasResourceSettingsModel(PluginResourceSettingsModel):
+ """Model for the XAS plugin."""
+
+ identifier = "xas"
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.add_models(
+ {
+ "pw": PwCodeModel(
+ name="pw.x",
+ description="pw.x",
+ default_calc_job_plugin="quantumespresso.pw",
+ ),
+ "xspectra": CodeModel(
+ name="xspectra.x",
+ description="xspectra.x",
+ default_calc_job_plugin="quantumespresso.xspectra",
+ ),
+ }
+ )
+
+
+class XasResourceSettingsPanel(
+ PluginResourceSettingsPanel[XasResourceSettingsModel],
+):
+ title = "XAS Structure"
diff --git a/tests/conftest.py b/tests/conftest.py
index bacfeb26c..0494ff941 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -429,19 +429,12 @@ def app(pw_code, dos_code, projwfc_code, projwfc_bands_code):
app.submit_model.qe_installed = True
# set up codes
- pw_code_model = app.submit_model.get_model("global").get_code("quantumespresso.pw")
- dos_code_model = app.submit_model.get_model("global").get_code(
- "quantumespresso.dos"
- )
- projwfc_code_model = app.submit_model.get_model("global").get_code(
- "quantumespresso.projwfc"
- )
-
- pw_code_model.activate()
- dos_code_model.activate()
- projwfc_code_model.activate()
+ global_model = app.submit_model.get_model("global")
+ global_model.get_model("quantumespresso.pw").activate()
+ global_model.get_model("quantumespresso.dos").activate()
+ global_model.get_model("quantumespresso.projwfc").activate()
- app.submit_model.get_model("global").set_selected_codes(
+ global_model.set_selected_codes(
{
"pw": {"code": pw_code.label},
"dos": {"code": dos_code.label},
@@ -509,7 +502,9 @@ def _submit_app_generator(
app.configure_model.confirm()
app.submit_model.input_structure = generate_structure_data()
- app.submit_model.get_model("global").get_code("quantumespresso.pw").num_cpus = 2
+ app.submit_model.get_model("global").get_model(
+ "quantumespresso.pw"
+ ).num_cpus = 2
return app
@@ -818,7 +813,9 @@ def _generate_qeapp_workchain(
app.configure_model.confirm()
# step 3 setup code and resources
- app.submit_model.get_model("global").get_code("quantumespresso.pw").num_cpus = 4
+ app.submit_model.get_model("global").get_model(
+ "quantumespresso.pw"
+ ).num_cpus = 4
parameters = app.submit_model.get_model_state()
builder = app.submit_model._create_builder(parameters)
diff --git a/tests/test_codes.py b/tests/test_codes.py
index 984215606..8df05c6ec 100644
--- a/tests/test_codes.py
+++ b/tests/test_codes.py
@@ -7,7 +7,7 @@ def test_code_not_selected(submit_app_generator):
"""Test if there is an error when the code is not selected."""
app: App = submit_app_generator(properties=["dos"])
model = app.submit_model
- model.get_model("global").get_code("quantumespresso.dos").selected = None
+ model.get_model("global").get_model("quantumespresso.dos").selected = None
# Check builder construction passes without an error
parameters = model.get_model_state()
model._create_builder(parameters)
@@ -19,8 +19,8 @@ def test_set_selected_codes(submit_app_generator):
parameters = app.submit_model.get_model_state()
model = SubmissionStepModel()
_ = SubmitQeAppWorkChainStep(model=model, qe_auto_setup=False)
- for name, code_model in app.submit_model.get_model("global").codes.items():
- model.get_model("global").get_code(name).is_active = code_model.is_active
+ for identifier, code_model in app.submit_model.get_model("global").get_models():
+ model.get_model("global").get_model(identifier).is_active = code_model.is_active
model.qe_installed = True
model.get_model("global").set_selected_codes(parameters["codes"]["global"]["codes"])
assert model.get_selected_codes() == app.submit_model.get_selected_codes()
@@ -32,23 +32,14 @@ def test_update_codes_display(app: App):
"""
app.submit_step.render()
model = app.submit_model
- model.get_model("global").update_active_codes()
- assert (
- app.submit_step.global_code_settings.code_widgets["dos"].layout.display
- == "none"
- )
+ global_model = model.get_model("global")
+ global_model.update_active_codes()
+ global_resources = app.submit_step.global_resources
+ assert global_resources.code_widgets["dos"].layout.display == "none"
model.input_parameters = {"workchain": {"properties": ["pdos"]}}
- model.get_model("global").update_active_codes()
- assert (
- app.submit_step._model.get_model("global")
- .codes["quantumespresso.dos"]
- .is_active
- is True
- )
- assert (
- app.submit_step.global_code_settings.code_widgets["dos"].layout.display
- == "block"
- )
+ global_model.update_active_codes()
+ assert global_model.get_model("quantumespresso.dos").is_active is True
+ assert global_resources.code_widgets["dos"].layout.display == "block"
def test_check_submission_blockers(app: App):
@@ -63,7 +54,7 @@ def test_check_submission_blockers(app: App):
assert len(model.internal_submission_blockers) == 0
# set dos code to None, will introduce another blocker
- dos_code = model.get_model("global").get_code("quantumespresso.dos")
+ dos_code = model.get_model("global").get_model("quantumespresso.dos")
dos_value = dos_code.selected
dos_code.selected = None
model.update_submission_blockers()
@@ -78,16 +69,16 @@ def test_check_submission_blockers(app: App):
def test_qeapp_computational_resources_widget(app: App):
"""Test QEAppComputationalResourcesWidget."""
app.submit_step.render()
- pw_code_model = app.submit_model.get_model("global").get_code("quantumespresso.pw")
- pw_code_widget = app.submit_step.global_code_settings.code_widgets["pw"]
+ global_model = app.submit_model.get_model("global")
+ global_resources = app.submit_step.global_resources
+ pw_code_model = global_model.get_model("quantumespresso.pw")
+ pw_code_widget = global_resources.code_widgets["pw"]
assert pw_code_widget.parallelization.npool.layout.display == "none"
- pw_code_model.override = True
+ pw_code_model.parallelization_override = True
pw_code_model.npool = 2
assert pw_code_widget.parallelization.npool.layout.display == "block"
assert pw_code_widget.parameters == {
- "code": app.submit_step.global_code_settings.code_widgets[
- "pw"
- ].value, # TODO why None?
+ "code": global_resources.code_widgets["pw"].value,
"cpus": 1,
"cpus_per_task": 1,
"max_wallclock_seconds": 43200,
diff --git a/tests/test_submit_qe_workchain.py b/tests/test_submit_qe_workchain.py
index e0c367a69..90d06d58f 100644
--- a/tests/test_submit_qe_workchain.py
+++ b/tests/test_submit_qe_workchain.py
@@ -16,6 +16,7 @@ def test_create_builder_default(
app.submit_model._create_builder(parameters)
# since uuid is specific to each run, we remove it from the output
ui_parameters = remove_uuid_fields(parameters)
+ remove_code_options(ui_parameters)
# regression test for the parameters generated by the app
# this parameters are passed to the workchain
data_regression.check(ui_parameters)
@@ -144,16 +145,17 @@ def test_warning_messages(
app: App = submit_app_generator(properties=["bands", "pdos"])
submit_model = app.submit_model
+ global_model = submit_model.get_model("global")
- pw_code = submit_model.get_model("global").get_code("quantumespresso.pw")
+ pw_code = global_model.get_model("quantumespresso.pw")
pw_code.num_cpus = 1
- submit_model.get_model("global").check_resources()
+ global_model.check_resources()
# no warning:
assert submit_model.submission_warning_messages == ""
# now we increase the resources, so we should have the Warning-3
pw_code.num_cpus = len(os.sched_getaffinity(0))
- submit_model.get_model("global").check_resources()
+ global_model.check_resources()
for suggestion in ["avoid_overloading", "go_remote"]:
assert suggestions[suggestion] in submit_model.submission_warning_messages
@@ -161,12 +163,10 @@ def test_warning_messages(
structure = generate_structure_data("H2O-larger")
submit_model.input_structure = structure
pw_code.num_cpus = 1
- submit_model.get_model("global").check_resources()
+ global_model.check_resources()
num_sites = len(structure.sites)
volume = structure.get_cell_volume()
- estimated_CPUs = submit_model.get_model("global")._estimate_min_cpus(
- num_sites, volume
- )
+ estimated_CPUs = global_model._estimate_min_cpus(num_sites, volume)
assert estimated_CPUs == 2
for suggestion in ["more_resources", "change_configuration"]:
assert suggestions[suggestion] in submit_model.submission_warning_messages
@@ -232,3 +232,10 @@ def remove_uuid_fields(data):
else:
# Return the value unchanged if it's not a dictionary or list
return data
+
+
+def remove_code_options(parameters):
+ """Remove the code options from the parameters."""
+ for panel in parameters["codes"].values(): # type: ignore
+ for code in panel["codes"].values():
+ del code["options"]