diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3be42964..70b75ce8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ exclude: '^(.*egg.info.*|.*/parameters.py).*$' repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.5.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -9,23 +9,23 @@ repos: exclude: .github/repo_meta.yaml - id: debug-statements - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort - repo: https://github.com/asottile/pyupgrade - rev: v2.37.3 + rev: v3.15.1 hooks: - id: pyupgrade args: [--py37-plus] - repo: https://github.com/psf/black - rev: 22.6.0 + rev: 24.2.0 hooks: - id: black args: - --safe language_version: python3 - repo: https://github.com/Lucas-C/pre-commit-hooks.git - rev: v1.3.0 + rev: v1.5.5 hooks: - id: insert-license name: insert-py-license @@ -39,7 +39,7 @@ repos: - --license-filepath - license_header.txt - repo: https://github.com/pycqa/flake8 - rev: 5.0.4 + rev: 7.0.0 hooks: - id: flake8 additional_dependencies: diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 5751354c..422fe807 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,10 @@ Source code is also available at: # Release Notes +- 1.5.2 + + - Add support for sequence ordering in tests + - v1.5.1(November 03, 2023) - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..907176c3 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,5 @@ +[tool.ruff] +line-length = 88 + +[tool.black] +line-length = 88 diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index f229fb93..2a1bb51a 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -112,7 +112,7 @@ @CompileState.plugin_for("default", "select") class SnowflakeSelectState(SelectState): def _setup_joins(self, args, raw_columns): - for (right, onclause, left, flags) in args: + for right, onclause, left, flags in args: isouter = flags["isouter"] full = flags["full"] @@ -579,9 +579,11 @@ def visit_copy_into(self, copy_into, **kw): [ "{} = {}".format( n, - v._compiler_dispatch(self, **kw) - if getattr(v, "compiler_dispatch", False) - else str(v), + ( + v._compiler_dispatch(self, **kw) + if getattr(v, "compiler_dispatch", False) + else str(v) + ), ) for n, v in options_list ] @@ -604,20 +606,24 @@ def visit_copy_formatter(self, formatter, **kw): return f"FILE_FORMAT=(format_name = {formatter.options['format_name']})" return "FILE_FORMAT=(TYPE={}{})".format( formatter.file_format, - " " - + " ".join( - [ - "{}={}".format( - name, - value._compiler_dispatch(self, **kw) - if hasattr(value, "_compiler_dispatch") - else formatter.value_repr(name, value), - ) - for name, value in options_list - ] - ) - if formatter.options - else "", + ( + " " + + " ".join( + [ + "{}={}".format( + name, + ( + value._compiler_dispatch(self, **kw) + if hasattr(value, "_compiler_dispatch") + else formatter.value_repr(name, value) + ), + ) + for name, value in options_list + ] + ) + if formatter.options + else "" + ), ) def visit_aws_bucket(self, aws_bucket, **kw): @@ -967,6 +973,29 @@ def visit_identity_column(self, identity, **kw): text += f"({start},{increment})" return text + def get_identity_options(self, identity_options): + text = [] + if identity_options.increment is not None: + text.append(f"INCREMENT BY {identity_options.increment:d}") + if identity_options.start is not None: + text.append(f"START WITH {identity_options.start:d}") + if identity_options.minvalue is not None: + text.append(f"MINVALUE {identity_options.minvalue:d}") + if identity_options.maxvalue is not None: + text.append(f"MAXVALUE {identity_options.maxvalue:d}") + if identity_options.nominvalue is not None: + text.append("NO MINVALUE") + if identity_options.nomaxvalue is not None: + text.append("NO MAXVALUE") + if identity_options.cache is not None: + text.append(f"CACHE {identity_options.cache:d}") + if identity_options.cycle is not None: + text.append("CYCLE" if identity_options.cycle else "NO CYCLE") + if identity_options.order is not None: + text.append("ORDER" if identity_options.order else "NOORDER") + + return " ".join(text) + class SnowflakeTypeCompiler(compiler.GenericTypeCompiler): def visit_BYTEINT(self, type_, **kw): diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index 9bb60916..cec16673 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -259,7 +259,8 @@ def field_delimiter(self, deli_type): def file_extension(self, ext): """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is - responsible for specifying a valid file extension that can be read by the desired software or service.""" + responsible for specifying a valid file extension that can be read by the desired software or service. + """ if not isinstance(ext, (NoneType, string_types)): raise TypeError("File extension should be a string") self.options["FILE_EXTENSION"] = ext @@ -386,7 +387,8 @@ def compression(self, comp_type): def file_extension(self, ext): """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is - responsible for specifying a valid file extension that can be read by the desired software or service.""" + responsible for specifying a valid file extension that can be read by the desired software or service. + """ if not isinstance(ext, (NoneType, string_types)): raise TypeError("File extension should be a string") self.options["FILE_EXTENSION"] = ext diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 4fefa07f..2e40d03c 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -595,11 +595,13 @@ def _get_schema_columns(self, connection, schema, **kw): "autoincrement": is_identity == "YES", "comment": comment, "primary_key": ( - column_name - in schema_primary_keys[table_name]["constrained_columns"] - ) - if current_table_pks - else False, + ( + column_name + in schema_primary_keys[table_name]["constrained_columns"] + ) + if current_table_pks + else False + ), } ) if is_identity == "YES": @@ -688,11 +690,13 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw): "autoincrement": is_identity == "YES", "comment": comment if comment != "" else None, "primary_key": ( - column_name - in schema_primary_keys[table_name]["constrained_columns"] - ) - if current_table_pks - else False, + ( + column_name + in schema_primary_keys[table_name]["constrained_columns"] + ) + if current_table_pks + else False + ), } ) @@ -876,18 +880,22 @@ def get_table_comment(self, connection, table_name, schema=None, **kw): result = self._get_view_comment(connection, table_name, schema) return { - "text": result._mapping["comment"] - if result and result._mapping["comment"] - else None + "text": ( + result._mapping["comment"] + if result and result._mapping["comment"] + else None + ) } def connect(self, *cargs, **cparams): return ( super().connect( *cargs, - **_update_connection_application_name(**cparams) - if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME - else cparams, + **( + _update_connection_application_name(**cparams) + if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME + else cparams + ), ) if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME else super().connect(*cargs, **cparams) diff --git a/src/snowflake/sqlalchemy/util.py b/src/snowflake/sqlalchemy/util.py index 54044349..32e07373 100644 --- a/src/snowflake/sqlalchemy/util.py +++ b/src/snowflake/sqlalchemy/util.py @@ -235,7 +235,14 @@ def __init__( else: adapt_from = left_selectable - (pj, sj, source, dest, secondary, target_adapter,) = prop._create_joins( + ( + pj, + sj, + source, + dest, + secondary, + target_adapter, + ) = prop._create_joins( source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index 6aea4f54..d4318b86 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = (1, 5, 1, None) +VERSION = (1, 5, 2, None) diff --git a/tests/test_core.py b/tests/test_core.py index 29c55ae9..60b4fea4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -29,6 +29,7 @@ Table, UniqueConstraint, dialects, + insert, inspect, text, ) @@ -1424,47 +1425,51 @@ def test_special_schema_character(db_parameters, on_public_ci): def test_autoincrement(engine_testaccount): + """Snowflake does not guarantee generating sequence numbers without gaps. + + The generated numbers are not necessarily contiguous. + https://docs.snowflake.com/en/user-guide/querying-sequences + """ metadata = MetaData() users = Table( "users", metadata, - Column("uid", Integer, Sequence("id_seq"), primary_key=True), + Column("uid", Integer, Sequence("id_seq", order=True), primary_key=True), Column("name", String(39)), ) try: - users.create(engine_testaccount) - - with engine_testaccount.connect() as connection: - with connection.begin(): - connection.execute(users.insert(), [{"name": "sf1"}]) - assert connection.execute(select(users)).fetchall() == [(1, "sf1")] - connection.execute(users.insert(), [{"name": "sf2"}, {"name": "sf3"}]) - assert connection.execute(select(users)).fetchall() == [ - (1, "sf1"), - (2, "sf2"), - (3, "sf3"), - ] - connection.execute(users.insert(), {"name": "sf4"}) - assert connection.execute(select(users)).fetchall() == [ - (1, "sf1"), - (2, "sf2"), - (3, "sf3"), - (4, "sf4"), - ] - - seq = Sequence("id_seq") - nextid = connection.execute(seq) - connection.execute(users.insert(), [{"uid": nextid, "name": "sf5"}]) - assert connection.execute(select(users)).fetchall() == [ - (1, "sf1"), - (2, "sf2"), - (3, "sf3"), - (4, "sf4"), - (5, "sf5"), - ] + metadata.create_all(engine_testaccount) + + with engine_testaccount.begin() as connection: + connection.execute(insert(users), [{"name": "sf1"}]) + assert connection.execute(select(users)).fetchall() == [(1, "sf1")] + connection.execute(insert(users), [{"name": "sf2"}, {"name": "sf3"}]) + assert connection.execute(select(users)).fetchall() == [ + (1, "sf1"), + (2, "sf2"), + (3, "sf3"), + ] + connection.execute(insert(users), {"name": "sf4"}) + assert connection.execute(select(users)).fetchall() == [ + (1, "sf1"), + (2, "sf2"), + (3, "sf3"), + (4, "sf4"), + ] + + seq = Sequence("id_seq") + nextid = connection.execute(seq) + connection.execute(insert(users), [{"uid": nextid, "name": "sf5"}]) + assert connection.execute(select(users)).fetchall() == [ + (1, "sf1"), + (2, "sf2"), + (3, "sf3"), + (4, "sf4"), + (5, "sf5"), + ] finally: - users.drop(engine_testaccount) + metadata.drop_all(engine_testaccount) @pytest.mark.skip( @@ -1869,10 +1874,16 @@ def test_snowflake_sqlalchemy_as_valid_client_type(): ) snowflake.connector.connection.DEFAULT_CONFIGURATION[ "internal_application_name" - ] = ("PythonConnector", (type(None), str)) + ] = ( + "PythonConnector", + (type(None), str), + ) snowflake.connector.connection.DEFAULT_CONFIGURATION[ "internal_application_version" - ] = ("3.0.0", (type(None), str)) + ] = ( + "3.0.0", + (type(None), str), + ) engine = create_engine( URL( user=CONNECTION_PARAMETERS["user"], diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 78658012..e428b9d7 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -2,89 +2,136 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table, select +from sqlalchemy import ( + Column, + Integer, + MetaData, + Sequence, + String, + Table, + insert, + select, +) +from sqlalchemy.sql import text def test_table_with_sequence(engine_testaccount, db_parameters): + """Snowflake does not guarantee generating sequence numbers without gaps. + + The generated numbers are not necessarily contiguous. + https://docs.snowflake.com/en/user-guide/querying-sequences + """ # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/124 test_table_name = "sequence" test_sequence_name = f"{test_table_name}_id_seq" + metadata = MetaData() + sequence_table = Table( test_table_name, - MetaData(), - Column("id", Integer, Sequence(test_sequence_name), primary_key=True), + metadata, + Column( + "id", Integer, Sequence(test_sequence_name, order=True), primary_key=True + ), Column("data", String(39)), ) - sequence_table.create(engine_testaccount) - seq = Sequence(test_sequence_name) + + autoload_metadata = MetaData() + try: - with engine_testaccount.connect() as conn: - with conn.begin(): - conn.execute(sequence_table.insert(), [{"data": "test_insert_1"}]) - select_stmt = select(sequence_table).order_by("id") - result = conn.execute(select_stmt).fetchall() - assert result == [(1, "test_insert_1")] - autoload_sequence_table = Table( - test_table_name, MetaData(), autoload_with=engine_testaccount - ) - conn.execute( - autoload_sequence_table.insert(), - [{"data": "multi_insert_1"}, {"data": "multi_insert_2"}], - ) - conn.execute( - autoload_sequence_table.insert(), [{"data": "test_insert_2"}] - ) - nextid = conn.execute(seq) - conn.execute( - autoload_sequence_table.insert(), - [{"id": nextid, "data": "test_insert_seq"}], - ) - result = conn.execute(select_stmt).fetchall() - assert result == [ - (1, "test_insert_1"), - (2, "multi_insert_1"), - (3, "multi_insert_2"), - (4, "test_insert_2"), - (5, "test_insert_seq"), - ] + metadata.create_all(engine_testaccount) + + with engine_testaccount.begin() as conn: + conn.execute(insert(sequence_table), ({"data": "test_insert_1"})) + result = conn.execute(select(sequence_table)).fetchall() + assert result == [(1, "test_insert_1")], result + + autoload_sequence_table = Table( + test_table_name, + autoload_metadata, + autoload_with=engine_testaccount, + ) + seq = Sequence(test_sequence_name, order=True) + + conn.execute( + insert(autoload_sequence_table), + ( + {"data": "multi_insert_1"}, + {"data": "multi_insert_2"}, + ), + ) + conn.execute( + insert(autoload_sequence_table), + ({"data": "test_insert_2"},), + ) + + nextid = conn.execute(seq) + conn.execute( + insert(autoload_sequence_table), + ({"id": nextid, "data": "test_insert_seq"}), + ) + + result = conn.execute(select(sequence_table)).fetchall() + + assert result == [ + (1, "test_insert_1"), + (2, "multi_insert_1"), + (3, "multi_insert_2"), + (4, "test_insert_2"), + (5, "test_insert_seq"), + ], result + finally: - sequence_table.drop(engine_testaccount) - seq.drop(engine_testaccount) + metadata.drop_all(engine_testaccount) -def test_table_with_autoincrement(engine_testaccount, db_parameters): +def test_table_with_autoincrement(engine_testaccount): + """Snowflake does not guarantee generating sequence numbers without gaps. + + The generated numbers are not necessarily contiguous. + https://docs.snowflake.com/en/user-guide/querying-sequences + """ # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/124 test_table_name = "sequence" + metadata = MetaData() autoincrement_table = Table( test_table_name, - MetaData(), + metadata, Column("id", Integer, autoincrement=True, primary_key=True), Column("data", String(39)), ) - autoincrement_table.create(engine_testaccount) + + select_stmt = select(autoincrement_table).order_by("id") + try: - with engine_testaccount.connect() as conn: - with conn.begin(): - conn.execute(autoincrement_table.insert(), [{"data": "test_insert_1"}]) - select_stmt = select(autoincrement_table).order_by("id") - result = conn.execute(select_stmt).fetchall() - assert result == [(1, "test_insert_1")] - autoload_sequence_table = Table( - test_table_name, MetaData(), autoload_with=engine_testaccount - ) - conn.execute( - autoload_sequence_table.insert(), - [{"data": "multi_insert_1"}, {"data": "multi_insert_2"}], - ) - conn.execute( - autoload_sequence_table.insert(), [{"data": "test_insert_2"}] - ) - result = conn.execute(select_stmt).fetchall() - assert result == [ - (1, "test_insert_1"), - (2, "multi_insert_1"), - (3, "multi_insert_2"), - (4, "test_insert_2"), - ] + with engine_testaccount.begin() as conn: + conn.execute(text("ALTER SESSION SET NOORDER_SEQUENCE_AS_DEFAULT = FALSE")) + metadata.create_all(conn) + + conn.execute(insert(autoincrement_table), ({"data": "test_insert_1"})) + result = conn.execute(select_stmt).fetchall() + assert result == [(1, "test_insert_1")] + + autoload_sequence_table = Table( + test_table_name, MetaData(), autoload_with=engine_testaccount + ) + conn.execute( + insert(autoload_sequence_table), + [ + {"data": "multi_insert_1"}, + {"data": "multi_insert_2"}, + ], + ) + conn.execute( + insert(autoload_sequence_table), + [{"data": "test_insert_2"}], + ) + result = conn.execute(select_stmt).fetchall() + assert result == [ + (1, "test_insert_1"), + (2, "multi_insert_1"), + (3, "multi_insert_2"), + (4, "test_insert_2"), + ], result + finally: - autoincrement_table.drop(engine_testaccount) + metadata.drop_all(engine_testaccount)