Skip to content

Commit

Permalink
RHOAIENG-3771 - Reduce execution time of E2E tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jiripetrlik authored and openshift-merge-bot[bot] committed Mar 4, 2024
1 parent 0beece1 commit 47381fb
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tests/e2e/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torch.utils.data import DataLoader, random_split, RandomSampler
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST
Expand Down Expand Up @@ -127,7 +127,11 @@ def setup(self, stage=None):
)

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
return DataLoader(
self.mnist_train,
batch_size=BATCH_SIZE,
sampler=RandomSampler(self.mnist_train, num_samples=1000),
)

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
Expand All @@ -147,10 +151,11 @@ def test_dataloader(self):
trainer = Trainer(
accelerator="auto",
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
max_epochs=5,
max_epochs=3,
callbacks=[TQDMProgressBar(refresh_rate=20)],
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
replace_sampler_ddp=False,
strategy="ddp",
)

Expand Down

0 comments on commit 47381fb

Please sign in to comment.