diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 4fbf57fef..16e32f0c6 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -778,5 +778,17 @@ def test_num_workers_mismatch(self): self.assertTrue(False, "Error should be of type AssertionError") +class TestTorchDataLazyImport(unittest.TestCase): + def test_lazy_imports(self) -> None: + import torchdata + + self.assertFalse("datapipes" in torchdata.__dict__) + + from torchdata import _extension, datapipes as dp, janitor # noqa # noqa + + self.assertTrue("datapipes" in torchdata.__dict__) + dp.iter.IterableWrapper([1, 2]) + + if __name__ == "__main__": unittest.main() diff --git a/torchdata/__init__.py b/torchdata/__init__.py index 43aa7b7b0..a50e48ab8 100644 --- a/torchdata/__init__.py +++ b/torchdata/__init__.py @@ -4,11 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchdata import _extension # noqa: F401 - -from . import datapipes - -janitor = datapipes.utils.janitor +import importlib try: from .version import __version__ # noqa: F401 @@ -22,3 +18,17 @@ # Please keep this list sorted assert __all__ == sorted(__all__) + + +# Lazy import these modules +def __getattr__(name): + if name == "janitor": + return importlib.import_module(".datapipes.utils." + name, __name__) + else: + try: + return importlib.import_module("." + name, __name__) + except ModuleNotFoundError: + if name in globals(): + return globals()[name] + else: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from None