Skip to content

Commit

Permalink
Add an option to stop PyroModules from sharing parameters (#3149)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 authored Nov 9, 2022
1 parent 891880f commit 8b7e564
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 95 deletions.
2 changes: 2 additions & 0 deletions pyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"condition",
"deterministic",
"do",
"enable_module_local_param",
"enable_validation",
"factor",
"get_param_store",
Expand All @@ -51,6 +52,7 @@
"log",
"markov",
"module",
"module_local_param_enabled",
"param",
"plate",
"plate",
Expand Down
54 changes: 54 additions & 0 deletions pyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,26 @@
import warnings
from abc import ABCMeta, abstractmethod

import torch

import pyro
import pyro.poutine as poutine
from pyro.infer.util import is_validation_enabled
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_site_shape


class ELBOModule(torch.nn.Module):
def __init__(self, model: torch.nn.Module, guide: torch.nn.Module, elbo: "ELBO"):
super().__init__()
self.model = model
self.guide = guide
self.elbo = elbo

def forward(self, *args, **kwargs):
return self.elbo.differentiable_loss(self.model, self.guide, *args, **kwargs)


class ELBO(object, metaclass=ABCMeta):
"""
:class:`ELBO` is the top-level interface for stochastic variational
Expand All @@ -23,6 +36,40 @@ class ELBO(object, metaclass=ABCMeta):
:class:`~pyro.infer.tracegraph_elbo.TraceGraph_ELBO`, or
:class:`~pyro.infer.traceenum_elbo.TraceEnum_ELBO`.
.. note:: Derived classes now provide a more idiomatic PyTorch interface via
:meth:`__call__` for (model, guide) pairs that are :class:`~torch.nn.Module` s,
which is useful for integrating Pyro's variational inference tooling with
standard PyTorch interfaces like :class:`~torch.optim.Optimizer` s
and the large ecosystem of libraries like PyTorch Lightning
and the PyTorch JIT that work with these interfaces::
model = Model()
guide = pyro.infer.autoguide.AutoNormal(model)
elbo_ = pyro.infer.Trace_ELBO(num_particles=10)
# Fix the model/guide pair
elbo = elbo_(model, guide)
# perform any data-dependent initialization
elbo(data)
optim = torch.optim.Adam(elbo.parameters(), lr=0.001)
for _ in range(100):
optim.zero_grad()
loss = elbo(data)
loss.backward()
optim.step()
Note that Pyro's global parameter store may cause this new interface to
behave unexpectedly relative to standard PyTorch when working with
:class:`~pyro.nn.PyroModule` s.
Users are therefore strongly encouraged to use this interface in conjunction
with :func:`~pyro.enable_module_local_param` which will override the default
implicit sharing of parameters across :class:`~pyro.nn.PyroModule` instances.
:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
:param int max_plate_nesting: Optional bound on max number of nested
Expand Down Expand Up @@ -86,6 +133,13 @@ def __init__(
self.jit_options = jit_options
self.tail_adaptive_beta = tail_adaptive_beta

def __call__(self, model: torch.nn.Module, guide: torch.nn.Module) -> ELBOModule:
"""
Given a model and guide, returns a :class:`~torch.nn.Module` which
computes the ELBO loss when called with arguments to the model and guide.
"""
return ELBOModule(model, guide, self)

def _guess_max_plate_nesting(self, model, guide, args, kwargs):
"""
Guesses max_plate_nesting by running the (model,guide) pair once
Expand Down
42 changes: 41 additions & 1 deletion pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@
from pyro.ops.provenance import detach_provenance
from pyro.poutine.runtime import _PYRO_PARAM_STORE

_MODULE_LOCAL_PARAMS: bool = False


@pyro.settings.register("module_local_params", __name__, "_MODULE_LOCAL_PARAMS")
def _validate_module_local_params(value: bool) -> None:
assert isinstance(value, bool)


def _is_module_local_param_enabled() -> bool:
return pyro.settings.get("module_local_params")


class PyroParam(namedtuple("PyroParam", ("init_value", "constraint", "event_dim"))):
"""
Expand Down Expand Up @@ -178,15 +189,23 @@ def __init__(self):
self.active = 0
self.cache = {}
self.used = False
if _is_module_local_param_enabled():
self.param_state = {"params": {}, "constraints": {}}

def __enter__(self):
if not self.active and _is_module_local_param_enabled():
self._param_ctx = pyro.get_param_store().scope(state=self.param_state)
self.param_state = self._param_ctx.__enter__()
self.active += 1
self.used = True

def __exit__(self, type, value, traceback):
self.active -= 1
if not self.active:
self.cache.clear()
if _is_module_local_param_enabled():
self._param_ctx.__exit__(type, value, traceback)
del self._param_ctx

def get(self, name):
if self.active:
Expand Down Expand Up @@ -409,6 +428,8 @@ def named_pyro_params(self, prefix="", recurse=True):
yield elem

def _pyro_set_supermodule(self, name, context):
if _is_module_local_param_enabled() and pyro.settings.get("validate_poutine"):
self._check_module_local_param_usage()
self._pyro_name = name
self._pyro_context = context
for key, value in self._modules.items():
Expand All @@ -424,7 +445,26 @@ def _pyro_get_fullname(self, name):

def __call__(self, *args, **kwargs):
with self._pyro_context:
return super().__call__(*args, **kwargs)
result = super().__call__(*args, **kwargs)
if (
pyro.settings.get("validate_poutine")
and not self._pyro_context.active
and _is_module_local_param_enabled()
):
self._check_module_local_param_usage()
return result

def _check_module_local_param_usage(self) -> None:
self_nn_params = set(id(p) for p in self.parameters())
self_pyro_params = set(
id(p if not hasattr(p, "unconstrained") else p.unconstrained())
for p in self._pyro_context.param_state["params"].values()
)
if not self_pyro_params <= self_nn_params:
raise NotImplementedError(
"Support for global pyro.param statements in PyroModules "
"with local param mode enabled is not yet implemented."
)

def __getattr__(self, name):
# PyroParams trigger pyro.param statements.
Expand Down
Loading

0 comments on commit 8b7e564

Please sign in to comment.