diff --git a/examples/get_started/torch-loader.py b/examples/get_started/torch-loader.py index 1a26ed36d..39ca670e7 100644 --- a/examples/get_started/torch-loader.py +++ b/examples/get_started/torch-loader.py @@ -9,7 +9,7 @@ from posixpath import basename import torch -from torch import nn, optim +from torch import multiprocessing, nn, optim from torch.utils.data import DataLoader from torchvision.transforms import v2 @@ -54,6 +54,7 @@ def forward(self, x): if __name__ == "__main__": ds = ( DataChain.from_storage(STORAGE, type="image") + .settings(cache=True, prefetch=25) .filter(C("file.path").glob("*.jpg")) .map( label=lambda path: label_to_int(basename(path)[:3], CLASSES), @@ -64,8 +65,9 @@ def forward(self, x): train_loader = DataLoader( ds.to_pytorch(transform=transform), - batch_size=16, - num_workers=2, + batch_size=25, + num_workers=4, + multiprocessing_context=multiprocessing.get_context("spawn"), ) model = CNN() diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index 701d1307f..d3735f364 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -6,14 +6,7 @@ import pytest -get_started_examples = sorted( - [ - filename - for filename in glob.glob("examples/get_started/**/*.py", recursive=True) - # torch-loader will not finish within an hour on Linux runner - if "torch" not in filename or os.environ.get("RUNNER_OS") != "Linux" - ] -) +get_started_examples = sorted(glob.glob("examples/get_started/**/*.py", recursive=True)) llm_and_nlp_examples = sorted(glob.glob("examples/llm_and_nlp/**/*.py", recursive=True))