Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make torchdata default imports lazy #1243

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/stateful_dataloader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,7 @@ def test_multiprocessing_contexts(self):
list(self._get_data_loader(ds_cls(counting_ds_n), multiprocessing_context=ctx, **dl_common_args)),
)

@unittest.skipIf(IS_MACOS, "Not working on macos")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See test_multiprocessing_contexts() above on line 1614, similar non-datapipe test is skipped for macOS so going to follow that here too

Copy link
Contributor Author

@andrewkho andrewkho Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative for this test to pass is to skip ctx == "fork" or locally import torchdata.datapipes, even though there is absolutely no dependency on it, guessing it's something to do with https://www.wefearchange.org/2018/11/forkmacos.rst.html

def _test_multiprocessing_iterdatapipe(self, with_dill):
# Testing to make sure that function from global scope (e.g. imported from library) can be serialized
# and used with multiprocess DataLoader
Expand Down
12 changes: 12 additions & 0 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,5 +778,17 @@ def test_num_workers_mismatch(self):
self.assertTrue(False, "Error should be of type AssertionError")


class TestTorchDataLazyImport(unittest.TestCase):
def test_lazy_imports(self) -> None:
import torchdata

self.assertFalse("datapipes" in torchdata.__dict__)

from torchdata import datapipes as dp, janitor # noqa

self.assertTrue("datapipes" in torchdata.__dict__)
dp.iter.IterableWrapper([1, 2])


if __name__ == "__main__":
unittest.main()
20 changes: 15 additions & 5 deletions torchdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchdata import _extension # noqa: F401

from . import datapipes

janitor = datapipes.utils.janitor
import importlib

try:
from .version import __version__ # noqa: F401
Expand All @@ -22,3 +18,17 @@

# Please keep this list sorted
assert __all__ == sorted(__all__)


# Lazy import all modules
def __getattr__(name):
if name == "janitor":
return importlib.import_module(".datapipes.utils." + name, __name__)
else:
try:
return importlib.import_module("." + name, __name__)
except ModuleNotFoundError:
if name in globals():
return globals()[name]
else:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
2 changes: 2 additions & 0 deletions torchdata/_torchdata/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from typing import List

from torchdata import _extension # noqa: F401

# TODO: Add pyi generate script
class S3Handler:
def __init__(self, request_timeout_ms: int, region: str) -> None: ...
Expand Down
1 change: 1 addition & 0 deletions torchdata/dataloader2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.


from torchdata import _extension # noqa: F401
from torchdata.dataloader2.dataloader2 import DataLoader2, DataLoader2Iterator
from torchdata.dataloader2.error import PauseIteration
from torchdata.dataloader2.reading_service import (
Expand Down
2 changes: 2 additions & 0 deletions torchdata/datapipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from torch.utils.data import DataChunk, functional_datapipe

from torchdata import _extension # noqa: F401

from . import iter, map, utils

__all__ = ["DataChunk", "functional_datapipe", "iter", "map", "utils"]
Loading