From 9deb4c16f37fdca663f9600781fbdd88cb022474 Mon Sep 17 00:00:00 2001 From: Danylo Baibak Date: Fri, 26 Jan 2024 01:12:03 -0800 Subject: [PATCH 1/3] Forward fix / Update dill_available API for torchdata (#1222) Summary: Pull Request resolved: https://github.com/pytorch/data/pull/1222 Changes from the PyTorch repo (D53082622) broke torchdata. I updated the dill_available API for torchdata to keep everything in sync. Reviewed By: atalman, ejguan Differential Revision: D53086369 fbshipit-source-id: 7344c4cd3205a38689722330721257b5a01bd32f --- test/test_serialization.py | 8 ++++---- torchdata/datapipes/iter/util/cacheholder.py | 5 +++-- torchdata/datapipes/iter/util/converter.py | 10 ++++++---- torchdata/datapipes/iter/util/dataframemaker.py | 8 ++++---- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/test/test_serialization.py b/test/test_serialization.py index bc0248cfe..cb786cede 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -17,11 +17,11 @@ import torchdata.datapipes.iter as iterdp import torchdata.datapipes.map as mapdp from _utils._common_utils_for_test import create_temp_dir, create_temp_files -from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE +from torch.utils._import_utils import dill_available from torchdata.datapipes.iter import IterableWrapper from torchdata.datapipes.map import SequenceWrapper -if DILL_AVAILABLE: +if dill_available(): import dill dill.extend(use_dill=False) @@ -87,7 +87,7 @@ def _filter_by_module_availability(datapipes): filter_set.update([iterdp.IoPathFileLister, iterdp.IoPathFileOpener, iterdp.IoPathSaver]) if rarfile is None: filter_set.update([iterdp.RarArchiveLoader]) - if torcharrow is None or not DILL_AVAILABLE: + if torcharrow is None or not dill_available(): filter_set.update([iterdp.DataFrameMaker, iterdp.ParquetDataFrameLoader]) return [dp for dp in datapipes if dp[0] not in filter_set] @@ -374,7 +374,7 @@ def test_serializable_with_dill(self): # Skipping value comparison for these DataPipes dp_skip_comparison = {iterdp.OnDiskCacheHolder, iterdp.ParagraphAggregator} for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: - if DILL_AVAILABLE: + if dill_available(): try: if dpipe in dp_skip_comparison: # Make sure they are picklable/loadable (no value comparison) datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] diff --git a/torchdata/datapipes/iter/util/cacheholder.py b/torchdata/datapipes/iter/util/cacheholder.py index f12a6a38c..5ddcf13cd 100644 --- a/torchdata/datapipes/iter/util/cacheholder.py +++ b/torchdata/datapipes/iter/util/cacheholder.py @@ -21,13 +21,14 @@ except ImportError: portalocker = None -from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE +from torch.utils._import_utils import dill_available +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn from torch.utils.data.graph import traverse_dps from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterableWrapper, IterDataPipe -if DILL_AVAILABLE: +if dill_available(): import dill dill.extend(use_dill=False) diff --git a/torchdata/datapipes/iter/util/converter.py b/torchdata/datapipes/iter/util/converter.py index 0721e1741..1f5e7c9ef 100644 --- a/torchdata/datapipes/iter/util/converter.py +++ b/torchdata/datapipes/iter/util/converter.py @@ -8,10 +8,12 @@ from typing import Callable, Dict, Optional +from torch.utils._import_utils import dill_available + from torch.utils.data import IterDataPipe, MapDataPipe -from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn -if DILL_AVAILABLE: +if dill_available(): import dill dill.extend(use_dill=False) @@ -108,7 +110,7 @@ def __len__(self): return len(self._map) # type: ignore[arg-type] def __getstate__(self): - if DILL_AVAILABLE: + if dill_available(): dill_key_value_fn = dill.dumps(self.key_value_fn) else: dill_key_value_fn = self.key_value_fn @@ -120,7 +122,7 @@ def __getstate__(self): def __setstate__(self, state): (self.datapipe, dill_key_value_fn, self._map) = state - if DILL_AVAILABLE: + if dill_available(): self.key_value_fn = dill.loads(dill_key_value_fn) # type: ignore[assignment] else: self.key_value_fn = dill_key_value_fn # type: ignore[assignment] diff --git a/torchdata/datapipes/iter/util/dataframemaker.py b/torchdata/datapipes/iter/util/dataframemaker.py index 5e24e8496..a7e5c27b0 100644 --- a/torchdata/datapipes/iter/util/dataframemaker.py +++ b/torchdata/datapipes/iter/util/dataframemaker.py @@ -7,7 +7,7 @@ from functools import partial from typing import List, Optional, TypeVar -from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE +from torch.utils._import_utils import dill_available from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe @@ -19,7 +19,7 @@ torcharrow = None parquet = None -if DILL_AVAILABLE: +if dill_available(): import dill dill.extend(use_dill=False) @@ -150,7 +150,7 @@ def __iter__(self): yield torcharrow.from_arrow(row_group, dtype=self.dtype) def __getstate__(self): - if DILL_AVAILABLE: + if dill_available(): dill_dtype = dill.dumps(self.dtype) else: dill_dtype = self.dtype @@ -161,7 +161,7 @@ def __getstate__(self): def __setstate__(self, state): (self.source_dp, dill_dtype, self.columns, self.device, self.use_threads) = state - if DILL_AVAILABLE: + if dill_available(): self.dtype = dill.loads(dill_dtype) # type: ignore[assignment] else: self.dtype = dill_dtype # type: ignore[assignment] From 2a8818fbe24e1a5d961f6554e5e05f3cd6e76d24 Mon Sep 17 00:00:00 2001 From: Yingxin Kang Date: Mon, 26 Feb 2024 10:50:53 -0800 Subject: [PATCH 2/3] Data loader directly use loaded state dict when save right after load Summary: If no iterator is created in the middle of load_state_dict and state_dict calls, we should be able to directly return the original state dict without triggering reading service because the states are unchanged Reviewed By: xunnanxu Differential Revision: D54102267 fbshipit-source-id: b7402975b0f871d58a7b6452e29dc1e029733a9b --- test/dataloader2/test_dataloader2.py | 2 ++ torchdata/dataloader2/dataloader2.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index 2c28b0169..a438b1ead 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -136,6 +136,8 @@ def test_dataloader2_load_state_dict(self) -> None: restored_data_loader: DataLoader2 = DataLoader2(datapipe=None, reading_service=reading_service) restored_data_loader.load_state_dict(state) + new_state = restored_data_loader.state_dict() + self.assertDictEqual(state, new_state) restored_data_loader_datapipe = restored_data_loader.datapipe deserialized_datapipe = pickle.loads(state[SERIALIZED_DATAPIPE_KEY_NAME]) diff --git a/torchdata/dataloader2/dataloader2.py b/torchdata/dataloader2/dataloader2.py index f8844a14a..a8767a093 100644 --- a/torchdata/dataloader2/dataloader2.py +++ b/torchdata/dataloader2/dataloader2.py @@ -193,6 +193,7 @@ def __init__( self._reset_seed: bool = True # Seed generator as of beginning of each epoch self._initial_seed_generator: SeedGenerator = clone(self._seed_generator) + self._state_dict: Optional[Dict[str, Any]] = None def __iter__(self) -> DataLoader2Iterator[T_co]: r""" @@ -283,6 +284,13 @@ def state_dict(self) -> Dict[str, Any]: - ``serialized_datapipe``:Serialized ``DataPipe`` before ``ReadingService`` adaption. - ``reading_service_state``: The state of ``ReadingService`` and adapted ``DataPipe``. """ + + # If state_dict is called right after load_state_dict calls, without iterator created in the middle, + # we should directly return the original state dict without triggering reading_service.checkpoint + # because the states are unchanged + if self.valid_iterator_id is None and self._state_dict is not None: + return self._state_dict + reading_service_state = None if self.reading_service is not None and isinstance(self.reading_service, CheckpointableReadingServiceInterface): reading_service_state = self.reading_service.checkpoint() @@ -329,6 +337,8 @@ def from_state( data_loader._seed_generator = pickle.loads(randomness_state[2]) data_loader._initial_seed_generator = pickle.loads(randomness_state[3]) + data_loader._state_dict = state + return data_loader def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -344,6 +354,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: "Please create a new dataloader in order to use load state dict." ) + self._state_dict = state_dict + serialized_datapipe = state_dict[SERIALIZED_DATAPIPE_KEY_NAME] reading_service_state = state_dict[READING_SERVICE_STATE_KEY_NAME] From 7ef198ac165998fd74852b9d4bae69d4d6dfb66a Mon Sep 17 00:00:00 2001 From: DanilBaibak Date: Tue, 5 Mar 2024 01:05:57 -0800 Subject: [PATCH 3/3] Migrate the macOS runners label from macos-m1-12 to macos-m1-stable (#1224) Summary: There is a new label for our macOS runners: "macos-m1-stable". All runners labeled "macos-m1-12" should be switched to "macos-m1-stable". [Here](https://fb.workplace.com/groups/pytorch.dev.perf.infra.teams/permalink/7546708885348237/) you can find more detailed info. Pull Request resolved: https://github.com/pytorch/data/pull/1224 Reviewed By: huydhn, ejguan Differential Revision: D53476591 Pulled By: DanilBaibak fbshipit-source-id: 5e25c8867425bbf09cae83d80b42a1e89cb68332 --- .github/workflows/build-conda-m1.yml | 2 +- .github/workflows/build-wheels-m1.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-conda-m1.yml b/.github/workflows/build-conda-m1.yml index f9c337b44..1f014acb9 100644 --- a/.github/workflows/build-conda-m1.yml +++ b/.github/workflows/build-conda-m1.yml @@ -37,7 +37,7 @@ jobs: package-name: torchdata env-var-script: packaging/env-var-script.txt smoke-test-script: test/smoke_test/smoke_test.py - runner-type: macos-m1-12 + runner-type: macos-m1-stable trigger-event: ${{ github.event_name }} secrets: CONDA_PYTORCHBOT_TOKEN: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} diff --git a/.github/workflows/build-wheels-m1.yml b/.github/workflows/build-wheels-m1.yml index d75f0dac0..05be60b51 100644 --- a/.github/workflows/build-wheels-m1.yml +++ b/.github/workflows/build-wheels-m1.yml @@ -39,6 +39,6 @@ jobs: post-script: "" package-name: torchdata env-var-script: packaging/env-var-script.txt - runner-type: macos-m1-12 + runner-type: macos-m1-stable smoke-test-script: test/smoke_test/smoke_test.py trigger-event: ${{ github.event_name }}