Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(memtables): track memtables with a weakset to allow overwriting tables with the same name but different data #10133

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions ibis/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys
import urllib.parse
import weakref
from collections import Counter
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple

Expand Down Expand Up @@ -863,6 +864,10 @@ def __init__(self, *args, **kwargs):
self._con_args: tuple[Any] = args
self._con_kwargs: dict[str, Any] = kwargs
self._can_reconnect: bool = True
# mapping of memtable names to their finalizers
self._finalizers = {}
self._memtables = weakref.WeakSet()
self._current_memtables = weakref.WeakValueDictionary()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can reduce this state to just a single weakset and dict, mind if I push up a commit?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Go for it!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jcrist I am still open to you pushing up a commit if you've already got it locally, otherwise, no big deal.

super().__init__()

@property
Expand Down Expand Up @@ -1110,16 +1115,47 @@ def _register_udfs(self, expr: ir.Expr) -> None:
if self.supports_python_udfs:
raise NotImplementedError(self.name)

def _in_memory_table_exists(self, name: str) -> bool:
return name in self.list_tables()
def _verify_in_memory_tables_are_unique(self, expr: ir.Expr) -> None:
memtables = expr.op().find(ops.InMemoryTable)
name_counts = Counter(op.name for op in memtables)

if duplicate_names := sorted(
name for name, count in name_counts.items() if count > 1
):
raise exc.IbisError(f"Duplicate in-memory table names: {duplicate_names}")
return memtables

def _register_in_memory_tables(self, expr: ir.Expr) -> None:
for memtable in expr.op().find(ops.InMemoryTable):
if not self._in_memory_table_exists(memtable.name):
for memtable in self._verify_in_memory_tables_are_unique(expr):
name = memtable.name

# this particular memtable has never been registered
if memtable not in self._memtables:
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
# but we have a memtable with the same name
if (
current_memtable := self._current_memtables.pop(name, None)
) is not None:
# if we're here this means we overwrite, so do the following:
# 1. remove the old memtable from the set of memtables
# 2. grab the old finalizer and invoke it
self._memtables.remove(current_memtable)
finalizer = self._finalizers.pop(name)
finalizer()
else:
# if memtable is in the set, then by construction it must be
# true that the name of this memtable is in the current
# memtables mapping
assert name in self._current_memtables

# if there's no memtable named `name` then register it, setup a
# finalizer, and set it as the current memtable with `name`
if self._current_memtables.get(name) is None:
self._register_in_memory_table(memtable)
weakref.finalize(
memtable, self._finalize_in_memory_table, memtable.name
self._memtables.add(memtable)
self._finalizers[name] = weakref.finalize(
memtable, self._finalize_in_memory_table, name
)
self._current_memtables[name] = memtable

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
if self.supports_in_memory_tables:
Expand Down
10 changes: 0 additions & 10 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,6 @@ def _session_dataset(self):
self.__session_dataset = self._make_session()
return self.__session_dataset

def _in_memory_table_exists(self, name: str) -> bool:
table_ref = bq.TableReference(self._session_dataset, name)

try:
self._get_table(table_ref)
except com.TableNotFound:
return False
else:
return True

def _finalize_memtable(self, name: str) -> None:
table_ref = bq.TableReference(self._session_dataset, name)
self.client.delete_table(table_ref, not_found_ok=True)
Expand Down
24 changes: 3 additions & 21 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,9 @@ def _normalize_external_tables(self, external_tables=None) -> ExternalData | Non
def _collect_in_memory_tables(
self, expr: ir.Table | None, external_tables: Mapping | None = None
):
memtables = {op.name: op for op in expr.op().find(ops.InMemoryTable)}
memtables = {
op.name: op for op in self._verify_in_memory_tables_are_unique(expr)
}
externals = toolz.valmap(_to_memtable, external_tables or {})
return toolz.merge(memtables, externals)

Expand Down Expand Up @@ -779,23 +781,3 @@ def create_view(
with self._safe_raw_sql(src, external_tables=external_tables):
pass
return self.table(name, database=database)

def _in_memory_table_exists(self, name: str) -> bool:
name = sg.table(name, quoted=self.compiler.quoted).sql(self.dialect)
try:
# DESCRIBE TABLE $TABLE FORMAT NULL is the fastest way to check
# table existence in clickhouse; FORMAT NULL produces no data which
# is ideal since we don't care about the output for existence
# checking
#
# Other methods compared were
# 1. SELECT 1 FROM $TABLE LIMIT 0
# 2. SHOW TABLES LIKE $TABLE LIMIT 1
#
# if the table exists nothing is returned and there's no error
# otherwise there's an error
self.con.raw_query(f"DESCRIBE {name} FORMAT NULL")
except cc.driver.exceptions.DatabaseError:
return False
else:
return True
9 changes: 0 additions & 9 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,6 @@ def _register_failure(self):
f"please call one of {msg} directly"
)

def _in_memory_table_exists(self, name: str) -> bool:
db = self.con.catalog().database()
try:
db.table(name)
except Exception: # noqa: BLE001 because DataFusion has nothing better
return False
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
# self.con.register_table is broken, so we do this roundabout thing
# of constructing a datafusion DataFrame, which has a side effect
Expand Down
9 changes: 0 additions & 9 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,15 +1606,6 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
}
)

