Skip to content

Commit

Permalink
Getting closer with model injection
Browse files Browse the repository at this point in the history
  • Loading branch information
edan-bainglass committed Sep 22, 2024
1 parent c3bc184 commit 1f1206a
Show file tree
Hide file tree
Showing 17 changed files with 473 additions and 400 deletions.
185 changes: 79 additions & 106 deletions src/aiidalab_qe/app/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
import ipywidgets as ipw
import traitlets as tl

from aiida_quantumespresso.common.types import RelaxType
from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
from aiidalab_qe.app.utils import get_entry_items
from aiidalab_widgets_base import WizardAppWidgetStep

from .advanced import AdvancedSettings
from .model import AdvancedModel, ConfigurationModel, WorkChainModel
from .model import ConfigurationModel
from .workflow import WorkChainSettings

DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore
Expand All @@ -40,32 +39,18 @@ def __init__(self, model: ConfigurationModel, **kwargs):

self._model = model

workchain_model = WorkChainModel()
advanced_model = AdvancedModel()
ipw.dlink(
(self._model, "input_structure"),
(advanced_model, "input_structure"),
)
ipw.dlink(
(workchain_model, "protocol"),
(advanced_model, "protocol"),
)
ipw.dlink(
(workchain_model, "spin_type"),
(advanced_model, "spin_type"),
)
ipw.dlink(
(workchain_model, "electronic_type"),
(advanced_model, "electronic_type"),
)
# TODO necessary?
self._model.observe(
lambda _: advanced_model.update_kpoints_mesh(),
self._on_input_structure_change,
"input_structure",
)

self.workchain_settings = WorkChainSettings(model=workchain_model)
self.advanced_settings = AdvancedSettings(model=advanced_model)
self._model.workchain.observe(
self._on_protocol_change,
"protocol",
)

self.workchain_settings = WorkChainSettings(model=model)
self.advanced_settings = AdvancedSettings(model=model)

