-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch Lightning example #3189
Conversation
Addresses #3171. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I've been using Lightning recently as well, so I left some (optional) suggestions aimed at making the example slightly more PyTorch-idiomatic using new features from #3149
|
||
def main(args): | ||
# Create a model, synthetic data, a guide, and a lightning module. | ||
pyro.set_rng_seed(args.seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This option added in #3149 ensures that parameters of PyroModule
s will not be implicitly shared across model instances via the Pyro parameter store:
pyro.set_rng_seed(args.seed) | |
pyro.set_rng_seed(args.seed) | |
pyro.settings.set(module_local_params=True) |
It's not really exercised in this simple example since there's only one model and guide but I think it's good practice to enable it whenever models and guides can be written as PyroModule
s and trained using generic PyTorch infrastructure like torch.optim
and PyTorch Lightning.
examples/svi_lightning.py
Outdated
guide = AutoNormal(model) | ||
training_plan = PyroLightningModule(model, guide, args.learning_rate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change uses the new __call__
method added to the base pyro.infer.elbo.ELBO
in #3149 that takes a model and guide returns a torch.nn.Module
wrapper around the loss:
guide = AutoNormal(model) | |
training_plan = PyroLightningModule(model, guide, args.learning_rate) | |
guide = AutoNormal(model) | |
loss_fn = Trace_ELBO()(model, guide) | |
training_plan = PyroLightningModule(loss_fn, args.learning_rate) |
It saves you from having to pass around a model and guide everywhere or deal with the Pyro parameter store, which makes SVI a little easier to use with other PyTorch tools like Lightning and the PyTorch JIT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know about ELBOModule
. This is much neater!
examples/svi_lightning.py
Outdated
# All relevant parameters need to be initialized before ``configure_optimizer`` is called. | ||
# Since we used AutoNormal guide our parameters have not be initialized yet. | ||
# Therefore we warm up the guide by running one mini-batch through it. | ||
mini_batch = dataset[: args.batch_size] | ||
guide(*mini_batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# All relevant parameters need to be initialized before ``configure_optimizer`` is called. | |
# Since we used AutoNormal guide our parameters have not be initialized yet. | |
# Therefore we warm up the guide by running one mini-batch through it. | |
mini_batch = dataset[: args.batch_size] | |
guide(*mini_batch) | |
# All relevant parameters need to be initialized before ``configure_optimizer`` is called. | |
# Since we used AutoNormal guide our parameters have not be initialized yet. | |
# Therefore we initialize the model and guide by running one mini-batch through the loss. | |
mini_batch = dataset[: args.batch_size] | |
loss_fn(*mini_batch) |
# Distributed training via Pytorch Lightning. | ||
# | ||
# This tutorial demonstrates how to distribute SVI training across multiple | ||
# machines (or multiple GPUs on one or more machines) using the PyTorch Lightning |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the distributed training in this example? Is it hidden in the default configuration of the DataLoader
and TrainingPlan
in main
below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argparse arguments are passed to the pl.Trainer
:
trainer = pl.Trainer.from_argparse_args(args)
So you can run the script as follows:
$ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp
When there are multiple devices DataLoader
will use DistributedSampler
automatically.
examples/svi_lightning.py
Outdated
def __init__(self, model, guide, lr): | ||
super().__init__() | ||
self.pyro_model = model | ||
self.pyro_guide = guide | ||
self.loss_fn = Trace_ELBO().differentiable_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, model, guide, lr): | |
super().__init__() | |
self.pyro_model = model | |
self.pyro_guide = guide | |
self.loss_fn = Trace_ELBO().differentiable_loss | |
def __init__(self, loss_fn: pyro.infer.elbo.ELBOModule, lr: float): | |
super().__init__() | |
self.loss_fn = loss_fn | |
self.model = loss_fn.model | |
self.guide = loss_fn.guide |
examples/svi_lightning.py
Outdated
|
||
def training_step(self, batch, batch_idx): | ||
"""Training step for Pyro training.""" | ||
loss = self.loss_fn(self.pyro_model, self.pyro_guide, *batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loss = self.loss_fn(self.pyro_model, self.pyro_guide, *batch) | |
loss = self.loss_fn(*batch) |
examples/svi_lightning.py
Outdated
|
||
def configure_optimizers(self): | ||
"""Configure an optimizer.""" | ||
return torch.optim.Adam(self.pyro_guide.parameters(), lr=self.lr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return torch.optim.Adam(self.pyro_guide.parameters(), lr=self.lr) | |
return torch.optim.Adam(self.loss_fn.parameters(), lr=self.lr) |
examples/svi_lightning.py
Outdated
self.lr = lr | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a forward
method that calls Predictive
is sometimes helpful:
self.lr = lr | |
self.lr = lr | |
self.predictive = pyro.infer.Predictive(self.model, guide=self.guide) | |
def forward(self, *args): | |
return self.predictive(*args) | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reviewing @eb8680. I think it is much neater now using ELBOModule
!
# Distributed training via Pytorch Lightning. | ||
# | ||
# This tutorial demonstrates how to distribute SVI training across multiple | ||
# machines (or multiple GPUs on one or more machines) using the PyTorch Lightning |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argparse arguments are passed to the pl.Trainer
:
trainer = pl.Trainer.from_argparse_args(args)
So you can run the script as follows:
$ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp
When there are multiple devices DataLoader
will use DistributedSampler
automatically.
examples/svi_lightning.py
Outdated
guide = AutoNormal(model) | ||
training_plan = PyroLightningModule(model, guide, args.learning_rate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know about ELBOModule
. This is much neater!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Can you just confirm the generated docs are readable, i.e. after running make tutorial
? Also ensure the title isn't too long when it appears on the left hand side TOC.
@fritzo There is something wrong with building tutorials when I run
Trying to figure out what is wrong ... (if you know a quick fix would appreciate it) |
@ordabayevy not sure what's causing the build issue... Unrelated, I see
Could you add svi_lightning to tutorial/source/index.rst so it shows up on the website? |
Still no luck with
|
I was able to build the tutorial by ignoring warnings and can confirm that the generated doc is readable and the title in the left hand side TOC is not too long. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks for building tutorials. I'll look into fixing those warnings.
@eb8680 any further comments? I'll hold off merging, feel free to merge
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PyTorch Lightning example (pyro-ppl#3189)
PyTorch Lightning example (pyro-ppl#3189)
This example shows how to train Pyro models using PyTorch Lightning and is adapted from Horovod example.