Skip to content

Commit

Permalink
Downgrade python requirement from 3.12 --> 3.10 (#36)
Browse files Browse the repository at this point in the history
* Reduce requirement to python 3.10

Signed-off-by: Fabrice Normandin <[email protected]>

* Update devcontainer and github workflows for 3.10

Signed-off-by: Fabrice Normandin <[email protected]>

* Update dependencies and lockfile

Signed-off-by: Fabrice Normandin <[email protected]>

* Create a separate dep group for docs

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix bug with import in generate_reference_docs.py

Signed-off-by: Fabrice Normandin <[email protected]>

* Add missing mktestdocs deps in dev group

Signed-off-by: Fabrice Normandin <[email protected]>

* Skip tests with missing files instead of passing

Signed-off-by: Fabrice Normandin <[email protected]>

---------

Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice authored Aug 8, 2024
1 parent 264b5a1 commit 1241d58
Show file tree
Hide file tree
Showing 18 changed files with 1,374 additions and 1,240 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
{
"name": "Research Template",
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
"image": "mcr.microsoft.com/devcontainers/python:1-3.12-bullseye",
"image": "mcr.microsoft.com/devcontainers/python:1-3.10-bullseye",
// Features to add to the dev container. More info: https://containers.dev/features.
"features": {
"ghcr.io/devcontainers-contrib/features/pdm:2": {},
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: '3.12'
python-version: '3.10'
- run: pip install pre-commit
- run: pre-commit --version
- run: pre-commit install
Expand All @@ -41,7 +41,7 @@ jobs:
max-parallel: 4
matrix:
platform: [ubuntu-latest]
python-version: ['3.12']
python-version: ['3.10']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
Expand Down Expand Up @@ -72,7 +72,7 @@ jobs:
strategy:
max-parallel: 1
matrix:
python-version: ['3.12']
python-version: ['3.10']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
Expand Down Expand Up @@ -132,16 +132,16 @@ jobs:
cluster: ['mila']
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: 3.12
python-version: "3.10"
- run: pip install pdm
- name: Install dependencies
run: pdm install

- name: Test with pytest
run: pdm run pytest -v --cov=project --cov-report=xml --cov-append --gen-missing
run: pdm run pytest -v --cov=project --cov-report=xml --cov-append --skip-if-files-missing

# TODO: Re-enable this later
# - name: Test with pytest (only slow tests)
Expand Down
3 changes: 1 addition & 2 deletions docs/generate_reference_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from logging import getLogger as get_logger
from pathlib import Path

import mkdocs_gen_files

from project.utils.env_vars import REPO_ROOTDIR

logger = get_logger(__name__)
Expand All @@ -25,6 +23,7 @@ def add_doc_for_module(module_path: Path) -> None:
- [ ] We don't currently see the docs from the docstrings of __init__.py files.
- [ ] Might be nice to show the config files also?
"""
import mkdocs_gen_files

assert module_path.is_dir() # and (module_path / "__init__.py").exists(), module_path

Expand Down
2,434 changes: 1,265 additions & 1,169 deletions pdm.lock

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions project/algorithms/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@
from collections.abc import Mapping
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any, Generic, Literal, override
from typing import Any, Generic, Literal

import torch
from lightning import LightningModule, Trainer
from lightning import pytorch as pl
from typing_extensions import TypeVar
from typing_extensions import TypeVar, override

from project.utils.types import PyTree
from project.utils.types import NestedMapping
from project.utils.utils import get_log_dir

logger = get_logger(__name__)

BatchType = TypeVar("BatchType", bound=PyTree[torch.Tensor], contravariant=True)
BatchType = TypeVar(
"BatchType",
bound=torch.Tensor | tuple[torch.Tensor, ...] | NestedMapping[str, torch.Tensor],
contravariant=True,
)
StepOutputType = TypeVar(
"StepOutputType",
bound=torch.Tensor | Mapping[str, Any] | None,
Expand Down
3 changes: 2 additions & 1 deletion project/algorithms/callbacks/classification_metrics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import warnings
from logging import getLogger as get_logger
from typing import Literal, NotRequired, Required, TypedDict, override
from typing import Literal, TypedDict

import torch
import torchmetrics
from lightning import LightningModule, Trainer
from torch import Tensor
from torchmetrics.classification import MulticlassAccuracy
from typing_extensions import NotRequired, Required, override

from project.algorithms.callbacks.callback import BatchType, Callback
from project.utils.types.protocols import ClassificationDataModule
Expand Down
3 changes: 2 additions & 1 deletion project/algorithms/callbacks/samples_per_second.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import time
from typing import Literal, override
from typing import Literal

from lightning import LightningModule, Trainer
from torch import Tensor
from torch.optim import Optimizer
from typing_extensions import override

from project.algorithms.callbacks.callback import BatchType, Callback, StepOutputType
from project.utils.types import is_sequence_of
Expand Down
14 changes: 11 additions & 3 deletions project/algorithms/jax_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from collections.abc import Callable
from typing import Concatenate, Literal
from typing import Concatenate, Literal, ParamSpec, TypeVar

import flax.linen
import jax
Expand Down Expand Up @@ -203,14 +203,22 @@ def to_channels_last(x: jax.Array) -> jax.Array:
return x.transpose(0, 2, 3, 1)


def jit[**P, Out](
P = ParamSpec("P")
Out = TypeVar("Out")


def jit(
fn: Callable[P, Out],
) -> Callable[P, Out]:
"""Small type hint fix for jax's `jit` (preserves the signature of the callable)."""
return jax.jit(fn) # type: ignore


def value_and_grad[In, **P, Out, Aux](
In = TypeVar("In")
Aux = TypeVar("Aux")


def value_and_grad(
fn: Callable[Concatenate[In, P], tuple[Out, Aux]],
argnums: Literal[0] = 0,
has_aux: Literal[True] = True,
Expand Down
4 changes: 2 additions & 2 deletions project/algorithms/testsuites/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import NotRequired, Protocol, TypedDict
from typing import Protocol, TypedDict

import torch
from lightning import LightningDataModule, LightningModule, Trainer
from torch import Tensor
from typing_extensions import TypeVar
from typing_extensions import NotRequired, TypeVar

from project.utils.types import PyTree
from project.utils.types.protocols import DataModule, Module
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import TypeVar

from torch import Tensor
from torchvision.tv_tensors import Image

Expand All @@ -10,9 +12,10 @@
# todo: need to decide whether this should be a base class or just a protocol.
# - IF this is a protocol, then we can't use issubclass with it, so it can't be used in the
# `supported_datamodule_types` field on AlgorithmTests subclasses (for example `ClassificationAlgorithmTests`).
BatchType = TypeVar("BatchType", bound=tuple[Image, Tensor])


class ImageClassificationDataModule[BatchType: tuple[Image, Tensor]](
class ImageClassificationDataModule(
VisionDataModule[BatchType], ClassificationDataModule[BatchType]
):
"""Lightning data modules for image classification."""
Expand Down
6 changes: 3 additions & 3 deletions project/datamodules/image_classification/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import Callable
from logging import getLogger as get_logger
from pathlib import Path
from typing import ClassVar, Literal
from typing import ClassVar, Literal, NewType

import rich
import rich.logging
Expand All @@ -33,8 +33,8 @@ def imagenet_normalization():
return transform_lib.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


type ClassIndex = int
type ImageIndex = int
ClassIndex = NewType("ClassIndex", int)
ImageIndex = NewType("ImageIndex", int)


class ImageNetDataModule(VisionDataModule):
Expand Down
15 changes: 9 additions & 6 deletions project/datamodules/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Callable
from logging import getLogger as get_logger
from pathlib import Path
from typing import ClassVar, Concatenate, Literal
from typing import ClassVar, Concatenate, Literal, ParamSpec, TypeVar

import torch
from lightning import LightningDataModule
Expand All @@ -22,8 +22,11 @@

logger = get_logger(__name__)

BatchType_co = TypeVar("BatchType_co", covariant=True)
P = ParamSpec("P")

class VisionDataModule[BatchType_co](LightningDataModule, DataModule[BatchType_co]):

class VisionDataModule(LightningDataModule, DataModule[BatchType_co]):
"""A LightningDataModule for image datasets.
(Taken from pl_bolts which is not very well maintained.)
Expand Down Expand Up @@ -199,7 +202,7 @@ def _get_splits(self, len_dataset: int) -> list[int]:
def default_transforms(self) -> Callable:
"""Default transform for the dataset."""

def train_dataloader[**P](
def train_dataloader(
self,
_dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader,
*args: P.args,
Expand All @@ -217,7 +220,7 @@ def train_dataloader[**P](
**kwargs,
)

def val_dataloader[**P](
def val_dataloader(
self,
_dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader,
*args: P.args,
Expand All @@ -234,7 +237,7 @@ def val_dataloader[**P](
**kwargs,
)

def test_dataloader[**P](
def test_dataloader(
self,
_dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader,
*args: P.args,
Expand All @@ -251,7 +254,7 @@ def test_dataloader[**P](
**kwargs,
)

def _data_loader[**P](
def _data_loader(
self,
dataset: Dataset,
_dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader,
Expand Down
5 changes: 4 additions & 1 deletion project/utils/hydra_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,10 @@ def _default_factory(
return default # type: ignore


def make_config_and_store[Target](
Target = TypeVar("Target")


def make_config_and_store(
target: Callable[..., Target], *, store: hydra_zen.ZenStore, **overrides
):
"""Creates a config dataclass for the given target and stores it in the config store.
Expand Down
7 changes: 5 additions & 2 deletions project/utils/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections.abc import Mapping, Sequence
from contextlib import contextmanager
from logging import getLogger as get_logger
from typing import Any
from typing import Any, Generic, TypeVar

import lightning
import numpy as np
Expand Down Expand Up @@ -111,7 +111,10 @@ def _parametrized_fixture(request: pytest.FixtureRequest):
return _parametrized_fixture


class ParametrizedFixture[T]:
T = TypeVar("T")


class ParametrizedFixture(Generic[T]):
"""Small helper function that creates a parametrized pytest fixture for the given values.
The name of the fixture will be the name that is used for this variable on a class.
Expand Down
20 changes: 10 additions & 10 deletions project/utils/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, NewType, TypeGuard, Unpack
from typing import Annotated, Any, NewType, TypeAlias, TypeGuard

import annotated_types
from torch import Tensor
from typing_extensions import TypeVar, TypeVarTuple
from typing_extensions import TypeVar, TypeVarTuple, Unpack

from .protocols import Dataclass, DataModule, Module

Expand All @@ -22,28 +22,28 @@
OutT = TypeVar("OutT", default=Tensor, covariant=True)
Ts = TypeVarTuple("Ts", default=Unpack[tuple[Tensor, ...]])
T = TypeVar("T", default=Tensor)
K = TypeVar("K")
V = TypeVar("V")

type NestedDict[K, V] = dict[K, V | NestedDict[K, V]]
type NestedMapping[K, V] = Mapping[K, V | NestedMapping[K, V]]
type PyTree[T] = T | Iterable[PyTree[T]] | Mapping[Any, PyTree[T]]
NestedDict: TypeAlias = dict[K, V | "NestedDict[K, V]"]
NestedMapping = Mapping[K, V | "NestedMapping[K, V]"]
PyTree = T | Iterable["PyTree[T]"] | Mapping[Any, "PyTree[T]"]


def is_list_of[V](object: Any, item_type: type[V] | tuple[type[V], ...]) -> TypeGuard[list[V]]:
def is_list_of(object: Any, item_type: type[V] | tuple[type[V], ...]) -> TypeGuard[list[V]]:
"""Used to check (and tell the type checker) that `object` is a list of items of this type."""
return isinstance(object, list) and is_sequence_of(object, item_type)


def is_sequence_of[V](
def is_sequence_of(
object: Any, item_type: type[V] | tuple[type[V], ...]
) -> TypeGuard[Sequence[V]]:
"""Used to check (and tell the type checker) that `object` is a sequence of items of this
type."""
return isinstance(object, Sequence) and all(isinstance(value, item_type) for value in object)


def is_mapping_of[K, V](
object: Any, key_type: type[K], value_type: type[V]
) -> TypeGuard[Mapping[K, V]]:
def is_mapping_of(object: Any, key_type: type[K], value_type: type[V]) -> TypeGuard[Mapping[K, V]]:
"""Used to check (and tell the type checker) that `object` is a mapping with keys and values of
the given types."""
return isinstance(object, Mapping) and all(
Expand Down
Loading

0 comments on commit 1241d58

Please sign in to comment.