Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1212541-ordered-sequence: add support for creating sequences order #473

Merged
merged 4 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
exclude: '^(.*egg.info.*|.*/parameters.py).*$'
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
exclude: .github/repo_meta.yaml
- id: debug-statements
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/asottile/pyupgrade
rev: v2.37.3
rev: v3.15.1
hooks:
- id: pyupgrade
args: [--py37-plus]
- repo: https://github.com/psf/black
rev: 22.6.0
rev: 24.2.0
hooks:
- id: black
args:
- --safe
language_version: python3
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
rev: v1.3.0
rev: v1.5.5
hooks:
- id: insert-license
name: insert-py-license
Expand All @@ -39,7 +39,7 @@ repos:
- --license-filepath
- license_header.txt
- repo: https://github.com/pycqa/flake8
rev: 5.0.4
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies:
Expand Down
4 changes: 4 additions & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ Source code is also available at:

# Release Notes

- 1.5.2

- Add support for sequence ordering in tests

- v1.5.1(November 03, 2023)

- Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057.
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[tool.ruff]
line-length = 88

[tool.black]
line-length = 88
65 changes: 47 additions & 18 deletions src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
@CompileState.plugin_for("default", "select")
class SnowflakeSelectState(SelectState):
def _setup_joins(self, args, raw_columns):
for (right, onclause, left, flags) in args:
for right, onclause, left, flags in args:
isouter = flags["isouter"]
full = flags["full"]

Expand Down Expand Up @@ -579,9 +579,11 @@ def visit_copy_into(self, copy_into, **kw):
[
"{} = {}".format(
n,
v._compiler_dispatch(self, **kw)
if getattr(v, "compiler_dispatch", False)
else str(v),
(
v._compiler_dispatch(self, **kw)
if getattr(v, "compiler_dispatch", False)
else str(v)
),
)
for n, v in options_list
]
Expand All @@ -604,20 +606,24 @@ def visit_copy_formatter(self, formatter, **kw):
return f"FILE_FORMAT=(format_name = {formatter.options['format_name']})"
return "FILE_FORMAT=(TYPE={}{})".format(
formatter.file_format,
" "
+ " ".join(
[
"{}={}".format(
name,
value._compiler_dispatch(self, **kw)
if hasattr(value, "_compiler_dispatch")
else formatter.value_repr(name, value),
)
for name, value in options_list
]
)
if formatter.options
else "",
(
" "
+ " ".join(
[
"{}={}".format(
name,
(
value._compiler_dispatch(self, **kw)
if hasattr(value, "_compiler_dispatch")
else formatter.value_repr(name, value)
),
)
for name, value in options_list
]
)
if formatter.options
else ""
),
)

def visit_aws_bucket(self, aws_bucket, **kw):
Expand Down Expand Up @@ -967,6 +973,29 @@ def visit_identity_column(self, identity, **kw):
text += f"({start},{increment})"
return text

def get_identity_options(self, identity_options):
text = []
if identity_options.increment is not None:
text.append(f"INCREMENT BY {identity_options.increment:d}")
if identity_options.start is not None:
text.append(f"START WITH {identity_options.start:d}")
if identity_options.minvalue is not None:
text.append(f"MINVALUE {identity_options.minvalue:d}")
if identity_options.maxvalue is not None:
text.append(f"MAXVALUE {identity_options.maxvalue:d}")
if identity_options.nominvalue is not None:
text.append("NO MINVALUE")
if identity_options.nomaxvalue is not None:
text.append("NO MAXVALUE")
if identity_options.cache is not None:
text.append(f"CACHE {identity_options.cache:d}")
if identity_options.cycle is not None:
text.append("CYCLE" if identity_options.cycle else "NO CYCLE")
if identity_options.order is not None:
text.append("ORDER" if identity_options.order else "NOORDER")

return " ".join(text)


class SnowflakeTypeCompiler(compiler.GenericTypeCompiler):
def visit_BYTEINT(self, type_, **kw):
Expand Down
6 changes: 4 additions & 2 deletions src/snowflake/sqlalchemy/custom_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ def field_delimiter(self, deli_type):

def file_extension(self, ext):
"""String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is
responsible for specifying a valid file extension that can be read by the desired software or service."""
responsible for specifying a valid file extension that can be read by the desired software or service.
"""
if not isinstance(ext, (NoneType, string_types)):
raise TypeError("File extension should be a string")
self.options["FILE_EXTENSION"] = ext
Expand Down Expand Up @@ -386,7 +387,8 @@ def compression(self, comp_type):

def file_extension(self, ext):
"""String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is
responsible for specifying a valid file extension that can be read by the desired software or service."""
responsible for specifying a valid file extension that can be read by the desired software or service.
"""
if not isinstance(ext, (NoneType, string_types)):
raise TypeError("File extension should be a string")
self.options["FILE_EXTENSION"] = ext
Expand Down
40 changes: 24 additions & 16 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,11 +595,13 @@ def _get_schema_columns(self, connection, schema, **kw):
"autoincrement": is_identity == "YES",
"comment": comment,
"primary_key": (
column_name
in schema_primary_keys[table_name]["constrained_columns"]
)
if current_table_pks
else False,
(
column_name
in schema_primary_keys[table_name]["constrained_columns"]
)
if current_table_pks
else False
),
}
)
if is_identity == "YES":
Expand Down Expand Up @@ -688,11 +690,13 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):
"autoincrement": is_identity == "YES",
"comment": comment if comment != "" else None,
"primary_key": (
column_name
in schema_primary_keys[table_name]["constrained_columns"]
)
if current_table_pks
else False,
(
column_name
in schema_primary_keys[table_name]["constrained_columns"]
)
if current_table_pks
else False
),
}
)

