Skip to content

Commit

Permalink
Add Profiling example (link to wandb report) (#64)
Browse files Browse the repository at this point in the history
* Add experiment config, md link to wandb report

* wildcard import to defined

* removed no_op network dependency, algorithm.network convention

* Update docs/profiling_test.py

Co-authored-by: Fabrice Normandin <[email protected]>

* Update docs/profiling_test.py

Co-authored-by: Fabrice Normandin <[email protected]>

* Update docs/profiling_test.py

Co-authored-by: Fabrice Normandin <[email protected]>

* Update docs/profiling_test.py

Co-authored-by: Fabrice Normandin <[email protected]>

* Update docs/profiling_test.py

Co-authored-by: Fabrice Normandin <[email protected]>

* pre-commit modifications

* Update docs/profiling_test.py

Co-authored-by: Fabrice Normandin <[email protected]>

* Fix pre-commit issues

Signed-off-by: Fabrice Normandin <[email protected]>

---------

Signed-off-by: Fabrice Normandin <[email protected]>
Co-authored-by: Fabrice Normandin <[email protected]>
  • Loading branch information
cmvcordova and lebrice authored Oct 11, 2024
1 parent 402b177 commit 7c34f38
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/examples/profiling.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
<iframe src="https://wandb.ai/cesar-valdez-mcgill-university/ResearchTemplate/reports/Profiling--Vmlldzo5NDI1MjU0" style="border:none;height:1024px;width:100%"></iframe>
117 changes: 117 additions & 0 deletions docs/profiling_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import shutil

import hydra.errors
import pytest
from omegaconf import DictConfig

from project.conftest import ( # noqa: F401
accelerator,
algorithm_config,
algorithm_network_config,
command_line_arguments,
datamodule_config,
devices,
experiment_dictconfig,
num_devices_to_use,
overrides,
)
from project.experiment import setup_experiment
from project.utils.hydra_utils import resolve_dictconfig


@pytest.mark.skipif(not shutil.which("sbatch"), reason="Needs to be run on a SLURM cluster")
@pytest.mark.parametrize(
"command_line_arguments",
[
# Instrumenting your code -baseline
"""
experiment=profiling \
algorithm=example \
trainer.logger.wandb.name="Baseline" \
trainer.logger.wandb.tags=["Training","Baseline comparison","CPU/GPU comparison"]
""",
# Identifying potential bottlenecks - baseline
"""
experiment=profiling\
algorithm=no_op\
trainer.logger.wandb.name="Baseline without training" \
trainer.logger.wandb.tags=["No training","Baseline comparison"]
""",
# Identifying potential bottlenecks - num_workers multirun
pytest.param(
"""
-m experiment=profiling \
algorithm=no_op \
trainer.logger.wandb.tags=["1 CPU Dataloading","Worker throughput"] \
datamodule.num_workers=1,4,8,16,32
""",
marks=pytest.mark.xfail(
reason="LexerNoViableAltException error caused by the -m flag",
raises=hydra.errors.OverrideParseException,
strict=True,
),
),
# Identifying potential bottlenecks - num_workers multirun
pytest.param(
"""
-m experiment=profiling \
algorithm=no_op \
resources=cpu \
trainer.logger.wandb.tags=["2 CPU Dataloading","Worker throughput"] \
hydra.launcher.timeout_min=60 \
hydra.launcher.cpus_per_task=2 \
hydra.launcher.constraint="sapphire" \
datamodule.num_workers=1,4,8,16,32
""",
marks=pytest.mark.xfail(
reason="LexerNoViableAltException error caused by the -m flag",
raises=hydra.errors.OverrideParseException,
strict=True,
),
),
# Identifying potential bottlenecks - fcnet mnist
"""
experiment=profiling \
algorithm=example \
algorithm/network=fcnet \
datamodule=mnist \
trainer.logger.wandb.name="FcNet/MNIST baseline with training" \
trainer.logger.wandb.tags=["CPU/GPU comparison","GPU","MNIST"]
""",
# Throughput across GPU types
"""
experiment=profiling \
algorithm=example \
resources=one_gpu \
hydra.launcher.gres='gpu:a100:1' \
hydra.launcher.cpus_per_task=4 \
datamodule.num_workers=8 \
trainer.logger.wandb.name="A100 training" \
trainer.logger.wandb.tags=["GPU comparison"]
""",
# Making the most out of your GPU
pytest.param(
"""
-m experiment=profiling \
algorithm=example \
datamodule.num_workers=8 \
datamodule.batch_size=32,64,128,256 \
trainer.logger.wandb.tags=["Batch size comparison"]\
'++trainer.logger.wandb.name=Batch size ${datamodule.batch_size}'
""",
marks=pytest.mark.xfail(
reason="LexerNoViableAltException error caused by the -m flag",
raises=hydra.errors.OverrideParseException,
strict=True,
),
),
],
indirect=True,
)
def test_notebook_commands_dont_cause_errors(experiment_dictconfig: DictConfig): # noqa
# check for any errors related to OmegaConf interpolations and such
config = resolve_dictconfig(experiment_dictconfig)
# check for any errors when actually instantiating the components.
_experiment = setup_experiment(config)
# Note: Here we don't actually do anything with the objects.
4 changes: 1 addition & 3 deletions project/algorithms/no_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
from lightning import Callback, LightningModule
from torch import nn

from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback
from project.utils.typing_utils.protocols import DataModule
Expand All @@ -11,10 +10,9 @@
class NoOp(LightningModule):
"""No-op algorithm that does no learning and is used to benchmark the dataloading speed."""

def __init__(self, datamodule: DataModule, network: nn.Module):
def __init__(self, datamodule: DataModule):
super().__init__()
self.datamodule = datamodule
self.network = network
# Set this so PyTorch-Lightning doesn't try to train the model using our 'loss'
self.automatic_optimization = False

Expand Down
13 changes: 13 additions & 0 deletions project/configs/experiment/profiling.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# @package _global_

defaults:
- override /datamodule: imagenet
- override /algorithm: example
- override /trainer/logger: wandb

trainer:
min_epochs: 1
max_epochs: 2
limit_train_batches: 30
limit_val_batches: 2
num_sanity_val_steps: 0
6 changes: 6 additions & 0 deletions project/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@

import operator
import os
import shlex
import sys
import typing
from collections import defaultdict
Expand Down Expand Up @@ -187,6 +188,11 @@ def command_line_arguments(
function so that the respective components are created in the same way as they
would be by Hydra in a regular run.
"""
if param := getattr(request, "param", None):
# If we manually overwrite the command-line arguments with indirect parametrization,
# then ignore the rest of the stuff here and just use the provided command-line args.
# Split the string into a list of command-line arguments if needed.
return shlex.split(param) if isinstance(param, str) else param

combination = set([datamodule_config, algorithm_network_config, algorithm_config])
for configs, marks in default_marks_for_config_combinations.items():
Expand Down

0 comments on commit 7c34f38

Please sign in to comment.