Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 17, 2024
1 parent d7036b8 commit 09d775d
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
4 changes: 3 additions & 1 deletion source/tests/pt/model/test_atomic_model_atomic_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def cvt_ret(x):
self.merged_output_stat, stat_file_path=self.stat_file_path
)
ret1 = md0.forward_common_atomic(*args)
expected_std = np.ones((2,2,2)) # 2 keys, 2 atypes, 2 max dims.
expected_std = np.ones((2, 2, 2)) # 2 keys, 2 atypes, 2 max dims.
np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)
ret1 = cvt_ret(ret1)
# nt x odim
Expand All @@ -223,6 +223,7 @@ def cvt_ret(x):
expected_ret1["bar"] = ret0["bar"] + bar_bias[at]
for kk in ["foo", "bar"]:
np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk])

# 3. test bias load from file
def raise_error():
raise RuntimeError
Expand Down Expand Up @@ -258,6 +259,7 @@ def raise_error():
np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4)
np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)


class TestAtomicModelStatMergeGlobalAtomic(
unittest.TestCase, TestCaseSingleFrameWithNlist
):
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt/model/test_atomic_model_global_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def cvt_ret(x):
)
ret1 = md0.forward_common_atomic(*args)
ret1 = cvt_ret(ret1)
expected_std = np.ones((3,2,2)) # 3 keys, 2 atypes, 2 max dims.
expected_std = np.ones((3, 2, 2)) # 3 keys, 2 atypes, 2 max dims.
# nt x odim
foo_bias = np.array([1.0, 3.0]).reshape(2, 1)
bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2)
Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/test_polar_atomic_model_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def cvt_ret(x):
)
ret1 = md0.forward_common_atomic(*args)
ret1 = cvt_ret(ret1)
expected_std = np.ones((1,2,9)) # 1 keys, 2 atypes, 9 max dims.
expected_std = np.ones((1, 2, 9)) # 1 keys, 2 atypes, 9 max dims.
# nt x odim (dia)
diagnoal_bias = np.array(
[
Expand All @@ -222,7 +222,7 @@ def raise_error():
ret2 = cvt_ret(ret2)
np.testing.assert_almost_equal(ret1["polarizability"], ret2["polarizability"])
np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)

# 4. test change bias
BaseAtomicModel.change_out_bias(
md0, self.merged_output_stat, bias_adjust_mode="change-by-statistic"
Expand Down

0 comments on commit 09d775d

Please sign in to comment.