Skip to content

Commit

Permalink
Merge pull request #37 from PeterJCLaw/type-annotations
Browse files Browse the repository at this point in the history
Add mypy & type annotations.
  • Loading branch information
danpalmer authored Jul 14, 2024
2 parents 6bedbb8 + 03164bd commit 51b01c9
Show file tree
Hide file tree
Showing 29 changed files with 563 additions and 198 deletions.
26 changes: 26 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,32 @@ jobs:
run: |
poetry run flake8 . --jobs=auto --format=github
type-check:
name: Type-check
runs-on: ubuntu-20.04

steps:
- name: Checkout
uses: actions/checkout@v3

- name: Set up Python 3.9
uses: actions/setup-python@v4
with:
python-version: 3.9

- name: Set up Poetry
uses: abatilo/[email protected]
with:
poetry-version: 1.7.1

- name: Install dependencies
run: |
poetry install
- name: Type-check
run: |
poetry run mypy src tests
validate-dependencies:
name: Check dependency locks
runs-on: ubuntu-20.04
Expand Down
60 changes: 59 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ tox-gh-actions = "^2.4.0"
psycopg2-binary = "^2.8.6"
pytest-check = "^1.0.1"

mypy = "^1.8.0"
types-tqdm = "^4.66.0.20240106"

[tool.poetry.group.dev.dependencies]
tox-pyenv = "^1.1.0"

Expand All @@ -48,6 +51,21 @@ use_parentheses = true
ensure_newline_before_comments = true
line_length = 80

[tool.mypy]
warn_unused_configs = true

# Be fairly strict with our types
strict_optional = true
enable_error_code = "ignore-without-code"
disallow_incomplete_defs = true
disallow_any_generics = true
disallow_untyped_decorators = true
disallow_untyped_defs = true

[[tool.mypy.overrides]]
module = "django.*"
ignore_missing_imports = true

[tool.pytest.ini_options]
DJANGO_SETTINGS_MODULE = "testsite.settings"
django_find_project = false
Expand Down
40 changes: 32 additions & 8 deletions src/devdata/anonymisers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import pathlib
import random
from typing import Any, TypeVar

import faker
from django.db import models

from .types import Anonymiser, GenericAnonymiser
from .utils import get_exported_pks_for_model

T = TypeVar("T")


def faker_anonymise(
generator, *args, preserve_nulls=False, unique=False, **kwargs
):
def anonymise(*, pii_value, fake, **_kwargs):
generator: str,
*args: Any,
preserve_nulls: bool = False,
unique: bool = False,
**kwargs: Any,
) -> Anonymiser:
def anonymise(*, pii_value: T, fake: faker.Faker, **_kwargs: object) -> T:
if preserve_nulls and pii_value is None:
return None

Expand All @@ -16,8 +28,15 @@ def anonymise(*, pii_value, fake, **_kwargs):
return anonymise


def preserve_internal(alternative):
def anonymise(obj, field, pii_value, **kwargs):
def preserve_internal(
alternative: GenericAnonymiser[T],
) -> GenericAnonymiser[T]:
def anonymise(
obj: models.Model,
field: str,
pii_value: T,
**kwargs: Any,
) -> T:
if getattr(obj, "is_superuser", False) or getattr(
obj, "is_staff", False
):
Expand All @@ -27,16 +46,21 @@ def anonymise(obj, field, pii_value, **kwargs):
return anonymise


def const(value, preserve_nulls=False):
def anonymise(*_, pii_value, **_kwargs):
def const(value: T, preserve_nulls: bool = False) -> GenericAnonymiser[T]:
def anonymise(*_: object, pii_value: T, **_kwargs: object) -> T:
if preserve_nulls and pii_value is None:
return None
return value

return anonymise


def random_foreign_key(obj, field, dest, **_kwargs):
def random_foreign_key(
obj: models.Model,
field: str,
dest: pathlib.Path,
**_kwargs: object,
) -> Any:
related_model = obj._meta.get_field(field).related_model
exported_pks = get_exported_pks_for_model(dest, related_model)
return random.choice(exported_pks)
31 changes: 22 additions & 9 deletions src/devdata/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from __future__ import annotations

import json
from collections.abc import Collection
from pathlib import Path

from django.core.management import call_command
from django.core.management.color import no_style
Expand All @@ -20,14 +24,14 @@
)


def validate_strategies(only=None):
def validate_strategies(only: Collection[str] = ()) -> None:
not_found = []

for model in get_all_models():
if model._meta.abstract:
continue

app_model_label = to_app_model_label(model)
app_model_label = to_app_model_label(model) # type: ignore[arg-type] # mypy can't see that models are hashable

if app_model_label not in settings.strategies:
if only and app_model_label not in only:
Expand All @@ -49,7 +53,7 @@ def validate_strategies(only=None):
)


def export_migration_state(django_dbname, dest):
def export_migration_state(django_dbname: str, dest: Path) -> None:
file_path = migrations_file_path(dest)
file_path.parent.mkdir(parents=True, exist_ok=True)

Expand All @@ -69,7 +73,12 @@ def export_migration_state(django_dbname, dest):
json.dump(migration_state, f, indent=4, cls=DjangoJSONEncoder)


def export_data(django_dbname, dest, only=None, no_update=False):
def export_data(
django_dbname: str,
dest: Path,
only: Collection[str] = (),
no_update: bool = False,
) -> None:
model_strategies = sort_model_strategies(settings.strategies)
bar = progress(model_strategies)
for app_model_label, strategy in bar:
Expand Down Expand Up @@ -100,7 +109,11 @@ def export_data(django_dbname, dest, only=None, no_update=False):
)


def export_extras(django_dbname, dest, no_update=False):
def export_extras(
django_dbname: str,
dest: Path,
no_update: bool = False,
) -> None:
bar = progress(settings.extra_strategies)
for strategy in bar:
bar.set_postfix({"extra": strategy.name})
Expand All @@ -114,7 +127,7 @@ def export_extras(django_dbname, dest, no_update=False):
)


def import_schema(src, django_dbname):
def import_schema(src: Path, django_dbname: str) -> None:
connection = connections[django_dbname]

with disable_migrations():
Expand Down Expand Up @@ -149,7 +162,7 @@ def import_schema(src, django_dbname):
)


def import_data(src, django_dbname):
def import_data(src: Path, django_dbname: str) -> None:
model_strategies = sort_model_strategies(settings.strategies)
bar = progress(model_strategies)
for app_model_label, strategy in bar:
Expand All @@ -160,14 +173,14 @@ def import_data(src, django_dbname):
strategy.import_data(django_dbname, src, model)


def import_extras(src, django_dbname):
def import_extras(src: Path, django_dbname: str) -> None:
bar = progress(settings.extra_strategies)
for strategy in bar:
bar.set_postfix({"extra": strategy.name})
strategy.import_data(django_dbname, src)


def import_cleanup(src, django_dbname):
def import_cleanup(src: Path, django_dbname: str) -> None:
conn = connections[django_dbname]
with conn.cursor() as cursor:
for reset_sql in conn.ops.sequence_reset_sql(
Expand Down
22 changes: 15 additions & 7 deletions src/devdata/extras.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from __future__ import annotations

import json
import textwrap
from pathlib import Path
from typing import Callable, Dict, Set, Tuple
from typing import Any, Callable

from django.db import connections

Logger = Callable[[object], None]
Logger = Callable[[str], None]


class ExtraImport:
"""
Base extra defining how to get data into a fresh database.
"""

depends_on = () # type: Tuple[str, ...]
name: str
depends_on: tuple[str, ...] = ()

def __init__(self) -> None:
pass
Expand All @@ -28,9 +31,9 @@ class ExtraExport:
Base extra defining how to get data out of an existing database.
"""

seen_names = set() # type: Set[Tuple[str, str]]
seen_names: set[str] = set()

def __init__(self, *args, name, **kwargs):
def __init__(self, *args: Any, name: str, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

self.name = name
Expand Down Expand Up @@ -76,7 +79,12 @@ class PostgresSequences(ExtraExport, ExtraImport):
matching primary keys.
"""

def __init__(self, *args, name="postgres-sequences", **kwargs):
def __init__(
self,
*args: Any,
name: str = "postgres-sequences",
**kwargs: Any,
) -> None:
super().__init__(*args, name=name, **kwargs)

def export_data(
Expand Down Expand Up @@ -154,7 +162,7 @@ def import_data(self, django_dbname: str, src: Path) -> None:
with self.data_file(src).open() as f:
sequences = json.load(f)

def check_simple_value(mapping: Dict[str, str], *, key: str) -> str:
def check_simple_value(mapping: dict[str, str], *, key: str) -> str:
value = mapping[key]
if not value.replace("_", "").isalnum():
raise ValueError(f"{key} is not alphanumeric")
Expand Down
Loading

0 comments on commit 51b01c9

Please sign in to comment.