Skip to content

Commit

Permalink
Make BandsData accept a generic BZ object
Browse files Browse the repository at this point in the history
  • Loading branch information
pfebrer committed Sep 27, 2023
1 parent 078d745 commit 2c07837
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions src/sisl/viz/data/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
class Aiida_node: pass
AIIDA_AVAILABLE = False


class BandsData(XarrayData):

def sanity_check(self,
Expand Down Expand Up @@ -276,29 +275,34 @@ def from_siesta_bands(cls, bands_file: bandsSileSiesta):
@new.register
@classmethod
def from_hamiltonian(cls,
band_structure: sisl.BandStructure,
bz: sisl.BrillouinZone,
H: Union[sisl.Hamiltonian, None] = None,
extra_vars: Sequence[Union[Dict, str]] = ()
):
"""Uses a sisl's `BandStructure` object to calculate the bands."""
if band_structure is None:
"""Uses a sisl's `BrillouinZone` object to calculate the bands."""
if bz is None:
raise ValueError("No band structure (k points path) was provided")

if not isinstance(getattr(band_structure, "parent", None), sisl.Hamiltonian):
if not isinstance(getattr(bz, "parent", None), sisl.Hamiltonian):
H = HamiltonianDataSource(H=H)
band_structure.set_parent(H)
bz.set_parent(H)
else:
H = band_structure.parent
H = bz.parent

# Define the spin class of this calculation.
spin = H.spin

ticks = band_structure.lineartick()
if isinstance(bz, sisl.BandStructure):
ticks = bz.lineartick()
kticks = bz.lineark()
else:
ticks = (None, None)
kticks = np.arange(0, len(bz))

# Get the wrapper function that we should call on each eigenstate.
# This also returns the coordinates and names to build the final dataset.
bands_wrapper, all_vars, coords_values = _get_eigenstate_wrapper(
band_structure.lineark(), spin, extra_vars=extra_vars
kticks, spin, extra_vars=extra_vars
)

# Get a dataset with all values for all spin indices
Expand All @@ -312,7 +316,7 @@ def from_hamiltonian(cls,
if not spin.is_diagonal:
spin_kwarg = {}

with band_structure.apply(pool=_do_parallel_calc, zip=True) as parallel:
with bz.apply(pool=_do_parallel_calc, zip=True) as parallel:
spin_bands = parallel.dataarray.eigenstate(
wrap=partial(bands_wrapper, spin_index=spin_index),
**spin_kwarg,
Expand All @@ -326,18 +330,18 @@ def from_hamiltonian(cls,

# If the band structure contains discontinuities, we will copy the dataset
# adding the discontinuities.
if len(band_structure._jump_idx) > 0:
if isinstance(bz, sisl.BandStructure) and len(bz._jump_idx) > 0:

old_coords = bands_data.coords
coords = {
name: band_structure.insert_jump(old_coords[name]) if name == "k" else old_coords[name].values
name: bz.insert_jump(old_coords[name]) if name == "k" else old_coords[name].values
for name in old_coords
}

def _add_jump(array):
if "k" in array.coords:
array = array.transpose("k", ...)
return (array.dims, band_structure.insert_jump(array))
return (array.dims, bz.insert_jump(array))
else:
return array

Expand Down

0 comments on commit 2c07837

Please sign in to comment.