diff --git a/src/hssm/likelihoods/analytical.py b/src/hssm/likelihoods/analytical.py index d626d26f..860e8f07 100644 --- a/src/hssm/likelihoods/analytical.py +++ b/src/hssm/likelihoods/analytical.py @@ -18,6 +18,17 @@ LOGP_LB = pm.floatX(-66.1) +π = np.pi +τ = 2 * π +sqrt_τ = pt.sqrt(τ) +log_π = pt.log(π) +log_τ = pt.log(τ) +log_4 = pt.log(4) + + +def _max(a: np.ndarray, b: np.ndarray) -> np.ndarray: + return pt.max(pt.stack([a, b]), axis=0) + def k_small(rt: np.ndarray, err: float) -> np.ndarray: """Determine number of terms needed for small-t expansion. @@ -34,9 +45,15 @@ def k_small(rt: np.ndarray, err: float) -> np.ndarray: np.ndarray A 1D at array of k_small. """ - ks = 2 + pt.sqrt(-2 * rt * pt.log(2 * np.sqrt(2 * np.pi * rt) * err)) - ks = pt.max(pt.stack([ks, pt.sqrt(rt) + 1]), axis=0) - ks = pt.switch(2 * pt.sqrt(2 * np.pi * rt) * err < 1, ks, 2) + sqrt_rt = pt.sqrt(rt) + log_rt = pt.log(rt) + rt_log_2_sqrt_τ_rt_times_2 = rt * (log_4 + log_τ + log_rt) + + ks = 2 + pt.sqrt(-err * rt_log_2_sqrt_τ_rt_times_2) + ks = _max(ks, sqrt_rt + 1) + + condition = 2 * sqrt_τ * sqrt_rt * err < 1 + ks = pt.switch(condition, ks, 2) return ks @@ -56,9 +73,16 @@ def k_large(rt: np.ndarray, err: float) -> np.ndarray: np.ndarray A 1D at array of k_large. """ - kl = pt.sqrt(-2 * pt.log(np.pi * rt * err) / (np.pi**2 * rt)) - kl = pt.max(pt.stack([kl, 1.0 / (np.pi * pt.sqrt(rt))]), axis=0) - kl = pt.switch(np.pi * rt * err < 1, kl, 1.0 / (np.pi * pt.sqrt(rt))) + log_rt = pt.log(rt) + sqrt_rt = pt.sqrt(rt) + log_err = pt.log(err) + + π_rt_err = π * rt * err + π_sqrt_rt = π * sqrt_rt + + kl = pt.sqrt(-2 * (log_π + log_err + log_rt)) / π_sqrt_rt + kl = _max(kl, 1.0 / pt.sqrt(π_sqrt_rt)) + kl = pt.switch(π_rt_err < 1, kl, 1.0 / π_sqrt_rt) return kl @@ -141,7 +165,7 @@ def ftt01w_fast(tt: np.ndarray, w: float, k_terms: int) -> np.ndarray: c = pt.max(r, axis=0) p = pt.exp(c) * pt.sum(y * pt.exp(r - c), axis=0) # Normalize p - p = p / pt.sqrt(2 * np.pi * pt.power(tt, 3)) + p = p / pt.sqrt(2 * π * pt.power(tt, 3)) return p @@ -167,9 +191,9 @@ def ftt01w_slow(tt: np.ndarray, w: float, k_terms: int) -> np.ndarray: The approximated function f(tt|0, 1, w). """ k = get_ks(k_terms, fast=False) - y = k * pt.sin(k * np.pi * w) - r = -pt.power(k, 2) * pt.power(np.pi, 2) * tt / 2 - p = pt.sum(y * pt.exp(r), axis=0) * np.pi + y = k * pt.sin(k * π * w) + r = -pt.power(k, 2) * pt.power(π, 2) * tt / 2 + p = pt.sum(y * pt.exp(r), axis=0) * π return p diff --git a/tests/test_likelihoods_lba.py b/tests/test_likelihoods_lba.py index 1508b095..4cbef88a 100644 --- a/tests/test_likelihoods_lba.py +++ b/tests/test_likelihoods_lba.py @@ -1,143 +1,82 @@ """Unit testing for LBA likelihood functions.""" -from pathlib import Path from itertools import product import numpy as np -import pandas as pd import pymc as pm -import pytensor -import pytensor.tensor as pt import pytest -import arviz as az -from pytensor.compile.nanguardmode import NanGuardMode import hssm # pylint: disable=C0413 from hssm.likelihoods.analytical import logp_lba2, logp_lba3 -from hssm.likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox -from hssm.distribution_utils import make_likelihood_callable hssm.set_floatX("float32") CLOSE_TOLERANCE = 1e-4 -def test_lba2_basic(): - size = 1000 +def filter_theta(theta, exclude_keys=["A", "b"]): + """Filter out specific keys from the theta dictionary.""" + return {k: v for k, v in theta.items() if k not in exclude_keys} - lba_data_out = hssm.simulate_data( - model="lba2", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0), size=size - ) - - # Test if vectorization ok across parameters - out_A_vec = logp_lba2( - lba_data_out.values, A=np.array([0.2] * size), b=0.5, v0=1.0, v1=1.0 - ).eval() - out_base = logp_lba2(lba_data_out.values, A=0.2, b=0.5, v0=1.0, v1=1.0).eval() - assert np.allclose(out_A_vec, out_base, atol=CLOSE_TOLERANCE) - - out_b_vec = logp_lba2( - lba_data_out.values, - A=np.array([0.2] * size), - b=np.array([0.5] * size), - v0=1.0, - v1=1.0, - ).eval() - assert np.allclose(out_b_vec, out_base, atol=CLOSE_TOLERANCE) - - out_v_vec = logp_lba2( - lba_data_out.values, - A=np.array([0.2] * size), - b=np.array([0.5] * size), - v0=np.array([1.0] * size), - v1=np.array([1.0] * size), - ).eval() - assert np.allclose(out_v_vec, out_base, atol=CLOSE_TOLERANCE) - # Test A > b leads to error +def assert_parameter_value_error(logp_func, lba_data_out, A, b, theta): + """Helper function to assert ParameterValueError for given parameters.""" with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba2( - lba_data_out.values, A=np.array([0.6] * 1000), b=0.5, v0=1.0, v1=1.0 + logp_func( + lba_data_out.values, + A=A, + b=b, + **filter_theta(theta, ["A", "b"]), ).eval() - with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba2(lba_data_out.values, A=0.6, b=0.5, v0=1.0, v1=1.0).eval() - with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba2( - lba_data_out.values, A=0.6, b=np.array([0.5] * 1000), v0=1.0, v1=1.0 - ).eval() +def vectorize_param(theta, param, size): + """ + Vectorize a specific parameter in the theta dictionary. - with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba2( - lba_data_out.values, - A=np.array([0.6] * 1000), - b=np.array([0.5] * 1000), - v0=1.0, - v1=1.0, - ).eval() + Parameters: + theta (dict): Dictionary of parameters. + param (str): The parameter to vectorize. + size (int): The size of the vector. + Returns: + dict: A new dictionary with the specified parameter vectorized. -def test_lba3_basic(): - size = 1000 + Examples: + >>> theta = {"A": 0.2, "b": 0.5, "v0": 1.0, "v1": 1.0} + >>> vectorize_param(theta, "A", 3) + {'A': array([0.2, 0.2, 0.2]), 'b': 0.5, 'v0': 1.0, 'v1': 1.0} - lba_data_out = hssm.simulate_data( - model="lba3", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0), size=size - ) - - # Test if vectorization ok across parameters - out_A_vec = logp_lba3( - lba_data_out.values, A=np.array([0.2] * size), b=0.5, v0=1.0, v1=1.0, v2=1.0 - ).eval() - - out_base = logp_lba3( - lba_data_out.values, A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0 - ).eval() - - assert np.allclose(out_A_vec, out_base, atol=CLOSE_TOLERANCE) - - out_b_vec = logp_lba3( - lba_data_out.values, - A=np.array([0.2] * size), - b=np.array([0.5] * size), - v0=1.0, - v1=1.0, - v2=1.0, - ).eval() - assert np.allclose(out_b_vec, out_base, atol=CLOSE_TOLERANCE) - - out_v_vec = logp_lba3( - lba_data_out.values, - A=np.array([0.2] * size), - b=np.array([0.5] * size), - v0=np.array([1.0] * size), - v1=np.array([1.0] * size), - v2=np.array([1.0] * size), - ).eval() - assert np.allclose(out_v_vec, out_base, atol=CLOSE_TOLERANCE) + >>> vectorize_param(theta, "v0", 2) + {'A': 0.2, 'b': 0.5, 'v0': array([1., 1.]), 'v1': 1.0} + """ + return {k: (np.full(size, v) if k == param else v) for k, v in theta.items()} - # Test A > b leads to error - with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba3( - lba_data_out.values, A=np.array([0.6] * 1000), b=0.5, v0=1.0, v1=1.0, v2=1.0 - ).eval() - with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba3(lba_data_out.values, b=0.5, A=0.6, v0=1.0, v1=1.0, v2=1.0).eval() +theta_lba2 = dict(A=0.2, b=0.5, v0=1.0, v1=1.0) +theta_lba3 = theta_lba2 | {"v2": 1.0} - with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba3( - lba_data_out.values, A=0.6, b=np.array([0.5] * 1000), v0=1.0, v1=1.0, v2=1.0 - ).eval() - with pytest.raises(pm.logprob.utils.ParameterValueError): - logp_lba3( - lba_data_out.values, - A=np.array([0.6] * 1000), - b=np.array([0.5] * 1000), - v0=1.0, - v1=1.0, - v2=1.0, - ).eval() +@pytest.mark.parametrize( + "logp_func, model, theta", + [(logp_lba2, "lba2", theta_lba2), (logp_lba3, "lba3", theta_lba3)], +) +def test_lba(logp_func, model, theta): + size = 1000 + lba_data_out = hssm.simulate_data(model=model, theta=theta, size=size) + + # Test if vectorization is ok across parameters + for param in theta: + param_vec = vectorize_param(theta, param, size) + out_vec = logp_func(lba_data_out.values, **param_vec).eval() + out_base = logp_func(lba_data_out.values, **theta).eval() + assert np.allclose(out_vec, out_base, atol=CLOSE_TOLERANCE) + + # Test A > b leads to error + A_values = [np.full(size, 0.6), 0.6] + b_values = [np.full(size, 0.5), 0.5] + + for A, b in product(A_values, b_values): + assert_parameter_value_error(logp_func, lba_data_out, A, b, theta)