self.built_in_settings = [
self.workchain_settings,
Expand Down Expand Up @@ -126,15 +111,6 @@ def render(self):

self.rendered = True

def reset(self):
"""Reset the widgets in all settings to their initial states."""
with self.hold_trait_notifications():
for _, settings in self.settings.items():
if settings.rendered:
settings.reset()
for key, p in self.properties.items():
p.run.value = key in DEFAULT["workchain"]["properties"]

def is_saved(self):
"""Check if the current step is saved.
That all changes are confirmed.
Expand All @@ -148,27 +124,17 @@ def confirm(self, _=None):
self.state = self.State.SUCCESS

def get_configuration_parameters(self):
parameters = {
setting.identifier: setting.get_panel_value()
for setting in self.settings.values()
}
properties = self._get_properties()
# TODO necessary to store the properties in the workchain settings?
parameters["workchain"].update("properties", properties)
self._model.get_model_state()

def set_configuration_parameters(self, parameters):
"""Set the inputs in the GUI based on a set of parameters."""
self._model.set_model_state(parameters)

# TODO check logic
def reset(self):
with self.hold_trait_notifications():
for identifier, settings in self.settings.items():
if parameters.get(identifier):
settings.set_panel_value(parameters[identifier])
properties = parameters.get("properties", [])
for name in self.properties:
if name in properties:
self.properties[name].run.value = True
else:
self.properties[name].run.value = False
self._model.reset()
for _, settings in self.settings.items():
settings.reset()

@tl.observe("previous_step_state")
def _on_previous_step_state_change(self, change):
Expand All @@ -179,89 +145,96 @@ def _on_tab_change(self, change):
return
self.tab.children[tab].render() # type: ignore

def _on_input_structure_change(self, _):
self._model.advanced.update()

def _on_protocol_change(self, _):
self._model.advanced.update()

def _fetch_setting_entries(self):
"""Handle plugin specific settings."""

self.properties = {}
self.reminder_info = {}
self.property_children = [ipw.HTML("Select which properties to calculate:")]
entries = get_entry_items("aiidalab_qe.properties", "outline")

outlines = get_entry_items("aiidalab_qe.properties", "outline")
models = get_entry_items("aiidalab_qe.properties", "model")
settings = get_entry_items("aiidalab_qe.properties", "setting")
for (name, entry_point), setting in zip(entries.items(), settings.values()):
self.properties[name] = entry_point()
self.properties[name].run.observe(self._update_panel, "value")
self.reminder_info[name] = ipw.HTML()
for identifier in settings:
outline = outlines[identifier]()
model = models[identifier]()
info = ipw.HTML()
ipw.link(
(model, "include_plugin"),
(outline.include_plugin, "value"),
)

def toggle_plugin_model(
change,
identifier=identifier,
model=model,
info=info,
):
self._update_panel()
if change["new"]:
self._model.add_model(identifier, model)
info.value = f"Customize {identifier} settings below"
else:
self._model.remove_model(identifier)
info.value = ""

model.observe(toggle_plugin_model, "include_plugin")

self.properties[identifier] = outline
self.property_children.append(
ipw.HBox(
children=[
self.properties[name],
self.reminder_info[name],
outline,
info,
]
)
)

def update_reminder_info(change, name=name):
info = self.reminder_info[name]
info.value = (
f"Customize {name} settings in the corresponding tab"
if change["new"]
else ""
)

if name in settings:
self.properties[name].run.observe(update_reminder_info, "value")
kwargs = {"parent": self, "identifier": name}
# TODO drop check in the future - plugin models should be required
if name in models:
kwargs["model"] = models[name]()
self.settings[name] = setting(**kwargs)

self.property_children.append(
ipw.HTML("""
<div style="line-height: 140%; padding-top: 10px; padding-bottom: 0px">
The band structure workflow will automatically detect the default
path in reciprocal space using the
<a href="https://www.materialscloud.org/work/tools/seekpath" target="_blank">SeeK-path tool</a>.
</div>
""")
)
kwargs = {
"parent": self,
"identifier": identifier,
"config_model": self._model,
}
self.settings[identifier] = settings[identifier](**kwargs)

# # TODO move this somewhere else (below bands ideally)
# self.property_children.append(
# ipw.HTML("""
# <div style="line-height: 140%; padding-top: 10px; padding-bottom: 0px">
# The band structure workflow will automatically detect the default
# path in reciprocal space using the
# <a href="https://www.materialscloud.org/work/tools/seekpath" target="_blank">SeeK-path tool</a>.
# </div>
# """)
# )

def _update_panel(self, _=None):
"""Dynamic add/remove the panel based on the selected properties."""
self.tab.children = self.built_in_settings
for identifier in self.properties:
if identifier in self.settings and self.properties[identifier].run.value:
self.tab.children += (self.settings[identifier],)
model = self._model.get_model(identifier)
setting = self.settings[identifier]
if model and model.include_plugin:
self.tab.children += (setting,)
self.tab.set_title(
len(self.tab.children) - 1, self.settings[identifier].title
len(self.tab.children) - 1,
setting.title,
)

def _update_state(self, previous_step_state):
if previous_step_state == self.State.SUCCESS:
self.state = self.State.CONFIGURED
for settings in self.settings.values():
settings._update_state()
# # TODO why?
# for settings in self.settings.values():
# settings._update_state()
elif previous_step_state == self.State.FAIL:
self.state = self.State.FAIL
else:
self.state = self.State.INIT
self.reset()

def _get_properties(self):
properties = []
run_bands = False
run_pdos = False
for name in self.properties:
if self.properties[name].run.value:
properties.append(name)
if name == "bands":
run_bands = True
elif name == "pdos":
run_bands = True

if RelaxType(self._model.basic.relax_type) is not RelaxType.NONE or not (
run_bands or run_pdos
):
properties.append("relax")
return properties
# self.reset()
Loading

0 comments on commit 1f1206a

Please sign in to comment.