From 37403b9e6915250bb5caed19abb603804ae18081 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Wed, 11 Dec 2024 11:24:38 +0545 Subject: [PATCH] torch-loader(example): use prefetch and try to run example in linux --- examples/get_started/torch-loader.py | 3 ++- src/datachain/asyn.py | 2 +- tests/examples/test_examples.py | 9 +-------- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/examples/get_started/torch-loader.py b/examples/get_started/torch-loader.py index 1a26ed36d..c8649720c 100644 --- a/examples/get_started/torch-loader.py +++ b/examples/get_started/torch-loader.py @@ -54,6 +54,7 @@ def forward(self, x): if __name__ == "__main__": ds = ( DataChain.from_storage(STORAGE, type="image") + .settings(prefetch=25, cache=True) .filter(C("file.path").glob("*.jpg")) .map( label=lambda path: label_to_int(basename(path)[:3], CLASSES), @@ -64,7 +65,7 @@ def forward(self, x): train_loader = DataLoader( ds.to_pytorch(transform=transform), - batch_size=16, + batch_size=25, num_workers=2, ) diff --git a/src/datachain/asyn.py b/src/datachain/asyn.py index 1b87afc41..3c27b7063 100644 --- a/src/datachain/asyn.py +++ b/src/datachain/asyn.py @@ -165,7 +165,7 @@ async def _break_iteration(self) -> None: def iterate(self, timeout=None) -> Generator[ResultT, None, None]: init = asyncio.run_coroutine_threadsafe(self.init(), self.loop) - init.result(timeout=1) + init.result() async_run = asyncio.run_coroutine_threadsafe(self.run(), self.loop) try: while True: diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index 701d1307f..d3735f364 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -6,14 +6,7 @@ import pytest -get_started_examples = sorted( - [ - filename - for filename in glob.glob("examples/get_started/**/*.py", recursive=True) - # torch-loader will not finish within an hour on Linux runner - if "torch" not in filename or os.environ.get("RUNNER_OS") != "Linux" - ] -) +get_started_examples = sorted(glob.glob("examples/get_started/**/*.py", recursive=True)) llm_and_nlp_examples = sorted(glob.glob("examples/llm_and_nlp/**/*.py", recursive=True))