From e0877ee6d1fd778acf84cfd9cac4c640adb8fa8d Mon Sep 17 00:00:00 2001 From: plumbum082 <54125554+plumbum082@users.noreply.github.com> Date: Fri, 20 Oct 2023 09:36:25 +0800 Subject: [PATCH] Fix no axis type bug (#120) * add rules of local axis for NoAxisType * add test for NoAxisType --- dmff/admp/spatial.py | 8 +++++ tests/test_admp/test_noaxistype.py | 50 ++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 tests/test_admp/test_noaxistype.py diff --git a/dmff/admp/spatial.py b/dmff/admp/spatial.py index 8336e98c4..e11850741 100644 --- a/dmff/admp/spatial.py +++ b/dmff/admp/spatial.py @@ -72,6 +72,7 @@ def generate_construct_local_frames(axis_types, axis_indices): Bisector_filter = (axis_types == Bisector) ZBisect_filter = (axis_types == ZBisect) ThreeFold_filter = (axis_types == ThreeFold) + NoAxisType_filter = (axis_types == NoAxisType) def construct_local_frames(positions, box): ''' @@ -139,6 +140,13 @@ def construct_local_frames(positions, box): vec_x = normalize(vec_x - vec_z * xz_projection, axis=1) # up to this point, x-axis should be ready vec_y = jnp.cross(vec_z, vec_x) + + # NoAxisType + if np.sum(NoAxisType_filter) > 0: + vec_y = vec_y.at[NoAxisType_filter].set(jnp.array([0,1,0])) + vec_z = vec_z.at[NoAxisType_filter].set(jnp.array([0,0,1])) + vec_x = vec_x.at[NoAxisType_filter].set(jnp.array([1,0,0])) + return jnp.stack((vec_x, vec_y, vec_z), axis=1) diff --git a/tests/test_admp/test_noaxistype.py b/tests/test_admp/test_noaxistype.py new file mode 100644 index 000000000..39de50e93 --- /dev/null +++ b/tests/test_admp/test_noaxistype.py @@ -0,0 +1,50 @@ +import jax.numpy as jnp +import numpy as np +import numpy.testing as npt +import pytest +from dmff.admp.spatial import (build_quasi_internal, + generate_construct_local_frames, pbc_shift, + v_pbc_shift) + + +class TestSpatial: + + @pytest.mark.parametrize( + "axis_types, axis_indices, positions, box, expected_local_frames", + [ + ( + np.array([5]), + np.array( + [ + [-1, -1, -1], + ] + ), + jnp.array( + [ + [0.992, 0.068, -0.073], + ] + ), + jnp.array([[50.000, 0.0, 0.0], [0.0, 50.000, 0.0], [0.0, 0.0, 50.000]]), + np.array( + [ + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], + ] + ), + ) + ], + ) + def test_generate_construct_local_frames( + self, axis_types, axis_indices, positions, box, expected_local_frames + ): + construct_local_frame_fn = generate_construct_local_frames( + axis_types, axis_indices + ) + assert construct_local_frame_fn + npt.assert_allclose( + construct_local_frame_fn(positions, box), expected_local_frames, rtol=1e-5 + ) +