diff --git a/piccolo/apps/fixtures/commands/load.py b/piccolo/apps/fixtures/commands/load.py index 1de1d5a44..64e4c3334 100644 --- a/piccolo/apps/fixtures/commands/load.py +++ b/piccolo/apps/fixtures/commands/load.py @@ -51,7 +51,7 @@ async def load_json_string( finder = Finder() engine = engine_finder() - if not engine: + if engine is None: raise Exception("Unable to find the engine.") # This is what we want to the insert into the database: diff --git a/piccolo/apps/migrations/auto/__init__.py b/piccolo/apps/migrations/auto/__init__.py index cdffc6c1c..1df58816c 100644 --- a/piccolo/apps/migrations/auto/__init__.py +++ b/piccolo/apps/migrations/auto/__init__.py @@ -2,3 +2,11 @@ from .migration_manager import MigrationManager from .schema_differ import AlterStatements, SchemaDiffer from .schema_snapshot import SchemaSnapshot + +__all__ = [ + "DiffableTable", + "MigrationManager", + "AlterStatements", + "SchemaDiffer", + "SchemaSnapshot", +] diff --git a/piccolo/apps/migrations/auto/migration_manager.py b/piccolo/apps/migrations/auto/migration_manager.py index fca36e8e7..e8f4931cb 100644 --- a/piccolo/apps/migrations/auto/migration_manager.py +++ b/piccolo/apps/migrations/auto/migration_manager.py @@ -261,7 +261,8 @@ def add_column( cleaned_params = deserialise_params(params=params) column = column_class(**cleaned_params) column._meta.name = column_name - column._meta.db_column_name = db_column_name + if db_column_name: + column._meta.db_column_name = db_column_name self.add_columns.append( AddColumnClass( diff --git a/piccolo/apps/migrations/commands/backwards.py b/piccolo/apps/migrations/commands/backwards.py index 6627fe8af..363992510 100644 --- a/piccolo/apps/migrations/commands/backwards.py +++ b/piccolo/apps/migrations/commands/backwards.py @@ -32,7 +32,9 @@ def __init__( async def run_migrations_backwards(self, app_config: AppConfig): migration_modules: t.Dict[str, MigrationModule] = ( - self.get_migration_modules(app_config.migrations_folder_path) + self.get_migration_modules( + app_config.resolved_migrations_folder_path + ) ) ran_migration_ids = await Migration.get_migrations_which_ran( diff --git a/piccolo/apps/migrations/commands/base.py b/piccolo/apps/migrations/commands/base.py index a3966f7c3..bcc5cbc55 100644 --- a/piccolo/apps/migrations/commands/base.py +++ b/piccolo/apps/migrations/commands/base.py @@ -86,7 +86,7 @@ async def get_migration_managers( """ migration_managers: t.List[MigrationManager] = [] - migrations_folder = app_config.migrations_folder_path + migrations_folder = app_config.resolved_migrations_folder_path migration_modules: t.Dict[str, MigrationModule] = ( self.get_migration_modules(migrations_folder) diff --git a/piccolo/apps/migrations/commands/check.py b/piccolo/apps/migrations/commands/check.py index fd2b49c3d..53e20840a 100644 --- a/piccolo/apps/migrations/commands/check.py +++ b/piccolo/apps/migrations/commands/check.py @@ -36,7 +36,7 @@ async def get_migration_statuses(self) -> t.List[MigrationStatus]: continue migration_modules = self.get_migration_modules( - app_config.migrations_folder_path + app_config.resolved_migrations_folder_path ) ids = self.get_migration_ids(migration_modules) for _id in ids: diff --git a/piccolo/apps/migrations/commands/clean.py b/piccolo/apps/migrations/commands/clean.py index e7ef22091..687ff64e9 100644 --- a/piccolo/apps/migrations/commands/clean.py +++ b/piccolo/apps/migrations/commands/clean.py @@ -20,7 +20,7 @@ def get_migration_ids_to_remove(self) -> t.List[str]: app_config = self.get_app_config(app_name=self.app_name) migration_module_dict = self.get_migration_modules( - folder_path=app_config.migrations_folder_path + folder_path=app_config.resolved_migrations_folder_path ) # The migration IDs which are in migration modules. diff --git a/piccolo/apps/migrations/commands/forwards.py b/piccolo/apps/migrations/commands/forwards.py index 6d967dd5e..62278060d 100644 --- a/piccolo/apps/migrations/commands/forwards.py +++ b/piccolo/apps/migrations/commands/forwards.py @@ -33,7 +33,9 @@ async def run_migrations(self, app_config: AppConfig) -> MigrationResult: ) migration_modules: t.Dict[str, MigrationModule] = ( - self.get_migration_modules(app_config.migrations_folder_path) + self.get_migration_modules( + app_config.resolved_migrations_folder_path + ) ) ids = self.get_migration_ids(migration_modules) diff --git a/piccolo/apps/migrations/commands/new.py b/piccolo/apps/migrations/commands/new.py index ff123aaa2..082868435 100644 --- a/piccolo/apps/migrations/commands/new.py +++ b/piccolo/apps/migrations/commands/new.py @@ -98,7 +98,9 @@ def _generate_migration_meta(app_config: AppConfig) -> NewMigrationMeta: filename = f"{cleaned_app_name}_{cleaned_id}" - path = os.path.join(app_config.migrations_folder_path, f"{filename}.py") + path = os.path.join( + app_config.resolved_migrations_folder_path, f"{filename}.py" + ) return NewMigrationMeta( migration_id=_id, migration_filename=filename, migration_path=path @@ -255,7 +257,7 @@ async def new( app_config = Finder().get_app_config(app_name=app_name) - _create_migrations_folder(app_config.migrations_folder_path) + _create_migrations_folder(app_config.resolved_migrations_folder_path) try: await _create_new_migration( diff --git a/piccolo/apps/schema/commands/generate.py b/piccolo/apps/schema/commands/generate.py index 20aea360d..da97d247b 100644 --- a/piccolo/apps/schema/commands/generate.py +++ b/piccolo/apps/schema/commands/generate.py @@ -313,7 +313,7 @@ def __add__(self, value: OutputSchema) -> OutputSchema: **{"integer": BigInt, "json": JSONB}, } -COLUMN_DEFAULT_PARSER = { +COLUMN_DEFAULT_PARSER: t.Dict[t.Type[Column], t.Any] = { BigInt: re.compile(r"^'?(?P-?[0-9]\d*)'?(?:::bigint)?$"), Boolean: re.compile(r"^(?Ptrue|false)$"), Bytea: re.compile(r"'(?P.*)'::bytea$"), @@ -373,7 +373,7 @@ def __add__(self, value: OutputSchema) -> OutputSchema: } # Re-map for Cockroach compatibility. -COLUMN_DEFAULT_PARSER_COCKROACH = { +COLUMN_DEFAULT_PARSER_COCKROACH: t.Dict[t.Type[Column], t.Any] = { **COLUMN_DEFAULT_PARSER, BigInt: re.compile(r"^(?P-?\d+)$"), } diff --git a/piccolo/apps/shell/commands/run.py b/piccolo/apps/shell/commands/run.py index 38cd1af66..4f86cc23c 100644 --- a/piccolo/apps/shell/commands/run.py +++ b/piccolo/apps/shell/commands/run.py @@ -24,7 +24,7 @@ def start_ipython_shell(**tables: t.Type[Table]): # pragma: no cover if table_class_name not in existing_global_names: globals()[table_class_name] = table_class - IPython.embed(using=_asyncio_runner, colors="neutral") + IPython.embed(using=_asyncio_runner, colors="neutral") # type: ignore def run() -> None: diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index c248d3b1b..9c74b4a52 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -1956,7 +1956,9 @@ def _setup(self, table_class: t.Type[Table]) -> ForeignKeySetupResponse: if is_table_class: # Record the reverse relationship on the target table. - references._meta._foreign_key_references.append(self) + t.cast( + t.Type[Table], references + )._meta._foreign_key_references.append(self) # Allow columns on the referenced table to be accessed via # auto completion. @@ -2710,7 +2712,7 @@ def all(self, value: t.Any) -> Where: else: raise ValueError("Unrecognised engine type") - def cat(self, value: t.List[t.Any]) -> QueryString: + def cat(self, value: t.Union[t.Any, t.List[t.Any]]) -> QueryString: """ Used in an ``update`` query to append items to an array. @@ -2741,7 +2743,7 @@ def cat(self, value: t.List[t.Any]) -> QueryString: db_column_name = self._meta.db_column_name return QueryString(f'array_cat("{db_column_name}", {{}})', value) - def __add__(self, value: t.List[t.Any]) -> QueryString: + def __add__(self, value: t.Union[t.Any, t.List[t.Any]]) -> QueryString: return self.cat(value) ########################################################################### diff --git a/piccolo/columns/defaults/base.py b/piccolo/columns/defaults/base.py index 9ef45ec93..062162032 100644 --- a/piccolo/columns/defaults/base.py +++ b/piccolo/columns/defaults/base.py @@ -18,7 +18,7 @@ def sqlite(self) -> str: pass @abstractmethod - def python(self): + def python(self) -> t.Any: pass def get_postgres_interval_string(self, attributes: t.List[str]) -> str: diff --git a/piccolo/columns/defaults/date.py b/piccolo/columns/defaults/date.py index 87e431390..423f112ca 100644 --- a/piccolo/columns/defaults/date.py +++ b/piccolo/columns/defaults/date.py @@ -102,7 +102,15 @@ def from_date(cls, instance: datetime.date): # Might add an enum back which encapsulates all of the options. -DateArg = t.Union[DateOffset, DateCustom, DateNow, Enum, None, datetime.date] +DateArg = t.Union[ + DateOffset, + DateCustom, + DateNow, + Enum, + None, + datetime.date, + t.Callable[[], datetime.date], +] __all__ = ["DateArg", "DateOffset", "DateCustom", "DateNow"] diff --git a/piccolo/columns/defaults/interval.py b/piccolo/columns/defaults/interval.py index f3daba639..4d5f72ae8 100644 --- a/piccolo/columns/defaults/interval.py +++ b/piccolo/columns/defaults/interval.py @@ -80,6 +80,7 @@ def from_timedelta(cls, instance: datetime.timedelta): Enum, None, datetime.timedelta, + t.Callable[[], datetime.timedelta], ] diff --git a/piccolo/columns/defaults/time.py b/piccolo/columns/defaults/time.py index 25535cb5d..9b72416ea 100644 --- a/piccolo/columns/defaults/time.py +++ b/piccolo/columns/defaults/time.py @@ -89,7 +89,15 @@ def from_time(cls, instance: datetime.time): ) -TimeArg = t.Union[TimeCustom, TimeNow, TimeOffset, Enum, None, datetime.time] +TimeArg = t.Union[ + TimeCustom, + TimeNow, + TimeOffset, + Enum, + None, + datetime.time, + t.Callable[[], datetime.time], +] __all__ = ["TimeArg", "TimeCustom", "TimeNow", "TimeOffset"] diff --git a/piccolo/columns/defaults/timestamp.py b/piccolo/columns/defaults/timestamp.py index 9558f4100..73e06d1ef 100644 --- a/piccolo/columns/defaults/timestamp.py +++ b/piccolo/columns/defaults/timestamp.py @@ -138,6 +138,7 @@ class DatetimeDefault: None, datetime.datetime, DatetimeDefault, + t.Callable[[], datetime.datetime], ] diff --git a/piccolo/columns/defaults/uuid.py b/piccolo/columns/defaults/uuid.py index 176d282ec..17b07021c 100644 --- a/piccolo/columns/defaults/uuid.py +++ b/piccolo/columns/defaults/uuid.py @@ -22,7 +22,7 @@ def python(self): return uuid.uuid4() -UUIDArg = t.Union[UUID4, uuid.UUID, str, Enum, None] +UUIDArg = t.Union[UUID4, uuid.UUID, str, Enum, None, t.Callable[[], uuid.UUID]] __all__ = ["UUIDArg", "UUID4"] diff --git a/piccolo/columns/m2m.py b/piccolo/columns/m2m.py index 90469fc1f..29bafe9b5 100644 --- a/piccolo/columns/m2m.py +++ b/piccolo/columns/m2m.py @@ -131,6 +131,7 @@ def get_select_string( if len(self.columns) > 1 or not self.serialisation_safe: column_name = table_2_pk_name else: + assert len(self.columns) > 0 column_name = self.columns[0]._meta.db_column_name return QueryString( @@ -256,15 +257,14 @@ def secondary_table(self) -> t.Type[Table]: @dataclass class M2MAddRelated: - target_row: Table m2m: M2M rows: t.Sequence[Table] extra_column_values: t.Dict[t.Union[Column, str], t.Any] - def __post_init__(self) -> None: - # Normalise `extra_column_values`, so we just have the column names. - self.extra_column_values: t.Dict[str, t.Any] = { + @property + def resolved_extra_column_values(self) -> t.Dict[str, t.Any]: + return { i._meta.name if isinstance(i, Column) else i: j for i, j in self.extra_column_values.items() } @@ -281,7 +281,9 @@ async def _run(self): joining_table_rows = [] for row in rows: - joining_table_row = joining_table(**self.extra_column_values) + joining_table_row = joining_table( + **self.resolved_extra_column_values + ) setattr( joining_table_row, self.m2m._meta.primary_foreign_key._meta.name, @@ -323,7 +325,6 @@ def __await__(self): @dataclass class M2MRemoveRelated: - target_row: Table m2m: M2M rows: t.Sequence[Table] @@ -363,7 +364,6 @@ def __await__(self): @dataclass class M2MGetRelated: - row: Table m2m: M2M diff --git a/piccolo/conf/apps.py b/piccolo/conf/apps.py index 47631c478..6c16c9e81 100644 --- a/piccolo/conf/apps.py +++ b/piccolo/conf/apps.py @@ -157,17 +157,22 @@ class AppConfig: """ app_name: str - migrations_folder_path: str + migrations_folder_path: t.Union[str, pathlib.Path] table_classes: t.List[t.Type[Table]] = field(default_factory=list) migration_dependencies: t.List[str] = field(default_factory=list) commands: t.List[t.Union[t.Callable, Command]] = field( default_factory=list ) - def __post_init__(self) -> None: - if isinstance(self.migrations_folder_path, pathlib.Path): - self.migrations_folder_path = str(self.migrations_folder_path) + @property + def resolved_migrations_folder_path(self) -> str: + return ( + str(self.migrations_folder_path) + if isinstance(self.migrations_folder_path, pathlib.Path) + else self.migrations_folder_path + ) + def __post_init__(self) -> None: self._migration_dependency_app_configs: t.Optional[ t.List[AppConfig] ] = None diff --git a/piccolo/engine/base.py b/piccolo/engine/base.py index 95d1b8a24..bf59426ad 100644 --- a/piccolo/engine/base.py +++ b/piccolo/engine/base.py @@ -7,12 +7,14 @@ import typing as t from abc import ABCMeta, abstractmethod +from typing_extensions import Self + from piccolo.querystring import QueryString from piccolo.utils.sync import run_sync from piccolo.utils.warnings import Level, colored_string, colored_warning if t.TYPE_CHECKING: # pragma: no cover - from piccolo.query.base import Query + from piccolo.query.base import DDL, Query logger = logging.getLogger(__name__) @@ -32,31 +34,76 @@ def validate_savepoint_name(savepoint_name: str) -> None: ) -class Batch: - pass +class BaseBatch(metaclass=ABCMeta): + @abstractmethod + async def __aenter__(self: Self, *args, **kwargs) -> Self: ... + @abstractmethod + async def __aexit__(self, *args, **kwargs): ... -TransactionClass = t.TypeVar("TransactionClass") + @abstractmethod + def __aiter__(self: Self) -> Self: ... + @abstractmethod + async def __anext__(self) -> t.List[t.Dict]: ... -class Engine(t.Generic[TransactionClass], metaclass=ABCMeta): - __slots__ = ("query_id",) +class BaseTransaction(metaclass=ABCMeta): - def __init__(self): - run_sync(self.check_version()) - run_sync(self.prep_database()) - self.query_id = 0 + __slots__: t.Tuple[str, ...] = tuple() - @property @abstractmethod - def engine_type(self) -> str: - pass + async def __aenter__(self, *args, **kwargs): ... - @property @abstractmethod - def min_version_number(self) -> float: - pass + async def __aexit__(self, *args, **kwargs) -> bool: ... + + +class BaseAtomic(metaclass=ABCMeta): + + __slots__: t.Tuple[str, ...] = tuple() + + @abstractmethod + def add(self, *query: t.Union[Query, DDL]): ... + + @abstractmethod + async def run(self): ... + + @abstractmethod + def run_sync(self): ... + + @abstractmethod + def __await__(self): ... + + +TransactionClass = t.TypeVar("TransactionClass", bound=BaseTransaction) + + +class Engine(t.Generic[TransactionClass], metaclass=ABCMeta): + __slots__ = ( + "query_id", + "log_queries", + "log_responses", + "engine_type", + "min_version_number", + "current_transaction", + ) + + def __init__( + self, + engine_type: str, + min_version_number: t.Union[int, float], + log_queries: bool = False, + log_responses: bool = False, + ): + self.log_queries = log_queries + self.log_responses = log_responses + self.engine_type = engine_type + self.min_version_number = min_version_number + + run_sync(self.check_version()) + run_sync(self.prep_database()) + self.query_id = 0 @abstractmethod async def get_version(self) -> float: @@ -76,11 +123,13 @@ async def batch( query: Query, batch_size: int = 100, node: t.Optional[str] = None, - ) -> Batch: + ) -> BaseBatch: pass @abstractmethod - async def run_querystring(self, querystring: QueryString, in_pool: bool): + async def run_querystring( + self, querystring: QueryString, in_pool: bool = True + ): pass @abstractmethod @@ -88,11 +137,11 @@ async def run_ddl(self, ddl: str, in_pool: bool = True): pass @abstractmethod - def transaction(self): + def transaction(self, *args, **kwargs) -> TransactionClass: pass @abstractmethod - def atomic(self): + def atomic(self) -> BaseAtomic: pass async def check_version(self): diff --git a/piccolo/engine/cockroach.py b/piccolo/engine/cockroach.py index ecbb74ad8..d091527bb 100644 --- a/piccolo/engine/cockroach.py +++ b/piccolo/engine/cockroach.py @@ -16,9 +16,6 @@ class CockroachEngine(PostgresEngine): :class:`PostgresEngine `. """ - engine_type = "cockroach" - min_version_number = 0 # Doesn't seem to work with cockroach versioning. - def __init__( self, config: t.Dict[str, t.Any], @@ -34,6 +31,8 @@ def __init__( log_responses=log_responses, extra_nodes=extra_nodes, ) + self.engine_type = "cockroach" + self.min_version_number = 0 async def prep_database(self): try: diff --git a/piccolo/engine/postgres.py b/piccolo/engine/postgres.py index 06b8ffb4b..970623535 100644 --- a/piccolo/engine/postgres.py +++ b/piccolo/engine/postgres.py @@ -4,7 +4,15 @@ import typing as t from dataclasses import dataclass -from piccolo.engine.base import Batch, Engine, validate_savepoint_name +from typing_extensions import Self + +from piccolo.engine.base import ( + BaseAtomic, + BaseBatch, + BaseTransaction, + Engine, + validate_savepoint_name, +) from piccolo.engine.exceptions import TransactionError from piccolo.query.base import DDL, Query from piccolo.querystring import QueryString @@ -18,16 +26,17 @@ from asyncpg.connection import Connection from asyncpg.cursor import Cursor from asyncpg.pool import Pool + from asyncpg.transaction import Transaction @dataclass -class AsyncBatch(Batch): +class AsyncBatch(BaseBatch): connection: Connection query: Query batch_size: int # Set internally - _transaction = None + _transaction: t.Optional[Transaction] = None _cursor: t.Optional[Cursor] = None @property @@ -36,20 +45,26 @@ def cursor(self) -> Cursor: raise ValueError("_cursor not set") return self._cursor + @property + def transaction(self) -> Transaction: + if not self._transaction: + raise ValueError("The transaction can't be found.") + return self._transaction + async def next(self) -> t.List[t.Dict]: data = await self.cursor.fetch(self.batch_size) return await self.query._process_results(data) - def __aiter__(self): + def __aiter__(self: Self) -> Self: return self - async def __anext__(self): + async def __anext__(self) -> t.List[t.Dict]: response = await self.next() if response == []: raise StopAsyncIteration() return response - async def __aenter__(self): + async def __aenter__(self: Self) -> Self: self._transaction = self.connection.transaction() await self._transaction.start() querystring = self.query.querystrings[0] @@ -60,9 +75,9 @@ async def __aenter__(self): async def __aexit__(self, exception_type, exception, traceback): if exception: - await self._transaction.rollback() + await self.transaction.rollback() else: - await self._transaction.commit() + await self.transaction.commit() await self.connection.close() @@ -72,7 +87,7 @@ async def __aexit__(self, exception_type, exception, traceback): ############################################################################### -class Atomic: +class Atomic(BaseAtomic): """ This is useful if you want to build up a transaction programatically, by adding queries to it. @@ -140,7 +155,7 @@ async def release(self): ) -class PostgresTransaction: +class PostgresTransaction(BaseTransaction): """ Used for wrapping queries in a transaction, using a context manager. Currently it's async only. @@ -243,7 +258,7 @@ async def savepoint(self, name: t.Optional[str] = None) -> Savepoint: ########################################################################### - async def __aexit__(self, exception_type, exception, traceback): + async def __aexit__(self, exception_type, exception, traceback) -> bool: if self._parent: return exception is None @@ -269,7 +284,7 @@ async def __aexit__(self, exception_type, exception, traceback): ############################################################################### -class PostgresEngine(Engine[t.Optional[PostgresTransaction]]): +class PostgresEngine(Engine[PostgresTransaction]): """ Used to connect to PostgreSQL. @@ -331,16 +346,10 @@ class PostgresEngine(Engine[t.Optional[PostgresTransaction]]): __slots__ = ( "config", "extensions", - "log_queries", - "log_responses", "extra_nodes", "pool", - "current_transaction", ) - engine_type = "postgres" - min_version_number = 10 - def __init__( self, config: t.Dict[str, t.Any], @@ -362,7 +371,12 @@ def __init__( self.current_transaction = contextvars.ContextVar( f"pg_current_transaction_{database_name}", default=None ) - super().__init__() + super().__init__( + engine_type="postgres", + log_queries=log_queries, + log_responses=log_responses, + min_version_number=10, + ) @staticmethod def _parse_raw_version_string(version_string: str) -> float: diff --git a/piccolo/engine/sqlite.py b/piccolo/engine/sqlite.py index f6fbd4e38..3f7649d76 100644 --- a/piccolo/engine/sqlite.py +++ b/piccolo/engine/sqlite.py @@ -11,7 +11,15 @@ from decimal import Decimal from functools import partial, wraps -from piccolo.engine.base import Batch, Engine, validate_savepoint_name +from typing_extensions import Self + +from piccolo.engine.base import ( + BaseAtomic, + BaseBatch, + BaseTransaction, + Engine, + validate_savepoint_name, +) from piccolo.engine.exceptions import TransactionError from piccolo.query.base import DDL, Query from piccolo.querystring import QueryString @@ -309,7 +317,7 @@ def convert_M2M_out(value: str) -> t.List: @dataclass -class AsyncBatch(Batch): +class AsyncBatch(BaseBatch): connection: Connection query: Query batch_size: int @@ -327,16 +335,16 @@ async def next(self) -> t.List[t.Dict]: data = await self.cursor.fetchmany(self.batch_size) return await self.query._process_results(data) - def __aiter__(self): + def __aiter__(self: Self) -> Self: return self - async def __anext__(self): + async def __anext__(self) -> t.List[t.Dict]: response = await self.next() if response == []: raise StopAsyncIteration() return response - async def __aenter__(self): + async def __aenter__(self: Self) -> Self: querystring = self.query.querystrings[0] template, template_args = querystring.compile_string() @@ -344,7 +352,7 @@ async def __aenter__(self): return self async def __aexit__(self, exception_type, exception, traceback): - await self._cursor.close() + await self.cursor.close() await self.connection.close() return exception is not None @@ -363,7 +371,7 @@ class TransactionType(enum.Enum): exclusive = "EXCLUSIVE" -class Atomic: +class Atomic(BaseAtomic): """ Usage: @@ -384,9 +392,9 @@ def __init__( ): self.engine = engine self.transaction_type = transaction_type - self.queries: t.List[Query] = [] + self.queries: t.List[t.Union[Query, DDL]] = [] - def add(self, *query: Query): + def add(self, *query: t.Union[Query, DDL]): self.queries += list(query) async def run(self): @@ -434,7 +442,7 @@ async def release(self): ) -class SQLiteTransaction: +class SQLiteTransaction(BaseTransaction): """ Used for wrapping queries in a transaction, using a context manager. Currently it's async only. @@ -534,7 +542,7 @@ async def savepoint(self, name: t.Optional[str] = None) -> Savepoint: ########################################################################### - async def __aexit__(self, exception_type, exception, traceback): + async def __aexit__(self, exception_type, exception, traceback) -> bool: if self._parent: return exception is None @@ -560,16 +568,8 @@ def dict_factory(cursor, row) -> t.Dict: return {col[0]: row[idx] for idx, col in enumerate(cursor.description)} -class SQLiteEngine(Engine[t.Optional[SQLiteTransaction]]): - __slots__ = ( - "connection_kwargs", - "current_transaction", - "log_queries", - "log_responses", - ) - - engine_type = "sqlite" - min_version_number = 3.25 +class SQLiteEngine(Engine[SQLiteTransaction]): + __slots__ = ("connection_kwargs",) def __init__( self, @@ -613,7 +613,12 @@ def __init__( f"sqlite_current_transaction_{path}", default=None ) - super().__init__() + super().__init__( + engine_type="sqlite", + min_version_number=3.25, + log_queries=log_queries, + log_responses=log_responses, + ) @property def path(self): diff --git a/piccolo/query/methods/create_index.py b/piccolo/query/methods/create_index.py index b10e0c203..c81d38b9d 100644 --- a/piccolo/query/methods/create_index.py +++ b/piccolo/query/methods/create_index.py @@ -14,7 +14,7 @@ class CreateIndex(DDL): def __init__( self, table: t.Type[Table], - columns: t.List[t.Union[Column, str]], + columns: t.Union[t.List[Column], t.List[str]], method: IndexMethod = IndexMethod.btree, if_not_exists: bool = False, **kwargs, diff --git a/piccolo/query/methods/drop_index.py b/piccolo/query/methods/drop_index.py index 437728437..049a066dd 100644 --- a/piccolo/query/methods/drop_index.py +++ b/piccolo/query/methods/drop_index.py @@ -14,7 +14,7 @@ class DropIndex(Query): def __init__( self, table: t.Type[Table], - columns: t.List[t.Union[Column, str]], + columns: t.Union[t.List[Column], t.List[str]], if_exists: bool = True, **kwargs, ): diff --git a/piccolo/query/methods/objects.py b/piccolo/query/methods/objects.py index f11f78e8e..7f2b5aaed 100644 --- a/piccolo/query/methods/objects.py +++ b/piccolo/query/methods/objects.py @@ -5,7 +5,7 @@ from piccolo.columns.column_types import ForeignKey from piccolo.columns.combination import And, Where from piccolo.custom_types import Combinable, TableInstance -from piccolo.engine.base import Batch +from piccolo.engine.base import BaseBatch from piccolo.query.base import Query from piccolo.query.methods.select import Select from piccolo.query.mixins import ( @@ -268,17 +268,17 @@ def where(self: Self, *where: t.Union[Combinable, QueryString]) -> Self: ########################################################################### - def first(self: Self) -> First[TableInstance]: + def first(self) -> First[TableInstance]: self.limit_delegate.limit(1) return First[TableInstance](query=self) - def get(self: Self, where: Combinable) -> Get[TableInstance]: + def get(self, where: Combinable) -> Get[TableInstance]: self.where_delegate.where(where) self.limit_delegate.limit(1) return Get[TableInstance](query=First[TableInstance](query=self)) def get_or_create( - self: Self, + self, where: Combinable, defaults: t.Optional[t.Dict[Column, t.Any]] = None, ) -> GetOrCreate[TableInstance]: @@ -288,17 +288,17 @@ def get_or_create( query=self, table_class=self.table, where=where, defaults=defaults ) - def create(self: Self, **columns: t.Any) -> Create[TableInstance]: + def create(self, **columns: t.Any) -> Create[TableInstance]: return Create[TableInstance](table_class=self.table, columns=columns) ########################################################################### async def batch( - self: Self, + self, batch_size: t.Optional[int] = None, node: t.Optional[str] = None, **kwargs, - ) -> Batch: + ) -> BaseBatch: if batch_size: kwargs.update(batch_size=batch_size) if node: diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index fdb929f8a..05302455f 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -5,11 +5,11 @@ from collections import OrderedDict from piccolo.columns import Column, Selectable -from piccolo.columns.column_types import JSON, JSONB, PrimaryKey +from piccolo.columns.column_types import JSON, JSONB from piccolo.columns.m2m import M2MSelect from piccolo.columns.readable import Readable from piccolo.custom_types import TableInstance -from piccolo.engine.base import Batch +from piccolo.engine.base import BaseBatch from piccolo.query.base import Query from piccolo.query.mixins import ( AsOfDelegate, @@ -217,7 +217,7 @@ async def _splice_m2m_rows( self, response: t.List[t.Dict[str, t.Any]], secondary_table: t.Type[Table], - secondary_table_pk: PrimaryKey, + secondary_table_pk: Column, m2m_name: str, m2m_select: M2MSelect, as_list: bool = False, @@ -386,14 +386,20 @@ def order_by( return self @t.overload - def output(self: Self, *, as_list: bool) -> SelectList: ... + def output(self: Self, *, as_list: bool) -> SelectList: # type: ignore + ... @t.overload - def output(self: Self, *, as_json: bool) -> SelectJSON: ... + def output(self: Self, *, as_json: bool) -> SelectJSON: # type: ignore + ... @t.overload def output(self: Self, *, load_json: bool) -> Self: ... + @t.overload + def output(self: Self, *, load_json: bool, as_list: bool) -> SelectJSON: # type: ignore # noqa: E501 + ... + @t.overload def output(self: Self, *, nested: bool) -> Self: ... @@ -404,7 +410,7 @@ def output( as_json: bool = False, load_json: bool = False, nested: bool = False, - ): + ) -> t.Union[Self, SelectJSON, SelectList]: self.output_delegate.output( as_list=as_list, as_json=as_json, @@ -436,7 +442,7 @@ async def batch( batch_size: t.Optional[int] = None, node: t.Optional[str] = None, **kwargs, - ) -> Batch: + ) -> BaseBatch: if batch_size: kwargs.update(batch_size=batch_size) if node: diff --git a/piccolo/query/mixins.py b/piccolo/query/mixins.py index 214d1b8d7..d9d5f84ca 100644 --- a/piccolo/query/mixins.py +++ b/piccolo/query/mixins.py @@ -207,7 +207,6 @@ def __str__(self): @dataclass class Output: - as_json: bool = False as_list: bool = False as_objects: bool = False @@ -236,7 +235,6 @@ class Callback: @dataclass class WhereDelegate: - _where: t.Optional[Combinable] = None _where_columns: t.List[Column] = field(default_factory=list) @@ -246,7 +244,8 @@ def get_where_columns(self): needed. """ self._where_columns = [] - self._extract_columns(self._where) + if self._where is not None: + self._extract_columns(self._where) return self._where_columns def _extract_columns(self, combinable: Combinable): @@ -277,7 +276,6 @@ def where(self, *where: t.Union[Combinable, QueryString]): @dataclass class OrderByDelegate: - _order_by: OrderBy = field(default_factory=OrderBy) def get_order_by_columns(self) -> t.List[Column]: @@ -303,7 +301,6 @@ def order_by(self, *columns: t.Union[Column, OrderByRaw], ascending=True): @dataclass class LimitDelegate: - _limit: t.Optional[Limit] = None _first: bool = False @@ -330,7 +327,6 @@ def as_of(self, interval: str = "-1s"): @dataclass class DistinctDelegate: - _distinct: Distinct = field( default_factory=lambda: Distinct(enabled=False, on=None) ) @@ -356,7 +352,6 @@ def returning(self, columns: t.Sequence[Column]): @dataclass class CountDelegate: - _count: bool = False def count(self): @@ -365,7 +360,6 @@ def count(self): @dataclass class AddDelegate: - _add: t.List[Table] = field(default_factory=list) def add(self, *instances: Table, table_class: t.Type[Table]): @@ -421,8 +415,7 @@ def output( self._output.nested = bool(nested) def copy(self) -> OutputDelegate: - _output = self._output.copy() if self._output is not None else None - return self.__class__(_output=_output) + return self.__class__(_output=self._output.copy()) @dataclass diff --git a/piccolo/schema.py b/piccolo/schema.py index ef0bd6ab4..01cd0bd91 100644 --- a/piccolo/schema.py +++ b/piccolo/schema.py @@ -10,7 +10,6 @@ class SchemaDDLBase(abc.ABC): - db: Engine @property @@ -132,16 +131,19 @@ def __init__(self, db: Engine, schema_name: str): self.db = db self.schema_name = schema_name - async def run(self): - response = await self.db.run_querystring( - QueryString( - """ + async def run(self) -> t.List[str]: + response = t.cast( + t.List[t.Dict], + await self.db.run_querystring( + QueryString( + """ SELECT table_name FROM information_schema.tables WHERE table_schema = {} """, - self.schema_name, - ) + self.schema_name, + ) + ), ) return [i["table_name"] for i in response] @@ -156,9 +158,14 @@ class ListSchemas: def __init__(self, db: Engine): self.db = db - async def run(self): - response = await self.db.run_querystring( - QueryString("SELECT schema_name FROM information_schema.schemata") + async def run(self) -> t.List[str]: + response = t.cast( + t.List[t.Dict], + await self.db.run_querystring( + QueryString( + "SELECT schema_name FROM information_schema.schemata" + ) + ), ) return [i["schema_name"] for i in response] @@ -180,7 +187,7 @@ def __init__(self, db: t.Optional[Engine] = None): """ db = db or engine_finder() - if not db: + if db is None: raise ValueError("The DB can't be found.") self.db = db diff --git a/piccolo/table.py b/piccolo/table.py index 7882db95e..3b3ff4853 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -143,8 +143,11 @@ def db(self) -> Engine: def db(self, value: Engine): self._db = value - def refresh_db(self): - self.db = engine_finder() + def refresh_db(self) -> None: + engine = engine_finder() + if engine is None: + raise ValueError("The engine can't be found") + self.db = engine def get_column_by_name(self, name: str) -> Column: """ @@ -184,8 +187,8 @@ def get_auto_update_values(self) -> t.Dict[Column, t.Any]: class TableMetaclass(type): - def __str__(cls): - return cls._table_str() + def __str__(cls) -> str: + return cls._table_str() # type: ignore def __repr__(cls): """ @@ -822,7 +825,7 @@ def __repr__(self) -> str: @classmethod def all_related( cls, exclude: t.Optional[t.List[t.Union[str, ForeignKey]]] = None - ) -> t.List[Column]: + ) -> t.List[ForeignKey]: """ Used in conjunction with ``objects`` queries. Just as we can use ``all_related`` on a ``ForeignKey``, you can also use it for the table @@ -1251,7 +1254,7 @@ def indexes(cls) -> Indexes: @classmethod def create_index( cls, - columns: t.List[t.Union[Column, str]], + columns: t.Union[t.List[Column], t.List[str]], method: IndexMethod = IndexMethod.btree, if_not_exists: bool = False, ) -> CreateIndex: @@ -1273,7 +1276,9 @@ def create_index( @classmethod def drop_index( - cls, columns: t.List[t.Union[Column, str]], if_exists: bool = True + cls, + columns: t.Union[t.List[Column], t.List[str]], + if_exists: bool = True, ) -> DropIndex: """ Drop a table index. If multiple columns are specified, this refers @@ -1464,22 +1469,18 @@ async def drop_db_tables(*tables: t.Type[Table]) -> None: # SQLite doesn't support CASCADE, so we have to drop them in the # correct order. sorted_table_classes = reversed(sort_table_classes(list(tables))) - atomic = engine.atomic() - atomic.add( - *[ - Alter(table=table).drop_table(if_exists=True) - for table in sorted_table_classes - ] - ) + ddl_statements = [ + Alter(table=table).drop_table(if_exists=True) + for table in sorted_table_classes + ] else: - atomic = engine.atomic() - atomic.add( - *[ - table.alter().drop_table(cascade=True, if_exists=True) - for table in tables - ] - ) + ddl_statements = [ + table.alter().drop_table(cascade=True, if_exists=True) + for table in tables + ] + atomic = engine.atomic() + atomic.add(*ddl_statements) await atomic.run() diff --git a/piccolo/utils/encoding.py b/piccolo/utils/encoding.py index 48a131dc5..97fde4683 100644 --- a/piccolo/utils/encoding.py +++ b/piccolo/utils/encoding.py @@ -19,13 +19,15 @@ def dump_json(data: t.Any, pretty: bool = False) -> str: orjson_params["option"] = ( orjson.OPT_INDENT_2 | orjson.OPT_APPEND_NEWLINE # type: ignore ) - return orjson.dumps(data, **orjson_params).decode("utf8") + return orjson.dumps(data, **orjson_params).decode( # type: ignore + "utf8" + ) else: params: t.Dict[str, t.Any] = {"default": str} if pretty: params["indent"] = 2 - return json.dumps(data, **params) + return json.dumps(data, **params) # type: ignore def load_json(data: str) -> t.Any: - return orjson.loads(data) if ORJSON else json.loads(data) + return orjson.loads(data) if ORJSON else json.loads(data) # type: ignore diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index 6fbda712d..726e62994 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -7,4 +7,5 @@ slotscheck==0.17.1 twine==3.8.0 mypy==1.7.1 pip-upgrader==1.4.15 +pyright==1.1.367 wheel==0.38.1 diff --git a/scripts/pyright.sh b/scripts/pyright.sh new file mode 100755 index 000000000..616652eb8 --- /dev/null +++ b/scripts/pyright.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# We have a separate script for pyright vs lint.sh, as it's hard to get 100% +# success in pyright. In the future we might merge them. + +set -e + +MODULES="piccolo" +SOURCES="$MODULES tests" + +echo "Running pyright..." +pyright $sources +echo "-----" + +echo "All passed!" diff --git a/tests/apps/migrations/auto/integration/test_migrations.py b/tests/apps/migrations/auto/integration/test_migrations.py index 2851aaee9..84b194ea8 100644 --- a/tests/apps/migrations/auto/integration/test_migrations.py +++ b/tests/apps/migrations/auto/integration/test_migrations.py @@ -145,7 +145,7 @@ def _test_migrations( """ app_config = self._get_app_config() - migrations_folder_path = app_config.migrations_folder_path + migrations_folder_path = app_config.resolved_migrations_folder_path if os.path.exists(migrations_folder_path): shutil.rmtree(migrations_folder_path) diff --git a/tests/conf/test_apps.py b/tests/conf/test_apps.py index 44b2d4a4a..0749f5f25 100644 --- a/tests/conf/test_apps.py +++ b/tests/conf/test_apps.py @@ -85,7 +85,7 @@ def test_pathlib(self): config = AppConfig( app_name="music", migrations_folder_path=pathlib.Path(__file__) ) - self.assertEqual(config.migrations_folder_path, __file__) + self.assertEqual(config.resolved_migrations_folder_path, __file__) def test_get_table_with_name(self): """ diff --git a/tests/engine/test_nested_transaction.py b/tests/engine/test_nested_transaction.py index 23bee59a4..71d519b79 100644 --- a/tests/engine/test_nested_transaction.py +++ b/tests/engine/test_nested_transaction.py @@ -45,10 +45,12 @@ async def run_nested(self): self.assertTrue(await Musician.table_exists().run()) musician = await Musician.select("name").first().run() + assert musician is not None self.assertEqual(musician["name"], "Bob") self.assertTrue(await Roadie.table_exists().run()) roadie = await Roadie.select("name").first().run() + assert roadie is not None self.assertEqual(roadie["name"], "Dave") def test_nested(self): diff --git a/tests/engine/test_transaction.py b/tests/engine/test_transaction.py index 4b47f8759..88e4cff15 100644 --- a/tests/engine/test_transaction.py +++ b/tests/engine/test_transaction.py @@ -4,7 +4,6 @@ import pytest -from piccolo.engine.postgres import Atomic from piccolo.engine.sqlite import SQLiteEngine, TransactionType from piccolo.table import drop_db_tables_sync from piccolo.utils.sync import run_sync @@ -58,7 +57,7 @@ async def run() -> None: engine = Band._meta.db await engine.start_connection_pool() - atomic: Atomic = engine.atomic() + atomic = engine.atomic() atomic.add( Manager.create_table(), Band.create_table(), diff --git a/tests/table/test_indexes.py b/tests/table/test_indexes.py index 6aebd350c..13d1758de 100644 --- a/tests/table/test_indexes.py +++ b/tests/table/test_indexes.py @@ -1,5 +1,7 @@ +import typing as t from unittest import TestCase +from piccolo.columns.base import Column from piccolo.columns.column_types import Integer from piccolo.table import Table from tests.example_apps.music.tables import Manager @@ -45,12 +47,12 @@ def setUp(self): def tearDown(self): Concert.alter().drop_table().run_sync() - def test_problematic_name(self): + def test_problematic_name(self) -> None: """ Make sure we can add an index to a column with a problematic name (which clashes with a SQL keyword). """ - columns = [Concert.order] + columns: t.List[Column] = [Concert.order] Concert.create_index(columns=columns).run_sync() index_name = Concert._get_index_name([i._meta.name for i in columns]) diff --git a/tests/utils/test_pydantic.py b/tests/utils/test_pydantic.py index 5447361e9..82096603d 100644 --- a/tests/utils/test_pydantic.py +++ b/tests/utils/test_pydantic.py @@ -274,6 +274,7 @@ class Ticket(Table): # We'll also fetch it from the DB in case the database adapter's UUID # is used. ticket_from_db = Ticket.objects().first().run_sync() + assert ticket_from_db is not None for ticket_ in (ticket, ticket_from_db): json = pydantic_model(**ticket_.to_dict()).model_dump_json() @@ -368,8 +369,8 @@ class Movie(Table): json_string = '{"code": 12345}' model_instance = pydantic_model(meta=json_string, meta_b=json_string) - self.assertEqual(model_instance.meta, json_string) - self.assertEqual(model_instance.meta_b, json_string) + self.assertEqual(model_instance.meta, json_string) # type: ignore + self.assertEqual(model_instance.meta_b, json_string) # type: ignore def test_deserialize_json(self): class Movie(Table): @@ -384,8 +385,8 @@ class Movie(Table): output = {"code": 12345} model_instance = pydantic_model(meta=json_string, meta_b=json_string) - self.assertEqual(model_instance.meta, output) - self.assertEqual(model_instance.meta_b, output) + self.assertEqual(model_instance.meta, output) # type: ignore + self.assertEqual(model_instance.meta_b, output) # type: ignore def test_validation(self): class Movie(Table): @@ -428,8 +429,8 @@ class Movie(Table): pydantic_model = create_pydantic_model(table=Movie) movie = pydantic_model(meta=None, meta_b=None) - self.assertIsNone(movie.meta) - self.assertIsNone(movie.meta_b) + self.assertIsNone(movie.meta) # type: ignore + self.assertIsNone(movie.meta_b) # type: ignore class TestExcludeColumns(TestCase): @@ -490,7 +491,7 @@ class Computer(Table): with self.assertRaises(ValueError): create_pydantic_model( Computer, - exclude_columns=("CPU",), + exclude_columns=("CPU",), # type: ignore ) def test_invalid_column_different_table(self): @@ -629,7 +630,10 @@ class Band(Table): ####################################################################### - ManagerModel = BandModel.model_fields["manager"].annotation + ManagerModel = t.cast( + t.Type[pydantic.BaseModel], + BandModel.model_fields["manager"].annotation, + ) self.assertTrue(issubclass(ManagerModel, pydantic.BaseModel)) self.assertEqual( [i for i in ManagerModel.model_fields.keys()], ["name", "country"] @@ -637,7 +641,10 @@ class Band(Table): ####################################################################### - CountryModel = ManagerModel.model_fields["country"].annotation + CountryModel = t.cast( + t.Type[pydantic.BaseModel], + ManagerModel.model_fields["country"].annotation, + ) self.assertTrue(issubclass(CountryModel, pydantic.BaseModel)) self.assertEqual( [i for i in CountryModel.model_fields.keys()], ["name"] @@ -674,7 +681,10 @@ class Concert(Table): BandModel = create_pydantic_model(table=Band, nested=(Band.manager,)) - ManagerModel = BandModel.model_fields["manager"].annotation + ManagerModel = t.cast( + t.Type[pydantic.BaseModel], + BandModel.model_fields["manager"].annotation, + ) self.assertTrue(issubclass(ManagerModel, pydantic.BaseModel)) self.assertEqual( [i for i in ManagerModel.model_fields.keys()], ["name", "country"] @@ -690,22 +700,29 @@ class Concert(Table): # Test two levels deep BandModel = create_pydantic_model( - table=Band, nested=(Band.manager.country,) + table=Band, nested=(Band.manager._.country,) ) - ManagerModel = BandModel.model_fields["manager"].annotation + ManagerModel = t.cast( + t.Type[pydantic.BaseModel], + BandModel.model_fields["manager"].annotation, + ) self.assertTrue(issubclass(ManagerModel, pydantic.BaseModel)) self.assertEqual( [i for i in ManagerModel.model_fields.keys()], ["name", "country"] ) self.assertEqual(ManagerModel.__qualname__, "Band.manager") - AssistantManagerType = BandModel.model_fields[ - "assistant_manager" - ].annotation + AssistantManagerType = t.cast( + t.Type[pydantic.BaseModel], + BandModel.model_fields["assistant_manager"].annotation, + ) self.assertIs(AssistantManagerType, t.Optional[int]) - CountryModel = ManagerModel.model_fields["country"].annotation + CountryModel = t.cast( + t.Type[pydantic.BaseModel], + ManagerModel.model_fields["country"].annotation, + ) self.assertTrue(issubclass(CountryModel, pydantic.BaseModel)) self.assertEqual( [i for i in CountryModel.model_fields.keys()], ["name"] @@ -716,13 +733,16 @@ class Concert(Table): # Test three levels deep ConcertModel = create_pydantic_model( - Concert, nested=(Concert.band_1.manager,) + Concert, nested=(Concert.band_1._.manager,) ) VenueModel = ConcertModel.model_fields["venue"].annotation self.assertIs(VenueModel, t.Optional[int]) - BandModel = ConcertModel.model_fields["band_1"].annotation + BandModel = t.cast( + t.Type[pydantic.BaseModel], + ConcertModel.model_fields["band_1"].annotation, + ) self.assertTrue(issubclass(BandModel, pydantic.BaseModel)) self.assertEqual( [i for i in BandModel.model_fields.keys()], @@ -730,7 +750,10 @@ class Concert(Table): ) self.assertEqual(BandModel.__qualname__, "Concert.band_1") - ManagerModel = BandModel.model_fields["manager"].annotation + ManagerModel = t.cast( + t.Type[pydantic.BaseModel], + BandModel.model_fields["manager"].annotation, + ) self.assertTrue(issubclass(ManagerModel, pydantic.BaseModel)) self.assertEqual( [i for i in ManagerModel.model_fields.keys()], @@ -751,11 +774,14 @@ class Concert(Table): MyConcertModel = create_pydantic_model( Concert, - nested=(Concert.band_1.manager,), + nested=(Concert.band_1._.manager,), model_name="MyConcertModel", ) - BandModel = MyConcertModel.model_fields["band_1"].annotation + BandModel = t.cast( + t.Type[pydantic.BaseModel], + MyConcertModel.model_fields["band_1"].annotation, + ) self.assertEqual(BandModel.__qualname__, "MyConcertModel.band_1") ManagerModel = BandModel.model_fields["manager"].annotation @@ -763,7 +789,7 @@ class Concert(Table): ManagerModel.__qualname__, "MyConcertModel.band_1.manager" ) - def test_cascaded_args(self): + def test_cascaded_args(self) -> None: """ Make sure that arguments passed to ``create_pydantic_model`` are cascaded to nested models. @@ -784,14 +810,20 @@ class Band(Table): table=Band, nested=True, include_default_columns=True ) - ManagerModel = BandModel.model_fields["manager"].annotation + ManagerModel = t.cast( + t.Type[pydantic.BaseModel], + BandModel.model_fields["manager"].annotation, + ) self.assertTrue(issubclass(ManagerModel, pydantic.BaseModel)) self.assertEqual( [i for i in ManagerModel.model_fields.keys()], ["id", "name", "country"], ) - CountryModel = ManagerModel.model_fields["country"].annotation + CountryModel = t.cast( + t.Type[pydantic.BaseModel], + ManagerModel.model_fields["country"].annotation, + ) self.assertTrue(issubclass(CountryModel, pydantic.BaseModel)) self.assertEqual( [i for i in CountryModel.model_fields.keys()], ["id", "name"] @@ -823,13 +855,22 @@ class Concert(Table): table=Concert, nested=True, max_recursion_depth=2 ) - VenueModel = ConcertModel.model_fields["venue"].annotation + VenueModel = t.cast( + t.Type[pydantic.BaseModel], + ConcertModel.model_fields["venue"].annotation, + ) self.assertTrue(issubclass(VenueModel, pydantic.BaseModel)) - BandModel = ConcertModel.model_fields["band"].annotation + BandModel = t.cast( + t.Type[pydantic.BaseModel], + ConcertModel.model_fields["band"].annotation, + ) self.assertTrue(issubclass(BandModel, pydantic.BaseModel)) - ManagerModel = BandModel.model_fields["manager"].annotation + ManagerModel = t.cast( + t.Type[pydantic.BaseModel], + BandModel.model_fields["manager"].annotation, + ) self.assertTrue(issubclass(ManagerModel, pydantic.BaseModel)) # We should have hit the recursion depth: @@ -851,7 +892,7 @@ class Band(Table): model = BandModel(regrettable_column_name="test") - self.assertEqual(model.name, "test") + self.assertEqual(model.name, "test") # type: ignore class TestJSONSchemaExtra(TestCase): @@ -885,7 +926,7 @@ class Band(Table): config: pydantic.config.ConfigDict = {"extra": "forbid"} model = create_pydantic_model(Band, pydantic_config=config) - self.assertEqual(model.model_config["extra"], "forbid") + self.assertEqual(model.model_config.get("extra"), "forbid") def test_pydantic_invalid_extra_fields(self) -> None: """