Expand Down Expand Up @@ -876,18 +880,22 @@ def get_table_comment(self, connection, table_name, schema=None, **kw):
result = self._get_view_comment(connection, table_name, schema)

return {
"text": result._mapping["comment"]
if result and result._mapping["comment"]
else None
"text": (
result._mapping["comment"]
if result and result._mapping["comment"]
else None
)
}

def connect(self, *cargs, **cparams):
return (
super().connect(
*cargs,
**_update_connection_application_name(**cparams)
if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME
else cparams,
**(
_update_connection_application_name(**cparams)
if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME
else cparams
),
)
if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME
else super().connect(*cargs, **cparams)
Expand Down
9 changes: 8 additions & 1 deletion src/snowflake/sqlalchemy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,14 @@ def __init__(
else:
adapt_from = left_selectable

(pj, sj, source, dest, secondary, target_adapter,) = prop._create_joins(
(
pj,
sj,
source,
dest,
secondary,
target_adapter,
) = prop._create_joins(
source_selectable=adapt_from,
dest_selectable=adapt_to,
source_polymorphic=True,
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/sqlalchemy/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
#
# Update this for the versions
# Don't change the forth version number from None
VERSION = (1, 5, 1, None)
VERSION = (1, 5, 2, None)
79 changes: 45 additions & 34 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Table,
UniqueConstraint,
dialects,
insert,
inspect,
text,
)
Expand Down Expand Up @@ -1424,47 +1425,51 @@ def test_special_schema_character(db_parameters, on_public_ci):


def test_autoincrement(engine_testaccount):
"""Snowflake does not guarantee generating sequence numbers without gaps.

The generated numbers are not necessarily contiguous.
https://docs.snowflake.com/en/user-guide/querying-sequences
"""
metadata = MetaData()
users = Table(
"users",
metadata,
Column("uid", Integer, Sequence("id_seq"), primary_key=True),
Column("uid", Integer, Sequence("id_seq", order=True), primary_key=True),
Column("name", String(39)),
)

try:
users.create(engine_testaccount)

with engine_testaccount.connect() as connection:
with connection.begin():
connection.execute(users.insert(), [{"name": "sf1"}])
assert connection.execute(select(users)).fetchall() == [(1, "sf1")]
connection.execute(users.insert(), [{"name": "sf2"}, {"name": "sf3"}])
assert connection.execute(select(users)).fetchall() == [
(1, "sf1"),
(2, "sf2"),
(3, "sf3"),
]
connection.execute(users.insert(), {"name": "sf4"})
assert connection.execute(select(users)).fetchall() == [
(1, "sf1"),
(2, "sf2"),
(3, "sf3"),
(4, "sf4"),
]

seq = Sequence("id_seq")
nextid = connection.execute(seq)
connection.execute(users.insert(), [{"uid": nextid, "name": "sf5"}])
assert connection.execute(select(users)).fetchall() == [
(1, "sf1"),
(2, "sf2"),
(3, "sf3"),
(4, "sf4"),
(5, "sf5"),
]
metadata.create_all(engine_testaccount)

with engine_testaccount.begin() as connection:
connection.execute(insert(users), [{"name": "sf1"}])
assert connection.execute(select(users)).fetchall() == [(1, "sf1")]
connection.execute(insert(users), [{"name": "sf2"}, {"name": "sf3"}])
assert connection.execute(select(users)).fetchall() == [
(1, "sf1"),
(2, "sf2"),
(3, "sf3"),
]
connection.execute(insert(users), {"name": "sf4"})
assert connection.execute(select(users)).fetchall() == [
(1, "sf1"),
(2, "sf2"),
(3, "sf3"),
(4, "sf4"),
]

seq = Sequence("id_seq")
nextid = connection.execute(seq)
connection.execute(insert(users), [{"uid": nextid, "name": "sf5"}])
assert connection.execute(select(users)).fetchall() == [
(1, "sf1"),
(2, "sf2"),
(3, "sf3"),
(4, "sf4"),
(5, "sf5"),
]
finally:
users.drop(engine_testaccount)
metadata.drop_all(engine_testaccount)


@pytest.mark.skip(
Expand Down Expand Up @@ -1869,10 +1874,16 @@ def test_snowflake_sqlalchemy_as_valid_client_type():
)
snowflake.connector.connection.DEFAULT_CONFIGURATION[
"internal_application_name"
] = ("PythonConnector", (type(None), str))
] = (
"PythonConnector",
(type(None), str),
)
snowflake.connector.connection.DEFAULT_CONFIGURATION[
"internal_application_version"
] = ("3.0.0", (type(None), str))
] = (
"3.0.0",
(type(None), str),
)
engine = create_engine(
URL(
user=CONNECTION_PARAMETERS["user"],
Expand Down
Loading
Loading