Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function that caches sampling results #277

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Dec 4, 2023

import pymc as pm
from pymc_experimental.utils.cache import cache_sampling

with pm.Model() as m:
    y_data = pm.MutableData("y_data", [0, 1, 2])
    x = pm.Normal("x", 0, 1)
    y = pm.Normal("y", mu=x, observed=y_data)

    cache_sample = cache_sampling(pm.sample, dir="traces")
    idata1 = cache_sample(chains=2)

    # Cache hit! Returning stored result
    idata2 = cache_sample(chains=2)

    pm.set_data({"y_data": [1, 1, 1]})
    idata3 = cache_sample(chains=2)

assert idata1.posterior["x"].mean() == idata2.posterior["x"].mean()
assert idata1.posterior["x"].mean() != idata3.posterior["x"].mean()

@ricardoV94 ricardoV94 added the enhancements New feature or request label Dec 4, 2023
@twiecki
Copy link
Member

twiecki commented Dec 4, 2023

When would that be useful?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Dec 4, 2023

When rerunning notebooks or any workflow with saving/loading of traces where you might still be tinkering with the model.

You don't need to bother defining the names of the traces, or overriding old traces, since caching is automatically derived from the model and its data

@ricardoV94 ricardoV94 force-pushed the cache_sampling branch 3 times, most recently from df085cc to 5b20bdf Compare December 4, 2023 18:35
return name, props


def hash_from_fg(fg: FunctionGraph) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe to pytensor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too experimental for now

@fonnesbeck
Copy link
Member

I usually rely on things like MLFlow for storing artifacts like this.

@ricardoV94
Copy link
Member Author

I'm not familiar with MLflow, the idea here is that it pairs the saved traces to the exact model/sampling function (and arguments) that were used.

Basically the model and the function kwargs are the cache key.

Does this have any parallel to your workflow with MLflow?

Comment on lines +42 to +43
name = str(obj)
props = str(getattr(obj, "_props", lambda: {})())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
name = str(obj)
props = str(getattr(obj, "_props", lambda: {})())
name = str(obj)
if hasattr(obj, "_props"):
prop_dict = obj._prop_dict()
props = str(
{k: get_name_and_props(v) for k, v in prop_dict.items()}
)
else:
props = str({})
name = str(obj)
props = str(getattr(obj, "_props", lambda: {})())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just to make sure that potential recursion of _props is handled.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_props are not recursive

if os.path.exists(file_path):
os.remove(file_path)
if not os.path.exists(dir):
os.mkdir(dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
os.mkdir(dir)
os.makedirs(dir, exist_ok=True)

I think that it's better to use os.makedirs because it creates the intermediate directories if they are required.

az.to_netcdf(idata_out, file_path)

# We save inferencedata separately and extend if needed
if extend_inferencedata:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks weird to me. The sampling_fn will go make up the hash. So if someone first calls sample, and then sample_posterior_predictive using extend, the second cache will include both the posterior and posterior_predictive groups; but the first cache will only include the posterior group. I think that it's cleaner to never allow for extending the idata inplace, and force users to combine the different InferenceData objects themselves.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

Copy link
Member Author

@ricardoV94 ricardoV94 Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait no, I don't see what the problem is. We only ever save new idatas coming out of the sampling_fn and these never extend the previous one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants