Skip to content

Commit

Permalink
Mypy 1.7.1 (#905)
Browse files Browse the repository at this point in the history
* update mypy

* make sure functions have type annotations

Otherwise mypy ignores them

* explicit optionals

* more explicit optionals

* make sure functions have type annotations

Otherwise MyPy ignores them

* use mypy version supported by VSCode extension

* fix variable already defined error

* fix type annotation

* add missing type annotation

* white space

* fix errors with `engine` in `sql_shell`

* fix errors in `shell/commands/run.py`

* fix type annotation in `serialisation.py`

* fix return type in `MigrationModule`

* fix error with default not being defined

* fix typo

* fix type annotation

* fix optional value

* fix warnings in `generate.py`

* fix errors in `sql_shell/commands/run.py`

* fix type annotations in `migration_manager.py`

* fix problems with `main.py`

* fix remaining mypy errors in `base.py`

* reorder `staticmethod` and `abstractmethod`

* fix pydantic errors in latest pydantic version

* fix warnings in tests about method signatures without types

* use mypy==1.7.1
  • Loading branch information
dantownsend authored Nov 30, 2023
1 parent 5f972d0 commit b8026da
Show file tree
Hide file tree
Showing 35 changed files with 235 additions and 194 deletions.
3 changes: 1 addition & 2 deletions piccolo/apps/migrations/auto/diffable_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def compare_dicts(
output = {}

for key, value in dict_1.items():

dict_2_value = dict_2.get(key, ...)
if (
dict_2_value is not ...
Expand Down Expand Up @@ -99,7 +98,7 @@ class DiffableTable:
columns: t.List[Column] = field(default_factory=list)
previous_class_name: t.Optional[str] = None

def __post_init__(self):
def __post_init__(self) -> None:
self.columns_map: t.Dict[str, Column] = {
i._meta.name: i for i in self.columns
}
Expand Down
61 changes: 36 additions & 25 deletions piccolo/apps/migrations/auto/migration_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from piccolo.engine import engine_finder
from piccolo.query import Query
from piccolo.query.base import DDL
from piccolo.schema import SchemaDDLBase
from piccolo.table import Table, create_table_class, sort_table_classes
from piccolo.utils.warnings import colored_warning

Expand Down Expand Up @@ -123,6 +124,9 @@ def table_class_names(self) -> t.List[str]:
return list({i.table_class_name for i in self.alter_columns})


AsyncFunction = t.Callable[[], t.Coroutine]


@dataclass
class MigrationManager:
"""
Expand Down Expand Up @@ -152,8 +156,10 @@ class MigrationManager:
alter_columns: AlterColumnCollection = field(
default_factory=AlterColumnCollection
)
raw: t.List[t.Union[t.Callable, t.Coroutine]] = field(default_factory=list)
raw_backwards: t.List[t.Union[t.Callable, t.Coroutine]] = field(
raw: t.List[t.Union[t.Callable, AsyncFunction]] = field(
default_factory=list
)
raw_backwards: t.List[t.Union[t.Callable, AsyncFunction]] = field(
default_factory=list
)

Expand Down Expand Up @@ -227,7 +233,7 @@ def add_column(
db_column_name: t.Optional[str] = None,
column_class_name: str = "",
column_class: t.Optional[t.Type[Column]] = None,
params: t.Dict[str, t.Any] = None,
params: t.Optional[t.Dict[str, t.Any]] = None,
schema: t.Optional[str] = None,
):
"""
Expand Down Expand Up @@ -309,8 +315,8 @@ def alter_column(
tablename: str,
column_name: str,
db_column_name: t.Optional[str] = None,
params: t.Dict[str, t.Any] = None,
old_params: t.Dict[str, t.Any] = None,
params: t.Optional[t.Dict[str, t.Any]] = None,
old_params: t.Optional[t.Dict[str, t.Any]] = None,
column_class: t.Optional[t.Type[Column]] = None,
old_column_class: t.Optional[t.Type[Column]] = None,
schema: t.Optional[str] = None,
Expand All @@ -336,14 +342,14 @@ def alter_column(
)
)

def add_raw(self, raw: t.Union[t.Callable, t.Coroutine]):
def add_raw(self, raw: t.Union[t.Callable, AsyncFunction]):
"""
A migration manager can execute arbitrary functions or coroutines when
run. This is useful if you want to execute raw SQL.
"""
self.raw.append(raw)

def add_raw_backwards(self, raw: t.Union[t.Callable, t.Coroutine]):
def add_raw_backwards(self, raw: t.Union[t.Callable, AsyncFunction]):
"""
When reversing a migration, you may want to run extra code to help
clean up.
Expand Down Expand Up @@ -387,13 +393,13 @@ async def get_table_from_snapshot(
###########################################################################

@staticmethod
async def _print_query(query: t.Union[DDL, Query]):
async def _print_query(query: t.Union[DDL, Query, SchemaDDLBase]):
if isinstance(query, DDL):
print("\n", ";".join(query.ddl) + ";")
else:
print(str(query))

async def _run_query(self, query: t.Union[DDL, Query]):
async def _run_query(self, query: t.Union[DDL, Query, SchemaDDLBase]):
"""
If MigrationManager is not in the preview mode,
executes the queries. else, prints the query.
Expand All @@ -403,7 +409,7 @@ async def _run_query(self, query: t.Union[DDL, Query]):
else:
await query.run()

async def _run_alter_columns(self, backwards=False):
async def _run_alter_columns(self, backwards: bool = False):
for table_class_name in self.alter_columns.table_class_names:
alter_columns = self.alter_columns.for_table_class_name(
table_class_name
Expand All @@ -421,7 +427,6 @@ async def _run_alter_columns(self, backwards=False):
)

for alter_column in alter_columns:

params = (
alter_column.old_params
if backwards
Expand Down Expand Up @@ -622,7 +627,7 @@ async def _run_drop_tables(self, backwards=False):
diffable_table.to_table_class().alter().drop_table()
)

async def _run_drop_columns(self, backwards=False):
async def _run_drop_columns(self, backwards: bool = False):
if backwards:
for drop_column in self.drop_columns.drop_columns:
_Table = await self.get_table_from_snapshot(
Expand All @@ -647,7 +652,7 @@ async def _run_drop_columns(self, backwards=False):
if not columns:
continue

_Table: t.Type[Table] = create_table_class(
_Table = create_table_class(
class_name=table_class_name,
class_kwargs={
"tablename": columns[0].tablename,
Expand All @@ -660,7 +665,7 @@ async def _run_drop_columns(self, backwards=False):
_Table.alter().drop_column(column=column.column_name)
)

async def _run_rename_tables(self, backwards=False):
async def _run_rename_tables(self, backwards: bool = False):
for rename_table in self.rename_tables:
class_name = (
rename_table.new_class_name
Expand Down Expand Up @@ -690,7 +695,7 @@ async def _run_rename_tables(self, backwards=False):
_Table.alter().rename_table(new_name=new_tablename)
)

async def _run_rename_columns(self, backwards=False):
async def _run_rename_columns(self, backwards: bool = False):
for table_class_name in self.rename_columns.table_class_names:
columns = self.rename_columns.for_table_class_name(
table_class_name
Expand Down Expand Up @@ -726,7 +731,7 @@ async def _run_rename_columns(self, backwards=False):
)
)

async def _run_add_tables(self, backwards=False):
async def _run_add_tables(self, backwards: bool = False):
table_classes: t.List[t.Type[Table]] = []
for add_table in self.add_tables:
add_columns: t.List[
Expand Down Expand Up @@ -755,7 +760,7 @@ async def _run_add_tables(self, backwards=False):
for _Table in sorted_table_classes:
await self._run_query(_Table.create_table())

async def _run_add_columns(self, backwards=False):
async def _run_add_columns(self, backwards: bool = False):
"""
Add columns, which belong to existing tables
"""
Expand All @@ -768,7 +773,7 @@ async def _run_add_columns(self, backwards=False):
# be deleted.
continue

_Table: t.Type[Table] = create_table_class(
_Table = create_table_class(
class_name=add_column.table_class_name,
class_kwargs={
"tablename": add_column.tablename,
Expand All @@ -790,7 +795,7 @@ async def _run_add_columns(self, backwards=False):

# Define the table, with the columns, so the metaclass
# sets up the columns correctly.
_Table: t.Type[Table] = create_table_class(
_Table = create_table_class(
class_name=add_columns[0].table_class_name,
class_kwargs={
"tablename": add_columns[0].tablename,
Expand Down Expand Up @@ -818,7 +823,7 @@ async def _run_add_columns(self, backwards=False):
_Table.create_index([add_column.column])
)

async def _run_change_table_schema(self, backwards=False):
async def _run_change_table_schema(self, backwards: bool = False):
from piccolo.schema import SchemaManager

schema_manager = SchemaManager()
Expand All @@ -827,15 +832,19 @@ async def _run_change_table_schema(self, backwards=False):
if backwards:
# Note, we don't try dropping any schemas we may have created.
# It's dangerous to do so, just in case the user manually
# added tables etc to the scheme, and we delete them.
# added tables etc to the schema, and we delete them.

if change_table_schema.old_schema not in (None, "public"):
if (
change_table_schema.old_schema
and change_table_schema.old_schema != "public"
):
await self._run_query(
schema_manager.create_schema(
schema_name=change_table_schema.old_schema,
if_not_exists=True,
)
)

await self._run_query(
schema_manager.move_table(
table_name=change_table_schema.tablename,
Expand All @@ -845,7 +854,10 @@ async def _run_change_table_schema(self, backwards=False):
)

else:
if change_table_schema.new_schema not in (None, "public"):
if (
change_table_schema.new_schema
and change_table_schema.new_schema != "public"
):
await self._run_query(
schema_manager.create_schema(
schema_name=change_table_schema.new_schema,
Expand All @@ -861,7 +873,7 @@ async def _run_change_table_schema(self, backwards=False):
)
)

async def run(self, backwards=False):
async def run(self, backwards: bool = False):
direction = "backwards" if backwards else "forwards"
if self.preview:
direction = "preview " + direction
Expand All @@ -873,7 +885,6 @@ async def run(self, backwards=False):
raise Exception("Can't find engine")

async with engine.transaction():

if not self.preview:
if direction == "backwards":
raw_list = self.raw_backwards
Expand Down
12 changes: 6 additions & 6 deletions piccolo/apps/migrations/auto/schema_differ.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class SchemaDiffer:

###########################################################################

def __post_init__(self):
def __post_init__(self) -> None:
self.schema_snapshot_map: t.Dict[str, DiffableTable] = {
i.class_name: i for i in self.schema_snapshot
}
Expand Down Expand Up @@ -270,7 +270,6 @@ def check_renamed_columns(self) -> RenameColumnCollection:
used_drop_column_names: t.List[str] = []

for add_column in delta.add_columns:

for drop_column in delta.drop_columns:
if drop_column.column_name in used_drop_column_names:
continue
Expand Down Expand Up @@ -455,10 +454,11 @@ def _get_snapshot_table(
class_name = self.rename_tables_collection.renamed_from(
table_class_name
)
snapshot_table = self.schema_snapshot_map.get(class_name)
if snapshot_table:
snapshot_table.class_name = table_class_name
return snapshot_table
if class_name:
snapshot_table = self.schema_snapshot_map.get(class_name)
if snapshot_table:
snapshot_table.class_name = table_class_name
return snapshot_table
return None

@property
Expand Down
4 changes: 1 addition & 3 deletions piccolo/apps/migrations/auto/serialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def __hash__(self):
def __eq__(self, other):
return check_equality(self, other)

def __repr__(self):
def __repr__(self) -> str:
tablename = self.table_type._meta.tablename

# We have to add the primary key column definition too, so foreign
Expand Down Expand Up @@ -493,15 +493,13 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams:
extra_definitions: t.List[Definition] = []

for key, value in params.items():

# Builtins, such as str, list and dict.
if inspect.getmodule(value) == builtins:
params[key] = SerialisedBuiltin(builtin=value)
continue

# Column instances
if isinstance(value, Column):

# For target_column (which is used by ForeignKey), we can just
# serialise it as the column name:
if key == "target_column":
Expand Down
20 changes: 12 additions & 8 deletions piccolo/apps/schema/commands/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class TableConstraints:
tablename: str
constraints: t.List[Constraint]

def __post_init__(self):
def __post_init__(self) -> None:
foreign_key_constraints: t.List[Constraint] = []
unique_constraints: t.List[Constraint] = []
primary_key_constraints: t.List[Constraint] = []
Expand Down Expand Up @@ -127,7 +127,8 @@ def get_foreign_key_constraint_name(self, column_name) -> ConstraintTable:
for i in self.foreign_key_constraints:
if i.column_name == column_name:
return ConstraintTable(
name=i.constraint_name, schema=i.constraint_schema
name=i.constraint_name,
schema=i.constraint_schema or "public",
)

raise ValueError("No matching constraint found")
Expand Down Expand Up @@ -307,7 +308,7 @@ def __add__(self, value: OutputSchema) -> OutputSchema:
}

# Re-map for Cockroach compatibility.
COLUMN_TYPE_MAP_COCKROACH = {
COLUMN_TYPE_MAP_COCKROACH: t.Dict[str, t.Type[Column]] = {
**COLUMN_TYPE_MAP,
**{"integer": BigInt, "json": JSONB},
}
Expand Down Expand Up @@ -374,14 +375,13 @@ def __add__(self, value: OutputSchema) -> OutputSchema:
# Re-map for Cockroach compatibility.
COLUMN_DEFAULT_PARSER_COCKROACH = {
**COLUMN_DEFAULT_PARSER,
**{BigInt: re.compile(r"^(?P<value>-?\d+)$")},
BigInt: re.compile(r"^(?P<value>-?\d+)$"),
}


def get_column_default(
column_type: t.Type[Column], column_default: str, engine_type: str
) -> t.Any:

if engine_type == "cockroach":
pat = COLUMN_DEFAULT_PARSER_COCKROACH.get(column_type)
else:
Expand Down Expand Up @@ -462,6 +462,7 @@ def get_column_default(
"gin": IndexMethod.gin,
}


# 'Indices' seems old-fashioned and obscure in this context.
async def get_indexes( # noqa: E302
table_class: t.Type[Table], tablename: str, schema_name: str = "public"
Expand Down Expand Up @@ -786,9 +787,12 @@ async def create_table_class_from_db(
kwargs["length"] = pg_row_meta.character_maximum_length
elif isinstance(column_type, Numeric):
radix = pg_row_meta.numeric_precision_radix
precision = int(str(pg_row_meta.numeric_precision), radix)
scale = int(str(pg_row_meta.numeric_scale), radix)
kwargs["digits"] = (precision, scale)
if radix:
precision = int(str(pg_row_meta.numeric_precision), radix)
scale = int(str(pg_row_meta.numeric_scale), radix)
kwargs["digits"] = (precision, scale)
else:
kwargs["digits"] = None

if column_default:
default_value = get_column_default(
Expand Down
Loading

0 comments on commit b8026da

Please sign in to comment.