Skip to content

Commit

Permalink
Replace manual ldos grid params with MALA parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
timcallow committed Aug 28, 2024
1 parent 8b7cd3a commit 31da07e
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions mala/datahandling/ldos_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
target_calculator=None,
descriptor_calculator=None,
):
self.ldos_parameters = parameters.targets
super(LDOSAlign, self).__init__(
parameters,
target_calculator=target_calculator,
Expand Down Expand Up @@ -91,8 +92,6 @@ def align_ldos_to_ref(
zero_tol=1e-5,
left_truncate=False,
right_truncate_value=None,
egrid_spacing_ev=0.1,
egrid_offset_ev=-10,
number_of_electrons=None,
n_shift_mse=None,
):
Expand Down Expand Up @@ -131,7 +130,6 @@ def align_ldos_to_ref(
vectors based on mean-squared error
computed automatically if None
"""

if self.parameters._configuration["mpi"]:
comm = get_comm()
rank = comm.rank
Expand All @@ -140,15 +138,16 @@ def align_ldos_to_ref(
comm = None
rank = 0
size = 1

if rank == 0:
# load in the reference snapshot
snapshot_ref = self.parameters.snapshot_directories_list[
reference_index
]
ldos_ref = np.load(
os.path.join(
snapshot_ref.output_npy_directory, snapshot_ref.output_npy_file
snapshot_ref.output_npy_directory,
snapshot_ref.output_npy_file,
),
mmap_mode="r",
)
Expand All @@ -166,9 +165,12 @@ def align_ldos_to_ref(
left_index_ref = np.where(ldos_mean_ref > zero_tol)[0][0]

# get the energy grid
emax = egrid_offset_ev + n_target * egrid_spacing_ev
emax = (
self.ldos_parameters.ldos_gridoffset_ev
+ n_target * self.ldos_parameters.ldos_gridspacing_ev
)
e_grid = np.linspace(
egrid_offset_ev,
self.ldos_parameters.ldos_gridoffset_ev,
emax,
n_target,
endpoint=False,
Expand All @@ -192,12 +194,12 @@ def align_ldos_to_ref(
n_shift_mse = comm.bcast(n_shift_mse, root=0)
N_snapshots = comm.bcast(N_snapshots, root=0)
n_target = comm.bcast(n_target, root=0)

local_snapshots = [i for i in range(rank, N_snapshots, size)]

else:
local_snapshots = range(N_snapshots)

for idx in local_snapshots:
snapshot = self.parameters.snapshot_directories_list[idx]
print(f"Aligning snapshot {idx+1} of {N_snapshots}")
Expand Down Expand Up @@ -228,7 +230,7 @@ def align_ldos_to_ref(
n_shift_mse,
)

e_shift = optimal_shift * egrid_spacing_ev
e_shift = optimal_shift * self.ldos_parameters.ldos_gridspacing_ev
if optimal_shift != 0:
ldos_shifted[:, :-optimal_shift] = ldos[:, optimal_shift:]
else:
Expand All @@ -246,11 +248,12 @@ def align_ldos_to_ref(
# get the first non-zero value
ldos_shifted = ldos_shifted[:, left_index_ref:]
new_egrid_offset = (
egrid_offset_ev
+ (left_index_ref + optimal_shift) * egrid_spacing_ev
self.ldos_parameters.ldos_gridoffset_ev
+ (left_index_ref + optimal_shift)
* self.ldos_parameters.ldos_gridspacing_ev
)
else:
new_egrid_offset = egrid_offset_ev
new_egrid_offset = self.ldos_parameters.ldos_gridoffset_ev

# reshape
ldos_shifted = ldos_shifted.reshape(ngrid, ngrid, ngrid, -1)
Expand All @@ -259,7 +262,9 @@ def align_ldos_to_ref(
"ldos_shift_ev": round(e_shift, 4),
"aligned_ldos_gridoffset_ev": round(new_egrid_offset, 4),
"aligned_ldos_gridsize": np.shape(ldos_shifted)[-1],
"aligned_ldos_gridspacing": round(egrid_spacing_ev, 4),
"aligned_ldos_gridspacing": round(
self.ldos_parameters.ldos_gridspacing_ev, 4
),
}

if number_of_electrons is not None:
Expand Down

0 comments on commit 31da07e

Please sign in to comment.