diff --git a/source/tests/pt/model/test_linear_atomic_model_stat.py b/source/tests/pt/model/test_linear_atomic_model_stat.py index d8137c17c0..f3fe4f63d4 100644 --- a/source/tests/pt/model/test_linear_atomic_model_stat.py +++ b/source/tests/pt/model/test_linear_atomic_model_stat.py @@ -80,7 +80,7 @@ def forward( [4.0, 5.0, 6.0], ] ) - .view([nf, nloc] + self.output_def()["energy"].shape) + .view([nf, nloc, *self.output_def()["energy"].shape]) .to(env.GLOBAL_PT_FLOAT_PRECISION) .to(env.DEVICE) ) @@ -124,7 +124,7 @@ def forward( [10.0, 11.0, 12.0], ] ) - .view([nf, nloc] + self.output_def()["energy"].shape) + .view([nf, nloc, *self.output_def()["energy"].shape]) .to(env.GLOBAL_PT_FLOAT_PRECISION) .to(env.DEVICE) ) @@ -206,8 +206,8 @@ def test_linear_atomic_model_stat_with_bias(self): [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], ] - ).reshape([nf, nloc] + linear_model.fitting_output_def()["energy"].shape) - + ).reshape(nf, nloc, *linear_model.fitting_output_def()["energy"].shape) + np.testing.assert_almost_equal(ret0, expected_ret0) # 2. test bias is applied