diff --git a/aioodbc/cursor.py b/aioodbc/cursor.py index 126f755..1fe70bc 100644 --- a/aioodbc/cursor.py +++ b/aioodbc/cursor.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any, Callable, Coroutine, List, Optional, Tuple, TypeVar + import pyodbc from .log import logger @@ -8,6 +10,9 @@ __all__ = ["Cursor"] +_T = TypeVar("_T") + + class Cursor: """Cursors represent a database cursor (and map to ODBC HSTMTs), which is used to manage the context of a fetch operation. @@ -17,13 +22,23 @@ class Cursor: the other cursors. """ - def __init__(self, pyodbc_cursor: pyodbc.Cursor, connection, echo=False): + def __init__( + self, + pyodbc_cursor: pyodbc.Cursor, + connection: Connection, + echo: bool = False, + ) -> None: self._conn = connection self._impl: pyodbc.Cursor = pyodbc_cursor self._loop = connection.loop self._echo: bool = echo - async def _run_operation(self, func, *args, **kwargs): + async def _run_operation( + self, + func: Callable[..., _T], + *args: Any, + **kwargs: Any, + ) -> _T: # execute func in thread pool of attached to cursor connection if not self._conn: raise pyodbc.OperationalError("Cursor is closed.") @@ -37,17 +52,17 @@ async def _run_operation(self, func, *args, **kwargs): raise @property - def echo(self): + def echo(self) -> bool: """Return echo mode status.""" return self._echo @property - def connection(self): + def connection(self) -> Optional[Connection]: """Cursors database connection""" return self._conn @property - def autocommit(self): + def autocommit(self) -> bool: """Show autocommit mode for current database session. True if connection is in autocommit mode; False otherwse. The default is False. @@ -56,12 +71,12 @@ def autocommit(self): return self._conn.autocommit @autocommit.setter - def autocommit(self, value): + def autocommit(self, value: bool) -> None: assert self._conn is not None # mypy self._conn.autocommit = value @property - def rowcount(self): + def rowcount(self) -> int: """The number of rows modified by the previous DDL statement. This is -1 if no SQL has been executed or if the number of rows is @@ -73,7 +88,7 @@ def rowcount(self): return self._impl.rowcount @property - def description(self): + def description(self) -> Tuple[Tuple[str, Any, int, int, int, int, bool]]: """This read-only attribute is a list of 7-item tuples, each containing (name, type_code, display_size, internal_size, precision, scale, null_ok). @@ -91,12 +106,12 @@ def description(self): return self._impl.description @property - def closed(self): + def closed(self) -> bool: """Read only property indicates if cursor has been closed""" return self._conn is None @property - def arraysize(self): + def arraysize(self) -> int: """This read/write attribute specifies the number of rows to fetch at a time with .fetchmany() . It defaults to 1 meaning to fetch a single row at a time. @@ -104,10 +119,10 @@ def arraysize(self): return self._impl.arraysize @arraysize.setter - def arraysize(self, size): + def arraysize(self, size: int) -> None: self._impl.arraysize = size - async def close(self): + async def close(self) -> None: """Close the cursor now (rather than whenever __del__ is called). The cursor will be unusable from this point forward; an Error @@ -119,7 +134,7 @@ async def close(self): await self._run_operation(self._impl.close) self._conn = None - async def execute(self, sql, *params): + async def execute(self, sql: str, *params: Any) -> Cursor: """Executes the given operation substituting any markers with the given parameters. @@ -136,7 +151,7 @@ async def execute(self, sql, *params): await self._run_operation(self._impl.execute, sql, *params) return self - def executemany(self, sql, *params): + def executemany(self, sql: str, *params: Any) -> Coroutine[Any, Any, None]: """Prepare a database query or command and then execute it against all parameter sequences found in the sequence seq_of_params. @@ -157,7 +172,7 @@ async def setoutputsize(self, *args, **kwargs): """Does nothing, required by DB API.""" return None - def fetchone(self): + def fetchone(self) -> Coroutine[Any, Any, Optional[pyodbc.Row]]: """Returns the next row or None when no more data is available. A ProgrammingError exception is raised if no SQL has been executed @@ -167,7 +182,7 @@ def fetchone(self): fut = self._run_operation(self._impl.fetchone) return fut - def fetchall(self): + def fetchall(self) -> Coroutine[Any, Any, List[pyodbc.Row]]: """Returns a list of all remaining rows. Since this reads all rows into memory, it should not be used if @@ -181,7 +196,7 @@ def fetchall(self): fut = self._run_operation(self._impl.fetchall) return fut - def fetchmany(self, size=0): + def fetchmany(self, size: int = 0) -> Coroutine[Any, Any, List[pyodbc.Row]]: """Returns a list of remaining rows, containing no more than size rows, used to process results in chunks. The list will be empty when there are no more rows. @@ -200,7 +215,7 @@ def fetchmany(self, size=0): fut = self._run_operation(self._impl.fetchmany, size) return fut - def nextset(self): + def nextset(self) -> Coroutine[Any, Any, bool]: """This method will make the cursor skip to the next available set, discarding any remaining rows from the current set. @@ -214,7 +229,13 @@ def nextset(self): fut = self._run_operation(self._impl.nextset) return fut - def tables(self, **kw): + def tables( + self, + table: Optional[str] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + tableType: Optional[str] = None, + ) -> Coroutine[Any, Any, pyodbc.Cursor]: """Creates a result set of tables in the database that match the given criteria. @@ -223,10 +244,22 @@ def tables(self, **kw): :param schema: the schmea name :param tableType: one of TABLE, VIEW, SYSTEM TABLE ... """ - fut = self._run_operation(self._impl.tables, **kw) + fut = self._run_operation( + self._impl.tables, + table=table, + catalog=catalog, + schema=schema, + tableType=tableType, + ) return fut - def columns(self, **kw): + def columns( + self, + table: Optional[str] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + column: Optional[str] = None, + ) -> Coroutine[Any, Any, pyodbc.Cursor]: """Creates a results set of column names in specified tables by executing the ODBC SQLColumns function. Each row fetched has the following columns. @@ -236,10 +269,23 @@ def columns(self, **kw): :param schema: the schmea name :param column: string search pattern for column names. """ - fut = self._run_operation(self._impl.columns, **kw) + fut = self._run_operation( + self._impl.columns, + table=table, + catalog=catalog, + schema=schema, + column=column, + ) return fut - def statistics(self, table, catalog=None, schema=None, unique=False, quick=True): + def statistics( + self, + table: str, + catalog: Optional[str] = None, + schema: Optional[str] = None, + unique: bool = False, + quick: bool = True, + ) -> Coroutine[Any, Any, pyodbc.Cursor]: """Creates a results set of statistics about a single table and the indexes associated with the table by executing SQLStatistics. @@ -262,8 +308,12 @@ def statistics(self, table, catalog=None, schema=None, unique=False, quick=True) return fut def rowIdColumns( - self, table, catalog=None, schema=None, nullable=True # nopep8 - ): + self, + table: str, + catalog: Optional[str] = None, + schema: Optional[str] = None, + nullable: bool = True, + ) -> Coroutine[Any, Any, pyodbc.Cursor]: """Executes SQLSpecialColumns with SQL_BEST_ROWID which creates a result set of columns that uniquely identify a row """ @@ -277,8 +327,12 @@ def rowIdColumns( return fut def rowVerColumns( - self, table, catalog=None, schema=None, nullable=True # nopep8 - ): + self, + table: str, + catalog: Optional[str] = None, + schema: Optional[str] = None, + nullable: bool = True, + ) -> Coroutine[Any, Any, pyodbc.Cursor]: """Executes SQLSpecialColumns with SQL_ROWVER which creates a result set of columns that are automatically updated when any value in the row is updated. @@ -292,7 +346,12 @@ def rowVerColumns( ) return fut - def primaryKeys(self, table, catalog=None, schema=None): # nopep8 + def primaryKeys( + self, + table: str, + catalog: Optional[str] = None, + schema: Optional[str] = None, + ) -> Coroutine[Any, Any, pyodbc.Cursor]: """Creates a result set of column names that make up the primary key for a table by executing the SQLPrimaryKeys function.""" fut = self._run_operation( @@ -300,17 +359,36 @@ def primaryKeys(self, table, catalog=None, schema=None): # nopep8 ) return fut - def foreignKeys(self, *a, **kw): # nopep8 + def foreignKeys( + self, + table: Optional[str] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + foreignTable: Optional[str] = None, + foreignCatalog: Optional[str] = None, + foreignSchema: Optional[str] = None, + ) -> Coroutine[Any, Any, pyodbc.Cursor]: """Executes the SQLForeignKeys function and creates a result set of column names that are foreign keys in the specified table (columns in the specified table that refer to primary keys in other tables) or foreign keys in other tables that refer to the primary key in the specified table. """ - fut = self._run_operation(self._impl.foreignKeys, *a, **kw) + fut = self._run_operation( + self._impl.foreignKeys, + table=table, + catalog=catalog, + schema=schema, + foreignTable=foreignTable, + foreignCatalog=foreignCatalog, + foreignSchema=foreignSchema, + ) return fut - def getTypeInfo(self, sql_type): # nopep8 + def getTypeInfo( + self, + sql_type: Optional[int] = None, + ) -> Coroutine[Any, Any, pyodbc.Cursor]: """Executes SQLGetTypeInfo a creates a result set with information about the specified data type or all data types supported by the ODBC driver if not specified. @@ -318,42 +396,62 @@ def getTypeInfo(self, sql_type): # nopep8 fut = self._run_operation(self._impl.getTypeInfo, sql_type) return fut - def procedures(self, *a, **kw): + def procedures( + self, + procedure: Optional[str] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + ) -> Coroutine[Any, Any, pyodbc.Cursor]: """Executes SQLProcedures and creates a result set of information about the procedures in the data source. """ - fut = self._run_operation(self._impl.procedures, *a, **kw) + fut = self._run_operation( + self._impl.procedures, + procedure=procedure, + catalog=catalog, + schema=schema, + ) return fut - def procedureColumns(self, *a, **kw): # nopep8 - fut = self._run_operation(self._impl.procedureColumns, *a, **kw) + def procedureColumns( + self, + procedure: Optional[str] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + ) -> Coroutine[Any, Any, pyodbc.Cursor]: + fut = self._run_operation( + self._impl.procedureColumns, + procedure=procedure, + catalog=catalog, + schema=schema, + ) return fut - def skip(self, count): + def skip(self, count: int) -> Coroutine[Any, Any, None]: fut = self._run_operation(self._impl.skip, count) return fut - def commit(self): + def commit(self) -> Coroutine[Any, Any, None]: fut = self._run_operation(self._impl.commit) return fut - def rollback(self): + def rollback(self) -> Coroutine[Any, Any, None]: fut = self._run_operation(self._impl.rollback) return fut - def __aiter__(self): + def __aiter__(self) -> Cursor: return self - async def __anext__(self): + async def __anext__(self) -> pyodbc.Row: ret = await self.fetchone() if ret is not None: return ret else: raise StopAsyncIteration - async def __aenter__(self): + async def __aenter__(self) -> Cursor: return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: await self.close() return