diff --git a/src/aiidalab_qe_vibroscopy/app/result/model.py b/src/aiidalab_qe_vibroscopy/app/result/model.py index 35f0b92..d4862c3 100644 --- a/src/aiidalab_qe_vibroscopy/app/result/model.py +++ b/src/aiidalab_qe_vibroscopy/app/result/model.py @@ -2,7 +2,6 @@ import traitlets as tl -from aiidalab_qe_vibroscopy.utils.raman.result import export_iramanworkchain_data from aiidalab_qe_vibroscopy.utils.phonons.result import export_phononworkchain_data from aiidalab_qe_vibroscopy.utils.euphonic import export_euphonic_data @@ -24,7 +23,10 @@ def needs_dielectric_tab(self): return True def needs_raman_tab(self): - return export_iramanworkchain_data(self.get_vibro_node()) + node = self.get_vibro_node() + if not any(key in node for key in ["iraman", "harmonic"]): + return False + return True # Here we use _fetch_child_process_node() since the function needs the input_structure in inputs def needs_phonons_tab(self): diff --git a/src/aiidalab_qe_vibroscopy/app/result/result.py b/src/aiidalab_qe_vibroscopy/app/result/result.py index 21ced89..ab24155 100644 --- a/src/aiidalab_qe_vibroscopy/app/result/result.py +++ b/src/aiidalab_qe_vibroscopy/app/result/result.py @@ -7,6 +7,9 @@ from aiidalab_qe_vibroscopy.app.widgets.dielectricmodel import DielectricModel import ipywidgets as ipw +from aiidalab_qe_vibroscopy.app.widgets.ramanwidget import RamanWidget +from aiidalab_qe_vibroscopy.app.widgets.ramanmodel import RamanModel + class VibroResultsPanel(ResultsPanel[VibroResultsModel]): title = "Vibronic" @@ -32,8 +35,14 @@ def render(self): if self._model.needs_phonons_tab(): tab_data.append(("Phonons", ipw.HTML("phonon_data"))) - if self._model.needs_raman_tab(): - tab_data.append(("Raman", ipw.HTML("raman_data"))) + needs_raman_tab = self._model.needs_raman_tab() + if needs_raman_tab: + raman_model = RamanModel() + raman_widget = RamanWidget( + model=raman_model, + node=vibro_node, + ) + tab_data.append(("Raman", raman_widget)) needs_dielectri_tab = self._model.needs_dielectric_tab() diff --git a/src/aiidalab_qe_vibroscopy/app/widgets/ramanmodel.py b/src/aiidalab_qe_vibroscopy/app/widgets/ramanmodel.py new file mode 100644 index 0000000..46e9281 --- /dev/null +++ b/src/aiidalab_qe_vibroscopy/app/widgets/ramanmodel.py @@ -0,0 +1,307 @@ +from __future__ import annotations +from aiidalab_qe.common.mvc import Model +import traitlets as tl +from aiida.common.extendeddicts import AttributeDict +from IPython.display import display +import numpy as np +from aiida_vibroscopy.utils.broadenings import multilorentz +import plotly.graph_objects as go +import base64 +import json + + +class RamanModel(Model): + vibro = tl.Instance(AttributeDict, allow_none=True) + + raman_plot_type_options = tl.List( + trait=tl.List(tl.Unicode()), + default_value=[ + ("Powder", "powder"), + ("Single Crystal", "single_crystal"), + ], + ) + raman_plot_type = tl.Unicode("powder") + raman_temperature = tl.Float(300) + raman_frequency_laser = tl.Float(532) + raman_pol_incoming = tl.Unicode("0 0 1") + raman_pol_outgoing = tl.Unicode("0 0 1") + raman_broadening = tl.Float(10.0) + raman_separate_polarizations = tl.Bool(False) + + frequencies = [] + intensities = [] + + frequencies_depolarized = [] + intensities_depolarized = [] + + def fetch_data(self): + """Fetch the Raman data from the VibroWorkChain""" + self.raman_data = self.get_vibrational_data(self.vibro) + + def update_data(self): + """ + Update the Raman plot data based on the selected plot type and configuration. + """ + if self.raman_plot_type == "powder": + self._update_powder_data() + else: + self._update_single_crystal_data() + + def _update_powder_data(self): + """ + Update data for the powder Raman plot. + """ + ( + polarized_intensities, + depolarized_intensities, + frequencies, + _, + ) = self.raman_data.run_powder_raman_intensities( + frequencies=self.raman_frequency_laser, + temperature=self.raman_temperature, + ) + + if self.raman_separate_polarizations: + self.frequencies, self.intensities = self.generate_plot_data( + frequencies, + polarized_intensities, + self.raman_broadening, + ) + self.frequencies_depolarized, self.intensities_depolarized = ( + self.generate_plot_data( + frequencies, + depolarized_intensities, + self.raman_broadening, + ) + ) + else: + combined_intensities = polarized_intensities + depolarized_intensities + self.frequencies, self.intensities = self.generate_plot_data( + frequencies, + combined_intensities, + self.raman_broadening, + ) + self.frequencies_depolarized, self.intensities_depolarized = [], [] + + def _update_single_crystal_data(self): + """ + Update data for the single crystal Raman plot. + """ + dir_incoming, _ = self._check_inputs_correct(self.raman_pol_incoming) + dir_outgoing, _ = self._check_inputs_correct(self.raman_pol_outgoing) + + ( + intensities, + frequencies, + labels, + ) = self.raman_data.run_single_crystal_raman_intensities( + pol_incoming=dir_incoming, + pol_outgoing=dir_outgoing, + frequencies=self.raman_frequency_laser, + temperature=self.raman_temperature, + ) + self.frequencies, self.intensities = self.generate_plot_data( + frequencies, intensities + ) + self.frequencies_depolarized, self.intensities_depolarized = [], [] + + def update_plot(self, plot): + """ + Update the Raman plot based on the selected plot type and configuration. + + Parameters: + plot: The plotly.graph_objs.Figure widget to update. + """ + update_function = ( + self._update_powder_plot + if self.raman_plot_type == "powder" + else self._update_single_crystal_plot + ) + update_function(plot) + + def _update_powder_plot(self, plot): + """ + Update the powder Raman plot. + + Parameters: + plot: The plotly.graph_objs.Figure widget to update. + """ + if self.raman_separate_polarizations: + self._update_polarized_and_depolarized(plot) + else: + self._clear_depolarized_and_update(plot) + + def _update_polarized_and_depolarized(self, plot): + """ + Update the plot when polarized and depolarized data are separate. + + Parameters: + plot: The plotly.graph_objs.Figure widget to update. + """ + if len(plot.data) == 1: + self._update_trace( + plot.data[0], self.frequencies, self.intensities, "Polarized" + ) + plot.add_trace( + go.Scatter( + x=self.frequencies_depolarized, + y=self.intensities_depolarized, + name="Depolarized", + ) + ) + plot.layout.title.text = "Powder Raman Spectrum" + elif len(plot.data) == 2: + self._update_trace( + plot.data[0], self.frequencies, self.intensities, "Polarized" + ) + self._update_trace( + plot.data[1], + self.frequencies_depolarized, + self.intensities_depolarized, + "Depolarized", + ) + plot.layout.title.text = "Powder Raman Spectrum" + + def _clear_depolarized_and_update(self, plot): + """ + Clear depolarized data and update the plot. + + Parameters: + plot: The plotly.graph_objs.Figure widget to update. + """ + if len(plot.data) == 2: + self._update_trace(plot.data[0], self.frequencies, self.intensities, "") + plot.data[1].x = [] + plot.data[1].y = [] + plot.layout.title.text = "Powder Raman Spectrum" + elif len(plot.data) == 1: + self._update_trace(plot.data[0], self.frequencies, self.intensities, "") + plot.layout.title.text = "Powder Raman Spectrum" + + def _update_single_crystal_plot(self, plot): + """ + Update the single crystal Raman plot. + + Parameters: + plot: The plotly.graph_objs.Figure widget to update. + """ + if len(plot.data) == 2: + self._update_trace(plot.data[0], self.frequencies, self.intensities, "") + plot.data[1].x = [] + plot.data[1].y = [] + plot.layout.title.text = "Single Crystal Raman Spectrum" + elif len(plot.data) == 1: + self._update_trace(plot.data[0], self.frequencies, self.intensities, "") + plot.layout.title.text = "Single Crystal Raman Spectrum" + + def _update_trace(self, trace, x_data, y_data, name): + """ + Helper function to update a single trace in the plot. + + Parameters: + trace: The trace to update. + x_data: The new x-axis data. + y_data: The new y-axis data. + name: The name of the trace. + """ + trace.x = x_data + trace.y = y_data + trace.name = name + + def get_vibrational_data(self, node): + """ + Extract vibrational data from an IRamanWorkChain or HarmonicWorkChain node. + + Parameters: + node: The workchain node containing IRaman or Harmonic data. + + Returns: + The vibrational accuracy data (vibro) or None if not available. + """ + # Determine the output node + output_node = getattr(node, "iraman", None) or getattr(node, "harmonic", None) + if not output_node: + return None + + # Check for vibrational data and extract accuracy + vibrational_data = getattr(output_node, "vibrational_data", None) + if not vibrational_data: + return None + + # Extract vibrational accuracy (prefer numerical_accuracy_4 if available) + vibro = getattr(vibrational_data, "numerical_accuracy_4", None) or getattr( + vibrational_data, "numerical_accuracy_2", None + ) + + return vibro + + def _check_inputs_correct(self, polarization): + # Check if the polarization vectors are correct + input_text = polarization + input_values = input_text.split() + dir_values = [] + if len(input_values) == 3: + try: + dir_values = [float(i) for i in input_values] + return dir_values, True + except: # noqa: E722 + return dir_values, False + else: + return dir_values, False + + def generate_plot_data( + self, + frequencies: list[float], + intensities: list[float], + broadening: float = 10.0, + x_range: list[float] | str = "auto", + broadening_function=multilorentz, + normalize: bool = True, + ): + frequencies = np.array(frequencies) + intensities = np.array(intensities) + + if x_range == "auto": + xi = max(0, frequencies.min() - 200) + xf = frequencies.max() + 200 + x_range = np.arange(xi, xf, 1.0) + + y_range = broadening_function(x_range, frequencies, intensities, broadening) + + if normalize: + y_range /= y_range.max() + + return x_range, y_range + + def download_data(self, _=None): + filename = "spectra.json" + if self.raman_separate_polarizations: + my_dict = { + "Frequencies cm-1": self.frequencies.tolist(), + "Polarized intensities": self.intensities.tolist(), + "Depolarized intensities": self.intensities_depolarized.tolist(), + } + else: + my_dict = { + "Frequencies cm-1": self.frequencies.tolist(), + "Intensities": self.intensities.tolist(), + } + json_str = json.dumps(my_dict) + b64_str = base64.b64encode(json_str.encode()).decode() + self._download(payload=b64_str, filename=filename) + + @staticmethod + def _download(payload, filename): + from IPython.display import Javascript + + javas = Javascript( + """ + var link = document.createElement('a'); + link.href = 'data:text/json;charset=utf-8;base64,{payload}' + link.download = "{filename}" + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); + """.format(payload=payload, filename=filename) + ) + display(javas) diff --git a/src/aiidalab_qe_vibroscopy/app/widgets/ramanwidget.py b/src/aiidalab_qe_vibroscopy/app/widgets/ramanwidget.py new file mode 100644 index 0000000..085fa7e --- /dev/null +++ b/src/aiidalab_qe_vibroscopy/app/widgets/ramanwidget.py @@ -0,0 +1,176 @@ +import ipywidgets as ipw +from aiidalab_qe_vibroscopy.app.widgets.ramanmodel import RamanModel +import plotly.graph_objects as go +from aiidalab_widgets_base.utils import StatusHTML + + +class RamanWidget(ipw.VBox): + """ + Widget to display Raman properties Tab + """ + + def __init__(self, model: RamanModel, node: None, **kwargs): + super().__init__( + children=[ipw.HTML("Loading Raman data...")], + **kwargs, + ) + self._model = model + self._model.vibro = node + self.rendered = False + + def render(self): + if self.rendered: + return + + self.raman_plot_type = ipw.ToggleButtons( + description="Spectrum type:", + style={"description_width": "initial"}, + ) + ipw.dlink( + (self._model, "raman_plot_type_options"), + (self.raman_plot_type, "options"), + ) + ipw.link( + (self._model, "raman_plot_type"), + (self.raman_plot_type, "value"), + ) + self.raman_plot_type.observe(self._on_raman_plot_type_change, names="value") + self.raman_temperature = ipw.FloatText( + description="Temperature (K):", + style={"description_width": "initial"}, + ) + ipw.link( + (self._model, "raman_temperature"), + (self.raman_temperature, "value"), + ) + self.raman_frequency_laser = ipw.FloatText( + description="Laser frequency (nm):", + style={"description_width": "initial"}, + ) + ipw.link( + (self._model, "raman_frequency_laser"), + (self.raman_frequency_laser, "value"), + ) + self.raman_pol_incoming = ipw.Text( + description="Incoming polarization:", + style={"description_width": "initial"}, + layout=ipw.Layout(visibility="hidden"), + ) + ipw.link( + (self._model, "raman_pol_incoming"), + (self.raman_pol_incoming, "value"), + ) + self.raman_pol_outgoing = ipw.Text( + description="Outgoing polarization:", + style={"description_width": "initial"}, + layout=ipw.Layout(visibility="hidden"), + ) + ipw.link( + (self._model, "raman_pol_outgoing"), + (self.raman_pol_outgoing, "value"), + ) + self.raman_plot_button = ipw.Button( + description="Update Plot", + icon="pencil", + button_style="primary", + layout=ipw.Layout(width="auto"), + ) + self.raman_plot_button.on_click(self._on_raman_plot_button_click) + self.raman_download_button = ipw.Button( + description="Download Data", + icon="download", + button_style="primary", + layout=ipw.Layout(width="auto"), + ) + self.raman_download_button.on_click(self._model.download_data) + self._wrong_syntax = StatusHTML(clear_after=8) + + self.raman_broadening = ipw.FloatText( + description="Broadening (cm-1):", + style={"description_width": "initial"}, + ) + ipw.link( + (self._model, "raman_broadening"), + (self.raman_broadening, "value"), + ) + + self.raman_separate_polarized = ipw.Checkbox( + description="Separate polarized and depolarized intensities", + style={"description_width": "initial"}, + ) + ipw.link( + (self._model, "raman_separate_polarizations"), + (self.raman_separate_polarized, "value"), + ) + self.raman_spectrum = go.FigureWidget( + layout=go.Layout( + title=dict(text="Powder Raman spectrum"), + barmode="overlay", + xaxis=dict( + title="Wavenumber (cm-1)", + nticks=0, + ), + yaxis=dict( + title="Intensity (arb. units)", + ), + height=500, + width=700, + plot_bgcolor="white", + ) + ) + + self.children = [ + ipw.HTML("