Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option to stop PyroModules from sharing parameters #3149

Merged
merged 18 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
barrier,
clear_param_store,
deterministic,
enable_module_local_param,
enable_validation,
factor,
get_param_store,
iarange,
irange,
module,
module_local_param_enabled,
param,
plate,
plate_stack,
Expand All @@ -41,6 +43,7 @@
"condition",
"deterministic",
"do",
"enable_module_local_param",
"enable_validation",
"factor",
"get_param_store",
Expand All @@ -49,6 +52,7 @@
"log",
"markov",
"module",
"module_local_param_enabled",
"param",
"plate",
"plate",
Expand Down
56 changes: 56 additions & 0 deletions pyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,26 @@
import warnings
from abc import ABCMeta, abstractmethod

import torch

import pyro
import pyro.poutine as poutine
from pyro.infer.util import is_validation_enabled
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_site_shape


class _ELBOModule(torch.nn.Module):
eb8680 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, model: torch.nn.Module, guide: torch.nn.Module, elbo: "ELBO"):
super().__init__()
self.model = model
self.guide = guide
self.elbo = elbo

def forward(self, *args, **kwargs):
return self.elbo.differentiable_loss(self.model, self.guide, *args, **kwargs)
eb8680 marked this conversation as resolved.
Show resolved Hide resolved


class ELBO(object, metaclass=ABCMeta):
"""
:class:`ELBO` is the top-level interface for stochastic variational
Expand All @@ -23,6 +36,40 @@ class ELBO(object, metaclass=ABCMeta):
:class:`~pyro.infer.tracegraph_elbo.TraceGraph_ELBO`, or
:class:`~pyro.infer.traceenum_elbo.TraceEnum_ELBO`.

.. note:: Derived classes now provide a more idiomatic PyTorch interface via
:meth:`__call__` for (model, guide) pairs that are :class:`~torch.nn.Module` s,
which is useful for integrating Pyro's variational inference tooling with
standard PyTorch interfaces like :class:`~torch.optim.Optimizer` s
and the large ecosystem of libraries like PyTorch Lightning
and the PyTorch JIT that work with these interfaces::

model = Model()
guide = pyro.infer.autoguide.AutoNormal(model)

elbo_ = pyro.infer.Trace_ELBO(num_particles=10)

# Fix the model/guide pair
elbo = elbo_(model, guide)

# perform any data-dependent initialization
elbo(data)

optim = torch.optim.Adam(elbo.parameters(), lr=0.001)

for _ in range(100):
optim.zero_grad()
loss = elbo(data)
loss.backward()
optim.step()

Note that Pyro's global parameter store may cause this new interface to
behave unexpectedly relative to standard PyTorch when working with
:class:`~pyro.nn.PyroModule` s.

Users are therefore strongly encouraged to use this interface in conjunction
with :func:`~pyro.enable_module_local_param` which will override the default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: override -> disable or avoid?

implicit sharing of parameters across :class:`~pyro.nn.PyroModule` instances.

:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
:param int max_plate_nesting: Optional bound on max number of nested
Expand Down Expand Up @@ -86,6 +133,15 @@ def __init__(
self.jit_options = jit_options
self.tail_adaptive_beta = tail_adaptive_beta

def __call__(
self, model: torch.nn.Module, guide: torch.nn.Module
) -> torch.nn.Module:
eb8680 marked this conversation as resolved.
Show resolved Hide resolved
"""
Given a model and guide, returns a :class:`~torch.nn.Module` which
computes the ELBO loss when called with arguments to the model and guide.
"""
return _ELBOModule(model, guide, self)

def _guess_max_plate_nesting(self, model, guide, args, kwargs):
"""
Guesses max_plate_nesting by running the (model,guide) pair once
Expand Down
10 changes: 10 additions & 0 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def _unconstrain(constrained_value, constraint):
return torch.nn.Parameter(unconstrained_value)


def _is_module_local_param_enabled() -> bool:
return pyro.poutine.runtime._PYRO_MODULE_LOCAL_PARAM


class _Context:
"""
Sometimes-active cache for ``PyroModule.__call__()`` contexts.
Expand All @@ -180,13 +184,19 @@ def __init__(self):
self.used = False

def __enter__(self):
if not self.active and _is_module_local_param_enabled():
self._param_ctx = pyro.get_param_store().scope(state=None)
self._param_ctx.__enter__()
self.active += 1
self.used = True

def __exit__(self, type, value, traceback):
self.active -= 1
if not self.active:
self.cache.clear()
if _is_module_local_param_enabled():
self._param_ctx.__exit__(type, value, traceback)
del self._param_ctx

def get(self, name):
if self.active:
Expand Down
3 changes: 3 additions & 0 deletions pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# the global ParamStore
_PYRO_PARAM_STORE = ParamStoreDict()

# toggle usage of local param stores in PyroModules
_PYRO_MODULE_LOCAL_PARAM = False


class _DimAllocator:
"""
Expand Down
32 changes: 32 additions & 0 deletions pyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,35 @@ def validation_enabled(is_validate=True):
dist.enable_validation(distribution_validation_status)
infer.enable_validation(infer_validation_status)
poutine.enable_validation(poutine_validation_status)


def enable_module_local_param(is_enabled: bool = False) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to make it super clear that users can now decide between (i) a global param store or (ii) local nn.Module style parameters. Like maybe

with pyro.param_storage("local"): ...
with pyro.param_storage("global"): ...

or pyro.disable_param_store(True) or pyro.enable_param_store(False). Whatever we call it I think it would be good in the first docstring sentence to mention the phrase "param store" and the word "nn.Module".

"""
Toggles the behavior of :class:`~pyro.nn.module.PyroModule` to use
local parameters instead of global parameters.

When this feature is enabled, :class:`~pyro.nn.module.PyroModule`
instances will not share parameters with other instances of the same
class through Pyro's global parameter store. Instead, each instance
will have its own local parameters, just like a standard :class:`torch.nn.Module`.

.. note:: This feature is disabled by default to ensure backwards compatibility
of :class:`~pyro.nn.module.PyroModule` with existing Pyro code.

:param bool is_enabled: (optional; defaults to False) whether to
enable local parameters.
"""
poutine.runtime._PYRO_MODULE_LOCAL_PARAM = is_enabled


@contextmanager
def module_local_param_enabled(is_enabled=False):
"""
Context manager to temporarily toggle local parameter stores in PyroModules.
"""
old_flag = poutine.runtime._PYRO_MODULE_LOCAL_PARAM
poutine.runtime._PYRO_MODULE_LOCAL_PARAM = is_enabled
try:
yield
finally:
poutine.runtime._PYRO_MODULE_LOCAL_PARAM = old_flag
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def pytest_runtest_setup(item):
if test_initialize_marker:
rng_seed = test_initialize_marker.kwargs["rng_seed"]
pyro.set_rng_seed(rng_seed)
pyro.enable_module_local_param(False)


def pytest_addoption(parser):
Expand Down
Loading