Skip to content

Commit

Permalink
Merge pull request #266 from Modalities/fix/tests
Browse files Browse the repository at this point in the history
fix: towards torch 2.5 compatibility / fix unit tests (github actions)
  • Loading branch information
le1nux authored Nov 12, 2024
2 parents 7f85f33 + da5862e commit fabe182
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
sudo apt-get update
sudo apt-get install curl -y # required by coveralls
sudo apt-get install git -y
python -m pip install torch
python -m pip install torch~=2.4.1
python -m pip install --upgrade pip setuptools wheel
export FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE
python -m pip install -e .[tests]
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Currently, the flash attention dependency cannot be installed without torch bein
Until the flash attention developers fix this, we have to run

```sh
pip install torch
pip install torch~=2.4.1
```
beforehand.

Expand All @@ -75,7 +75,7 @@ pip install -e .
To install Modalities via pip, run

```sh
pip install torch
pip install torch~=2.4.1
pip install modalities
```

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ description = "Modalities, a PyTorch-native framework for distributed and reprod
readme = "README.md"
dependencies = [
"numpy<2.0",
"torch>=2.3",
"torch~=2.4.1",
"packaging",
"tqdm",
"pyyaml",
Expand Down
7 changes: 6 additions & 1 deletion src/modalities/dataloader/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Iterable, Optional, Union

from torch.utils.data import Dataset, DistributedSampler, Sampler
from torch.utils.data.dataloader import DataLoader, T_co, _collate_fn_t, _worker_init_fn_t
from torch.utils.data.dataloader import DataLoader, _collate_fn_t, _worker_init_fn_t

try: # torch <= 2.4
from torch.utils.data.dataloader import T_co
except ImportError: # torch >= 2.5
from torch.utils.data.dataloader import _T_co as T_co

from modalities.dataloader.samplers import ResumableBatchSampler

Expand Down

0 comments on commit fabe182

Please sign in to comment.