From d7036b8039d19ded99b42ed233dafdf8336a24fe Mon Sep 17 00:00:00 2001 From: anyangml Date: Wed, 17 Apr 2024 07:37:58 +0000 Subject: [PATCH] feat: add UT on out_std --- source/tests/pt/model/test_atomic_model_atomic_stat.py | 6 ++++-- source/tests/pt/model/test_atomic_model_global_stat.py | 5 +++++ source/tests/pt/model/test_polar_atomic_model_stat.py | 6 +++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/source/tests/pt/model/test_atomic_model_atomic_stat.py b/source/tests/pt/model/test_atomic_model_atomic_stat.py index 8f365a09fe..82b6ce6aa5 100644 --- a/source/tests/pt/model/test_atomic_model_atomic_stat.py +++ b/source/tests/pt/model/test_atomic_model_atomic_stat.py @@ -212,6 +212,8 @@ def cvt_ret(x): self.merged_output_stat, stat_file_path=self.stat_file_path ) ret1 = md0.forward_common_atomic(*args) + expected_std = np.ones((2,2,2)) # 2 keys, 2 atypes, 2 max dims. + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) ret1 = cvt_ret(ret1) # nt x odim foo_bias = np.array([5.0, 6.0]).reshape(2, 1) @@ -221,7 +223,6 @@ def cvt_ret(x): expected_ret1["bar"] = ret0["bar"] + bar_bias[at] for kk in ["foo", "bar"]: np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) - # 3. test bias load from file def raise_error(): raise RuntimeError @@ -231,6 +232,7 @@ def raise_error(): ret2 = cvt_ret(ret2) for kk in ["foo", "bar"]: np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) # 4. test change bias BaseAtomicModel.change_out_bias( @@ -254,7 +256,7 @@ def raise_error(): ).reshape(2, 3, 1) for kk in ["foo"]: np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) - + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) class TestAtomicModelStatMergeGlobalAtomic( unittest.TestCase, TestCaseSingleFrameWithNlist diff --git a/source/tests/pt/model/test_atomic_model_global_stat.py b/source/tests/pt/model/test_atomic_model_global_stat.py index ca71b604ce..bb4ea3ea32 100644 --- a/source/tests/pt/model/test_atomic_model_global_stat.py +++ b/source/tests/pt/model/test_atomic_model_global_stat.py @@ -193,6 +193,7 @@ def cvt_ret(x): # nf x na x odim ret0 = md0.forward_common_atomic(*args) ret0 = cvt_ret(ret0) + expected_ret0 = {} expected_ret0["foo"] = np.array( [ @@ -221,6 +222,7 @@ def cvt_ret(x): ) ret1 = md0.forward_common_atomic(*args) ret1 = cvt_ret(ret1) + expected_std = np.ones((3,2,2)) # 3 keys, 2 atypes, 2 max dims. # nt x odim foo_bias = np.array([1.0, 3.0]).reshape(2, 1) bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) @@ -230,6 +232,7 @@ def cvt_ret(x): expected_ret1["bar"] = ret0["bar"] + bar_bias[at] for kk in ["foo", "pix", "bar"]: np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) # 3. test bias load from file def raise_error(): @@ -240,6 +243,7 @@ def raise_error(): ret2 = cvt_ret(ret2) for kk in ["foo", "pix", "bar"]: np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) # 4. test change bias BaseAtomicModel.change_out_bias( @@ -266,6 +270,7 @@ def raise_error(): for kk in ["foo", "pix"]: np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk]) # bar is too complicated to be manually computed. + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) def test_preset_bias(self): nf, nloc, nnei = self.nlist.shape diff --git a/source/tests/pt/model/test_polar_atomic_model_stat.py b/source/tests/pt/model/test_polar_atomic_model_stat.py index 43a58d12a7..c130dc24c5 100644 --- a/source/tests/pt/model/test_polar_atomic_model_stat.py +++ b/source/tests/pt/model/test_polar_atomic_model_stat.py @@ -198,6 +198,7 @@ def cvt_ret(x): ) ret1 = md0.forward_common_atomic(*args) ret1 = cvt_ret(ret1) + expected_std = np.ones((1,2,9)) # 1 keys, 2 atypes, 9 max dims. # nt x odim (dia) diagnoal_bias = np.array( [ @@ -210,6 +211,7 @@ def cvt_ret(x): np.testing.assert_almost_equal( ret1["polarizability"], expected_ret1["polarizability"] ) + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) # 3. test bias load from file def raise_error(): @@ -219,7 +221,8 @@ def raise_error(): ret2 = md0.forward_common_atomic(*args) ret2 = cvt_ret(ret2) np.testing.assert_almost_equal(ret1["polarizability"], ret2["polarizability"]) - + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) + # 4. test change bias BaseAtomicModel.change_out_bias( md0, self.merged_output_stat, bias_adjust_mode="change-by-statistic" @@ -256,3 +259,4 @@ def raise_error(): np.testing.assert_almost_equal( ret3["polarizability"], expected_ret3["polarizability"], decimal=4 ) + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)