Skip to content

Commit

Permalink
fix: suppress buggy output from pytensor.function
Browse files Browse the repository at this point in the history
  • Loading branch information
cpaniaguam committed Oct 18, 2024
1 parent fa12018 commit 929d2b8
Showing 1 changed file with 37 additions and 12 deletions.
49 changes: 37 additions & 12 deletions tests/test_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
old implementation of WFPT from (https://github.com/hddm-devs/hddm)
"""

import contextlib
import logging
import os
from pathlib import Path
from itertools import product

Expand All @@ -25,6 +28,25 @@

hssm.set_floatX("float32")


# Temporary measure to suppress output from pytensor.function
# See issues #594 in hssm and #1037 in pymc-devs/pytensor repos
class SuppressOutput:
def __enter__(self):
self._null_file = open(os.devnull, "w")
self._stdout_context = contextlib.redirect_stdout(self._null_file)
self._stderr_context = contextlib.redirect_stderr(self._null_file)
self._stdout_context.__enter__()
self._stderr_context.__enter__()
logging.disable(logging.CRITICAL) # Disable logging

def __exit__(self, exc_type, exc_value, traceback):
self._stdout_context.__exit__(exc_type, exc_value, traceback)
self._stderr_context.__exit__(exc_type, exc_value, traceback)
self._null_file.close()
logging.disable(logging.NOTSET) # Re-enable logging


# def test_logp(data_fixture):
# """
# This function compares new and old implementation of logp calculation
Expand Down Expand Up @@ -109,11 +131,13 @@ def test_analytical_gradient():
size = cav_data_numpy.shape[0]
logp = logp_ddm(cav_data_numpy, v, a, z, t).sum()
grad = pt.grad(logp, wrt=[v, a, z, t])
grad_func = pytensor.function(
[v, a, z, t],
grad,
mode=nan_guard_mode,
)

with SuppressOutput():
grad_func = pytensor.function(
[v, a, z, t],
grad,
mode=nan_guard_mode,
)
v_test = np.random.normal(size=size)
a_test = np.random.uniform(0.0001, 2, size=size)
z_test = np.random.uniform(0.1, 1.0, size=size)
Expand All @@ -123,13 +147,14 @@ def test_analytical_gradient():

assert np.all(np.isfinite(grad), axis=None), "Gradient contains non-finite values."

grad_func_sdv = pytensor.function(
[v, a, z, t, sv],
pt.grad(
logp_ddm_sdv(cav_data_numpy, v, a, z, t, sv).sum(), wrt=[v, a, z, t, sv]
),
mode=nan_guard_mode,
)
with SuppressOutput():
grad_func_sdv = pytensor.function(
[v, a, z, t, sv],
pt.grad(
logp_ddm_sdv(cav_data_numpy, v, a, z, t, sv).sum(), wrt=[v, a, z, t, sv]
),
mode=nan_guard_mode,
)

grad_sdv = np.array(grad_func_sdv(v_test, a_test, z_test, t_test, sv_test))

Expand Down

0 comments on commit 929d2b8

Please sign in to comment.