diff --git a/piccolo/apps/migrations/auto/diffable_table.py b/piccolo/apps/migrations/auto/diffable_table.py index 89b312018..aa609f041 100644 --- a/piccolo/apps/migrations/auto/diffable_table.py +++ b/piccolo/apps/migrations/auto/diffable_table.py @@ -39,7 +39,6 @@ def compare_dicts( output = {} for key, value in dict_1.items(): - dict_2_value = dict_2.get(key, ...) if ( dict_2_value is not ... @@ -99,7 +98,7 @@ class DiffableTable: columns: t.List[Column] = field(default_factory=list) previous_class_name: t.Optional[str] = None - def __post_init__(self): + def __post_init__(self) -> None: self.columns_map: t.Dict[str, Column] = { i._meta.name: i for i in self.columns } diff --git a/piccolo/apps/migrations/auto/migration_manager.py b/piccolo/apps/migrations/auto/migration_manager.py index 8c2e9548c..5c18cf89f 100644 --- a/piccolo/apps/migrations/auto/migration_manager.py +++ b/piccolo/apps/migrations/auto/migration_manager.py @@ -18,6 +18,7 @@ from piccolo.engine import engine_finder from piccolo.query import Query from piccolo.query.base import DDL +from piccolo.schema import SchemaDDLBase from piccolo.table import Table, create_table_class, sort_table_classes from piccolo.utils.warnings import colored_warning @@ -123,6 +124,9 @@ def table_class_names(self) -> t.List[str]: return list({i.table_class_name for i in self.alter_columns}) +AsyncFunction = t.Callable[[], t.Coroutine] + + @dataclass class MigrationManager: """ @@ -152,8 +156,10 @@ class MigrationManager: alter_columns: AlterColumnCollection = field( default_factory=AlterColumnCollection ) - raw: t.List[t.Union[t.Callable, t.Coroutine]] = field(default_factory=list) - raw_backwards: t.List[t.Union[t.Callable, t.Coroutine]] = field( + raw: t.List[t.Union[t.Callable, AsyncFunction]] = field( + default_factory=list + ) + raw_backwards: t.List[t.Union[t.Callable, AsyncFunction]] = field( default_factory=list ) @@ -227,7 +233,7 @@ def add_column( db_column_name: t.Optional[str] = None, column_class_name: str = "", column_class: t.Optional[t.Type[Column]] = None, - params: t.Dict[str, t.Any] = None, + params: t.Optional[t.Dict[str, t.Any]] = None, schema: t.Optional[str] = None, ): """ @@ -309,8 +315,8 @@ def alter_column( tablename: str, column_name: str, db_column_name: t.Optional[str] = None, - params: t.Dict[str, t.Any] = None, - old_params: t.Dict[str, t.Any] = None, + params: t.Optional[t.Dict[str, t.Any]] = None, + old_params: t.Optional[t.Dict[str, t.Any]] = None, column_class: t.Optional[t.Type[Column]] = None, old_column_class: t.Optional[t.Type[Column]] = None, schema: t.Optional[str] = None, @@ -336,14 +342,14 @@ def alter_column( ) ) - def add_raw(self, raw: t.Union[t.Callable, t.Coroutine]): + def add_raw(self, raw: t.Union[t.Callable, AsyncFunction]): """ A migration manager can execute arbitrary functions or coroutines when run. This is useful if you want to execute raw SQL. """ self.raw.append(raw) - def add_raw_backwards(self, raw: t.Union[t.Callable, t.Coroutine]): + def add_raw_backwards(self, raw: t.Union[t.Callable, AsyncFunction]): """ When reversing a migration, you may want to run extra code to help clean up. @@ -387,13 +393,13 @@ async def get_table_from_snapshot( ########################################################################### @staticmethod - async def _print_query(query: t.Union[DDL, Query]): + async def _print_query(query: t.Union[DDL, Query, SchemaDDLBase]): if isinstance(query, DDL): print("\n", ";".join(query.ddl) + ";") else: print(str(query)) - async def _run_query(self, query: t.Union[DDL, Query]): + async def _run_query(self, query: t.Union[DDL, Query, SchemaDDLBase]): """ If MigrationManager is not in the preview mode, executes the queries. else, prints the query. @@ -403,7 +409,7 @@ async def _run_query(self, query: t.Union[DDL, Query]): else: await query.run() - async def _run_alter_columns(self, backwards=False): + async def _run_alter_columns(self, backwards: bool = False): for table_class_name in self.alter_columns.table_class_names: alter_columns = self.alter_columns.for_table_class_name( table_class_name @@ -421,7 +427,6 @@ async def _run_alter_columns(self, backwards=False): ) for alter_column in alter_columns: - params = ( alter_column.old_params if backwards @@ -622,7 +627,7 @@ async def _run_drop_tables(self, backwards=False): diffable_table.to_table_class().alter().drop_table() ) - async def _run_drop_columns(self, backwards=False): + async def _run_drop_columns(self, backwards: bool = False): if backwards: for drop_column in self.drop_columns.drop_columns: _Table = await self.get_table_from_snapshot( @@ -647,7 +652,7 @@ async def _run_drop_columns(self, backwards=False): if not columns: continue - _Table: t.Type[Table] = create_table_class( + _Table = create_table_class( class_name=table_class_name, class_kwargs={ "tablename": columns[0].tablename, @@ -660,7 +665,7 @@ async def _run_drop_columns(self, backwards=False): _Table.alter().drop_column(column=column.column_name) ) - async def _run_rename_tables(self, backwards=False): + async def _run_rename_tables(self, backwards: bool = False): for rename_table in self.rename_tables: class_name = ( rename_table.new_class_name @@ -690,7 +695,7 @@ async def _run_rename_tables(self, backwards=False): _Table.alter().rename_table(new_name=new_tablename) ) - async def _run_rename_columns(self, backwards=False): + async def _run_rename_columns(self, backwards: bool = False): for table_class_name in self.rename_columns.table_class_names: columns = self.rename_columns.for_table_class_name( table_class_name @@ -726,7 +731,7 @@ async def _run_rename_columns(self, backwards=False): ) ) - async def _run_add_tables(self, backwards=False): + async def _run_add_tables(self, backwards: bool = False): table_classes: t.List[t.Type[Table]] = [] for add_table in self.add_tables: add_columns: t.List[ @@ -755,7 +760,7 @@ async def _run_add_tables(self, backwards=False): for _Table in sorted_table_classes: await self._run_query(_Table.create_table()) - async def _run_add_columns(self, backwards=False): + async def _run_add_columns(self, backwards: bool = False): """ Add columns, which belong to existing tables """ @@ -768,7 +773,7 @@ async def _run_add_columns(self, backwards=False): # be deleted. continue - _Table: t.Type[Table] = create_table_class( + _Table = create_table_class( class_name=add_column.table_class_name, class_kwargs={ "tablename": add_column.tablename, @@ -790,7 +795,7 @@ async def _run_add_columns(self, backwards=False): # Define the table, with the columns, so the metaclass # sets up the columns correctly. - _Table: t.Type[Table] = create_table_class( + _Table = create_table_class( class_name=add_columns[0].table_class_name, class_kwargs={ "tablename": add_columns[0].tablename, @@ -818,7 +823,7 @@ async def _run_add_columns(self, backwards=False): _Table.create_index([add_column.column]) ) - async def _run_change_table_schema(self, backwards=False): + async def _run_change_table_schema(self, backwards: bool = False): from piccolo.schema import SchemaManager schema_manager = SchemaManager() @@ -827,15 +832,19 @@ async def _run_change_table_schema(self, backwards=False): if backwards: # Note, we don't try dropping any schemas we may have created. # It's dangerous to do so, just in case the user manually - # added tables etc to the scheme, and we delete them. + # added tables etc to the schema, and we delete them. - if change_table_schema.old_schema not in (None, "public"): + if ( + change_table_schema.old_schema + and change_table_schema.old_schema != "public" + ): await self._run_query( schema_manager.create_schema( schema_name=change_table_schema.old_schema, if_not_exists=True, ) ) + await self._run_query( schema_manager.move_table( table_name=change_table_schema.tablename, @@ -845,7 +854,10 @@ async def _run_change_table_schema(self, backwards=False): ) else: - if change_table_schema.new_schema not in (None, "public"): + if ( + change_table_schema.new_schema + and change_table_schema.new_schema != "public" + ): await self._run_query( schema_manager.create_schema( schema_name=change_table_schema.new_schema, @@ -861,7 +873,7 @@ async def _run_change_table_schema(self, backwards=False): ) ) - async def run(self, backwards=False): + async def run(self, backwards: bool = False): direction = "backwards" if backwards else "forwards" if self.preview: direction = "preview " + direction @@ -873,7 +885,6 @@ async def run(self, backwards=False): raise Exception("Can't find engine") async with engine.transaction(): - if not self.preview: if direction == "backwards": raw_list = self.raw_backwards diff --git a/piccolo/apps/migrations/auto/schema_differ.py b/piccolo/apps/migrations/auto/schema_differ.py index 2ee4bd3e1..1d095b938 100644 --- a/piccolo/apps/migrations/auto/schema_differ.py +++ b/piccolo/apps/migrations/auto/schema_differ.py @@ -123,7 +123,7 @@ class SchemaDiffer: ########################################################################### - def __post_init__(self): + def __post_init__(self) -> None: self.schema_snapshot_map: t.Dict[str, DiffableTable] = { i.class_name: i for i in self.schema_snapshot } @@ -270,7 +270,6 @@ def check_renamed_columns(self) -> RenameColumnCollection: used_drop_column_names: t.List[str] = [] for add_column in delta.add_columns: - for drop_column in delta.drop_columns: if drop_column.column_name in used_drop_column_names: continue @@ -455,10 +454,11 @@ def _get_snapshot_table( class_name = self.rename_tables_collection.renamed_from( table_class_name ) - snapshot_table = self.schema_snapshot_map.get(class_name) - if snapshot_table: - snapshot_table.class_name = table_class_name - return snapshot_table + if class_name: + snapshot_table = self.schema_snapshot_map.get(class_name) + if snapshot_table: + snapshot_table.class_name = table_class_name + return snapshot_table return None @property diff --git a/piccolo/apps/migrations/auto/serialisation.py b/piccolo/apps/migrations/auto/serialisation.py index 320cf74f7..d1fd5ee47 100644 --- a/piccolo/apps/migrations/auto/serialisation.py +++ b/piccolo/apps/migrations/auto/serialisation.py @@ -347,7 +347,7 @@ def __hash__(self): def __eq__(self, other): return check_equality(self, other) - def __repr__(self): + def __repr__(self) -> str: tablename = self.table_type._meta.tablename # We have to add the primary key column definition too, so foreign @@ -493,7 +493,6 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: extra_definitions: t.List[Definition] = [] for key, value in params.items(): - # Builtins, such as str, list and dict. if inspect.getmodule(value) == builtins: params[key] = SerialisedBuiltin(builtin=value) @@ -501,7 +500,6 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: # Column instances if isinstance(value, Column): - # For target_column (which is used by ForeignKey), we can just # serialise it as the column name: if key == "target_column": diff --git a/piccolo/apps/schema/commands/generate.py b/piccolo/apps/schema/commands/generate.py index 2f2fc9596..20aea360d 100644 --- a/piccolo/apps/schema/commands/generate.py +++ b/piccolo/apps/schema/commands/generate.py @@ -91,7 +91,7 @@ class TableConstraints: tablename: str constraints: t.List[Constraint] - def __post_init__(self): + def __post_init__(self) -> None: foreign_key_constraints: t.List[Constraint] = [] unique_constraints: t.List[Constraint] = [] primary_key_constraints: t.List[Constraint] = [] @@ -127,7 +127,8 @@ def get_foreign_key_constraint_name(self, column_name) -> ConstraintTable: for i in self.foreign_key_constraints: if i.column_name == column_name: return ConstraintTable( - name=i.constraint_name, schema=i.constraint_schema + name=i.constraint_name, + schema=i.constraint_schema or "public", ) raise ValueError("No matching constraint found") @@ -307,7 +308,7 @@ def __add__(self, value: OutputSchema) -> OutputSchema: } # Re-map for Cockroach compatibility. -COLUMN_TYPE_MAP_COCKROACH = { +COLUMN_TYPE_MAP_COCKROACH: t.Dict[str, t.Type[Column]] = { **COLUMN_TYPE_MAP, **{"integer": BigInt, "json": JSONB}, } @@ -374,14 +375,13 @@ def __add__(self, value: OutputSchema) -> OutputSchema: # Re-map for Cockroach compatibility. COLUMN_DEFAULT_PARSER_COCKROACH = { **COLUMN_DEFAULT_PARSER, - **{BigInt: re.compile(r"^(?P-?\d+)$")}, + BigInt: re.compile(r"^(?P-?\d+)$"), } def get_column_default( column_type: t.Type[Column], column_default: str, engine_type: str ) -> t.Any: - if engine_type == "cockroach": pat = COLUMN_DEFAULT_PARSER_COCKROACH.get(column_type) else: @@ -462,6 +462,7 @@ def get_column_default( "gin": IndexMethod.gin, } + # 'Indices' seems old-fashioned and obscure in this context. async def get_indexes( # noqa: E302 table_class: t.Type[Table], tablename: str, schema_name: str = "public" @@ -786,9 +787,12 @@ async def create_table_class_from_db( kwargs["length"] = pg_row_meta.character_maximum_length elif isinstance(column_type, Numeric): radix = pg_row_meta.numeric_precision_radix - precision = int(str(pg_row_meta.numeric_precision), radix) - scale = int(str(pg_row_meta.numeric_scale), radix) - kwargs["digits"] = (precision, scale) + if radix: + precision = int(str(pg_row_meta.numeric_precision), radix) + scale = int(str(pg_row_meta.numeric_scale), radix) + kwargs["digits"] = (precision, scale) + else: + kwargs["digits"] = None if column_default: default_value = get_column_default( diff --git a/piccolo/apps/shell/commands/run.py b/piccolo/apps/shell/commands/run.py index 566c36c29..38cd1af66 100644 --- a/piccolo/apps/shell/commands/run.py +++ b/piccolo/apps/shell/commands/run.py @@ -1,7 +1,7 @@ import sys import typing as t -from piccolo.conf.apps import AppConfig, AppRegistry, Finder +from piccolo.conf.apps import Finder from piccolo.table import Table try: @@ -13,9 +13,7 @@ IPYTHON = False -def start_ipython_shell( - **tables: t.Dict[str, t.Type[Table]] -): # pragma: no cover +def start_ipython_shell(**tables: t.Type[Table]): # pragma: no cover if not IPYTHON: sys.exit( "Install iPython using `pip install ipython` to use this feature." @@ -29,12 +27,12 @@ def start_ipython_shell( IPython.embed(using=_asyncio_runner, colors="neutral") -def run(): +def run() -> None: """ Runs an iPython shell, and automatically imports all of the Table classes from your project. """ - app_registry: AppRegistry = Finder().get_app_registry() + app_registry = Finder().get_app_registry() tables = {} if app_registry.app_configs: @@ -43,7 +41,6 @@ def run(): print(spacer) for app_name, app_config in app_registry.app_configs.items(): - app_config: AppConfig = app_config print(f"Importing {app_name} tables:") if app_config.table_classes: for table_class in sorted( diff --git a/piccolo/apps/sql_shell/commands/run.py b/piccolo/apps/sql_shell/commands/run.py index dd3c09d17..7de03dfcd 100644 --- a/piccolo/apps/sql_shell/commands/run.py +++ b/piccolo/apps/sql_shell/commands/run.py @@ -1,22 +1,20 @@ import os import signal import subprocess +import sys import typing as t from piccolo.engine.finder import engine_finder from piccolo.engine.postgres import PostgresEngine from piccolo.engine.sqlite import SQLiteEngine -if t.TYPE_CHECKING: # pragma: no cover - from piccolo.engine.base import Engine - -def run(): +def run() -> None: """ Launch the SQL shell for the configured engine. For Postgres this will be psql, and for SQLite it will be sqlite3. """ - engine: t.Optional[Engine] = engine_finder() + engine = engine_finder() if engine is None: raise ValueError( @@ -26,7 +24,7 @@ def run(): # Heavily inspired by Django's dbshell command if isinstance(engine, PostgresEngine): - engine: PostgresEngine = engine + engine = t.cast(PostgresEngine, engine) args = ["psql"] @@ -42,7 +40,8 @@ def run(): args += ["-h", host] if port: args += ["-p", str(port)] - args += [database] + if database: + args += [database] sigint_handler = signal.getsignal(signal.SIGINT) subprocess_env = os.environ.copy() @@ -58,8 +57,11 @@ def run(): signal.signal(signal.SIGINT, sigint_handler) elif isinstance(engine, SQLiteEngine): - engine: SQLiteEngine = engine + engine = t.cast(SQLiteEngine, engine) + + database = t.cast(str, engine.connection_kwargs.get("database")) + if not database: + sys.exit("Unable to determine which database to connect to.") + print("Enter .quit to exit") - subprocess.run( - ["sqlite3", engine.connection_kwargs.get("database")], check=True - ) + subprocess.run(["sqlite3", database], check=True) diff --git a/piccolo/columns/base.py b/piccolo/columns/base.py index 8c5b84064..eb4dbb693 100644 --- a/piccolo/columns/base.py +++ b/piccolo/columns/base.py @@ -469,6 +469,7 @@ class Band(Table): """ value_type: t.Type = int + default: t.Any def __init__( self, diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 9934eebed..4b316eba4 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -1959,7 +1959,7 @@ def copy(self) -> ForeignKey: return column def all_columns( - self, exclude: t.List[t.Union[Column, str]] = None + self, exclude: t.Optional[t.List[t.Union[Column, str]]] = None ) -> t.List[Column]: """ Allow a user to access all of the columns on the related table. This is @@ -2010,7 +2010,7 @@ def all_columns( ] def all_related( - self, exclude: t.List[t.Union[ForeignKey, str]] = None + self, exclude: t.Optional[t.List[t.Union[ForeignKey, str]]] = None ) -> t.List[ForeignKey]: """ Returns each ``ForeignKey`` column on the related table. This is @@ -2065,7 +2065,7 @@ class Tour(Table): if fk_column._meta.name not in excluded_column_names ] - def set_proxy_columns(self): + def set_proxy_columns(self) -> None: """ In order to allow a fluent interface, where tables can be traversed using ForeignKeys (e.g. ``Band.manager.name``), we add attributes to diff --git a/piccolo/columns/m2m.py b/piccolo/columns/m2m.py index 69a647244..0eefd22e7 100644 --- a/piccolo/columns/m2m.py +++ b/piccolo/columns/m2m.py @@ -252,7 +252,7 @@ class M2MAddRelated: rows: t.Sequence[Table] extra_column_values: t.Dict[t.Union[Column, str], t.Any] - def __post_init__(self): + def __post_init__(self) -> None: # Normalise `extra_column_values`, so we just have the column names. self.extra_column_values: t.Dict[str, t.Any] = { i._meta.name if isinstance(i, Column) else i: j diff --git a/piccolo/conf/apps.py b/piccolo/conf/apps.py index b88d50ba2..df8ddad10 100644 --- a/piccolo/conf/apps.py +++ b/piccolo/conf/apps.py @@ -6,10 +6,12 @@ import pathlib import traceback import typing as t +from abc import abstractmethod from dataclasses import dataclass, field from importlib import import_module from types import ModuleType +from piccolo.apps.migrations.auto.migration_manager import MigrationManager from piccolo.engine.base import Engine from piccolo.table import Table from piccolo.utils.graphlib import TopologicalSorter @@ -22,8 +24,9 @@ class MigrationModule(ModuleType): DESCRIPTION: str @staticmethod - async def forwards() -> None: - pass + @abstractmethod + async def forwards() -> MigrationManager: + ... class PiccoloAppModule(ModuleType): @@ -32,8 +35,8 @@ class PiccoloAppModule(ModuleType): def table_finder( modules: t.Sequence[str], - include_tags: t.Sequence[str] = None, - exclude_tags: t.Sequence[str] = None, + include_tags: t.Optional[t.Sequence[str]] = None, + exclude_tags: t.Optional[t.Sequence[str]] = None, exclude_imported: bool = False, ) -> t.List[t.Type[Table]]: """ @@ -151,11 +154,7 @@ class AppConfig: default_factory=list ) - def __post_init__(self): - self.commands = [ - i if isinstance(i, Command) else Command(i) for i in self.commands - ] - + def __post_init__(self) -> None: if isinstance(self.migrations_folder_path, pathlib.Path): self.migrations_folder_path = str(self.migrations_folder_path) @@ -167,6 +166,11 @@ def register_table(self, table_class: t.Type[Table]): self.table_classes.append(table_class) return table_class + def get_commands(self) -> t.List[Command]: + return [ + i if isinstance(i, Command) else Command(i) for i in self.commands + ] + @property def migration_dependency_app_configs(self) -> t.List[AppConfig]: """ @@ -176,7 +180,6 @@ def migration_dependency_app_configs(self) -> t.List[AppConfig]: # We cache the value so it's more efficient, and also so we can set the # underlying value in unit tests for easier mocking. if self._migration_dependency_app_configs is None: - modules: t.List[PiccoloAppModule] = [ t.cast(PiccoloAppModule, import_module(module_path)) for module_path in self.migration_dependencies @@ -214,7 +217,7 @@ class AppRegistry: """ - def __init__(self, apps: t.List[str] = None): + def __init__(self, apps: t.Optional[t.List[str]] = None): self.apps = apps or [] self.app_configs: t.Dict[str, AppConfig] = {} app_names = [] diff --git a/piccolo/engine/cockroach.py b/piccolo/engine/cockroach.py index 6c5019531..ecbb74ad8 100644 --- a/piccolo/engine/cockroach.py +++ b/piccolo/engine/cockroach.py @@ -25,7 +25,7 @@ def __init__( extensions: t.Sequence[str] = (), log_queries: bool = False, log_responses: bool = False, - extra_nodes: t.Dict[str, CockroachEngine] = None, + extra_nodes: t.Optional[t.Dict[str, CockroachEngine]] = None, ) -> None: super().__init__( config=config, diff --git a/piccolo/engine/postgres.py b/piccolo/engine/postgres.py index b5c179703..06b8ffb4b 100644 --- a/piccolo/engine/postgres.py +++ b/piccolo/engine/postgres.py @@ -22,7 +22,6 @@ @dataclass class AsyncBatch(Batch): - connection: Connection query: Query batch_size: int @@ -93,9 +92,9 @@ class Atomic: def __init__(self, engine: PostgresEngine): self.engine = engine - self.queries: t.List[Query] = [] + self.queries: t.List[t.Union[Query, DDL]] = [] - def add(self, *query: Query): + def add(self, *query: t.Union[Query, DDL]): self.queries += list(query) async def run(self): @@ -348,7 +347,7 @@ def __init__( extensions: t.Sequence[str] = ("uuid-ossp",), log_queries: bool = False, log_responses: bool = False, - extra_nodes: t.Mapping[str, PostgresEngine] = None, + extra_nodes: t.Optional[t.Mapping[str, PostgresEngine]] = None, ) -> None: if extra_nodes is None: extra_nodes = {} @@ -489,7 +488,9 @@ async def batch( ########################################################################### - async def _run_in_pool(self, query: str, args: t.Sequence[t.Any] = None): + async def _run_in_pool( + self, query: str, args: t.Optional[t.Sequence[t.Any]] = None + ): if args is None: args = [] if not self.pool: @@ -501,7 +502,7 @@ async def _run_in_pool(self, query: str, args: t.Sequence[t.Any] = None): return response async def _run_in_new_connection( - self, query: str, args: t.Sequence[t.Any] = None + self, query: str, args: t.Optional[t.Sequence[t.Any]] = None ): if args is None: args = [] diff --git a/piccolo/engine/sqlite.py b/piccolo/engine/sqlite.py index 7d0b3eae2..ffc6606b0 100644 --- a/piccolo/engine/sqlite.py +++ b/piccolo/engine/sqlite.py @@ -196,7 +196,6 @@ def convert_M2M_out(value: bytes) -> t.List: @dataclass class AsyncBatch(Batch): - connection: Connection query: Query batch_size: int @@ -448,7 +447,6 @@ def dict_factory(cursor, row) -> t.Dict: class SQLiteEngine(Engine[t.Optional[SQLiteTransaction]]): - __slots__ = ( "connection_kwargs", "current_transaction", @@ -585,7 +583,7 @@ async def _get_inserted_pk(self, cursor, table: t.Type[Table]) -> t.Any: async def _run_in_new_connection( self, query: str, - args: t.List[t.Any] = None, + args: t.Optional[t.List[t.Any]] = None, query_type: str = "generic", table: t.Optional[t.Type[Table]] = None, ): @@ -611,7 +609,7 @@ async def _run_in_existing_connection( self, connection, query: str, - args: t.List[t.Any] = None, + args: t.Optional[t.List[t.Any]] = None, query_type: str = "generic", table: t.Optional[t.Type[Table]] = None, ): diff --git a/piccolo/main.py b/piccolo/main.py index 0766a0aa7..556b86f59 100644 --- a/piccolo/main.py +++ b/piccolo/main.py @@ -34,7 +34,7 @@ def get_diagnose_flag() -> bool: return DIAGNOSE_FLAG in sys.argv -def main(): +def main() -> None: """ The entrypoint to the Piccolo CLI. """ @@ -72,7 +72,7 @@ def main(): tester_config, user_config, ]: - for command in _app_config.commands: + for command in _app_config.get_commands(): cli.register( command.callable, group_name=_app_config.app_name, @@ -92,12 +92,14 @@ def main(): ) else: for app_name, _app_config in APP_REGISTRY.app_configs.items(): - for command in _app_config.commands: + for command in _app_config.get_commands(): if cli.command_exists( - group_name=app_name, command_name=command.callable.__name__ + group_name=app_name, + command_name=command.callable.__name__, ): # Skipping - already registered. continue + cli.register( command.callable, group_name=app_name, diff --git a/piccolo/query/base.py b/piccolo/query/base.py index a7761b8c6..b10d42ee0 100644 --- a/piccolo/query/base.py +++ b/piccolo/query/base.py @@ -26,7 +26,6 @@ def __exit__(self, exception_type, exception, traceback): class Query(t.Generic[TableInstance, QueryResponseType]): - __slots__ = ("table", "_frozen_querystrings") def __init__( @@ -45,7 +44,7 @@ def engine_type(self) -> str: else: raise ValueError("Engine isn't defined.") - async def _process_results(self, results): + async def _process_results(self, results) -> QueryResponseType: if results: keys = results[0].keys() keys = [i.replace("$", ".") for i in keys] @@ -118,14 +117,20 @@ async def _process_results(self, results): if output: if output._output.as_objects: if output._output.nested: - raw = [make_nested_object(row, self.table) for row in raw] + return t.cast( + QueryResponseType, + [make_nested_object(row, self.table) for row in raw], + ) else: - raw = [ - self.table(**columns, _exists_in_db=True) - for columns in raw - ] + return t.cast( + QueryResponseType, + [ + self.table(**columns, _exists_in_db=True) + for columns in raw + ], + ) - return raw + return t.cast(QueryResponseType, raw) def _validate(self): """ @@ -222,7 +227,7 @@ def run_sync( with Timer(): return run_sync(coroutine) - async def response_handler(self, response): + async def response_handler(self, response: t.List) -> t.Any: """ Subclasses can override this to modify the raw response returned by the database driver. @@ -370,7 +375,6 @@ def __str__(self) -> str: class DDL: - __slots__ = ("table",) def __init__(self, table: t.Type[Table], **kwargs): diff --git a/piccolo/query/methods/insert.py b/piccolo/query/methods/insert.py index 5c06169d8..d7c655b80 100644 --- a/piccolo/query/methods/insert.py +++ b/piccolo/query/methods/insert.py @@ -81,7 +81,7 @@ def on_conflict( ########################################################################### - def _raw_response_callback(self, results): + def _raw_response_callback(self, results: t.List): """ Assign the ids of the created rows to the model instances. """ diff --git a/piccolo/query/methods/objects.py b/piccolo/query/methods/objects.py index 6892fc95c..5b1c96002 100644 --- a/piccolo/query/methods/objects.py +++ b/piccolo/query/methods/objects.py @@ -280,7 +280,7 @@ def get(self: Self, where: Combinable) -> Get[TableInstance]: def get_or_create( self: Self, where: Combinable, - defaults: t.Dict[Column, t.Any] = None, + defaults: t.Optional[t.Dict[Column, t.Any]] = None, ) -> GetOrCreate[TableInstance]: if defaults is None: defaults = {} diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index 85b2ac6b4..a00745e48 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -367,7 +367,7 @@ class Select(Query[TableInstance, t.List[t.Dict[str, t.Any]]]): def __init__( self, table: t.Type[TableInstance], - columns_list: t.Sequence[t.Union[Selectable, str]] = None, + columns_list: t.Optional[t.Sequence[t.Union[Selectable, str]]] = None, exclude_secrets: bool = False, **kwargs, ): diff --git a/piccolo/query/methods/update.py b/piccolo/query/methods/update.py index 0d914b858..ff6a10589 100644 --- a/piccolo/query/methods/update.py +++ b/piccolo/query/methods/update.py @@ -20,7 +20,6 @@ class UpdateError(Exception): class Update(Query[TableInstance, t.List[t.Any]]): - __slots__ = ( "force", "returning_delegate", @@ -41,7 +40,9 @@ def __init__( # Clauses def values( - self, values: t.Dict[t.Union[Column, str], t.Any] = None, **kwargs + self, + values: t.Optional[t.Dict[t.Union[Column, str], t.Any]] = None, + **kwargs, ) -> Update: if values is None: values = {} diff --git a/piccolo/table.py b/piccolo/table.py index 92590b2c2..abae84a52 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -224,7 +224,7 @@ def __init_subclass__( cls, tablename: t.Optional[str] = None, db: t.Optional[Engine] = None, - tags: t.List[str] = None, + tags: t.Optional[t.List[str]] = None, help_text: t.Optional[str] = None, schema: t.Optional[str] = None, ): # sourcery no-metrics @@ -364,7 +364,7 @@ def __init_subclass__( def __init__( self, - _data: t.Dict[Column, t.Any] = None, + _data: t.Optional[t.Dict[Column, t.Any]] = None, _ignore_missing: bool = False, _exists_in_db: bool = False, **kwargs, @@ -826,7 +826,7 @@ def __repr__(self) -> str: @classmethod def all_related( - cls, exclude: t.List[t.Union[str, ForeignKey]] = None + cls, exclude: t.Optional[t.List[t.Union[str, ForeignKey]]] = None ) -> t.List[Column]: """ Used in conjunction with ``objects`` queries. Just as we can use @@ -876,7 +876,7 @@ def all_related( @classmethod def all_columns( - cls, exclude: t.Sequence[t.Union[str, Column]] = None + cls, exclude: t.Optional[t.Sequence[t.Union[str, Column]]] = None ) -> t.List[Column]: """ Used in conjunction with ``select`` queries. Just as we can use @@ -1120,7 +1120,6 @@ def count( column: t.Optional[Column] = None, distinct: t.Optional[t.Sequence[Column]] = None, ) -> Count: - """ Count the number of matching rows:: @@ -1191,7 +1190,7 @@ def table_exists(cls) -> TableExists: @classmethod def update( cls, - values: t.Dict[t.Union[Column, str], t.Any] = None, + values: t.Optional[t.Dict[t.Union[Column, str], t.Any]] = None, force: bool = False, use_auto_update: bool = True, **kwargs, @@ -1303,7 +1302,7 @@ def _get_index_name(cls, column_names: t.List[str]) -> str: @classmethod def _table_str( - cls, abbreviated=False, excluded_params: t.List[str] = None + cls, abbreviated=False, excluded_params: t.Optional[t.List[str]] = None ): """ Returns a basic string representation of the table and its columns. diff --git a/piccolo/testing/model_builder.py b/piccolo/testing/model_builder.py index 8010f2139..15b50416c 100644 --- a/piccolo/testing/model_builder.py +++ b/piccolo/testing/model_builder.py @@ -30,7 +30,7 @@ class ModelBuilder: async def build( cls, table_class: t.Type[TableInstance], - defaults: t.Dict[t.Union[Column, str], t.Any] = None, + defaults: t.Optional[t.Dict[t.Union[Column, str], t.Any]] = None, persist: bool = True, minimal: bool = False, ) -> TableInstance: @@ -81,7 +81,7 @@ async def build( def build_sync( cls, table_class: t.Type[TableInstance], - defaults: t.Dict[t.Union[Column, str], t.Any] = None, + defaults: t.Optional[t.Dict[t.Union[Column, str], t.Any]] = None, persist: bool = True, minimal: bool = False, ) -> TableInstance: @@ -101,7 +101,7 @@ def build_sync( async def _build( cls, table_class: t.Type[TableInstance], - defaults: t.Dict[t.Union[Column, str], t.Any] = None, + defaults: t.Optional[t.Dict[t.Union[Column, str], t.Any]] = None, minimal: bool = False, persist: bool = True, ) -> TableInstance: @@ -115,7 +115,6 @@ async def _build( setattr(model, column._meta.name, value) for column in model._meta.columns: - if column._meta.null and minimal: continue diff --git a/piccolo/utils/pydantic.py b/piccolo/utils/pydantic.py index 9f88ce35a..e245e522d 100644 --- a/piccolo/utils/pydantic.py +++ b/piccolo/utils/pydantic.py @@ -24,6 +24,11 @@ from piccolo.table import Table from piccolo.utils.encoding import load_json +try: + from pydantic.config import JsonDict +except ImportError: + JsonDict = dict # type: ignore + def pydantic_json_validator(value: t.Optional[str], required: bool = True): if value is None: @@ -243,7 +248,7 @@ def create_pydantic_model( if column._meta.db_column_name != column._meta.name: params["alias"] = column._meta.db_column_name - extra = { + extra: JsonDict = { "help_text": column._meta.help_text, "choices": column._meta.get_choices_dict(), "secret": column._meta.secret, @@ -320,7 +325,7 @@ def create_pydantic_model( pydantic_config["json_schema_extra"] = dict(json_schema_extra_) - model = pydantic.create_model( # type: ignore + model = pydantic.create_model( model_name, __config__=pydantic_config, __validators__=validators, diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index e8c4f5445..853726ee9 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -5,6 +5,6 @@ flake8==6.1.0 isort==5.10.1 slotscheck==0.17.1 twine==3.8.0 -mypy==0.961 +mypy==1.7.1 pip-upgrader==1.4.15 wheel==0.38.1 diff --git a/tests/apps/migrations/auto/test_schema_differ.py b/tests/apps/migrations/auto/test_schema_differ.py index cf621a916..9cf6d26f2 100644 --- a/tests/apps/migrations/auto/test_schema_differ.py +++ b/tests/apps/migrations/auto/test_schema_differ.py @@ -16,10 +16,9 @@ class TestSchemaDiffer(TestCase): - maxDiff = None - def test_add_table(self): + def test_add_table(self) -> None: """ Test adding a new table. """ @@ -49,7 +48,7 @@ def test_add_table(self): "manager.add_column(table_class_name='Band', tablename='band', column_name='name', db_column_name='name', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False}, schema=None)", # noqa ) - def test_drop_table(self): + def test_drop_table(self) -> None: """ Test dropping an existing table. """ @@ -67,7 +66,7 @@ def test_drop_table(self): "manager.drop_table(class_name='Band', tablename='band', schema=None)", # noqa: E501 ) - def test_rename_table(self): + def test_rename_table(self) -> None: """ Test renaming a table. """ @@ -98,7 +97,7 @@ def test_rename_table(self): self.assertEqual(schema_differ.create_tables.statements, []) self.assertEqual(schema_differ.drop_tables.statements, []) - def test_change_schema(self): + def test_change_schema(self) -> None: """ Testing changing the schema. """ @@ -133,7 +132,7 @@ def test_change_schema(self): self.assertListEqual(schema_differ.create_tables.statements, []) self.assertListEqual(schema_differ.drop_tables.statements, []) - def test_add_column(self): + def test_add_column(self) -> None: """ Test adding a column to an existing table. """ @@ -168,7 +167,7 @@ def test_add_column(self): "manager.add_column(table_class_name='Band', tablename='band', column_name='genre', db_column_name='genre', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False}, schema=None)", # noqa: E501 ) - def test_drop_column(self): + def test_drop_column(self) -> None: """ Test dropping a column from an existing table. """ @@ -203,7 +202,7 @@ def test_drop_column(self): "manager.drop_column(table_class_name='Band', tablename='band', column_name='genre', db_column_name='genre', schema=None)", # noqa: E501 ) - def test_rename_column(self): + def test_rename_column(self) -> None: """ Test renaming a column in an existing table. """ @@ -261,7 +260,7 @@ def test_rename_column(self): self.assertTrue(schema_differ.rename_columns.statements == []) @patch("piccolo.apps.migrations.auto.schema_differ.input") - def test_rename_multiple_columns(self, input: MagicMock): + def test_rename_multiple_columns(self, input: MagicMock) -> None: """ Make sure renaming columns works when several columns have been renamed. @@ -419,7 +418,7 @@ def mock_input(value: str): ], ) - def test_alter_column_precision(self): + def test_alter_column_precision(self) -> None: price_1 = Numeric(digits=(4, 2)) price_1._meta.name = "price" @@ -451,7 +450,7 @@ def test_alter_column_precision(self): "manager.alter_column(table_class_name='Ticket', tablename='ticket', column_name='price', db_column_name='price', params={'digits': (4, 2)}, old_params={'digits': (5, 2)}, column_class=Numeric, old_column_class=Numeric, schema=None)", # noqa ) - def test_db_column_name(self): + def test_db_column_name(self) -> None: """ Make sure alter statements use the ``db_column_name`` if provided. diff --git a/tests/apps/schema/commands/test_generate.py b/tests/apps/schema/commands/test_generate.py index a08521038..29784fa8a 100644 --- a/tests/apps/schema/commands/test_generate.py +++ b/tests/apps/schema/commands/test_generate.py @@ -64,7 +64,7 @@ def _compare_table_columns( # Make sure the unique constraint is the same self.assertEqual(col_1._meta.unique, col_2._meta.unique) - def test_get_output_schema(self): + def test_get_output_schema(self) -> None: """ Make sure that the a Piccolo schema can be generated from the database. """ @@ -75,9 +75,11 @@ def test_get_output_schema(self): self.assertTrue(len(output_schema.imports) > 0) MegaTable_ = output_schema.get_table_with_name("MegaTable") + assert MegaTable_ is not None self._compare_table_columns(MegaTable, MegaTable_) SmallTable_ = output_schema.get_table_with_name("SmallTable") + assert SmallTable_ is not None self._compare_table_columns(SmallTable, SmallTable_) @patch("piccolo.apps.schema.commands.generate.print") @@ -94,7 +96,7 @@ def test_generate_command(self, print_: MagicMock): # Cockroach throws FeatureNotSupportedError, which does not pass this test. @engines_skip("cockroach") - def test_unknown_column_type(self): + def test_unknown_column_type(self) -> None: """ Make sure unknown column types are handled gracefully. """ @@ -119,11 +121,13 @@ class Box(Column): for table in output_schema.tables: if table.__name__ == "MegaTable": self.assertEqual( - output_schema.tables[1].my_column.__class__.__name__, + output_schema.tables[1] + ._meta.get_column_by_name("my_column") + .__class__.__name__, "Column", ) - def test_generate_required_tables(self): + def test_generate_required_tables(self) -> None: """ Make sure only tables passed to `tablenames` are created. """ @@ -132,9 +136,10 @@ def test_generate_required_tables(self): ) self.assertEqual(len(output_schema.tables), 1) SmallTable_ = output_schema.get_table_with_name("SmallTable") + assert SmallTable_ is not None self._compare_table_columns(SmallTable, SmallTable_) - def test_exclude_table(self): + def test_exclude_table(self) -> None: """ Make sure exclude works. """ @@ -143,10 +148,11 @@ def test_exclude_table(self): ) self.assertEqual(len(output_schema.tables), 1) SmallTable_ = output_schema.get_table_with_name("SmallTable") + assert SmallTable_ is not None self._compare_table_columns(SmallTable, SmallTable_) @engines_skip("cockroach") - def test_self_referencing_fk(self): + def test_self_referencing_fk(self) -> None: """ Make sure self-referencing foreign keys are handled correctly. """ @@ -160,12 +166,15 @@ def test_self_referencing_fk(self): # Make sure the 'references' value of the generated column is "self". for table in output_schema.tables: if table.__name__ == "MegaTable": - column: ForeignKey = output_schema.tables[ - 1 - ].self_referencing_fk + column = t.cast( + ForeignKey, + output_schema.tables[1]._meta.get_column_by_name( + "self_referencing_fk" + ), + ) self.assertEqual( - column._foreign_key_meta.references._meta.tablename, + column._foreign_key_meta.resolved_references._meta.tablename, # noqa: E501 MegaTable._meta.tablename, ) self.assertEqual(column._meta.params["references"], "self") @@ -190,23 +199,24 @@ def setUp(self): def tearDown(self): Concert.alter().drop_table(if_exists=True).run_sync() - def test_index(self): + def test_index(self) -> None: """ Make sure that a table with an index is reflected correctly. """ output_schema: OutputSchema = run_sync(get_output_schema()) Concert_ = output_schema.tables[0] - self.assertEqual(Concert_.name._meta.index, True) - self.assertEqual(Concert_.name._meta.index_method, IndexMethod.hash) + name_column = Concert_._meta.get_column_by_name("name") + self.assertTrue(name_column._meta.index) + self.assertEqual(name_column._meta.index_method, IndexMethod.hash) - self.assertEqual(Concert_.time._meta.index, True) - self.assertEqual(Concert_.time._meta.index_method, IndexMethod.btree) + time_column = Concert_._meta.get_column_by_name("time") + self.assertTrue(time_column._meta.index) + self.assertEqual(time_column._meta.index_method, IndexMethod.btree) - self.assertEqual(Concert_.capacity._meta.index, False) - self.assertEqual( - Concert_.capacity._meta.index_method, IndexMethod.btree - ) + capacity_column = Concert_._meta.get_column_by_name("capacity") + self.assertEqual(capacity_column._meta.index, False) + self.assertEqual(capacity_column._meta.index_method, IndexMethod.btree) ############################################################################### @@ -229,7 +239,6 @@ class Book(Table): @engines_only("postgres") class TestGenerateWithSchema(TestCase): - tables = [Publication, Writer, Book] schema_manager = SchemaManager() @@ -250,7 +259,7 @@ def tearDown(self) -> None: schema_name=schema_name, if_exists=True, cascade=True ).run_sync() - def test_reference_to_another_schema(self): + def test_reference_to_another_schema(self) -> None: output_schema: OutputSchema = run_sync(get_output_schema()) self.assertEqual(len(output_schema.tables), 3) publication = output_schema.tables[0] @@ -263,8 +272,10 @@ def test_reference_to_another_schema(self): self.assertEqual(Writer._meta.tablename, writer._meta.tablename) # Make sure foreign key values are correct. - self.assertEqual(writer.publication, publication) - self.assertEqual(book.writer, writer) + self.assertEqual( + writer._meta.get_column_by_name("publication"), publication + ) + self.assertEqual(book._meta.get_column_by_name("writer"), writer) @engines_only("postgres", "cockroach") diff --git a/tests/columns/foreign_key/test_attribute_access.py b/tests/columns/foreign_key/test_attribute_access.py index 87caeb78e..3f9d8afab 100644 --- a/tests/columns/foreign_key/test_attribute_access.py +++ b/tests/columns/foreign_key/test_attribute_access.py @@ -40,7 +40,7 @@ def test_attribute_access(self): for band_table in (BandA, BandB, BandC, BandD): self.assertIsInstance(band_table.manager.name, Varchar) - def test_recursion_limit(self): + def test_recursion_limit(self) -> None: """ When a table has a ForeignKey to itself, an Exception should be raised if the call chain is too large. @@ -51,7 +51,7 @@ def test_recursion_limit(self): self.assertIsInstance(column, Varchar) with self.assertRaises(Exception): - Manager.manager.manager.manager.manager.manager.manager.manager.manager.manager.manager.manager.name # noqa + Manager.manager.manager.manager.manager.manager.manager.manager.manager.manager.manager.manager.name # type: ignore # noqa: E501 def test_recursion_time(self): """ diff --git a/tests/columns/m2m/base.py b/tests/columns/m2m/base.py index f91d8d7d6..a3f282d23 100644 --- a/tests/columns/m2m/base.py +++ b/tests/columns/m2m/base.py @@ -292,7 +292,7 @@ def test_add_m2m(self): Genre = self.genre GenreToBand = self.genre_to_band - band: Band = Band.objects().get(Band.name == "Pythonistas").run_sync() + band = Band.objects().get(Band.name == "Pythonistas").run_sync() band.add_m2m(Genre(name="Punk Rock"), m2m=Band.genres).run_sync() self.assertTrue( @@ -320,7 +320,7 @@ def test_extra_columns_str(self): reason = "Their second album was very punk rock." - band: Band = Band.objects().get(Band.name == "Pythonistas").run_sync() + band = Band.objects().get(Band.name == "Pythonistas").run_sync() band.add_m2m( Genre(name="Punk Rock"), m2m=Band.genres, @@ -351,7 +351,7 @@ def test_extra_columns_class(self): reason = "Their second album was very punk rock." - band: Band = Band.objects().get(Band.name == "Pythonistas").run_sync() + band = Band.objects().get(Band.name == "Pythonistas").run_sync() band.add_m2m( Genre(name="Punk Rock"), m2m=Band.genres, @@ -379,11 +379,9 @@ def test_add_m2m_existing(self): Genre = self.genre GenreToBand = self.genre_to_band - band: Band = Band.objects().get(Band.name == "Pythonistas").run_sync() + band = Band.objects().get(Band.name == "Pythonistas").run_sync() - genre: Genre = ( - Genre.objects().get(Genre.name == "Classical").run_sync() - ) + genre = Genre.objects().get(Genre.name == "Classical").run_sync() band.add_m2m(genre, m2m=Band.genres).run_sync() @@ -408,7 +406,7 @@ def test_get_m2m(self): """ Band = self.band - band: Band = Band.objects().get(Band.name == "Pythonistas").run_sync() + band = Band.objects().get(Band.name == "Pythonistas").run_sync() genres = band.get_m2m(Band.genres).run_sync() @@ -424,7 +422,7 @@ def test_remove_m2m(self): Genre = self.genre GenreToBand = self.genre_to_band - band: Band = Band.objects().get(Band.name == "Pythonistas").run_sync() + band = Band.objects().get(Band.name == "Pythonistas").run_sync() genre = Genre.objects().get(Genre.name == "Rock").run_sync() diff --git a/tests/columns/m2m/test_m2m.py b/tests/columns/m2m/test_m2m.py index aaec9fc84..d897eace6 100644 --- a/tests/columns/m2m/test_m2m.py +++ b/tests/columns/m2m/test_m2m.py @@ -160,9 +160,7 @@ def test_add_m2m(self): """ Make sure we can add items to the joining table. """ - customer: Customer = ( - Customer.objects().get(Customer.name == "Bob").run_sync() - ) + customer = Customer.objects().get(Customer.name == "Bob").run_sync() customer.add_m2m( Concert(name="Jazzfest"), m2m=Customer.concerts ).run_sync() @@ -193,9 +191,7 @@ def test_add_m2m_within_transaction(self): async def add_m2m_in_transaction(): async with engine.transaction(): - customer: Customer = await Customer.objects().get( - Customer.name == "Bob" - ) + customer = await Customer.objects().get(Customer.name == "Bob") await customer.add_m2m( Concert(name="Jazzfest"), m2m=Customer.concerts ) @@ -220,9 +216,7 @@ def test_get_m2m(self): """ Make sure we can get related items via the joining table. """ - customer: Customer = ( - Customer.objects().get(Customer.name == "Bob").run_sync() - ) + customer = Customer.objects().get(Customer.name == "Bob").run_sync() concerts = customer.get_m2m(Customer.concerts).run_sync() diff --git a/tests/engine/test_pool.py b/tests/engine/test_pool.py index adddc6bda..510a77072 100644 --- a/tests/engine/test_pool.py +++ b/tests/engine/test_pool.py @@ -1,6 +1,7 @@ import asyncio import os import tempfile +import typing as t from unittest import TestCase from unittest.mock import call, patch @@ -12,8 +13,8 @@ @engines_only("postgres", "cockroach") class TestPool(DBTestCase): - async def _create_pool(self): - engine: PostgresEngine = Manager._meta.db + async def _create_pool(self) -> None: + engine = t.cast(PostgresEngine, Manager._meta.db) await engine.start_connection_pool() assert engine.pool is not None @@ -70,8 +71,8 @@ def test_many_queries(self): @engines_only("postgres", "cockroach") class TestPoolProxyMethods(DBTestCase): - async def _create_pool(self): - engine: PostgresEngine = Manager._meta.db + async def _create_pool(self) -> None: + engine = t.cast(PostgresEngine, Manager._meta.db) # Deliberate typo ('nnn'): await engine.start_connnection_pool() diff --git a/tests/engine/test_transaction.py b/tests/engine/test_transaction.py index 3cba32c86..4b47f8759 100644 --- a/tests/engine/test_transaction.py +++ b/tests/engine/test_transaction.py @@ -45,12 +45,12 @@ def test_succeeds(self): drop_db_tables_sync(Band, Manager) @engines_only("postgres", "cockroach") - def test_pool(self): + def test_pool(self) -> None: """ Make sure atomic works correctly when a connection pool is active. """ - async def run(): + async def run() -> None: """ We have to run this async function, so we can use a connection pool. diff --git a/tests/example_apps/music/tables.py b/tests/example_apps/music/tables.py index 8e2870ea1..dff416c2f 100644 --- a/tests/example_apps/music/tables.py +++ b/tests/example_apps/music/tables.py @@ -7,6 +7,7 @@ ForeignKey, Integer, Numeric, + Serial, Text, Varchar, ) @@ -21,6 +22,7 @@ class Manager(Table): + id: Serial name = Varchar(length=50) @classmethod @@ -29,6 +31,7 @@ def get_readable(cls) -> Readable: class Band(Table): + id: Serial name = Varchar(length=50) manager = ForeignKey(Manager, null=True) popularity = ( @@ -47,6 +50,7 @@ def get_readable(cls) -> Readable: class Venue(Table): + id: Serial name = Varchar(length=100) capacity = Integer(default=0, secret=True) @@ -56,6 +60,7 @@ def get_readable(cls) -> Readable: class Concert(Table): + id: Serial band_1 = ForeignKey(Band) band_2 = ForeignKey(Band) venue = ForeignKey(Venue) @@ -74,6 +79,7 @@ def get_readable(cls) -> Readable: class Ticket(Table): + id: Serial concert = ForeignKey(Concert) price = Numeric(digits=(5, 2)) @@ -83,6 +89,7 @@ class Poster(Table, tags=["special"]): Has tags for tests which need it. """ + id: Serial content = Text() @@ -96,6 +103,7 @@ class Size(str, Enum): medium = "m" large = "l" + id: Serial size = Varchar(length=1, choices=Size, default=Size.large) @@ -104,5 +112,6 @@ class RecordingStudio(Table): Used for testing JSON and JSONB columns. """ + id: Serial facilities = JSON() facilities_b = JSONB() diff --git a/tests/query/test_freeze.py b/tests/query/test_freeze.py index 29cb5271f..ca916ba5e 100644 --- a/tests/query/test_freeze.py +++ b/tests/query/test_freeze.py @@ -4,7 +4,7 @@ from unittest import mock from piccolo.columns import Integer, Varchar -from piccolo.query.base import Query +from piccolo.query.base import FrozenQuery, Query from piccolo.table import Table from tests.base import AsyncMock, DBTestCase, sqlite_only from tests.example_apps.music.tables import Band @@ -12,12 +12,12 @@ @dataclass class QueryResponse: - query: Query + query: t.Union[Query, FrozenQuery] response: t.Any class TestFreeze(DBTestCase): - def test_frozen_select_queries(self): + def test_frozen_select_queries(self) -> None: """ Make sure a variety of select queries work as expected when frozen. """ diff --git a/tests/table/test_refresh.py b/tests/table/test_refresh.py index 69deb9a73..ffb78ddb8 100644 --- a/tests/table/test_refresh.py +++ b/tests/table/test_refresh.py @@ -7,12 +7,13 @@ def setUp(self): super().setUp() self.insert_rows() - def test_refresh(self): + def test_refresh(self) -> None: """ Make sure ``refresh`` works, with no columns specified. """ # Fetch an instance from the database. - band: Band = Band.objects().get(Band.name == "Pythonistas").run_sync() + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None initial_data = band.to_dict() # Modify the data in the database. @@ -27,12 +28,13 @@ def test_refresh(self): self.assertTrue(band.popularity == 8000) self.assertTrue(band.id == initial_data["id"]) - def test_columns(self): + def test_columns(self) -> None: """ Make sure ``refresh`` works, when columns are specified. """ # Fetch an instance from the database. - band: Band = Band.objects().get(Band.name == "Pythonistas").run_sync() + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None initial_data = band.to_dict() # Modify the data in the database. @@ -52,7 +54,7 @@ def test_columns(self): self.assertTrue(band.popularity == initial_data["popularity"]) self.assertTrue(band.id == initial_data["id"]) - def test_error_when_not_in_db(self): + def test_error_when_not_in_db(self) -> None: """ Make sure we can't refresh an instance which hasn't been saved in the database. @@ -67,12 +69,13 @@ def test_error_when_not_in_db(self): str(manager.exception), ) - def test_error_when_pk_in_none(self): + def test_error_when_pk_in_none(self) -> None: """ Make sure we can't refresh an instance when the primary key value isn't set. """ - band: Band = Band.objects().first().run_sync() + band = Band.objects().first().run_sync() + assert band is not None band.id = None with self.assertRaises(ValueError) as manager: diff --git a/tests/utils/test_pydantic.py b/tests/utils/test_pydantic.py index 35aff0fa4..a6e003f7f 100644 --- a/tests/utils/test_pydantic.py +++ b/tests/utils/test_pydantic.py @@ -819,7 +819,7 @@ class Band(Table): class TestPydanticExtraFields(TestCase): - def test_pydantic_extra_fields(self): + def test_pydantic_extra_fields(self) -> None: """ Make sure that the value of ``extra`` in the config class is correctly propagated to the generated model. @@ -833,7 +833,7 @@ class Band(Table): self.assertEqual(model.model_config["extra"], "forbid") - def test_pydantic_invalid_extra_fields(self): + def test_pydantic_invalid_extra_fields(self) -> None: """ Make sure that invalid values for ``extra`` in the config class are rejected. @@ -842,7 +842,9 @@ def test_pydantic_invalid_extra_fields(self): class Band(Table): name = Varchar() - config: pydantic.config.ConfigDict = {"extra": "foobar"} + config: pydantic.config.ConfigDict = { + "extra": "foobar" # type: ignore + } with pytest.raises(pydantic_core._pydantic_core.SchemaError): create_pydantic_model(Band, pydantic_config=config)