Skip to content

Commit

Permalink
make torchdata default imports lazy
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed Apr 23, 2024
1 parent 6355127 commit 90514ff
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
12 changes: 12 additions & 0 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,5 +778,17 @@ def test_num_workers_mismatch(self):
self.assertTrue(False, "Error should be of type AssertionError")


class TestTorchDataLazyImport(unittest.TestCase):
def test_lazy_imports(self) -> None:
import torchdata

self.assertFalse("datapipes" in torchdata.__dict__)

from torchdata import _extension, datapipes as dp, janitor # noqa # noqa

self.assertTrue("datapipes" in torchdata.__dict__)
dp.iter.IterableWrapper([1, 2])


if __name__ == "__main__":
unittest.main()
20 changes: 15 additions & 5 deletions torchdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchdata import _extension # noqa: F401

from . import datapipes

janitor = datapipes.utils.janitor
import importlib

try:
from .version import __version__ # noqa: F401
Expand All @@ -22,3 +18,17 @@

# Please keep this list sorted
assert __all__ == sorted(__all__)


# Lazy import these modules
def __getattr__(name):
if name == "janitor":
return importlib.import_module(".datapipes.utils." + name, __name__)
else:
try:
return importlib.import_module("." + name, __name__)
except ModuleNotFoundError:
if name in globals():
return globals()[name]
else:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from None

0 comments on commit 90514ff

Please sign in to comment.