Skip to content

Commit

Permalink
fix: precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Apr 8, 2024
1 parent f9278eb commit f6ebec1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 38 deletions.
51 changes: 17 additions & 34 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,43 +264,26 @@ def compute_output_stats(
keys = new_keys

# split system based on label
atomic_sampled_idx = defaultdict(set)
global_sampled_idx = defaultdict(set)
atomic_sampled_idx = defaultdict(list)
global_sampled_idx = defaultdict(list)

for kk in keys:
for idx, system in enumerate(sampled):
if (
(("find_atom_" + kk) in system)
and (system["find_atom_" + kk] > 0.0)
and (idx not in atomic_sampled_idx[kk])
):
atomic_sampled_idx[kk].add(idx)
elif (
(("find_" + kk) in system)
and (system["find_" + kk] > 0.0)
and (idx not in global_sampled_idx[kk])
):
global_sampled_idx[kk].add(idx)


if (("find_atom_" + kk) in system) and (system["find_atom_" + kk] > 0.0) and (len(atomic_sampled_idx[kk])==0 or idx > atomic_sampled_idx[kk][-1]):
atomic_sampled_idx[kk].append(idx)
elif (("find_" + kk) in system) and (system["find_" + kk] > 0.0) and (len(global_sampled_idx[kk])==0 or idx > global_sampled_idx[kk][-1]):
global_sampled_idx[kk].append(idx)

else:
continue

# use index to gather model predictions for the corresponding systems.
model_pred_g = (
{
kk: [vv[idx] for idx in sorted(list(global_sampled_idx[kk]))]
for kk, vv in model_pred.items()
}
if model_pred
else None
)
model_pred_a = (
{
kk: [vv[idx] for idx in sorted(list(atomic_sampled_idx[kk]))]
for kk, vv in model_pred.items()
}
if model_pred
else None
)

model_pred_g = {kk: [vv[idx] for idx in global_sampled_idx[kk]] for kk, vv in model_pred.items()} if model_pred else None
model_pred_a = {kk: [vv[idx] for idx in atomic_sampled_idx[kk]] for kk, vv in model_pred.items()} if model_pred else None

# concat all frames within those systmes
model_pred_g = (
{
Expand Down Expand Up @@ -396,11 +379,11 @@ def compute_output_stats_global(
for kk in keys
}
# shape: (nframes, ndim)
merged_output = {kk: to_numpy_array(torch.cat(outputs[kk])) for kk in keys}
merged_output = {kk: to_numpy_array(torch.cat(outputs[kk])) for kk in keys if len(outputs[kk])>0}
# shape: (nframes, ntypes)
merged_natoms = {
kk: to_numpy_array(torch.cat(input_natoms[kk])[:, 2:]) for kk in keys
}

merged_natoms = {kk: to_numpy_array(torch.cat(input_natoms[kk])[:, 2:]) for kk in keys if len(input_natoms[kk])>0}

nf = {kk: merged_natoms[kk].shape[0] for kk in keys}
if preset_bias is not None:
assigned_atom_ener = {
Expand Down
8 changes: 4 additions & 4 deletions source/tests/pt/model/test_atomic_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(
[4.0, 5.0, 6.0],
]
)
.view([nf, nloc] + self.output_def()["foo"].shape)
.view([nf, nloc, *self.output_def()["foo"].shape])
.to(env.GLOBAL_PT_FLOAT_PRECISION)
.to(env.DEVICE)
)
Expand All @@ -98,7 +98,7 @@ def forward(
[4.0, 5.0, 6.0, 10.0, 11.0, 12.0],
]
)
.view([nf, nloc] + self.output_def()["bar"].shape)
.view([nf, nloc, *self.output_def()["bar"].shape])
.to(env.GLOBAL_PT_FLOAT_PRECISION)
.to(env.DEVICE)
)
Expand Down Expand Up @@ -198,13 +198,13 @@ def cvt_ret(x):
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
]
).reshape([nf, nloc] + md0.fitting_output_def()["foo"].shape)
).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape])
expected_ret0["bar"] = np.array(
[
[1.0, 2.0, 3.0, 7.0, 8.0, 9.0],
[4.0, 5.0, 6.0, 10.0, 11.0, 12.0],
]
).reshape([nf, nloc] + md0.fitting_output_def()["bar"].shape)
).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape])
for kk in ["foo", "bar"]:
np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk])

Expand Down

0 comments on commit f6ebec1

Please sign in to comment.