From 09d775d828f2e0a036849db98274bb1503410ba2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Apr 2024 07:38:29 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/model/test_atomic_model_atomic_stat.py | 4 +++- source/tests/pt/model/test_atomic_model_global_stat.py | 2 +- source/tests/pt/model/test_polar_atomic_model_stat.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/source/tests/pt/model/test_atomic_model_atomic_stat.py b/source/tests/pt/model/test_atomic_model_atomic_stat.py index 82b6ce6aa5..15d403f0c1 100644 --- a/source/tests/pt/model/test_atomic_model_atomic_stat.py +++ b/source/tests/pt/model/test_atomic_model_atomic_stat.py @@ -212,7 +212,7 @@ def cvt_ret(x): self.merged_output_stat, stat_file_path=self.stat_file_path ) ret1 = md0.forward_common_atomic(*args) - expected_std = np.ones((2,2,2)) # 2 keys, 2 atypes, 2 max dims. + expected_std = np.ones((2, 2, 2)) # 2 keys, 2 atypes, 2 max dims. np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) ret1 = cvt_ret(ret1) # nt x odim @@ -223,6 +223,7 @@ def cvt_ret(x): expected_ret1["bar"] = ret0["bar"] + bar_bias[at] for kk in ["foo", "bar"]: np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + # 3. test bias load from file def raise_error(): raise RuntimeError @@ -258,6 +259,7 @@ def raise_error(): np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) + class TestAtomicModelStatMergeGlobalAtomic( unittest.TestCase, TestCaseSingleFrameWithNlist ): diff --git a/source/tests/pt/model/test_atomic_model_global_stat.py b/source/tests/pt/model/test_atomic_model_global_stat.py index bb4ea3ea32..799948b14f 100644 --- a/source/tests/pt/model/test_atomic_model_global_stat.py +++ b/source/tests/pt/model/test_atomic_model_global_stat.py @@ -222,7 +222,7 @@ def cvt_ret(x): ) ret1 = md0.forward_common_atomic(*args) ret1 = cvt_ret(ret1) - expected_std = np.ones((3,2,2)) # 3 keys, 2 atypes, 2 max dims. + expected_std = np.ones((3, 2, 2)) # 3 keys, 2 atypes, 2 max dims. # nt x odim foo_bias = np.array([1.0, 3.0]).reshape(2, 1) bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) diff --git a/source/tests/pt/model/test_polar_atomic_model_stat.py b/source/tests/pt/model/test_polar_atomic_model_stat.py index c130dc24c5..f3d4f6955b 100644 --- a/source/tests/pt/model/test_polar_atomic_model_stat.py +++ b/source/tests/pt/model/test_polar_atomic_model_stat.py @@ -198,7 +198,7 @@ def cvt_ret(x): ) ret1 = md0.forward_common_atomic(*args) ret1 = cvt_ret(ret1) - expected_std = np.ones((1,2,9)) # 1 keys, 2 atypes, 9 max dims. + expected_std = np.ones((1, 2, 9)) # 1 keys, 2 atypes, 9 max dims. # nt x odim (dia) diagnoal_bias = np.array( [ @@ -222,7 +222,7 @@ def raise_error(): ret2 = cvt_ret(ret2) np.testing.assert_almost_equal(ret1["polarizability"], ret2["polarizability"]) np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) - + # 4. test change bias BaseAtomicModel.change_out_bias( md0, self.merged_output_stat, bias_adjust_mode="change-by-statistic"