-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an option to stop PyroModules from sharing parameters #3149
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice rethinking towards more idiomatic PyTorch!
|
||
def __enter__(self): | ||
if not self.active and _is_module_local_param_enabled(): | ||
self._param_ctx = pyro.get_param_store().scope(state=self.param_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Persisting self.param_state
like this (and in _pyro_set_supermodule
below) seems to be a reasonable solution for the behavior of vanilla pyro.param
statements. Values of these parameters are now local to the outermost PyroModule
in a nested PyroModule
instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey thanks for your patience in reviewing this subtle PR. The ELBOModule
changes look clean. I'm still working through understanding the module_local_param changes...
pyro/primitives.py
Outdated
@@ -562,3 +562,35 @@ def validation_enabled(is_validate=True): | |||
dist.enable_validation(distribution_validation_status) | |||
infer.enable_validation(infer_validation_status) | |||
poutine.enable_validation(poutine_validation_status) | |||
|
|||
|
|||
def enable_module_local_param(is_enabled: bool = False) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to make it super clear that users can now decide between (i) a global param store or (ii) local nn.Module style parameters. Like maybe
with pyro.param_storage("local"): ...
with pyro.param_storage("global"): ...
or pyro.disable_param_store(True)
or pyro.enable_param_store(False)
. Whatever we call it I think it would be good in the first docstring sentence to mention the phrase "param store" and the word "nn.Module".
pyro/nn/module.py
Outdated
if _is_module_local_param_enabled(): | ||
with pyro.get_param_store().scope( | ||
state=self._pyro_context.param_state | ||
) as vanilla_param_state: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Would another word for "vanilla" be "global" or "global-only" or "raw" or "nonmodule" or something? We might want to avoid "vanilla" because PyTorch users new to Pyro might think of "vanilla" as "an nn.Param attribute of an nn.Module".
:class:`~pyro.nn.PyroModule` s. | ||
|
||
Users are therefore strongly encouraged to use this interface in conjunction | ||
with :func:`~pyro.enable_module_local_param` which will override the default |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: override -> disable or avoid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after minor comment on .set()
vs .context()
in tests
Replaces #2996
This PR adds two small related features for easier Pyro-PyTorch integration:
__call__
method for the basepyro.infer.elbo.ELBO
that bindsELBO
instances to specificnn.Module
model/guide pairs in aModule
that exposes their PyTorch parametersPyroModule
instances from sharing parameter values with one another through the global Pyro parameter store, and a primitive and context manager for toggling it. One context where this is useful is for workflows that involve multiple models and autoguides with overlapping parameter names.An edge case I haven't handled here is the behavior under the new local parameter setting of regular
pyro.param
statements (as opposed toPyroParam
) within aPyroModule
that don't have their data associated with any underlyingnn.Module
. I've raised an error rather than attempt to get this working, since I think it's usually aPyroModule
programming anti-pattern to mix global and local parameter states in this way.I am also hopeful that these changes will simplify the use of Pyro with the PyTorch JIT and other PyTorch compilers, but I have left testing this for future work, since I suspect it will require additional engineering that is out of scope for this PR.
Tasks:
pyro.settings
module from Clean up handling of global settings #3152pyro.param
statement inside aPyroModule
Tested:
PyroModule
tests intests/nn/test_module.py
pyro.param