Skip to content

Commit

Permalink
dielectric results model and widget
Browse files Browse the repository at this point in the history
  • Loading branch information
AndresOrtegaGuerrero committed Nov 30, 2024
1 parent 8ea879c commit 04d5d17
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 186 deletions.
4 changes: 3 additions & 1 deletion src/aiidalab_qe_vibroscopy/app/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ class VibroConfigurationSettingsModel(ConfigurationSettingsModel, HasInputStruct
trait=tl.Int(),
default_value=[2, 2, 2],
)
supercell_number_estimator = tl.Unicode("?")
supercell_number_estimator = tl.Unicode(
"Click the button to estimate the supercell size."
)

def get_model_state(self):
return {
Expand Down
7 changes: 0 additions & 7 deletions src/aiidalab_qe_vibroscopy/app/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np

from ..utils.raman.result import export_iramanworkchain_data
from ..utils.dielectric.result import export_dielectric_data, DielectricResults
from ..utils.phonons.result import export_phononworkchain_data

from ..utils.euphonic import (
Expand Down Expand Up @@ -86,7 +85,6 @@ def _update_view(self):
spectra_data = export_iramanworkchain_data(self.node)
phonon_data = export_phononworkchain_data(self.node)
ins_data = export_euphonic_data(self.node)
dielectric_data = export_dielectric_data(self.node)

if phonon_data:
phonon_children = ()
Expand Down Expand Up @@ -208,11 +206,6 @@ def _update_view(self):
)
tab_titles.append("Raman/IR spectra")

if dielectric_data:
dielectric_results = DielectricResults(dielectric_data)
children_result_widget += (dielectric_results,)
tab_titles.append("Dielectric properties")

# euphonic
if ins_data:
intensity_maps = EuphonicSuperWidget(
Expand Down
21 changes: 19 additions & 2 deletions src/aiidalab_qe_vibroscopy/app/result/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from aiidalab_qe_vibroscopy.app.result.model import VibroResultsModel
from aiidalab_qe.common.panel import ResultsPanel

from aiidalab_qe_vibroscopy.app.widgets.dielectricwidget import DielectricWidget
from aiidalab_qe_vibroscopy.app.widgets.dielectricmodel import DielectricModel
import ipywidgets as ipw


Expand All @@ -19,6 +21,10 @@ def render(self):
layout=ipw.Layout(min_height="250px"),
selected_index=None,
)
self.tabs.observe(
self._on_tab_change,
"selected_index",
)

tab_data = []
# vibro_node = self._model.get_vibro_node()
Expand All @@ -29,8 +35,14 @@ def render(self):
if self._model.needs_raman_tab():
tab_data.append(("Raman", ipw.HTML("raman_data")))

if self._model.needs_dielectric_tab():
tab_data.append(("Dielectric", ipw.HTML("dielectric_data")))
dielectric_data = self._model.needs_dielectric_tab()

if dielectric_data:
dielectric_model = DielectricModel()
dielectric_widget = DielectricWidget(
model=dielectric_model, dielectric_data=dielectric_data
)
tab_data.append(("Dielectric Properties", dielectric_widget))

if self._model.needs_euphonic_tab():
tab_data.append(("Euphonic", ipw.HTML("euphonic_data")))
Expand All @@ -43,3 +55,8 @@ def render(self):

self.children = [self.tabs]
self.rendered = True

def _on_tab_change(self, change):
if (tab_index := change["new"]) is None:
return
self.tabs.children[tab_index].render() # type: ignore
153 changes: 150 additions & 3 deletions src/aiidalab_qe_vibroscopy/app/widgets/dielectricmodel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,156 @@
from aiidalab_qe.common.mvc import Model
from aiida.common.extendeddicts import AttributeDict
import traitlets as tl

from aiidalab_qe_vibroscopy.utils.dielectric.result import NumpyEncoder
import numpy as np
import base64
import json
from IPython.display import display

class DielectricModel(Model):
vibro = tl.Instance(AttributeDict, allow_none=True)

class DielectricModel(Model):
dielectric_data = {}

site_selector_options = tl.List(
trait=tl.Tuple((tl.Unicode(), tl.Int())),
)

dielectric_tensor_table = tl.Unicode("")
born_charges_table = tl.Unicode("")
raman_tensors_table = tl.Unicode("")
site = tl.Int()

def set_initial_values(self):
"""Set the initial values for the model."""

self.dielectric_tensor_table = self._create_dielectric_tensor_table()
self.born_charges_table = self._create_born_charges_table(0)
self.raman_tensors_table = self._create_raman_tensors_table(0)
self.site_selector_options = self._get_site_selector_options()

def _get_site_selector_options(self):
"""Get the site selector options."""
if not self.dielectric_data:
return []

unit_cell_sites = self.dielectric_data["unit_cell"]
decimal_places = 5
# Create the options with rounded positions
site_selector_options = [
(
f"{site.kind_name} @ ({', '.join(f'{coord:.{decimal_places}f}' for coord in site.position)})",
index,
)
for index, site in enumerate(unit_cell_sites)
]
return site_selector_options

def _create_dielectric_tensor_table(self):
"""Create the HTML table for the dielectric tensor."""
if not self.dielectric_data:
return ""

dielectric_tensor = self.dielectric_data["dielectric_tensor"]
table_data = self._generate_table(dielectric_tensor)
return table_data

def _create_born_charges_table(self, site_index):
"""Create the HTML table for the Born charges."""
if not self.dielectric_data:
return ""

born_charges = self.dielectric_data["born_charges"]
round_data = born_charges[site_index].round(6)
table_data = self._generate_table(round_data)
return table_data

def _create_raman_tensors_table(self, site_index):
"""Create the HTML table for the Raman tensors."""
if not self.dielectric_data:
return ""

raman_tensors = self.dielectric_data["raman_tensors"]
round_data = raman_tensors[site_index].round(6)
table_data = self._generate_table(round_data, cell_width="200px")
return table_data

def download_data(self, _=None):
"""Function to download the data."""
if self.dielectric_data:
data_to_print = {
key: value
for key, value in self.dielectric_data.items()
if key != "unit_cell"
}
file_name = "dielectric_data.json"
json_str = json.dumps(data_to_print, cls=NumpyEncoder)
b64_str = base64.b64encode(json_str.encode()).decode()
self._download(payload=b64_str, filename=file_name)

@staticmethod
def _download(payload, filename):
"""Download payload as a file named as filename."""
from IPython.display import Javascript

javas = Javascript(
f"""
var link = document.createElement('a');
link.href = 'data:text/json;charset=utf-8;base64,{payload}'
link.download = "{filename}"
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
"""
)
display(javas)

def on_site_selection_change(self, site):
self.site = site
self.born_charges_table = self._create_born_charges_table(site)
self.raman_tensors_table = self._create_raman_tensors_table(site)

def _generate_table(self, data, cell_width="50px"):
rows = []
for row in data:
cells = []
for value in row:
# Check if value is a numpy array
if isinstance(value, np.ndarray):
# Format the numpy array as a string, e.g., "[0, 0, 1]"
value_str = np.array2string(
value, separator=", ", formatter={"all": lambda x: f"{x:.6g}"}
)
cell = f"<td>{value_str}</td>"
elif isinstance(value, str) and value == "special":
# Handle the "special" keyword
cell = f'<td class="blue-cell">{value}</td>'
else:
# Handle other types (numbers, strings, etc.)
cell = f"<td>{value}</td>"
cells.append(cell)
rows.append(f"<tr>{''.join(cells)}</tr>")

# Define the HTML with styles, using the dynamic cell width
table_html = f"""
<style>
table {{
border-collapse: collapse;
width: auto; /* Adjust to content */
}}
th, td {{
border: 1px solid black;
text-align: center;
padding: 4px;
height: 12px;
width: {cell_width}; /* Set custom cell width */
}}
.blue-cell {{
background-color: gray;
color: white;
}}
</style>
<table>
{''.join(rows)}
</table>
"""
return table_html
73 changes: 72 additions & 1 deletion src/aiidalab_qe_vibroscopy/app/widgets/dielectricwidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ class DielectricWidget(ipw.VBox):
Widget for displaying dielectric properties results
"""

def __init__(self, model: DielectricModel, dielectric_node: None, **kwargs):
def __init__(self, model: DielectricModel, dielectric_data: None, **kwargs):
super().__init__(
children=[LoadingWidget("Loading widgets")],
**kwargs,
)
self._model = model

self.rendered = False
self._model.dielectric_data = dielectric_data

def render(self):
if self.rendered:
return
Expand All @@ -31,8 +34,76 @@ def render(self):
</div>"""
)

