-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch Lightning example #3189
Changes from 6 commits
585beb9
260d05b
8e310b1
a48dc7c
ff4325d
b91c739
086a684
78aabaa
99e9ec2
dbf878a
562d1bb
92d3065
be091d1
7cb5f86
c56582d
e7e44b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,116 @@ | ||||||||
# Copyright Contributors to the Pyro project. | ||||||||
# SPDX-License-Identifier: Apache-2.0 | ||||||||
|
||||||||
# Distributed training via Pytorch Lightning. | ||||||||
# | ||||||||
# This tutorial demonstrates how to distribute SVI training across multiple | ||||||||
# machines (or multiple GPUs on one or more machines) using the PyTorch Lightning | ||||||||
# library. PyTorch Lightning enables data-parallel training by aggregating stochastic | ||||||||
# gradients at each step of training. We focus on integration between PyTorch Lightning and Pyro. | ||||||||
# For further details on distributed computing with PyTorch Lightning, see | ||||||||
# https://lightning.ai/docs/pytorch/latest | ||||||||
# | ||||||||
# This assumes you have installed pytorch lightning, e.g. via | ||||||||
# pip install pyro-ppl[lightning] | ||||||||
|
||||||||
import argparse | ||||||||
|
||||||||
import pytorch_lightning as pl | ||||||||
import torch | ||||||||
|
||||||||
import pyro | ||||||||
import pyro.distributions as dist | ||||||||
from pyro.infer import Trace_ELBO | ||||||||
from pyro.infer.autoguide import AutoNormal | ||||||||
from pyro.nn import PyroModule | ||||||||
|
||||||||
|
||||||||
# We define a model as usual, with no reference to Pytorch Lightning. | ||||||||
# This model is data parallel and supports subsampling. | ||||||||
class Model(PyroModule): | ||||||||
def __init__(self, size): | ||||||||
super().__init__() | ||||||||
self.size = size | ||||||||
|
||||||||
def forward(self, covariates, data=None): | ||||||||
coeff = pyro.sample("coeff", dist.Normal(0, 1)) | ||||||||
bias = pyro.sample("bias", dist.Normal(0, 1)) | ||||||||
scale = pyro.sample("scale", dist.LogNormal(0, 1)) | ||||||||
|
||||||||
# Since we'll use a distributed dataloader during training, we need to | ||||||||
# manually pass minibatches of (covariates,data) that are smaller than | ||||||||
# the full self.size. In particular we cannot rely on pyro.plate to | ||||||||
# automatically subsample, since that would lead to all workers drawing | ||||||||
# identical subsamples. | ||||||||
with pyro.plate("data", self.size, len(covariates)): | ||||||||
loc = bias + coeff * covariates | ||||||||
return pyro.sample("obs", dist.Normal(loc, scale), obs=data) | ||||||||
|
||||||||
|
||||||||
# We define an ELBO loss, a PyTorch optimizer, and a training step in our PyroLightningModule. | ||||||||
# Note that we are using a PyTorch optimizer instead of a Pyro optimizer and | ||||||||
# we are using ``training_step`` instead of Pyro's SVI machinery. | ||||||||
class PyroLightningModule(pl.LightningModule): | ||||||||
def __init__(self, loss_fn: pyro.infer.elbo.ELBOModule, lr: float): | ||||||||
super().__init__() | ||||||||
self.loss_fn = loss_fn | ||||||||
self.model = loss_fn.model | ||||||||
self.guide = loss_fn.guide | ||||||||
self.lr = lr | ||||||||
self.predictive = pyro.infer.Predictive( | ||||||||
self.model, guide=self.guide, num_samples=1 | ||||||||
) | ||||||||
|
||||||||
def forward(self, *args): | ||||||||
return self.predictive(*args) | ||||||||
|
||||||||
def training_step(self, batch, batch_idx): | ||||||||
"""Training step for Pyro training.""" | ||||||||
loss = self.loss_fn(*batch) | ||||||||
# Logging to TensorBoard by default | ||||||||
self.log("train_loss", loss) | ||||||||
return loss | ||||||||
|
||||||||
def configure_optimizers(self): | ||||||||
"""Configure an optimizer.""" | ||||||||
return torch.optim.Adam(self.loss_fn.parameters(), lr=self.lr) | ||||||||
|
||||||||
|
||||||||
def main(args): | ||||||||
# Create a model, synthetic data, a guide, and a lightning module. | ||||||||
pyro.set_rng_seed(args.seed) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This option added in #3149 ensures that parameters of
Suggested change
It's not really exercised in this simple example since there's only one model and guide but I think it's good practice to enable it whenever models and guides can be written as |
||||||||
pyro.settings.set(module_local_params=True) | ||||||||
model = Model(args.size) | ||||||||
covariates = torch.randn(args.size) | ||||||||
data = model(covariates) | ||||||||
guide = AutoNormal(model) | ||||||||
loss_fn = Trace_ELBO()(model, guide) | ||||||||
training_plan = PyroLightningModule(loss_fn, args.learning_rate) | ||||||||
|
||||||||
# Create a dataloader. | ||||||||
dataset = torch.utils.data.TensorDataset(covariates, data) | ||||||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size) | ||||||||
|
||||||||
# All relevant parameters need to be initialized before ``configure_optimizer`` is called. | ||||||||
# Since we used AutoNormal guide our parameters have not be initialized yet. | ||||||||
# Therefore we initialize the model and guide by running one mini-batch through the loss. | ||||||||
mini_batch = dataset[: args.batch_size] | ||||||||
loss_fn(*mini_batch) | ||||||||
|
||||||||
# Run stochastic variational inference using PyTorch Lightning Trainer. | ||||||||
trainer = pl.Trainer.from_argparse_args(args) | ||||||||
trainer.fit(training_plan, train_dataloaders=dataloader) | ||||||||
|
||||||||
|
||||||||
if __name__ == "__main__": | ||||||||
assert pyro.__version__.startswith("1.8.4") | ||||||||
parser = argparse.ArgumentParser( | ||||||||
description="Distributed training via PyTorch Lightning" | ||||||||
) | ||||||||
parser.add_argument("--size", default=1000000, type=int) | ||||||||
parser.add_argument("--batch_size", default=100, type=int) | ||||||||
parser.add_argument("--learning_rate", default=0.01, type=float) | ||||||||
parser.add_argument("--seed", default=20200723, type=int) | ||||||||
parser = pl.Trainer.add_argparse_args(parser) | ||||||||
args = parser.parse_args() | ||||||||
main(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
Example: distributed training via PyTorch Lightning | ||
=================================================== | ||
|
||
This script passes argparse arguments to PyTorch Lightning ``Trainer`` automatically_, for example:: | ||
|
||
$ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp | ||
|
||
.. _automatically: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-in-python-scripts | ||
|
||
`View svi_lightning.py on github`__ | ||
|
||
.. _github: https://github.com/pyro-ppl/pyro/blob/dev/examples/svi_lightning.py | ||
|
||
__ github_ | ||
|
||
.. literalinclude:: ../../examples/svi_lightning.py | ||
:language: python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the distributed training in this example? Is it hidden in the default configuration of the
DataLoader
andTrainingPlan
inmain
below?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argparse arguments are passed to the
pl.Trainer
:So you can run the script as follows:
When there are multiple devices
DataLoader
will useDistributedSampler
automatically.