diff --git a/.github/workflows/build-conda-m1.yml b/.github/workflows/build-conda-m1.yml index f9c337b44..1f014acb9 100644 --- a/.github/workflows/build-conda-m1.yml +++ b/.github/workflows/build-conda-m1.yml @@ -37,7 +37,7 @@ jobs: package-name: torchdata env-var-script: packaging/env-var-script.txt smoke-test-script: test/smoke_test/smoke_test.py - runner-type: macos-m1-12 + runner-type: macos-m1-stable trigger-event: ${{ github.event_name }} secrets: CONDA_PYTORCHBOT_TOKEN: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} diff --git a/.github/workflows/build-wheels-m1.yml b/.github/workflows/build-wheels-m1.yml index d75f0dac0..05be60b51 100644 --- a/.github/workflows/build-wheels-m1.yml +++ b/.github/workflows/build-wheels-m1.yml @@ -39,6 +39,6 @@ jobs: post-script: "" package-name: torchdata env-var-script: packaging/env-var-script.txt - runner-type: macos-m1-12 + runner-type: macos-m1-stable smoke-test-script: test/smoke_test/smoke_test.py trigger-event: ${{ github.event_name }} diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index 2c28b0169..a438b1ead 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -136,6 +136,8 @@ def test_dataloader2_load_state_dict(self) -> None: restored_data_loader: DataLoader2 = DataLoader2(datapipe=None, reading_service=reading_service) restored_data_loader.load_state_dict(state) + new_state = restored_data_loader.state_dict() + self.assertDictEqual(state, new_state) restored_data_loader_datapipe = restored_data_loader.datapipe deserialized_datapipe = pickle.loads(state[SERIALIZED_DATAPIPE_KEY_NAME]) diff --git a/test/test_serialization.py b/test/test_serialization.py index bc0248cfe..cb786cede 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -17,11 +17,11 @@ import torchdata.datapipes.iter as iterdp import torchdata.datapipes.map as mapdp from _utils._common_utils_for_test import create_temp_dir, create_temp_files -from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE +from torch.utils._import_utils import dill_available from torchdata.datapipes.iter import IterableWrapper from torchdata.datapipes.map import SequenceWrapper -if DILL_AVAILABLE: +if dill_available(): import dill dill.extend(use_dill=False) @@ -87,7 +87,7 @@ def _filter_by_module_availability(datapipes): filter_set.update([iterdp.IoPathFileLister, iterdp.IoPathFileOpener, iterdp.IoPathSaver]) if rarfile is None: filter_set.update([iterdp.RarArchiveLoader]) - if torcharrow is None or not DILL_AVAILABLE: + if torcharrow is None or not dill_available(): filter_set.update([iterdp.DataFrameMaker, iterdp.ParquetDataFrameLoader]) return [dp for dp in datapipes if dp[0] not in filter_set] @@ -374,7 +374,7 @@ def test_serializable_with_dill(self): # Skipping value comparison for these DataPipes dp_skip_comparison = {iterdp.OnDiskCacheHolder, iterdp.ParagraphAggregator} for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: - if DILL_AVAILABLE: + if dill_available(): try: if dpipe in dp_skip_comparison: # Make sure they are picklable/loadable (no value comparison) datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] diff --git a/torchdata/dataloader2/dataloader2.py b/torchdata/dataloader2/dataloader2.py index f8844a14a..a8767a093 100644 --- a/torchdata/dataloader2/dataloader2.py +++ b/torchdata/dataloader2/dataloader2.py @@ -193,6 +193,7 @@ def __init__( self._reset_seed: bool = True # Seed generator as of beginning of each epoch self._initial_seed_generator: SeedGenerator = clone(self._seed_generator) + self._state_dict: Optional[Dict[str, Any]] = None def __iter__(self) -> DataLoader2Iterator[T_co]: r""" @@ -283,6 +284,13 @@ def state_dict(self) -> Dict[str, Any]: - ``serialized_datapipe``:Serialized ``DataPipe`` before ``ReadingService`` adaption. - ``reading_service_state``: The state of ``ReadingService`` and adapted ``DataPipe``. """ + + # If state_dict is called right after load_state_dict calls, without iterator created in the middle, + # we should directly return the original state dict without triggering reading_service.checkpoint + # because the states are unchanged + if self.valid_iterator_id is None and self._state_dict is not None: + return self._state_dict + reading_service_state = None if self.reading_service is not None and isinstance(self.reading_service, CheckpointableReadingServiceInterface): reading_service_state = self.reading_service.checkpoint() @@ -329,6 +337,8 @@ def from_state( data_loader._seed_generator = pickle.loads(randomness_state[2]) data_loader._initial_seed_generator = pickle.loads(randomness_state[3]) + data_loader._state_dict = state + return data_loader def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -344,6 +354,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: "Please create a new dataloader in order to use load state dict." ) + self._state_dict = state_dict + serialized_datapipe = state_dict[SERIALIZED_DATAPIPE_KEY_NAME] reading_service_state = state_dict[READING_SERVICE_STATE_KEY_NAME] diff --git a/torchdata/datapipes/iter/util/cacheholder.py b/torchdata/datapipes/iter/util/cacheholder.py index f12a6a38c..5ddcf13cd 100644 --- a/torchdata/datapipes/iter/util/cacheholder.py +++ b/torchdata/datapipes/iter/util/cacheholder.py @@ -21,13 +21,14 @@ except ImportError: portalocker = None -from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE +from torch.utils._import_utils import dill_available +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn from torch.utils.data.graph import traverse_dps from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterableWrapper, IterDataPipe -if DILL_AVAILABLE: +if dill_available(): import dill dill.extend(use_dill=False) diff --git a/torchdata/datapipes/iter/util/converter.py b/torchdata/datapipes/iter/util/converter.py index 0721e1741..1f5e7c9ef 100644 --- a/torchdata/datapipes/iter/util/converter.py +++ b/torchdata/datapipes/iter/util/converter.py @@ -8,10 +8,12 @@ from typing import Callable, Dict, Optional +from torch.utils._import_utils import dill_available + from torch.utils.data import IterDataPipe, MapDataPipe -from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn -if DILL_AVAILABLE: +if dill_available(): import dill dill.extend(use_dill=False) @@ -108,7 +110,7 @@ def __len__(self): return len(self._map) # type: ignore[arg-type] def __getstate__(self): - if DILL_AVAILABLE: + if dill_available(): dill_key_value_fn = dill.dumps(self.key_value_fn) else: dill_key_value_fn = self.key_value_fn @@ -120,7 +122,7 @@ def __getstate__(self): def __setstate__(self, state): (self.datapipe, dill_key_value_fn, self._map) = state - if DILL_AVAILABLE: + if dill_available(): self.key_value_fn = dill.loads(dill_key_value_fn) # type: ignore[assignment] else: self.key_value_fn = dill_key_value_fn # type: ignore[assignment] diff --git a/torchdata/datapipes/iter/util/dataframemaker.py b/torchdata/datapipes/iter/util/dataframemaker.py index 5e24e8496..a7e5c27b0 100644 --- a/torchdata/datapipes/iter/util/dataframemaker.py +++ b/torchdata/datapipes/iter/util/dataframemaker.py @@ -7,7 +7,7 @@ from functools import partial from typing import List, Optional, TypeVar -from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE +from torch.utils._import_utils import dill_available from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe @@ -19,7 +19,7 @@ torcharrow = None parquet = None -if DILL_AVAILABLE: +if dill_available(): import dill dill.extend(use_dill=False) @@ -150,7 +150,7 @@ def __iter__(self): yield torcharrow.from_arrow(row_group, dtype=self.dtype) def __getstate__(self): - if DILL_AVAILABLE: + if dill_available(): dill_dtype = dill.dumps(self.dtype) else: dill_dtype = self.dtype @@ -161,7 +161,7 @@ def __getstate__(self): def __setstate__(self, state): (self.source_dp, dill_dtype, self.columns, self.device, self.use_threads) = state - if DILL_AVAILABLE: + if dill_available(): self.dtype = dill.loads(dill_dtype) # type: ignore[assignment] else: self.dtype = dill_dtype # type: ignore[assignment]