From c715733f3173a8b3d06d3bb41dc0de51e2c300a4 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Fri, 4 Oct 2024 15:20:45 -0600 Subject: [PATCH] Update index reflection to work only on hybrid tables --- .github/workflows/build_test.yml | 2 - src/snowflake/sqlalchemy/snowdialect.py | 57 +++++++++++++++++++ .../sql/custom_schema/custom_table_base.py | 22 +++---- .../sql/custom_schema/custom_table_prefix.py | 13 +++++ .../sql/custom_schema/dynamic_table.py | 3 +- .../sql/custom_schema/hybrid_table.py | 3 +- .../custom_tables/test_create_hybrid_table.py | 3 +- tests/test_orm.py | 2 +- 8 files changed, 88 insertions(+), 17 deletions(-) create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index 78952a9a..bad10947 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -113,7 +113,6 @@ jobs: .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py - name: Run tests run: hatch run test-dialect - if: matrix.cloud-provider != 'aws' - name: Run test for AWS run: hatch run test-dialect-aws if: matrix.cloud-provider == 'aws' @@ -209,7 +208,6 @@ jobs: python -m hatch env create default - name: Run tests run: hatch run sa14:test-dialect - if: matrix.cloud-provider != 'aws' - name: Run test for AWS run: hatch run sa14:test-dialect-aws if: matrix.cloud-provider == 'aws' diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 7510eea2..0e00aaed 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -65,6 +65,7 @@ _CUSTOM_Float, _CUSTOM_Time, ) +from .sql.custom_schema.custom_table_prefix import CustomTablePrefix from .util import ( _update_connection_application_name, parse_url_boolean, @@ -898,6 +899,12 @@ def get_multi_indexes( """ Gets the indexes definition """ + + table_prefixes = self.get_multi_prefixes( + connection, schema, filter_prefix=CustomTablePrefix.HYBRID.name + ) + if len(table_prefixes) == 0: + return [] schema = schema or self.default_schema_name if not schema: result = connection.execute( @@ -918,6 +925,12 @@ def get_multi_indexes( if ( row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY' or table not in filter_names + or (schema, table) not in table_prefixes + or ( + (schema, table) in table_prefixes + and CustomTablePrefix.HYBRID.name + not in table_prefixes[(schema, table)] + ) ): continue index = { @@ -942,6 +955,50 @@ def _value_or_default(self, data, table, schema): else: return [] + def get_prefixes_from_data(self, n2i, row, **kw): + prefixes_found = [] + for valid_prefix in CustomTablePrefix: + key = f"is_{valid_prefix.name.lower()}" + if key in n2i and row[n2i[key]] == "Y": + prefixes_found.append(valid_prefix.name) + return prefixes_found + + @reflection.cache + def get_multi_prefixes( + self, connection, schema, table_name=None, filter_prefix=None, **kw + ): + """ + Gets all table prefixes + """ + schema = schema or self.default_schema_name + filter = f"LIKE '{table_name}'" if table_name else "" + if schema: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES IN SCHEMA {schema}" + ) + ) + else: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES LIKE '{table_name}'" + ) + ) + + n2i = self.__class__._map_name_to_idx(result) + tables_prefixes = {} + for row in result.cursor.fetchall(): + table = self.normalize_name(str(row[n2i["name"]])) + table_prefixes = self.get_prefixes_from_data(n2i, row) + if filter_prefix and filter_prefix not in table_prefixes: + continue + if (schema, table) in tables_prefixes: + tables_prefixes[(schema, table)].append(table_prefixes) + else: + tables_prefixes[(schema, table)] = table_prefixes + + return tables_prefixes + @reflection.cache def get_indexes(self, connection, tablename, schema, **kw): """ diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py index 87153f20..f620da2a 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -10,12 +10,13 @@ from ..._constants import DIALECT_NAME from ...compat import IS_VERSION_20 from ...custom_commands import NoneType +from .custom_table_prefix import CustomTablePrefix from .options.table_option import TableOption class CustomTableBase(Table): - __table_prefix__ = "" - _support_primary_and_foreign_keys = True + __table_prefixes__: typing.List[CustomTablePrefix] = [] + _support_primary_and_foreign_keys: bool = True def __init__( self, @@ -24,8 +25,10 @@ def __init__( *args: SchemaItem, **kw: Any, ) -> None: - if self.__table_prefix__ != "": - prefixes = kw.get("prefixes", []) + [self.__table_prefix__] + if len(self.__table_prefixes__) > 0: + prefixes = kw.get("prefixes", []) + [ + prefix.name for prefix in self.__table_prefixes__ + ] kw.update(prefixes=prefixes) if not IS_VERSION_20 and hasattr(super(), "_init"): super()._init(name, metadata, *args, **kw) @@ -52,9 +55,8 @@ def _get_dialect_option(self, option_name: str) -> typing.Optional[TableOption]: @classmethod def is_equal_type(cls, table: Table) -> bool: - if isinstance(table, cls.__class__): - return True - for prefix in table._prefixes: - if prefix == cls.__table_prefix__: - return True - return False + for prefix in cls.__table_prefixes__: + if prefix.name not in table._prefixes: + return False + + return True diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py new file mode 100644 index 00000000..fea13a8a --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from enum import Enum + + +class CustomTablePrefix(Enum): + DEFAULT = 0 + EXTERNAL = (1,) + EVENT = (2,) + HYBRID = (3,) + ICEBERG = (4,) + DYNAMIC = (5,) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py index 7d0a02e6..1a2248fc 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -10,6 +10,7 @@ from snowflake.sqlalchemy.custom_commands import NoneType +from .custom_table_prefix import CustomTablePrefix from .options.target_lag import TargetLag from .options.warehouse import Warehouse from .table_from_query import TableFromQueryBase @@ -27,7 +28,7 @@ class DynamicTable(TableFromQueryBase): """ - __table_prefix__ = "DYNAMIC" + __table_prefixes__ = [CustomTablePrefix.DYNAMIC] _support_primary_and_foreign_keys = False diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py index 0cff1c3a..bd49a420 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py @@ -10,6 +10,7 @@ from snowflake.sqlalchemy.custom_commands import NoneType from .custom_table_base import CustomTableBase +from .custom_table_prefix import CustomTablePrefix class HybridTable(CustomTableBase): @@ -22,7 +23,7 @@ class HybridTable(CustomTableBase): interface for creating dynamic tables and management. """ - __table_prefix__ = "HYBRID" + __table_prefixes__ = [CustomTablePrefix.HYBRID] _support_primary_and_foreign_keys = True diff --git a/tests/custom_tables/test_create_hybrid_table.py b/tests/custom_tables/test_create_hybrid_table.py index 81da7624..43ae3ab6 100644 --- a/tests/custom_tables/test_create_hybrid_table.py +++ b/tests/custom_tables/test_create_hybrid_table.py @@ -6,7 +6,7 @@ from sqlalchemy import Column, Index, Integer, MetaData, String, select from sqlalchemy.orm import Session, declarative_base -from src.snowflake.sqlalchemy import HybridTable +from snowflake.sqlalchemy import HybridTable @pytest.mark.aws @@ -25,7 +25,6 @@ def test_create_hybrid_table(engine_testaccount, db_parameters, snapshot): with engine_testaccount.connect() as conn: ins = dynamic_test_table_1.insert().values(id=1, name="test") - conn.execute(ins) conn.commit() diff --git a/tests/test_orm.py b/tests/test_orm.py index d10650e5..cb3a7768 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -21,7 +21,7 @@ ) from sqlalchemy.orm import Session, declarative_base, relationship -from src.snowflake.sqlalchemy import HybridTable +from snowflake.sqlalchemy import HybridTable def test_basic_orm(engine_testaccount):