Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 8, 2024
1 parent 80159e1 commit a83e757
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
42 changes: 28 additions & 14 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,8 @@ def compute_output_stats(

# failed to restore the bias from stat file. compute
if bias_atom_e is None:


# only get data once, sampled is a list of dict[str, torch.Tensor]
sampled = merged() if callable(merged) else merged
sampled = merged() if callable(merged) else merged
if model_forward is not None:
model_pred = _compute_model_predict(sampled, keys, model_forward)

Check warning on line 252 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L251-L252

Added lines #L251 - L252 were not covered by tests
else:
Expand All @@ -267,11 +265,19 @@ def compute_output_stats(
global_sampled : [sys1, sys2]
atomic_sampled : [sys1]
"""
for kk in keys:
for kk in keys:
for idx, system in enumerate(sampled):
if (("find_atom_" + kk) in system) and (system["find_atom_" + kk] > 0.0) and (idx not in atomic_sampled):
atomic_sampled[idx] = system
elif (("find_" + kk) in system) and (system["find_" + kk] > 0.0) and (idx not in global_sampled):
if (

Check warning on line 270 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L269-L270

Added lines #L269 - L270 were not covered by tests
(("find_atom_" + kk) in system)
and (system["find_atom_" + kk] > 0.0)
and (idx not in atomic_sampled)
):
atomic_sampled[idx] = system
elif (

Check warning on line 276 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L275-L276

Added lines #L275 - L276 were not covered by tests
(("find_" + kk) in system)
and (system["find_" + kk] > 0.0)
and (idx not in global_sampled)
):
global_sampled[idx] = system

Check warning on line 281 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L281

Added line #L281 was not covered by tests
else:
continue

Check warning on line 283 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L283

Added line #L283 was not covered by tests
Expand All @@ -287,7 +293,7 @@ def compute_output_stats(
preset_bias,
model_pred,
)

if len(atomic_sampled) > 0:
bias_atom_e, std_atom_e = compute_output_stats_atomic(

Check warning on line 298 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L297-L298

Added lines #L297 - L298 were not covered by tests
global_sampled,
Expand All @@ -297,14 +303,15 @@ def compute_output_stats(
preset_bias,
model_pred,
)

Check warning

Code scanning / CodeQL

Use of the return value of a procedure Warning

The result of
compute_output_stats_atomic
is used even though it is always None.

# need to merge dict
if stat_file_path is not None:
_save_to_file(stat_file_path, bias_atom_e, std_atom_e)

bias_atom_e = {kk: to_torch_tensor(vv) for kk, vv in bias_atom_e.items()}
std_atom_e = {kk: to_torch_tensor(vv) for kk, vv in std_atom_e.items()}
return bias_atom_e, std_atom_e
return bias_atom_e, std_atom_e

Check warning on line 313 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L311-L313

Added lines #L311 - L313 were not covered by tests


def compute_output_stats_global(

Check warning on line 316 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L316

Added line #L316 was not covered by tests
sampled: List[dict],
Expand All @@ -315,7 +322,6 @@ def compute_output_stats_global(
model_pred: Optional[Dict[str, np.ndarray]] = None,
):
"""This function only handle stat computation from reduced global labels."""

# remove the keys that are not in the sample
keys = [keys] if isinstance(keys, str) else keys
assert isinstance(keys, list)
Expand All @@ -324,7 +330,14 @@ def compute_output_stats_global(
keys = new_keys

Check warning on line 330 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L326-L330

Added lines #L326 - L330 were not covered by tests

# get label dict from sample; for each key, only picking the system with global labels.
outputs = {kk: [system[kk] for system in sampled if kk in system and system.get(f"find_{kk}", 0) > 0] for kk in keys}
outputs = {

Check warning on line 333 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L333

Added line #L333 was not covered by tests
kk: [
system[kk]
for system in sampled
if kk in system and system.get(f"find_{kk}", 0) > 0
]
for kk in keys
}

data_mixed_type = "real_natoms_vec" in sampled[0]
natoms_key = "natoms" if not data_mixed_type else "real_natoms_vec"
Expand Down Expand Up @@ -367,7 +380,7 @@ def compute_output_stats_global(
rcond=rcond,
)
bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e)

Check warning on line 382 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L382

Added line #L382 was not covered by tests

# unbias_e is only used for print rmse
if model_pred is None:
unbias_e = {

Check warning on line 386 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L385-L386

Added lines #L385 - L386 were not covered by tests
Expand All @@ -394,6 +407,7 @@ def rmse(x):
)
return bias_atom_e, std_atom_e

Check warning on line 408 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L408

Added line #L408 was not covered by tests


def compute_output_stats_atomic(

Check warning on line 411 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L411

Added line #L411 was not covered by tests
sampled: List[dict],
ntypes: int,
Expand All @@ -402,4 +416,4 @@ def compute_output_stats_atomic(
preset_bias: Optional[Dict[str, List[Optional[torch.Tensor]]]] = None,
model_pred: Optional[Dict[str, np.ndarray]] = None,
):
pass
pass

Check warning on line 419 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L419

Added line #L419 was not covered by tests
2 changes: 1 addition & 1 deletion source/tests/pt/model/test_atomic_model_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def setUp(self):
np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2)
),
"find_foo": np.float32(1.0),
"find_bar": np.float32(1.0)
"find_bar": np.float32(1.0),
}
]
self.tempdir = tempfile.TemporaryDirectory()
Expand Down

0 comments on commit a83e757

Please sign in to comment.