def _in_memory_table_exists(self, name: str) -> bool:
try:
# this handles both tables and views
self.con.table(name)
except (duckdb.CatalogException, duckdb.InvalidInputException):
return False
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
self.con.register(op.name, op.data.to_pyarrow(op.schema))

Expand Down
3 changes: 0 additions & 3 deletions ibis/backends/exasol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,6 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
finally:
self.con.execute(drop_view)

def _in_memory_table_exists(self, name: str) -> bool:
return self.con.meta.table_exists(name)

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/flink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@

def execute(self, expr: ir.Expr, **kwargs: Any) -> Any:
"""Execute an expression."""
self._verify_in_memory_tables_are_unique(expr)

Check warning on line 374 in ibis/backends/flink/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/flink/__init__.py#L374

Added line #L374 was not covered by tests
self._register_udfs(expr)

table_expr = expr.as_table()
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/impala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,10 +1223,6 @@ def explain(

return "\n".join(["Query:", util.indent(query, 2), "", *results.iloc[:, 0]])

def _in_memory_table_exists(self, name: str) -> bool:
with contextlib.closing(self.con.cursor()) as cur:
return cur.table_exists(name)

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
Expand Down
10 changes: 0 additions & 10 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,16 +738,6 @@ def create_table(
namespace=ops.Namespace(catalog=catalog, database=db),
).to_expr()

def _in_memory_table_exists(self, name: str) -> bool:
# The single character U here means user-defined table
# see https://learn.microsoft.com/en-us/sql/relational-databases/system-catalog-views/sys-objects-transact-sql?view=sql-server-ver16
sql = sg.select(sg.func("object_id", sge.convert(name), sge.convert("U"))).sql(
self.dialect
)
with self.begin() as cur:
[(result,)] = cur.execute(sql).fetchall()
return result is not None

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
Expand Down
17 changes: 0 additions & 17 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,23 +477,6 @@ def create_table(
name, schema=schema, source=self, namespace=ops.Namespace(database=database)
).to_expr()

def _in_memory_table_exists(self, name: str) -> bool:
name = sg.to_identifier(name, quoted=self.compiler.quoted).sql(self.dialect)
# just return the single field with column names; no need to bring back
# everything if the command succeeds
sql = f"SHOW COLUMNS FROM {name} LIKE 'Field'"
try:
with self.begin() as cur:
cur.execute(sql)
cur.fetchall()
except MySQLdb.ProgrammingError as e:
err_code, _ = e.args
if err_code == ER.NO_SUCH_TABLE:
return False
raise
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
Expand Down
17 changes: 1 addition & 16 deletions ibis/backends/oracle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ibis import util
from ibis.backends import CanListDatabase, CanListSchema
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers.base import NULL, STAR, C
from ibis.backends.sql.compilers.base import STAR, C

if TYPE_CHECKING:
from urllib.parse import ParseResult
Expand Down Expand Up @@ -522,21 +522,6 @@ def drop_table(

super().drop_table(name, database=(catalog, db), force=force)

def _in_memory_table_exists(self, name: str) -> bool:
sql = (
sg.select(NULL)
.from_(sg.to_identifier("USER_OBJECTS", quoted=self.compiler.quoted))
.where(
C.OBJECT_TYPE.eq(sge.convert("TABLE")),
C.OBJECT_NAME.eq(sge.convert(name)),
)
.limit(sge.convert(1))
.sql(self.dialect)
)
with self.begin() as cur:
results = cur.execute(sql).fetchall()
return bool(results)

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema

Expand Down
3 changes: 0 additions & 3 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ def table(self, name: str) -> ir.Table:
schema = sch.infer(table)
return ops.DatabaseTable(name, schema, self).to_expr()

def _in_memory_table_exists(self, name: str) -> bool:
return name in self._tables

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
self._add_table(op.name, op.data.to_polars(op.schema).lazy())

Expand Down
15 changes: 0 additions & 15 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,6 @@ def _from_url(self, url: ParseResult, **kwargs):

return self.connect(**kwargs)

def _in_memory_table_exists(self, name: str) -> bool:
import psycopg2.errors

ident = sg.to_identifier(name, quoted=self.compiler.quoted)
sql = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect)

try:
with self.begin() as cur:
cur.execute(sql)
cur.fetchall()
except psycopg2.errors.UndefinedTable:
return False
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
from psycopg2.extras import execute_batch

Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,6 @@ def _register_udfs(self, expr: ir.Expr) -> None:
self._session.udf.register(f"unwrap_json_{typ.__name__}", unwrap_json(typ))
self._session.udf.register("unwrap_json_float", unwrap_json_float)

def _in_memory_table_exists(self, name: str) -> bool:
sql = f"SHOW TABLES IN {self.current_database} LIKE '{name}'"
return bool(self._session.sql(sql).count())

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = PySparkSchema.from_ibis(op.schema)
df = self._session.createDataFrame(data=op.data.to_frame(), schema=schema)
Expand Down
15 changes: 0 additions & 15 deletions ibis/backends/risingwave/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,21 +262,6 @@ def create_table(
name, schema=schema, source=self, namespace=ops.Namespace(database=database)
).to_expr()

def _in_memory_table_exists(self, name: str) -> bool:
import psycopg2.errors

ident = sg.to_identifier(name, quoted=self.compiler.quoted)
sql = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect)

