From f6ebec1db405fc38b6ede126bccb8d33bf6509aa Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 8 Apr 2024 21:56:33 +0800 Subject: [PATCH] fix: precommit --- deepmd/pt/utils/stat.py | 51 ++++++++--------------- source/tests/pt/model/test_atomic_bias.py | 8 ++-- 2 files changed, 21 insertions(+), 38 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 70effcf23c..1da48bbb0c 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -264,43 +264,26 @@ def compute_output_stats( keys = new_keys # split system based on label - atomic_sampled_idx = defaultdict(set) - global_sampled_idx = defaultdict(set) + atomic_sampled_idx = defaultdict(list) + global_sampled_idx = defaultdict(list) for kk in keys: for idx, system in enumerate(sampled): - if ( - (("find_atom_" + kk) in system) - and (system["find_atom_" + kk] > 0.0) - and (idx not in atomic_sampled_idx[kk]) - ): - atomic_sampled_idx[kk].add(idx) - elif ( - (("find_" + kk) in system) - and (system["find_" + kk] > 0.0) - and (idx not in global_sampled_idx[kk]) - ): - global_sampled_idx[kk].add(idx) + + + if (("find_atom_" + kk) in system) and (system["find_atom_" + kk] > 0.0) and (len(atomic_sampled_idx[kk])==0 or idx > atomic_sampled_idx[kk][-1]): + atomic_sampled_idx[kk].append(idx) + elif (("find_" + kk) in system) and (system["find_" + kk] > 0.0) and (len(global_sampled_idx[kk])==0 or idx > global_sampled_idx[kk][-1]): + global_sampled_idx[kk].append(idx) + else: continue # use index to gather model predictions for the corresponding systems. - model_pred_g = ( - { - kk: [vv[idx] for idx in sorted(list(global_sampled_idx[kk]))] - for kk, vv in model_pred.items() - } - if model_pred - else None - ) - model_pred_a = ( - { - kk: [vv[idx] for idx in sorted(list(atomic_sampled_idx[kk]))] - for kk, vv in model_pred.items() - } - if model_pred - else None - ) + + model_pred_g = {kk: [vv[idx] for idx in global_sampled_idx[kk]] for kk, vv in model_pred.items()} if model_pred else None + model_pred_a = {kk: [vv[idx] for idx in atomic_sampled_idx[kk]] for kk, vv in model_pred.items()} if model_pred else None + # concat all frames within those systmes model_pred_g = ( { @@ -396,11 +379,11 @@ def compute_output_stats_global( for kk in keys } # shape: (nframes, ndim) - merged_output = {kk: to_numpy_array(torch.cat(outputs[kk])) for kk in keys} + merged_output = {kk: to_numpy_array(torch.cat(outputs[kk])) for kk in keys if len(outputs[kk])>0} # shape: (nframes, ntypes) - merged_natoms = { - kk: to_numpy_array(torch.cat(input_natoms[kk])[:, 2:]) for kk in keys - } + + merged_natoms = {kk: to_numpy_array(torch.cat(input_natoms[kk])[:, 2:]) for kk in keys if len(input_natoms[kk])>0} + nf = {kk: merged_natoms[kk].shape[0] for kk in keys} if preset_bias is not None: assigned_atom_ener = { diff --git a/source/tests/pt/model/test_atomic_bias.py b/source/tests/pt/model/test_atomic_bias.py index 562832b429..82f2bda8ec 100644 --- a/source/tests/pt/model/test_atomic_bias.py +++ b/source/tests/pt/model/test_atomic_bias.py @@ -87,7 +87,7 @@ def forward( [4.0, 5.0, 6.0], ] ) - .view([nf, nloc] + self.output_def()["foo"].shape) + .view([nf, nloc, *self.output_def()["foo"].shape]) .to(env.GLOBAL_PT_FLOAT_PRECISION) .to(env.DEVICE) ) @@ -98,7 +98,7 @@ def forward( [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], ] ) - .view([nf, nloc] + self.output_def()["bar"].shape) + .view([nf, nloc, *self.output_def()["bar"].shape]) .to(env.GLOBAL_PT_FLOAT_PRECISION) .to(env.DEVICE) ) @@ -198,13 +198,13 @@ def cvt_ret(x): [1.0, 2.0, 3.0], [4.0, 5.0, 6.0], ] - ).reshape([nf, nloc] + md0.fitting_output_def()["foo"].shape) + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) expected_ret0["bar"] = np.array( [ [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], ] - ).reshape([nf, nloc] + md0.fitting_output_def()["bar"].shape) + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) for kk in ["foo", "bar"]: np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk])