Skip to content

Commit

Permalink
Add function that caches sampling results
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 4, 2023
1 parent 150fb0f commit 9f0d5c7
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Utils

spline.bspline_interpolation
prior.prior_from_idata
cache.cache_sampling


Statespace Models
Expand Down
44 changes: 44 additions & 0 deletions pymc_experimental/tests/utils/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os

import pymc as pm

from pymc_experimental.utils.cache import cache_sampling


def test_cache_sampling(tmpdir):

with pm.Model() as m:
x = pm.Normal("x", 0, 1)
y = pm.Normal("y", mu=x, observed=[0, 1, 2])

cache_prior = cache_sampling(pm.sample_prior_predictive, path=tmpdir)
cache_post = cache_sampling(pm.sample, path=tmpdir)
cache_pred = cache_sampling(pm.sample_posterior_predictive, path=tmpdir)
assert len(os.listdir(tmpdir)) == 0

prior1, prior2 = (cache_prior(samples=5) for _ in range(2))
assert len(os.listdir(tmpdir)) == 1
assert prior1.prior["x"].mean() == prior2.prior["x"].mean()

post1, post2 = (cache_post(tune=5, draws=5, progressbar=False) for _ in range(2))
assert len(os.listdir(tmpdir)) == 2
assert post1.posterior["x"].mean() == post2.posterior["x"].mean()

# Change model
with pm.Model() as m:
x = pm.Normal("x", 0, 1)
y = pm.Normal("y", mu=x, observed=[0, 1, 2, 3])

post3 = cache_post(tune=5, draws=5, progressbar=False)
assert len(os.listdir(tmpdir)) == 3
assert post3.posterior["x"].mean() != post1.posterior["x"].mean()

pred1, pred2 = (cache_pred(trace=post3, progressbar=False) for _ in range(2))
assert len(os.listdir(tmpdir)) == 4
assert pred1.posterior_predictive["y"].mean() == pred2.posterior_predictive["y"].mean()
assert "x" not in pred1.posterior_predictive

# Change kwargs
pred3 = cache_pred(trace=post3, progressbar=False, var_names=["x"])
assert len(os.listdir(tmpdir)) == 5
assert "x" in pred3.posterior_predictive
169 changes: 169 additions & 0 deletions pymc_experimental/utils/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import hashlib
import os
import sys
from typing import Literal

import arviz as az
import numpy as np
from pymc import (
modelcontext,
sample,
sample_posterior_predictive,
sample_prior_predictive,
)
from pymc.model.fgraph import fgraph_from_model
from pytensor.compile import SharedVariable
from pytensor.graph import Constant, FunctionGraph
from pytensor.scalar import ScalarType
from pytensor.tensor import TensorType
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.type_other import NoneTypeT


def hash_data(c):
if isinstance(c.type, NoneTypeT):
return ""
if isinstance(c.type, (ScalarType, TensorType)):
if isinstance(c, Constant):
arr = c.data
elif isinstance(c, SharedVariable):
arr = c.get_value(borrow=True)
arr_data = arr.view(np.uint8) if arr.size > 1 else arr.tobytes()
return hashlib.sha1(arr_data).hexdigest()
else:
raise NotImplementedError(f"Hashing not implemented for type {c.type}")


def get_name_and_props(obj):
name = str(obj)
props = str(getattr(obj, "_props", lambda: {})())
return name, props


def hash_from_fg(fg: FunctionGraph) -> int:
objects_to_hash = []
for node in fg.toposort():
objects_to_hash.append(
(
get_name_and_props(node.op),
tuple(get_name_and_props(inp.type) for inp in node.inputs),
tuple(get_name_and_props(out.type) for out in node.outputs),
# Name is not a symbolic input in the fgraph representation, maybe it should?
tuple(inp.name for inp in node.inputs if inp.name),
tuple(out.name for out in node.outputs if out.name),
)
)
objects_to_hash.append(
tuple(
hash_data(c)
for c in node.inputs
if (
isinstance(c, (Constant, SharedVariable))
# Ignore RNG values
and not isinstance(c.type, RandomType)
)
)
)
str_hash = "\n".join(map(str, objects_to_hash))
return hashlib.sha1(str_hash.encode()).hexdigest()


def cache_sampling(
sampling_fn: Literal[sample, sample_prior_predictive, sample_posterior_predictive],
path: str = "",
force_sample: bool = False,
):
"""Cache the result of PyMC sampling.
Parameter
---------
sampling_fn: Callable
Must be one of `pymc.sample`, `pymc.sample_prior_predictive` or `pymc.sample_posterior_predictive`.
Positional arguments are disallowed.
path: string, Optional
The path where the results should be saved or retrieved from. Defaults to working directory.
force_sample: bool, Optional
Whether to force sampling even if cache is found. Defaults to False.
Returns
-------
cached_sampling_fn: Callable
Function that wraps the sampling_fn. When called, the wrapped function will look for a valid cached result.
A valid cache requires the same:
1. Model and data
2. Sampling function
3. Sampling kwargs, ignoring ``random_seed``, ``trace``, ``progressbar``, ``extend_inferencedata`` and ``compile_kwargs``.
If o valid cache is found, sampling is bypassed altogether, unless ``force_sample=True``.
Otherwise, sampling is performed and the result cached for future reuse.
Caching is done on the basis of SHA-1 hashing, and there could be unlikely false positives.
Examples
--------
.. code-block:: python
import pymc as pm
from pymc_experimental.utils.cache import cache_sampling
with pm.Model() as m:
x = pm.Normal("x", 0, 1)
y = pm.Normal("y", mu=x, observed=[0, 1, 2])
cache_sample = cache_sampling(pm.sample, path="data")
idata1 = cache_sample(chains=2)
# Cache hit! Returning stored result
idata2 = cache_sample(chains=2)
assert idata1 == idata2
"""
allowed_fns = (sample, sample_prior_predictive, sample_posterior_predictive)
if sampling_fn not in allowed_fns:
raise ValueError(f"Cache sampling can only be used with {allowed_fns}")

def wrapped_sampling_fn(*args, model=None, random_seed=None, **kwargs):
if args:
raise ValueError("Non-keyword arguments not allowed in cache_sampling")

extend_inferencedata = kwargs.pop("extend_inferencedata", False)

# Model hash
model = modelcontext(model)
fg, _ = fgraph_from_model(model)
model_hash = hash_from_fg(fg)

# Sampling hash
sampling_hash_kwargs = kwargs.copy()
sampling_hash_kwargs["sampling_fn"] = str(sampling_fn)
sampling_hash_kwargs.pop("trace", None)
sampling_hash_kwargs.pop("random_seed", None)
sampling_hash_kwargs.pop("progressbar", None)
sampling_hash_kwargs.pop("compile_kwargs", None)
sampling_hash = str(sampling_hash_kwargs)

file_name = hashlib.sha1((model_hash + sampling_hash).encode()).hexdigest() + ".nc"
file_path = os.path.join(path, file_name)

if not force_sample and os.path.exists(file_path):
print("Cache hit! Returning stored result", file=sys.stdout)
idata_out = az.from_netcdf(file_path)

else:
idata_out = sampling_fn(*args, **kwargs, model=model, random_seed=random_seed)
if os.path.exists(file_path):
os.remove(file_path)
if not os.path.exists(path):
os.mkdir(path)
az.to_netcdf(idata_out, file_path)

# We save inferencedata separately and extend if needed
if extend_inferencedata:
trace = kwargs["trace"]
trace.extend(idata_out)
idata_out = trace

return idata_out

return wrapped_sampling_fn

0 comments on commit 9f0d5c7

Please sign in to comment.