Skip to content

Commit

Permalink
Update index reflection to work only on hybrid tables
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jvasquezrojas committed Oct 4, 2024
1 parent 444a300 commit c715733
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 17 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ jobs:
.github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py
- name: Run tests
run: hatch run test-dialect
if: matrix.cloud-provider != 'aws'
- name: Run test for AWS
run: hatch run test-dialect-aws
if: matrix.cloud-provider == 'aws'
Expand Down Expand Up @@ -209,7 +208,6 @@ jobs:
python -m hatch env create default
- name: Run tests
run: hatch run sa14:test-dialect
if: matrix.cloud-provider != 'aws'
- name: Run test for AWS
run: hatch run sa14:test-dialect-aws
if: matrix.cloud-provider == 'aws'
Expand Down
57 changes: 57 additions & 0 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
_CUSTOM_Float,
_CUSTOM_Time,
)
from .sql.custom_schema.custom_table_prefix import CustomTablePrefix
from .util import (
_update_connection_application_name,
parse_url_boolean,
Expand Down Expand Up @@ -898,6 +899,12 @@ def get_multi_indexes(
"""
Gets the indexes definition
"""

table_prefixes = self.get_multi_prefixes(
connection, schema, filter_prefix=CustomTablePrefix.HYBRID.name
)
if len(table_prefixes) == 0:
return []
schema = schema or self.default_schema_name
if not schema:
result = connection.execute(
Expand All @@ -918,6 +925,12 @@ def get_multi_indexes(
if (
row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY'
or table not in filter_names
or (schema, table) not in table_prefixes
or (
(schema, table) in table_prefixes
and CustomTablePrefix.HYBRID.name
not in table_prefixes[(schema, table)]
)
):
continue
index = {
Expand All @@ -942,6 +955,50 @@ def _value_or_default(self, data, table, schema):
else:
return []

def get_prefixes_from_data(self, n2i, row, **kw):
prefixes_found = []
for valid_prefix in CustomTablePrefix:
key = f"is_{valid_prefix.name.lower()}"
if key in n2i and row[n2i[key]] == "Y":
prefixes_found.append(valid_prefix.name)
return prefixes_found

@reflection.cache
def get_multi_prefixes(
self, connection, schema, table_name=None, filter_prefix=None, **kw
):
"""
Gets all table prefixes
"""
schema = schema or self.default_schema_name
filter = f"LIKE '{table_name}'" if table_name else ""
if schema:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES IN SCHEMA {schema}"
)
)
else:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES LIKE '{table_name}'"
)
)

n2i = self.__class__._map_name_to_idx(result)
tables_prefixes = {}
for row in result.cursor.fetchall():
table = self.normalize_name(str(row[n2i["name"]]))
table_prefixes = self.get_prefixes_from_data(n2i, row)
if filter_prefix and filter_prefix not in table_prefixes:
continue
if (schema, table) in tables_prefixes:
tables_prefixes[(schema, table)].append(table_prefixes)
else:
tables_prefixes[(schema, table)] = table_prefixes

return tables_prefixes

@reflection.cache
def get_indexes(self, connection, tablename, schema, **kw):
"""
Expand Down
22 changes: 12 additions & 10 deletions src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
from ..._constants import DIALECT_NAME
from ...compat import IS_VERSION_20
from ...custom_commands import NoneType
from .custom_table_prefix import CustomTablePrefix
from .options.table_option import TableOption


class CustomTableBase(Table):
__table_prefix__ = ""
_support_primary_and_foreign_keys = True
__table_prefixes__: typing.List[CustomTablePrefix] = []
_support_primary_and_foreign_keys: bool = True

def __init__(
self,
Expand All @@ -24,8 +25,10 @@ def __init__(
*args: SchemaItem,
**kw: Any,
) -> None:
if self.__table_prefix__ != "":
prefixes = kw.get("prefixes", []) + [self.__table_prefix__]
if len(self.__table_prefixes__) > 0:
prefixes = kw.get("prefixes", []) + [
prefix.name for prefix in self.__table_prefixes__
]
kw.update(prefixes=prefixes)
if not IS_VERSION_20 and hasattr(super(), "_init"):
super()._init(name, metadata, *args, **kw)
Expand All @@ -52,9 +55,8 @@ def _get_dialect_option(self, option_name: str) -> typing.Optional[TableOption]:

@classmethod
def is_equal_type(cls, table: Table) -> bool:
if isinstance(table, cls.__class__):
return True
for prefix in table._prefixes:
if prefix == cls.__table_prefix__:
return True
return False
for prefix in cls.__table_prefixes__:
if prefix.name not in table._prefixes:
return False

return True
13 changes: 13 additions & 0 deletions src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.

from enum import Enum


class CustomTablePrefix(Enum):
DEFAULT = 0
EXTERNAL = (1,)
EVENT = (2,)
HYBRID = (3,)
ICEBERG = (4,)
DYNAMIC = (5,)
3 changes: 2 additions & 1 deletion src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from snowflake.sqlalchemy.custom_commands import NoneType

from .custom_table_prefix import CustomTablePrefix
from .options.target_lag import TargetLag
from .options.warehouse import Warehouse
from .table_from_query import TableFromQueryBase
Expand All @@ -27,7 +28,7 @@ class DynamicTable(TableFromQueryBase):
"""

__table_prefix__ = "DYNAMIC"
__table_prefixes__ = [CustomTablePrefix.DYNAMIC]

_support_primary_and_foreign_keys = False

Expand Down
3 changes: 2 additions & 1 deletion src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from snowflake.sqlalchemy.custom_commands import NoneType

from .custom_table_base import CustomTableBase
from .custom_table_prefix import CustomTablePrefix


class HybridTable(CustomTableBase):
Expand All @@ -22,7 +23,7 @@ class HybridTable(CustomTableBase):
interface for creating dynamic tables and management.
"""

__table_prefix__ = "HYBRID"
__table_prefixes__ = [CustomTablePrefix.HYBRID]

_support_primary_and_foreign_keys = True

Expand Down
3 changes: 1 addition & 2 deletions tests/custom_tables/test_create_hybrid_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy import Column, Index, Integer, MetaData, String, select
from sqlalchemy.orm import Session, declarative_base

from src.snowflake.sqlalchemy import HybridTable
from snowflake.sqlalchemy import HybridTable


@pytest.mark.aws
Expand All @@ -25,7 +25,6 @@ def test_create_hybrid_table(engine_testaccount, db_parameters, snapshot):

with engine_testaccount.connect() as conn:
ins = dynamic_test_table_1.insert().values(id=1, name="test")

conn.execute(ins)
conn.commit()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from sqlalchemy.orm import Session, declarative_base, relationship

from src.snowflake.sqlalchemy import HybridTable
from snowflake.sqlalchemy import HybridTable


def test_basic_orm(engine_testaccount):
Expand Down

0 comments on commit c715733

Please sign in to comment.