Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Lightning example #3189

Merged
merged 16 commits into from
Mar 16, 2023
2 changes: 1 addition & 1 deletion examples/svi_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# https://horovod.readthedocs.io/en/stable
#
# This assumes you have installed horovod, e.g. via
# pip install pyro[horovod]
# pip install pyro-ppl[horovod]
# For detailed instructions see
# https://horovod.readthedocs.io/en/stable/install.html
# On my mac laptop I was able to install horovod with
Expand Down
116 changes: 116 additions & 0 deletions examples/svi_lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

# Distributed training via Pytorch Lightning.
#
# This tutorial demonstrates how to distribute SVI training across multiple
# machines (or multiple GPUs on one or more machines) using the PyTorch Lightning
Comment on lines +4 to +7
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the distributed training in this example? Is it hidden in the default configuration of the DataLoader and TrainingPlan in main below?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argparse arguments are passed to the pl.Trainer:

trainer = pl.Trainer.from_argparse_args(args)

So you can run the script as follows:

$ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp

When there are multiple devices DataLoader will use DistributedSampler automatically.

# library. PyTorch Lightning enables data-parallel training by aggregating stochastic
# gradients at each step of training. We focus on integration between PyTorch Lightning and Pyro.
# For further details on distributed computing with PyTorch Lightning, see
# https://lightning.ai/docs/pytorch/latest
#
# This assumes you have installed pytorch lightning, e.g. via
# pip install pyro-ppl[lightning]

import argparse

import pytorch_lightning as pl
import torch

import pyro
import pyro.distributions as dist
from pyro.infer import Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.nn import PyroModule


# We define a model as usual, with no reference to Pytorch Lightning.
# This model is data parallel and supports subsampling.
class Model(PyroModule):
def __init__(self, size):
super().__init__()
self.size = size

def forward(self, covariates, data=None):
coeff = pyro.sample("coeff", dist.Normal(0, 1))
bias = pyro.sample("bias", dist.Normal(0, 1))
scale = pyro.sample("scale", dist.LogNormal(0, 1))

# Since we'll use a distributed dataloader during training, we need to
# manually pass minibatches of (covariates,data) that are smaller than
# the full self.size. In particular we cannot rely on pyro.plate to
# automatically subsample, since that would lead to all workers drawing
# identical subsamples.
with pyro.plate("data", self.size, len(covariates)):
loc = bias + coeff * covariates
return pyro.sample("obs", dist.Normal(loc, scale), obs=data)


# We define an ELBO loss, a PyTorch optimizer, and a training step in our PyroLightningModule.
# Note that we are using a PyTorch optimizer instead of a Pyro optimizer and
# we are using ``training_step`` instead of Pyro's SVI machinery.
class PyroLightningModule(pl.LightningModule):
def __init__(self, loss_fn: pyro.infer.elbo.ELBOModule, lr: float):
super().__init__()
self.loss_fn = loss_fn
self.model = loss_fn.model
self.guide = loss_fn.guide
self.lr = lr
self.predictive = pyro.infer.Predictive(
self.model, guide=self.guide, num_samples=1
)

def forward(self, *args):
return self.predictive(*args)

def training_step(self, batch, batch_idx):
"""Training step for Pyro training."""
loss = self.loss_fn(*batch)
# Logging to TensorBoard by default
self.log("train_loss", loss)
return loss

def configure_optimizers(self):
"""Configure an optimizer."""
return torch.optim.Adam(self.loss_fn.parameters(), lr=self.lr)


def main(args):
# Create a model, synthetic data, a guide, and a lightning module.
pyro.set_rng_seed(args.seed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This option added in #3149 ensures that parameters of PyroModules will not be implicitly shared across model instances via the Pyro parameter store:

Suggested change
pyro.set_rng_seed(args.seed)
pyro.set_rng_seed(args.seed)
pyro.settings.set(module_local_params=True)

It's not really exercised in this simple example since there's only one model and guide but I think it's good practice to enable it whenever models and guides can be written as PyroModules and trained using generic PyTorch infrastructure like torch.optim and PyTorch Lightning.

pyro.settings.set(module_local_params=True)
model = Model(args.size)
covariates = torch.randn(args.size)
data = model(covariates)
guide = AutoNormal(model)
loss_fn = Trace_ELBO()(model, guide)
training_plan = PyroLightningModule(loss_fn, args.learning_rate)

# Create a dataloader.
dataset = torch.utils.data.TensorDataset(covariates, data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size)

# All relevant parameters need to be initialized before ``configure_optimizer`` is called.
# Since we used AutoNormal guide our parameters have not be initialized yet.
# Therefore we initialize the model and guide by running one mini-batch through the loss.
mini_batch = dataset[: args.batch_size]
loss_fn(*mini_batch)

# Run stochastic variational inference using PyTorch Lightning Trainer.
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(training_plan, train_dataloaders=dataloader)


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.4")
parser = argparse.ArgumentParser(
description="Distributed training via PyTorch Lightning"
)
parser.add_argument("--size", default=1000000, type=int)
parser.add_argument("--batch_size", default=100, type=int)
parser.add_argument("--learning_rate", default=0.01, type=float)
parser.add_argument("--seed", default=20200723, type=int)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
main(args)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
"yapf",
],
"horovod": ["horovod[pytorch]>=0.19"],
"lightning": ["pytorch_lightning"],
"funsor": [
# This must be a released version when Pyro is released.
# "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@7bb52d0eae3046d08a20d1b288544e1a21b4f461",
Expand Down
8 changes: 8 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def wrapper(*args, **kwargs):
horovod is None, reason="horovod is not available"
)

try:
import pytorch_lightning
except ImportError:
pytorch_lightning = None
requires_lightning = pytest.mark.skipif(
pytorch_lightning is None, reason="pytorch lightning is not available"
)

try:
import funsor
except ImportError:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
requires_cuda,
requires_funsor,
requires_horovod,
requires_lightning,
xfail_param,
)

Expand Down Expand Up @@ -110,6 +111,10 @@
"sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide auto",
"sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide easy",
"svi_horovod.py --num-epochs=2 --size=400 --no-horovod",
pytest.param(
"svi_lightning.py --max_epochs=2 --size=400 --accelerator cpu --devices 1",
marks=[requires_lightning],
),
"toy_mixture_model_discrete_enumeration.py --num-steps=1",
"sparse_regression.py --num-steps=100 --num-data=100 --num-dimensions 11",
"vae/ss_vae_M2.py --num-epochs=1",
Expand Down Expand Up @@ -177,6 +182,10 @@
"sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --cuda",
"sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --cuda",
"svi_horovod.py --num-epochs=2 --size=400 --cuda --no-horovod",
pytest.param(
"svi_lightning.py --max_epochs=2 --size=400 --accelerator gpu --devices 1",
marks=[requires_lightning],
),
"vae/vae.py --num-epochs=1 --cuda",
"vae/ss_vae_M2.py --num-epochs=1 --cuda",
"vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda",
Expand Down
17 changes: 17 additions & 0 deletions tutorial/source/svi_lightning.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Example: distributed training via PyTorch Lightning
===================================================

This script passes argparse arguments to PyTorch Lightning ``Trainer`` automatically_, for example::

$ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp

.. _automatically: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-in-python-scripts

`View svi_lightning.py on github`__

.. _github: https://github.com/pyro-ppl/pyro/blob/dev/examples/svi_lightning.py

__ github_

.. literalinclude:: ../../examples/svi_lightning.py
:language: python