From a562c56f4d4eace5472e174dacdcb3fcad71125c Mon Sep 17 00:00:00 2001 From: dromanov Date: Fri, 5 May 2023 14:29:08 +0300 Subject: [PATCH 1/4] =?UTF-8?q?=F0=9F=8F=B7=EF=B8=8F=20Add=20types?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aiomysql/__init__.py | 73 ++++---- aiomysql/connection.py | 370 +++++++++++++++++++++++-------------- aiomysql/pool.py | 152 +++++++++------ aiomysql/py.typed | 0 aiomysql/sa/__init__.py | 33 ++-- aiomysql/sa/connection.py | 170 +++++++++++------ aiomysql/sa/engine.py | 175 ++++++++++++------ aiomysql/sa/transaction.py | 71 ++++--- aiomysql/utils.py | 245 ++++++++++-------------- 9 files changed, 773 insertions(+), 516 deletions(-) create mode 100644 aiomysql/py.typed diff --git a/aiomysql/__init__.py b/aiomysql/__init__.py index a367fcd2..f0cb58c4 100644 --- a/aiomysql/__init__.py +++ b/aiomysql/__init__.py @@ -23,47 +23,54 @@ """ +from typing import List, Type + from pymysql.converters import escape_dict, escape_sequence, escape_string -from pymysql.err import (Warning, Error, InterfaceError, DataError, - DatabaseError, OperationalError, IntegrityError, - InternalError, - NotSupportedError, ProgrammingError, MySQLError) +from pymysql.err import ( + Warning, + Error, + InterfaceError, + DataError, + DatabaseError, + OperationalError, + IntegrityError, + InternalError, + NotSupportedError, + ProgrammingError, + MySQLError, +) +from aiomysql.pool import create_pool, Pool +from ._version import version from .connection import Connection, connect from .cursors import Cursor, SSCursor, DictCursor, SSDictCursor -from .pool import create_pool, Pool -from ._version import version __version__ = version -__all__ = [ - +__all__: List[Type] = [ # Errors - 'Error', - 'DataError', - 'DatabaseError', - 'IntegrityError', - 'InterfaceError', - 'InternalError', - 'MySQLError', - 'NotSupportedError', - 'OperationalError', - 'ProgrammingError', - 'Warning', + Error, + DataError, + DatabaseError, + IntegrityError, + InterfaceError, + InternalError, + MySQLError, + NotSupportedError, + OperationalError, + ProgrammingError, + Warning, - 'escape_dict', - 'escape_sequence', - 'escape_string', + escape_dict, + escape_sequence, + escape_string, - 'Connection', - 'Pool', - 'connect', - 'create_pool', - 'Cursor', - 'SSCursor', - 'DictCursor', - 'SSDictCursor' + Connection, + Pool, + connect, + create_pool, + Cursor, + SSCursor, + DictCursor, + SSDictCursor, ] - -(Connection, Pool, connect, create_pool, Cursor, SSCursor, DictCursor, - SSDictCursor) # pyflakes diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 2c559f92..0b6eef47 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -2,42 +2,78 @@ # http://dev.mysql.com/doc/internals/en/client-server-protocol.html import asyncio +import configparser +import getpass import os import socket import struct import sys import warnings -import configparser -import getpass from functools import partial +from typing import ( + Union, + Mapping, + Any, + Optional, + List, + Tuple, + Dict, + Type +) + +from pymysql.charset import ( + charset_by_name, + charset_by_id +) +# noinspection PyUnresolvedReferences,PyProtectedMember +from pymysql.connections import ( + EOFPacketWrapper, + FieldDescriptorPacket, + LoadLocalPacketWrapper, + MysqlPacket, + OKPacketWrapper, + TEXT_TYPES, + MAX_PACKET_LEN, + DEFAULT_CHARSET, + _auth +) +# noinspection PyUnresolvedReferences +from pymysql.constants import ( + CLIENT, + COMMAND, + CR, + FIELD_TYPE, + SERVER_STATUS +) +# noinspection PyUnresolvedReferences +from pymysql.converters import ( + escape_item, + encoders, + decoders, + escape_string, + escape_bytes_prefixed, + through +) +from pymysql.err import ( + Warning, + Error, + InterfaceError, + DataError, + DatabaseError, + OperationalError, + IntegrityError, + InternalError, + NotSupportedError, + ProgrammingError +) -from pymysql.charset import charset_by_name, charset_by_id -from pymysql.constants import SERVER_STATUS -from pymysql.constants import CLIENT -from pymysql.constants import COMMAND -from pymysql.constants import CR -from pymysql.constants import FIELD_TYPE -from pymysql.converters import (escape_item, encoders, decoders, - escape_string, escape_bytes_prefixed, through) -from pymysql.err import (Warning, Error, - InterfaceError, DataError, DatabaseError, - OperationalError, - IntegrityError, InternalError, NotSupportedError, - ProgrammingError) - -from pymysql.connections import TEXT_TYPES, MAX_PACKET_LEN, DEFAULT_CHARSET -from pymysql.connections import _auth - -from pymysql.connections import MysqlPacket -from pymysql.connections import FieldDescriptorPacket -from pymysql.connections import EOFPacketWrapper -from pymysql.connections import OKPacketWrapper -from pymysql.connections import LoadLocalPacketWrapper - -# from aiomysql.utils import _convert_to_str from .cursors import Cursor -from .utils import _pack_int24, _lenenc_int, _ConnectionContextManager, _ContextManager from .log import logger +from .utils import ( + _pack_int24, + _lenenc_int, + _ContextManager +) try: DEFAULT_USER = getpass.getuser() @@ -45,15 +81,32 @@ DEFAULT_USER = "unknown" -def connect(host="localhost", user=None, password="", - db=None, port=3306, unix_socket=None, - charset='', sql_mode=None, - read_default_file=None, conv=decoders, use_unicode=None, - client_flag=0, cursorclass=Cursor, init_command=None, - connect_timeout=None, read_default_group=None, - autocommit=False, echo=False, - local_infile=False, loop=None, ssl=None, auth_plugin='', - program_name='', server_public_key=None): +async def connect( + host: str = "localhost", + user: Optional[str] = None, + password: str = "", + db: Optional[str] = None, + port: int = 3306, + unix_socket: Optional[str] = None, + charset: str = '', + sql_mode: Optional[str] = None, + read_default_file: Optional[str] = None, + conv: Optional[Dict[str, Any]] = None, + use_unicode: Optional[bool] = None, + client_flag: int = 0, + cursorclass: Type[Cursor] = Cursor, + init_command: Optional[str] = None, + connect_timeout: Optional[int] = None, + read_default_group: Optional[str] = None, + autocommit: bool = False, + echo: bool = False, + local_infile: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, + ssl: Optional[Union[bool, Dict[str, Any]]] = None, + auth_plugin: Optional[str] = '', + program_name: Optional[str] = '', + server_public_key: Optional[Union[str, bytes]] = None, +) -> 'Connection': """See connections.Connection.__init__() for information about defaults.""" coro = _connect(host=host, user=user, password=password, db=db, @@ -67,16 +120,28 @@ def connect(host="localhost", user=None, password="", autocommit=autocommit, echo=echo, local_infile=local_infile, loop=loop, ssl=ssl, auth_plugin=auth_plugin, program_name=program_name) - return _ConnectionContextManager(coro) + return _ContextManager[Connection](coro, disconnect) + + +async def disconnect(c: "Connection") -> None: + c.close() +# todo: Update Any to stricter kwarg +# https://github.com/python/mypy/issues/4441 async def _connect(*args, **kwargs): conn = Connection(*args, **kwargs) await conn._connect() return conn -async def _open_connection(host=None, port=None, **kwds): +# todo: Update Any to stricter kwarg +# https://github.com/python/mypy/issues/4441 +async def _open_connection( + host: Optional[str] = None, + port: Optional[int] = None, + **kwds: Any +) -> Tuple['_StreamReader', asyncio.StreamWriter]: """This is based on asyncio.open_connection, allowing us to use a custom StreamReader. @@ -91,13 +156,18 @@ async def _open_connection(host=None, port=None, **kwds): return reader, writer -async def _open_unix_connection(path=None, **kwds): +# todo: Update Any to stricter kwarg +# https://github.com/python/mypy/issues/4441 +async def _open_unix_connection( + path: Optional[str] = None, + **kwds: Any +) -> Tuple['_StreamReader', asyncio.StreamWriter]: """This is based on asyncio.open_unix_connection, allowing us to use a custom StreamReader. `limit` arg has been removed as we don't currently use it. """ - loop = asyncio.events.get_running_loop() + loop = asyncio.get_running_loop() reader = _StreamReader(loop=loop) protocol = asyncio.StreamReaderProtocol(reader, loop=loop) @@ -114,8 +184,9 @@ class _StreamReader(asyncio.StreamReader): `limit` arg has been removed as we don't currently use it. """ - def __init__(self, loop=None): - self._eof_received = False + + def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + self._eof_received: bool = False super().__init__(loop=loop) def feed_eof(self) -> None: @@ -123,7 +194,7 @@ def feed_eof(self) -> None: super().feed_eof() @property - def eof_received(self): + def eof_received(self) -> bool: return self._eof_received @@ -134,15 +205,33 @@ class Connection: connect(). """ - def __init__(self, host="localhost", user=None, password="", - db=None, port=3306, unix_socket=None, - charset='', sql_mode=None, - read_default_file=None, conv=decoders, use_unicode=None, - client_flag=0, cursorclass=Cursor, init_command=None, - connect_timeout=None, read_default_group=None, - autocommit=False, echo=False, - local_infile=False, loop=None, ssl=None, auth_plugin='', - program_name='', server_public_key=None): + def __init__( + self, + host: str = "localhost", + user: Optional[str] = None, + password: str = "", + db: Optional[str] = None, + port: int = 3306, + unix_socket: Optional[str] = None, + charset: str = '', + sql_mode: Optional[str] = None, + read_default_file: Optional[str] = None, + conv: dict = decoders, + use_unicode: Optional[bool] = None, + client_flag: int = 0, + cursorclass: type = Cursor, + init_command: Optional[str] = None, + connect_timeout: Optional[int] = None, + read_default_group: Optional[str] = None, + autocommit: Optional[bool] = False, + echo: Optional[bool] = False, + local_infile: Optional[bool] = False, + loop: Optional[asyncio.AbstractEventLoop] = None, + ssl: Optional[Union[str, Mapping[str, Any]]] = None, + auth_plugin: Optional[str] = '', + program_name: Optional[str] = '', + server_public_key: Optional[str] = None + ) -> None: """ Establish a connection to the MySQL database. Accepts several arguments: @@ -161,7 +250,7 @@ def __init__(self, host="localhost", user=None, password="", :param conv: Decoders dictionary to use instead of the default one. This is used to provide custom marshalling of types. See converters. - :param use_unicode: Whether or not to default to unicode strings. + :param use_unicode: Whether to default to unicode strings. :param client_flag: Custom flags to send to MySQL. Find potential values in constants.CLIENT. :param cursorclass: Custom cursor class to use. @@ -278,69 +367,69 @@ def __init__(self, host="localhost", user=None, password="", self._close_reason = None @property - def host(self): + def host(self) -> str: """MySQL server IP address or name""" return self._host @property - def port(self): + def port(self) -> int: """MySQL server TCP/IP port""" return self._port @property - def unix_socket(self): + def unix_socket(self) -> Optional[str]: """MySQL Unix socket file location""" return self._unix_socket @property - def db(self): + def db(self) -> Optional[str]: """Current database name.""" return self._db @property - def user(self): + def user(self) -> Optional[str]: """User used while connecting to MySQL""" return self._user @property - def echo(self): + def echo(self) -> Optional[bool]: """Return echo mode status.""" return self._echo @property - def last_usage(self): + def last_usage(self) -> Optional[float]: """Return time() when connection was used.""" return self._last_usage @property - def loop(self): + def loop(self) -> Optional[asyncio.AbstractEventLoop]: return self._loop @property - def closed(self): + def closed(self) -> Optional[bool]: """The readonly property that returns ``True`` if connections is closed. """ return self._writer is None @property - def encoding(self): + def encoding(self) -> str: """Encoding employed for this connection.""" return self._encoding @property - def charset(self): + def charset(self) -> str: """Returns the character set for current connection.""" return self._charset - def close(self): + def close(self) -> None: """Close socket connection""" if self._writer: self._writer.transport.close() self._writer = None self._reader = None - async def ensure_closed(self): + async def ensure_closed(self) -> None: """Send quit command and then close socket connection""" if self._writer is None: # connection has been closed @@ -350,7 +439,7 @@ async def ensure_closed(self): await self._writer.drain() self.close() - async def autocommit(self, value): + async def autocommit(self, value) -> None: """Enable/disable autocommit mode for current MySQL session. :param value: ``bool``, toggle autocommit @@ -360,7 +449,7 @@ async def autocommit(self, value): if value != current: await self._send_autocommit_mode() - def get_autocommit(self): + def get_autocommit(self) -> bool: """Returns autocommit status for current MySQL session. :returns bool: current autocommit status.""" @@ -368,7 +457,7 @@ def get_autocommit(self): status = self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT return bool(status) - async def _read_ok_packet(self): + async def _read_ok_packet(self) -> bool: pkt = await self._read_packet() if not pkt.is_ok_packet(): raise OperationalError(2014, "Command Out of Sync") @@ -376,41 +465,41 @@ async def _read_ok_packet(self): self.server_status = ok.server_status return True - async def _send_autocommit_mode(self): - """Set whether or not to commit after every execute() """ + async def _send_autocommit_mode(self) -> None: + """Set whether to commit after every execute() """ await self._execute_command( COMMAND.COM_QUERY, "SET AUTOCOMMIT = %s" % self.escape(self.autocommit_mode)) await self._read_ok_packet() - async def begin(self): + async def begin(self) -> None: """Begin transaction.""" await self._execute_command(COMMAND.COM_QUERY, "BEGIN") await self._read_ok_packet() - async def commit(self): + async def commit(self) -> None: """Commit changes to stable storage.""" await self._execute_command(COMMAND.COM_QUERY, "COMMIT") await self._read_ok_packet() - async def rollback(self): + async def rollback(self) -> None: """Roll back the current transaction.""" await self._execute_command(COMMAND.COM_QUERY, "ROLLBACK") await self._read_ok_packet() - async def select_db(self, db): + async def select_db(self, db) -> None: """Set current db""" await self._execute_command(COMMAND.COM_INIT_DB, db) await self._read_ok_packet() - async def show_warnings(self): + async def show_warnings(self) -> List[Tuple[int, str, Optional[int]]]: """SHOW WARNINGS""" await self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS") result = MySQLResult(self) await result.read() return result.rows - def escape(self, obj): + def escape(self, obj: Union[str, bytes, Any]) -> str: """ Escape whatever value you pass to it""" if isinstance(obj, str): return "'" + self.escape_string(obj) + "'" @@ -418,11 +507,11 @@ def escape(self, obj): return escape_bytes_prefixed(obj) return escape_item(obj, self._charset) - def literal(self, obj): + def literal(self, obj: Union[str, bytes, Any]) -> str: """Alias for escape()""" return self.escape(obj) - def escape_string(self, s): + def escape_string(self, s: str) -> str: if (self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES): return s.replace("'", "''") @@ -451,14 +540,14 @@ def cursor(self, *cursors): cur = cursors[0](self, self._echo) elif cursors: cursor_name = ''.join(map(lambda x: x.__name__, cursors)) \ - .replace('Cursor', '') + 'Cursor' + .replace('Cursor', '') + 'Cursor' cursor_class = type(cursor_name, cursors, {}) cur = cursor_class(self, self._echo) else: cur = self.cursorclass(self, self._echo) fut = self._loop.create_future() fut.set_result(cur) - return _ContextManager(fut) + return _ContextManager[Cursor](fut, _close_cursor) # The following methods are INTERNAL USE ONLY (called from Cursor) async def query(self, sql, unbuffered=False): @@ -615,7 +704,7 @@ async def _read_packet(self, packet_type=MysqlPacket): ' None: + try: + if unbuffered: result = MySQLResult(self) await result.init_unbuffered_query() - except BaseException: - result.unbuffered_active = False - result.connection = None - raise - else: - result = MySQLResult(self) - await result.read() - self._result = result - self._affected_rows = result.affected_rows - if result.server_status is not None: - self.server_status = result.server_status + else: + result = MySQLResult(self) + await result.read() + + self._result = result + self._affected_rows = result.affected_rows + self.server_status = getattr(result, 'server_status', None) + + except Exception: + self._result = None + self._affected_rows = -1 + self.server_status = None + raise def insert_id(self): if self._result: @@ -989,7 +1079,7 @@ async def caching_sha2_password_auth(self, pkt): pkt = await self._read_packet() pkt.check_error() - async def sha256_password_auth(self, pkt): + async def sha256_password_auth(self, pkt: Any) -> Any: if self._secure: logger.debug("sha256: Sending plain password") data = self._password.encode('latin1') + b'\0' @@ -1031,19 +1121,19 @@ async def sha256_password_auth(self, pkt): return pkt # _mysql support - def thread_id(self): + def thread_id(self) -> int: return self.server_thread_id[0] - def character_set_name(self): + def character_set_name(self) -> str: return self._charset - def get_host_info(self): + def get_host_info(self) -> str: return self.host_info - def get_proto_info(self): + def get_proto_info(self) -> Tuple[int, int, int]: return self.protocol_version - async def _get_server_information(self): + async def _get_server_information(self) -> None: i = 0 packet = await self._read_packet() data = packet.get_all_data() @@ -1105,7 +1195,7 @@ async def _get_server_information(self): else: self._server_auth_plugin = data[i:server_end].decode('latin1') - def get_transaction_status(self): + def get_transaction_status(self) -> bool: return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_IN_TRANS) def get_server_info(self): @@ -1113,18 +1203,18 @@ def get_server_info(self): # Just to always have consistent errors 2 helpers - def _close_on_cancel(self): + def _close_on_cancel(self) -> None: self.close() self._close_reason = "Cancelled during execution" - def _ensure_alive(self): + def _ensure_alive(self) -> None: if not self._writer: if self._close_reason is None: raise InterfaceError("(0, 'Not connected')") else: raise InterfaceError(self._close_reason) - def __del__(self): + def __del__(self) -> None: if self._writer: warnings.warn("Unclosed connection {!r}".format(self), ResourceWarning) @@ -1142,24 +1232,28 @@ def __del__(self): NotSupportedError = NotSupportedError +async def _close_cursor(c: Cursor) -> None: + await c.close() + + # TODO: move OK and EOF packet parsing/logic into a proper subclass # of MysqlPacket like has been done with FieldDescriptorPacket. class MySQLResult: - def __init__(self, connection): - self.connection = connection - self.affected_rows = None - self.insert_id = None - self.server_status = None - self.warning_count = 0 - self.message = None - self.field_count = 0 - self.description = None - self.rows = None - self.has_next = None - self.unbuffered_active = False - - async def read(self): + def __init__(self, connection: Connection) -> None: + self.connection: Connection = connection + self.affected_rows: Optional[int] = None + self.insert_id: Optional[int] = None + self.server_status: Optional[int] = None + self.warning_count: int = 0 + self.message: Optional[str] = None + self.field_count: int = 0 + self.description: Optional[List] = None + self.rows: Optional[List] = None + self.has_next: Optional[bool] = None + self.unbuffered_active: bool = False + + async def read(self) -> None: try: first_packet = await self.connection._read_packet() @@ -1173,7 +1267,7 @@ async def read(self): finally: self.connection = None - async def init_unbuffered_query(self): + async def init_unbuffered_query(self) -> None: self.unbuffered_active = True first_packet = await self.connection._read_packet() @@ -1194,7 +1288,7 @@ async def init_unbuffered_query(self): # we set it to this instead of None, which would be preferred. self.affected_rows = 18446744073709551615 - def _read_ok_packet(self, first_packet): + def _read_ok_packet(self, first_packet) -> None: ok_packet = OKPacketWrapper(first_packet) self.affected_rows = ok_packet.affected_rows self.insert_id = ok_packet.insert_id @@ -1203,7 +1297,7 @@ def _read_ok_packet(self, first_packet): self.message = ok_packet.message self.has_next = ok_packet.has_next - async def _read_load_local_packet(self, first_packet): + async def _read_load_local_packet(self, first_packet) -> None: load_packet = LoadLocalPacketWrapper(first_packet) sender = LoadLocalFile(load_packet.filename, self.connection) try: @@ -1218,7 +1312,7 @@ async def _read_load_local_packet(self, first_packet): raise OperationalError(2014, "Commands Out of Sync") self._read_ok_packet(ok_packet) - def _check_packet_is_eof(self, packet): + def _check_packet_is_eof(self, packet) -> bool: if packet.is_eof_packet(): eof_packet = EOFPacketWrapper(packet) self.warning_count = eof_packet.warning_count @@ -1226,12 +1320,12 @@ def _check_packet_is_eof(self, packet): return True return False - async def _read_result_packet(self, first_packet): + async def _read_result_packet(self, first_packet) -> None: self.field_count = first_packet.read_length_encoded_integer() await self._get_descriptions() await self._read_rowdata_packet() - async def _read_rowdata_packet_unbuffered(self): + async def _read_rowdata_packet_unbuffered(self) -> Optional[tuple[Optional[Any], ...]]: # Check if in an active query if not self.unbuffered_active: return @@ -1249,7 +1343,7 @@ async def _read_rowdata_packet_unbuffered(self): self.rows = (row,) return row - async def _finish_unbuffered_query(self): + async def _finish_unbuffered_query(self) -> None: # After much reading on the MySQL protocol, it appears that there is, # in fact, no way to stop MySQL from sending all the data after # executing a query, so we just spin, and wait for an EOF packet. @@ -1260,8 +1354,8 @@ async def _finish_unbuffered_query(self): # TODO: replace these numbers with constants when available # TODO: in a new PyMySQL release if e.args[0] in ( - 3024, # ER.QUERY_TIMEOUT - 1969, # ER.STATEMENT_TIMEOUT + 3024, # ER.QUERY_TIMEOUT + 1969, # ER.STATEMENT_TIMEOUT ): # if the query timed out we can simply ignore this error self.unbuffered_active = False @@ -1275,7 +1369,7 @@ async def _finish_unbuffered_query(self): # release reference to kill cyclic reference. self.connection = None - async def _read_rowdata_packet(self): + async def _read_rowdata_packet(self) -> None: """Read a rowdata packet for each data row in the result set.""" rows = [] while True: @@ -1289,7 +1383,7 @@ async def _read_rowdata_packet(self): self.affected_rows = len(rows) self.rows = tuple(rows) - def _read_row_from_packet(self, packet): + def _read_row_from_packet(self, packet) -> tuple[Optional[Any], ...]: row = [] for encoding, converter in self.converters: try: @@ -1306,7 +1400,7 @@ def _read_row_from_packet(self, packet): row.append(data) return tuple(row) - async def _get_descriptions(self): + async def _get_descriptions(self) -> None: """Read a column descriptor packet for each column in the result.""" self.fields = [] self.converters = [] @@ -1359,9 +1453,9 @@ def __init__(self, filename, connection): self._file_object = None self._executor = None # means use default executor - def _open_file(self): + def _open_file(self) -> asyncio.Future: - def opener(filename): + def opener(filename) -> None: try: self._file_object = open(filename, 'rb') except IOError as e: @@ -1371,7 +1465,7 @@ def opener(filename): fut = self._loop.run_in_executor(self._executor, opener, self.filename) return fut - def _file_read(self, chunk_size): + def _file_read(self, chunk_size: int) -> asyncio.Future: def freader(chunk_size): try: @@ -1391,7 +1485,7 @@ def freader(chunk_size): fut = self._loop.run_in_executor(self._executor, freader, chunk_size) return fut - async def send_data(self): + async def send_data(self) -> None: """Send data packets from the local file to the server""" self.connection._ensure_alive() conn = self.connection diff --git a/aiomysql/pool.py b/aiomysql/pool.py index eaaddbe0..f95aaae6 100644 --- a/aiomysql/pool.py +++ b/aiomysql/pool.py @@ -4,74 +4,115 @@ import asyncio import collections import warnings - -from .connection import connect -from .utils import (_PoolContextManager, _PoolConnectionContextManager, - _PoolAcquireContextManager) - - -def create_pool(minsize=1, maxsize=10, echo=False, pool_recycle=-1, - loop=None, **kwargs): +from types import TracebackType +from typing import ( + Optional, + Any, + Deque, + Type +) + +from aiomysql.connection import ( + connect, + Connection +) +from aiomysql.utils import ( + _ContextManager +) + + +# todo: Update Any to stricter kwarg +# https://github.com/python/mypy/issues/4441 +def create_pool( + minsize: int = 1, + maxsize: int = 10, + echo: bool = False, + pool_recycle: int = -1, + loop: Optional[asyncio.AbstractEventLoop] = None, + **kwargs: Any) -> _ContextManager["Pool"]: coro = _create_pool(minsize=minsize, maxsize=maxsize, echo=echo, pool_recycle=pool_recycle, loop=loop, **kwargs) - return _PoolContextManager(coro) + return _ContextManager[Pool](coro, _destroy_pool) + + +async def _destroy_pool(pool: "Pool") -> None: + pool.close() + await pool.wait_closed() -async def _create_pool(minsize=1, maxsize=10, echo=False, pool_recycle=-1, - loop=None, **kwargs): +# todo: Update Any to stricter kwarg +# https://github.com/python/mypy/issues/4441 +async def _create_pool( + minsize: int = 1, + maxsize: int = 10, + echo: bool = False, + pool_recycle: int = -1, + loop: Optional[asyncio.AbstractEventLoop] = None, + **kwargs: Any +) -> 'Pool': if loop is None: loop = asyncio.get_event_loop() - pool = Pool(minsize=minsize, maxsize=maxsize, echo=echo, - pool_recycle=pool_recycle, loop=loop, **kwargs) + pool: Pool = Pool(minsize=minsize, maxsize=maxsize, echo=echo, + pool_recycle=pool_recycle, loop=loop, **kwargs) + if minsize > 0: async with pool._cond: await pool._fill_free_pool(False) + return pool class Pool(asyncio.AbstractServer): """Connection pool""" - def __init__(self, minsize, maxsize, echo, pool_recycle, loop, **kwargs): + def __init__( + self, + minsize: int, + maxsize: int, + echo: bool, + pool_recycle: int, + loop: asyncio.AbstractEventLoop, + **kwargs: Any + ) -> None: if minsize < 0: raise ValueError("minsize should be zero or greater") if maxsize < minsize and maxsize != 0: raise ValueError("maxsize should be not less than minsize") - self._minsize = minsize - self._loop = loop - self._conn_kwargs = kwargs - self._acquiring = 0 - self._free = collections.deque(maxlen=maxsize or None) - self._cond = asyncio.Condition() - self._used = set() - self._terminated = set() - self._closing = False - self._closed = False - self._echo = echo - self._recycle = pool_recycle + self._minsize: int = minsize + self._loop: asyncio.AbstractEventLoop = loop + self._conn_kwargs: dict[str, Any] = kwargs + self._acquiring: int = 0 + self._free: Deque[Any] = collections.deque(maxlen=maxsize or None) + self._cond: asyncio.Condition = asyncio.Condition() + self._used: set[Any] = set() + self._terminated: set[Any] = set() + self._closing: bool = False + self._closed: bool = False + self._echo: bool = echo + self._recycle: int = pool_recycle @property - def echo(self): + def echo(self) -> bool: return self._echo @property - def minsize(self): + def minsize(self) -> int: return self._minsize @property - def maxsize(self): + def maxsize(self) -> int: return self._free.maxlen @property - def size(self): + def size(self) -> int: return self.freesize + len(self._used) + self._acquiring @property - def freesize(self): + def freesize(self) -> int: return len(self._free) - async def clear(self): + async def clear(self) -> None: """Close all free connections in pool.""" async with self._cond: while self._free: @@ -80,28 +121,27 @@ async def clear(self): self._cond.notify() @property - def closed(self): + def closed(self) -> bool: """ The readonly property that returns ``True`` if connections is closed. """ return self._closed - def close(self): + def close(self) -> None: """Close pool. Mark all pool connections to be closed on getting back to pool. - Closed pool doesn't allow to acquire new connections. + Closed pool doesn't allow acquiring new connections. """ if self._closed: return self._closing = True - def terminate(self): + def terminate(self) -> None: """Terminate pool. Close pool with instantly closing all acquired connections also. """ - self.close() for conn in list(self._used): @@ -110,9 +150,8 @@ def terminate(self): self._used.clear() - async def wait_closed(self): + async def wait_closed(self) -> None: """Wait for closing all pool's connections.""" - if self._closed: return if not self._closing: @@ -129,12 +168,12 @@ async def wait_closed(self): self._closed = True - def acquire(self): + async def acquire(self) -> _ContextManager: """Acquire free connection from the pool.""" coro = self._acquire() - return _PoolAcquireContextManager(coro, self) + return _ContextManager[Connection](coro, self.release) - async def _acquire(self): + async def _acquire(self) -> Connection: if self._closing: raise RuntimeError("Cannot acquire connection after closing pool") async with self._cond: @@ -149,7 +188,7 @@ async def _acquire(self): else: await self._cond.wait() - async def _fill_free_pool(self, override_min): + async def _fill_free_pool(self, override_min: bool) -> None: # iterate over free connections and remove timed out ones free_size = len(self._free) n = 0 @@ -167,8 +206,7 @@ async def _fill_free_pool(self, override_min): self._free.pop() conn.close() - elif (self._recycle > -1 and - self._loop.time() - conn.last_usage > self._recycle): + elif -1 < self._recycle < self._loop.time() - conn.last_usage: self._free.pop() conn.close() @@ -200,11 +238,11 @@ async def _fill_free_pool(self, override_min): finally: self._acquiring -= 1 - async def _wakeup(self): + async def _wakeup(self) -> None: async with self._cond: self._cond.notify() - def release(self, conn): + def release(self, conn: Any) -> asyncio.Future: """Release free connection back to the connection pool. This is **NOT** a coroutine. @@ -230,16 +268,18 @@ def release(self, conn): fut = self._loop.create_task(self._wakeup()) return fut - def __enter__(self): + def __enter__(self) -> None: raise RuntimeError( '"yield from" should be used as context manager expression') - def __exit__(self, *args): + # todo: Update Any to stricter kwarg + # https://github.com/python/mypy/issues/4441 + def __exit__(self, *args: Any) -> None: # This must exist because __enter__ exists, even though that # always raises; that's how the with-statement works. pass # pragma: nocover - def __iter__(self): + def __iter__(self) -> _ContextManager: # This is not a coroutine. It is meant to enable the idiom: # # with (yield from pool) as conn: @@ -253,18 +293,20 @@ def __iter__(self): # finally: # conn.release() conn = yield from self.acquire() - return _PoolConnectionContextManager(self, conn) + return _ContextManager[Connection](conn, self.release) - def __await__(self): + def __await__(self) -> _ContextManager: msg = "with await pool as conn deprecated, use" \ "async with pool.acquire() as conn instead" warnings.warn(msg, DeprecationWarning, stacklevel=2) conn = yield from self.acquire() - return _PoolConnectionContextManager(self, conn) + return _ContextManager[Connection](conn, self.release) - async def __aenter__(self): + async def __aenter__(self) -> 'Pool': return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> None: self.close() await self.wait_closed() diff --git a/aiomysql/py.typed b/aiomysql/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/aiomysql/sa/__init__.py b/aiomysql/sa/__init__.py index 4927ec9d..d679f4a0 100644 --- a/aiomysql/sa/__init__.py +++ b/aiomysql/sa/__init__.py @@ -1,14 +1,23 @@ -"""Optional support for sqlalchemy.sql dynamic query generation.""" from .connection import SAConnection -from .engine import create_engine, Engine -from .exc import (Error, ArgumentError, InvalidRequestError, - NoSuchColumnError, ResourceClosedError) +from .engine import ( + create_engine, + Engine +) +from .exc import ( + Error, + ArgumentError, + InvalidRequestError, + NoSuchColumnError, + ResourceClosedError +) - -__all__ = ('create_engine', 'SAConnection', 'Error', - 'ArgumentError', 'InvalidRequestError', 'NoSuchColumnError', - 'ResourceClosedError', 'Engine') - - -(SAConnection, Error, ArgumentError, InvalidRequestError, - NoSuchColumnError, ResourceClosedError, create_engine, Engine) +__all__ = [ + "SAConnection", + "Error", + "ArgumentError", + "InvalidRequestError", + "NoSuchColumnError", + "ResourceClosedError", + "create_engine", + "Engine" +] diff --git a/aiomysql/sa/connection.py b/aiomysql/sa/connection.py index cab3961a..47202494 100644 --- a/aiomysql/sa/connection.py +++ b/aiomysql/sa/connection.py @@ -1,25 +1,54 @@ # ported from: # https://github.com/aio-libs/aiopg/blob/master/aiopg/sa/connection.py import weakref +from types import TracebackType +from typing import ( + Any, + Dict, + Union, + List, + Tuple, + Optional, + Type +) from sqlalchemy.sql import ClauseElement -from sqlalchemy.sql.dml import UpdateBase from sqlalchemy.sql.ddl import DDLElement +from sqlalchemy.sql.dml import UpdateBase + +from . import exc, Engine +from .result import create_result_proxy, ResultProxy +from .transaction import ( + RootTransaction, + Transaction, + NestedTransaction, + TwoPhaseTransaction +) +from .. import Cursor, Connection +from ..utils import _IterableContextManager + -from . import exc -from .result import create_result_proxy -from .transaction import (RootTransaction, Transaction, - NestedTransaction, TwoPhaseTransaction) -from ..utils import _TransactionContextManager, _SAConnectionContextManager +async def _commit_transaction_if_active(t: Transaction) -> None: + if t.is_active: + await t.commit() -def noop(k): - return k +async def _rollback_transaction(t: Transaction) -> None: + await t.rollback() + + +async def _close_result_proxy(c: 'ResultProxy') -> None: + await c.close() class SAConnection: - def __init__(self, connection, engine, compiled_cache=None): + def __init__( + self, + connection: Connection, + engine: Engine, + compiled_cache: Optional[Any] = None, + ) -> None: self._connection = connection self._transaction = None self._savepoint_seq = 0 @@ -28,7 +57,14 @@ def __init__(self, connection, engine, compiled_cache=None): self._dialect = engine.dialect self._compiled_cache = compiled_cache - def execute(self, query, *multiparams, **params): + # todo: Update Any to stricter kwarg + # https://github.com/python/mypy/issues/4441 + def execute( + self, + query: Union[str, DDLElement, ClauseElement], + *multiparams: Any, + **params: Any + ) -> _IterableContextManager: """Executes a SQL query with optional parameters. query - a SQL query string or any sqlalchemy expression. @@ -66,9 +102,15 @@ def execute(self, query, *multiparams, **params): """ coro = self._execute(query, *multiparams, **params) - return _SAConnectionContextManager(coro) + return _IterableContextManager[ResultProxy](coro, _close_result_proxy) - def _base_params(self, query, dp, compiled, is_update): + def _base_params( + self, + query: Union[str, DDLElement, ClauseElement], + dp: Any, + compiled: Any, + is_update: bool, + ) -> Dict[str, Any]: """ handle params """ @@ -84,24 +126,29 @@ def _base_params(self, query, dp, compiled, is_update): compiled_params = compiled.construct_params(dp) processors = compiled._bind_processors params = [{ - key: processors.get(key, noop)(compiled_params[key]) + key: processors.get(key, lambda val: val)(compiled_params[key]) for key in compiled_params }] post_processed_params = self._dialect.execute_sequence_format(params) return post_processed_params[0] - async def _executemany(self, query, dps, cursor): + async def _executemany( + self, + query: Union[str, DDLElement, ClauseElement], + dps: Any, + cursor: Cursor, + ) -> ResultProxy: """ executemany """ - result_map = None + result_map: Optional[Dict[str, Any]] = None if isinstance(query, str): await cursor.executemany(query, dps) elif isinstance(query, DDLElement): raise exc.ArgumentError( - "Don't mix sqlalchemy DDL clause " - "and execution with parameters" - ) + "Don't mix sqlalchemy DDL clause " + "and execution with parameters" + ) elif isinstance(query, ClauseElement): compiled = query.compile(dialect=self._dialect) params = [] @@ -132,7 +179,14 @@ async def _executemany(self, query, dps, cursor): self._weak_results.add(ret) return ret - async def _execute(self, query, *multiparams, **params): + # todo: Update Any to stricter kwarg + # https://github.com/python/mypy/issues/4441 + async def _execute( + self, + query: Union[str, DDLElement, ClauseElement], + *multiparams: Any, + **params: Any + ): cursor = await self._connection.cursor() dp = _distill_params(multiparams, params) if len(dp) > 1: @@ -182,10 +236,17 @@ async def _execute(self, query, *multiparams, **params): self._weak_results.add(ret) return ret - async def scalar(self, query, *multiparams, **params): + # todo: Update Any to stricter kwarg + # https://github.com/python/mypy/issues/4441 + async def scalar( + self, + query: Union[str, DDLElement, ClauseElement], + *multiparams: Any, + **params: Any + ) -> Optional[Any]: """Executes a SQL query and returns a scalar value.""" res = await self.execute(query, *multiparams, **params) - return (await res.scalar()) + return await res.scalar() @property def closed(self): @@ -216,7 +277,7 @@ def begin(self): Calls to .commit only have an effect when invoked via the outermost Transaction object, though the .rollback method of - any of the Transaction objects will roll back the transaction. + the Transaction objects will roll back the transaction. See also: .begin_nested - use a SAVEPOINT @@ -224,9 +285,9 @@ def begin(self): """ coro = self._begin() - return _TransactionContextManager(coro) + return _IterableContextManager(coro) - async def _begin(self): + async def _begin(self) -> 'Transaction': if self._transaction is None: self._transaction = RootTransaction(self) await self._begin_impl() @@ -234,14 +295,14 @@ async def _begin(self): else: return Transaction(self, self._transaction) - async def _begin_impl(self): + async def _begin_impl(self) -> None: cur = await self._connection.cursor() try: await cur.execute('BEGIN') finally: await cur.close() - async def _commit_impl(self): + async def _commit_impl(self) -> None: cur = await self._connection.cursor() try: await cur.execute('COMMIT') @@ -249,7 +310,7 @@ async def _commit_impl(self): await cur.close() self._transaction = None - async def _rollback_impl(self): + async def _rollback_impl(self) -> None: cur = await self._connection.cursor() try: await cur.execute('ROLLBACK') @@ -257,7 +318,7 @@ async def _rollback_impl(self): await cur.close() self._transaction = None - async def begin_nested(self): + async def begin_nested(self) -> 'Transaction': """Begin a nested transaction and return a transaction handle. The returned object is an instance of :class:`.NestedTransaction`. @@ -276,34 +337,34 @@ async def begin_nested(self): self._transaction._savepoint = await self._savepoint_impl() return self._transaction - async def _savepoint_impl(self, name=None): + async def _savepoint_impl(self, name: Optional[str] = None) -> str: self._savepoint_seq += 1 - name = 'aiomysql_sa_savepoint_%s' % self._savepoint_seq + name = f'aiomysql_sa_savepoint_{self._savepoint_seq}' cur = await self._connection.cursor() try: - await cur.execute('SAVEPOINT ' + name) + await cur.execute(f'SAVEPOINT {name}') return name finally: await cur.close() - async def _rollback_to_savepoint_impl(self, name, parent): + async def _rollback_to_savepoint_impl(self, name: str, parent: Optional[Transaction] = None) -> None: cur = await self._connection.cursor() try: - await cur.execute('ROLLBACK TO SAVEPOINT ' + name) + await cur.execute(f'ROLLBACK TO SAVEPOINT {name}') finally: await cur.close() self._transaction = parent - async def _release_savepoint_impl(self, name, parent): + async def _release_savepoint_impl(self, name: str, parent: Optional[Transaction] = None) -> None: cur = await self._connection.cursor() try: - await cur.execute('RELEASE SAVEPOINT ' + name) + await cur.execute(f'RELEASE SAVEPOINT {name}') finally: await cur.close() self._transaction = parent - async def begin_twophase(self, xid=None): + async def begin_twophase(self, xid: Optional[str] = None) -> TwoPhaseTransaction: """Begin a two-phase or XA transaction and return a transaction handle. @@ -323,36 +384,36 @@ async def begin_twophase(self, xid=None): if xid is None: xid = self._dialect.create_xid() self._transaction = TwoPhaseTransaction(self, xid) - await self.execute("XA START %s", xid) + await self.execute(f"XA START {xid}") return self._transaction - async def _prepare_twophase_impl(self, xid): - await self.execute("XA END '%s'" % xid) - await self.execute("XA PREPARE '%s'" % xid) + async def _prepare_twophase_impl(self, xid: str) -> None: + await self.execute(f"XA END '{xid}'") + await self.execute(f"XA PREPARE '{xid}'") - async def recover_twophase(self): + async def recover_twophase(self) -> List[str]: """Return a list of prepared twophase transaction ids.""" result = await self.execute("XA RECOVER;") return [row[0] for row in result] - async def rollback_prepared(self, xid, *, is_prepared=True): + async def rollback_prepared(self, xid: str, *, is_prepared: bool = True) -> None: """Rollback prepared twophase transaction.""" if not is_prepared: - await self.execute("XA END '%s'" % xid) - await self.execute("XA ROLLBACK '%s'" % xid) + await self.execute(f"XA END '{xid}'") + await self.execute(f"XA ROLLBACK '{xid}'") - async def commit_prepared(self, xid, *, is_prepared=True): + async def commit_prepared(self, xid: str, *, is_prepared: bool = True) -> None: """Commit prepared twophase transaction.""" if not is_prepared: - await self.execute("XA END '%s'" % xid) - await self.execute("XA COMMIT '%s'" % xid) + await self.execute(f"XA END '{xid}'") + await self.execute(f"XA COMMIT '{xid}'") @property - def in_transaction(self): + def in_transaction(self) -> bool: """Return True if a transaction is in progress.""" return self._transaction is not None and self._transaction.is_active - async def close(self): + async def close(self) -> None: """Close this SAConnection. This results in a release of the underlying database @@ -378,14 +439,19 @@ async def close(self): self._connection = None self._engine = None - async def __aenter__(self): + async def __aenter__(self) -> 'SAConnection': return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: await self.close() -def _distill_params(multiparams, params): +def _distill_params(multiparams: Any, params: Any) -> List[Union[Dict[str, Any], List[Tuple[Any, ...]]]]: """Given arguments from the calling form *multiparams, **params, return a list of bind parameter structures, usually a list of dictionaries. diff --git a/aiomysql/sa/engine.py b/aiomysql/sa/engine.py index 243e5001..1367c9eb 100644 --- a/aiomysql/sa/engine.py +++ b/aiomysql/sa/engine.py @@ -1,14 +1,34 @@ # ported from: # https://github.com/aio-libs/aiopg/blob/master/aiopg/sa/engine.py import asyncio +from types import TracebackType +from typing import ( + Optional, + Dict, + Any, + MutableMapping, + Union +) + +from sqlalchemy import ( + Engine, + Dialect +) import aiomysql from .connection import SAConnection -from .exc import InvalidRequestError, ArgumentError -from ..utils import _PoolContextManager, _PoolAcquireContextManager +from .exc import ( + InvalidRequestError, + ArgumentError +) from ..cursors import ( - Cursor, DeserializationCursor, DictCursor, SSCursor, SSDictCursor) - + Cursor, + DeserializationCursor, + DictCursor, + SSCursor, + SSDictCursor +) +from ..utils import _ContextManager try: from sqlalchemy.dialects.mysql.pymysql import MySQLDialect_pymysql @@ -17,8 +37,15 @@ raise ImportError('aiomysql.sa requires sqlalchemy') +# noinspection PyPep8Naming,PyAbstractClass class MySQLCompiler_pymysql(MySQLCompiler_mysqldb): - def construct_params(self, params=None, _group_number=None, _check=True): + def construct_params( + self, + params: Optional[Dict[str, Any]] = None, + _group_number: Optional[int] = None, + _check: bool = True, + **kwargs: Any + ) -> MutableMapping[str, Any]: pd = super().construct_params(params, _group_number, _check) for column in self.prefetch: @@ -26,7 +53,7 @@ def construct_params(self, params=None, _group_number=None, _check=True): return pd - def _exec_default(self, default): + def _exec_default(self, default: Any) -> Any: if default.is_callable: return default.arg(self.dialect) else: @@ -38,9 +65,17 @@ def _exec_default(self, default): _dialect.default_paramstyle = 'pyformat' -def create_engine(minsize=1, maxsize=10, loop=None, - dialect=_dialect, pool_recycle=-1, compiled_cache=None, - **kwargs): +# todo: Update Any to stricter kwarg +# https://github.com/python/mypy/issues/4441 +def create_engine( + minsize: int = 1, + maxsize: int = 10, + loop: Optional[asyncio.AbstractEventLoop] = None, + dialect: Dialect = _dialect, + pool_recycle: int = -1, + compiled_cache: Optional[Dict[str, Any]] = None, + **kwargs: Union[str, int, bool, Any] +): """A coroutine for Engine creation. Returns Engine instance with embedded connection pool. @@ -53,22 +88,37 @@ def create_engine(minsize=1, maxsize=10, loop=None, cursorclass = kwargs.get('cursorclass', Cursor) if not issubclass(cursorclass, Cursor) or any( - issubclass(cursorclass, cursor_class) - for cursor_class in deprecated_cursor_classes + issubclass(cursorclass, cursor_class) + for cursor_class in deprecated_cursor_classes ): - raise ArgumentError('SQLAlchemy engine does not support ' - 'this cursor class') + raise ArgumentError(f"The cursor class '{cursorclass.__name__}' is not supported by the SQLAlchemy engine.") coro = _create_engine(minsize=minsize, maxsize=maxsize, loop=loop, dialect=dialect, pool_recycle=pool_recycle, compiled_cache=compiled_cache, **kwargs) - return _EngineContextManager(coro) + return _ContextManager(coro, _close_engine) + + +async def _close_engine(engine: 'Engine') -> None: + engine.close() + await engine.wait_closed() -async def _create_engine(minsize=1, maxsize=10, loop=None, - dialect=_dialect, pool_recycle=-1, - compiled_cache=None, **kwargs): +async def _close_connection(c: SAConnection) -> None: + await c.close() + +# todo: Update Any to stricter kwarg +# https://github.com/python/mypy/issues/4441 +async def _create_engine( + minsize: int = 1, + maxsize: int = 10, + loop: Optional[asyncio.AbstractEventLoop] = None, + dialect: Dialect = _dialect, + pool_recycle: int = -1, + compiled_cache: Optional[Dict[str, Any]] = None, + **kwargs: Any +) -> Engine: if loop is None: loop = asyncio.get_event_loop() pool = await aiomysql.create_pool(minsize=minsize, maxsize=maxsize, @@ -78,7 +128,7 @@ async def _create_engine(minsize=1, maxsize=10, loop=None, try: return Engine(dialect, pool, compiled_cache=compiled_cache, **kwargs) finally: - pool.release(conn) + await pool.release(conn) class Engine: @@ -90,52 +140,60 @@ class Engine: create_engine coroutine. """ - def __init__(self, dialect, pool, compiled_cache=None, **kwargs): + # todo: Update Any to stricter kwarg + # https://github.com/python/mypy/issues/4441 + def __init__( + self, + dialect: Dialect, + pool: Any, + compiled_cache: Any = None, + **kwargs: Any + ) -> None: self._dialect = dialect self._pool = pool self._compiled_cache = compiled_cache self._conn_kw = kwargs @property - def dialect(self): + def dialect(self) -> Dialect: """An dialect for engine.""" return self._dialect @property - def name(self): + def name(self) -> Dialect.name: """A name of the dialect.""" return self._dialect.name @property - def driver(self): + def driver(self) -> Dialect.driver: """A driver of the dialect.""" return self._dialect.driver @property - def minsize(self): + def minsize(self) -> int: return self._pool.minsize @property - def maxsize(self): + def maxsize(self) -> int: return self._pool.maxsize @property - def size(self): + def size(self) -> int: return self._pool.size @property - def freesize(self): + def freesize(self) -> int: return self._pool.freesize - def close(self): + def close(self) -> None: """Close engine. Mark all engine connections to be closed on getting back to pool. - Closed engine doesn't allow to acquire new connections. + Closed engine doesn't allow acquiring new connections. """ self._pool.close() - def terminate(self): + def terminate(self) -> None: """Terminate engine. Terminate engine pool with instantly closing all acquired @@ -143,22 +201,21 @@ def terminate(self): """ self._pool.terminate() - async def wait_closed(self): + async def wait_closed(self) -> None: """Wait for closing all engine's connections.""" await self._pool.wait_closed() - def acquire(self): + def acquire(self) -> _ContextManager: """Get a connection from pool.""" coro = self._acquire() - return _EngineAcquireContextManager(coro, self) + return _ContextManager[SAConnection](coro, _close_connection) - async def _acquire(self): + async def _acquire(self) -> SAConnection: raw = await self._pool.acquire() - conn = SAConnection(raw, self, compiled_cache=self._compiled_cache) - return conn + return SAConnection(raw, self, compiled_cache=self._compiled_cache) - def release(self, conn): - """Revert back connection to pool.""" + def release(self, conn: SAConnection) -> None: + """Revert connection to pool.""" if conn.in_transaction: raise InvalidRequestError("Cannot release a connection with " "not finished transaction") @@ -167,41 +224,47 @@ def release(self, conn): def __enter__(self): raise RuntimeError( - '"yield from" should be used as context manager expression') - - def __exit__(self, *args): + '"await" should be used as context manager expression') + + def __exit__( + self, + exc_type: Optional[type], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType] + ) -> None: # This must exist because __enter__ exists, even though that # always raises; that's how the with-statement works. pass # pragma: nocover - def __iter__(self): - # This is not a coroutine. It is meant to enable the idiom: + async def __aiter__(self) -> '_ConnectionContextManager': + # This is not a coroutine. It is meant to enable the idiom: # - # with (yield from engine) as conn: + # async with engine as conn: # # # as an alternative to: # - # conn = yield from engine.acquire() + # conn = await engine.acquire() # try: # # finally: # engine.release(conn) - conn = yield from self.acquire() + conn = await self.acquire() return _ConnectionContextManager(self, conn) - async def __aenter__(self): + async def __aenter__(self) -> 'Engine': return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, + exc_type: Optional[type], + exc_val: Optional[BaseException], + exc_tb: Optional[Any] + ) -> None: self.close() await self.wait_closed() -_EngineContextManager = _PoolContextManager -_EngineAcquireContextManager = _PoolAcquireContextManager - - class _ConnectionContextManager: """Context manager. @@ -219,15 +282,19 @@ class _ConnectionContextManager: __slots__ = ('_engine', '_conn') - def __init__(self, engine, conn): + def __init__( + self, + engine: Engine, + conn: SAConnection + ): self._engine = engine self._conn = conn - def __enter__(self): + def __enter__(self) -> SAConnection: assert self._conn is not None return self._conn - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: try: self._engine.release(self._conn) finally: diff --git a/aiomysql/sa/transaction.py b/aiomysql/sa/transaction.py index ff15ac08..9ef4643c 100644 --- a/aiomysql/sa/transaction.py +++ b/aiomysql/sa/transaction.py @@ -1,5 +1,12 @@ # ported from: # https://github.com/aio-libs/aiopg/blob/master/aiopg/sa/transaction.py +from types import TracebackType +from typing import ( + Any, + Optional, + Type +) + from . import exc @@ -26,22 +33,26 @@ class Transaction(object): SAConnection.begin_nested(). """ - def __init__(self, connection, parent): + def __init__( + self, + connection: Any, + parent: Optional['Transaction'] + ) -> None: self._connection = connection self._parent = parent or self self._is_active = True @property - def is_active(self): + def is_active(self) -> bool: """Return ``True`` if a transaction is active.""" return self._is_active @property - def connection(self): + def connection(self) -> Any: """Return transaction's connection (SAConnection instance).""" return self._connection - async def close(self): + async def close(self) -> None: """Close this transaction. If this transaction is the base transaction in a begin/commit @@ -58,17 +69,17 @@ async def close(self): else: self._is_active = False - async def rollback(self): + async def rollback(self) -> None: """Roll back this transaction.""" if not self._parent._is_active: return await self._do_rollback() self._is_active = False - async def _do_rollback(self): + async def _do_rollback(self) -> None: await self._parent.rollback() - async def commit(self): + async def commit(self) -> None: """Commit this transaction.""" if not self._parent._is_active: @@ -76,13 +87,18 @@ async def commit(self): await self._do_commit() self._is_active = False - async def _do_commit(self): + async def _do_commit(self) -> None: pass - async def __aenter__(self): + async def __aenter__(self) -> 'Transaction': return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType] + ) -> None: if exc_type: await self.rollback() else: @@ -92,13 +108,16 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): class RootTransaction(Transaction): - def __init__(self, connection): + def __init__( + self, + connection: Any + ) -> None: super().__init__(connection, None) - async def _do_rollback(self): + async def _do_rollback(self) -> None: await self._connection._rollback_impl() - async def _do_commit(self): + async def _do_commit(self) -> None: await self._connection._commit_impl() @@ -111,18 +130,22 @@ class NestedTransaction(Transaction): The interface is the same as that of Transaction class. """ - _savepoint = None + _savepoint: Optional[Any] = None - def __init__(self, connection, parent): + def __init__( + self, + connection: Any, + parent: Optional['Transaction'] + ): super(NestedTransaction, self).__init__(connection, parent) - async def _do_rollback(self): + async def _do_rollback(self) -> None: assert self._savepoint is not None, "Broken transaction logic" if self._is_active: await self._connection._rollback_to_savepoint_impl( self._savepoint, self._parent) - async def _do_commit(self): + async def _do_commit(self) -> None: assert self._savepoint is not None, "Broken transaction logic" if self._is_active: await self._connection._release_savepoint_impl( @@ -139,17 +162,21 @@ class TwoPhaseTransaction(Transaction): with the addition of the .prepare() method. """ - def __init__(self, connection, xid): + def __init__( + self, + connection: Any, + xid: Any + ) -> None: super().__init__(connection, None) self._is_prepared = False self._xid = xid @property - def xid(self): + def xid(self) -> 'xid': """Returns twophase transaction id.""" return self._xid - async def prepare(self): + async def prepare(self) -> None: """Prepare this TwoPhaseTransaction. After a PREPARE, the transaction can be committed. @@ -160,10 +187,10 @@ async def prepare(self): await self._connection._prepare_twophase_impl(self._xid) self._is_prepared = True - async def _do_rollback(self): + async def _do_rollback(self) -> None: await self._connection.rollback_prepared( self._xid, is_prepared=self._is_prepared) - async def _do_commit(self): + async def _do_commit(self) -> None: await self._connection.commit_prepared( self._xid, is_prepared=self._is_prepared) diff --git a/aiomysql/utils.py b/aiomysql/utils.py index 74ad99a7..61640237 100644 --- a/aiomysql/utils.py +++ b/aiomysql/utils.py @@ -1,44 +1,44 @@ -from collections.abc import Coroutine - import struct - - -def _pack_int24(n): - return struct.pack(" 'Any': return self._coro.send(value) - def throw(self, typ, val=None, tb=None): + def throw( + self, + typ: Type[BaseException], + val: Optional[BaseException] = None, + tb: Optional[TracebackType] = None + ) -> Any: if val is None: return self._coro.throw(typ) elif tb is None: @@ -46,142 +46,87 @@ def throw(self, typ, val=None, tb=None): else: return self._coro.throw(typ, val, tb) - def close(self): + def close(self) -> None: return self._coro.close() - @property - def gi_frame(self): - return self._coro.gi_frame - - @property - def gi_running(self): - return self._coro.gi_running - - @property - def gi_code(self): - return self._coro.gi_code - - def __next__(self): - return self.send(None) + async def __anext__(self) -> _Tobj: + try: + value = self._coro.send(None) + except StopAsyncIteration: + self._obj = None + raise + else: + return value - def __iter__(self): - return self._coro.__await__() + def __aiter__(self) -> AsyncGenerator[None, _Tobj]: + return self._obj - def __await__(self): + def __await__(self) -> Generator[Any, None, _Tobj]: return self._coro.__await__() - async def __aenter__(self): + async def __aenter__(self) -> _Tobj: self._obj = await self._coro + assert self._obj return self._obj - async def __aexit__(self, exc_type, exc, tb): - await self._obj.close() - self._obj = None - - -class _ConnectionContextManager(_ContextManager): - async def __aexit__(self, exc_type, exc, tb): - if exc_type is not None: - self._obj.close() - else: - await self._obj.ensure_closed() - self._obj = None + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + try: + if exc_type is not None and exc is not None and tb is not None: + await self._release_on_exception(self._obj) + else: + await self._release(self._obj) + finally: + await self._obj.close() + self._obj = None -class _PoolContextManager(_ContextManager): - async def __aexit__(self, exc_type, exc, tb): - self._obj.close() - await self._obj.wait_closed() - self._obj = None +class _IterableContextManager(_ContextManager[_Tobj]): + __slots__ = () + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) -class _SAConnectionContextManager(_ContextManager): - def __aiter__(self): + def __aiter__(self) -> '_IterableContextManager[_Tobj]': return self - async def __anext__(self): + async def __anext__(self) -> _Tobj: if self._obj is None: self._obj = await self._coro try: - return await self._obj.__anext__() + return await self._obj.__anext__() # type: ignore except StopAsyncIteration: - await self._obj.close() - self._obj = None + try: + await self._release(self._obj) + finally: + self._obj = None raise -class _TransactionContextManager(_ContextManager): - async def __aexit__(self, exc_type, exc, tb): - if exc_type: - await self._obj.rollback() - else: - if self._obj.is_active: - await self._obj.commit() - self._obj = None - - -class _PoolAcquireContextManager(_ContextManager): - - __slots__ = ('_coro', '_conn', '_pool') - - def __init__(self, coro, pool): - self._coro = coro - self._conn = None - self._pool = pool - - async def __aenter__(self): - self._conn = await self._coro - return self._conn - - async def __aexit__(self, exc_type, exc, tb): - try: - await self._pool.release(self._conn) - finally: - self._pool = None - self._conn = None - - -class _PoolConnectionContextManager: - """Context manager. - - This enables the following idiom for acquiring and releasing a - connection around a block: - - with (yield from pool) as conn: - cur = yield from conn.cursor() - - while failing loudly when accidentally using: - - with pool: - - """ - - __slots__ = ('_pool', '_conn') - - def __init__(self, pool, conn): - self._pool = pool - self._conn = conn - - def __enter__(self): - assert self._conn - return self._conn - - def __exit__(self, exc_type, exc_val, exc_tb): - try: - self._pool.release(self._conn) - finally: - self._pool = None - self._conn = None +def _pack_int24(n: int) -> bytes: + return struct.pack(" bytes: + if i < 0: + raise ValueError( + "Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i + ) + elif i < 0xFB: + return bytes([i]) + elif i < (1 << 16): + return b"\xfc" + struct.pack(" Date: Fri, 5 May 2023 14:53:41 +0300 Subject: [PATCH 2/4] update flake8 --- .flake8 | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.flake8 b/.flake8 index 2bcd70e3..43fbd362 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,4 @@ [flake8] -max-line-length = 88 +max-line-length = 128 +exclude = venv/*, tests/*, docs/* + From facede80b599af0fc34b3ac426b8e38fe9a771e3 Mon Sep 17 00:00:00 2001 From: dromanov Date: Fri, 5 May 2023 15:11:58 +0300 Subject: [PATCH 3/4] =?UTF-8?q?=E2=8F=AA=EF=B8=8F=20Revert=20changes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aiomysql/__init__.py | 52 ++++++++++++++++++++------------------- aiomysql/connection.py | 13 ++++------ aiomysql/cursors.py | 30 +++++++++++++--------- aiomysql/sa/__init__.py | 16 +++++------- aiomysql/sa/connection.py | 3 +-- aiomysql/sa/result.py | 7 +++--- 6 files changed, 61 insertions(+), 60 deletions(-) diff --git a/aiomysql/__init__.py b/aiomysql/__init__.py index f0cb58c4..4d7e3ab9 100644 --- a/aiomysql/__init__.py +++ b/aiomysql/__init__.py @@ -23,8 +23,6 @@ """ -from typing import List, Type - from pymysql.converters import escape_dict, escape_sequence, escape_string from pymysql.err import ( Warning, @@ -47,30 +45,34 @@ __version__ = version -__all__: List[Type] = [ +__all__ = [ + # Errors - Error, - DataError, - DatabaseError, - IntegrityError, - InterfaceError, - InternalError, - MySQLError, - NotSupportedError, - OperationalError, - ProgrammingError, - Warning, + 'Error', + 'DataError', + 'DatabaseError', + 'IntegrityError', + 'InterfaceError', + 'InternalError', + 'MySQLError', + 'NotSupportedError', + 'OperationalError', + 'ProgrammingError', + 'Warning', - escape_dict, - escape_sequence, - escape_string, + 'escape_dict', + 'escape_sequence', + 'escape_string', - Connection, - Pool, - connect, - create_pool, - Cursor, - SSCursor, - DictCursor, - SSDictCursor, + 'Connection', + 'Pool', + 'connect', + 'create_pool', + 'Cursor', + 'SSCursor', + 'DictCursor', + 'SSDictCursor' ] + +(Connection, Pool, connect, create_pool, Cursor, SSCursor, DictCursor, + SSDictCursor) # pyflakes diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 0b6eef47..9caafefd 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -512,8 +512,7 @@ def literal(self, obj: Union[str, bytes, Any]) -> str: return self.escape(obj) def escape_string(self, s: str) -> str: - if (self.server_status & - SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES): + if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES: return s.replace("'", "''") return escape_string(s) @@ -539,8 +538,7 @@ def cursor(self, *cursors): if cursors and len(cursors) == 1: cur = cursors[0](self, self._echo) elif cursors: - cursor_name = ''.join(map(lambda x: x.__name__, cursors)) \ - .replace('Cursor', '') + 'Cursor' + cursor_name = ''.join(map(lambda x: x.__name__, cursors)).replace('Cursor', '') + 'Cursor' cursor_class = type(cursor_name, cursors, {}) cur = cursor_class(self, self._echo) else: @@ -940,8 +938,7 @@ async def _request_authentication(self): # connection-phase-packets.html#packet-Protocol::AuthSwitchRequest auth_packet.read_uint8() # 0xfe packet identifier plugin_name = auth_packet.read_string() - if (self.server_capabilities & CLIENT.PLUGIN_AUTH and - plugin_name is not None): + if self.server_capabilities & CLIENT.PLUGIN_AUTH and plugin_name is not None: await self._process_auth(plugin_name, auth_packet) else: # send legacy handshake @@ -1325,7 +1322,7 @@ async def _read_result_packet(self, first_packet) -> None: await self._get_descriptions() await self._read_rowdata_packet() - async def _read_rowdata_packet_unbuffered(self) -> Optional[tuple[Optional[Any], ...]]: + async def _read_rowdata_packet_unbuffered(self) -> Optional[Tuple[Optional[Any], ...]]: # Check if in an active query if not self.unbuffered_active: return @@ -1383,7 +1380,7 @@ async def _read_rowdata_packet(self) -> None: self.affected_rows = len(rows) self.rows = tuple(rows) - def _read_row_from_packet(self, packet) -> tuple[Optional[Any], ...]: + def _read_row_from_packet(self, packet) -> Tuple[Optional[Any], ...]: row = [] for encoding, converter in self.converters: try: diff --git a/aiomysql/cursors.py b/aiomysql/cursors.py index 3401bdbf..854be749 100644 --- a/aiomysql/cursors.py +++ b/aiomysql/cursors.py @@ -1,21 +1,29 @@ -import re +# https://github.com/PyMySQL/PyMySQL/blob/master/pymysql/cursors.py#L11-L18 +import contextlib import json +import re import warnings -import contextlib from pymysql.err import ( - Warning, Error, InterfaceError, DataError, - DatabaseError, OperationalError, IntegrityError, InternalError, - NotSupportedError, ProgrammingError) + Warning, + Error, + InterfaceError, + DataError, + DatabaseError, + OperationalError, + IntegrityError, + InternalError, + NotSupportedError, + ProgrammingError +) -from .log import logger from .connection import FIELD_TYPE - -# https://github.com/PyMySQL/PyMySQL/blob/master/pymysql/cursors.py#L11-L18 +from .log import logger #: Regular expression for :meth:`Cursor.executemany`. #: executemany only supports simple bulk insert. #: You can use it to load large dataset. +# flake8: noqa RE_INSERT_VALUES = re.compile( r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" + @@ -155,7 +163,7 @@ async def close(self): if conn is None: return try: - while (await self.nextset()): + while await self.nextset(): pass finally: self._connection = None @@ -198,7 +206,7 @@ def _escape_args(self, args, conn): elif isinstance(args, dict): return dict((key, conn.escape(val)) for (key, val) in args.items()) else: - # If it's not a dictionary let's try escaping it anyways. + # If it's not a dictionary let's try escaping it anyway. # Worst case it will throw a Value error return conn.escape(args) @@ -230,7 +238,7 @@ async def execute(self, query, args=None): """ conn = self._get_db() - while (await self.nextset()): + while await self.nextset(): pass if args is not None: diff --git a/aiomysql/sa/__init__.py b/aiomysql/sa/__init__.py index d679f4a0..281646a1 100644 --- a/aiomysql/sa/__init__.py +++ b/aiomysql/sa/__init__.py @@ -11,13 +11,9 @@ ResourceClosedError ) -__all__ = [ - "SAConnection", - "Error", - "ArgumentError", - "InvalidRequestError", - "NoSuchColumnError", - "ResourceClosedError", - "create_engine", - "Engine" -] +__all__ = ('create_engine', 'SAConnection', 'Error', + 'ArgumentError', 'InvalidRequestError', 'NoSuchColumnError', + 'ResourceClosedError', 'Engine') + +(SAConnection, Error, ArgumentError, InvalidRequestError, + NoSuchColumnError, ResourceClosedError, create_engine, Engine) diff --git a/aiomysql/sa/connection.py b/aiomysql/sa/connection.py index 47202494..f641bd58 100644 --- a/aiomysql/sa/connection.py +++ b/aiomysql/sa/connection.py @@ -484,8 +484,7 @@ def _distill_params(multiparams: Any, params: Any) -> List[Union[Dict[str, Any], # execute(stmt, "value") return [[zero]] else: - if (hasattr(multiparams[0], '__iter__') and - not hasattr(multiparams[0], 'strip')): + if hasattr(multiparams[0], '__iter__') and not hasattr(multiparams[0], 'strip'): return multiparams else: return [multiparams] diff --git a/aiomysql/sa/result.py b/aiomysql/sa/result.py index f34d3ff9..24ee7b63 100644 --- a/aiomysql/sa/result.py +++ b/aiomysql/sa/result.py @@ -180,17 +180,16 @@ def _key_fallback(self, key, raiseerr=True): # or colummn('name') constructs to ColumnElements, or after a # pickle/unpickle roundtrip elif isinstance(key, expression.ColumnElement): - if (key._label and key._label in map): + if key._label and key._label in map: result = map[key._label] - elif (hasattr(key, 'name') and key.name in map): + elif hasattr(key, 'name') and key.name in map: # match is only on name. result = map[key.name] # search extra hard to make sure this # isn't a column/label name overlap. # this check isn't currently available if the row # was unpickled. - if (result is not None and - result[1] is not None): + if result is not None and result[1] is not None: for obj in result[1]: if key._compare_name_for_result(obj): break From 0a13cc4a545f349b3c908e2c701654caefc78bf8 Mon Sep 17 00:00:00 2001 From: dromanov Date: Sun, 7 May 2023 14:00:36 +0300 Subject: [PATCH 4/4] =?UTF-8?q?=F0=9F=9A=91=EF=B8=8F=20Fix=20bug=20with=20?= =?UTF-8?q?circular=20import?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aiomysql/__init__.py | 3 +- aiomysql/connection.py | 1126 ++++++++++++++++++++++++------ aiomysql/cursors.py | 714 ------------------- aiomysql/pool.py | 103 ++- aiomysql/sa/connection.py | 10 +- aiomysql/sa/engine.py | 57 +- aiomysql/sa/exc.py | 12 +- aiomysql/sa/result.py | 37 +- aiomysql/sa/transaction.py | 18 +- aiomysql/utils.py | 92 ++- tests/test_connection.py | 4 +- tests/test_cursor.py | 6 +- tests/test_deserialize_cursor.py | 24 +- tests/test_dictcursor.py | 6 +- tests/test_issues.py | 2 +- tests/test_sha_connection.py | 2 +- tests/test_sscursor.py | 2 +- 17 files changed, 1162 insertions(+), 1056 deletions(-) delete mode 100644 aiomysql/cursors.py diff --git a/aiomysql/__init__.py b/aiomysql/__init__.py index 4d7e3ab9..82e2e217 100644 --- a/aiomysql/__init__.py +++ b/aiomysql/__init__.py @@ -40,8 +40,7 @@ from aiomysql.pool import create_pool, Pool from ._version import version -from .connection import Connection, connect -from .cursors import Cursor, SSCursor, DictCursor, SSDictCursor +from .connection import Connection, connect, SSDictCursor, Cursor, SSCursor, DictCursor __version__ = version diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 9caafefd..00569ee3 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -1,25 +1,19 @@ # Python implementation of the MySQL client-server protocol # http://dev.mysql.com/doc/internals/en/client-server-protocol.html - +# https://github.com/PyMySQL/PyMySQL/blob/master/pymysql/cursors.py#L11-L18 import asyncio import configparser +import contextlib import getpass +import json import os +import re import socket import struct import sys import warnings from functools import partial -from typing import ( - Union, - Mapping, - Any, - Optional, - List, - Tuple, - Dict, - Type -) +from typing import Optional, Union from pymysql.charset import ( charset_by_name, @@ -67,7 +61,6 @@ ProgrammingError ) -from .cursors import Cursor from .log import logger from .utils import ( _pack_int24, @@ -80,8 +73,718 @@ except KeyError: DEFAULT_USER = "unknown" +#: Regular expression for :meth:`Cursor.executemany`. +#: executemany only supports simple bulk insert. +#: You can use it to load large dataset. +# flake8: noqa +RE_INSERT_VALUES = re.compile( + r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" + + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" + + r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", + re.IGNORECASE | re.DOTALL) + + +class Cursor: + """Cursor is used to interact with the database.""" + + #: Max statement size which :meth:`executemany` generates. + #: + #: Max size of allowed statement is max_allowed_packet - + # packet_header_size. + #: Default value of max_allowed_packet is 1048576. + max_stmt_length = 1024000 + + def __init__(self, connection, echo=False): + """ + Do not create an instance of a Cursor yourself. Call + connections.Connection.cursor(). + """ + self._connection = connection + self._loop = self._connection.loop + self._description = None + self._rownumber = 0 + self._rowcount = -1 + self._arraysize = 1 + self._executed = None + self._result = None + self._rows = None + self._lastrowid = None + self._echo = echo + + @property + def connection(self): + """ + This read-only attribute return a reference to the Connection + object on which the cursor was created. + """ + return self._connection + + @property + def description(self): + """ + This read-only attribute is a sequence of 7-item sequences. + + Each of these sequences is a collections.namedtuple containing + information describing one result column: + + 0. name: the name of the column returned. + 1. type_code: the type of the column. + 2. display_size: the actual length of the column in bytes. + 3. internal_size: the size in bytes of the column associated to + this column on the server. + 4. precision: total number of significant digits in columns of + type NUMERIC. None for other types. + 5. scale: count of decimal digits in the fractional part in + columns of type NUMERIC. None for other types. + 6. null_ok: always None as not easy to retrieve from the libpq. + + This attribute will be None for operations that do not + return rows or if the cursor has not had an operation invoked + via the execute() method yet. + """ + return self._description + + @property + def rowcount(self): + """ + Returns the number of rows that has been produced of affected. + + This read-only attribute specifies the number of rows that the + last :meth:`execute` produced (for Data Query Language + statements like SELECT) or affected (for Data Manipulation + Language statements like UPDATE or INSERT). + + The attribute is -1 in case no .execute() has been performed + on the cursor or the row count of the last operation if it + can't be determined by the interface. + """ + return self._rowcount + + @property + def rownumber(self): + """ + Row index. + + This read-only attribute provides the current 0-based index of the + cursor in the result set or ``None`` if the index cannot be + determined. + """ + + return self._rownumber + + @property + def arraysize(self): + """ + How many rows will be returned by fetchmany() call. + + This read/write attribute specifies the number of rows to + fetch at a time with fetchmany(). It defaults to + 1 meaning to fetch a single row at a time. + + """ + return self._arraysize + + @arraysize.setter + def arraysize(self, val): + """ + How many rows will be returned by fetchmany() call. + + This read/write attribute specifies the number of rows to + fetch at a time with fetchmany(). It defaults to + 1 meaning to fetch a single row at a time. + + """ + self._arraysize = val + + @property + def lastrowid(self): + """ + This read-only property returns the value generated for an + AUTO_INCREMENT column by the previous INSERT or UPDATE statement + or None when there is no such value available. For example, + if you perform an INSERT into a table that contains an AUTO_INCREMENT + column, lastrowid returns the AUTO_INCREMENT value for the new row. + """ + return self._lastrowid + + @property + def echo(self): + """Return echo mode status.""" + return self._echo + + @property + def closed(self): + """ + The readonly property that returns ``True`` if connections was + detached from current cursor + """ + return True if not self._connection else False + + async def close(self): + """Closing a cursor just exhausts all remaining data.""" + conn = self._connection + if conn is None: + return + try: + while await self.nextset(): + pass + finally: + self._connection = None + + def _get_db(self): + if not self._connection: + raise ProgrammingError("Cursor closed") + return self._connection + + def _check_executed(self): + if not self._executed: + raise ProgrammingError("execute() first") + + def _conv_row(self, row): + return row + + def setinputsizes(self, *args): + """Does nothing, required by DB API.""" + + def setoutputsizes(self, *args): + """Does nothing, required by DB API.""" -async def connect( + async def nextset(self): + """Get the next query set""" + conn = self._get_db() + current_result = self._result + if current_result is None or current_result is not conn._result: + return + if not current_result.has_next: + return + self._result = None + self._clear_result() + await conn.next_result() + await self._do_get_result() + return True + + def _escape_args(self, args, conn): + if isinstance(args, (tuple, list)): + return tuple(conn.escape(arg) for arg in args) + elif isinstance(args, dict): + return dict((key, conn.escape(val)) for (key, val) in args.items()) + else: + # If it's not a dictionary let's try escaping it anyway. + # Worst case it will throw a Value error + return conn.escape(args) + + def mogrify(self, query, args=None): + """ + Returns the exact string that is sent to the database by calling + to execute() method. This method follows the extension to the DB + API 2.0 followed by Psycopg. + + :param query: ``str`` sql statement + :param args: ``tuple`` or ``list`` of arguments for sql query + """ + conn = self._get_db() + if args is not None: + query = query % self._escape_args(args, conn) + return query + + async def execute(self, query, args=None): + """ + Executes the given operation + + Executes the given operation substituting any markers with + the given parameters. + + For example, getting all rows where id is 5: + cursor.execute("SELECT * FROM t1 WHERE id = %s", (5,)) + + :param query: ``str`` sql statement + :param args: ``tuple`` or ``list`` of arguments for sql query + :returns: ``int``, number of rows that has been produced of affected + """ + conn = self._get_db() + + while await self.nextset(): + pass + + if args is not None: + query = query % self._escape_args(args, conn) + + await self._query(query) + self._executed = query + if self._echo: + logger.info(query) + logger.info("%r", args) + return self._rowcount + + async def executemany(self, query, args): + """ + Execute the given operation multiple times + + The executemany() method will execute the operation iterating + over the list of parameters in seq_params. + + Example: Inserting 3 new employees and their phone number + + data = [ + ('Jane','555-001'), + ('Joe', '555-001'), + ('John', '555-003') + ] + stmt = "INSERT INTO employees (name, phone) VALUES ('%s','%s')" + await cursor.executemany(stmt, data) + + INSERT or REPLACE statements are optimized by batching the data, + that is using the MySQL multiple rows syntax. + + :param query: `str`, sql statement + :param args: ``tuple`` or ``list`` of arguments for sql query + """ + if not args: + return + + if self._echo: + logger.info("CALL %s", query) + logger.info("%r", args) + + m = RE_INSERT_VALUES.match(query) + if m: + q_prefix = m.group(1) % () + q_values = m.group(2).rstrip() + q_postfix = m.group(3) or '' + assert q_values[0] == '(' and q_values[-1] == ')' + return (await self._do_execute_many( + q_prefix, q_values, q_postfix, args, self.max_stmt_length, + self._get_db().encoding)) + else: + rows = 0 + for arg in args: + await self.execute(query, arg) + rows += self._rowcount + self._rowcount = rows + return self._rowcount + + async def _do_execute_many(self, prefix, values, postfix, args, + max_stmt_length, encoding): + conn = self._get_db() + escape = self._escape_args + if isinstance(prefix, str): + prefix = prefix.encode(encoding) + if isinstance(postfix, str): + postfix = postfix.encode(encoding) + sql = bytearray(prefix) + args = iter(args) + v = values % escape(next(args), conn) + if isinstance(v, str): + v = v.encode(encoding, 'surrogateescape') + sql += v + rows = 0 + for arg in args: + v = values % escape(arg, conn) + if isinstance(v, str): + v = v.encode(encoding, 'surrogateescape') + if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length: + r = await self.execute(sql + postfix) + rows += r + sql = bytearray(prefix) + else: + sql += b',' + sql += v + r = await self.execute(sql + postfix) + rows += r + self._rowcount = rows + return rows + + async def callproc(self, procname, args=()): + """ + Execute stored procedure procname with args + + Compatibility warning: PEP-249 specifies that any modified + parameters must be returned. This is currently impossible + as they are only available by storing them in a server + variable and then retrieved by a query. Since stored + procedures return zero or more result sets, there is no + reliable way to get at OUT or INOUT parameters via callproc. + The server variables are named @_procname_n, where procname + is the parameter above and n is the position of the parameter + (from zero). Once all result sets generated by the procedure + have been fetched, you can issue a SELECT @_procname_0, ... + query using .execute() to get any OUT or INOUT values. + + Compatibility warning: The act of calling a stored procedure + itself creates an empty result set. This appears after any + result sets generated by the procedure. This is non-standard + behavior with respect to the DB-API. Be sure to use nextset() + to advance through all result sets; otherwise you may get + disconnected. + + :param procname: ``str``, name of procedure to execute on server + :param args: `sequence of parameters to use with procedure + :returns: the original args. + """ + conn = self._get_db() + if self._echo: + logger.info("CALL %s", procname) + logger.info("%r", args) + + for index, arg in enumerate(args): + q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg)) + await self._query(q) + await self.nextset() + + _args = ','.join('@_%s_%d' % (procname, i) for i in range(len(args))) + q = "CALL %s(%s)" % (procname, _args) + await self._query(q) + self._executed = q + return args + + def fetchone(self): + """Fetch the next row """ + self._check_executed() + fut = self._loop.create_future() + + if self._rows is None or self._rownumber >= len(self._rows): + fut.set_result(None) + return fut + result = self._rows[self._rownumber] + self._rownumber += 1 + + fut = self._loop.create_future() + fut.set_result(result) + return fut + + def fetchmany(self, size=None): + """ + Returns the next set of rows of a query result, returning a + list of tuples. When no more rows are available, it returns an + empty list. + + The number of rows returned can be specified using the size argument, + which defaults to one + + :param size: ``int`` number of rows to return + :returns: ``list`` of fetched rows + """ + self._check_executed() + fut = self._loop.create_future() + if self._rows is None: + fut.set_result([]) + return fut + end = self._rownumber + (size or self._arraysize) + result = self._rows[self._rownumber:end] + self._rownumber = min(end, len(self._rows)) + + fut.set_result(result) + return fut + + def fetchall(self): + """ + Returns all rows of a query result set + + :returns: ``list`` of fetched rows + """ + self._check_executed() + fut = self._loop.create_future() + if self._rows is None: + fut.set_result([]) + return fut + + if self._rownumber: + result = self._rows[self._rownumber:] + else: + result = self._rows + self._rownumber = len(self._rows) + + fut.set_result(result) + return fut + + def scroll(self, value, mode='relative'): + """Scroll the cursor in the result set to a new position according + to mode. + + If mode is relative (default), value is taken as offset to the + current position in the result set, if set to absolute, value + states an absolute target position. An IndexError should be raised in + case a scroll operation would leave the result set. In this case, + the cursor position is left undefined (ideal would be to + not move the cursor at all). + + :param int value: move cursor to next position according to mode. + :param str mode: scroll mode, possible modes: `relative` and `absolute` + """ + self._check_executed() + if mode == 'relative': + r = self._rownumber + value + elif mode == 'absolute': + r = value + else: + raise ProgrammingError("unknown scroll mode %s" % mode) + + if not (0 <= r < len(self._rows)): + raise IndexError("out of range") + self._rownumber = r + + fut = self._loop.create_future() + fut.set_result(None) + return fut + + async def _query(self, q): + conn = self._get_db() + self._last_executed = q + self._clear_result() + await conn.query(q) + await self._do_get_result() + + def _clear_result(self): + self._rownumber = 0 + self._result = None + + self._rowcount = 0 + self._description = None + self._lastrowid = None + self._rows = None + + async def _do_get_result(self): + conn = self._get_db() + self._rownumber = 0 + self._result = result = conn._result + self._rowcount = result.affected_rows + self._description = result.description + self._lastrowid = result.insert_id + self._rows = result.rows + + if result.warning_count > 0: + await self._show_warnings(conn) + + async def _show_warnings(self, conn): + if self._result and self._result.has_next: + return + ws = await conn.show_warnings() + if ws is None: + return + for w in ws: + msg = w[-1] + warnings.warn(str(msg), Warning, 4) + + Warning = Warning + Error = Error + InterfaceError = InterfaceError + DatabaseError = DatabaseError + DataError = DataError + OperationalError = OperationalError + IntegrityError = IntegrityError + InternalError = InternalError + ProgrammingError = ProgrammingError + NotSupportedError = NotSupportedError + + def __aiter__(self): + return self + + async def __anext__(self): + ret = await self.fetchone() + if ret is not None: + return ret + else: + raise StopAsyncIteration # noqa + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + return + + +class _DeserializationCursorMixin: + async def _do_get_result(self): + await super()._do_get_result() + if self._rows: + self._rows = [self._deserialization_row(r) for r in self._rows] + + def _deserialization_row(self, row): + if row is None: + return None + if isinstance(row, dict): + dict_flag = True + else: + row = list(row) + dict_flag = False + for index, (name, field_type, *n) in enumerate(self._description): + if field_type == FIELD_TYPE.JSON: + point = name if dict_flag else index + with contextlib.suppress(ValueError, TypeError): + row[point] = json.loads(row[point]) + if dict_flag: + return row + else: + return tuple(row) + + def _conv_row(self, row): + if row is None: + return None + row = super()._conv_row(row) + return self._deserialization_row(row) + + +class DeserializationCursor(_DeserializationCursorMixin, Cursor): + """A cursor automatic deserialization of json type fields""" + + +class _DictCursorMixin: + # You can override this to use OrderedDict or other dict-like types. + dict_type = dict + + async def _do_get_result(self): + await super()._do_get_result() + fields = [] + if self._description: + for f in self._result.fields: + name = f.name + if name in fields: + name = f.table_name + '.' + name + fields.append(name) + self._fields = fields + + if fields and self._rows: + self._rows = [self._conv_row(r) for r in self._rows] + + def _conv_row(self, row): + if row is None: + return None + row = super()._conv_row(row) + return self.dict_type(zip(self._fields, row)) + + +class DictCursor(_DictCursorMixin, Cursor): + """A cursor which returns results as a dictionary""" + + +class SSCursor(Cursor): + """ + Unbuffered Cursor, mainly useful for queries that return a lot of + data, or for connections to remote servers over a slow network. + + Instead of copying every row of data into a buffer, this will fetch + rows as needed. The upside of this, is the client uses much less memory, + and rows are returned much faster when traveling over a slow network, + or if the result set is very big. + + There are limitations, though. The MySQL protocol doesn't support + returning the total number of rows, so the only way to tell how many rows + there are is to iterate over every row returned. Also, it currently isn't + possible to scroll backwards, as only the current row is held in memory. + """ + + async def close(self): + conn = self._connection + if conn is None: + return + + if self._result is not None and self._result is conn._result: + await self._result._finish_unbuffered_query() + + try: + while await self.nextset(): + pass + finally: + self._connection = None + + async def _query(self, q): + conn = self._get_db() + self._last_executed = q + await conn.query(q, unbuffered=True) + await self._do_get_result() + return self._rowcount + + async def _read_next(self): + """Read next row """ + row = await self._result._read_rowdata_packet_unbuffered() + row = self._conv_row(row) + return row + + async def fetchone(self): + """ Fetch next row """ + self._check_executed() + row = await self._read_next() + if row is None: + return + self._rownumber += 1 + return row + + async def fetchall(self): + """Fetch all, as per MySQLdb. Pretty useless for large queries, as + it is buffered. + """ + rows = [] + while True: + row = await self.fetchone() + if row is None: + break + rows.append(row) + return rows + + async def fetchmany(self, size=None): + """Returns the next set of rows of a query result, returning a + list of tuples. When no more rows are available, it returns an + empty list. + + The number of rows returned can be specified using the size argument, + which defaults to one + + :param size: ``int`` number of rows to return + :returns: ``list`` of fetched rows + """ + self._check_executed() + if size is None: + size = self._arraysize + + rows = [] + for i in range(size): + row = await self._read_next() + if row is None: + break + rows.append(row) + self._rownumber += 1 + return rows + + async def scroll(self, value, mode='relative'): + """Scroll the cursor in the result set to a new position + according to mode . Same as :meth:`Cursor.scroll`, but move cursor + on server side one by one row. If you want to move 20 rows forward + scroll will make 20 queries to move cursor. Currently, only forward + scrolling is supported. + + :param int value: move cursor to next position according to mode. + :param str mode: scroll mode, possible modes: `relative` and `absolute` + """ + + self._check_executed() + + if mode == 'relative': + if value < 0: + raise NotSupportedError("Backwards scrolling not supported " + "by this cursor") + + for _ in range(value): + await self._read_next() + self._rownumber += value + elif mode == 'absolute': + if value < self._rownumber: + raise NotSupportedError( + "Backwards scrolling not supported by this cursor") + + end = value - self._rownumber + for _ in range(end): + await self._read_next() + self._rownumber = value + else: + raise ProgrammingError(f"unknown scroll {mode}") + + +class SSDictCursor(_DictCursorMixin, SSCursor): + """An unbuffered cursor, which returns results as a dictionary """ + + +def connect( host: str = "localhost", user: Optional[str] = None, password: str = "", @@ -91,10 +794,10 @@ async def connect( charset: str = '', sql_mode: Optional[str] = None, read_default_file: Optional[str] = None, - conv: Optional[Dict[str, Any]] = None, - use_unicode: Optional[bool] = None, + conv: Optional[dict] = decoders, + use_unicode: Optional[Union[bool, str]] = None, client_flag: int = 0, - cursorclass: Type[Cursor] = Cursor, + cursorclass: "Cursor" = Cursor, init_command: Optional[str] = None, connect_timeout: Optional[int] = None, read_default_group: Optional[str] = None, @@ -102,50 +805,61 @@ async def connect( echo: bool = False, local_infile: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, - ssl: Optional[Union[bool, Dict[str, Any]]] = None, - auth_plugin: Optional[str] = '', - program_name: Optional[str] = '', - server_public_key: Optional[Union[str, bytes]] = None, -) -> 'Connection': + ssl: Optional[dict] = None, + auth_plugin: str = '', + program_name: str = '', + server_public_key: Optional[bytes] = None +) -> _ContextManager["Connection"]: """See connections.Connection.__init__() for information about defaults.""" - coro = _connect(host=host, user=user, password=password, db=db, - port=port, unix_socket=unix_socket, charset=charset, - sql_mode=sql_mode, read_default_file=read_default_file, - conv=conv, use_unicode=use_unicode, - client_flag=client_flag, cursorclass=cursorclass, - init_command=init_command, - connect_timeout=connect_timeout, - read_default_group=read_default_group, - autocommit=autocommit, echo=echo, - local_infile=local_infile, loop=loop, ssl=ssl, - auth_plugin=auth_plugin, program_name=program_name) - return _ContextManager[Connection](coro, disconnect) - - -async def disconnect(c: "Connection") -> None: - c.close() - - -# todo: Update Any to stricter kwarg -# https://github.com/python/mypy/issues/4441 + coro = _connect( + host=host, + user=user, + password=password, + db=db, + port=port, + unix_socket=unix_socket, + charset=charset, + sql_mode=sql_mode, + read_default_file=read_default_file, + conv=conv, + use_unicode=use_unicode, + client_flag=client_flag, + cursorclass=cursorclass, + init_command=init_command, + connect_timeout=connect_timeout, + read_default_group=read_default_group, + autocommit=autocommit, + echo=echo, + local_infile=local_infile, + loop=loop, + ssl=ssl, + auth_plugin=auth_plugin, + program_name=program_name + ) + return _ContextManager[Connection](coro, _disconnect) + + +# noinspection PyUnresolvedReferences +async def _disconnect(c: "Connection"): + await c.close() + + +async def _close_cursor(c: Cursor): + await c.close() + + +# noinspection PyProtectedMember async def _connect(*args, **kwargs): conn = Connection(*args, **kwargs) await conn._connect() return conn -# todo: Update Any to stricter kwarg -# https://github.com/python/mypy/issues/4441 -async def _open_connection( - host: Optional[str] = None, - port: Optional[int] = None, - **kwds: Any -) -> Tuple['_StreamReader', asyncio.StreamWriter]: - """This is based on asyncio.open_connection, allowing us to use a custom +async def _open_connection(host=None, port=None, **kwds): + """ + This is based on asyncio.open_connection, allowing us to use a custom StreamReader. - - `limit` arg has been removed as we don't currently use it. """ loop = asyncio.events.get_running_loop() reader = _StreamReader(loop=loop) @@ -156,18 +870,12 @@ async def _open_connection( return reader, writer -# todo: Update Any to stricter kwarg -# https://github.com/python/mypy/issues/4441 -async def _open_unix_connection( - path: Optional[str] = None, - **kwds: Any -) -> Tuple['_StreamReader', asyncio.StreamWriter]: - """This is based on asyncio.open_unix_connection, allowing us to use a custom +async def _open_unix_connection(path=None, **kwds): + """ + This is based on asyncio.open_unix_connection, allowing us to use a custom StreamReader. - - `limit` arg has been removed as we don't currently use it. """ - loop = asyncio.get_running_loop() + loop = asyncio.events.get_running_loop() reader = _StreamReader(loop=loop) protocol = asyncio.StreamReaderProtocol(reader, loop=loop) @@ -178,15 +886,14 @@ async def _open_unix_connection( class _StreamReader(asyncio.StreamReader): - """This StreamReader exposes whether EOF was received, allowing us to + """ + This StreamReader exposes whether EOF was received, allowing us to discard the associated connection instead of returning it from the pool when checking free connections in Pool._fill_free_pool(). - - `limit` arg has been removed as we don't currently use it. """ - def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: - self._eof_received: bool = False + def __init__(self, loop=None): + self._eof_received = False super().__init__(loop=loop) def feed_eof(self) -> None: @@ -194,12 +901,13 @@ def feed_eof(self) -> None: super().feed_eof() @property - def eof_received(self) -> bool: + def eof_received(self): return self._eof_received class Connection: - """Representation of a socket with a mysql server. + """ + Representation of a socket with a mysql server. The proper way to get an instance of this class is to call connect(). @@ -207,31 +915,31 @@ class Connection: def __init__( self, - host: str = "localhost", - user: Optional[str] = None, - password: str = "", - db: Optional[str] = None, - port: int = 3306, - unix_socket: Optional[str] = None, - charset: str = '', - sql_mode: Optional[str] = None, - read_default_file: Optional[str] = None, - conv: dict = decoders, - use_unicode: Optional[bool] = None, - client_flag: int = 0, - cursorclass: type = Cursor, - init_command: Optional[str] = None, - connect_timeout: Optional[int] = None, - read_default_group: Optional[str] = None, - autocommit: Optional[bool] = False, - echo: Optional[bool] = False, - local_infile: Optional[bool] = False, - loop: Optional[asyncio.AbstractEventLoop] = None, - ssl: Optional[Union[str, Mapping[str, Any]]] = None, - auth_plugin: Optional[str] = '', - program_name: Optional[str] = '', - server_public_key: Optional[str] = None - ) -> None: + host="localhost", + user=None, + password="", + db=None, + port=3306, + unix_socket=None, + charset='', + sql_mode=None, + read_default_file=None, + conv=decoders, + use_unicode=None, + client_flag=0, + cursorclass=Cursor, + init_command=None, + connect_timeout=None, + read_default_group=None, + autocommit=False, + echo=False, + local_infile=False, + loop=None, + ssl=None, + auth_plugin='', + program_name='', + server_public_key=None + ): """ Establish a connection to the MySQL database. Accepts several arguments: @@ -367,69 +1075,70 @@ def __init__( self._close_reason = None @property - def host(self) -> str: + def host(self): """MySQL server IP address or name""" return self._host @property - def port(self) -> int: + def port(self): """MySQL server TCP/IP port""" return self._port @property - def unix_socket(self) -> Optional[str]: + def unix_socket(self): """MySQL Unix socket file location""" return self._unix_socket @property - def db(self) -> Optional[str]: + def db(self): """Current database name.""" return self._db @property - def user(self) -> Optional[str]: + def user(self): """User used while connecting to MySQL""" return self._user @property - def echo(self) -> Optional[bool]: + def echo(self): """Return echo mode status.""" return self._echo @property - def last_usage(self) -> Optional[float]: + def last_usage(self): """Return time() when connection was used.""" return self._last_usage @property - def loop(self) -> Optional[asyncio.AbstractEventLoop]: + def loop(self): return self._loop @property - def closed(self) -> Optional[bool]: - """The readonly property that returns ``True`` if connections is + def closed(self): + """ + The readonly property that returns ``True`` if connections is closed. """ return self._writer is None @property - def encoding(self) -> str: + def encoding(self): """Encoding employed for this connection.""" return self._encoding @property - def charset(self) -> str: + def charset(self): """Returns the character set for current connection.""" return self._charset - def close(self) -> None: + def close(self): """Close socket connection""" if self._writer: self._writer.transport.close() self._writer = None self._reader = None - async def ensure_closed(self) -> None: + async def ensure_closed(self): """Send quit command and then close socket connection""" if self._writer is None: # connection has been closed @@ -439,8 +1148,9 @@ async def ensure_closed(self) -> None: await self._writer.drain() self.close() - async def autocommit(self, value) -> None: - """Enable/disable autocommit mode for current MySQL session. + async def autocommit(self, value): + """ + Enable/disable autocommit mode for current MySQL session. :param value: ``bool``, toggle autocommit """ @@ -450,14 +1160,12 @@ async def autocommit(self, value) -> None: await self._send_autocommit_mode() def get_autocommit(self) -> bool: - """Returns autocommit status for current MySQL session. - - :returns bool: current autocommit status.""" + """Returns autocommit status for current MySQL session.""" status = self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT return bool(status) - async def _read_ok_packet(self) -> bool: + async def _read_ok_packet(self): pkt = await self._read_packet() if not pkt.is_ok_packet(): raise OperationalError(2014, "Command Out of Sync") @@ -465,65 +1173,67 @@ async def _read_ok_packet(self) -> bool: self.server_status = ok.server_status return True - async def _send_autocommit_mode(self) -> None: + async def _send_autocommit_mode(self): """Set whether to commit after every execute() """ await self._execute_command( COMMAND.COM_QUERY, "SET AUTOCOMMIT = %s" % self.escape(self.autocommit_mode)) await self._read_ok_packet() - async def begin(self) -> None: + async def begin(self): """Begin transaction.""" await self._execute_command(COMMAND.COM_QUERY, "BEGIN") await self._read_ok_packet() - async def commit(self) -> None: + async def commit(self): """Commit changes to stable storage.""" await self._execute_command(COMMAND.COM_QUERY, "COMMIT") await self._read_ok_packet() - async def rollback(self) -> None: + async def rollback(self): """Roll back the current transaction.""" await self._execute_command(COMMAND.COM_QUERY, "ROLLBACK") await self._read_ok_packet() - async def select_db(self, db) -> None: + async def select_db(self, db): """Set current db""" await self._execute_command(COMMAND.COM_INIT_DB, db) await self._read_ok_packet() - async def show_warnings(self) -> List[Tuple[int, str, Optional[int]]]: + async def show_warnings(self): """SHOW WARNINGS""" await self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS") result = MySQLResult(self) await result.read() return result.rows - def escape(self, obj: Union[str, bytes, Any]) -> str: - """ Escape whatever value you pass to it""" + def escape(self, obj): + """Escape whatever value you pass to it""" if isinstance(obj, str): return "'" + self.escape_string(obj) + "'" if isinstance(obj, bytes): return escape_bytes_prefixed(obj) return escape_item(obj, self._charset) - def literal(self, obj: Union[str, bytes, Any]) -> str: + def literal(self, obj): """Alias for escape()""" return self.escape(obj) - def escape_string(self, s: str) -> str: - if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES: + def escape_string(self, s): + if (self.server_status & + SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES): return s.replace("'", "''") return escape_string(s) def cursor(self, *cursors): - """Instantiates and returns a cursor + """ + Instantiates and returns a cursor By default, :class:`Cursor` is returned. It is possible to also give a custom cursor through the cursor_class parameter, but it needs to be a subclass of :class:`Cursor` - :param cursor: custom cursor class. + :param cursors: custom cursor class. :returns: instance of cursor, by default :class:`Cursor` :raises TypeError: cursor_class is not a subclass of Cursor. """ @@ -538,7 +1248,8 @@ def cursor(self, *cursors): if cursors and len(cursors) == 1: cur = cursors[0](self, self._echo) elif cursors: - cursor_name = ''.join(map(lambda x: x.__name__, cursors)).replace('Cursor', '') + 'Cursor' + cursor_name = ''.join(map(lambda x: x.__name__, cursors)) \ + .replace('Cursor', '') + 'Cursor' cursor_class = type(cursor_name, cursors, {}) cur = cursor_class(self, self._echo) else: @@ -579,12 +1290,12 @@ async def ping(self, reconnect=True): try: await self._execute_command(COMMAND.COM_PING, "") await self._read_ok_packet() - except Exception: + except Exception as ex: if reconnect: await self._connect() await self.ping(False) else: - raise + raise Error(f"Ping failed and cannot reconnect: {ex}") async def set_charset(self, charset): """Sets the character set for the current connection""" @@ -652,10 +1363,11 @@ async def _connect(self): "Can't connect to MySQL server on %r" % self._host, ) from e - # If e is neither IOError nor OSError, it's a bug. - # Raising AssertionError would hide the original error, so we just - # reraise it. - raise + # Raise an error for any other exceptions that occurred during the connection process + raise OperationalError( + CR.CR_CONN_HOST_ERROR, + "Encountered an error while connecting to MySQL server on %r: %r" % (self._host, e), + ) from e def _set_keep_alive(self): transport = self._writer.transport @@ -677,8 +1389,9 @@ def _set_nodelay(self, value): transport.resume_reading() def write_packet(self, payload): - """Writes an entire "mysql packet" in its entirety to the network - addings its length and sequence number. + """ + Writes an entire "mysql packet" in its entirety to the network + adding its length and sequence number. """ # Internal note: when you build packet manually and calls # _write_bytes() directly, you should set self._next_seq_id properly. @@ -687,7 +1400,8 @@ def write_packet(self, payload): self._next_seq_id = (self._next_seq_id + 1) % 256 async def _read_packet(self, packet_type=MysqlPacket): - """Read an entire "mysql packet" in its entirety from the network + """ + Read an entire "mysql packet" in its entirety from the network and return a MysqlPacket type that represents the results. """ buff = b'' @@ -755,24 +1469,23 @@ async def _read_bytes(self, num_bytes): def _write_bytes(self, data): return self._writer.write(data) - async def _read_query_result(self, unbuffered: bool = False) -> None: - try: - if unbuffered: + async def _read_query_result(self, unbuffered=False): + result = None + if unbuffered: + try: result = MySQLResult(self) await result.init_unbuffered_query() - else: - result = MySQLResult(self) - await result.read() - - self._result = result - self._affected_rows = result.affected_rows - self.server_status = getattr(result, 'server_status', None) - - except Exception: - self._result = None - self._affected_rows = -1 - self.server_status = None - raise + except BaseException: + result.unbuffered_active = False + result.connection = None + raise + else: + result = MySQLResult(self) + await result.read() + self._result = result + self._affected_rows = result.affected_rows + if result.server_status is not None: + self.server_status = result.server_status def insert_id(self): if self._result: @@ -798,6 +1511,7 @@ async def _execute_command(self, command, sql): if self._result is not None: if self._result.unbuffered_active: warnings.warn("Previous unbuffered result was left incomplete") + # noinspection PyProtectedMember await self._result._finish_unbuffered_query() while self._result.has_next: await self.next_result() @@ -895,6 +1609,7 @@ async def _request_authentication(self): elif auth_plugin in ('', 'mysql_clear_password'): authresp = self._password.encode('latin1') + b'\0' + # noinspection PyUnresolvedReferences if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA: data += _lenenc_int(len(authresp)) + authresp elif self.server_capabilities & CLIENT.SECURE_CONNECTION: @@ -910,7 +1625,7 @@ async def _request_authentication(self): else: db = self._db data += db + b'\0' - + # noinspection PyUnresolvedReferences if self.server_capabilities & CLIENT.PLUGIN_AUTH: name = auth_plugin if isinstance(name, str): @@ -920,6 +1635,7 @@ async def _request_authentication(self): self._auth_plugin_used = auth_plugin # Sends the server a few pieces of client info + # noinspection PyUnresolvedReferences if self.server_capabilities & CLIENT.CONNECT_ATTRS: connect_attrs = b'' for k, v in self._connect_attrs.items(): @@ -938,7 +1654,9 @@ async def _request_authentication(self): # connection-phase-packets.html#packet-Protocol::AuthSwitchRequest auth_packet.read_uint8() # 0xfe packet identifier plugin_name = auth_packet.read_string() - if self.server_capabilities & CLIENT.PLUGIN_AUTH and plugin_name is not None: + # noinspection PyUnresolvedReferences + if (self.server_capabilities & CLIENT.PLUGIN_AUTH and + plugin_name is not None): await self._process_auth(plugin_name, auth_packet) else: # send legacy handshake @@ -1076,7 +1794,7 @@ async def caching_sha2_password_auth(self, pkt): pkt = await self._read_packet() pkt.check_error() - async def sha256_password_auth(self, pkt: Any) -> Any: + async def sha256_password_auth(self, pkt): if self._secure: logger.debug("sha256: Sending plain password") data = self._password.encode('latin1') + b'\0' @@ -1118,19 +1836,19 @@ async def sha256_password_auth(self, pkt: Any) -> Any: return pkt # _mysql support - def thread_id(self) -> int: + def thread_id(self): return self.server_thread_id[0] - def character_set_name(self) -> str: + def character_set_name(self): return self._charset - def get_host_info(self) -> str: + def get_host_info(self): return self.host_info - def get_proto_info(self) -> Tuple[int, int, int]: + def get_proto_info(self): return self.protocol_version - async def _get_server_information(self) -> None: + async def _get_server_information(self): i = 0 packet = await self._read_packet() data = packet.get_all_data() @@ -1150,7 +1868,7 @@ async def _get_server_information(self) -> None: self.server_capabilities = struct.unpack('= i + 6: lang, stat, cap_h, salt_len = struct.unpack(' None: i += 1 # AUTH PLUGIN NAME may appear here. + # noinspection PyUnresolvedReferences if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i: # Due to Bug#59453 the auth-plugin-name is missing the terminating # NUL-char in versions prior to 5.5.10 and 5.6.2. @@ -1192,7 +1911,7 @@ async def _get_server_information(self) -> None: else: self._server_auth_plugin = data[i:server_end].decode('latin1') - def get_transaction_status(self) -> bool: + def get_transaction_status(self): return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_IN_TRANS) def get_server_info(self): @@ -1200,18 +1919,18 @@ def get_server_info(self): # Just to always have consistent errors 2 helpers - def _close_on_cancel(self) -> None: + def _close_on_cancel(self): self.close() self._close_reason = "Cancelled during execution" - def _ensure_alive(self) -> None: + def _ensure_alive(self): if not self._writer: if self._close_reason is None: raise InterfaceError("(0, 'Not connected')") else: raise InterfaceError(self._close_reason) - def __del__(self) -> None: + def __del__(self): if self._writer: warnings.warn("Unclosed connection {!r}".format(self), ResourceWarning) @@ -1229,29 +1948,26 @@ def __del__(self) -> None: NotSupportedError = NotSupportedError -async def _close_cursor(c: Cursor) -> None: - await c.close() - - # TODO: move OK and EOF packet parsing/logic into a proper subclass # of MysqlPacket like has been done with FieldDescriptorPacket. class MySQLResult: - def __init__(self, connection: Connection) -> None: - self.connection: Connection = connection - self.affected_rows: Optional[int] = None - self.insert_id: Optional[int] = None - self.server_status: Optional[int] = None - self.warning_count: int = 0 - self.message: Optional[str] = None - self.field_count: int = 0 - self.description: Optional[List] = None - self.rows: Optional[List] = None - self.has_next: Optional[bool] = None - self.unbuffered_active: bool = False - - async def read(self) -> None: + def __init__(self, connection): + self.connection = connection + self.affected_rows = None + self.insert_id = None + self.server_status = None + self.warning_count = 0 + self.message = None + self.field_count = 0 + self.description = None + self.rows = None + self.has_next = None + self.unbuffered_active = False + + async def read(self): try: + # noinspection PyProtectedMember first_packet = await self.connection._read_packet() # TODO: use classes for different packet types? @@ -1264,8 +1980,9 @@ async def read(self) -> None: finally: self.connection = None - async def init_unbuffered_query(self) -> None: + async def init_unbuffered_query(self): self.unbuffered_active = True + # noinspection PyProtectedMember first_packet = await self.connection._read_packet() if first_packet.is_ok_packet(): @@ -1285,7 +2002,7 @@ async def init_unbuffered_query(self) -> None: # we set it to this instead of None, which would be preferred. self.affected_rows = 18446744073709551615 - def _read_ok_packet(self, first_packet) -> None: + def _read_ok_packet(self, first_packet): ok_packet = OKPacketWrapper(first_packet) self.affected_rows = ok_packet.affected_rows self.insert_id = ok_packet.insert_id @@ -1294,22 +2011,23 @@ def _read_ok_packet(self, first_packet) -> None: self.message = ok_packet.message self.has_next = ok_packet.has_next - async def _read_load_local_packet(self, first_packet) -> None: + async def _read_load_local_packet(self, first_packet): load_packet = LoadLocalPacketWrapper(first_packet) sender = LoadLocalFile(load_packet.filename, self.connection) try: await sender.send_data() except Exception: # Skip ok packet + # noinspection PyProtectedMember await self.connection._read_packet() raise - + # noinspection PyProtectedMember ok_packet = await self.connection._read_packet() if not ok_packet.is_ok_packet(): raise OperationalError(2014, "Commands Out of Sync") self._read_ok_packet(ok_packet) - def _check_packet_is_eof(self, packet) -> bool: + def _check_packet_is_eof(self, packet): if packet.is_eof_packet(): eof_packet = EOFPacketWrapper(packet) self.warning_count = eof_packet.warning_count @@ -1317,16 +2035,16 @@ def _check_packet_is_eof(self, packet) -> bool: return True return False - async def _read_result_packet(self, first_packet) -> None: + async def _read_result_packet(self, first_packet): self.field_count = first_packet.read_length_encoded_integer() await self._get_descriptions() await self._read_rowdata_packet() - async def _read_rowdata_packet_unbuffered(self) -> Optional[Tuple[Optional[Any], ...]]: + async def _read_rowdata_packet_unbuffered(self): # Check if in an active query if not self.unbuffered_active: return - + # noinspection PyProtectedMember packet = await self.connection._read_packet() if self._check_packet_is_eof(packet): self.unbuffered_active = False @@ -1340,12 +2058,13 @@ async def _read_rowdata_packet_unbuffered(self) -> Optional[Tuple[Optional[Any], self.rows = (row,) return row - async def _finish_unbuffered_query(self) -> None: + async def _finish_unbuffered_query(self): # After much reading on the MySQL protocol, it appears that there is, # in fact, no way to stop MySQL from sending all the data after # executing a query, so we just spin, and wait for an EOF packet. while self.unbuffered_active: try: + # noinspection PyProtectedMember packet = await self.connection._read_packet() except OperationalError as e: # TODO: replace these numbers with constants when available @@ -1366,10 +2085,11 @@ async def _finish_unbuffered_query(self) -> None: # release reference to kill cyclic reference. self.connection = None - async def _read_rowdata_packet(self) -> None: + async def _read_rowdata_packet(self): """Read a rowdata packet for each data row in the result set.""" rows = [] while True: + # noinspection PyProtectedMember packet = await self.connection._read_packet() if self._check_packet_is_eof(packet): # release reference to kill cyclic reference. @@ -1380,7 +2100,7 @@ async def _read_rowdata_packet(self) -> None: self.affected_rows = len(rows) self.rows = tuple(rows) - def _read_row_from_packet(self, packet) -> Tuple[Optional[Any], ...]: + def _read_row_from_packet(self, packet): row = [] for encoding, converter in self.converters: try: @@ -1397,7 +2117,7 @@ def _read_row_from_packet(self, packet) -> Tuple[Optional[Any], ...]: row.append(data) return tuple(row) - async def _get_descriptions(self) -> None: + async def _get_descriptions(self): """Read a column descriptor packet for each column in the result.""" self.fields = [] self.converters = [] @@ -1405,12 +2125,14 @@ async def _get_descriptions(self) -> None: conn_encoding = self.connection.encoding description = [] for i in range(self.field_count): + # noinspection PyProtectedMember field = await self.connection._read_packet( FieldDescriptorPacket) self.fields.append(field) description.append(field.description()) field_type = field.type_code if use_unicode: + # noinspection PyUnresolvedReferences if field_type == FIELD_TYPE.JSON: # When SELECT from JSON column: charset = binary # When SELECT CAST(... AS JSON): charset = connection @@ -1436,7 +2158,7 @@ async def _get_descriptions(self) -> None: if converter is through: converter = None self.converters.append((encoding, converter)) - + # noinspection PyProtectedMember eof_packet = await self.connection._read_packet() assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF' self.description = tuple(description) @@ -1450,9 +2172,9 @@ def __init__(self, filename, connection): self._file_object = None self._executor = None # means use default executor - def _open_file(self) -> asyncio.Future: + def _open_file(self): - def opener(filename) -> None: + def opener(filename): try: self._file_object = open(filename, 'rb') except IOError as e: @@ -1462,7 +2184,7 @@ def opener(filename) -> None: fut = self._loop.run_in_executor(self._executor, opener, self.filename) return fut - def _file_read(self, chunk_size: int) -> asyncio.Future: + def _file_read(self, chunk_size): def freader(chunk_size): try: @@ -1482,8 +2204,9 @@ def freader(chunk_size): fut = self._loop.run_in_executor(self._executor, freader, chunk_size) return fut - async def send_data(self) -> None: + async def send_data(self): """Send data packets from the local file to the server""" + # noinspection PyProtectedMember self.connection._ensure_alive() conn = self.connection @@ -1498,8 +2221,9 @@ async def send_data(self) -> None: # TODO: consider drain data conn.write_packet(chunk) except asyncio.CancelledError: + # noinspection PyProtectedMember self.connection._close_on_cancel() - raise + raise asyncio.CancelledError("send_data method was cancelled") from None finally: # send the empty packet to signify we are done sending data conn.write_packet(b"") diff --git a/aiomysql/cursors.py b/aiomysql/cursors.py deleted file mode 100644 index 854be749..00000000 --- a/aiomysql/cursors.py +++ /dev/null @@ -1,714 +0,0 @@ -# https://github.com/PyMySQL/PyMySQL/blob/master/pymysql/cursors.py#L11-L18 -import contextlib -import json -import re -import warnings - -from pymysql.err import ( - Warning, - Error, - InterfaceError, - DataError, - DatabaseError, - OperationalError, - IntegrityError, - InternalError, - NotSupportedError, - ProgrammingError -) - -from .connection import FIELD_TYPE -from .log import logger - -#: Regular expression for :meth:`Cursor.executemany`. -#: executemany only supports simple bulk insert. -#: You can use it to load large dataset. -# flake8: noqa -RE_INSERT_VALUES = re.compile( - r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" + - r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" + - r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", - re.IGNORECASE | re.DOTALL) - - -class Cursor: - """Cursor is used to interact with the database.""" - - #: Max statement size which :meth:`executemany` generates. - #: - #: Max size of allowed statement is max_allowed_packet - - # packet_header_size. - #: Default value of max_allowed_packet is 1048576. - max_stmt_length = 1024000 - - def __init__(self, connection, echo=False): - """Do not create an instance of a Cursor yourself. Call - connections.Connection.cursor(). - """ - self._connection = connection - self._loop = self._connection.loop - self._description = None - self._rownumber = 0 - self._rowcount = -1 - self._arraysize = 1 - self._executed = None - self._result = None - self._rows = None - self._lastrowid = None - self._echo = echo - - @property - def connection(self): - """This read-only attribute return a reference to the Connection - object on which the cursor was created.""" - return self._connection - - @property - def description(self): - """This read-only attribute is a sequence of 7-item sequences. - - Each of these sequences is a collections.namedtuple containing - information describing one result column: - - 0. name: the name of the column returned. - 1. type_code: the type of the column. - 2. display_size: the actual length of the column in bytes. - 3. internal_size: the size in bytes of the column associated to - this column on the server. - 4. precision: total number of significant digits in columns of - type NUMERIC. None for other types. - 5. scale: count of decimal digits in the fractional part in - columns of type NUMERIC. None for other types. - 6. null_ok: always None as not easy to retrieve from the libpq. - - This attribute will be None for operations that do not - return rows or if the cursor has not had an operation invoked - via the execute() method yet. - """ - return self._description - - @property - def rowcount(self): - """Returns the number of rows that has been produced of affected. - - This read-only attribute specifies the number of rows that the - last :meth:`execute` produced (for Data Query Language - statements like SELECT) or affected (for Data Manipulation - Language statements like UPDATE or INSERT). - - The attribute is -1 in case no .execute() has been performed - on the cursor or the row count of the last operation if it - can't be determined by the interface. - """ - return self._rowcount - - @property - def rownumber(self): - """Row index. - - This read-only attribute provides the current 0-based index of the - cursor in the result set or ``None`` if the index cannot be - determined. - """ - - return self._rownumber - - @property - def arraysize(self): - """How many rows will be returned by fetchmany() call. - - This read/write attribute specifies the number of rows to - fetch at a time with fetchmany(). It defaults to - 1 meaning to fetch a single row at a time. - - """ - return self._arraysize - - @arraysize.setter - def arraysize(self, val): - """How many rows will be returned by fetchmany() call. - - This read/write attribute specifies the number of rows to - fetch at a time with fetchmany(). It defaults to - 1 meaning to fetch a single row at a time. - - """ - self._arraysize = val - - @property - def lastrowid(self): - """This read-only property returns the value generated for an - AUTO_INCREMENT column by the previous INSERT or UPDATE statement - or None when there is no such value available. For example, - if you perform an INSERT into a table that contains an AUTO_INCREMENT - column, lastrowid returns the AUTO_INCREMENT value for the new row. - """ - return self._lastrowid - - @property - def echo(self): - """Return echo mode status.""" - return self._echo - - @property - def closed(self): - """The readonly property that returns ``True`` if connections was - detached from current cursor - """ - return True if not self._connection else False - - async def close(self): - """Closing a cursor just exhausts all remaining data.""" - conn = self._connection - if conn is None: - return - try: - while await self.nextset(): - pass - finally: - self._connection = None - - def _get_db(self): - if not self._connection: - raise ProgrammingError("Cursor closed") - return self._connection - - def _check_executed(self): - if not self._executed: - raise ProgrammingError("execute() first") - - def _conv_row(self, row): - return row - - def setinputsizes(self, *args): - """Does nothing, required by DB API.""" - - def setoutputsizes(self, *args): - """Does nothing, required by DB API.""" - - async def nextset(self): - """Get the next query set""" - conn = self._get_db() - current_result = self._result - if current_result is None or current_result is not conn._result: - return - if not current_result.has_next: - return - self._result = None - self._clear_result() - await conn.next_result() - await self._do_get_result() - return True - - def _escape_args(self, args, conn): - if isinstance(args, (tuple, list)): - return tuple(conn.escape(arg) for arg in args) - elif isinstance(args, dict): - return dict((key, conn.escape(val)) for (key, val) in args.items()) - else: - # If it's not a dictionary let's try escaping it anyway. - # Worst case it will throw a Value error - return conn.escape(args) - - def mogrify(self, query, args=None): - """ Returns the exact string that is sent to the database by calling - the execute() method. This method follows the extension to the DB - API 2.0 followed by Psycopg. - - :param query: ``str`` sql statement - :param args: ``tuple`` or ``list`` of arguments for sql query - """ - conn = self._get_db() - if args is not None: - query = query % self._escape_args(args, conn) - return query - - async def execute(self, query, args=None): - """Executes the given operation - - Executes the given operation substituting any markers with - the given parameters. - - For example, getting all rows where id is 5: - cursor.execute("SELECT * FROM t1 WHERE id = %s", (5,)) - - :param query: ``str`` sql statement - :param args: ``tuple`` or ``list`` of arguments for sql query - :returns: ``int``, number of rows that has been produced of affected - """ - conn = self._get_db() - - while await self.nextset(): - pass - - if args is not None: - query = query % self._escape_args(args, conn) - - await self._query(query) - self._executed = query - if self._echo: - logger.info(query) - logger.info("%r", args) - return self._rowcount - - async def executemany(self, query, args): - """Execute the given operation multiple times - - The executemany() method will execute the operation iterating - over the list of parameters in seq_params. - - Example: Inserting 3 new employees and their phone number - - data = [ - ('Jane','555-001'), - ('Joe', '555-001'), - ('John', '555-003') - ] - stmt = "INSERT INTO employees (name, phone) VALUES ('%s','%s')" - await cursor.executemany(stmt, data) - - INSERT or REPLACE statements are optimized by batching the data, - that is using the MySQL multiple rows syntax. - - :param query: `str`, sql statement - :param args: ``tuple`` or ``list`` of arguments for sql query - """ - if not args: - return - - if self._echo: - logger.info("CALL %s", query) - logger.info("%r", args) - - m = RE_INSERT_VALUES.match(query) - if m: - q_prefix = m.group(1) % () - q_values = m.group(2).rstrip() - q_postfix = m.group(3) or '' - assert q_values[0] == '(' and q_values[-1] == ')' - return (await self._do_execute_many( - q_prefix, q_values, q_postfix, args, self.max_stmt_length, - self._get_db().encoding)) - else: - rows = 0 - for arg in args: - await self.execute(query, arg) - rows += self._rowcount - self._rowcount = rows - return self._rowcount - - async def _do_execute_many(self, prefix, values, postfix, args, - max_stmt_length, encoding): - conn = self._get_db() - escape = self._escape_args - if isinstance(prefix, str): - prefix = prefix.encode(encoding) - if isinstance(postfix, str): - postfix = postfix.encode(encoding) - sql = bytearray(prefix) - args = iter(args) - v = values % escape(next(args), conn) - if isinstance(v, str): - v = v.encode(encoding, 'surrogateescape') - sql += v - rows = 0 - for arg in args: - v = values % escape(arg, conn) - if isinstance(v, str): - v = v.encode(encoding, 'surrogateescape') - if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length: - r = await self.execute(sql + postfix) - rows += r - sql = bytearray(prefix) - else: - sql += b',' - sql += v - r = await self.execute(sql + postfix) - rows += r - self._rowcount = rows - return rows - - async def callproc(self, procname, args=()): - """Execute stored procedure procname with args - - Compatibility warning: PEP-249 specifies that any modified - parameters must be returned. This is currently impossible - as they are only available by storing them in a server - variable and then retrieved by a query. Since stored - procedures return zero or more result sets, there is no - reliable way to get at OUT or INOUT parameters via callproc. - The server variables are named @_procname_n, where procname - is the parameter above and n is the position of the parameter - (from zero). Once all result sets generated by the procedure - have been fetched, you can issue a SELECT @_procname_0, ... - query using .execute() to get any OUT or INOUT values. - - Compatibility warning: The act of calling a stored procedure - itself creates an empty result set. This appears after any - result sets generated by the procedure. This is non-standard - behavior with respect to the DB-API. Be sure to use nextset() - to advance through all result sets; otherwise you may get - disconnected. - - :param procname: ``str``, name of procedure to execute on server - :param args: `sequence of parameters to use with procedure - :returns: the original args. - """ - conn = self._get_db() - if self._echo: - logger.info("CALL %s", procname) - logger.info("%r", args) - - for index, arg in enumerate(args): - q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg)) - await self._query(q) - await self.nextset() - - _args = ','.join('@_%s_%d' % (procname, i) for i in range(len(args))) - q = "CALL %s(%s)" % (procname, _args) - await self._query(q) - self._executed = q - return args - - def fetchone(self): - """Fetch the next row """ - self._check_executed() - fut = self._loop.create_future() - - if self._rows is None or self._rownumber >= len(self._rows): - fut.set_result(None) - return fut - result = self._rows[self._rownumber] - self._rownumber += 1 - - fut = self._loop.create_future() - fut.set_result(result) - return fut - - def fetchmany(self, size=None): - """Returns the next set of rows of a query result, returning a - list of tuples. When no more rows are available, it returns an - empty list. - - The number of rows returned can be specified using the size argument, - which defaults to one - - :param size: ``int`` number of rows to return - :returns: ``list`` of fetched rows - """ - self._check_executed() - fut = self._loop.create_future() - if self._rows is None: - fut.set_result([]) - return fut - end = self._rownumber + (size or self._arraysize) - result = self._rows[self._rownumber:end] - self._rownumber = min(end, len(self._rows)) - - fut.set_result(result) - return fut - - def fetchall(self): - """Returns all rows of a query result set - - :returns: ``list`` of fetched rows - """ - self._check_executed() - fut = self._loop.create_future() - if self._rows is None: - fut.set_result([]) - return fut - - if self._rownumber: - result = self._rows[self._rownumber:] - else: - result = self._rows - self._rownumber = len(self._rows) - - fut.set_result(result) - return fut - - def scroll(self, value, mode='relative'): - """Scroll the cursor in the result set to a new position according - to mode. - - If mode is relative (default), value is taken as offset to the - current position in the result set, if set to absolute, value - states an absolute target position. An IndexError should be raised in - case a scroll operation would leave the result set. In this case, - the cursor position is left undefined (ideal would be to - not move the cursor at all). - - :param int value: move cursor to next position according to mode. - :param str mode: scroll mode, possible modes: `relative` and `absolute` - """ - self._check_executed() - if mode == 'relative': - r = self._rownumber + value - elif mode == 'absolute': - r = value - else: - raise ProgrammingError("unknown scroll mode %s" % mode) - - if not (0 <= r < len(self._rows)): - raise IndexError("out of range") - self._rownumber = r - - fut = self._loop.create_future() - fut.set_result(None) - return fut - - async def _query(self, q): - conn = self._get_db() - self._last_executed = q - self._clear_result() - await conn.query(q) - await self._do_get_result() - - def _clear_result(self): - self._rownumber = 0 - self._result = None - - self._rowcount = 0 - self._description = None - self._lastrowid = None - self._rows = None - - async def _do_get_result(self): - conn = self._get_db() - self._rownumber = 0 - self._result = result = conn._result - self._rowcount = result.affected_rows - self._description = result.description - self._lastrowid = result.insert_id - self._rows = result.rows - - if result.warning_count > 0: - await self._show_warnings(conn) - - async def _show_warnings(self, conn): - if self._result and self._result.has_next: - return - ws = await conn.show_warnings() - if ws is None: - return - for w in ws: - msg = w[-1] - warnings.warn(str(msg), Warning, 4) - - Warning = Warning - Error = Error - InterfaceError = InterfaceError - DatabaseError = DatabaseError - DataError = DataError - OperationalError = OperationalError - IntegrityError = IntegrityError - InternalError = InternalError - ProgrammingError = ProgrammingError - NotSupportedError = NotSupportedError - - def __aiter__(self): - return self - - async def __anext__(self): - ret = await self.fetchone() - if ret is not None: - return ret - else: - raise StopAsyncIteration # noqa - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - return - - -class _DeserializationCursorMixin: - async def _do_get_result(self): - await super()._do_get_result() - if self._rows: - self._rows = [self._deserialization_row(r) for r in self._rows] - - def _deserialization_row(self, row): - if row is None: - return None - if isinstance(row, dict): - dict_flag = True - else: - row = list(row) - dict_flag = False - for index, (name, field_type, *n) in enumerate(self._description): - if field_type == FIELD_TYPE.JSON: - point = name if dict_flag else index - with contextlib.suppress(ValueError, TypeError): - row[point] = json.loads(row[point]) - if dict_flag: - return row - else: - return tuple(row) - - def _conv_row(self, row): - if row is None: - return None - row = super()._conv_row(row) - return self._deserialization_row(row) - - -class DeserializationCursor(_DeserializationCursorMixin, Cursor): - """A cursor automatic deserialization of json type fields""" - - -class _DictCursorMixin: - # You can override this to use OrderedDict or other dict-like types. - dict_type = dict - - async def _do_get_result(self): - await super()._do_get_result() - fields = [] - if self._description: - for f in self._result.fields: - name = f.name - if name in fields: - name = f.table_name + '.' + name - fields.append(name) - self._fields = fields - - if fields and self._rows: - self._rows = [self._conv_row(r) for r in self._rows] - - def _conv_row(self, row): - if row is None: - return None - row = super()._conv_row(row) - return self.dict_type(zip(self._fields, row)) - - -class DictCursor(_DictCursorMixin, Cursor): - """A cursor which returns results as a dictionary""" - - -class SSCursor(Cursor): - """Unbuffered Cursor, mainly useful for queries that return a lot of - data, or for connections to remote servers over a slow network. - - Instead of copying every row of data into a buffer, this will fetch - rows as needed. The upside of this, is the client uses much less memory, - and rows are returned much faster when traveling over a slow network, - or if the result set is very big. - - There are limitations, though. The MySQL protocol doesn't support - returning the total number of rows, so the only way to tell how many rows - there are is to iterate over every row returned. Also, it currently isn't - possible to scroll backwards, as only the current row is held in memory. - """ - - async def close(self): - conn = self._connection - if conn is None: - return - - if self._result is not None and self._result is conn._result: - await self._result._finish_unbuffered_query() - - try: - while (await self.nextset()): - pass - finally: - self._connection = None - - async def _query(self, q): - conn = self._get_db() - self._last_executed = q - await conn.query(q, unbuffered=True) - await self._do_get_result() - return self._rowcount - - async def _read_next(self): - """Read next row """ - row = await self._result._read_rowdata_packet_unbuffered() - row = self._conv_row(row) - return row - - async def fetchone(self): - """ Fetch next row """ - self._check_executed() - row = await self._read_next() - if row is None: - return - self._rownumber += 1 - return row - - async def fetchall(self): - """Fetch all, as per MySQLdb. Pretty useless for large queries, as - it is buffered. - """ - rows = [] - while True: - row = await self.fetchone() - if row is None: - break - rows.append(row) - return rows - - async def fetchmany(self, size=None): - """Returns the next set of rows of a query result, returning a - list of tuples. When no more rows are available, it returns an - empty list. - - The number of rows returned can be specified using the size argument, - which defaults to one - - :param size: ``int`` number of rows to return - :returns: ``list`` of fetched rows - """ - self._check_executed() - if size is None: - size = self._arraysize - - rows = [] - for i in range(size): - row = await self._read_next() - if row is None: - break - rows.append(row) - self._rownumber += 1 - return rows - - async def scroll(self, value, mode='relative'): - """Scroll the cursor in the result set to a new position - according to mode . Same as :meth:`Cursor.scroll`, but move cursor - on server side one by one row. If you want to move 20 rows forward - scroll will make 20 queries to move cursor. Currently only forward - scrolling is supported. - - :param int value: move cursor to next position according to mode. - :param str mode: scroll mode, possible modes: `relative` and `absolute` - """ - - self._check_executed() - - if mode == 'relative': - if value < 0: - raise NotSupportedError("Backwards scrolling not supported " - "by this cursor") - - for _ in range(value): - await self._read_next() - self._rownumber += value - elif mode == 'absolute': - if value < self._rownumber: - raise NotSupportedError( - "Backwards scrolling not supported by this cursor") - - end = value - self._rownumber - for _ in range(end): - await self._read_next() - self._rownumber = value - else: - raise ProgrammingError("unknown scroll mode %s" % mode) - - -class SSDictCursor(_DictCursorMixin, SSCursor): - """An unbuffered cursor, which returns results as a dictionary """ diff --git a/aiomysql/pool.py b/aiomysql/pool.py index f95aaae6..ff0e0cca 100644 --- a/aiomysql/pool.py +++ b/aiomysql/pool.py @@ -4,12 +4,10 @@ import asyncio import collections import warnings -from types import TracebackType from typing import ( Optional, Any, - Deque, - Type + Deque ) from aiomysql.connection import ( @@ -17,12 +15,10 @@ Connection ) from aiomysql.utils import ( - _ContextManager + _ContextManager, _Release, _TObj ) -# todo: Update Any to stricter kwarg -# https://github.com/python/mypy/issues/4441 def create_pool( minsize: int = 1, maxsize: int = 10, @@ -40,8 +36,6 @@ async def _destroy_pool(pool: "Pool") -> None: await pool.wait_closed() -# todo: Update Any to stricter kwarg -# https://github.com/python/mypy/issues/4441 async def _create_pool( minsize: int = 1, maxsize: int = 10, @@ -63,6 +57,79 @@ async def _create_pool( return pool +class _PoolContextManager(_ContextManager): + async def __aexit__(self, exc_type, exc, tb): + self._obj.close() + await self._obj.wait_closed() + self._obj = None + + +class _PoolAcquireContextManager(_ContextManager): + __slots__ = ('_coro', '_conn', '_pool') + + def __init__(self, coro, pool, release: _Release[_TObj]): + super().__init__(coro, release) + self._coro = coro + self._conn = None + self._pool = pool + + async def __aenter__(self): + self._conn = await self._coro + return self._conn + + async def __aexit__(self, exc_type, exc, tb): + try: + await self._pool.release(self._conn) + finally: + self._pool = None + self._conn = None + + +class _PoolConnectionContextManager: + """Context manager. + + This enables the following idiom for acquiring and releasing a + connection around a block: + + with (yield from pool) as conn: + cur = yield from conn.cursor() + + while failing loudly when accidentally using: + + with pool: + + """ + + __slots__ = ('_pool', '_conn') + + def __init__(self, pool, conn): + self._pool = pool + self._conn = conn + + def __enter__(self): + assert self._conn + return self._conn + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + self._pool.release(self._conn) + finally: + self._pool = None + self._conn = None + + async def __aenter__(self): + assert not self._conn + self._conn = await self._pool.acquire() + return self._conn + + async def __aexit__(self, exc_type, exc_val, exc_tb): + try: + await self._pool.release(self._conn) + finally: + self._pool = None + self._conn = None + + class Pool(asyncio.AbstractServer): """Connection pool""" @@ -188,6 +255,7 @@ async def _acquire(self) -> Connection: else: await self._cond.wait() + # noinspection PyProtectedMember async def _fill_free_pool(self, override_min: bool) -> None: # iterate over free connections and remove timed out ones free_size = len(self._free) @@ -243,7 +311,8 @@ async def _wakeup(self) -> None: self._cond.notify() def release(self, conn: Any) -> asyncio.Future: - """Release free connection back to the connection pool. + """ + Release free connection back to the connection pool. This is **NOT** a coroutine. """ @@ -270,7 +339,7 @@ def release(self, conn: Any) -> asyncio.Future: def __enter__(self) -> None: raise RuntimeError( - '"yield from" should be used as context manager expression') + '"await" should be used as context manager expression') # todo: Update Any to stricter kwarg # https://github.com/python/mypy/issues/4441 @@ -279,7 +348,7 @@ def __exit__(self, *args: Any) -> None: # always raises; that's how the with-statement works. pass # pragma: nocover - def __iter__(self) -> _ContextManager: + def __iter__(self): # This is not a coroutine. It is meant to enable the idiom: # # with (yield from pool) as conn: @@ -293,20 +362,18 @@ def __iter__(self) -> _ContextManager: # finally: # conn.release() conn = yield from self.acquire() - return _ContextManager[Connection](conn, self.release) + return _PoolConnectionContextManager(self, conn) - def __await__(self) -> _ContextManager: + def __await__(self): msg = "with await pool as conn deprecated, use" \ "async with pool.acquire() as conn instead" warnings.warn(msg, DeprecationWarning, stacklevel=2) conn = yield from self.acquire() - return _ContextManager[Connection](conn, self.release) + return _PoolConnectionContextManager(self, conn) - async def __aenter__(self) -> 'Pool': + async def __aenter__(self): return self - async def __aexit__(self, exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + async def __aexit__(self, exc_type, exc_val, exc_tb): self.close() await self.wait_closed() diff --git a/aiomysql/sa/connection.py b/aiomysql/sa/connection.py index f641bd58..f6cec390 100644 --- a/aiomysql/sa/connection.py +++ b/aiomysql/sa/connection.py @@ -16,7 +16,7 @@ from sqlalchemy.sql.ddl import DDLElement from sqlalchemy.sql.dml import UpdateBase -from . import exc, Engine +from . import exc from .result import create_result_proxy, ResultProxy from .transaction import ( RootTransaction, @@ -46,7 +46,7 @@ class SAConnection: def __init__( self, connection: Connection, - engine: Engine, + engine, compiled_cache: Optional[Any] = None, ) -> None: self._connection = connection @@ -57,8 +57,6 @@ def __init__( self._dialect = engine.dialect self._compiled_cache = compiled_cache - # todo: Update Any to stricter kwarg - # https://github.com/python/mypy/issues/4441 def execute( self, query: Union[str, DDLElement, ClauseElement], @@ -179,8 +177,6 @@ async def _executemany( self._weak_results.add(ret) return ret - # todo: Update Any to stricter kwarg - # https://github.com/python/mypy/issues/4441 async def _execute( self, query: Union[str, DDLElement, ClauseElement], @@ -236,8 +232,6 @@ async def _execute( self._weak_results.add(ret) return ret - # todo: Update Any to stricter kwarg - # https://github.com/python/mypy/issues/4441 async def scalar( self, query: Union[str, DDLElement, ClauseElement], diff --git a/aiomysql/sa/engine.py b/aiomysql/sa/engine.py index 1367c9eb..53620d32 100644 --- a/aiomysql/sa/engine.py +++ b/aiomysql/sa/engine.py @@ -10,18 +10,13 @@ Union ) -from sqlalchemy import ( - Engine, - Dialect -) - import aiomysql from .connection import SAConnection from .exc import ( InvalidRequestError, ArgumentError ) -from ..cursors import ( +from ..connection import ( Cursor, DeserializationCursor, DictCursor, @@ -65,18 +60,17 @@ def _exec_default(self, default: Any) -> Any: _dialect.default_paramstyle = 'pyformat' -# todo: Update Any to stricter kwarg -# https://github.com/python/mypy/issues/4441 def create_engine( minsize: int = 1, maxsize: int = 10, loop: Optional[asyncio.AbstractEventLoop] = None, - dialect: Dialect = _dialect, + dialect=_dialect, pool_recycle: int = -1, compiled_cache: Optional[Dict[str, Any]] = None, **kwargs: Union[str, int, bool, Any] ): - """A coroutine for Engine creation. + """ + A coroutine for Engine creation. Returns Engine instance with embedded connection pool. @@ -108,17 +102,15 @@ async def _close_connection(c: SAConnection) -> None: await c.close() -# todo: Update Any to stricter kwarg -# https://github.com/python/mypy/issues/4441 async def _create_engine( minsize: int = 1, maxsize: int = 10, loop: Optional[asyncio.AbstractEventLoop] = None, - dialect: Dialect = _dialect, + dialect=_dialect, pool_recycle: int = -1, compiled_cache: Optional[Dict[str, Any]] = None, **kwargs: Any -) -> Engine: +): if loop is None: loop = asyncio.get_event_loop() pool = await aiomysql.create_pool(minsize=minsize, maxsize=maxsize, @@ -132,7 +124,8 @@ async def _create_engine( class Engine: - """Connects a aiomysql.Pool and + """ + Connects a aiomysql.Pool and sqlalchemy.engine.interfaces.Dialect together to provide a source of database connectivity and behavior. @@ -140,11 +133,9 @@ class Engine: create_engine coroutine. """ - # todo: Update Any to stricter kwarg - # https://github.com/python/mypy/issues/4441 def __init__( self, - dialect: Dialect, + dialect, pool: Any, compiled_cache: Any = None, **kwargs: Any @@ -155,17 +146,17 @@ def __init__( self._conn_kw = kwargs @property - def dialect(self) -> Dialect: - """An dialect for engine.""" + def dialect(self): + """A dialect for engine.""" return self._dialect @property - def name(self) -> Dialect.name: + def name(self): """A name of the dialect.""" return self._dialect.name @property - def driver(self) -> Dialect.driver: + def driver(self): """A driver of the dialect.""" return self._dialect.driver @@ -186,7 +177,8 @@ def freesize(self) -> int: return self._pool.freesize def close(self) -> None: - """Close engine. + """ + Close engine. Mark all engine connections to be closed on getting back to pool. Closed engine doesn't allow acquiring new connections. @@ -194,7 +186,8 @@ def close(self) -> None: self._pool.close() def terminate(self) -> None: - """Terminate engine. + """ + Terminate engine. Terminate engine pool with instantly closing all acquired connections also. @@ -214,13 +207,12 @@ async def _acquire(self) -> SAConnection: raw = await self._pool.acquire() return SAConnection(raw, self, compiled_cache=self._compiled_cache) - def release(self, conn: SAConnection) -> None: + def release(self, conn: SAConnection): """Revert connection to pool.""" if conn.in_transaction: raise InvalidRequestError("Cannot release a connection with " "not finished transaction") - raw = conn.connection - return self._pool.release(raw) + return self._pool.release(conn.connection) def __enter__(self): raise RuntimeError( @@ -266,7 +258,8 @@ async def __aexit__( class _ConnectionContextManager: - """Context manager. + """ + Context manager. This enables the following idiom for acquiring and releasing a connection around a block: @@ -280,8 +273,6 @@ class _ConnectionContextManager: """ - __slots__ = ('_engine', '_conn') - def __init__( self, engine: Engine, @@ -290,13 +281,13 @@ def __init__( self._engine = engine self._conn = conn - def __enter__(self) -> SAConnection: + async def __aenter__(self) -> SAConnection: assert self._conn is not None return self._conn - def __exit__(self, *args: Any) -> None: + async def __aexit__(self, *args: Any) -> None: try: - self._engine.release(self._conn) + await self._engine.release(self._conn) finally: self._engine = None self._conn = None diff --git a/aiomysql/sa/exc.py b/aiomysql/sa/exc.py index f4141f31..ca6ac4ce 100644 --- a/aiomysql/sa/exc.py +++ b/aiomysql/sa/exc.py @@ -6,14 +6,16 @@ class Error(Exception): class ArgumentError(Error): - """Raised when an invalid or conflicting function argument is supplied. + """ + Raised when an invalid or conflicting function argument is supplied. This error generally corresponds to construction time state errors. """ class InvalidRequestError(ArgumentError): - """aiomysql.sa was asked to do something it can't do. + """ + aiomysql.sa was asked to do something it can't do. This error generally corresponds to runtime state errors. """ @@ -24,5 +26,7 @@ class NoSuchColumnError(KeyError, InvalidRequestError): class ResourceClosedError(InvalidRequestError): - """An operation was requested from a connection, cursor, or other - object that's in a closed state.""" + """ + An operation was requested from a connection, cursor, or other + object that's in a closed state. + """ diff --git a/aiomysql/sa/result.py b/aiomysql/sa/result.py index 24ee7b63..207efe10 100644 --- a/aiomysql/sa/result.py +++ b/aiomysql/sa/result.py @@ -11,12 +11,12 @@ async def create_result_proxy(connection, cursor, dialect, result_map): result_proxy = ResultProxy(connection, cursor, dialect, result_map) + # noinspection PyProtectedMember await result_proxy._prepare() return result_proxy class RowProxy(Mapping): - __slots__ = ('_result_proxy', '_row', '_processors', '_keymap') def __init__(self, result_proxy, row, processors, keymap): @@ -36,6 +36,7 @@ def __getitem__(self, key): try: processor, obj, index = self._keymap[key] except KeyError: + # noinspection PyProtectedMember processor, obj, index = self._result_proxy._key_fallback(key) # Do we need slicing at all? RowProxy now is Mapping not Sequence # except TypeError: @@ -66,6 +67,7 @@ def __getattr__(self, name): raise AttributeError(e.args[0]) def __contains__(self, key): + # noinspection PyProtectedMember return self._result_proxy._has_key(self._row, key) __hash__ = None @@ -92,6 +94,7 @@ class ResultMetaData: """Handle cursor.description, applying additional info from an execution context.""" + # noinspection PyProtectedMember def __init__(self, result_proxy, metadata): self._processors = processors = [] @@ -170,6 +173,7 @@ def __init__(self, result_proxy, metadata): # high precedence keymap. keymap.update(primary_keymap) + # noinspection PyProtectedMember def _key_fallback(self, key, raiseerr=True): map = self._keymap result = None @@ -177,7 +181,7 @@ def _key_fallback(self, key, raiseerr=True): result = map.get(key) # fallback for targeting a ColumnElement to a textual expression # this is a rare use case which only occurs when matching text() - # or colummn('name') constructs to ColumnElements, or after a + # or column('name') constructs to ColumnElements, or after a # pickle/unpickle roundtrip elif isinstance(key, expression.ColumnElement): if key._label and key._label in map: @@ -188,7 +192,7 @@ def _key_fallback(self, key, raiseerr=True): # search extra hard to make sure this # isn't a column/label name overlap. # this check isn't currently available if the row - # was unpickled. + # was unpicked. if result is not None and result[1] is not None: for obj in result[1]: if key._compare_name_for_result(obj): @@ -214,7 +218,8 @@ def _has_key(self, row, key): class ResultProxy: - """Wraps a DB-API cursor object to provide easier access to row columns. + """ + Wraps a DB-API cursor object to provide easier access to row columns. Individual columns may be accessed by their integer position, case-insensitive column name, or by sqlalchemy schema.Column @@ -250,6 +255,7 @@ async def _prepare(self): def callback(wr): loop.create_task(cursor.close()) + self._weak = weakref.ref(self, callback) else: self._metadata = None @@ -274,7 +280,8 @@ def keys(self): @property def rowcount(self): - """Return the 'rowcount' for this result. + """ + Return the 'rowcount' for this result. The 'rowcount' reports the number of rows *matched* by the WHERE criterion of an UPDATE or DELETE statement. @@ -315,7 +322,8 @@ def lastrowid(self): @property def returns_rows(self): - """True if this ResultProxy returns rows. + """ + True if this ResultProxy returns rows. I.e. if it is legal to call the methods .fetchone(), .fetchmany() and .fetchall()`. @@ -327,7 +335,8 @@ def closed(self): return self._closed async def close(self): - """Close this ResultProxy. + """ + Close this ResultProxy. Closes the underlying DBAPI cursor corresponding to the execution. @@ -387,7 +396,8 @@ async def fetchall(self): return ret async def fetchone(self): - """Fetch one row, just like DB-API cursor.fetchone(). + """ + Fetch one row, just like DB-API cursor.fetchone(). If a row is present, the cursor remains open after this is called. Else the cursor is automatically closed and None is returned. @@ -404,7 +414,8 @@ async def fetchone(self): return None async def fetchmany(self, size=None): - """Fetch many rows, just like DB-API + """ + Fetch many rows, just like DB-API cursor.fetchmany(size=cursor.arraysize). If rows are present, the cursor remains open after this is called. @@ -424,19 +435,21 @@ async def fetchmany(self, size=None): return ret async def first(self): - """Fetch the first row and then close the result set unconditionally. + """ + Fetch the first row and then close the result set unconditionally. Returns None if no row is present. """ if self._metadata is None: self._non_result() try: - return (await self.fetchone()) + return await self.fetchone() finally: await self.close() async def scalar(self): - """Fetch the first column of the first row, and close the result set. + """ + Fetch the first column of the first row, and close the result set. Returns None if no row is present. """ diff --git a/aiomysql/sa/transaction.py b/aiomysql/sa/transaction.py index 9ef4643c..1ef95f10 100644 --- a/aiomysql/sa/transaction.py +++ b/aiomysql/sa/transaction.py @@ -53,7 +53,8 @@ def connection(self) -> Any: return self._connection async def close(self) -> None: - """Close this transaction. + """ + Close this transaction. If this transaction is the base transaction in a begin/commit nesting, the transaction will rollback(). Otherwise, the @@ -115,14 +116,17 @@ def __init__( super().__init__(connection, None) async def _do_rollback(self) -> None: + # noinspection PyProtectedMember await self._connection._rollback_impl() async def _do_commit(self) -> None: + # noinspection PyProtectedMember await self._connection._commit_impl() class NestedTransaction(Transaction): - """Represent a 'nested', or SAVEPOINT transaction. + """ + Represent a 'nested', or SAVEPOINT transaction. A new NestedTransaction object may be procured using the SAConnection.begin_nested() method. @@ -142,18 +146,21 @@ def __init__( async def _do_rollback(self) -> None: assert self._savepoint is not None, "Broken transaction logic" if self._is_active: + # noinspection PyProtectedMember await self._connection._rollback_to_savepoint_impl( self._savepoint, self._parent) async def _do_commit(self) -> None: assert self._savepoint is not None, "Broken transaction logic" if self._is_active: + # noinspection PyProtectedMember await self._connection._release_savepoint_impl( self._savepoint, self._parent) class TwoPhaseTransaction(Transaction): - """Represent a two-phase transaction. + """ + Represent a two-phase transaction. A new TwoPhaseTransaction object may be procured using the SAConnection.begin_twophase() method. @@ -177,13 +184,16 @@ def xid(self) -> 'xid': return self._xid async def prepare(self) -> None: - """Prepare this TwoPhaseTransaction. + """ + Prepare this TwoPhaseTransaction. After a PREPARE, the transaction can be committed. """ if not self._parent.is_active: raise exc.InvalidRequestError("This transaction is inactive") + + # noinspection PyProtectedMember await self._connection._prepare_twophase_impl(self._xid) self._is_prepared = True diff --git a/aiomysql/utils.py b/aiomysql/utils.py index 61640237..7616a4d7 100644 --- a/aiomysql/utils.py +++ b/aiomysql/utils.py @@ -1,28 +1,57 @@ +import asyncio import struct +import sys from types import TracebackType from typing import ( - Coroutine, - TypeVar, Any, + Awaitable, + Callable, + Coroutine, + Generator, + Generic, Optional, - Generic, Callable, Awaitable, Type, AsyncGenerator, Generator + Type, + TypeVar, + Union, ) -_Tobj = TypeVar("_Tobj") -_Release = Callable[[_Tobj], Awaitable[None]] +if sys.version_info >= (3, 7, 0): + __get_running_loop = asyncio.get_running_loop +else: + def __get_running_loop() -> asyncio.AbstractEventLoop: + loop = asyncio.get_event_loop() + if not loop.is_running(): + raise RuntimeError('no running event loop') + return loop + + +def get_running_loop() -> asyncio.AbstractEventLoop: + return __get_running_loop() + + +def create_completed_future( + loop: asyncio.AbstractEventLoop +) -> 'asyncio.Future[Any]': + future = loop.create_future() + future.set_result(None) + return future + + +_TObj = TypeVar("_TObj") +_Release = Callable[[_TObj], Awaitable[None]] -class _ContextManager(Coroutine[Any, None, _Tobj], Generic[_Tobj]): +class _ContextManager(Coroutine[Any, None, _TObj], Generic[_TObj]): __slots__ = ('_coro', '_obj', '_release', '_release_on_exception') def __init__( self, - coro: Coroutine[Any, None, _Tobj], - release: _Release[_Tobj], - release_on_exception: Optional[_Release[_Tobj]] = None + coro: Coroutine[Any, None, _TObj], + release: _Release[_TObj], + release_on_exception: Optional[_Release[_TObj]] = None ): self._coro = coro - self._obj: Optional[_Tobj] = None + self._obj: Optional[_TObj] = None self._release = release self._release_on_exception = ( release @@ -33,38 +62,25 @@ def __init__( def send(self, value: Any) -> 'Any': return self._coro.send(value) - def throw( + def throw( # type: ignore self, typ: Type[BaseException], - val: Optional[BaseException] = None, + val: Optional[Union[BaseException, object]] = None, tb: Optional[TracebackType] = None ) -> Any: if val is None: return self._coro.throw(typ) - elif tb is None: + if tb is None: return self._coro.throw(typ, val) - else: - return self._coro.throw(typ, val, tb) + return self._coro.throw(typ, val, tb) def close(self) -> None: - return self._coro.close() + self._coro.close() - async def __anext__(self) -> _Tobj: - try: - value = self._coro.send(None) - except StopAsyncIteration: - self._obj = None - raise - else: - return value - - def __aiter__(self) -> AsyncGenerator[None, _Tobj]: - return self._obj - - def __await__(self) -> Generator[Any, None, _Tobj]: + def __await__(self) -> Generator[Any, None, _TObj]: return self._coro.__await__() - async def __aenter__(self) -> _Tobj: + async def __aenter__(self) -> _TObj: self._obj = await self._coro assert self._obj return self._obj @@ -75,26 +91,28 @@ async def __aexit__( exc: Optional[BaseException], tb: Optional[TracebackType], ) -> None: + if self._obj is None: + return + try: - if exc_type is not None and exc is not None and tb is not None: + if exc_type is not None: await self._release_on_exception(self._obj) else: await self._release(self._obj) finally: - await self._obj.close() self._obj = None -class _IterableContextManager(_ContextManager[_Tobj]): +class _IterableContextManager(_ContextManager[_TObj]): __slots__ = () def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) - def __aiter__(self) -> '_IterableContextManager[_Tobj]': + def __aiter__(self) -> '_IterableContextManager[_TObj]': return self - async def __anext__(self) -> _Tobj: + async def __anext__(self) -> _TObj: if self._obj is None: self._obj = await self._coro @@ -108,11 +126,11 @@ async def __anext__(self) -> _Tobj: raise -def _pack_int24(n: int) -> bytes: +def _pack_int24(n): return struct.pack(" bytes: +def _lenenc_int(i): if i < 0: raise ValueError( "Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i diff --git a/tests/test_connection.py b/tests/test_connection.py index c0c1be3d..b3a876d3 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -48,7 +48,7 @@ async def test_config_file(fill_my_cnf, connection_creator, mysql_params): # make sure connection is working cur = await conn.cursor() await cur.execute('SELECT 42;') - (r, ) = await cur.fetchone() + (r,) = await cur.fetchone() assert r == 42 conn.close() @@ -71,7 +71,7 @@ async def test_config_file_with_different_group(fill_my_cnf, # make sure connection is working cur = await conn.cursor() await cur.execute('SELECT 42;') - (r, ) = await cur.fetchone() + (r,) = await cur.fetchone() assert r == 42 conn.close() diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 12fd41ff..4e8c6685 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -3,7 +3,7 @@ import pytest from aiomysql import ProgrammingError, Cursor, InterfaceError, OperationalError -from aiomysql.cursors import RE_INSERT_VALUES +from aiomysql.connection import RE_INSERT_VALUES async def _prepare(conn): @@ -272,11 +272,12 @@ async def test_executemany(connection_creator): async def test_custom_cursor(connection_creator): class MyCursor(Cursor): pass + conn = await connection_creator() cur = await conn.cursor(MyCursor) assert isinstance(cur, MyCursor) await cur.execute("SELECT 42;") - (r, ) = await cur.fetchone() + (r,) = await cur.fetchone() assert r == 42 @@ -284,6 +285,7 @@ class MyCursor(Cursor): async def test_custom_cursor_not_cursor_subclass(connection_creator): class MyCursor2: pass + conn = await connection_creator() with pytest.raises(TypeError): await conn.cursor(MyCursor2) diff --git a/tests/test_deserialize_cursor.py b/tests/test_deserialize_cursor.py index 0fab3181..eab11f82 100644 --- a/tests/test_deserialize_cursor.py +++ b/tests/test_deserialize_cursor.py @@ -1,9 +1,8 @@ import copy -import aiomysql.cursors - import pytest +import aiomysql.connection BOB = ("bob", 21, {"k1": "pretty", "k2": [18, 25]}) JIM = ("jim", 56, {"k1": "rich", "k2": [20, 60]}) @@ -12,10 +11,9 @@ @pytest.fixture() async def prepare(connection): - havejson = True - c = await connection.cursor(aiomysql.cursors.DeserializationCursor) + c = await connection.cursor(aiomysql.connection.DeserializationCursor) # create a table ane some data to query await c.execute("drop table if exists deserialize_cursor") @@ -52,7 +50,7 @@ async def test_deserialize_cursor(prepare, connection): # all assert test compare to the structure as would come # out from MySQLdb conn = connection - c = await conn.cursor(aiomysql.cursors.DeserializationCursor) + c = await conn.cursor(aiomysql.connection.DeserializationCursor) # pull back the single row dict for bob and check await c.execute("SELECT * from deserialize_cursor " @@ -88,7 +86,7 @@ async def test_deserialize_cursor_low_version(prepare, connection): # all assert test compare to the structure as would come # out from MySQLdb conn = connection - c = await conn.cursor(aiomysql.cursors.DeserializationCursor) + c = await conn.cursor(aiomysql.connection.DeserializationCursor) # pull back the single row dict for bob and check await c.execute("SELECT * from deserialize_cursor where name='bob'") @@ -120,8 +118,8 @@ async def test_deserializedictcursor(prepare, connection): bob = {'name': 'bob', 'age': 21, 'claim': {"k1": "pretty", "k2": [18, 25]}} conn = connection - c = await conn.cursor(aiomysql.cursors.DeserializationCursor, - aiomysql.cursors.DictCursor) + c = await conn.cursor(aiomysql.connection.DeserializationCursor, + aiomysql.connection.DictCursor) await c.execute("SELECT * from deserialize_cursor " "where name='bob'") r = await c.fetchall() @@ -135,8 +133,8 @@ async def test_ssdeserializecursor(prepare, connection): if not havejson: return conn = connection - c = await conn.cursor(aiomysql.cursors.SSCursor, - aiomysql.cursors.DeserializationCursor) + c = await conn.cursor(aiomysql.connection.SSCursor, + aiomysql.connection.DeserializationCursor) await c.execute("SELECT * from deserialize_cursor " "where name='bob'") r = await c.fetchall() @@ -152,9 +150,9 @@ async def test_ssdeserializedictcursor(prepare, connection): bob = {'name': 'bob', 'age': 21, 'claim': {"k1": "pretty", "k2": [18, 25]}} conn = connection - c = await conn.cursor(aiomysql.cursors.SSCursor, - aiomysql.cursors.DeserializationCursor, - aiomysql.cursors.DictCursor) + c = await conn.cursor(aiomysql.connection.SSCursor, + aiomysql.connection.DeserializationCursor, + aiomysql.connection.DictCursor) await c.execute("SELECT * from deserialize_cursor " "where name='bob'") r = await c.fetchall() diff --git a/tests/test_dictcursor.py b/tests/test_dictcursor.py index 5326d5b6..0e50228d 100644 --- a/tests/test_dictcursor.py +++ b/tests/test_dictcursor.py @@ -2,7 +2,7 @@ import pytest -import aiomysql.cursors +import aiomysql.connection BOB = {'name': 'bob', 'age': 21, @@ -12,7 +12,7 @@ FRED = {'name': 'fred', 'age': 100, 'DOB': datetime.datetime(1911, 9, 12, 1, 1, 1)} -CURSOR_TYPE = aiomysql.cursors.DictCursor +CURSOR_TYPE = aiomysql.connection.DictCursor async def prepare(conn): @@ -100,7 +100,7 @@ async def test_ssdictcursor(connection): conn = connection await prepare(connection) - c = await conn.cursor(aiomysql.cursors.SSDictCursor) + c = await conn.cursor(aiomysql.connection.SSDictCursor) await c.execute("SELECT * from dictcursor where name='bob'") r = await c.fetchall() assert [BOB] == r,\ diff --git a/tests/test_issues.py b/tests/test_issues.py index e60a5103..59fca111 100644 --- a/tests/test_issues.py +++ b/tests/test_issues.py @@ -333,7 +333,7 @@ async def test_issue_79(connection): """ Duplicate field overwrites the previous one in the result of DictCursor """ conn = connection - c = await conn.cursor(aiomysql.cursors.DictCursor) + c = await conn.cursor(aiomysql.connection.DictCursor) await c.execute("drop table if exists a") await c.execute("drop table if exists b") diff --git a/tests/test_sha_connection.py b/tests/test_sha_connection.py index 47baa0a6..3ad1fecb 100644 --- a/tests/test_sha_connection.py +++ b/tests/test_sha_connection.py @@ -4,7 +4,7 @@ import pytest -# You could parameterise these tests with this, but then pytest +# You could parameterize these tests with this, but then pytest # does some funky stuff and spins up and tears down containers # per function call. Remember it would be # mysql_versions * event_loops * 4 auth tests ~= 3*2*4 ~= 24 tests diff --git a/tests/test_sscursor.py b/tests/test_sscursor.py index eff2ee33..3593491d 100644 --- a/tests/test_sscursor.py +++ b/tests/test_sscursor.py @@ -4,7 +4,7 @@ from pymysql import NotSupportedError from aiomysql import ProgrammingError, InterfaceError, OperationalError -from aiomysql.cursors import SSCursor +from aiomysql.connection import SSCursor DATA = [