Skip to content

Commit

Permalink
chore(mysql): port to MySQLdb instead of pymysql (#10077)
Browse files Browse the repository at this point in the history
Replaces pymysql with mysqlclient, mostly out of frustration with bizarre GC behavior discovered during #10055. I think this is probably a breaking change due to some changes in how types are inferred for JSON, INET and UUID types.

BREAKING CHANGE: Ibis now uses the `MySQLdb` driver. You may need to install MySQL client libraries to **build** the extension.
  • Loading branch information
cpcloud authored Sep 18, 2024
1 parent 966c5e8 commit 2b6633c
Show file tree
Hide file tree
Showing 14 changed files with 136 additions and 141 deletions.
2 changes: 1 addition & 1 deletion .github/renovate.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"addLabels": ["druid"]
},
{
"matchPackagePatterns": ["pymysql", "mariadb"],
"matchPackagePatterns": ["mysqlclient", "mariadb"],
"addLabels": ["mysql"]
},
{
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ jobs:
- polars
sys-deps:
- libgeos-dev
- default-libmysqlclient-dev
- name: postgres
title: PostgreSQL
extras:
Expand Down Expand Up @@ -271,6 +272,7 @@ jobs:
- mysql
sys-deps:
- libgeos-dev
- default-libmysqlclient-dev
- os: windows-latest
backend:
name: clickhouse
Expand Down
2 changes: 1 addition & 1 deletion conda/environment-arm64-flink.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- pyarrow-hotfix >=0.4
- pydata-google-auth
- pydruid >=0.6.5
- pymysql >=1
- mysqlclient >=2.2.4
- pyspark >=3
- python-dateutil >=2.8.2
- python-duckdb >=0.8.1
Expand Down
2 changes: 1 addition & 1 deletion conda/environment-arm64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- pyarrow-hotfix >=0.4
- pydata-google-auth
- pydruid >=0.6.5
- pymysql >=1
- mysqlclient >=2.2.4
- pyodbc >=4.0.39
- pyspark >=3
- python-dateutil >=2.8.2
Expand Down
2 changes: 1 addition & 1 deletion conda/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- pyarrow-hotfix >=0.4
- pydata-google-auth
- pydruid >=0.6.5
- pymysql >=1
- mysqlclient >=2.2.4
- pyodbc >=4.0.39
- pyspark >=3
- python >=3.10
Expand Down
158 changes: 76 additions & 82 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
from __future__ import annotations

import contextlib
import re
import warnings
from functools import cached_property
from operator import itemgetter
from typing import TYPE_CHECKING, Any
from urllib.parse import unquote_plus

import pymysql
import MySQLdb
import sqlglot as sg
import sqlglot.expressions as sge
from pymysql.constants import ER
from pymysql.err import ProgrammingError
from MySQLdb import ProgrammingError
from MySQLdb.constants import ER

import ibis
import ibis.backends.sql.compilers as sc
Expand All @@ -24,7 +23,6 @@
import ibis.expr.types as ir
from ibis import util
from ibis.backends import CanCreateDatabase
from ibis.backends.mysql.datatypes import _type_from_cursor_info
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers.base import STAR, TRUE, C

Expand Down Expand Up @@ -89,16 +87,14 @@ def _from_url(self, url: ParseResult, **kwargs):

@cached_property
def version(self):
matched = re.search(r"(\d+)\.(\d+)\.(\d+)", self.con.server_version)
return ".".join(matched.groups())
return ".".join(map(str, self.con._server_version))

def do_connect(
self,
host: str = "localhost",
user: str | None = None,
password: str | None = None,
port: int = 3306,
database: str | None = None,
autocommit: bool = True,
**kwargs,
) -> None:
Expand All @@ -114,12 +110,10 @@ def do_connect(
Password
port
Port
database
Database to connect to
autocommit
Autocommit mode
kwargs
Additional keyword arguments passed to `pymysql.connect`
Additional keyword arguments passed to `MySQLdb.connect`
Examples
--------
Expand Down Expand Up @@ -149,22 +143,20 @@ def do_connect(
year int32
month int32
"""
self.con = pymysql.connect(
self.con = MySQLdb.connect(
user=user,
host=host,
host="127.0.0.1" if host == "localhost" else host,
port=port,
password=password,
database=database,
autocommit=autocommit,
conv=pymysql.converters.conversions,
**kwargs,
)

self._post_connect()

@util.experimental
@classmethod
def from_connection(cls, con: pymysql.Connection) -> Backend:
def from_connection(cls, con) -> Backend:
"""Create an Ibis client from an existing connection to a MySQL database.
Parameters
Expand All @@ -179,7 +171,7 @@ def from_connection(cls, con: pymysql.Connection) -> Backend:
return new_backend

def _post_connect(self) -> None:
with contextlib.closing(self.con.cursor()) as cur:
with self.con.cursor() as cur:
try:
cur.execute("SET @@session.time_zone = 'UTC'")
except Exception as e: # noqa: BLE001
Expand All @@ -198,23 +190,34 @@ def list_databases(self, like: str | None = None) -> list[str]:
return self._filter_with_like(databases, like)

def _get_schema_using_query(self, query: str) -> sch.Schema:
with self.begin() as cur:
cur.execute(
sg.select(STAR)
.from_(
sg.parse_one(query, dialect=self.dialect).subquery(
sg.to_identifier("tmp", quoted=self.compiler.quoted)
)
from ibis.backends.mysql.datatypes import _type_from_cursor_info

sql = (
sg.select(STAR)
.from_(
sg.parse_one(query, dialect=self.dialect).subquery(
sg.to_identifier("tmp", quoted=self.compiler.quoted)
)
.limit(0)
.sql(self.dialect)
)
return sch.Schema(
{
field.name: _type_from_cursor_info(descr, field)
for descr, field in zip(cur.description, cur._result.fields)
}
.limit(0)
.sql(self.dialect)
)
with self.begin() as cur:
cur.execute(sql)
descr, flags = cur.description, cur.description_flags

items = {}
for (name, type_code, _, _, field_length, scale, _), raw_flags in zip(
descr, flags
):
item = _type_from_cursor_info(
flags=raw_flags,
type_code=type_code,
field_length=field_length,
scale=scale,
)
items[name] = item
return sch.Schema(items)

def get_schema(
self, name: str, *, catalog: str | None = None, database: str | None = None
Expand Down Expand Up @@ -258,38 +261,52 @@ def drop_database(self, name: str, force: bool = False) -> None:
def begin(self):
con = self.con
cur = con.cursor()
autocommit = con.get_autocommit()

if not autocommit:
con.begin()

try:
yield cur
except Exception:
con.rollback()
if not autocommit:
con.rollback()
raise
else:
con.commit()
if not autocommit:
con.commit()
finally:
cur.close()

# TODO(kszucs): should make it an abstract method or remove the use of it
# from .execute()
@contextlib.contextmanager
def _safe_raw_sql(self, *args, **kwargs):
with contextlib.closing(self.raw_sql(*args, **kwargs)) as result:
with self.raw_sql(*args, **kwargs) as result:
yield result

def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
with contextlib.suppress(AttributeError):
query = query.sql(dialect=self.name)

con = self.con
autocommit = con.get_autocommit()

cursor = con.cursor()

if not autocommit:
con.begin()

try:
cursor.execute(query, **kwargs)
except Exception:
con.rollback()
if not autocommit:
con.rollback()
cursor.close()
raise
else:
con.commit()
if not autocommit:
con.commit()
return cursor

# TODO: disable positional arguments
Expand Down Expand Up @@ -406,11 +423,9 @@ def create_table(
if temp:
properties.append(sge.TemporaryProperty())

temp_memtable_view = None
if obj is not None:
if not isinstance(obj, ir.Expr):
table = ibis.memtable(obj)
temp_memtable_view = table.op().name
else:
table = obj

Expand All @@ -428,39 +443,33 @@ def create_table(
if not schema:
schema = table.schema()

table_expr = sg.table(temp_name, catalog=database, quoted=self.compiler.quoted)
target = sge.Schema(
this=table_expr, expressions=schema.to_sqlglot(self.dialect)
)
quoted = self.compiler.quoted
dialect = self.dialect

table_expr = sg.table(temp_name, catalog=database, quoted=quoted)
target = sge.Schema(this=table_expr, expressions=schema.to_sqlglot(dialect))

create_stmt = sge.Create(
kind="TABLE",
this=target,
properties=sge.Properties(expressions=properties),
kind="TABLE", this=target, properties=sge.Properties(expressions=properties)
)

this = sg.table(name, catalog=database, quoted=self.compiler.quoted)
this = sg.table(name, catalog=database, quoted=quoted)
with self._safe_raw_sql(create_stmt) as cur:
if query is not None:
insert_stmt = sge.Insert(this=table_expr, expression=query).sql(
self.name
)
cur.execute(insert_stmt)
cur.execute(sge.Insert(this=table_expr, expression=query).sql(dialect))

if overwrite:
cur.execute(sge.Drop(kind="TABLE", this=this, exists=True).sql(dialect))
cur.execute(
sge.Drop(kind="TABLE", this=this, exists=True).sql(self.name)
)
cur.execute(
f"ALTER TABLE IF EXISTS {table_expr.sql(self.name)} RENAME TO {this.sql(self.name)}"
sge.Alter(
kind="TABLE",
this=table_expr,
exists=True,
actions=[sge.RenameTable(this=this)],
).sql(dialect)
)

if schema is None:
# Clean up temporary memtable if we've created one
# for in-memory reads
if temp_memtable_view is not None:
self.drop_table(temp_memtable_view)

return self.table(name, database=database)

# preserve the input schema if it was provided
Expand All @@ -477,7 +486,7 @@ def _in_memory_table_exists(self, name: str) -> bool:
with self.begin() as cur:
cur.execute(sql)
cur.fetchall()
except pymysql.err.ProgrammingError as e:
except MySQLdb.ProgrammingError as e:
err_code, _ = e.args
if err_code == ER.NO_SUCH_TABLE:
return False
Expand All @@ -495,16 +504,17 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:

name = op.name
quoted = self.compiler.quoted
dialect = self.dialect

create_stmt = sg.exp.Create(
kind="TABLE",
this=sg.exp.Schema(
this=sg.to_identifier(name, quoted=quoted),
expressions=schema.to_sqlglot(self.dialect),
expressions=schema.to_sqlglot(dialect),
),
properties=sg.exp.Properties(expressions=[sge.TemporaryProperty()]),
)
create_stmt_sql = create_stmt.sql(self.name)
create_stmt_sql = create_stmt.sql(dialect)

df = op.data.to_frame()
# nan can not be used with MySQL
Expand Down Expand Up @@ -549,23 +559,7 @@ def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame:

from ibis.backends.mysql.converter import MySQLPandasData

try:
df = pd.DataFrame.from_records(
cursor, columns=schema.names, coerce_float=True
)
except Exception:
# clean up the cursor if we fail to create the DataFrame
#
# in the sqlite case failing to close the cursor results in
# artificially locked tables
cursor.close()
raise
df = MySQLPandasData.convert_table(df, schema)
return df

def _finalize_memtable(self, name: str) -> None:
"""No-op.
Executing **any** SQL in a finalizer causes the underlying connection
socket to be set to `None`. It is unclear why this happens.
"""
df = pd.DataFrame.from_records(
cursor.fetchall(), columns=schema.names, coerce_float=True
)
return MySQLPandasData.convert_table(df, schema)
Loading

0 comments on commit 2b6633c

Please sign in to comment.