diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index fbf5cc36..ea01f21e 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -1,7 +1,12 @@ from typing import Iterable, Optional, Union from torch.utils.data import Dataset, DistributedSampler, Sampler -from torch.utils.data.dataloader import DataLoader, T_co, _collate_fn_t, _worker_init_fn_t +from torch.utils.data.dataloader import DataLoader, _collate_fn_t, _worker_init_fn_t + +try: # torch <= 2.4 + from torch.utils.data.dataloader import T_co +except ImportError: # torch >= 2.5 + from torch.utils.data.dataloader import _T_co as T_co from modalities.dataloader.samplers import ResumableBatchSampler