try:
with self.begin() as cur:
cur.execute(sql)
cur.fetchall()
except psycopg2.errors.InternalError:
return False
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
Expand Down
19 changes: 0 additions & 19 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,25 +663,6 @@ def list_tables(

return self._filter_with_like(tables + views, like=like)

def _in_memory_table_exists(self, name: str) -> bool:
import snowflake.connector

ident = sg.to_identifier(name, quoted=self.compiler.quoted)
sql = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect)

try:
with self.con.cursor() as cur:
cur.execute(sql).fetchall()
except snowflake.connector.errors.ProgrammingError as e:
# this cryptic error message is the only generic and reliable way
# to tell if the error means "table not found for any reason"
# otherwise, we need to reraise the exception
if e.sqlstate == "42S02":
return False
raise
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
import pyarrow.parquet as pq

Expand Down
12 changes: 0 additions & 12 deletions ibis/backends/sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,18 +345,6 @@ def _generate_create_table(self, table: sge.Table, schema: sch.Schema):

return sge.Create(kind="TABLE", this=target)

def _in_memory_table_exists(self, name: str) -> bool:
ident = sg.to_identifier(name, quoted=self.compiler.quoted)
query = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect)
try:
with self.begin() as cur:
cur.execute(query)
cur.fetchall()
except sqlite3.OperationalError:
return False
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
table = sg.table(op.name, quoted=self.compiler.quoted, catalog="temp")
create_stmt = self._generate_create_table(table, op.schema).sql(self.name)
Expand Down
Loading