From 6d13673d22f550313b29d6b5312fd3ef43454a44 Mon Sep 17 00:00:00 2001 From: Junfeng Qiao Date: Mon, 22 Apr 2024 14:03:27 +0200 Subject: [PATCH] Shift bands with different Fermi energies When plotting the band structure comparison. --- src/aiida_wannier90_workflows/cli/group.py | 2 +- src/aiida_wannier90_workflows/cli/list.py | 2 +- .../utils/workflows/plot/bands.py | 43 ++++++++++++------- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/aiida_wannier90_workflows/cli/group.py b/src/aiida_wannier90_workflows/cli/group.py index 5bff6c3..59296dc 100644 --- a/src/aiida_wannier90_workflows/cli/group.py +++ b/src/aiida_wannier90_workflows/cli/group.py @@ -4,7 +4,7 @@ from aiida import orm from aiida.cmdline.params import types from aiida.cmdline.utils import decorators, echo -from aiida.cmdline.utils.query.calculation import CalculationQueryBuilder +from aiida.tools.query.calculation import CalculationQueryBuilder from .root import cmd_root diff --git a/src/aiida_wannier90_workflows/cli/list.py b/src/aiida_wannier90_workflows/cli/list.py index f47f9a4..02554a4 100644 --- a/src/aiida_wannier90_workflows/cli/list.py +++ b/src/aiida_wannier90_workflows/cli/list.py @@ -4,7 +4,7 @@ from aiida import orm from aiida.cmdline.params import options as options_core from aiida.cmdline.utils import decorators, echo -from aiida.cmdline.utils.query.calculation import CalculationQueryBuilder +from aiida.tools.query.calculation import CalculationQueryBuilder from .root import cmd_root diff --git a/src/aiida_wannier90_workflows/utils/workflows/plot/bands.py b/src/aiida_wannier90_workflows/utils/workflows/plot/bands.py index 7436c62..36a4f93 100755 --- a/src/aiida_wannier90_workflows/utils/workflows/plot/bands.py +++ b/src/aiida_wannier90_workflows/utils/workflows/plot/bands.py @@ -138,6 +138,7 @@ def get_mpl_code_for_bands( wan_bands, *, fermi_energy=None, + fermi_energy2=None, shift_fermi=False, title=None, save=False, @@ -170,6 +171,14 @@ def get_mpl_code_for_bands( replacement += "p.axhline(y=0, color='blue', linestyle='--', label='Fermi', zorder=-1)\n" else: replacement += "p.axhline(y=fermi_energy, color='blue', linestyle='--', label='Fermi', zorder=-1)\n" + if (fermi_energy2 is not None) and abs(fermi_energy2 - fermi_energy) > 1e-3: + replacement += f"fermi_energy2 = {fermi_energy2}\n" + if shift_fermi: + replacement += "p.axhline(y=0, color='cyan', linestyle='--', label='Fermi2', zorder=-1)\n" + else: + replacement += "p.axhline(y=fermi_energy2, color='cyan', linestyle='--', label='Fermi2', zorder=-1)\n" + else: + replacement += "fermi_energy2 = fermi_energy\n" replacement += "pl.legend()\n\n" replacement += "for path in paths:" dft_mpl_code = dft_mpl_code.replace(b"for path in paths:", replacement.encode()) @@ -186,7 +195,7 @@ def get_mpl_code_for_bands( ) wan_mpl_code = wan_mpl_code.replace( b"p.plot(x, band, label=label,", - b"p.plot(x, [_-fermi_energy for _ in band], label=label,", + b"p.plot(x, [_-fermi_energy2 for _ in band], label=label,", ) dft_mpl_code = dft_mpl_code.replace( b"p.set_ylim([all_data['y_min_lim'], all_data['y_max_lim']])", @@ -194,7 +203,7 @@ def get_mpl_code_for_bands( ) wan_mpl_code = wan_mpl_code.replace( b"p.set_ylim([all_data['y_min_lim'], all_data['y_max_lim']])", - b"p.set_ylim([all_data['y_min_lim']-fermi_energy, all_data['y_max_lim']-fermi_energy])", + b"p.set_ylim([all_data['y_min_lim']-fermi_energy2, all_data['y_max_lim']-fermi_energy2])", ) mpl_code = dft_mpl_code + wan_mpl_code @@ -232,7 +241,12 @@ def get_output_bands(workchain): def get_mpl_code_for_workchains( - workchain0, workchain1, title=None, save=False, filename=None + workchain0, + workchain1, + title=None, + save=False, + filename=None, + shift_fermi=False, ): """Return matplotlib code for comparing band structures of two workchains.""" # assume workchain0 is pw, workchain1 is wannier @@ -250,23 +264,15 @@ def get_mpl_code_for_workchains( if save and (filename is None): filename = f"bandsdiff_{formula}_{workchain0.pk}_{workchain1.pk}.py" - if workchain1.process_class in ( - Wannier90BaseWorkChain, - Wannier90BandsWorkChain, - Wannier90OptimizeWorkChain, - ): - fermi_energy = get_workchain_fermi_energy(workchain1) - else: - if workchain0.process_class in [PwBandsWorkChain, ProjwfcBandsWorkChain]: - fermi_energy = workchain0.outputs["scf_parameters"]["fermi_energy"] - else: - raise ValueError(f"Cannot find fermi energy from {workchain0}") + fermi_energy = get_workchain_fermi_energy(workchain0) + fermi_energy2 = get_workchain_fermi_energy(workchain1) mpl_code = get_mpl_code_for_bands( dft_bands, wan_bands, fermi_energy=fermi_energy, - shift_fermi=False, + fermi_energy2=fermi_energy2, + shift_fermi=shift_fermi, title=title, save=save, filename=filename, @@ -276,7 +282,12 @@ def get_mpl_code_for_workchains( def get_workchain_fermi_energy( - workchain: ty.Union[Wannier90BaseWorkChain, Wannier90BandsWorkChain] + workchain: ty.Union[ + Wannier90BaseWorkChain, + Wannier90BandsWorkChain, + PwBandsWorkChain, + ProjwfcBandsWorkChain, + ] ) -> float: """Get Fermi energy of Wannier90BandsWorkChain.