Skip to content

Commit

Permalink
add tests for criticaldamped
Browse files Browse the repository at this point in the history
  • Loading branch information
emptymalei committed Feb 20, 2024
1 parent 4a2cfee commit 1c2b202
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
24 changes: 22 additions & 2 deletions hamiltonian_flow/models/harmonic_oscillator.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,28 @@ def _x_under_damped(self, t: Union[float, np.ndarray]) -> Union[float, np.ndarra
* np.sin(omega_damp * t)
) * np.exp(-self.system.zeta * self.system.omega * t)

def _x_critical_damped(
self, t: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
r"""Solution to critical damped harmonic oscillators:
$$
x(t) = \left( x_0 \cos(\Omega t) + \frac{\zeta \omega x_0 + v_0}{\Omega} \sin(\Omega t) \right)
e^{-\zeta \omega t},
$$
where
$$
\Omega = \omega\sqrt{ 1 - \zeta^2}.
$$
"""
return self.initial_condition.x0 * np.exp(
-self.system.zeta * self.system.omega * t
)

def _x_over_damped(self, t: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
r"""Solution to under over harmonic oscillators:
r"""Solution to over harmonic oscillators:
$$
x(t) = \left( x_0 \cosh(\Gamma t) + \frac{\zeta \omega x_0 + v_0}{\Gamma} \sinh(\Gamma t) \right)
Expand Down Expand Up @@ -222,7 +242,7 @@ def __call__(self, n_periods: int, n_samples_per_period: int) -> pd.DataFrame:
elif self.system.type == "over_damped":
data = self._x_over_damped(time_steps)
elif self.system.type == "critical_damped":
data = self._x_under_damped(time_steps)
data = self._x_critical_damped(time_steps)
else:
raise ValueError(f"system type is not defined: {self.system.type}")

Expand Down
29 changes: 29 additions & 0 deletions tests/test_models/test_harmonic_oscillator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,32 @@ def test_overdamped_harmonic_oscillator(omega, zeta, expected):
df = ho(n_periods=1, n_samples_per_period=10)

pd.testing.assert_frame_equal(df, pd.DataFrame(expected))


@pytest.mark.parametrize(
"omega,zeta,expected",
[
(
0.5,
1.0,
[
{"t": 0.0, "x": 1.0},
{"t": 1.2566370614359172, "x": 0.5334880910911033},
{"t": 2.5132741228718345, "x": 0.2846095433360293},
{"t": 3.7699111843077517, "x": 0.1518358019806489},
{"t": 5.026548245743669, "x": 0.08100259215794314},
{"t": 6.283185307179586, "x": 0.04321391826377226},
{"t": 7.5398223686155035, "x": 0.023054110763106823},
{"t": 8.79645943005142, "x": 0.012299093542812717},
{"t": 10.053096491487338, "x": 0.006561419936306071},
{"t": 11.309733552923255, "x": 0.003500439396667034},
],
),
],
)
def test_criticaldamped_harmonic_oscillator(omega, zeta, expected):
ho = HarmonicOscillator(system={"omega": omega, "zeta": zeta})

df = ho(n_periods=1, n_samples_per_period=10)

pd.testing.assert_frame_equal(df, pd.DataFrame(expected))

0 comments on commit 1c2b202

Please sign in to comment.