diff --git a/test/stateful_dataloader/test_dataloader.py b/test/stateful_dataloader/test_dataloader.py index 347241720..181a91f74 100644 --- a/test/stateful_dataloader/test_dataloader.py +++ b/test/stateful_dataloader/test_dataloader.py @@ -1638,6 +1638,7 @@ def test_multiprocessing_contexts(self): list(self._get_data_loader(ds_cls(counting_ds_n), multiprocessing_context=ctx, **dl_common_args)), ) + @unittest.skipIf(IS_MACOS, "Not working on macos") def _test_multiprocessing_iterdatapipe(self, with_dill): # Testing to make sure that function from global scope (e.g. imported from library) can be serialized # and used with multiprocess DataLoader diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 4fbf57fef..9a830945a 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -778,5 +778,17 @@ def test_num_workers_mismatch(self): self.assertTrue(False, "Error should be of type AssertionError") +class TestTorchDataLazyImport(unittest.TestCase): + def test_lazy_imports(self) -> None: + import torchdata + + self.assertFalse("datapipes" in torchdata.__dict__) + + from torchdata import datapipes as dp, janitor # noqa + + self.assertTrue("datapipes" in torchdata.__dict__) + dp.iter.IterableWrapper([1, 2]) + + if __name__ == "__main__": unittest.main() diff --git a/torchdata/__init__.py b/torchdata/__init__.py index 43aa7b7b0..0d3f45970 100644 --- a/torchdata/__init__.py +++ b/torchdata/__init__.py @@ -4,11 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchdata import _extension # noqa: F401 - -from . import datapipes - -janitor = datapipes.utils.janitor +import importlib try: from .version import __version__ # noqa: F401 @@ -22,3 +18,17 @@ # Please keep this list sorted assert __all__ == sorted(__all__) + + +# Lazy import all modules +def __getattr__(name): + if name == "janitor": + return importlib.import_module(".datapipes.utils." + name, __name__) + else: + try: + return importlib.import_module("." + name, __name__) + except ModuleNotFoundError: + if name in globals(): + return globals()[name] + else: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/torchdata/_torchdata/__init__.pyi b/torchdata/_torchdata/__init__.pyi index a22edcc8a..fcea8c6c5 100644 --- a/torchdata/_torchdata/__init__.pyi +++ b/torchdata/_torchdata/__init__.pyi @@ -6,6 +6,8 @@ from typing import List +from torchdata import _extension # noqa: F401 + # TODO: Add pyi generate script class S3Handler: def __init__(self, request_timeout_ms: int, region: str) -> None: ... diff --git a/torchdata/dataloader2/__init__.py b/torchdata/dataloader2/__init__.py index aaaae491e..df2f92e85 100644 --- a/torchdata/dataloader2/__init__.py +++ b/torchdata/dataloader2/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. +from torchdata import _extension # noqa: F401 from torchdata.dataloader2.dataloader2 import DataLoader2, DataLoader2Iterator from torchdata.dataloader2.error import PauseIteration from torchdata.dataloader2.reading_service import ( diff --git a/torchdata/datapipes/__init__.py b/torchdata/datapipes/__init__.py index 30360840b..754e95b8b 100644 --- a/torchdata/datapipes/__init__.py +++ b/torchdata/datapipes/__init__.py @@ -6,6 +6,8 @@ from torch.utils.data import DataChunk, functional_datapipe +from torchdata import _extension # noqa: F401 + from . import iter, map, utils __all__ = ["DataChunk", "functional_datapipe", "iter", "map", "utils"]