self.site_selector = ipw.Dropdown(
layout=ipw.Layout(width="450px"),
description="Select atom site:",
style={"description_width": "initial"},
)
ipw.dlink(
(self._model, "site_selector_options"),
(self.site_selector, "options"),
)
self.site_selector.observe(self._on_site_change, names="value")

self.download_button = ipw.Button(
description="Download Data", icon="download", button_style="primary"
)

self.download_button.on_click(self._model.download_data)

# HTML table with the dielectric tensor
self.dielectric_tensor_table = ipw.HTML()
ipw.link(
(self._model, "dielectric_tensor_table"),
(self.dielectric_tensor_table, "value"),
)

# HTML table with the Born charges @ site
self.born_charges_table = ipw.HTML()
ipw.link(
(self._model, "born_charges_table"),
(self.born_charges_table, "value"),
)

# HTML table with the Raman tensors @ site
self.raman_tensors_table = ipw.HTML()
ipw.link(
(self._model, "raman_tensors_table"),
(self.raman_tensors_table, "value"),
)

self.children = [
self.dielectric_results_help,
ipw.HTML("<h3>Dielectric tensor</h3>"),
self.dielectric_tensor_table,
self.site_selector,
ipw.HBox(
[
ipw.VBox(
[
ipw.HTML("<h3>Born effective charges</h3>"),
self.born_charges_table,
]
),
ipw.VBox(
[
ipw.HTML("<h3>Raman Tensor </h3>"),
self.raman_tensors_table,
]
),
]
),
self.download_button,
]

self.rendered = True
self._initial_view()

def _initial_view(self):
self._model.set_initial_values()
self.dielectric_tensor_table.layout = ipw.Layout(width="300px", height="auto")
self.born_charges_table.layout = ipw.Layout(width="300px", height="auto")
# self.raman_tensors_table.layout = ipw.Layout(width="auto", height="auto")

def _on_site_change(self, change):
self._model.on_site_selection_change(change["new"])
Loading

0 comments on commit 04d5d17

Please sign in to comment.