Skip to content

Commit

Permalink
Refactor LBA tests to improve parameter handling and add vectorizatio…
Browse files Browse the repository at this point in the history
…n helpers
  • Loading branch information
cpaniaguam committed Sep 26, 2024
1 parent c9c590f commit c56d654
Showing 1 changed file with 50 additions and 111 deletions.
161 changes: 50 additions & 111 deletions tests/test_likelihoods_lba.py
Original file line number Diff line number Diff line change
@@ -1,143 +1,82 @@
"""Unit testing for LBA likelihood functions."""

from pathlib import Path
from itertools import product

import numpy as np
import pandas as pd
import pymc as pm
import pytensor
import pytensor.tensor as pt
import pytest
import arviz as az
from pytensor.compile.nanguardmode import NanGuardMode

import hssm

# pylint: disable=C0413
from hssm.likelihoods.analytical import logp_lba2, logp_lba3
from hssm.likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox
from hssm.distribution_utils import make_likelihood_callable

hssm.set_floatX("float32")

CLOSE_TOLERANCE = 1e-4


def test_lba2_basic():
size = 1000
def filter_theta(theta, exclude_keys=["A", "b"]):
"""Filter out specific keys from the theta dictionary."""
return {k: v for k, v in theta.items() if k not in exclude_keys}

lba_data_out = hssm.simulate_data(
model="lba2", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0), size=size
)

# Test if vectorization ok across parameters
out_A_vec = logp_lba2(
lba_data_out.values, A=np.array([0.2] * size), b=0.5, v0=1.0, v1=1.0
).eval()
out_base = logp_lba2(lba_data_out.values, A=0.2, b=0.5, v0=1.0, v1=1.0).eval()
assert np.allclose(out_A_vec, out_base, atol=CLOSE_TOLERANCE)

out_b_vec = logp_lba2(
lba_data_out.values,
A=np.array([0.2] * size),
b=np.array([0.5] * size),
v0=1.0,
v1=1.0,
).eval()
assert np.allclose(out_b_vec, out_base, atol=CLOSE_TOLERANCE)

out_v_vec = logp_lba2(
lba_data_out.values,
A=np.array([0.2] * size),
b=np.array([0.5] * size),
v0=np.array([1.0] * size),
v1=np.array([1.0] * size),
).eval()
assert np.allclose(out_v_vec, out_base, atol=CLOSE_TOLERANCE)

# Test A > b leads to error
def assert_parameter_value_error(logp_func, lba_data_out, A, b, theta):
"""Helper function to assert ParameterValueError for given parameters."""
with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba2(
lba_data_out.values, A=np.array([0.6] * 1000), b=0.5, v0=1.0, v1=1.0
logp_func(
lba_data_out.values,
A=A,
b=b,
**filter_theta(theta, ["A", "b"]),
).eval()

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba2(lba_data_out.values, A=0.6, b=0.5, v0=1.0, v1=1.0).eval()

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba2(
lba_data_out.values, A=0.6, b=np.array([0.5] * 1000), v0=1.0, v1=1.0
).eval()
def vectorize_param(theta, param, size):
"""
Vectorize a specific parameter in the theta dictionary.
with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba2(
lba_data_out.values,
A=np.array([0.6] * 1000),
b=np.array([0.5] * 1000),
v0=1.0,
v1=1.0,
).eval()
Parameters:
theta (dict): Dictionary of parameters.
param (str): The parameter to vectorize.
size (int): The size of the vector.
Returns:
dict: A new dictionary with the specified parameter vectorized.
def test_lba3_basic():
size = 1000
Examples:
>>> theta = {"A": 0.2, "b": 0.5, "v0": 1.0, "v1": 1.0}
>>> vectorize_param(theta, "A", 3)
{'A': array([0.2, 0.2, 0.2]), 'b': 0.5, 'v0': 1.0, 'v1': 1.0}
lba_data_out = hssm.simulate_data(
model="lba3", theta=dict(A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0), size=size
)

# Test if vectorization ok across parameters
out_A_vec = logp_lba3(
lba_data_out.values, A=np.array([0.2] * size), b=0.5, v0=1.0, v1=1.0, v2=1.0
).eval()

out_base = logp_lba3(
lba_data_out.values, A=0.2, b=0.5, v0=1.0, v1=1.0, v2=1.0
).eval()

assert np.allclose(out_A_vec, out_base, atol=CLOSE_TOLERANCE)

out_b_vec = logp_lba3(
lba_data_out.values,
A=np.array([0.2] * size),
b=np.array([0.5] * size),
v0=1.0,
v1=1.0,
v2=1.0,
).eval()
assert np.allclose(out_b_vec, out_base, atol=CLOSE_TOLERANCE)

out_v_vec = logp_lba3(
lba_data_out.values,
A=np.array([0.2] * size),
b=np.array([0.5] * size),
v0=np.array([1.0] * size),
v1=np.array([1.0] * size),
v2=np.array([1.0] * size),
).eval()
assert np.allclose(out_v_vec, out_base, atol=CLOSE_TOLERANCE)
>>> vectorize_param(theta, "v0", 2)
{'A': 0.2, 'b': 0.5, 'v0': array([1., 1.]), 'v1': 1.0}
"""
return {k: (np.full(size, v) if k == param else v) for k, v in theta.items()}

# Test A > b leads to error
with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba3(
lba_data_out.values, A=np.array([0.6] * 1000), b=0.5, v0=1.0, v1=1.0, v2=1.0
).eval()

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba3(lba_data_out.values, b=0.5, A=0.6, v0=1.0, v1=1.0, v2=1.0).eval()
theta_lba2 = dict(A=0.2, b=0.5, v0=1.0, v1=1.0)
theta_lba3 = theta_lba2 | {"v2": 1.0}

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba3(
lba_data_out.values, A=0.6, b=np.array([0.5] * 1000), v0=1.0, v1=1.0, v2=1.0
).eval()

with pytest.raises(pm.logprob.utils.ParameterValueError):
logp_lba3(
lba_data_out.values,
A=np.array([0.6] * 1000),
b=np.array([0.5] * 1000),
v0=1.0,
v1=1.0,
v2=1.0,
).eval()
@pytest.mark.parametrize(
"logp_func, model, theta",
[(logp_lba2, "lba2", theta_lba2), (logp_lba3, "lba3", theta_lba3)],
)
def test_lba(logp_func, model, theta):
size = 1000
lba_data_out = hssm.simulate_data(model=model, theta=theta, size=size)

# Test if vectorization is ok across parameters
for param in theta:
param_vec = vectorize_param(theta, param, size)
out_vec = logp_func(lba_data_out.values, **param_vec).eval()
out_base = logp_func(lba_data_out.values, **theta).eval()
assert np.allclose(out_vec, out_base, atol=CLOSE_TOLERANCE)

# Test A > b leads to error
A_values = [np.full(size, 0.6), 0.6]
b_values = [np.full(size, 0.5), 0.5]

for A, b in product(A_values, b_values):
assert_parameter_value_error(logp_func, lba_data_out, A, b, theta)

0 comments on commit c56d654

Please sign in to comment.