Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup sample and allow specifying compile_kwargs (several major changes related to step samplers) #7578

Merged
merged 11 commits into from
Nov 29, 2024

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Nov 19, 2024

Major changes

  1. internal uses of logp_dlogp_function now work with raveled inputs. External use will issue a warning unless ravel_inputs is specified explicitly. Eventually it will only be possible to use ravel_inputs=True.
  2. Step samplers arguments besides vars must be passed by keyword
  3. RaveledVars point_map_info is now a 4-n tuple, with size introduced.
  4. assign_step_method does not call instantiate_steppers, but returns arguments needed for the latter.
  5. Allow passing compile_kwargs to pm.sample which is then forwarded to the step samplers functions

Enhancement

This PR speedups NUTS (and other step samplers), by:

  1. Avoiding many variable unravel and copies, by doing it inside PyTensor
  2. Avoiding copies when setting shared variables (borrow=True)
  3. Setting trust_input=True which can have a large overhead.
  4. Disabling GC collection for the C-backend function (related to Consider disabling PyTensor GC in sampling functions #7539)
  5. Using slots for faster attribute access in the Tree class (and smaller footprint)
  6. Inlining some functions and being more lazy when possible

This PR speedups sample by:

  1. Avoiding way too many pytensor function compilations (model.initial_point() and very silly trace.fn after slicing at the end. It's also silly to compile the same function for every trace. We should just copy it.
  2. Avoid initializing NUTS step sampler just for most of the times then immediately discard it and using the one inside init_nuts. This will also reduce the path towards external samplers with nutpie/numpyro as it avoids the costly and useless compilation of the logp_dlogp_function
  3. Using trust_input and avoiding deepcopies in the trace function by using pytensor.In(borrow=True) and pytensor.Out(borrow=True).

Further speedups should come for free from #7539, specially for the Numba backend.

Benchmark

In the example below, sampling time is now only 7x slower than nutpie (5s vs 0.7s), compared to 13.5x slower (9.45s vs 0.7s) before. This assuming the same number of logp evals, in fact nutpie tuning allows us to get out with half the evals! We can hopefully bring it over.

Full time until from pm.sample to getting a trace is roughly halved as well (7.5s vs 14.4s), although this gain is not proportional to the number of draws.

With compile_kwargs=(mode="NUMBA"), sampling time is only 3x slower (2.3s).

import time
import pymc as pm
import numpy as np
import nutpie
import pandas as pd

# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}

# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
    intercept = pm.Normal("intercept", sigma=10)

    # County effects
    raw = pm.ZeroSumNormal("county_raw", dims="county")
    sd = pm.HalfNormal("county_sd")
    county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")

    # Global floor effect
    floor_effect = pm.Normal("floor_effect", sigma=2)

    # County:floor interaction
    raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
    sd = pm.HalfNormal("county_floor_sd")
    county_floor_effect = pm.Deterministic(
        "county_floor_effect", raw * sd, dims="county"
    )

    mu = (
        intercept
        + county_effect[county_idx]
        + floor_effect * data.floor.values
        + county_floor_effect[county_idx] * data.floor.values
    )

    sigma = pm.HalfNormal("sigma", sigma=1.5)
    pm.Normal(
        "log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id"
    )

from pymc.model.transform.optimization import freeze_dims_and_data
model = freeze_dims_and_data(model)
compiled_model = nutpie.compile_pymc_model(model)

start = time.perf_counter()
# More draws to make up for the fact that nutpie tunes better
trace_pymc = nutpie.sample(compiled_model, chains=1, tune=500, draws=1500, progress_bar=False)
end = time.perf_counter()
print(end - start)
idata = pm.sample(
    model=model, 
    chains=1,
    tune=500, 
    draws=500, 
    progressbar=False, 
    compute_convergence_checks=False, 
    return_inferencedata=False,
    # compile_kwargs=dict(mode="NUMBA")
)
print(idata._report.t_sampling)

📚 Documentation preview 📚: https://pymc--7578.org.readthedocs.build/en/7578/

@ricardoV94 ricardoV94 changed the title WIP Speedup NUTS Speedup sample Nov 22, 2024
@ricardoV94 ricardoV94 force-pushed the speedup_nuts branch 4 times, most recently from d6f9e14 to 87fd299 Compare November 23, 2024 23:06
@ricardoV94 ricardoV94 added major Include in major changes release notes section enhancements samplers labels Nov 23, 2024
@ricardoV94 ricardoV94 force-pushed the speedup_nuts branch 3 times, most recently from 874ae65 to cb8d51e Compare November 24, 2024 01:20
@ricardoV94
Copy link
Member Author

10 minutes seem to be saved in pytest CI time compared to previous runs

pymc/blocking.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 force-pushed the speedup_nuts branch 2 times, most recently from f5be4ab to 95ce8bc Compare November 25, 2024 23:33
@ricardoV94 ricardoV94 marked this pull request as ready for review November 25, 2024 23:53
Copy link

codecov bot commented Nov 26, 2024

Codecov Report

Attention: Patch coverage is 94.28571% with 14 lines in your changes missing coverage. Please review.

Project coverage is 92.84%. Comparing base (fe0e0d7) to head (bd232d2).
Report is 14 commits behind head on main.

Files with missing lines Patch % Lines
pymc/step_methods/hmc/quadpotential.py 0.00% 8 Missing ⚠️
pymc/sampling/mcmc.py 89.36% 5 Missing ⚠️
pymc/model/core.py 96.15% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7578      +/-   ##
==========================================
- Coverage   92.84%   92.84%   -0.01%     
==========================================
  Files         106      106              
  Lines       17686    17719      +33     
==========================================
+ Hits        16421    16451      +30     
- Misses       1265     1268       +3     
Files with missing lines Coverage Δ
pymc/backends/__init__.py 91.89% <100.00%> (+0.22%) ⬆️
pymc/backends/base.py 88.69% <100.00%> (+0.25%) ⬆️
pymc/backends/ndarray.py 80.00% <100.00%> (+0.90%) ⬆️
pymc/blocking.py 97.67% <100.00%> (+1.92%) ⬆️
pymc/pytensorf.py 90.97% <100.00%> (ø)
pymc/sampling/parallel.py 88.73% <100.00%> (-0.12%) ⬇️
pymc/step_methods/arraystep.py 96.10% <100.00%> (+1.43%) ⬆️
pymc/step_methods/hmc/base_hmc.py 91.91% <100.00%> (ø)
pymc/step_methods/hmc/integration.py 84.12% <100.00%> (+2.46%) ⬆️
pymc/step_methods/hmc/nuts.py 97.43% <100.00%> (+0.03%) ⬆️
... and 7 more

... and 1 file with indirect coverage changes

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, left a few small comments

pymc/backends/base.py Outdated Show resolved Hide resolved
pymc/backends/base.py Show resolved Hide resolved
pymc/step_methods/arraystep.py Show resolved Hide resolved
pymc/step_methods/hmc/nuts.py Show resolved Hide resolved
@ricardoV94 ricardoV94 changed the title Speedup sample Speedup sample (with several major changes related to step samplers) Nov 27, 2024
@ricardoV94 ricardoV94 changed the title Speedup sample (with several major changes related to step samplers) Speedup sample and allow specifying compile_kwargs (several major changes related to step samplers) Nov 27, 2024
@ricardoV94 ricardoV94 changed the title Speedup sample and allow specifying compile_kwargs (several major changes related to step samplers) Speedup sample and allow specifying compile_kwargs (several major changes related to step samplers) Nov 27, 2024
@ricardoV94 ricardoV94 merged commit 7c369c8 into pymc-devs:main Nov 29, 2024
23 checks passed
@ricardoV94 ricardoV94 deleted the speedup_nuts branch December 3, 2024 15:09
@ricardoV94 ricardoV94 mentioned this pull request Dec 21, 2024
19 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements maintenance major Include in major changes release notes section samplers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants