Skip to content

Commit

Permalink
feat: add UT on out_std
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Apr 17, 2024
1 parent 3f11f7b commit d7036b8
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
6 changes: 4 additions & 2 deletions source/tests/pt/model/test_atomic_model_atomic_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def cvt_ret(x):
self.merged_output_stat, stat_file_path=self.stat_file_path
)
ret1 = md0.forward_common_atomic(*args)
expected_std = np.ones((2,2,2)) # 2 keys, 2 atypes, 2 max dims.
np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)
ret1 = cvt_ret(ret1)
# nt x odim
foo_bias = np.array([5.0, 6.0]).reshape(2, 1)
Expand All @@ -221,7 +223,6 @@ def cvt_ret(x):
expected_ret1["bar"] = ret0["bar"] + bar_bias[at]
for kk in ["foo", "bar"]:
np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk])

# 3. test bias load from file
def raise_error():
raise RuntimeError
Expand All @@ -231,6 +232,7 @@ def raise_error():
ret2 = cvt_ret(ret2)
for kk in ["foo", "bar"]:
np.testing.assert_almost_equal(ret1[kk], ret2[kk])
np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)

# 4. test change bias
BaseAtomicModel.change_out_bias(
Expand All @@ -254,7 +256,7 @@ def raise_error():
).reshape(2, 3, 1)
for kk in ["foo"]:
np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4)

np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)

class TestAtomicModelStatMergeGlobalAtomic(
unittest.TestCase, TestCaseSingleFrameWithNlist
Expand Down
5 changes: 5 additions & 0 deletions source/tests/pt/model/test_atomic_model_global_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def cvt_ret(x):
# nf x na x odim
ret0 = md0.forward_common_atomic(*args)
ret0 = cvt_ret(ret0)

expected_ret0 = {}
expected_ret0["foo"] = np.array(
[
Expand Down Expand Up @@ -221,6 +222,7 @@ def cvt_ret(x):
)
ret1 = md0.forward_common_atomic(*args)
ret1 = cvt_ret(ret1)
expected_std = np.ones((3,2,2)) # 3 keys, 2 atypes, 2 max dims.
# nt x odim
foo_bias = np.array([1.0, 3.0]).reshape(2, 1)
bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2)
Expand All @@ -230,6 +232,7 @@ def cvt_ret(x):
expected_ret1["bar"] = ret0["bar"] + bar_bias[at]
for kk in ["foo", "pix", "bar"]:
np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk])
np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)

# 3. test bias load from file
def raise_error():
Expand All @@ -240,6 +243,7 @@ def raise_error():
ret2 = cvt_ret(ret2)
for kk in ["foo", "pix", "bar"]:
np.testing.assert_almost_equal(ret1[kk], ret2[kk])
np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)

# 4. test change bias
BaseAtomicModel.change_out_bias(
Expand All @@ -266,6 +270,7 @@ def raise_error():
for kk in ["foo", "pix"]:
np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk])
# bar is too complicated to be manually computed.
np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)

def test_preset_bias(self):
nf, nloc, nnei = self.nlist.shape
Expand Down
6 changes: 5 additions & 1 deletion source/tests/pt/model/test_polar_atomic_model_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def cvt_ret(x):
)
ret1 = md0.forward_common_atomic(*args)
ret1 = cvt_ret(ret1)
expected_std = np.ones((1,2,9)) # 1 keys, 2 atypes, 9 max dims.
# nt x odim (dia)
diagnoal_bias = np.array(
[
Expand All @@ -210,6 +211,7 @@ def cvt_ret(x):
np.testing.assert_almost_equal(
ret1["polarizability"], expected_ret1["polarizability"]
)
np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)

# 3. test bias load from file
def raise_error():
Expand All @@ -219,7 +221,8 @@ def raise_error():
ret2 = md0.forward_common_atomic(*args)
ret2 = cvt_ret(ret2)
np.testing.assert_almost_equal(ret1["polarizability"], ret2["polarizability"])

np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)

# 4. test change bias
BaseAtomicModel.change_out_bias(
md0, self.merged_output_stat, bias_adjust_mode="change-by-statistic"
Expand Down Expand Up @@ -256,3 +259,4 @@ def raise_error():
np.testing.assert_almost_equal(
ret3["polarizability"], expected_ret3["polarizability"], decimal=4
)
np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std)

0 comments on commit d7036b8

Please sign in to comment.