Skip to content

Commit

Permalink
Fix pathfinder wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 1, 2023
1 parent b08610c commit c8c601d
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 66 deletions.
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ dependencies:
- xhistogram
- statsmodels
- pip:
- pymc>=5.8.1 # CI was failing to resolve
- pymc>=5.9.0 # CI was failing to resolve
- blackjax
- scikit-learn
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ dependencies:
- xhistogram
- statsmodels
- pip:
- pymc>=5.8.1 # CI was failing to resolve
- pymc>=5.9.0 # CI was failing to resolve
- scikit-learn
5 changes: 4 additions & 1 deletion pymc_experimental/inference/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def fit(method, **kwargs):
"""
if method == "pathfinder":
try:
from pymc_experimental.inference.pathfinder import fit_pathfinder
import blackjax
except ImportError as exc:
raise RuntimeError("Need BlackJAX to use `pathfinder`") from exc

from pymc_experimental.inference.pathfinder import fit_pathfinder

return fit_pathfinder(**kwargs)
85 changes: 36 additions & 49 deletions pymc_experimental/inference/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import collections
import sys
from typing import Optional

import arviz as az
import blackjax
import jax
import jax.numpy as jnp
import jax.random as random
import numpy as np
import pymc as pm
from pymc import modelcontext
from packaging import version
from pymc.backends.arviz import coords_and_dims_for_inferencedata
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.model import modelcontext
from pymc.sampling.jax import get_jaxified_graph
from pymc.util import RandomSeed, _get_seeds_per_chain, get_default_varnames


def convert_flat_trace_to_idata(
samples,
dims=None,
coords=None,
include_transformed=False,
postprocessing_backend="cpu",
model=None,
):

model = modelcontext(model)
init_position_dict = model.initial_point()
ip = model.initial_point()
ip_point_map_info = pm.blocking.DictToArrayBijection.map(ip).point_map_info
trace = collections.defaultdict(list)
astart = pm.blocking.DictToArrayBijection.map(init_position_dict)
for sample in samples:
raveld_vars = pm.blocking.RaveledVars(sample, astart.point_map_info)
point = pm.blocking.DictToArrayBijection.rmap(raveld_vars, init_position_dict)
raveld_vars = RaveledVars(sample, ip_point_map_info)
point = DictToArrayBijection.rmap(raveld_vars, ip)
for p, v in point.items():
trace[p].append(v.tolist())

Expand All @@ -57,19 +55,19 @@ def convert_flat_trace_to_idata(
result = jax.vmap(jax.vmap(jax_fn))(
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
)

trace = {v.name: r for v, r in zip(vars_to_sample, result)}
coords, dims = coords_and_dims_for_inferencedata(model)
idata = az.from_dict(trace, dims=dims, coords=coords)

return idata


def fit_pathfinder(
iterations=5_000,
samples=1000,
random_seed: Optional[RandomSeed] = None,
postprocessing_backend="cpu",
ftol=1e-4,
model=None,
**pathfinder_kwargs,
):
"""
Fit the pathfinder algorithm as implemented in blackjax
Expand All @@ -78,15 +76,15 @@ def fit_pathfinder(
Parameters
----------
iterations : int
Number of iterations to run.
samples : int
Number of samples to draw from the fitted approximation.
random_seed : int
Random seed to set.
postprocessing_backend : str
Where to compute transformations of the trace.
"cpu" or "gpu".
ftol : float
Floating point tolerance
pathfinder_kwargs:
kwargs for blackjax.vi.pathfinder.approximate
Returns
-------
Expand All @@ -96,53 +94,42 @@ def fit_pathfinder(
---------
https://arxiv.org/abs/2108.03782
"""

(random_seed,) = _get_seeds_per_chain(random_seed, 1)
# Temporarily helper
if version.parse(blackjax.__version__).major < 1:
raise ImportError("fit_pathfinder requires blackjax 1.0 or above")

model = modelcontext(model)

rvs = [rv.name for rv in model.value_vars]
init_position_dict = model.initial_point()
init_position = [init_position_dict[rv] for rv in rvs]
ip = model.initial_point()
ip_map = DictToArrayBijection.map(ip)

new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
init_position_dict, (model.logp(),), model.value_vars, ()
ip, (model.logp(),), model.value_vars, ()
)

logprob_fn_list = get_jaxified_graph([new_input], new_logprob)

def logprob_fn(x):
return logprob_fn_list(x)[0]

dim = sum(v.size for v in init_position_dict.values())

rng_key = random.PRNGKey(random_seed)
w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(dim), jnp.eye(dim))
path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=ftol)

pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=ftol)
state = pathfinder.init(w0)

def inference_loop(rng_key, kernel, initial_state, num_samples):
@jax.jit
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)
[pathfinder_seed, sample_seed] = _get_seeds_per_chain(random_seed, 2)

keys = jax.random.split(rng_key, num_samples)
return jax.lax.scan(one_step, initial_state, keys)

_, rng_key = random.split(rng_key)
print("Running pathfinder...", file=sys.stdout)
_, (_, samples) = inference_loop(rng_key, pathfinder.step, state, iterations)

dims = {
var_name: [dim for dim in dims if dim is not None]
for var_name, dims in model.named_vars_to_dims.items()
}
pathfinder_state, _ = blackjax.vi.pathfinder.approximate(
rng_key=jax.random.key(pathfinder_seed),
logdensity_fn=logprob_fn,
initial_position=ip_map.data,
**pathfinder_kwargs,
)
samples, _ = blackjax.vi.pathfinder.sample(
rng_key=jax.random.key(sample_seed),
state=pathfinder_state,
num_samples=samples,
)

idata = convert_flat_trace_to_idata(
samples, postprocessing_backend=postprocessing_backend, coords=model.coords, dims=dims
samples,
postprocessing_backend=postprocessing_backend,
model=model,
)

return idata
20 changes: 7 additions & 13 deletions pymc_experimental/tests/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@
import pymc_experimental as pmx


# TODO: Remove this filterwarning after pytensor uses jnp.prod instead of jnp.product
@pytest.mark.skipif(sys.platform == "win32", reason="JAX not supported on windows.")
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="pymc.sampling.jax does not currently support python < 3.10"
)
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_pathfinder():
# Data of the Eight Schools Model
J = 8
Expand All @@ -41,12 +36,11 @@ def test_pathfinder():
theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
obs = pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y)

idata = pmx.fit(method="pathfinder", iterations=100)
idata = pmx.fit(method="pathfinder", random_seed=41)

assert idata is not None
assert "theta" in idata.posterior._variables.keys()
assert "tau" in idata.posterior._variables.keys()
assert "mu" in idata.posterior._variables.keys()
assert idata.posterior["mu"].shape == (1, 100)
assert idata.posterior["tau"].shape == (1, 100)
assert idata.posterior["theta"].shape == (1, 100, 8)
assert idata.posterior["mu"].shape == (1, 1000)
assert idata.posterior["tau"].shape == (1, 1000)
assert idata.posterior["theta"].shape == (1, 1000, 8)
# FIXME: pathfinder doesn't find a reasonable mean! Fix bug or choose model pathfinder can handle
# np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0)
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=0.5)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pymc>=5.8.1
pymc>=5.8.2
scikit-learn

0 comments on commit c8c601d

Please sign in to comment.