Skip to content

Commit

Permalink
Use itertools.chain for dict_array
Browse files Browse the repository at this point in the history
  • Loading branch information
kavanase committed Jun 25, 2024
1 parent 33d3a44 commit 0cf7b85
Showing 1 changed file with 37 additions and 36 deletions.
73 changes: 37 additions & 36 deletions pymatgen/io/vasp/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3858,42 +3858,43 @@ def read(self, filenames: list[PathLike]):
# Combine arrays along the kpoints axis:
# nbands (axis = 2) could differ between arrays, so set missing values to zero:
max_nbands = max(eig_dict[Spin.up].shape[1] for eig_dict in eigenvalues_list)
for dict_array_list in [occupancies_list, eigenvalues_list, data_list, xyz_data_list, phase_factors_list]:
for dict_array in dict_array_list:
if dict_array:
for key, array in dict_array.items():
if array.shape[1] < max_nbands:
if len(array.shape) == 2: # occupancies, eigenvalues
dict_array[key] = np.pad(
array,
((0, 0), (0, max_nbands - array.shape[2])),
mode="constant",
)
elif len(array.shape) == 4: # data, phase_factors
dict_array[key] = np.pad(
array,
(
(0, 0),
(0, max_nbands - array.shape[2]),
(0, 0),
(0, 0),
),
mode="constant",
)
elif len(array.shape) == 5: # xyz_data
dict_array[key] = np.pad(
array,
(
(0, 0),
(0, max_nbands - array.shape[2]),
(0, 0),
(0, 0),
(0, 0),
),
mode="constant",
)
else:
raise ValueError("Unexpected array shape encountered!")
for dict_array in itertools.chain(
occupancies_list, eigenvalues_list, data_list, xyz_data_list, phase_factors_list
):
if dict_array:
for key, array in dict_array.items():
if array.shape[1] < max_nbands:
if len(array.shape) == 2: # occupancies, eigenvalues
dict_array[key] = np.pad(
array,
((0, 0), (0, max_nbands - array.shape[2])),
mode="constant",
)
elif len(array.shape) == 4: # data, phase_factors
dict_array[key] = np.pad(
array,
(
(0, 0),
(0, max_nbands - array.shape[2]),
(0, 0),
(0, 0),
),
mode="constant",
)
elif len(array.shape) == 5: # xyz_data
dict_array[key] = np.pad(
array,
(
(0, 0),
(0, max_nbands - array.shape[2]),
(0, 0),
(0, 0),
(0, 0),
),
mode="constant",
)
else:
raise ValueError("Unexpected array shape encountered!")

# set nbands, nkpoints, and other attributes:
self.nbands = max_nbands
Expand Down

0 comments on commit 0cf7b85

Please sign in to comment.