From 47381fbedc929972e2eb0a6ec72b72d043a8c85b Mon Sep 17 00:00:00 2001 From: Jiri Petrlik Date: Fri, 1 Mar 2024 16:01:30 +0100 Subject: [PATCH] RHOAIENG-3771 - Reduce execution time of E2E tests --- tests/e2e/mnist.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/e2e/mnist.py b/tests/e2e/mnist.py index a99589659..2971d9c98 100644 --- a/tests/e2e/mnist.py +++ b/tests/e2e/mnist.py @@ -19,7 +19,7 @@ from pytorch_lightning.callbacks.progress import TQDMProgressBar from torch import nn from torch.nn import functional as F -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader, random_split, RandomSampler from torchmetrics import Accuracy from torchvision import transforms from torchvision.datasets import MNIST @@ -127,7 +127,11 @@ def setup(self, stage=None): ) def train_dataloader(self): - return DataLoader(self.mnist_train, batch_size=BATCH_SIZE) + return DataLoader( + self.mnist_train, + batch_size=BATCH_SIZE, + sampler=RandomSampler(self.mnist_train, num_samples=1000), + ) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=BATCH_SIZE) @@ -147,10 +151,11 @@ def test_dataloader(self): trainer = Trainer( accelerator="auto", # devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs - max_epochs=5, + max_epochs=3, callbacks=[TQDMProgressBar(refresh_rate=20)], num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)), devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)), + replace_sampler_ddp=False, strategy="ddp", )