From 2f526ce5dc988f46e09240a8dcd0563fa227834e Mon Sep 17 00:00:00 2001 From: Artur Dryomov Date: Mon, 16 Dec 2024 20:11:10 +0100 Subject: [PATCH] Move Black to pre-commit --- .pre-commit-config.yaml | 7 + setup.cfg | 4 +- setup.py | 33 +- tests/integration/conftest.py | 12 +- tests/integration/test_dbapi_integration.py | 895 +++++++------ .../test_sqlalchemy_integration.py | 447 ++++--- tests/integration/test_types_integration.py | 1128 +++++++---------- tests/unit/conftest.py | 4 +- tests/unit/oauth_test_utils.py | 69 +- tests/unit/sqlalchemy/conftest.py | 2 +- tests/unit/sqlalchemy/test_compiler.py | 98 +- tests/unit/sqlalchemy/test_datatype_parse.py | 4 +- tests/unit/sqlalchemy/test_dialect.py | 286 +++-- tests/unit/test_client.py | 259 ++-- tests/unit/test_dbapi.py | 99 +- tests/unit/test_transaction.py | 4 +- trino/auth.py | 108 +- trino/client.py | 85 +- trino/constants.py | 6 +- trino/dbapi.py | 96 +- trino/logging.py | 2 +- trino/mapper.py | 162 +-- trino/sqlalchemy/compiler.py | 36 +- trino/sqlalchemy/datatype.py | 8 +- trino/sqlalchemy/dialect.py | 25 +- trino/sqlalchemy/util.py | 2 +- trino/transaction.py | 12 +- trino/types.py | 9 +- 28 files changed, 1862 insertions(+), 2040 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c12618df..7dfae92c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,6 +18,13 @@ repos: name: "Python: imports" args: ["--py39-plus"] + - repo: "https://github.com/psf/black-pre-commit-mirror" + rev: "24.10.0" + hooks: + - id: "black" + name: "Python: format" + args: ["--target-version", "py39", "--line-length", "120"] + - repo: "https://github.com/pre-commit/pre-commit-hooks" rev: "v5.0.0" hooks: diff --git a/setup.cfg b/setup.cfg index c7cc318a..e42948ca 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,9 @@ max-line-length = 120 # # https://www.flake8rules.com/rules/W503.html # https://www.python.org/dev/peps/pep-0008/#should-a-line-break-before-or-after-a-binary-operator -ignore = W503 +# +# E203 and E701 are Black-specific exclusions. +ignore = E203,E701,W503 [mypy] check_untyped_defs = true diff --git a/setup.py b/setup.py index 0a512e6e..d9761ac4 100755 --- a/setup.py +++ b/setup.py @@ -27,27 +27,30 @@ readme = f.read() kerberos_require = ["requests_kerberos"] -gssapi_require = ["" - "requests_gssapi", - # PyPy compatibility issue https://github.com/jborean93/pykrb5/issues/49 - "krb5 == 0.5.1"] +gssapi_require = [ + "" "requests_gssapi", + # PyPy compatibility issue https://github.com/jborean93/pykrb5/issues/49 + "krb5 == 0.5.1", +] sqlalchemy_require = ["sqlalchemy >= 1.3"] external_authentication_token_cache_require = ["keyring"] # We don't add localstorage_require to all_require as users must explicitly opt in to use keyring. all_require = kerberos_require + sqlalchemy_require -tests_require = all_require + gssapi_require + [ - # httpretty >= 1.1 duplicates requests in `httpretty.latest_requests` - # https://github.com/gabrielfalcao/HTTPretty/issues/425 - "httpretty < 1.1", - "pytest", - "pytest-runner", - "pre-commit", - "black", - "isort", - "keyring" -] +tests_require = ( + all_require + + gssapi_require + + [ + # httpretty >= 1.1 duplicates requests in `httpretty.latest_requests` + # https://github.com/gabrielfalcao/HTTPretty/issues/425 + "httpretty < 1.1", + "pytest", + "pytest-runner", + "pre-commit", + "keyring", + ] +) setup( name=about["__title__"], diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 23fe3037..31716ded 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -72,13 +72,7 @@ def start_trino(image_tag=None): def wait_for_trino_workers(host, port, timeout=180): - request = TrinoRequest( - host=host, - port=port, - client_session=ClientSession( - user="test_fixture" - ) - ) + request = TrinoRequest(host=host, port=port, client_session=ClientSession(user="test_fixture")) sql = "SELECT state FROM system.runtime.nodes" t0 = time.time() while True: @@ -124,9 +118,7 @@ def start_trino_and_wait(image_tag=None): if host: port = os.environ.get("TRINO_RUNNING_PORT", DEFAULT_PORT) else: - container_id, proc, host, port = start_local_trino_server( - image_tag - ) + container_id, proc, host, port = start_local_trino_server(image_tag) print("trino.server.hostname {}".format(host)) print("trino.server.port {}".format(port)) diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 7165408f..c254d768 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -41,9 +41,7 @@ def trino_connection(run_trino): _, host, port = run_trino - yield trino.dbapi.Connection( - host=host, port=port, user="test", source="test", max_attempts=1 - ) + yield trino.dbapi.Connection(host=host, port=port, user="test", source="test", max_attempts=1) @pytest.fixture @@ -111,11 +109,14 @@ def test_select_query_result_iteration(trino_connection): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_select_query_result_iteration_statement_params(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -132,24 +133,27 @@ def test_select_query_result_iteration_statement_params(legacy_prepared_statemen ) x (id, name, letter) WHERE id >= ? """, - params=(3,) # expecting all the rows with id >= 3 + params=(3,), # expecting all the rows with id >= 3 ) rows = cur.fetchall() assert len(rows) == 3 - assert [3, 'three', 'c'] in rows - assert [4, 'four', 'd'] in rows - assert [5, 'five', 'e'] in rows + assert [3, "three", "c"] in rows + assert [4, "four", "d"] in rows + assert [5, "five", "e"] in rows @pytest.mark.parametrize( "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_none_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -164,11 +168,14 @@ def test_none_query_param(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_string_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -184,11 +191,14 @@ def test_string_query_param(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_execute_many(legacy_prepared_statements, run_trino): @@ -223,11 +233,14 @@ def test_execute_many(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_execute_many_without_params(legacy_prepared_statements, run_trino): @@ -248,11 +261,14 @@ def test_execute_many_without_params(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_execute_many_select(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -264,9 +280,7 @@ def test_execute_many_select(legacy_prepared_statements, run_trino): @pytest.mark.parametrize("connection_legacy_primitive_types", [None, True, False]) @pytest.mark.parametrize("cursor_legacy_primitive_types", [None, True, False]) def test_legacy_primitive_types_with_connection_and_cursor( - connection_legacy_primitive_types, - cursor_legacy_primitive_types, - run_trino + connection_legacy_primitive_types, cursor_legacy_primitive_types, run_trino ): _, host, port = run_trino @@ -308,10 +322,10 @@ def test_legacy_primitive_types_with_connection_and_cursor( if not expected_legacy_primitive_types: assert len(rows[0]) == 6 - assert rows[0][0] == Decimal('0.142857') + assert rows[0][0] == Decimal("0.142857") assert rows[0][1] == date(2018, 1, 1) assert rows[0][2] == datetime(2019, 1, 1, tzinfo=timezone(timedelta(hours=1))) - assert rows[0][3] == datetime(2019, 1, 1, tzinfo=ZoneInfo('UTC')) + assert rows[0][3] == datetime(2019, 1, 1, tzinfo=ZoneInfo("UTC")) assert rows[0][4] == datetime(2019, 1, 1) assert rows[0][5] == time(0, 0, 0, 0) else: @@ -319,32 +333,35 @@ def test_legacy_primitive_types_with_connection_and_cursor( assert isinstance(value, str) assert len(rows[0]) == 7 - assert rows[0][0] == '0.142857' - assert rows[0][1] == '2018-01-01' - assert rows[0][2] == '2019-01-01 00:00:00.000 +01:00' - assert rows[0][3] == '2019-01-01 00:00:00.000 UTC' - assert rows[0][4] == '2019-01-01 00:00:00.000' - assert rows[0][5] == '00:00:00.000' - assert rows[0][6] == '-2001-08-22' + assert rows[0][0] == "0.142857" + assert rows[0][1] == "2018-01-01" + assert rows[0][2] == "2019-01-01 00:00:00.000 +01:00" + assert rows[0][3] == "2019-01-01 00:00:00.000 UTC" + assert rows[0][4] == "2019-01-01 00:00:00.000" + assert rows[0][5] == "00:00:00.000" + assert rows[0][6] == "-2001-08-22" @pytest.mark.parametrize( "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_decimal_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) - cur.execute("SELECT ?", params=(Decimal('1112.142857'),)) + cur.execute("SELECT ?", params=(Decimal("1112.142857"),)) rows = cur.fetchall() - assert rows[0][0] == Decimal('1112.142857') + assert rows[0][0] == Decimal("1112.142857") assert_cursor_description(cur, trino_type="decimal(10, 6)", precision=10, scale=6) @@ -352,28 +369,31 @@ def test_decimal_query_param(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_decimal_scientific_notation_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) - cur.execute("SELECT ?", params=(Decimal('0E-10'),)) + cur.execute("SELECT ?", params=(Decimal("0E-10"),)) rows = cur.fetchall() - assert rows[0][0] == Decimal('0E-10') + assert rows[0][0] == Decimal("0E-10") assert_cursor_description(cur, trino_type="decimal(10, 10)", precision=10, scale=10) # Ensure we don't convert to floats - assert Decimal('0.1') == Decimal('1E-1') != 0.1 + assert Decimal("0.1") == Decimal("1E-1") != 0.1 - cur.execute("SELECT ?", params=(Decimal('1E-1'),)) + cur.execute("SELECT ?", params=(Decimal("1E-1"),)) rows = cur.fetchall() - assert rows[0][0] == Decimal('1E-1') + assert rows[0][0] == Decimal("1E-1") assert_cursor_description(cur, trino_type="decimal(1, 1)", precision=1, scale=1) @@ -391,16 +411,19 @@ def test_null_decimal(trino_connection): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_biggest_decimal(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) - params = Decimal('99999999999999999999999999999999999999') + params = Decimal("99999999999999999999999999999999999999") cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -412,16 +435,19 @@ def test_biggest_decimal(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_smallest_decimal(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) - params = Decimal('-99999999999999999999999999999999999999') + params = Decimal("-99999999999999999999999999999999999999") cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -433,16 +459,19 @@ def test_smallest_decimal(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_highest_precision_decimal(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) - params = Decimal('0.99999999999999999999999999999999999999') + params = Decimal("0.99999999999999999999999999999999999999") cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -454,11 +483,14 @@ def test_highest_precision_decimal(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_datetime_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -476,16 +508,19 @@ def test_datetime_query_param(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_datetime_with_utc_time_zone_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) - params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=ZoneInfo('UTC')) + params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=ZoneInfo("UTC")) cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -498,11 +533,14 @@ def test_datetime_with_utc_time_zone_query_param(legacy_prepared_statements, run "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_datetime_with_numeric_offset_time_zone_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -522,16 +560,19 @@ def test_datetime_with_numeric_offset_time_zone_query_param(legacy_prepared_stat "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_datetime_with_named_time_zone_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) - params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=ZoneInfo('America/Los_Angeles')) + params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=ZoneInfo("America/Los_Angeles")) cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -574,17 +615,20 @@ def test_datetime_with_time_zone_numeric_offset(trino_connection): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_datetimes_with_time_zone_in_dst_gap_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) # This is a datetime that lies within a DST transition and not actually exists. - params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=ZoneInfo('Europe/Brussels')) + params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=ZoneInfo("Europe/Brussels")) with pytest.raises(trino.exceptions.TrinoUserError): cur.execute("SELECT ?", params=(params,)) cur.fetchall() @@ -594,35 +638,41 @@ def test_datetimes_with_time_zone_in_dst_gap_query_param(legacy_prepared_stateme "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) -@pytest.mark.parametrize('fold', [0, 1]) +@pytest.mark.parametrize("fold", [0, 1]) def test_doubled_datetimes(fold, legacy_prepared_statements, run_trino): # Trino doesn't distinguish between doubled datetimes that lie within a DST transition. # See also https://github.com/trinodb/trino/issues/5781 cur = get_cursor(legacy_prepared_statements, run_trino) - params = datetime(2002, 10, 27, 1, 30, 0, tzinfo=ZoneInfo('US/Eastern'), fold=fold) + params = datetime(2002, 10, 27, 1, 30, 0, tzinfo=ZoneInfo("US/Eastern"), fold=fold) cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() - assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=ZoneInfo('US/Eastern')) + assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=ZoneInfo("US/Eastern")) @pytest.mark.parametrize( "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_date_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -651,11 +701,11 @@ def test_unsupported_python_dates(trino_connection): # dates below python min (1-1-1) or above max date (9999-12-31) are not supported for unsupported_date in [ - '-0001-01-01', - '0000-01-01', - '10000-01-01', - '-4999999-01-01', # Trino min date - '5000000-12-31', # Trino max date + "-0001-01-01", + "0000-01-01", + "10000-01-01", + "-4999999-01-01", # Trino min date + "5000000-12-31", # Trino max date ]: with pytest.raises(trino.exceptions.TrinoDataError): cur.execute(f"SELECT DATE '{unsupported_date}'") @@ -666,36 +716,39 @@ def test_unsupported_python_dates(trino_connection): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_supported_special_dates_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) for params in ( - # min python date - date(1, 1, 1), - # before julian->gregorian switch - date(1500, 1, 1), - # During julian->gregorian switch - date(1752, 9, 4), - # before epoch - date(1952, 4, 3), - date(1970, 1, 1), - date(1970, 2, 3), - # summer on northern hemisphere (possible DST) - date(2017, 7, 1), - # winter on northern hemisphere (possible DST on southern hemisphere) - date(2017, 1, 1), - # winter on southern hemisphere (possible DST on northern hemisphere) - date(2017, 12, 31), - date(1983, 4, 1), - date(1983, 10, 1), - # max python date - date(9999, 12, 31), + # min python date + date(1, 1, 1), + # before julian->gregorian switch + date(1500, 1, 1), + # During julian->gregorian switch + date(1752, 9, 4), + # before epoch + date(1952, 4, 3), + date(1970, 1, 1), + date(1970, 2, 3), + # summer on northern hemisphere (possible DST) + date(2017, 7, 1), + # winter on northern hemisphere (possible DST on southern hemisphere) + date(2017, 1, 1), + # winter on southern hemisphere (possible DST on northern hemisphere) + date(2017, 12, 31), + date(1983, 4, 1), + date(1983, 10, 1), + # max python date + date(9999, 12, 31), ): cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -709,7 +762,7 @@ def test_char(trino_connection): cur.execute("SELECT CHAR 'trino'") rows = cur.fetchall() - assert rows[0][0] == 'trino' + assert rows[0][0] == "trino" assert_cursor_description(cur, trino_type="char(5)", size=5) @@ -717,11 +770,14 @@ def test_char(trino_connection): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_time_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -739,16 +795,19 @@ def test_time_query_param(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_time_with_named_time_zone_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) - params = time(16, 43, 22, 320000, tzinfo=ZoneInfo('Asia/Shanghai')) + params = time(16, 43, 22, 320000, tzinfo=ZoneInfo("Asia/Shanghai")) cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() @@ -761,11 +820,14 @@ def test_time_with_named_time_zone_query_param(legacy_prepared_statements, run_t "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_time_with_numeric_offset_time_zone_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -837,18 +899,21 @@ def test_null_date_with_time_zone(trino_connection): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) @pytest.mark.parametrize( "binary_input", [ bytearray("a", "utf-8"), bytearray("a", "ascii"), - bytearray(b'\x00\x00\x00\x00'), + bytearray(b"\x00\x00\x00\x00"), bytearray(4), bytearray([1, 2, 3]), ], @@ -866,11 +931,14 @@ def test_binary_query_param(binary_input, legacy_prepared_statements, run_trino) "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_array_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -895,11 +963,14 @@ def test_array_query_param(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_array_none_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -921,11 +992,14 @@ def test_array_none_query_param(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_array_none_and_another_type_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -947,11 +1021,14 @@ def test_array_none_and_another_type_query_param(legacy_prepared_statements, run "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_array_timestamp_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -973,18 +1050,21 @@ def test_array_timestamp_query_param(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_array_timestamp_with_timezone_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) params = [ - datetime(2020, 1, 1, 0, 0, 0, tzinfo=ZoneInfo('UTC')), - datetime(2020, 1, 2, 0, 0, 0, tzinfo=ZoneInfo('UTC')), + datetime(2020, 1, 1, 0, 0, 0, tzinfo=ZoneInfo("UTC")), + datetime(2020, 1, 2, 0, 0, 0, tzinfo=ZoneInfo("UTC")), ] cur.execute("SELECT ?", params=(params,)) @@ -1002,11 +1082,14 @@ def test_array_timestamp_with_timezone_query_param(legacy_prepared_statements, r "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_dict_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -1026,11 +1109,14 @@ def test_dict_query_param(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_dict_timestamp_query_param_types(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -1046,11 +1132,14 @@ def test_dict_timestamp_query_param_types(legacy_prepared_statements, run_trino) "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_boolean_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -1070,11 +1159,14 @@ def test_boolean_query_param(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_row(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -1089,11 +1181,14 @@ def test_row(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_nested_row(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -1115,8 +1210,8 @@ def test_named_row(trino_connection): assert rows[0][0].x == 1 assert rows[0][0].y == 2.0 - assert rows[0][0].__annotations__["names"] == ['x', 'y'] - assert rows[0][0].__annotations__["types"] == ['bigint', 'double'] + assert rows[0][0].__annotations__["names"] == ["x", "y"] + assert rows[0][0].__annotations__["types"] == ["bigint", "double"] def test_named_row_duplicate_names(trino_connection): @@ -1128,8 +1223,8 @@ def test_named_row_duplicate_names(trino_connection): with pytest.raises(ValueError, match="Ambiguous row field reference: x"): rows[0][0].x - assert rows[0][0].__annotations__["names"] == ['x', 'x'] - assert rows[0][0].__annotations__["types"] == ['bigint', 'double'] + assert rows[0][0].__annotations__["names"] == ["x", "x"] + assert rows[0][0].__annotations__["types"] == ["bigint", "double"] assert str(rows[0][0]) == "(1, 2.0)" @@ -1138,20 +1233,20 @@ def test_nested_named_row(trino_connection): cur.execute("SELECT CAST(ROW(DECIMAL '2.3', ROW(1, 'test')) AS ROW(x DECIMAL(3,2), y ROW(x BIGINT, y VARCHAR)))") rows = cur.fetchall() - assert rows[0][0] == (Decimal('2.3'), (1, 'test')) - assert rows[0][0][0] == Decimal('2.3') - assert rows[0][0][1] == (1, 'test') + assert rows[0][0] == (Decimal("2.3"), (1, "test")) + assert rows[0][0][0] == Decimal("2.3") + assert rows[0][0][1] == (1, "test") assert rows[0][0][1][0] == 1 - assert rows[0][0][1][1] == 'test' - assert rows[0][0].x == Decimal('2.3') + assert rows[0][0][1][1] == "test" + assert rows[0][0].x == Decimal("2.3") assert rows[0][0].y.x == 1 - assert rows[0][0].y.y == 'test' + assert rows[0][0].y.y == "test" - assert rows[0][0].__annotations__["names"] == ['x', 'y'] - assert rows[0][0].__annotations__["types"] == ['decimal', 'row'] + assert rows[0][0].__annotations__["names"] == ["x", "y"] + assert rows[0][0].__annotations__["types"] == ["decimal", "row"] - assert rows[0][0].y.__annotations__["names"] == ['x', 'y'] - assert rows[0][0].y.__annotations__["types"] == ['bigint', 'varchar'] + assert rows[0][0].y.__annotations__["names"] == ["x", "y"] + assert rows[0][0].y.__annotations__["types"] == ["bigint", "varchar"] assert str(rows[0][0]) == "(x: Decimal('2.30'), y: (x: 1, y: 'test'))" @@ -1159,11 +1254,14 @@ def test_nested_named_row(trino_connection): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_float_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -1178,11 +1276,14 @@ def test_float_query_param(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_float_nan_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -1198,11 +1299,14 @@ def test_float_nan_query_param(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_float_inf_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -1222,11 +1326,14 @@ def test_float_inf_query_param(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_int_query_param(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) @@ -1247,21 +1354,27 @@ def test_int_query_param(legacy_prepared_statements, run_trino): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], +) +@pytest.mark.parametrize( + "params", + [ + "NOT A LIST OR TUPPLE", + {"invalid", "params"}, + object, + ], ) -@pytest.mark.parametrize('params', [ - 'NOT A LIST OR TUPPLE', - {'invalid', 'params'}, - object, -]) def test_select_query_invalid_params(params, legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) with pytest.raises(AssertionError): - cur.execute('SELECT ?', params=params) + cur.execute("SELECT ?", params=params) def test_select_cursor_iteration(trino_connection): @@ -1281,7 +1394,7 @@ def test_select_cursor_iteration(trino_connection): def test_execute_chaining(trino_connection): cur = trino_connection.cursor() - assert cur.execute('SELECT 1').fetchone()[0] == 1 + assert cur.execute("SELECT 1").fetchone()[0] == 1 def test_select_query_no_result(trino_connection): @@ -1419,8 +1532,7 @@ def test_transaction_multiple(trino_connection_with_transaction): assert len(rows2) == 1000 -@pytest.mark.skipif(trino_version() == '351', reason="Autocommit behaves " - "differently in older Trino versions") +@pytest.mark.skipif(trino_version() == "351", reason="Autocommit behaves " "differently in older Trino versions") def test_transaction_autocommit(trino_connection_in_autocommit): with trino_connection_in_autocommit as connection: connection.start_transaction() @@ -1430,25 +1542,27 @@ def test_transaction_autocommit(trino_connection_in_autocommit): """ CREATE TABLE memory.default.nation AS SELECT * from tpch.tiny.nation - """) + """ + ) cur.fetchall() - assert "Catalog only supports writes using autocommit: memory" \ - in str(transaction_error.value) + assert "Catalog only supports writes using autocommit: memory" in str(transaction_error.value) @pytest.mark.parametrize( "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_invalid_query_throws_correct_error(legacy_prepared_statements, run_trino): - """Tests that an invalid query raises the correct exception - """ + """Tests that an invalid query raises the correct exception""" cur = get_cursor(legacy_prepared_statements, run_trino) with pytest.raises(TrinoQueryError): cur.execute( @@ -1461,14 +1575,14 @@ def test_invalid_query_throws_correct_error(legacy_prepared_statements, run_trin def test_eager_loading_cursor_description(trino_connection): description_expected = [ - ('node_id', 'varchar', None, None, None, None, None), - ('http_uri', 'varchar', None, None, None, None, None), - ('node_version', 'varchar', None, None, None, None, None), - ('coordinator', 'boolean', None, None, None, None, None), - ('state', 'varchar', None, None, None, None, None), + ("node_id", "varchar", None, None, None, None, None), + ("http_uri", "varchar", None, None, None, None, None), + ("node_version", "varchar", None, None, None, None, None), + ("coordinator", "boolean", None, None, None, None, None), + ("state", "varchar", None, None, None, None, None), ] cur = trino_connection.cursor() - cur.execute('SELECT * FROM system.runtime.nodes') + cur.execute("SELECT * FROM system.runtime.nodes") description_before = cur.description assert description_before is not None @@ -1485,7 +1599,7 @@ def test_eager_loading_cursor_description(trino_connection): def test_info_uri(trino_connection): cur = trino_connection.cursor() assert cur.info_uri is None - cur.execute('SELECT * FROM system.runtime.nodes') + cur.execute("SELECT * FROM system.runtime.nodes") assert cur.info_uri is not None assert cur._query.query_id in cur.info_uri cur.fetchall() @@ -1522,44 +1636,43 @@ def retrieve_client_tags_from_query(run_trino, client_tags): ) cur = trino_connection.cursor() - cur.execute('SELECT 1') + cur.execute("SELECT 1") cur.fetchall() api_url = "http://" + trino_connection.host + ":" + str(trino_connection.port) - query_info = requests.post(api_url + "/ui/login", data={ - "username": "admin", - "password": "", - "redirectPath": api_url + '/ui/api/query/' + cur._query.query_id - }).json() + query_info = requests.post( + api_url + "/ui/login", + data={"username": "admin", "password": "", "redirectPath": api_url + "/ui/api/query/" + cur._query.query_id}, + ).json() - query_client_tags = query_info['session']['clientTags'] + query_client_tags = query_info["session"]["clientTags"] return query_client_tags -@pytest.mark.skipif(trino_version() == '351', reason="current_catalog not supported in older Trino versions") +@pytest.mark.skipif(trino_version() == "351", reason="current_catalog not supported in older Trino versions") def test_use_catalog_schema(trino_connection): cur = trino_connection.cursor() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() assert result[0][0] is None assert result[0][1] is None - cur.execute('USE tpch.tiny') + cur.execute("USE tpch.tiny") cur.fetchall() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpch' - assert result[0][1] == 'tiny' + assert result[0][0] == "tpch" + assert result[0][1] == "tiny" - cur.execute('USE tpcds.sf1') + cur.execute("USE tpcds.sf1") cur.fetchall() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpcds' - assert result[0][1] == 'sf1' + assert result[0][0] == "tpcds" + assert result[0][1] == "sf1" -@pytest.mark.skipif(trino_version() == '351', reason="current_catalog not supported in older Trino versions") +@pytest.mark.skipif(trino_version() == "351", reason="current_catalog not supported in older Trino versions") def test_use_schema(run_trino): _, host, port = run_trino @@ -1567,34 +1680,32 @@ def test_use_schema(run_trino): host=host, port=port, user="test", source="test", catalog="tpch", max_attempts=1 ) cur = trino_connection.cursor() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpch' + assert result[0][0] == "tpch" assert result[0][1] is None - cur.execute('USE tiny') + cur.execute("USE tiny") cur.fetchall() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpch' - assert result[0][1] == 'tiny' + assert result[0][0] == "tpch" + assert result[0][1] == "tiny" - cur.execute('USE sf1') + cur.execute("USE sf1") cur.fetchall() - cur.execute('SELECT current_catalog, current_schema') + cur.execute("SELECT current_catalog, current_schema") result = cur.fetchall() - assert result[0][0] == 'tpch' - assert result[0][1] == 'sf1' + assert result[0][0] == "tpch" + assert result[0][1] == "sf1" def test_set_role(run_trino): _, host, port = run_trino - trino_connection = trino.dbapi.Connection( - host=host, port=port, user="test", catalog="tpch" - ) + trino_connection = trino.dbapi.Connection(host=host, port=port, user="test", catalog="tpch") cur = trino_connection.cursor() - cur.execute('SHOW TABLES FROM information_schema') + cur.execute("SHOW TABLES FROM information_schema") cur.fetchall() assert cur._request._client_session.roles == {} @@ -1614,7 +1725,7 @@ def test_set_role_in_connection(run_trino): host=host, port=port, user="test", catalog="tpch", roles={"system": "ALL"} ) cur = trino_connection.cursor() - cur.execute('SHOW TABLES FROM information_schema') + cur.execute("SHOW TABLES FROM information_schema") cur.fetchall() assert_role_headers(cur, "system=ALL") @@ -1622,11 +1733,9 @@ def test_set_role_in_connection(run_trino): def test_set_system_role_in_connection(run_trino): _, host, port = run_trino - trino_connection = trino.dbapi.Connection( - host=host, port=port, user="test", catalog="tpch", roles="ALL" - ) + trino_connection = trino.dbapi.Connection(host=host, port=port, user="test", catalog="tpch", roles="ALL") cur = trino_connection.cursor() - cur.execute('SHOW TABLES FROM information_schema') + cur.execute("SHOW TABLES FROM information_schema") cur.fetchall() assert_role_headers(cur, "system=ALL") @@ -1639,37 +1748,37 @@ def assert_role_headers(cursor, expected_header): "legacy_prepared_statements", [ True, - pytest.param(None, marks=pytest.mark.skipif( - trino_version() > '417', - reason="This would use EXECUTE IMMEDIATE")) - ] + pytest.param( + None, marks=pytest.mark.skipif(trino_version() > "417", reason="This would use EXECUTE IMMEDIATE") + ), + ], ) def test_prepared_statements(legacy_prepared_statements, run_trino): cur = get_cursor(legacy_prepared_statements, run_trino) # Implicit prepared statements must work and deallocate statements on finish assert cur._request._client_session.prepared_statements == {} - cur.execute('SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?', (1,)) + cur.execute("SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?", (1,)) result = cur.fetchall() assert result[0][0] == 1 assert cur._request._client_session.prepared_statements == {} # Explicit prepared statements must also work - cur.execute('PREPARE test_prepared_statements FROM SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?') + cur.execute("PREPARE test_prepared_statements FROM SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?") cur.fetchall() - assert 'test_prepared_statements' in cur._request._client_session.prepared_statements - cur.execute('EXECUTE test_prepared_statements USING 1') + assert "test_prepared_statements" in cur._request._client_session.prepared_statements + cur.execute("EXECUTE test_prepared_statements USING 1") cur.fetchall() assert result[0][0] == 1 # An implicit prepared statement must not deallocate explicit statements - cur.execute('SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?', (1,)) + cur.execute("SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?", (1,)) result = cur.fetchall() assert result[0][0] == 1 - assert 'test_prepared_statements' in cur._request._client_session.prepared_statements + assert "test_prepared_statements" in cur._request._client_session.prepared_statements - assert 'test_prepared_statements' in cur._request._client_session.prepared_statements - cur.execute('DEALLOCATE PREPARE test_prepared_statements') + assert "test_prepared_statements" in cur._request._client_session.prepared_statements + cur.execute("DEALLOCATE PREPARE test_prepared_statements") cur.fetchall() assert cur._request._client_session.prepared_statements == {} @@ -1681,7 +1790,7 @@ def test_set_timezone_in_connection(run_trino): host=host, port=port, user="test", catalog="tpch", timezone="Europe/Brussels" ) cur = trino_connection.cursor() - cur.execute('SELECT current_timezone()') + cur.execute("SELECT current_timezone()") res = cur.fetchall() assert res[0][0] == "Europe/Brussels" @@ -1689,32 +1798,33 @@ def test_set_timezone_in_connection(run_trino): def test_connection_without_timezone(run_trino): _, host, port = run_trino - trino_connection = trino.dbapi.Connection( - host=host, port=port, user="test", catalog="tpch" - ) + trino_connection = trino.dbapi.Connection(host=host, port=port, user="test", catalog="tpch") cur = trino_connection.cursor() - cur.execute('SELECT current_timezone()') + cur.execute("SELECT current_timezone()") res = cur.fetchall() session_tz = res[0][0] localzone = get_localzone_name() - assert session_tz == localzone or \ - (session_tz == "UTC" and localzone == "Etc/UTC") \ - # Workaround for difference between Trino timezone and tzlocal for UTC + assert session_tz == localzone or ( + session_tz == "UTC" and localzone == "Etc/UTC" + ) # Workaround for difference between Trino timezone and tzlocal for UTC def test_describe(run_trino): _, host, port = run_trino trino_connection = trino.dbapi.Connection( - host=host, port=port, user="test", catalog="tpch", + host=host, + port=port, + user="test", + catalog="tpch", ) cur = trino_connection.cursor() result = cur.describe("SELECT 1, DECIMAL '1.0' as a") assert result == [ - DescribeOutput(name='_col0', catalog='', schema='', table='', type='integer', type_size=4, aliased=False), - DescribeOutput(name='a', catalog='', schema='', table='', type='decimal(2,1)', type_size=8, aliased=True) + DescribeOutput(name="_col0", catalog="", schema="", table="", type="integer", type_size=4, aliased=False), + DescribeOutput(name="a", catalog="", schema="", table="", type="decimal(2,1)", type_size=8, aliased=True), ] @@ -1722,7 +1832,10 @@ def test_describe_table_query(run_trino): _, host, port = run_trino trino_connection = trino.dbapi.Connection( - host=host, port=port, user="test", catalog="tpch", + host=host, + port=port, + user="test", + catalog="tpch", ) cur = trino_connection.cursor() @@ -1730,41 +1843,41 @@ def test_describe_table_query(run_trino): assert result == [ DescribeOutput( - name='nationkey', - catalog='tpch', - schema='tiny', - table='nation', - type='bigint', + name="nationkey", + catalog="tpch", + schema="tiny", + table="nation", + type="bigint", type_size=8, aliased=False, ), DescribeOutput( - name='name', - catalog='tpch', - schema='tiny', - table='nation', - type='varchar(25)', + name="name", + catalog="tpch", + schema="tiny", + table="nation", + type="varchar(25)", type_size=0, aliased=False, ), DescribeOutput( - name='regionkey', - catalog='tpch', - schema='tiny', - table='nation', - type='bigint', + name="regionkey", + catalog="tpch", + schema="tiny", + table="nation", + type="bigint", type_size=8, aliased=False, ), DescribeOutput( - name='comment', - catalog='tpch', - schema='tiny', - table='nation', - type='varchar(152)', + name="comment", + catalog="tpch", + schema="tiny", + table="nation", + type="varchar(152)", type_size=0, aliased=False, - ) + ), ] @@ -1781,10 +1894,10 @@ def test_rowcount_create_table(trino_connection): def test_rowcount_create_table_as_select(trino_connection): - with _TestTable( - trino_connection, - "memory.default.test_rowcount_ctas", "AS SELECT 1 a UNION ALL SELECT 2" - ) as (_, cur): + with _TestTable(trino_connection, "memory.default.test_rowcount_ctas", "AS SELECT 1 a UNION ALL SELECT 2") as ( + _, + cur, + ): assert cur.rowcount == 2 @@ -1798,11 +1911,14 @@ def test_rowcount_insert(trino_connection): "legacy_prepared_statements", [ True, - pytest.param(False, marks=pytest.mark.skipif( - trino_version() <= '417', - reason="EXECUTE IMMEDIATE was introduced in version 418")), - None - ] + pytest.param( + False, + marks=pytest.mark.skipif( + trino_version() <= "417", reason="EXECUTE IMMEDIATE was introduced in version 418" + ), + ), + None, + ], ) def test_prepared_statement_capability_autodetection(legacy_prepared_statements, run_trino): # start with an empty cache @@ -1853,14 +1969,11 @@ def assert_cursor_description(cur, trino_type, size=None, precision=None, scale= class _TestTable: def __init__(self, conn, table_name_prefix, table_definition) -> None: self._conn = conn - self._table_name = table_name_prefix + '_' + str(uuid.uuid4().hex) + self._table_name = table_name_prefix + "_" + str(uuid.uuid4().hex) self._table_definition = table_definition def __enter__(self) -> Tuple["_TestTable", Cursor]: - return ( - self, - self._conn.cursor().execute(f"CREATE TABLE {self._table_name} {self._table_definition}") - ) + return (self, self._conn.cursor().execute(f"CREATE TABLE {self._table_name} {self._table_definition}")) def __exit__(self, exc_type, exc_value, exc_tb) -> None: self._conn.cursor().execute(f"DROP TABLE {self._table_name}") diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 9d813365..60f39955 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -31,22 +31,20 @@ def trino_connection(run_trino, request): _, host, port = run_trino connect_args = {"source": "test", "max_attempts": 1} - if trino_version() <= '417': + if trino_version() <= "417": connect_args["legacy_prepared_statements"] = True - engine = sqla.create_engine(f"trino://test@{host}:{port}/{request.param}", - connect_args=connect_args) + engine = sqla.create_engine(f"trino://test@{host}:{port}/{request.param}", connect_args=connect_args) yield engine, engine.connect() @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_select_query(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) assert_column(nations, "nationkey", sqla.sql.sqltypes.BigInteger) assert_column(nations, "name", sqla.sql.sqltypes.String) assert_column(nations, "regionkey", sqla.sql.sqltypes.BigInteger) @@ -68,14 +66,13 @@ def assert_column(table, column_name, column_type): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) -@pytest.mark.parametrize('trino_connection', ['system'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["system"], indirect=True) def test_select_specific_columns(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nodes = sqla.Table('nodes', metadata, schema='runtime', autoload_with=conn) + nodes = sqla.Table("nodes", metadata, schema="runtime", autoload_with=conn) assert_column(nodes, "node_id", sqla.sql.sqltypes.String) assert_column(nodes, "state", sqla.sql.sqltypes.String) query = sqla.select(nodes.c.node_id, nodes.c.state) @@ -88,10 +85,9 @@ def test_select_specific_columns(trino_connection): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_define_and_create_table(trino_connection): engine, conn = trino_connection if not engine.dialect.has_schema(conn, "test"): @@ -99,15 +95,17 @@ def test_define_and_create_table(trino_connection): connection.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() try: - sqla.Table('users', - metadata, - sqla.Column('id', sqla.Integer), - sqla.Column('name', sqla.String), - sqla.Column('fullname', sqla.String), - schema="test") + sqla.Table( + "users", + metadata, + sqla.Column("id", sqla.Integer), + sqla.Column("name", sqla.String), + sqla.Column("fullname", sqla.String), + schema="test", + ) metadata.create_all(engine) - assert sqla.inspect(engine).has_table('users', schema="test") - users = sqla.Table('users', metadata, schema='test', autoload_with=conn) + assert sqla.inspect(engine).has_table("users", schema="test") + users = sqla.Table("users", metadata, schema="test", autoload_with=conn) assert_column(users, "id", sqla.sql.sqltypes.Integer) assert_column(users, "name", sqla.sql.sqltypes.String) assert_column(users, "fullname", sqla.sql.sqltypes.String) @@ -116,10 +114,9 @@ def test_define_and_create_table(trino_connection): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_insert(trino_connection): engine, conn = trino_connection @@ -128,12 +125,14 @@ def test_insert(trino_connection): connection.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() try: - users = sqla.Table('users', - metadata, - sqla.Column('id', sqla.Integer), - sqla.Column('name', sqla.String), - sqla.Column('fullname', sqla.String), - schema="test") + users = sqla.Table( + "users", + metadata, + sqla.Column("id", sqla.Integer), + sqla.Column("name", sqla.String), + sqla.Column("fullname", sqla.String), + schema="test", + ) metadata.create_all(engine) ins = users.insert() conn.execute(ins, {"id": 2, "name": "wendy", "fullname": "Wendy Williams"}) @@ -146,11 +145,8 @@ def test_insert(trino_connection): metadata.drop_all(engine) -@pytest.mark.skipif( - sqlalchemy_version() < "2.0", - reason="sqlalchemy.Uuid only exists with SQLAlchemy 2.0 and above" -) -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.skipif(sqlalchemy_version() < "2.0", reason="sqlalchemy.Uuid only exists with SQLAlchemy 2.0 and above") +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_define_and_create_table_uuid(trino_connection): engine, conn = trino_connection if not engine.dialect.has_schema(conn, "test"): @@ -158,23 +154,17 @@ def test_define_and_create_table_uuid(trino_connection): connection.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() try: - sqla.Table('users', - metadata, - sqla.Column('guid', sqla.Uuid), - schema="test") + sqla.Table("users", metadata, sqla.Column("guid", sqla.Uuid), schema="test") metadata.create_all(engine) - assert sqla.inspect(engine).has_table('users', schema="test") - users = sqla.Table('users', metadata, schema='test', autoload_with=conn) + assert sqla.inspect(engine).has_table("users", schema="test") + users = sqla.Table("users", metadata, schema="test", autoload_with=conn) assert_column(users, "guid", sqla.sql.sqltypes.Uuid) finally: metadata.drop_all(engine) -@pytest.mark.skipif( - sqlalchemy_version() < "2.0", - reason="sqlalchemy.Uuid only exists with SQLAlchemy 2.0 and above" -) -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.skipif(sqlalchemy_version() < "2.0", reason="sqlalchemy.Uuid only exists with SQLAlchemy 2.0 and above") +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_insert_uuid(trino_connection): engine, conn = trino_connection @@ -183,10 +173,7 @@ def test_insert_uuid(trino_connection): connection.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() try: - users = sqla.Table('users', - metadata, - sqla.Column('guid', sqla.Uuid), - schema="test") + users = sqla.Table("users", metadata, sqla.Column("guid", sqla.Uuid), schema="test") metadata.create_all(engine) ins = users.insert() guid = uuid.uuid4() @@ -201,50 +188,55 @@ def test_insert_uuid(trino_connection): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_insert_multiple_statements(trino_connection): engine, conn = trino_connection if not engine.dialect.has_schema(conn, "test"): with engine.begin() as connection: connection.execute(sqla.schema.CreateSchema("test")) metadata = sqla.MetaData() - users = sqla.Table('users', - metadata, - sqla.Column('id', sqla.Integer), - sqla.Column('name', sqla.String), - sqla.Column('fullname', sqla.String), - schema="test") + users = sqla.Table( + "users", + metadata, + sqla.Column("id", sqla.Integer), + sqla.Column("name", sqla.String), + sqla.Column("fullname", sqla.String), + schema="test", + ) metadata.create_all(engine) ins = users.insert() - conn.execute(ins, [ - {"id": 2, "name": "wendy", "fullname": "Wendy Williams"}, - {"id": 3, "name": "john", "fullname": "John Doe"}, - {"id": 4, "name": "mary", "fullname": "Mary Hopkins"}, - ]) + conn.execute( + ins, + [ + {"id": 2, "name": "wendy", "fullname": "Wendy Williams"}, + {"id": 3, "name": "john", "fullname": "John Doe"}, + {"id": 4, "name": "mary", "fullname": "Mary Hopkins"}, + ], + ) query = sqla.select(users) result = conn.execute(query) rows = result.fetchall() assert len(rows) == 3 - assert frozenset(rows) == frozenset([ - (2, "wendy", "Wendy Williams"), - (3, "john", "John Doe"), - (4, "mary", "Mary Hopkins"), - ]) + assert frozenset(rows) == frozenset( + [ + (2, "wendy", "Wendy Williams"), + (3, "john", "John Doe"), + (4, "mary", "Mary Hopkins"), + ] + ) metadata.drop_all(engine) @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_operators(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - customers = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) + customers = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) query = sqla.select(customers).where(customers.c.nationkey == 2) result = conn.execute(query) rows = result.fetchall() @@ -257,28 +249,27 @@ def test_operators(trino_connection): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_conjunctions(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn) - query = sqla.select(customers).where(and_( - customers.c.name.like('%12%'), - customers.c.nationkey == 15, - or_( - customers.c.mktsegment == 'AUTOMOBILE', - customers.c.mktsegment == 'HOUSEHOLD' - ), - not_(customers.c.acctbal < 0))) + customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn) + query = sqla.select(customers).where( + and_( + customers.c.name.like("%12%"), + customers.c.nationkey == 15, + or_(customers.c.mktsegment == "AUTOMOBILE", customers.c.mktsegment == "HOUSEHOLD"), + not_(customers.c.acctbal < 0), + ) + ) result = conn.execute(query) rows = result.fetchall() assert len(rows) == 1 -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_textual_sql(trino_connection): _, conn = trino_connection s = sqla.text("SELECT * from tiny.customer where nationkey = :e1 AND acctbal < :e2") @@ -297,38 +288,41 @@ def test_textual_sql(trino_connection): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_alias(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) nations1 = nations.alias("o1") nations2 = nations.alias("o2") - s = sqla.select(nations1) \ - .join(nations2, and_( - nations1.c.regionkey == nations2.c.regionkey, - nations1.c.nationkey != nations2.c.nationkey, - nations1.c.regionkey == 1 - )) \ + s = ( + sqla.select(nations1) + .join( + nations2, + and_( + nations1.c.regionkey == nations2.c.regionkey, + nations1.c.nationkey != nations2.c.nationkey, + nations1.c.regionkey == 1, + ), + ) .distinct() + ) result = conn.execute(s) rows = result.fetchall() assert len(rows) == 5 @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_subquery(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) - customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn) + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) + customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn) automobile_customers = sqla.select(customers.c.nationkey).where(customers.c.acctbal < -900) automobile_customers_subquery = automobile_customers.subquery() s = sqla.select(nations.c.name).where(nations.c.nationkey.in_(sqla.select(automobile_customers_subquery))) @@ -338,34 +332,34 @@ def test_subquery(trino_connection): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_joins(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) - customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn) - s = sqla.select(nations.c.name) \ - .select_from(nations.join(customers, nations.c.nationkey == customers.c.nationkey)) \ - .where(customers.c.acctbal < -900) \ + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) + customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn) + s = ( + sqla.select(nations.c.name) + .select_from(nations.join(customers, nations.c.nationkey == customers.c.nationkey)) + .where(customers.c.acctbal < -900) .distinct() + ) result = conn.execute(s) rows = result.fetchall() assert len(rows) == 15 @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) -@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["tpch"], indirect=True) def test_cte(trino_connection): _, conn = trino_connection metadata = sqla.MetaData() - nations = sqla.Table('nation', metadata, schema='tiny', autoload_with=conn) - customers = sqla.Table('customer', metadata, schema='tiny', autoload_with=conn) + nations = sqla.Table("nation", metadata, schema="tiny", autoload_with=conn) + customers = sqla.Table("customer", metadata, schema="tiny", autoload_with=conn) automobile_customers = sqla.select(customers.c.nationkey).where(customers.c.acctbal < -900) automobile_customers_cte = automobile_customers.cte() s = sqla.select(nations).where(nations.c.nationkey.in_(sqla.select(automobile_customers_cte))) @@ -375,19 +369,18 @@ def test_cte(trino_connection): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) @pytest.mark.parametrize( - 'trino_connection,json_object', + "trino_connection,json_object", [ - ('memory', None), - ('memory', 1), - ('memory', 'test'), - ('memory', [1, 'test']), - ('memory', {'test': 1}), + ("memory", None), + ("memory", 1), + ("memory", "test"), + ("memory", [1, "test"]), + ("memory", {"test": 1}), ], - indirect=['trino_connection'] + indirect=["trino_connection"], ) def test_json_column(trino_connection, json_object): engine, conn = trino_connection @@ -399,11 +392,11 @@ def test_json_column(trino_connection, json_object): try: table_with_json = sqla.Table( - 'table_with_json', + "table_with_json", metadata, - sqla.Column('id', sqla.Integer), - sqla.Column('json_column', JSON), - schema="test" + sqla.Column("id", sqla.Integer), + sqla.Column("json_column", JSON), + schema="test", ) metadata.create_all(engine) ins = table_with_json.insert() @@ -419,30 +412,18 @@ def test_json_column(trino_connection, json_object): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_json_column_operations(trino_connection): engine, conn = trino_connection metadata = sqla.MetaData() - json_object = { - "a": {"c": 1}, - 100: {"z": 200}, - "b": 2, - 10: 20, - "foo-bar": {"z": 200} - } + json_object = {"a": {"c": 1}, 100: {"z": 200}, "b": 2, 10: 20, "foo-bar": {"z": 200}} try: - table_with_json = sqla.Table( - 'table_with_json', - metadata, - sqla.Column('json_column', JSON), - schema="default" - ) + table_with_json = sqla.Table("table_with_json", metadata, sqla.Column("json_column", JSON), schema="default") metadata.create_all(engine) ins = table_with_json.insert() conn.execute(ins, {"json_column": json_object}) @@ -477,33 +458,34 @@ def test_json_column_operations(trino_connection): query = sqla.select(table_with_json.c.json_column["foo-bar"]) conn.execute(query) result = conn.execute(query) - assert result.fetchall()[0][0] == {'z': 200} + assert result.fetchall()[0][0] == {"z": 200} finally: metadata.drop_all(engine) @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) @pytest.mark.parametrize( - 'trino_connection,map_object,sqla_type', + "trino_connection,map_object,sqla_type", [ - ('memory', None, MAP(sqla.sql.sqltypes.String, sqla.sql.sqltypes.Integer)), - ('memory', {}, MAP(sqla.sql.sqltypes.String, sqla.sql.sqltypes.Integer)), - ('memory', {True: False, False: True}, MAP(sqla.sql.sqltypes.Boolean, sqla.sql.sqltypes.Boolean)), - ('memory', {1: 1, 2: None}, MAP(sqla.sql.sqltypes.Integer, sqla.sql.sqltypes.Integer)), - ('memory', {1.4: 1.4, math.inf: math.inf}, MAP(sqla.sql.sqltypes.Float, sqla.sql.sqltypes.Float)), - ('memory', {1.4: 1.4, math.inf: math.inf}, MAP(sqla.sql.sqltypes.REAL, sqla.sql.sqltypes.REAL)), - ('memory', - {Decimal("1.2"): Decimal("1.2")}, - MAP(sqla.sql.sqltypes.DECIMAL(2, 1), sqla.sql.sqltypes.DECIMAL(2, 1))), - ('memory', {"hello": "world"}, MAP(sqla.sql.sqltypes.String, sqla.sql.sqltypes.String)), - ('memory', {"a ": "a", "null": "n"}, MAP(sqla.sql.sqltypes.CHAR(4), sqla.sql.sqltypes.CHAR(1))), - ('memory', {b'': b'eh?', b'\x00': None}, MAP(sqla.sql.sqltypes.BINARY, sqla.sql.sqltypes.BINARY)), + ("memory", None, MAP(sqla.sql.sqltypes.String, sqla.sql.sqltypes.Integer)), + ("memory", {}, MAP(sqla.sql.sqltypes.String, sqla.sql.sqltypes.Integer)), + ("memory", {True: False, False: True}, MAP(sqla.sql.sqltypes.Boolean, sqla.sql.sqltypes.Boolean)), + ("memory", {1: 1, 2: None}, MAP(sqla.sql.sqltypes.Integer, sqla.sql.sqltypes.Integer)), + ("memory", {1.4: 1.4, math.inf: math.inf}, MAP(sqla.sql.sqltypes.Float, sqla.sql.sqltypes.Float)), + ("memory", {1.4: 1.4, math.inf: math.inf}, MAP(sqla.sql.sqltypes.REAL, sqla.sql.sqltypes.REAL)), + ( + "memory", + {Decimal("1.2"): Decimal("1.2")}, + MAP(sqla.sql.sqltypes.DECIMAL(2, 1), sqla.sql.sqltypes.DECIMAL(2, 1)), + ), + ("memory", {"hello": "world"}, MAP(sqla.sql.sqltypes.String, sqla.sql.sqltypes.String)), + ("memory", {"a ": "a", "null": "n"}, MAP(sqla.sql.sqltypes.CHAR(4), sqla.sql.sqltypes.CHAR(1))), + ("memory", {b"": b"eh?", b"\x00": None}, MAP(sqla.sql.sqltypes.BINARY, sqla.sql.sqltypes.BINARY)), ], - indirect=['trino_connection'] + indirect=["trino_connection"], ) def test_map_column(trino_connection, map_object, sqla_type): engine, conn = trino_connection @@ -515,11 +497,11 @@ def test_map_column(trino_connection, map_object, sqla_type): try: table_with_map = sqla.Table( - 'table_with_map', + "table_with_map", metadata, - sqla.Column('id', sqla.Integer), - sqla.Column('map_column', sqla_type), - schema="test" + sqla.Column("id", sqla.Integer), + sqla.Column("map_column", sqla_type), + schema="test", ) metadata.create_all(engine) ins = table_with_map.insert() @@ -534,23 +516,22 @@ def test_map_column(trino_connection, map_object, sqla_type): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) @pytest.mark.parametrize( - 'trino_connection,array_object,sqla_type', + "trino_connection,array_object,sqla_type", [ - ('memory', None, ARRAY(sqla.sql.sqltypes.String)), - ('memory', [], ARRAY(sqla.sql.sqltypes.String)), - ('memory', [True, False, True], ARRAY(sqla.sql.sqltypes.Boolean)), - ('memory', [1, 2, None], ARRAY(sqla.sql.sqltypes.Integer)), - ('memory', [1.4, 2.3, math.inf], ARRAY(sqla.sql.sqltypes.Float)), - ('memory', [Decimal("1.2"), Decimal("2.3")], ARRAY(sqla.sql.sqltypes.DECIMAL(2, 1))), - ('memory', ["hello", "world"], ARRAY(sqla.sql.sqltypes.String)), - ('memory', ["a ", "null"], ARRAY(sqla.sql.sqltypes.CHAR(4))), - ('memory', [b'eh?', None, b'\x00'], ARRAY(sqla.sql.sqltypes.BINARY)), + ("memory", None, ARRAY(sqla.sql.sqltypes.String)), + ("memory", [], ARRAY(sqla.sql.sqltypes.String)), + ("memory", [True, False, True], ARRAY(sqla.sql.sqltypes.Boolean)), + ("memory", [1, 2, None], ARRAY(sqla.sql.sqltypes.Integer)), + ("memory", [1.4, 2.3, math.inf], ARRAY(sqla.sql.sqltypes.Float)), + ("memory", [Decimal("1.2"), Decimal("2.3")], ARRAY(sqla.sql.sqltypes.DECIMAL(2, 1))), + ("memory", ["hello", "world"], ARRAY(sqla.sql.sqltypes.String)), + ("memory", ["a ", "null"], ARRAY(sqla.sql.sqltypes.CHAR(4))), + ("memory", [b"eh?", None, b"\x00"], ARRAY(sqla.sql.sqltypes.BINARY)), ], - indirect=['trino_connection'] + indirect=["trino_connection"], ) def test_array_column(trino_connection, array_object, sqla_type): engine, conn = trino_connection @@ -562,11 +543,11 @@ def test_array_column(trino_connection, array_object, sqla_type): try: table_with_array = sqla.Table( - 'table_with_array', + "table_with_array", metadata, - sqla.Column('id', sqla.Integer), - sqla.Column('array_column', sqla_type), - schema="test" + sqla.Column("id", sqla.Integer), + sqla.Column("array_column", sqla_type), + schema="test", ) metadata.create_all(engine) ins = table_with_array.insert() @@ -581,32 +562,42 @@ def test_array_column(trino_connection, array_object, sqla_type): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) @pytest.mark.parametrize( - 'trino_connection,row_object,sqla_type', + "trino_connection,row_object,sqla_type", [ - ('memory', None, ROW([('field1', sqla.sql.sqltypes.String), - ('field2', sqla.sql.sqltypes.String)])), - ('memory', ('hello', 'world'), ROW([('field1', sqla.sql.sqltypes.String), - ('field2', sqla.sql.sqltypes.String)])), - ('memory', (True, False), ROW([('field1', sqla.sql.sqltypes.Boolean), - ('field2', sqla.sql.sqltypes.Boolean)])), - ('memory', (1, 2), ROW([('field1', sqla.sql.sqltypes.Integer), - ('field2', sqla.sql.sqltypes.Integer)])), - ('memory', (1.4, float('inf')), ROW([('field1', sqla.sql.sqltypes.Float), - ('field2', sqla.sql.sqltypes.Float)])), - ('memory', (Decimal("1.2"), Decimal("2.3")), ROW([('field1', sqla.sql.sqltypes.DECIMAL(2, 1)), - ('field2', sqla.sql.sqltypes.DECIMAL(3, 1))])), - ('memory', ("hello", "world"), ROW([('field1', sqla.sql.sqltypes.String), - ('field2', sqla.sql.sqltypes.String)])), - ('memory', ("a ", "null"), ROW([('field1', sqla.sql.sqltypes.CHAR(4)), - ('field2', sqla.sql.sqltypes.CHAR(4))])), - ('memory', (b'eh?', b'oh?'), ROW([('field1', sqla.sql.sqltypes.BINARY), - ('field2', sqla.sql.sqltypes.BINARY)])), + ("memory", None, ROW([("field1", sqla.sql.sqltypes.String), ("field2", sqla.sql.sqltypes.String)])), + ( + "memory", + ("hello", "world"), + ROW([("field1", sqla.sql.sqltypes.String), ("field2", sqla.sql.sqltypes.String)]), + ), + ("memory", (True, False), ROW([("field1", sqla.sql.sqltypes.Boolean), ("field2", sqla.sql.sqltypes.Boolean)])), + ("memory", (1, 2), ROW([("field1", sqla.sql.sqltypes.Integer), ("field2", sqla.sql.sqltypes.Integer)])), + ( + "memory", + (1.4, float("inf")), + ROW([("field1", sqla.sql.sqltypes.Float), ("field2", sqla.sql.sqltypes.Float)]), + ), + ( + "memory", + (Decimal("1.2"), Decimal("2.3")), + ROW([("field1", sqla.sql.sqltypes.DECIMAL(2, 1)), ("field2", sqla.sql.sqltypes.DECIMAL(3, 1))]), + ), + ( + "memory", + ("hello", "world"), + ROW([("field1", sqla.sql.sqltypes.String), ("field2", sqla.sql.sqltypes.String)]), + ), + ( + "memory", + ("a ", "null"), + ROW([("field1", sqla.sql.sqltypes.CHAR(4)), ("field2", sqla.sql.sqltypes.CHAR(4))]), + ), + ("memory", (b"eh?", b"oh?"), ROW([("field1", sqla.sql.sqltypes.BINARY), ("field2", sqla.sql.sqltypes.BINARY)])), ], - indirect=['trino_connection'] + indirect=["trino_connection"], ) def test_row_column(trino_connection, row_object, sqla_type): engine, conn = trino_connection @@ -618,11 +609,11 @@ def test_row_column(trino_connection, row_object, sqla_type): try: table_with_row = sqla.Table( - 'table_with_row', + "table_with_row", metadata, - sqla.Column('id', sqla.Integer), - sqla.Column('row_column', sqla_type), - schema="test" + sqla.Column("id", sqla.Integer), + sqla.Column("row_column", sqla_type), + schema="test", ) metadata.create_all(engine) ins = table_with_row.insert() @@ -636,7 +627,7 @@ def test_row_column(trino_connection, row_object, sqla_type): metadata.drop_all(engine) -@pytest.mark.parametrize('trino_connection', ['system'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["system"], indirect=True) def test_get_catalog_names(trino_connection): engine, conn = trino_connection @@ -645,7 +636,7 @@ def test_get_catalog_names(trino_connection): assert set(schemas) == {"jmx", "memory", "system", "tpcds", "tpch"} -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_get_table_comment(trino_connection): engine, conn = trino_connection @@ -656,22 +647,22 @@ def test_get_table_comment(trino_connection): try: sqla.Table( - 'table_with_id', + "table_with_id", metadata, - sqla.Column('id', sqla.Integer), + sqla.Column("id", sqla.Integer), schema="test", # comment="This is a comment" TODO: Support comment creation through sqlalchemy api ) metadata.create_all(engine) insp = sqla.inspect(engine) - actual = insp.get_table_comment(table_name='table_with_id', schema="test") - assert actual['text'] is None + actual = insp.get_table_comment(table_name="table_with_id", schema="test") + assert actual["text"] is None finally: metadata.drop_all(engine) -@pytest.mark.parametrize('trino_connection', ['memory/test'], indirect=True) -@pytest.mark.parametrize('schema', [None, 'test']) +@pytest.mark.parametrize("trino_connection", ["memory/test"], indirect=True) +@pytest.mark.parametrize("schema", [None, "test"]) def test_get_table_names(trino_connection, schema): engine, conn = trino_connection schema_name = schema or engine.dialect._get_default_schema_name(conn) @@ -683,20 +674,20 @@ def test_get_table_names(trino_connection, schema): try: sqla.Table( - 'test_get_table_names', + "test_get_table_names", metadata, - sqla.Column('id', sqla.Integer), + sqla.Column("id", sqla.Integer), ) metadata.create_all(engine) view_name = schema_name + ".test_view" conn.execute(sqla.text(f"CREATE VIEW {view_name} AS SELECT * FROM test_get_table_names")) - assert sqla.inspect(engine).get_table_names(schema_name) == ['test_get_table_names'] + assert sqla.inspect(engine).get_table_names(schema_name) == ["test_get_table_names"] finally: conn.execute(sqla.text(f"DROP VIEW IF EXISTS {view_name}")) metadata.drop_all(engine) -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_get_table_names_raises(trino_connection): engine, _ = trino_connection @@ -704,8 +695,8 @@ def test_get_table_names_raises(trino_connection): sqla.inspect(engine).get_table_names(None) -@pytest.mark.parametrize('trino_connection', ['memory/test'], indirect=True) -@pytest.mark.parametrize('schema', [None, 'test']) +@pytest.mark.parametrize("trino_connection", ["memory/test"], indirect=True) +@pytest.mark.parametrize("schema", [None, "test"]) def test_get_view_names(trino_connection, schema): engine, conn = trino_connection schema_name = schema or engine.dialect._get_default_schema_name(conn) @@ -717,20 +708,20 @@ def test_get_view_names(trino_connection, schema): try: sqla.Table( - 'test_table', + "test_table", metadata, - sqla.Column('id', sqla.Integer), + sqla.Column("id", sqla.Integer), ) metadata.create_all(engine) view_name = schema_name + ".test_get_view_names" conn.execute(sqla.text(f"CREATE VIEW {view_name} AS SELECT * FROM test_table")) - assert sqla.inspect(engine).get_view_names(schema_name) == ['test_get_view_names'] + assert sqla.inspect(engine).get_view_names(schema_name) == ["test_get_view_names"] finally: conn.execute(sqla.text(f"DROP VIEW IF EXISTS {view_name}")) metadata.drop_all(engine) -@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +@pytest.mark.parametrize("trino_connection", ["memory"], indirect=True) def test_get_view_names_raises(trino_connection): engine, _ = trino_connection @@ -738,8 +729,8 @@ def test_get_view_names_raises(trino_connection): sqla.inspect(engine).get_view_names(None) -@pytest.mark.parametrize('trino_connection', ['system'], indirect=True) -@pytest.mark.skipif(trino_version() == '351', reason="version() not supported in older Trino versions") +@pytest.mark.parametrize("trino_connection", ["system"], indirect=True) +@pytest.mark.skipif(trino_version() == "351", reason="version() not supported in older Trino versions") def test_version_is_lazy(trino_connection): _, conn = trino_connection result = conn.execute(sqla.text("SELECT 1")) diff --git a/tests/integration/test_types_integration.py b/tests/integration/test_types_integration.py index 881619d0..4a789b85 100644 --- a/tests/integration/test_types_integration.py +++ b/tests/integration/test_types_integration.py @@ -21,136 +21,125 @@ def trino_connection(run_trino): _, host, port = run_trino - yield trino.dbapi.Connection( - host=host, port=port, user="test", source="test", max_attempts=1 - ) + yield trino.dbapi.Connection(host=host, port=port, user="test", source="test", max_attempts=1) def test_boolean(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="CAST(null AS BOOLEAN)", python=None) \ - .add_field(sql="false", python=False) \ - .add_field(sql="true", python=True) \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST(null AS BOOLEAN)", python=None).add_field( + sql="false", python=False + ).add_field(sql="true", python=True).execute() def test_tinyint(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="CAST(null AS TINYINT)", python=None) \ - .add_field(sql="CAST(-128 AS TINYINT)", python=-128) \ - .add_field(sql="CAST(42 AS TINYINT)", python=42) \ - .add_field(sql="CAST(127 AS TINYINT)", python=127) \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST(null AS TINYINT)", python=None).add_field( + sql="CAST(-128 AS TINYINT)", python=-128 + ).add_field(sql="CAST(42 AS TINYINT)", python=42).add_field(sql="CAST(127 AS TINYINT)", python=127).execute() def test_smallint(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="CAST(null AS SMALLINT)", python=None) \ - .add_field(sql="CAST(-32768 AS SMALLINT)", python=-32768) \ - .add_field(sql="CAST(42 AS SMALLINT)", python=42) \ - .add_field(sql="CAST(32767 AS SMALLINT)", python=32767) \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST(null AS SMALLINT)", python=None).add_field( + sql="CAST(-32768 AS SMALLINT)", python=-32768 + ).add_field(sql="CAST(42 AS SMALLINT)", python=42).add_field(sql="CAST(32767 AS SMALLINT)", python=32767).execute() def test_int(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="CAST(null AS INTEGER)", python=None) \ - .add_field(sql="CAST(-2147483648 AS INTEGER)", python=-2147483648) \ - .add_field(sql="CAST(83648 AS INTEGER)", python=83648) \ - .add_field(sql="CAST(2147483647 AS INTEGER)", python=2147483647) \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST(null AS INTEGER)", python=None).add_field( + sql="CAST(-2147483648 AS INTEGER)", python=-2147483648 + ).add_field(sql="CAST(83648 AS INTEGER)", python=83648).add_field( + sql="CAST(2147483647 AS INTEGER)", python=2147483647 + ).execute() def test_bigint(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="CAST(null AS BIGINT)", python=None) \ - .add_field(sql="CAST(-9223372036854775808 AS BIGINT)", python=-9223372036854775808) \ - .add_field(sql="CAST(9223 AS BIGINT)", python=9223) \ - .add_field(sql="CAST(9223372036854775807 AS BIGINT)", python=9223372036854775807) \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST(null AS BIGINT)", python=None).add_field( + sql="CAST(-9223372036854775808 AS BIGINT)", python=-9223372036854775808 + ).add_field(sql="CAST(9223 AS BIGINT)", python=9223).add_field( + sql="CAST(9223372036854775807 AS BIGINT)", python=9223372036854775807 + ).execute() def test_real(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="CAST(null AS REAL)", python=None) \ - .add_field(sql="CAST('NaN' AS REAL)", python=math.nan, has_nan=True) \ - .add_field(sql="CAST('-Infinity' AS REAL)", python=-math.inf) \ - .add_field(sql="CAST(3.4028235E38 AS REAL)", python=3.4028235e+38) \ - .add_field(sql="CAST(1.4E-45 AS REAL)", python=1.4e-45) \ - .add_field(sql="CAST('Infinity' AS REAL)", python=math.inf) \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST(null AS REAL)", python=None).add_field( + sql="CAST('NaN' AS REAL)", python=math.nan, has_nan=True + ).add_field(sql="CAST('-Infinity' AS REAL)", python=-math.inf).add_field( + sql="CAST(3.4028235E38 AS REAL)", python=3.4028235e38 + ).add_field( + sql="CAST(1.4E-45 AS REAL)", python=1.4e-45 + ).add_field( + sql="CAST('Infinity' AS REAL)", python=math.inf + ).execute() def test_double(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="CAST(null AS DOUBLE)", python=None) \ - .add_field(sql="CAST('NaN' AS DOUBLE)", python=math.nan, has_nan=True) \ - .add_field(sql="CAST('-Infinity' AS DOUBLE)", python=-math.inf) \ - .add_field(sql="CAST(1.7976931348623157E308 AS DOUBLE)", python=1.7976931348623157e+308) \ - .add_field(sql="CAST(4.9E-324 AS DOUBLE)", python=5e-324) \ - .add_field(sql="CAST('Infinity' AS DOUBLE)", python=math.inf) \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST(null AS DOUBLE)", python=None).add_field( + sql="CAST('NaN' AS DOUBLE)", python=math.nan, has_nan=True + ).add_field(sql="CAST('-Infinity' AS DOUBLE)", python=-math.inf).add_field( + sql="CAST(1.7976931348623157E308 AS DOUBLE)", python=1.7976931348623157e308 + ).add_field( + sql="CAST(4.9E-324 AS DOUBLE)", python=5e-324 + ).add_field( + sql="CAST('Infinity' AS DOUBLE)", python=math.inf + ).execute() def test_decimal(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="CAST(null AS DECIMAL)", python=None) \ - .add_field(sql="CAST(null AS DECIMAL(38,0))", python=None) \ - .add_field(sql="DECIMAL '10.3'", python=Decimal('10.3')) \ - .add_field(sql="CAST('0.123456789123456789' AS DECIMAL(18,18))", python=Decimal('0.123456789123456789')) \ - .add_field(sql="CAST(null AS DECIMAL(18,18))", python=None) \ - .add_field(sql="CAST('234.123456789123456789' AS DECIMAL(18,4))", python=Decimal('234.1235')) \ - .add_field(sql="CAST('10.3' AS DECIMAL(38,1))", python=Decimal('10.3')) \ - .add_field(sql="CAST('0.123456789123456789' AS DECIMAL(18,2))", python=Decimal('0.12')) \ - .add_field(sql="CAST('0.3123' AS DECIMAL(38,38))", python=Decimal('0.3123')) \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST(null AS DECIMAL)", python=None).add_field( + sql="CAST(null AS DECIMAL(38,0))", python=None + ).add_field(sql="DECIMAL '10.3'", python=Decimal("10.3")).add_field( + sql="CAST('0.123456789123456789' AS DECIMAL(18,18))", python=Decimal("0.123456789123456789") + ).add_field( + sql="CAST(null AS DECIMAL(18,18))", python=None + ).add_field( + sql="CAST('234.123456789123456789' AS DECIMAL(18,4))", python=Decimal("234.1235") + ).add_field( + sql="CAST('10.3' AS DECIMAL(38,1))", python=Decimal("10.3") + ).add_field( + sql="CAST('0.123456789123456789' AS DECIMAL(18,2))", python=Decimal("0.12") + ).add_field( + sql="CAST('0.3123' AS DECIMAL(38,38))", python=Decimal("0.3123") + ).execute() def test_varchar(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="'aaa'", python='aaa') \ - .add_field(sql="U&'Hello winter \2603 !'", python='Hello winter °3 !') \ - .add_field(sql="CAST(null AS VARCHAR)", python=None) \ - .add_field(sql="CAST('bbb' AS VARCHAR(1))", python='b') \ - .add_field(sql="CAST(null AS VARCHAR(1))", python=None) \ - .execute() + SqlTest(trino_connection).add_field(sql="'aaa'", python="aaa").add_field( + sql="U&'Hello winter \2603 !'", python="Hello winter °3 !" + ).add_field(sql="CAST(null AS VARCHAR)", python=None).add_field( + sql="CAST('bbb' AS VARCHAR(1))", python="b" + ).add_field( + sql="CAST(null AS VARCHAR(1))", python=None + ).execute() def test_char(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="CAST('ccc' AS CHAR)", python='c') \ - .add_field(sql="CAST('ccc' AS CHAR(5))", python='ccc ') \ - .add_field(sql="CAST(null AS CHAR)", python=None) \ - .add_field(sql="CAST('ddd' AS CHAR(1))", python='d') \ - .add_field(sql="CAST('😂' AS CHAR(1))", python='😂') \ - .add_field(sql="CAST(null AS CHAR(1))", python=None) \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST('ccc' AS CHAR)", python="c").add_field( + sql="CAST('ccc' AS CHAR(5))", python="ccc " + ).add_field(sql="CAST(null AS CHAR)", python=None).add_field(sql="CAST('ddd' AS CHAR(1))", python="d").add_field( + sql="CAST('😂' AS CHAR(1))", python="😂" + ).add_field( + sql="CAST(null AS CHAR(1))", python=None + ).execute() def test_varbinary(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="X'65683F'", python=b'eh?') \ - .add_field(sql="X''", python=b'') \ - .add_field(sql="CAST('' AS VARBINARY)", python=b'') \ - .add_field(sql="CAST(null AS VARBINARY)", python=None) \ - .execute() + SqlTest(trino_connection).add_field(sql="X'65683F'", python=b"eh?").add_field(sql="X''", python=b"").add_field( + sql="CAST('' AS VARBINARY)", python=b"" + ).add_field(sql="CAST(null AS VARBINARY)", python=None).execute() def test_varbinary_failure(trino_connection): - SqlExpectFailureTest(trino_connection) \ - .execute("CAST(42 AS VARBINARY)") + SqlExpectFailureTest(trino_connection).execute("CAST(42 AS VARBINARY)") def test_json(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="CAST('{}' AS JSON)", python='"{}"') \ - .add_field(sql="CAST('null' AS JSON)", python='"null"') \ - .add_field(sql="CAST(null AS JSON)", python=None) \ - .add_field(sql="CAST('3.14' AS JSON)", python='"3.14"') \ - .add_field(sql="CAST('a string' AS JSON)", python='"a string"') \ - .add_field(sql="CAST('a \" complex '' string :' AS JSON)", python='"a \\" complex \' string :"') \ - .add_field(sql="CAST('[]' AS JSON)", python='"[]"') \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST('{}' AS JSON)", python='"{}"').add_field( + sql="CAST('null' AS JSON)", python='"null"' + ).add_field(sql="CAST(null AS JSON)", python=None).add_field(sql="CAST('3.14' AS JSON)", python='"3.14"').add_field( + sql="CAST('a string' AS JSON)", python='"a string"' + ).add_field( + sql="CAST('a \" complex '' string :' AS JSON)", python='"a \\" complex \' string :"' + ).add_field( + sql="CAST('[]' AS JSON)", python='"[]"' + ).execute() def test_date(trino_connection): @@ -168,544 +157,342 @@ def test_date(trino_connection): ).execute() -@pytest.mark.skipif(trino_version() == '351', reason="time not rounded correctly in older Trino versions") +@pytest.mark.skipif(trino_version() == "351", reason="time not rounded correctly in older Trino versions") def test_time(trino_connection): ( SqlTest(trino_connection) - .add_field( - sql="CAST(null AS TIME)", - python=None) - .add_field( - sql="TIME '00:00:00'", - python=time(0, 0, 0)) # min supported time(3) - .add_field( - sql="TIME '12:34:56.123'", - python=time(12, 34, 56, 123000)) - .add_field( - sql="TIME '23:59:59.999'", - python=time(23, 59, 59, 999000)) # max supported time(3) + .add_field(sql="CAST(null AS TIME)", python=None) + .add_field(sql="TIME '00:00:00'", python=time(0, 0, 0)) # min supported time(3) + .add_field(sql="TIME '12:34:56.123'", python=time(12, 34, 56, 123000)) + .add_field(sql="TIME '23:59:59.999'", python=time(23, 59, 59, 999000)) # max supported time(3) # min value for each precision - .add_field( - sql="TIME '00:00:00'", - python=time(0, 0, 0)) - .add_field( - sql="TIME '00:00:00.1'", - python=time(0, 0, 0, 100000)) - .add_field( - sql="TIME '00:00:00.01'", - python=time(0, 0, 0, 10000)) - .add_field( - sql="TIME '00:00:00.001'", - python=time(0, 0, 0, 1000)) - .add_field( - sql="TIME '00:00:00.0001'", - python=time(0, 0, 0, 100)) - .add_field( - sql="TIME '00:00:00.00001'", - python=time(0, 0, 0, 10)) - .add_field( - sql="TIME '00:00:00.000001'", - python=time(0, 0, 0, 1)) + .add_field(sql="TIME '00:00:00'", python=time(0, 0, 0)) + .add_field(sql="TIME '00:00:00.1'", python=time(0, 0, 0, 100000)) + .add_field(sql="TIME '00:00:00.01'", python=time(0, 0, 0, 10000)) + .add_field(sql="TIME '00:00:00.001'", python=time(0, 0, 0, 1000)) + .add_field(sql="TIME '00:00:00.0001'", python=time(0, 0, 0, 100)) + .add_field(sql="TIME '00:00:00.00001'", python=time(0, 0, 0, 10)) + .add_field(sql="TIME '00:00:00.000001'", python=time(0, 0, 0, 1)) # max value for each precision - .add_field( - sql="TIME '23:59:59'", - python=time(23, 59, 59)) - .add_field( - sql="TIME '23:59:59.9'", - python=time(23, 59, 59, 900000)) - .add_field( - sql="TIME '23:59:59.99'", - python=time(23, 59, 59, 990000)) - .add_field( - sql="TIME '23:59:59.999'", - python=time(23, 59, 59, 999000)) - .add_field( - sql="TIME '23:59:59.9999'", - python=time(23, 59, 59, 999900)) - .add_field( - sql="TIME '23:59:59.99999'", - python=time(23, 59, 59, 999990)) - .add_field( - sql="TIME '23:59:59.999999'", - python=time(23, 59, 59, 999999)) + .add_field(sql="TIME '23:59:59'", python=time(23, 59, 59)) + .add_field(sql="TIME '23:59:59.9'", python=time(23, 59, 59, 900000)) + .add_field(sql="TIME '23:59:59.99'", python=time(23, 59, 59, 990000)) + .add_field(sql="TIME '23:59:59.999'", python=time(23, 59, 59, 999000)) + .add_field(sql="TIME '23:59:59.9999'", python=time(23, 59, 59, 999900)) + .add_field(sql="TIME '23:59:59.99999'", python=time(23, 59, 59, 999990)) + .add_field(sql="TIME '23:59:59.999999'", python=time(23, 59, 59, 999999)) # round down - .add_field( - sql="TIME '12:34:56.1234561'", - python=time(12, 34, 56, 123456)) + .add_field(sql="TIME '12:34:56.1234561'", python=time(12, 34, 56, 123456)) # round down, min value - .add_field( - sql="TIME '00:00:00.000000000001'", - python=time(0, 0, 0, 0)) - .add_field( - sql="TIME '00:00:00.0000001'", - python=time(0, 0, 0, 0)) + .add_field(sql="TIME '00:00:00.000000000001'", python=time(0, 0, 0, 0)) + .add_field(sql="TIME '00:00:00.0000001'", python=time(0, 0, 0, 0)) # round down, max value - .add_field( - sql="TIME '00:00:00.0000004'", - python=time(0, 0, 0, 0)) - .add_field( - sql="TIME '00:00:00.00000049'", - python=time(0, 0, 0, 0)) - .add_field( - sql="TIME '00:00:00.0000005'", - python=time(0, 0, 0, 0)) - .add_field( - sql="TIME '00:00:00.00000050'", - python=time(0, 0, 0, 0)) - .add_field( - sql="TIME '23:59:59.9999994'", - python=time(23, 59, 59, 999999)) - .add_field( - sql="TIME '23:59:59.9999994999'", - python=time(23, 59, 59, 999999)) + .add_field(sql="TIME '00:00:00.0000004'", python=time(0, 0, 0, 0)) + .add_field(sql="TIME '00:00:00.00000049'", python=time(0, 0, 0, 0)) + .add_field(sql="TIME '00:00:00.0000005'", python=time(0, 0, 0, 0)) + .add_field(sql="TIME '00:00:00.00000050'", python=time(0, 0, 0, 0)) + .add_field(sql="TIME '23:59:59.9999994'", python=time(23, 59, 59, 999999)) + .add_field(sql="TIME '23:59:59.9999994999'", python=time(23, 59, 59, 999999)) # round up - .add_field( - sql="TIME '12:34:56.123456789'", - python=time(12, 34, 56, 123457)) + .add_field(sql="TIME '12:34:56.123456789'", python=time(12, 34, 56, 123457)) # round up, min value - .add_field( - sql="TIME '00:00:00.000000500001'", - python=time(0, 0, 0, 1)) + .add_field(sql="TIME '00:00:00.000000500001'", python=time(0, 0, 0, 1)) # round up, max value - .add_field( - sql="TIME '00:00:00.0000009'", - python=time(0, 0, 0, 1)) - .add_field( - sql="TIME '00:00:00.00000099999'", - python=time(0, 0, 0, 1)) + .add_field(sql="TIME '00:00:00.0000009'", python=time(0, 0, 0, 1)) + .add_field(sql="TIME '00:00:00.00000099999'", python=time(0, 0, 0, 1)) # round up to next day, min value - .add_field( - sql="TIME '23:59:59.9999995'", - python=time(0, 0, 0)) - .add_field( - sql="TIME '23:59:59.999999500001'", - python=time(0, 0, 0)) + .add_field(sql="TIME '23:59:59.9999995'", python=time(0, 0, 0)) + .add_field(sql="TIME '23:59:59.999999500001'", python=time(0, 0, 0)) # round up to next day, max value - .add_field( - sql="TIME '23:59:59.9999999'", - python=time(0, 0, 0)) - .add_field( - sql="TIME '23:59:59.999999999'", - python=time(0, 0, 0)) + .add_field(sql="TIME '23:59:59.9999999'", python=time(0, 0, 0)) + .add_field(sql="TIME '23:59:59.999999999'", python=time(0, 0, 0)) ).execute() -@pytest.mark.skipif(trino_version() == '351', reason="time not rounded correctly in older Trino versions") +@pytest.mark.skipif(trino_version() == "351", reason="time not rounded correctly in older Trino versions") @pytest.mark.parametrize( - 'tz_str', + "tz_str", [ - '-08:00', - '+08:00', - '+05:30', - ] + "-08:00", + "+08:00", + "+05:30", + ], ) def test_time_with_timezone(trino_connection, tz_str: str): tz = create_timezone(tz_str) ( SqlTest(trino_connection) - .add_field( - sql="CAST(null AS TIME WITH TIME ZONE)", - python=None) + .add_field(sql="CAST(null AS TIME WITH TIME ZONE)", python=None) # min supported time(3) - .add_field( - sql=f"TIME '00:00:00 {tz_str}'", - python=time(0, 0, 0, tzinfo=tz)) - .add_field( - sql=f"TIME '12:34:56.123 {tz_str}'", - python=time(12, 34, 56, 123000, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00 {tz_str}'", python=time(0, 0, 0, tzinfo=tz)) + .add_field(sql=f"TIME '12:34:56.123 {tz_str}'", python=time(12, 34, 56, 123000, tzinfo=tz)) # max supported time(3) - .add_field( - sql=f"TIME '23:59:59.999 {tz_str}'", - python=time(23, 59, 59, 999000, tzinfo=tz)) + .add_field(sql=f"TIME '23:59:59.999 {tz_str}'", python=time(23, 59, 59, 999000, tzinfo=tz)) # min value for each precision - .add_field( - sql=f"TIME '00:00:00 {tz_str}'", - python=time(0, 0, 0, tzinfo=tz)) - .add_field( - sql=f"TIME '00:00:00.1 {tz_str}'", - python=time(0, 0, 0, 100000, tzinfo=tz)) - .add_field( - sql=f"TIME '00:00:00.01 {tz_str}'", - python=time(0, 0, 0, 10000, tzinfo=tz)) - .add_field( - sql=f"TIME '00:00:00.001 {tz_str}'", - python=time(0, 0, 0, 1000, tzinfo=tz)) - .add_field( - sql=f"TIME '00:00:00.0001 {tz_str}'", - python=time(0, 0, 0, 100, tzinfo=tz)) - .add_field( - sql=f"TIME '00:00:00.00001 {tz_str}'", - python=time(0, 0, 0, 10, tzinfo=tz)) - .add_field( - sql=f"TIME '00:00:00.000001 {tz_str}'", - python=time(0, 0, 0, 1, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00 {tz_str}'", python=time(0, 0, 0, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00.1 {tz_str}'", python=time(0, 0, 0, 100000, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00.01 {tz_str}'", python=time(0, 0, 0, 10000, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00.001 {tz_str}'", python=time(0, 0, 0, 1000, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00.0001 {tz_str}'", python=time(0, 0, 0, 100, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00.00001 {tz_str}'", python=time(0, 0, 0, 10, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00.000001 {tz_str}'", python=time(0, 0, 0, 1, tzinfo=tz)) # max value for each precision - .add_field( - sql=f"TIME '23:59:59 {tz_str}'", - python=time(23, 59, 59, tzinfo=tz)) - .add_field( - sql=f"TIME '23:59:59.9 {tz_str}'", - python=time(23, 59, 59, 900000, tzinfo=tz)) - .add_field( - sql=f"TIME '23:59:59.99 {tz_str}'", - python=time(23, 59, 59, 990000, tzinfo=tz)) - .add_field( - sql=f"TIME '23:59:59.999 {tz_str}'", - python=time(23, 59, 59, 999000, tzinfo=tz)) - .add_field( - sql=f"TIME '23:59:59.9999 {tz_str}'", - python=time(23, 59, 59, 999900, tzinfo=tz)) - .add_field( - sql=f"TIME '23:59:59.99999 {tz_str}'", - python=time(23, 59, 59, 999990, tzinfo=tz)) - .add_field( - sql=f"TIME '23:59:59.999999 {tz_str}'", - python=time(23, 59, 59, 999999, tzinfo=tz)) + .add_field(sql=f"TIME '23:59:59 {tz_str}'", python=time(23, 59, 59, tzinfo=tz)) + .add_field(sql=f"TIME '23:59:59.9 {tz_str}'", python=time(23, 59, 59, 900000, tzinfo=tz)) + .add_field(sql=f"TIME '23:59:59.99 {tz_str}'", python=time(23, 59, 59, 990000, tzinfo=tz)) + .add_field(sql=f"TIME '23:59:59.999 {tz_str}'", python=time(23, 59, 59, 999000, tzinfo=tz)) + .add_field(sql=f"TIME '23:59:59.9999 {tz_str}'", python=time(23, 59, 59, 999900, tzinfo=tz)) + .add_field(sql=f"TIME '23:59:59.99999 {tz_str}'", python=time(23, 59, 59, 999990, tzinfo=tz)) + .add_field(sql=f"TIME '23:59:59.999999 {tz_str}'", python=time(23, 59, 59, 999999, tzinfo=tz)) # round down - .add_field( - sql=f"TIME '12:34:56.1234561 {tz_str}'", - python=time(12, 34, 56, 123456, tzinfo=tz)) + .add_field(sql=f"TIME '12:34:56.1234561 {tz_str}'", python=time(12, 34, 56, 123456, tzinfo=tz)) # round down, min value - .add_field( - sql=f"TIME '00:00:00.000000000001 {tz_str}'", - python=time(0, 0, 0, 0, tzinfo=tz)) - .add_field( - sql=f"TIME '00:00:00.0000001 {tz_str}'", - python=time(0, 0, 0, 0, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00.000000000001 {tz_str}'", python=time(0, 0, 0, 0, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00.0000001 {tz_str}'", python=time(0, 0, 0, 0, tzinfo=tz)) # round down, max value - .add_field( - sql=f"TIME '00:00:00.0000004 {tz_str}'", - python=time(0, 0, 0, 0, tzinfo=tz)) - .add_field( - sql=f"TIME '00:00:00.00000049 {tz_str}'", - python=time(0, 0, 0, 0, tzinfo=tz)) - .add_field( - sql=f"TIME '00:00:00.0000005 {tz_str}'", - python=time(0, 0, 0, 0, tzinfo=tz)) - .add_field( - sql=f"TIME '23:59:59.9999994 {tz_str}'", - python=time(23, 59, 59, 999999, tzinfo=tz)) - .add_field( - sql=f"TIME '23:59:59.9999994999 {tz_str}'", - python=time(23, 59, 59, 999999, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00.0000004 {tz_str}'", python=time(0, 0, 0, 0, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00.00000049 {tz_str}'", python=time(0, 0, 0, 0, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00.0000005 {tz_str}'", python=time(0, 0, 0, 0, tzinfo=tz)) + .add_field(sql=f"TIME '23:59:59.9999994 {tz_str}'", python=time(23, 59, 59, 999999, tzinfo=tz)) + .add_field(sql=f"TIME '23:59:59.9999994999 {tz_str}'", python=time(23, 59, 59, 999999, tzinfo=tz)) # round up - .add_field( - sql=f"TIME '12:34:56.123456789 {tz_str}'", - python=time(12, 34, 56, 123457, tzinfo=tz)) + .add_field(sql=f"TIME '12:34:56.123456789 {tz_str}'", python=time(12, 34, 56, 123457, tzinfo=tz)) # round up, min value - .add_field( - sql=f"TIME '00:00:00.000000500001 {tz_str}'", - python=time(0, 0, 0, 1, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00.000000500001 {tz_str}'", python=time(0, 0, 0, 1, tzinfo=tz)) # round up, max value - .add_field( - sql=f"TIME '00:00:00.0000009 {tz_str}'", - python=time(0, 0, 0, 1, tzinfo=tz)) - .add_field( - sql=f"TIME '00:00:00.00000099999 {tz_str}'", - python=time(0, 0, 0, 1, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00.0000009 {tz_str}'", python=time(0, 0, 0, 1, tzinfo=tz)) + .add_field(sql=f"TIME '00:00:00.00000099999 {tz_str}'", python=time(0, 0, 0, 1, tzinfo=tz)) # round up to next day, min value - .add_field( - sql=f"TIME '23:59:59.9999995 {tz_str}'", - python=time(0, 0, 0, tzinfo=tz)) - .add_field( - sql=f"TIME '23:59:59.999999500001 {tz_str}'", - python=time(0, 0, 0, tzinfo=tz)) + .add_field(sql=f"TIME '23:59:59.9999995 {tz_str}'", python=time(0, 0, 0, tzinfo=tz)) + .add_field(sql=f"TIME '23:59:59.999999500001 {tz_str}'", python=time(0, 0, 0, tzinfo=tz)) # round up to next day, max value - .add_field( - sql=f"TIME '23:59:59.9999999 {tz_str}'", - python=time(0, 0, 0, tzinfo=tz)) - .add_field( - sql=f"TIME '23:59:59.999999999 {tz_str}'", - python=time(0, 0, 0, tzinfo=tz)) + .add_field(sql=f"TIME '23:59:59.9999999 {tz_str}'", python=time(0, 0, 0, tzinfo=tz)) + .add_field(sql=f"TIME '23:59:59.999999999 {tz_str}'", python=time(0, 0, 0, tzinfo=tz)) ).execute() def test_timestamp(trino_connection): ( SqlTest(trino_connection) - .add_field( - sql="CAST(null AS TIMESTAMP)", - python=None) + .add_field(sql="CAST(null AS TIMESTAMP)", python=None) # min supported timestamp(3) - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00'", - python=datetime(2001, 8, 22, 0, 0, 0)) - .add_field( - sql="TIMESTAMP '2001-08-22 12:34:56.123'", - python=datetime(2001, 8, 22, 12, 34, 56, 123000)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00'", python=datetime(2001, 8, 22, 0, 0, 0)) + .add_field(sql="TIMESTAMP '2001-08-22 12:34:56.123'", python=datetime(2001, 8, 22, 12, 34, 56, 123000)) # max supported timestamp(3) - .add_field( - sql="TIMESTAMP '2001-08-22 23:59:59.999'", - python=datetime(2001, 8, 22, 23, 59, 59, 999000)) + .add_field(sql="TIMESTAMP '2001-08-22 23:59:59.999'", python=datetime(2001, 8, 22, 23, 59, 59, 999000)) # min value for each precision - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00'", - python=datetime(2001, 8, 22, 0, 0, 0)) - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.1'", - python=datetime(2001, 8, 22, 0, 0, 0, 100000)) - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.01'", - python=datetime(2001, 8, 22, 0, 0, 0, 10000)) - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.001'", - python=datetime(2001, 8, 22, 0, 0, 0, 1000)) - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.0001'", - python=datetime(2001, 8, 22, 0, 0, 0, 100)) - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.00001'", - python=datetime(2001, 8, 22, 0, 0, 0, 10)) - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.000001'", - python=datetime(2001, 8, 22, 0, 0, 0, 1)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00'", python=datetime(2001, 8, 22, 0, 0, 0)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.1'", python=datetime(2001, 8, 22, 0, 0, 0, 100000)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.01'", python=datetime(2001, 8, 22, 0, 0, 0, 10000)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.001'", python=datetime(2001, 8, 22, 0, 0, 0, 1000)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.0001'", python=datetime(2001, 8, 22, 0, 0, 0, 100)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.00001'", python=datetime(2001, 8, 22, 0, 0, 0, 10)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.000001'", python=datetime(2001, 8, 22, 0, 0, 0, 1)) # max value for each precision - .add_field( - sql="TIMESTAMP '2001-08-22 23:59:59'", - python=datetime(2001, 8, 22, 23, 59, 59)) - .add_field( - sql="TIMESTAMP '2001-08-22 23:59:59.9'", - python=datetime(2001, 8, 22, 23, 59, 59, 900000)) - .add_field( - sql="TIMESTAMP '2001-08-22 23:59:59.99'", - python=datetime(2001, 8, 22, 23, 59, 59, 990000)) - .add_field( - sql="TIMESTAMP '2001-08-22 23:59:59.999'", - python=datetime(2001, 8, 22, 23, 59, 59, 999000)) - .add_field( - sql="TIMESTAMP '2001-08-22 23:59:59.9999'", - python=datetime(2001, 8, 22, 23, 59, 59, 999900)) - .add_field( - sql="TIMESTAMP '2001-08-22 23:59:59.99999'", - python=datetime(2001, 8, 22, 23, 59, 59, 999990)) - .add_field( - sql="TIMESTAMP '2001-08-22 23:59:59.999999'", - python=datetime(2001, 8, 22, 23, 59, 59, 999999)) + .add_field(sql="TIMESTAMP '2001-08-22 23:59:59'", python=datetime(2001, 8, 22, 23, 59, 59)) + .add_field(sql="TIMESTAMP '2001-08-22 23:59:59.9'", python=datetime(2001, 8, 22, 23, 59, 59, 900000)) + .add_field(sql="TIMESTAMP '2001-08-22 23:59:59.99'", python=datetime(2001, 8, 22, 23, 59, 59, 990000)) + .add_field(sql="TIMESTAMP '2001-08-22 23:59:59.999'", python=datetime(2001, 8, 22, 23, 59, 59, 999000)) + .add_field(sql="TIMESTAMP '2001-08-22 23:59:59.9999'", python=datetime(2001, 8, 22, 23, 59, 59, 999900)) + .add_field(sql="TIMESTAMP '2001-08-22 23:59:59.99999'", python=datetime(2001, 8, 22, 23, 59, 59, 999990)) + .add_field(sql="TIMESTAMP '2001-08-22 23:59:59.999999'", python=datetime(2001, 8, 22, 23, 59, 59, 999999)) # round down - .add_field( - sql="TIMESTAMP '2001-08-22 12:34:56.1234561'", - python=datetime(2001, 8, 22, 12, 34, 56, 123456)) + .add_field(sql="TIMESTAMP '2001-08-22 12:34:56.1234561'", python=datetime(2001, 8, 22, 12, 34, 56, 123456)) # round down, min value - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.000000000001'", - python=datetime(2001, 8, 22, 0, 0, 0, 0)) - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.0000001'", - python=datetime(2001, 8, 22, 0, 0, 0, 0)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.000000000001'", python=datetime(2001, 8, 22, 0, 0, 0, 0)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.0000001'", python=datetime(2001, 8, 22, 0, 0, 0, 0)) # round down, max value - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.0000004'", - python=datetime(2001, 8, 22, 0, 0, 0, 0)) - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.00000049'", - python=datetime(2001, 8, 22, 0, 0, 0, 0)) - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.0000005'", - python=datetime(2001, 8, 22, 0, 0, 0, 0)) - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.00000050'", - python=datetime(2001, 8, 22, 0, 0, 0, 0)) - .add_field( - sql="TIMESTAMP '2001-08-22 23:59:59.9999994'", - python=datetime(2001, 8, 22, 23, 59, 59, 999999)) - .add_field( - sql="TIMESTAMP '2001-08-22 23:59:59.9999994999'", - python=datetime(2001, 8, 22, 23, 59, 59, 999999)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.0000004'", python=datetime(2001, 8, 22, 0, 0, 0, 0)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.00000049'", python=datetime(2001, 8, 22, 0, 0, 0, 0)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.0000005'", python=datetime(2001, 8, 22, 0, 0, 0, 0)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.00000050'", python=datetime(2001, 8, 22, 0, 0, 0, 0)) + .add_field(sql="TIMESTAMP '2001-08-22 23:59:59.9999994'", python=datetime(2001, 8, 22, 23, 59, 59, 999999)) + .add_field(sql="TIMESTAMP '2001-08-22 23:59:59.9999994999'", python=datetime(2001, 8, 22, 23, 59, 59, 999999)) # round up - .add_field( - sql="TIMESTAMP '2001-08-22 12:34:56.123456789'", - python=datetime(2001, 8, 22, 12, 34, 56, 123457)) + .add_field(sql="TIMESTAMP '2001-08-22 12:34:56.123456789'", python=datetime(2001, 8, 22, 12, 34, 56, 123457)) # round up, min value - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.000000500001'", - python=datetime(2001, 8, 22, 0, 0, 0, 1)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.000000500001'", python=datetime(2001, 8, 22, 0, 0, 0, 1)) # round up, max value - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.0000009'", - python=datetime(2001, 8, 22, 0, 0, 0, 1)) - .add_field( - sql="TIMESTAMP '2001-08-22 00:00:00.00000099999'", - python=datetime(2001, 8, 22, 0, 0, 0, 1)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.0000009'", python=datetime(2001, 8, 22, 0, 0, 0, 1)) + .add_field(sql="TIMESTAMP '2001-08-22 00:00:00.00000099999'", python=datetime(2001, 8, 22, 0, 0, 0, 1)) # round up to next day, min value - .add_field( - sql="TIMESTAMP '2001-08-22 23:59:59.9999995'", - python=datetime(2001, 8, 23, 0, 0, 0, 0)) - .add_field( - sql="TIMESTAMP '2001-08-22 23:59:59.999999500001'", - python=datetime(2001, 8, 23, 0, 0, 0, 0)) + .add_field(sql="TIMESTAMP '2001-08-22 23:59:59.9999995'", python=datetime(2001, 8, 23, 0, 0, 0, 0)) + .add_field(sql="TIMESTAMP '2001-08-22 23:59:59.999999500001'", python=datetime(2001, 8, 23, 0, 0, 0, 0)) # round up to next day, max value - .add_field( - sql="TIMESTAMP '2001-08-22 23:59:59.9999999'", - python=datetime(2001, 8, 23, 0, 0, 0, 0)) - .add_field( - sql="TIMESTAMP '2001-08-22 23:59:59.999999999'", - python=datetime(2001, 8, 23, 0, 0, 0, 0)) + .add_field(sql="TIMESTAMP '2001-08-22 23:59:59.9999999'", python=datetime(2001, 8, 23, 0, 0, 0, 0)) + .add_field(sql="TIMESTAMP '2001-08-22 23:59:59.999999999'", python=datetime(2001, 8, 23, 0, 0, 0, 0)) # ce - .add_field( - sql="TIMESTAMP '0001-01-01 01:23:45.123'", - python=datetime(1, 1, 1, 1, 23, 45, 123000)) + .add_field(sql="TIMESTAMP '0001-01-01 01:23:45.123'", python=datetime(1, 1, 1, 1, 23, 45, 123000)) # julian calendar - .add_field( - sql="TIMESTAMP '1582-10-04 01:23:45.123'", - python=datetime(1582, 10, 4, 1, 23, 45, 123000)) + .add_field(sql="TIMESTAMP '1582-10-04 01:23:45.123'", python=datetime(1582, 10, 4, 1, 23, 45, 123000)) # during switch - .add_field( - sql="TIMESTAMP '1582-10-05 01:23:45.123'", - python=datetime(1582, 10, 5, 1, 23, 45, 123000)) + .add_field(sql="TIMESTAMP '1582-10-05 01:23:45.123'", python=datetime(1582, 10, 5, 1, 23, 45, 123000)) # gregorian calendar - .add_field( - sql="TIMESTAMP '1582-10-14 01:23:45.123'", - python=datetime(1582, 10, 14, 1, 23, 45, 123000)) + .add_field(sql="TIMESTAMP '1582-10-14 01:23:45.123'", python=datetime(1582, 10, 14, 1, 23, 45, 123000)) ).execute() @pytest.mark.parametrize( - 'tz_str', + "tz_str", [ - '-08:00', - '+08:00', - 'US/Eastern', - 'Asia/Kolkata', - 'GMT', - ] + "-08:00", + "+08:00", + "US/Eastern", + "Asia/Kolkata", + "GMT", + ], ) def test_timestamp_with_timezone(trino_connection, tz_str): tz = create_timezone(tz_str) ( SqlTest(trino_connection) - .add_field( - sql="CAST(null AS TIMESTAMP WITH TIME ZONE)", - python=None) + .add_field(sql="CAST(null AS TIMESTAMP WITH TIME ZONE)", python=None) # min supported timestamp(3) with time zone - .add_field( - sql=f"TIMESTAMP '2001-08-22 00:00:00 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, tzinfo=tz)) + .add_field(sql=f"TIMESTAMP '2001-08-22 00:00:00 {tz_str}'", python=datetime(2001, 8, 22, 0, 0, 0, tzinfo=tz)) .add_field( sql=f"TIMESTAMP '2001-08-22 12:34:56.123 {tz_str}'", - python=datetime(2001, 8, 22, 12, 34, 56, 123000, tzinfo=tz)) + python=datetime(2001, 8, 22, 12, 34, 56, 123000, tzinfo=tz), + ) # max supported timestamp(3) with time zone .add_field( sql=f"TIMESTAMP '2001-08-22 23:59:59.999 {tz_str}'", - python=datetime(2001, 8, 22, 23, 59, 59, 999000, tzinfo=tz)) + python=datetime(2001, 8, 22, 23, 59, 59, 999000, tzinfo=tz), + ) # min value for each precision + .add_field(sql=f"TIMESTAMP '2001-08-22 00:00:00 {tz_str}'", python=datetime(2001, 8, 22, 0, 0, 0, tzinfo=tz)) .add_field( - sql=f"TIMESTAMP '2001-08-22 00:00:00 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, tzinfo=tz)) - .add_field( - sql=f"TIMESTAMP '2001-08-22 00:00:00.1 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 100000, tzinfo=tz)) + sql=f"TIMESTAMP '2001-08-22 00:00:00.1 {tz_str}'", python=datetime(2001, 8, 22, 0, 0, 0, 100000, tzinfo=tz) + ) .add_field( - sql=f"TIMESTAMP '2001-08-22 00:00:00.01 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 10000, tzinfo=tz)) + sql=f"TIMESTAMP '2001-08-22 00:00:00.01 {tz_str}'", python=datetime(2001, 8, 22, 0, 0, 0, 10000, tzinfo=tz) + ) .add_field( - sql=f"TIMESTAMP '2001-08-22 00:00:00.001 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 1000, tzinfo=tz)) + sql=f"TIMESTAMP '2001-08-22 00:00:00.001 {tz_str}'", python=datetime(2001, 8, 22, 0, 0, 0, 1000, tzinfo=tz) + ) .add_field( - sql=f"TIMESTAMP '2001-08-22 00:00:00.0001 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 100, tzinfo=tz)) + sql=f"TIMESTAMP '2001-08-22 00:00:00.0001 {tz_str}'", python=datetime(2001, 8, 22, 0, 0, 0, 100, tzinfo=tz) + ) .add_field( - sql=f"TIMESTAMP '2001-08-22 00:00:00.00001 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 10, tzinfo=tz)) + sql=f"TIMESTAMP '2001-08-22 00:00:00.00001 {tz_str}'", python=datetime(2001, 8, 22, 0, 0, 0, 10, tzinfo=tz) + ) .add_field( - sql=f"TIMESTAMP '2001-08-22 00:00:00.000001 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 1, tzinfo=tz)) + sql=f"TIMESTAMP '2001-08-22 00:00:00.000001 {tz_str}'", python=datetime(2001, 8, 22, 0, 0, 0, 1, tzinfo=tz) + ) # max value for each precision - .add_field( - sql=f"TIMESTAMP '2001-08-22 23:59:59 {tz_str}'", - python=datetime(2001, 8, 22, 23, 59, 59, tzinfo=tz)) + .add_field(sql=f"TIMESTAMP '2001-08-22 23:59:59 {tz_str}'", python=datetime(2001, 8, 22, 23, 59, 59, tzinfo=tz)) .add_field( sql=f"TIMESTAMP '2001-08-22 23:59:59.9 {tz_str}'", - python=datetime(2001, 8, 22, 23, 59, 59, 900000, tzinfo=tz)) + python=datetime(2001, 8, 22, 23, 59, 59, 900000, tzinfo=tz), + ) .add_field( sql=f"TIMESTAMP '2001-08-22 23:59:59.99 {tz_str}'", - python=datetime(2001, 8, 22, 23, 59, 59, 990000, tzinfo=tz)) + python=datetime(2001, 8, 22, 23, 59, 59, 990000, tzinfo=tz), + ) .add_field( sql=f"TIMESTAMP '2001-08-22 23:59:59.999 {tz_str}'", - python=datetime(2001, 8, 22, 23, 59, 59, 999000, tzinfo=tz)) + python=datetime(2001, 8, 22, 23, 59, 59, 999000, tzinfo=tz), + ) .add_field( sql=f"TIMESTAMP '2001-08-22 23:59:59.9999 {tz_str}'", - python=datetime(2001, 8, 22, 23, 59, 59, 999900, tzinfo=tz)) + python=datetime(2001, 8, 22, 23, 59, 59, 999900, tzinfo=tz), + ) .add_field( sql=f"TIMESTAMP '2001-08-22 23:59:59.99999 {tz_str}'", - python=datetime(2001, 8, 22, 23, 59, 59, 999990, tzinfo=tz)) + python=datetime(2001, 8, 22, 23, 59, 59, 999990, tzinfo=tz), + ) .add_field( sql=f"TIMESTAMP '2001-08-22 23:59:59.999999 {tz_str}'", - python=datetime(2001, 8, 22, 23, 59, 59, 999999, tzinfo=tz)) + python=datetime(2001, 8, 22, 23, 59, 59, 999999, tzinfo=tz), + ) # round down .add_field( sql=f"TIMESTAMP '2001-08-22 12:34:56.1234561 {tz_str}'", - python=datetime(2001, 8, 22, 12, 34, 56, 123456, tzinfo=tz)) + python=datetime(2001, 8, 22, 12, 34, 56, 123456, tzinfo=tz), + ) # round down, min value .add_field( sql=f"TIMESTAMP '2001-08-22 00:00:00.000000000001 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 0, tzinfo=tz)) + python=datetime(2001, 8, 22, 0, 0, 0, 0, tzinfo=tz), + ) .add_field( - sql=f"TIMESTAMP '2001-08-22 00:00:00.0000001 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 0, tzinfo=tz)) + sql=f"TIMESTAMP '2001-08-22 00:00:00.0000001 {tz_str}'", python=datetime(2001, 8, 22, 0, 0, 0, 0, tzinfo=tz) + ) # round down, max value .add_field( - sql=f"TIMESTAMP '2001-08-22 00:00:00.0000004 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 0, tzinfo=tz)) + sql=f"TIMESTAMP '2001-08-22 00:00:00.0000004 {tz_str}'", python=datetime(2001, 8, 22, 0, 0, 0, 0, tzinfo=tz) + ) .add_field( sql=f"TIMESTAMP '2001-08-22 00:00:00.00000049 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 0, tzinfo=tz)) + python=datetime(2001, 8, 22, 0, 0, 0, 0, tzinfo=tz), + ) .add_field( - sql=f"TIMESTAMP '2001-08-22 00:00:00.0000005 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 0, tzinfo=tz)) + sql=f"TIMESTAMP '2001-08-22 00:00:00.0000005 {tz_str}'", python=datetime(2001, 8, 22, 0, 0, 0, 0, tzinfo=tz) + ) .add_field( sql=f"TIMESTAMP '2001-08-22 00:00:00.00000050 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 0, tzinfo=tz)) + python=datetime(2001, 8, 22, 0, 0, 0, 0, tzinfo=tz), + ) .add_field( sql=f"TIMESTAMP '2001-08-22 23:59:59.9999994 {tz_str}'", - python=datetime(2001, 8, 22, 23, 59, 59, 999999, tzinfo=tz)) + python=datetime(2001, 8, 22, 23, 59, 59, 999999, tzinfo=tz), + ) .add_field( sql=f"TIMESTAMP '2001-08-22 23:59:59.9999994999 {tz_str}'", - python=datetime(2001, 8, 22, 23, 59, 59, 999999, tzinfo=tz)) + python=datetime(2001, 8, 22, 23, 59, 59, 999999, tzinfo=tz), + ) # round up .add_field( sql=f"TIMESTAMP '2001-08-22 12:34:56.123456789 {tz_str}'", - python=datetime(2001, 8, 22, 12, 34, 56, 123457, tzinfo=tz)) + python=datetime(2001, 8, 22, 12, 34, 56, 123457, tzinfo=tz), + ) # round up, min value .add_field( sql=f"TIMESTAMP '2001-08-22 00:00:00.000000500001 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 1, tzinfo=tz)) + python=datetime(2001, 8, 22, 0, 0, 0, 1, tzinfo=tz), + ) # round up, max value .add_field( - sql=f"TIMESTAMP '2001-08-22 00:00:00.0000009 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 1, tzinfo=tz)) + sql=f"TIMESTAMP '2001-08-22 00:00:00.0000009 {tz_str}'", python=datetime(2001, 8, 22, 0, 0, 0, 1, tzinfo=tz) + ) .add_field( sql=f"TIMESTAMP '2001-08-22 00:00:00.00000099999 {tz_str}'", - python=datetime(2001, 8, 22, 0, 0, 0, 1, tzinfo=tz)) + python=datetime(2001, 8, 22, 0, 0, 0, 1, tzinfo=tz), + ) # round up to next day, min value .add_field( - sql=f"TIMESTAMP '2001-08-22 23:59:59.9999995 {tz_str}'", - python=datetime(2001, 8, 23, 0, 0, 0, 0, tzinfo=tz)) + sql=f"TIMESTAMP '2001-08-22 23:59:59.9999995 {tz_str}'", python=datetime(2001, 8, 23, 0, 0, 0, 0, tzinfo=tz) + ) .add_field( sql=f"TIMESTAMP '2001-08-22 23:59:59.999999500001 {tz_str}'", - python=datetime(2001, 8, 23, 0, 0, 0, 0, tzinfo=tz)) + python=datetime(2001, 8, 23, 0, 0, 0, 0, tzinfo=tz), + ) # round up to next day, max value .add_field( - sql=f"TIMESTAMP '2001-08-22 23:59:59.9999999 {tz_str}'", - python=datetime(2001, 8, 23, 0, 0, 0, 0, tzinfo=tz)) + sql=f"TIMESTAMP '2001-08-22 23:59:59.9999999 {tz_str}'", python=datetime(2001, 8, 23, 0, 0, 0, 0, tzinfo=tz) + ) .add_field( sql=f"TIMESTAMP '2001-08-22 23:59:59.999999999 {tz_str}'", - python=datetime(2001, 8, 23, 0, 0, 0, 0, tzinfo=tz)) + python=datetime(2001, 8, 23, 0, 0, 0, 0, tzinfo=tz), + ) # ce .add_field( - sql=f"TIMESTAMP '0001-01-01 01:23:45.123 {tz_str}'", - python=datetime(1, 1, 1, 1, 23, 45, 123000, tzinfo=tz)) + sql=f"TIMESTAMP '0001-01-01 01:23:45.123 {tz_str}'", python=datetime(1, 1, 1, 1, 23, 45, 123000, tzinfo=tz) + ) # Julian calendar .add_field( sql=f"TIMESTAMP '1582-10-04 01:23:45.123 {tz_str}'", - python=datetime(1582, 10, 4, 1, 23, 45, 123000, tzinfo=tz)) + python=datetime(1582, 10, 4, 1, 23, 45, 123000, tzinfo=tz), + ) # during switch .add_field( sql=f"TIMESTAMP '1582-10-05 01:23:45.123 {tz_str}'", - python=datetime(1582, 10, 5, 1, 23, 45, 123000, tzinfo=tz)) + python=datetime(1582, 10, 5, 1, 23, 45, 123000, tzinfo=tz), + ) # Gregorian calendar .add_field( sql=f"TIMESTAMP '1582-10-14 01:23:45.123 {tz_str}'", - python=datetime(1582, 10, 14, 1, 23, 45, 123000, tzinfo=tz)) + python=datetime(1582, 10, 14, 1, 23, 45, 123000, tzinfo=tz), + ) ).execute() @@ -717,16 +504,18 @@ def test_timestamp_with_timezone_dst(trino_connection): .add_field( sql=f"TIMESTAMP '2022-03-27 01:59:59.999999999 {tz_str}'", # 2:00:00 (STD) becomes 3:00:00 (DST)) - python=datetime(2022, 3, 27, 2, 0, 0, tzinfo=tz)) + python=datetime(2022, 3, 27, 2, 0, 0, tzinfo=tz), + ) .add_field( sql=f"TIMESTAMP '2022-10-30 02:59:59.999999999 {tz_str}'", # 3:00:00 (DST) becomes 2:00:00 (STD)) - python=datetime(2022, 10, 30, 3, 0, 0, tzinfo=tz)) + python=datetime(2022, 10, 30, 3, 0, 0, tzinfo=tz), + ) ).execute() def create_timezone(timezone_str: str) -> tzinfo: - if timezone_str.startswith('+') or timezone_str.startswith('-'): + if timezone_str.startswith("+") or timezone_str.startswith("-"): # Trino doesn't support sub-hour offsets # trino> select timestamp '2022-01-01 01:01:01.123 UTC+05:30:10'; # Query 20221118_120049_00002_vk3k4 failed: line 1:8: Time zone not supported: UTC+05:30:10 @@ -742,73 +531,41 @@ def create_timezone(timezone_str: str) -> tzinfo: def test_interval_year_to_month(trino_connection): ( SqlTest(trino_connection) - .add_field( - sql="CAST(null AS INTERVAL YEAR TO MONTH)", - python=None) - .add_field( - sql="INTERVAL '10' YEAR", - python=relativedelta(years=10)) - .add_field( - sql="INTERVAL '-5' YEAR", - python=relativedelta(years=-5)) - .add_field( - sql="INTERVAL '3' MONTH", - python=relativedelta(months=3)) - .add_field( - sql="INTERVAL '-18' MONTH", - python=relativedelta(years=-1, months=-6)) - .add_field( - sql="INTERVAL '30' MONTH", - python=relativedelta(years=2, months=6)) + .add_field(sql="CAST(null AS INTERVAL YEAR TO MONTH)", python=None) + .add_field(sql="INTERVAL '10' YEAR", python=relativedelta(years=10)) + .add_field(sql="INTERVAL '-5' YEAR", python=relativedelta(years=-5)) + .add_field(sql="INTERVAL '3' MONTH", python=relativedelta(months=3)) + .add_field(sql="INTERVAL '-18' MONTH", python=relativedelta(years=-1, months=-6)) + .add_field(sql="INTERVAL '30' MONTH", python=relativedelta(years=2, months=6)) # max supported INTERVAL in Trino - .add_field( - sql="INTERVAL '178956970-7' YEAR TO MONTH", - python=relativedelta(years=178956970, months=7)) + .add_field(sql="INTERVAL '178956970-7' YEAR TO MONTH", python=relativedelta(years=178956970, months=7)) # min supported INTERVAL in Trino - .add_field( - sql="INTERVAL '-178956970-8' YEAR TO MONTH", - python=relativedelta(years=-178956970, months=-8)) + .add_field(sql="INTERVAL '-178956970-8' YEAR TO MONTH", python=relativedelta(years=-178956970, months=-8)) ).execute() def test_interval_day_to_second(trino_connection): ( SqlTest(trino_connection) - .add_field( - sql="CAST(null AS INTERVAL DAY TO SECOND)", - python=None) - .add_field( - sql="INTERVAL '2' DAY", - python=timedelta(days=2)) - .add_field( - sql="INTERVAL '-2' DAY", - python=timedelta(days=-2)) - .add_field( - sql="INTERVAL '-2' SECOND", - python=timedelta(seconds=-2)) + .add_field(sql="CAST(null AS INTERVAL DAY TO SECOND)", python=None) + .add_field(sql="INTERVAL '2' DAY", python=timedelta(days=2)) + .add_field(sql="INTERVAL '-2' DAY", python=timedelta(days=-2)) + .add_field(sql="INTERVAL '-2' SECOND", python=timedelta(seconds=-2)) .add_field( sql="INTERVAL '1 11:11:11.116555' DAY TO SECOND", - python=timedelta(days=1, seconds=40271, microseconds=116000)) - .add_field( - sql="INTERVAL '-5 23:59:57.000' DAY TO SECOND", - python=timedelta(days=-6, seconds=3)) - .add_field( - sql="INTERVAL '12 10:45' DAY TO MINUTE", - python=timedelta(days=12, seconds=38700)) - .add_field( - sql="INTERVAL '45:32.123' MINUTE TO SECOND", - python=timedelta(seconds=2732, microseconds=123000)) - .add_field( - sql="INTERVAL '32.123' SECOND", - python=timedelta(seconds=32, microseconds=123000)) + python=timedelta(days=1, seconds=40271, microseconds=116000), + ) + .add_field(sql="INTERVAL '-5 23:59:57.000' DAY TO SECOND", python=timedelta(days=-6, seconds=3)) + .add_field(sql="INTERVAL '12 10:45' DAY TO MINUTE", python=timedelta(days=12, seconds=38700)) + .add_field(sql="INTERVAL '45:32.123' MINUTE TO SECOND", python=timedelta(seconds=2732, microseconds=123000)) + .add_field(sql="INTERVAL '32.123' SECOND", python=timedelta(seconds=32, microseconds=123000)) # max supported timedelta in Python .add_field( sql="INTERVAL '999999999 23:59:59.999' DAY TO SECOND", - python=timedelta(days=999999999, hours=23, minutes=59, seconds=59, milliseconds=999)) + python=timedelta(days=999999999, hours=23, minutes=59, seconds=59, milliseconds=999), + ) # min supported timedelta in Python - .add_field( - sql="INTERVAL '-999999999' DAY", - python=timedelta(days=-999999999)) + .add_field(sql="INTERVAL '-999999999' DAY", python=timedelta(days=-999999999)) ).execute() SqlExpectFailureTest(trino_connection).execute("INTERVAL '1000000000' DAY") @@ -817,48 +574,53 @@ def test_interval_day_to_second(trino_connection): def test_array(trino_connection): # primitive types - SqlTest(trino_connection) \ - .add_field(sql="CAST(null AS ARRAY(VARCHAR))", python=None) \ - .add_field(sql="ARRAY[]", python=[]) \ - .add_field(sql="ARRAY[true, false, null]", python=[True, False, None]) \ - .add_field(sql="ARRAY[1, 2, null]", python=[1, 2, None]) \ - .add_field( - sql="ARRAY[CAST('NaN' AS REAL), CAST('-Infinity' AS REAL), CAST(3.4028235E38 AS REAL), CAST(1.4E-45 AS " - "REAL), CAST('Infinity' AS REAL), null]", - python=[math.nan, -math.inf, 3.4028235e+38, 1.4e-45, math.inf, None], - has_nan=True) \ - .add_field( - sql="ARRAY[CAST('NaN' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST(1.7976931348623157E308 AS DOUBLE), " - "CAST(4.9E-324 AS DOUBLE), CAST('Infinity' AS DOUBLE), null]", - python=[math.nan, -math.inf, 1.7976931348623157e+308, 5e-324, math.inf, None], - has_nan=True) \ - .add_field(sql="ARRAY[1.2, 2.4, null]", python=[Decimal("1.2"), Decimal("2.4"), None]) \ - .add_field(sql="ARRAY[CAST('hello' AS VARCHAR), null]", python=["hello", None]) \ - .add_field(sql="ARRAY[CAST('a' AS CHAR(3)), null]", python=['a ', None]) \ - .add_field(sql="ARRAY[X'', X'65683F', null]", python=[b'', b'eh?', None]) \ - .add_field(sql="ARRAY[JSON 'null', JSON '{}', null]", python=['null', '{}', None]) \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST(null AS ARRAY(VARCHAR))", python=None).add_field( + sql="ARRAY[]", python=[] + ).add_field(sql="ARRAY[true, false, null]", python=[True, False, None]).add_field( + sql="ARRAY[1, 2, null]", python=[1, 2, None] + ).add_field( + sql="ARRAY[CAST('NaN' AS REAL), CAST('-Infinity' AS REAL), CAST(3.4028235E38 AS REAL), CAST(1.4E-45 AS " + "REAL), CAST('Infinity' AS REAL), null]", + python=[math.nan, -math.inf, 3.4028235e38, 1.4e-45, math.inf, None], + has_nan=True, + ).add_field( + sql="ARRAY[CAST('NaN' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST(1.7976931348623157E308 AS DOUBLE), " + "CAST(4.9E-324 AS DOUBLE), CAST('Infinity' AS DOUBLE), null]", + python=[math.nan, -math.inf, 1.7976931348623157e308, 5e-324, math.inf, None], + has_nan=True, + ).add_field( + sql="ARRAY[1.2, 2.4, null]", python=[Decimal("1.2"), Decimal("2.4"), None] + ).add_field( + sql="ARRAY[CAST('hello' AS VARCHAR), null]", python=["hello", None] + ).add_field( + sql="ARRAY[CAST('a' AS CHAR(3)), null]", python=["a ", None] + ).add_field( + sql="ARRAY[X'', X'65683F', null]", python=[b"", b"eh?", None] + ).add_field( + sql="ARRAY[JSON 'null', JSON '{}', null]", python=["null", "{}", None] + ).execute() # temporal types - SqlTest(trino_connection) \ - .add_field(sql="ARRAY[DATE '1970-01-01', null]", python=[date(1970, 1, 1), None]) \ - .add_field(sql="ARRAY[TIME '01:01:01', null]", python=[time(1, 1, 1), None]) \ - .add_field(sql="ARRAY[TIME '01:01:01 +05:30', null]", - python=[time(1, 1, 1, tzinfo=create_timezone("+05:30")), None]) \ - .add_field(sql="ARRAY[TIMESTAMP '1970-01-01 01:01:01', null]", - python=[datetime(1970, 1, 1, 1, 1, 1), None]) \ - .add_field(sql="ARRAY[TIMESTAMP '1970-01-01 01:01:01 +05:30', null]", - python=[datetime(1970, 1, 1, 1, 1, 1, tzinfo=create_timezone("+05:30")), None]) \ - .execute() + SqlTest(trino_connection).add_field( + sql="ARRAY[DATE '1970-01-01', null]", python=[date(1970, 1, 1), None] + ).add_field(sql="ARRAY[TIME '01:01:01', null]", python=[time(1, 1, 1), None]).add_field( + sql="ARRAY[TIME '01:01:01 +05:30', null]", python=[time(1, 1, 1, tzinfo=create_timezone("+05:30")), None] + ).add_field( + sql="ARRAY[TIMESTAMP '1970-01-01 01:01:01', null]", python=[datetime(1970, 1, 1, 1, 1, 1), None] + ).add_field( + sql="ARRAY[TIMESTAMP '1970-01-01 01:01:01 +05:30', null]", + python=[datetime(1970, 1, 1, 1, 1, 1, tzinfo=create_timezone("+05:30")), None], + ).execute() # structural types - SqlTest(trino_connection) \ - .add_field(sql="ARRAY[ARRAY[1, null], ARRAY[2, 3], null]", python=[[1, None], [2, 3], None]) \ - .add_field( - sql="ARRAY[MAP(ARRAY['foo', 'bar', 'baz'], ARRAY['one', 'two', null]), MAP(), null]", - python=[{"foo": "one", "bar": "two", "baz": None}, {}, None]) \ - .add_field(sql="ARRAY[ROW(1, 2), ROW(1, null), null]", python=[(1, 2), (1, None), None]) \ - .execute() + SqlTest(trino_connection).add_field( + sql="ARRAY[ARRAY[1, null], ARRAY[2, 3], null]", python=[[1, None], [2, 3], None] + ).add_field( + sql="ARRAY[MAP(ARRAY['foo', 'bar', 'baz'], ARRAY['one', 'two', null]), MAP(), null]", + python=[{"foo": "one", "bar": "two", "baz": None}, {}, None], + ).add_field( + sql="ARRAY[ROW(1, 2), ROW(1, null), null]", python=[(1, 2), (1, None), None] + ).execute() def test_map(trino_connection): @@ -870,41 +632,61 @@ def test_map(trino_connection): .add_field(sql="MAP(ARRAY[true, false], ARRAY[false, true])", python={True: False, False: True}) .add_field(sql="MAP(ARRAY[true, false], ARRAY[true, null])", python={True: True, False: None}) .add_field(sql="MAP(ARRAY[1, 2], ARRAY[1, null])", python={1: 1, 2: None}) - .add_field(sql="MAP(" - "ARRAY[CAST('NaN' AS REAL), CAST('-Infinity' AS REAL), CAST(3.4028235E38 AS REAL), CAST(1.4E-45 AS REAL), CAST('Infinity' AS REAL), CAST(1 AS REAL)], " # noqa: E501 - "ARRAY[CAST('NaN' AS REAL), CAST('-Infinity' AS REAL), CAST(3.4028235E38 AS REAL), CAST(1.4E-45 AS REAL), CAST('Infinity' AS REAL), null])", # noqa: E501 - python={math.nan: math.nan, - -math.inf: -math.inf, - 3.4028235e+38: 3.4028235e+38, - 1.4e-45: 1.4e-45, - math.inf: math.inf, - 1: None}, - has_nan=True) - .add_field(sql="MAP(" - "ARRAY[CAST('NaN' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST(1.7976931348623157E308 AS DOUBLE), CAST(4.9E-324 AS DOUBLE), CAST('Infinity' AS DOUBLE), CAST(1 AS DOUBLE)], " # noqa: E501 - "ARRAY[CAST('NaN' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST(1.7976931348623157E308 AS DOUBLE), CAST(4.9E-324 AS DOUBLE), CAST('Infinity' AS DOUBLE), null])", # noqa: E501 - python={math.nan: math.nan, - -math.inf: -math.inf, - 1.7976931348623157e+308: 1.7976931348623157e+308, - 5e-324: 5e-324, - math.inf: math.inf, - 1: None}, - has_nan=True) - .add_field(sql="MAP(ARRAY[CAST('NaN' AS DOUBLE)], ARRAY[CAST('NaN' AS DOUBLE)])", - python={math.nan: math.nan}, - has_nan=True) - .add_field(sql="MAP(ARRAY[1.2, 2.4, 4.8], ARRAY[1.2, 2.4, null])", - python={Decimal("1.2"): Decimal("1.2"), Decimal("2.4"): Decimal("2.4"), Decimal("4.8"): None}) - .add_field(sql="MAP(" - "ARRAY[CAST('hello' AS VARCHAR), CAST('null' AS VARCHAR)], " - "ARRAY[CAST('hello' AS VARCHAR), null])", - python={'hello': 'hello', 'null': None}) - .add_field(sql="MAP(ARRAY[CAST('a' AS CHAR(4)), CAST('null' AS CHAR(4))], ARRAY[CAST('a' AS CHAR), null])", - python={'a ': 'a', 'null': None}) - .add_field(sql="MAP(ARRAY[X'', X'65683F', X'00'], ARRAY[X'', X'65683F', null])", - python={b'': b'', b'eh?': b'eh?', b'\x00': None}) - .add_field(sql="MAP(ARRAY[JSON '1', JSON '{}', JSON 'null'], ARRAY[JSON '1', JSON '{}', null])", - python={'1': '1', '{}': '{}', 'null': None}) + .add_field( + sql="MAP(" + "ARRAY[CAST('NaN' AS REAL), CAST('-Infinity' AS REAL), CAST(3.4028235E38 AS REAL), CAST(1.4E-45 AS REAL), CAST('Infinity' AS REAL), CAST(1 AS REAL)], " # noqa: E501 + "ARRAY[CAST('NaN' AS REAL), CAST('-Infinity' AS REAL), CAST(3.4028235E38 AS REAL), CAST(1.4E-45 AS REAL), CAST('Infinity' AS REAL), null])", # noqa: E501 + python={ + math.nan: math.nan, + -math.inf: -math.inf, + 3.4028235e38: 3.4028235e38, + 1.4e-45: 1.4e-45, + math.inf: math.inf, + 1: None, + }, + has_nan=True, + ) + .add_field( + sql="MAP(" + "ARRAY[CAST('NaN' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST(1.7976931348623157E308 AS DOUBLE), CAST(4.9E-324 AS DOUBLE), CAST('Infinity' AS DOUBLE), CAST(1 AS DOUBLE)], " # noqa: E501 + "ARRAY[CAST('NaN' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST(1.7976931348623157E308 AS DOUBLE), CAST(4.9E-324 AS DOUBLE), CAST('Infinity' AS DOUBLE), null])", # noqa: E501 + python={ + math.nan: math.nan, + -math.inf: -math.inf, + 1.7976931348623157e308: 1.7976931348623157e308, + 5e-324: 5e-324, + math.inf: math.inf, + 1: None, + }, + has_nan=True, + ) + .add_field( + sql="MAP(ARRAY[CAST('NaN' AS DOUBLE)], ARRAY[CAST('NaN' AS DOUBLE)])", + python={math.nan: math.nan}, + has_nan=True, + ) + .add_field( + sql="MAP(ARRAY[1.2, 2.4, 4.8], ARRAY[1.2, 2.4, null])", + python={Decimal("1.2"): Decimal("1.2"), Decimal("2.4"): Decimal("2.4"), Decimal("4.8"): None}, + ) + .add_field( + sql="MAP(" + "ARRAY[CAST('hello' AS VARCHAR), CAST('null' AS VARCHAR)], " + "ARRAY[CAST('hello' AS VARCHAR), null])", + python={"hello": "hello", "null": None}, + ) + .add_field( + sql="MAP(ARRAY[CAST('a' AS CHAR(4)), CAST('null' AS CHAR(4))], ARRAY[CAST('a' AS CHAR), null])", + python={"a ": "a", "null": None}, + ) + .add_field( + sql="MAP(ARRAY[X'', X'65683F', X'00'], ARRAY[X'', X'65683F', null])", + python={b"": b"", b"eh?": b"eh?", b"\x00": None}, + ) + .add_field( + sql="MAP(ARRAY[JSON '1', JSON '{}', JSON 'null'], ARRAY[JSON '1', JSON '{}', null])", + python={"1": "1", "{}": "{}", "null": None}, + ) ).execute() # temporal types @@ -915,26 +697,32 @@ def test_map(trino_connection): time_2 = time(23, 59, 59) datetime_1 = datetime(1970, 1, 1, 1, 1, 1) datetime_2 = datetime(2023, 1, 1, 23, 59, 59) - SqlTest(trino_connection) \ - .add_field(sql="MAP(ARRAY[DATE '1970-01-01', DATE '2023-01-01'], ARRAY[DATE '1970-01-01', null])", - python={date(1970, 1, 1): date(1970, 1, 1), date(2023, 1, 1): None}) \ - .add_field(sql="MAP(ARRAY[TIME '01:01:01', TIME '23:59:59'], ARRAY[TIME '01:01:01', null])", - python={time_1: time_1, time_2: None}) \ - .add_field(sql="MAP(" - "ARRAY[TIME '01:01:01 +05:30', TIME '23:59:59 -05:00'], " - "ARRAY[TIME '01:01:01 +05:30', null])", - python={time_1.replace(tzinfo=tz_india): time_1.replace(tzinfo=tz_india), - time_2.replace(tzinfo=tz_new_york): None}) \ - .add_field(sql="MAP(" - "ARRAY[TIMESTAMP '1970-01-01 01:01:01', TIMESTAMP '2023-01-01 23:59:59'], " - "ARRAY[TIMESTAMP '1970-01-01 01:01:01', null])", - python={datetime_1: datetime_1, datetime_2: None}) \ - .add_field(sql="MAP(" - "ARRAY[TIMESTAMP '1970-01-01 01:01:01 +05:30', TIMESTAMP '2023-01-01 23:59:59 America/Los_Angeles'], " # noqa: E501 - "ARRAY[TIMESTAMP '1970-01-01 01:01:01 +05:30', null])", - python={datetime_1.replace(tzinfo=tz_india): datetime_1.replace(tzinfo=tz_india), - datetime_2.replace(tzinfo=tz_los_angeles): None}) \ - .execute() + SqlTest(trino_connection).add_field( + sql="MAP(ARRAY[DATE '1970-01-01', DATE '2023-01-01'], ARRAY[DATE '1970-01-01', null])", + python={date(1970, 1, 1): date(1970, 1, 1), date(2023, 1, 1): None}, + ).add_field( + sql="MAP(ARRAY[TIME '01:01:01', TIME '23:59:59'], ARRAY[TIME '01:01:01', null])", + python={time_1: time_1, time_2: None}, + ).add_field( + sql="MAP(" "ARRAY[TIME '01:01:01 +05:30', TIME '23:59:59 -05:00'], " "ARRAY[TIME '01:01:01 +05:30', null])", + python={ + time_1.replace(tzinfo=tz_india): time_1.replace(tzinfo=tz_india), + time_2.replace(tzinfo=tz_new_york): None, + }, + ).add_field( + sql="MAP(" + "ARRAY[TIMESTAMP '1970-01-01 01:01:01', TIMESTAMP '2023-01-01 23:59:59'], " + "ARRAY[TIMESTAMP '1970-01-01 01:01:01', null])", + python={datetime_1: datetime_1, datetime_2: None}, + ).add_field( + sql="MAP(" + "ARRAY[TIMESTAMP '1970-01-01 01:01:01 +05:30', TIMESTAMP '2023-01-01 23:59:59 America/Los_Angeles'], " # noqa: E501 + "ARRAY[TIMESTAMP '1970-01-01 01:01:01 +05:30', null])", + python={ + datetime_1.replace(tzinfo=tz_india): datetime_1.replace(tzinfo=tz_india), + datetime_2.replace(tzinfo=tz_los_angeles): None, + }, + ).execute() # structural types - note that none of these below tests work in the Trino JDBC Driver either. # TODO: https://github.com/trinodb/trino-python-client/issues/442 @@ -949,40 +737,39 @@ def test_map(trino_connection): def test_row(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="CAST(null AS ROW(x BIGINT, y DOUBLE))", python=None) \ - .add_field(sql="CAST(ROW(1, 2e0, null) AS ROW(x BIGINT, y DOUBLE, z DOUBLE))", python=(1, 2.0, None)) \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST(null AS ROW(x BIGINT, y DOUBLE))", python=None).add_field( + sql="CAST(ROW(1, 2e0, null) AS ROW(x BIGINT, y DOUBLE, z DOUBLE))", python=(1, 2.0, None) + ).execute() def test_ipaddress(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="CAST(null AS IPADDRESS)", python=None) \ - .add_field(sql="IPADDRESS '2001:db8::1'", python='2001:db8::1') \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST(null AS IPADDRESS)", python=None).add_field( + sql="IPADDRESS '2001:db8::1'", python="2001:db8::1" + ).execute() def test_uuid(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="CAST(null AS UUID)", python=None) \ - .add_field(sql="UUID '12151fd2-7586-11e9-8f9e-2a86e4085a59'", - python=uuid.UUID('12151fd2-7586-11e9-8f9e-2a86e4085a59')) \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST(null AS UUID)", python=None).add_field( + sql="UUID '12151fd2-7586-11e9-8f9e-2a86e4085a59'", python=uuid.UUID("12151fd2-7586-11e9-8f9e-2a86e4085a59") + ).execute() def test_digest(trino_connection): - SqlTest(trino_connection) \ - .add_field(sql="CAST(null AS HyperLogLog)", python=None) \ - .add_field(sql="CAST(null AS P4HyperLogLog)", python=None) \ - .add_field(sql="CAST(null AS SetDigest)", python=None) \ - .add_field(sql="CAST(null AS QDigest(BIGINT))", python=None) \ - .add_field(sql="CAST(null AS TDigest)", python=None) \ - .add_field(sql="approx_set(1)", python='AgwBAIADRAA=') \ - .add_field(sql="CAST(approx_set(1) AS P4HyperLogLog)", python='AwwAAAAg' + 'A' * 2730 + '==') \ - .add_field(sql="make_set_digest(1)", python='AQgAAAACCwEAgANEAAAgAAABAAAASsQF+7cDRAABAA==') \ - .add_field(sql="tdigest_agg(1)", - python='AAAAAAAAAPA/AAAAAAAA8D8AAAAAAABZQAAAAAAAAPA/AQAAAAAAAAAAAPA/AAAAAAAA8D8=') \ - .execute() + SqlTest(trino_connection).add_field(sql="CAST(null AS HyperLogLog)", python=None).add_field( + sql="CAST(null AS P4HyperLogLog)", python=None + ).add_field(sql="CAST(null AS SetDigest)", python=None).add_field( + sql="CAST(null AS QDigest(BIGINT))", python=None + ).add_field( + sql="CAST(null AS TDigest)", python=None + ).add_field( + sql="approx_set(1)", python="AgwBAIADRAA=" + ).add_field( + sql="CAST(approx_set(1) AS P4HyperLogLog)", python="AwwAAAAg" + "A" * 2730 + "==" + ).add_field( + sql="make_set_digest(1)", python="AQgAAAACCwEAgANEAAAgAAABAAAASsQF+7cDRAABAA==" + ).add_field( + sql="tdigest_agg(1)", python="AAAAAAAAAPA/AAAAAAAA8D8AAAAAAABZQAAAAAAAAPA/AQAAAAAAAAAAAPA/AAAAAAAA8D8=" + ).execute() class SqlTest: @@ -999,15 +786,14 @@ def add_field(self, sql, python, has_nan=False): return self def execute(self): - sql = 'SELECT ' + ',\n'.join(self.sql_args) + sql = "SELECT " + ",\n".join(self.sql_args) self.cur.execute(sql) actual_results = self.cur.fetchall() self._compare_results(actual_results[0], self.expected_results) def _are_equal_ignoring_nan(self, actual, expected) -> bool: - if isinstance(actual, float) and math.isnan(actual) \ - and isinstance(expected, float) and math.isnan(expected): + if isinstance(actual, float) and math.isnan(actual) and isinstance(expected, float) and math.isnan(expected): # Consider NaNs equal since we only want to make sure values round-trip return True return actual == expected @@ -1031,9 +817,11 @@ def _compare_results(self, actual_results, expected_results): # False # We create the NaNs using float("nan") which means PyTest's assert # will always fail on collections containing nan. - if (isinstance(actual, list) and isinstance(expected, list)) \ - or (isinstance(actual, set) and isinstance(expected, set)) \ - or (isinstance(actual, tuple) and isinstance(expected, tuple)): + if ( + (isinstance(actual, list) and isinstance(expected, list)) + or (isinstance(actual, set) and isinstance(expected, set)) + or (isinstance(actual, tuple) and isinstance(expected, tuple)) + ): for i, _ in enumerate(actual): if not self._are_equal_ignoring_nan(actual[i], expected[i]): # Will fail, here to provide useful assertion message @@ -1070,7 +858,7 @@ def __init__(self, trino_connection): self.cur = trino_connection.cursor(legacy_primitive_types=False) def execute(self, field): - sql = 'SELECT ' + field + sql = "SELECT " + field try: self.cur.execute(sql) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7360b109..576b0942 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -218,8 +218,7 @@ def sample_get_error_response_data(): "errorType": "USER_ERROR", "failureInfo": { "errorLocation": {"columnNumber": 15, "lineNumber": 1}, - "message": "line 1:15: Schema must be specified " - "when session schema is not set", + "message": "line 1:15: Schema must be specified " "when session schema is not set", "stack": [ "io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:48)", "io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:43)", @@ -298,4 +297,5 @@ def mock_get_and_post(): def sqlalchemy_version() -> str: import sqlalchemy + return sqlalchemy.__version__ diff --git a/tests/unit/oauth_test_utils.py b/tests/unit/oauth_test_utils.py index 956fee78..07064222 100644 --- a/tests/unit/oauth_test_utils.py +++ b/tests/unit/oauth_test_utils.py @@ -53,25 +53,28 @@ def __call__(self, request, uri, response_headers): if authorization and authorization.replace("Bearer ", "") in self.tokens: return [200, response_headers, json.dumps(self.sample_post_response_data)] elif self.redirect_server is None and self.token_server is not None: - return [401, - { - 'Www-Authenticate': ( - 'Bearer realm="Trino", token_type="JWT", ' - f'Bearer x_token_server="{self.token_server}"' - ), - 'Basic realm': '"Trino"' - }, - ""] - return [401, + return [ + 401, { - 'Www-Authenticate': ( - 'Bearer realm="Trino", token_type="JWT", ' - f'Bearer x_redirect_server="{self.redirect_server}", ' - f'x_token_server="{self.token_server}"' + "Www-Authenticate": ( + 'Bearer realm="Trino", token_type="JWT", ' f'Bearer x_token_server="{self.token_server}"' ), - 'Basic realm': '"Trino"' + "Basic realm": '"Trino"', }, - ""] + "", + ] + return [ + 401, + { + "Www-Authenticate": ( + 'Bearer realm="Trino", token_type="JWT", ' + f'Bearer x_redirect_server="{self.redirect_server}", ' + f'x_token_server="{self.token_server}"' + ), + "Basic realm": '"Trino"', + }, + "", + ] class GetTokenCallback: @@ -90,19 +93,19 @@ def __call__(self, request, uri, response_headers): def _get_token_requests(challenge_id): - return list(filter( - lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}", - httpretty.latest_requests())) + return list( + filter(lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}", httpretty.latest_requests()) + ) def _post_statement_requests(): - return list(filter( - lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH, - httpretty.latest_requests())) + return list( + filter(lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH, httpretty.latest_requests()) + ) class MultithreadedTokenServer: - Challenge = namedtuple('Challenge', ['token', 'attempts']) + Challenge = namedtuple("Challenge", ["token", "attempts"]) def __init__(self, sample_post_response_data, attempts=1): self.tokens = set() @@ -114,13 +117,13 @@ def __init__(self, sample_post_response_data, attempts=1): httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=self.post_statement_callback) + body=self.post_statement_callback, + ) # bind get token httpretty.register_uri( - method=httpretty.GET, - uri=re.compile(rf"{TOKEN_RESOURCE}/.*"), - body=self.get_token_callback) + method=httpretty.GET, uri=re.compile(rf"{TOKEN_RESOURCE}/.*"), body=self.get_token_callback + ) # noinspection PyUnusedLocal def post_statement_callback(self, request, uri, response_headers): @@ -135,9 +138,15 @@ def post_statement_callback(self, request, uri, response_headers): self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts) redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" token_server = f"{TOKEN_RESOURCE}/{challenge_id}" - return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", ' - f'x_token_server="{token_server}"', - 'Basic realm': '"Trino"'}, ""] + return [ + 401, + { + "Www-Authenticate": f'Bearer x_redirect_server="{redirect_server}", ' + f'x_token_server="{token_server}"', + "Basic realm": '"Trino"', + }, + "", + ] # noinspection PyUnusedLocal def get_token_callback(self, request, uri, response_headers): diff --git a/tests/unit/sqlalchemy/conftest.py b/tests/unit/sqlalchemy/conftest.py index 4103b10a..9ea0ff3f 100644 --- a/tests/unit/sqlalchemy/conftest.py +++ b/tests/unit/sqlalchemy/conftest.py @@ -41,7 +41,7 @@ def _assert_sqltype(this: SQLType, that: SQLType): _assert_sqltype(this.value_type, that.value_type) elif isinstance(this, ROW): assert len(this.attr_types) == len(that.attr_types) - for (this_attr, that_attr) in zip(this.attr_types, that.attr_types): + for this_attr, that_attr in zip(this.attr_types, that.attr_types): assert this_attr[0] == that_attr[0] _assert_sqltype(this_attr[1], that_attr[1]) diff --git a/tests/unit/sqlalchemy/test_compiler.py b/tests/unit/sqlalchemy/test_compiler.py index 4da5232e..24c259f5 100644 --- a/tests/unit/sqlalchemy/test_compiler.py +++ b/tests/unit/sqlalchemy/test_compiler.py @@ -27,18 +27,12 @@ metadata = MetaData() table_without_catalog = Table( - 'table', + "table", metadata, - Column('id', Integer), - Column('name', String), -) -table_with_catalog = Table( - 'table', - metadata, - Column('id', Integer), - schema='default', - trino_catalog='other' + Column("id", Integer), + Column("name", String), ) +table_with_catalog = Table("table", metadata, Column("id", Integer), schema="default", trino_catalog="other") @pytest.fixture @@ -47,8 +41,7 @@ def dialect(): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) def test_limit_offset(dialect): statement = select(table_without_catalog).limit(10).offset(0) @@ -57,8 +50,7 @@ def test_limit_offset(dialect): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) def test_limit(dialect): statement = select(table_without_catalog).limit(10) @@ -67,8 +59,7 @@ def test_limit(dialect): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) def test_offset(dialect): statement = select(table_without_catalog).offset(0) @@ -77,24 +68,23 @@ def test_offset(dialect): @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) def test_cte_insert_order(dialect): - cte = select(table_without_catalog).cte('cte') + cte = select(table_without_catalog).cte("cte") statement = insert(table_without_catalog).from_select(table_without_catalog.columns, cte) query = statement.compile(dialect=dialect) - assert str(query) == \ - 'INSERT INTO "table" (id, name) WITH cte AS \n'\ - '(SELECT "table".id AS id, "table".name AS name \n'\ - 'FROM "table")\n'\ - ' SELECT cte.id, cte.name \n'\ - 'FROM cte' + assert ( + str(query) == 'INSERT INTO "table" (id, name) WITH cte AS \n' + '(SELECT "table".id AS id, "table".name AS name \n' + 'FROM "table")\n' + " SELECT cte.id, cte.name \n" + "FROM cte" + ) @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) def test_catalogs_argument(dialect): statement = select(table_with_catalog) @@ -105,68 +95,62 @@ def test_catalogs_argument(dialect): def test_catalogs_create_table(dialect): statement = CreateTable(table_with_catalog) query = statement.compile(dialect=dialect) - assert str(query) == \ - '\n'\ - 'CREATE TABLE "other".default."table" (\n'\ - '\tid INTEGER\n'\ - ')\n'\ - '\n' + assert str(query) == "\n" 'CREATE TABLE "other".default."table" (\n' "\tid INTEGER\n" ")\n" "\n" @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) def test_table_clause(dialect): statement = select(table("user", column("id"), column("name"), column("description"))) query = statement.compile(dialect=dialect) - assert str(query) == 'SELECT user.id, user.name, user.description \nFROM user' + assert str(query) == "SELECT user.id, user.name, user.description \nFROM user" @pytest.mark.skipif( - sqlalchemy_version() < "1.4", - reason="columns argument to select() must be a Python list or other iterable" + sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable" ) @pytest.mark.parametrize( - 'function,element', + "function,element", [ - ('first_value', func.first_value), - ('last_value', func.last_value), - ('nth_value', func.nth_value), - ('lead', func.lead), - ('lag', func.lag), - ] + ("first_value", func.first_value), + ("last_value", func.last_value), + ("nth_value", func.nth_value), + ("lead", func.lead), + ("lag", func.lag), + ], ) def test_ignore_nulls(dialect, function, element): statement = select( element( table_without_catalog.c.id, ignore_nulls=True, - ).over(partition_by=table_without_catalog.c.name).label('window') + ) + .over(partition_by=table_without_catalog.c.name) + .label("window") ) query = statement.compile(dialect=dialect) - assert str(query) == \ - f'SELECT {function}("table".id) IGNORE NULLS OVER (PARTITION BY "table".name) AS window '\ - f'\nFROM "table"' + assert ( + str(query) == f'SELECT {function}("table".id) IGNORE NULLS OVER (PARTITION BY "table".name) AS window ' + f'\nFROM "table"' + ) statement = select( element( table_without_catalog.c.id, ignore_nulls=False, - ).over(partition_by=table_without_catalog.c.name).label('window') + ) + .over(partition_by=table_without_catalog.c.name) + .label("window") ) query = statement.compile(dialect=dialect) - assert str(query) == \ - f'SELECT {function}("table".id) OVER (PARTITION BY "table".name) AS window ' \ - f'\nFROM "table"' + assert str(query) == f'SELECT {function}("table".id) OVER (PARTITION BY "table".name) AS window ' f'\nFROM "table"' -@pytest.mark.skipif( - sqlalchemy_version() < "2.0", - reason="ImportError: cannot import name 'try_cast' from 'sqlalchemy'" -) +@pytest.mark.skipif(sqlalchemy_version() < "2.0", reason="ImportError: cannot import name 'try_cast' from 'sqlalchemy'") def test_try_cast(dialect): from sqlalchemy import try_cast + statement = select(try_cast(table_without_catalog.c.id, String)) query = statement.compile(dialect=dialect) assert str(query) == 'SELECT try_cast("table".id as VARCHAR) AS id \nFROM "table"' diff --git a/tests/unit/sqlalchemy/test_datatype_parse.py b/tests/unit/sqlalchemy/test_datatype_parse.py index 99b4b50d..320f90b4 100644 --- a/tests/unit/sqlalchemy/test_datatype_parse.py +++ b/tests/unit/sqlalchemy/test_datatype_parse.py @@ -69,7 +69,7 @@ def test_parse_cases(type_str: str, sql_type: TypeEngine, assert_sqltype): "CHAR(10)": CHAR(10), "VARCHAR(10)": VARCHAR(10), "DECIMAL(20)": DECIMAL(20), - "DECIMAL(20, 3)": DECIMAL(20, 3) + "DECIMAL(20, 3)": DECIMAL(20, 3), } @@ -188,7 +188,7 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype): "timestamp(3)": TIMESTAMP(3, timezone=False), "timestamp(6)": TIMESTAMP(6), "timestamp(12) with time zone": TIMESTAMP(12, timezone=True), - "timestamp with time zone": TIMESTAMP(timezone=True) + "timestamp with time zone": TIMESTAMP(timezone=True), } diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py index 294c016f..e78f64aa 100644 --- a/tests/unit/sqlalchemy/test_dialect.py +++ b/tests/unit/sqlalchemy/test_dialect.py @@ -25,32 +25,38 @@ def setup_method(self): "url, generated_url, expected_args, expected_kwargs", [ ( - make_url(trino_url( - user="user", - host="localhost", - )), - 'trino://user@localhost:8080/?source=trino-sqlalchemy', + make_url( + trino_url( + user="user", + host="localhost", + ) + ), + "trino://user@localhost:8080/?source=trino-sqlalchemy", list(), dict(host="localhost", catalog="system", user="user", port=8080, source="trino-sqlalchemy"), ), ( - make_url(trino_url( - user="user", - host="localhost", - port=443, - )), - 'trino://user@localhost:443/?source=trino-sqlalchemy', + make_url( + trino_url( + user="user", + host="localhost", + port=443, + ) + ), + "trino://user@localhost:443/?source=trino-sqlalchemy", list(), dict(host="localhost", port=443, catalog="system", user="user", source="trino-sqlalchemy"), ), ( - make_url(trino_url( - user="user", - password="pass", - host="localhost", - source="trino-rulez", - )), - 'trino://user:***@localhost:8080/?source=trino-rulez', + make_url( + trino_url( + user="user", + password="pass", + host="localhost", + source="trino-rulez", + ) + ), + "trino://user:***@localhost:8080/?source=trino-rulez", list(), dict( host="localhost", @@ -59,20 +65,22 @@ def setup_method(self): user="user", auth=BasicAuthentication("user", "pass"), http_scheme="https", - source="trino-rulez" + source="trino-rulez", ), ), ( - make_url(trino_url( - user="user", - host="localhost", - cert="/my/path/to/cert", - key="afdlsdfk%4#'", - )), - 'trino://user@localhost:8080/' - '?cert=%2Fmy%2Fpath%2Fto%2Fcert' - '&key=afdlsdfk%254%23%27' - '&source=trino-sqlalchemy', + make_url( + trino_url( + user="user", + host="localhost", + cert="/my/path/to/cert", + key="afdlsdfk%4#'", + ) + ), + "trino://user@localhost:8080/" + "?cert=%2Fmy%2Fpath%2Fto%2Fcert" + "&key=afdlsdfk%254%23%27" + "&source=trino-sqlalchemy", list(), dict( host="localhost", @@ -81,18 +89,18 @@ def setup_method(self): user="user", auth=CertificateAuthentication("/my/path/to/cert", "afdlsdfk%4#'"), http_scheme="https", - source="trino-sqlalchemy" + source="trino-sqlalchemy", ), ), ( - make_url(trino_url( - user="user", - host="localhost", - access_token="afdlsdfk%4#'", - )), - 'trino://user@localhost:8080/' - '?access_token=afdlsdfk%254%23%27' - '&source=trino-sqlalchemy', + make_url( + trino_url( + user="user", + host="localhost", + access_token="afdlsdfk%4#'", + ) + ), + "trino://user@localhost:8080/" "?access_token=afdlsdfk%254%23%27" "&source=trino-sqlalchemy", list(), dict( host="localhost", @@ -101,26 +109,28 @@ def setup_method(self): user="user", auth=JWTAuthentication("afdlsdfk%4#'"), http_scheme="https", - source="trino-sqlalchemy" + source="trino-sqlalchemy", ), ), ( - make_url(trino_url( - user="user", - host="localhost", - session_properties={"query_max_run_time": "1d"}, - http_headers={"trino": 1}, - extra_credential=[("a", "b"), ("c", "d")], - client_tags=["1", "sql"], - legacy_primitive_types=False, - )), - 'trino://user@localhost:8080/' - '?client_tags=%5B%221%22%2C+%22sql%22%5D' - '&extra_credential=%5B%5B%22a%22%2C+%22b%22%5D%2C+%5B%22c%22%2C+%22d%22%5D%5D' - '&http_headers=%7B%22trino%22%3A+1%7D' - '&legacy_primitive_types=false' - '&session_properties=%7B%22query_max_run_time%22%3A+%221d%22%7D' - '&source=trino-sqlalchemy', + make_url( + trino_url( + user="user", + host="localhost", + session_properties={"query_max_run_time": "1d"}, + http_headers={"trino": 1}, + extra_credential=[("a", "b"), ("c", "d")], + client_tags=["1", "sql"], + legacy_primitive_types=False, + ) + ), + "trino://user@localhost:8080/" + "?client_tags=%5B%221%22%2C+%22sql%22%5D" + "&extra_credential=%5B%5B%22a%22%2C+%22b%22%5D%2C+%5B%22c%22%2C+%22d%22%5D%5D" + "&http_headers=%7B%22trino%22%3A+1%7D" + "&legacy_primitive_types=false" + "&session_properties=%7B%22query_max_run_time%22%3A+%221d%22%7D" + "&source=trino-sqlalchemy", list(), dict( host="localhost", @@ -137,30 +147,33 @@ def setup_method(self): ), # url encoding ( - make_url(trino_url( - user="user@test.org/my_role", - password="pass /*&", - host="localhost", - session_properties={"query_max_run_time": "1d"}, - http_headers={"trino": 1}, - extra_credential=[ - ("user1@test.org/my_role", "user2@test.org/my_role"), - ("user3@test.org/my_role", "user36@test.org/my_role")], - legacy_primitive_types=False, - client_tags=["1 @& /\"", "sql"], - verify=False, - )), - 'trino://user%40test.org%2Fmy_role:***@localhost:8080/' - '?client_tags=%5B%221+%40%26+%2F%5C%22%22%2C+%22sql%22%5D' - '&extra_credential=%5B%5B%22user1%40test.org%2Fmy_role%22%2C+' - '%22user2%40test.org%2Fmy_role%22%5D%2C+' - '%5B%22user3%40test.org%2Fmy_role%22%2C+' - '%22user36%40test.org%2Fmy_role%22%5D%5D' - '&http_headers=%7B%22trino%22%3A+1%7D' - '&legacy_primitive_types=false' - '&session_properties=%7B%22query_max_run_time%22%3A+%221d%22%7D' - '&source=trino-sqlalchemy' - '&verify=false', + make_url( + trino_url( + user="user@test.org/my_role", + password="pass /*&", + host="localhost", + session_properties={"query_max_run_time": "1d"}, + http_headers={"trino": 1}, + extra_credential=[ + ("user1@test.org/my_role", "user2@test.org/my_role"), + ("user3@test.org/my_role", "user36@test.org/my_role"), + ], + legacy_primitive_types=False, + client_tags=['1 @& /"', "sql"], + verify=False, + ) + ), + "trino://user%40test.org%2Fmy_role:***@localhost:8080/" + "?client_tags=%5B%221+%40%26+%2F%5C%22%22%2C+%22sql%22%5D" + "&extra_credential=%5B%5B%22user1%40test.org%2Fmy_role%22%2C+" + "%22user2%40test.org%2Fmy_role%22%5D%2C+" + "%5B%22user3%40test.org%2Fmy_role%22%2C+" + "%22user36%40test.org%2Fmy_role%22%5D%5D" + "&http_headers=%7B%22trino%22%3A+1%7D" + "&legacy_primitive_types=false" + "&session_properties=%7B%22query_max_run_time%22%3A+%221d%22%7D" + "&source=trino-sqlalchemy" + "&verify=false", list(), dict( host="localhost", @@ -174,23 +187,26 @@ def setup_method(self): http_headers={"trino": 1}, extra_credential=[ ("user1@test.org/my_role", "user2@test.org/my_role"), - ("user3@test.org/my_role", "user36@test.org/my_role")], + ("user3@test.org/my_role", "user36@test.org/my_role"), + ], legacy_primitive_types=False, - client_tags=["1 @& /\"", "sql"], + client_tags=['1 @& /"', "sql"], verify=False, ), ), ( - make_url(trino_url( - user="user", - host="localhost", - roles={ - "hive": "finance", - "system": "analyst", - } - )), - 'trino://user@localhost:8080/' - '?roles=%7B%22hive%22%3A+%22finance%22%2C+%22system%22%3A+%22analyst%22%7D&source=trino-sqlalchemy', + make_url( + trino_url( + user="user", + host="localhost", + roles={ + "hive": "finance", + "system": "analyst", + }, + ) + ), + "trino://user@localhost:8080/" + "?roles=%7B%22hive%22%3A+%22finance%22%2C+%22system%22%3A+%22analyst%22%7D&source=trino-sqlalchemy", list(), dict( host="localhost", @@ -202,16 +218,18 @@ def setup_method(self): ), ), ( - make_url(trino_url( - user="user", - host="localhost", - client_tags=["1", "sql"], - legacy_prepared_statements=False, - )), - 'trino://user@localhost:8080/' - '?client_tags=%5B%221%22%2C+%22sql%22%5D' - '&legacy_prepared_statements=false' - '&source=trino-sqlalchemy', + make_url( + trino_url( + user="user", + host="localhost", + client_tags=["1", "sql"], + legacy_prepared_statements=False, + ) + ), + "trino://user@localhost:8080/" + "?client_tags=%5B%221%22%2C+%22sql%22%5D" + "&legacy_prepared_statements=false" + "&source=trino-sqlalchemy", list(), dict( host="localhost", @@ -226,11 +244,7 @@ def setup_method(self): ], ) def test_create_connect_args( - self, - url: URL, - generated_url: str, - expected_args: List[Any], - expected_kwargs: Dict[str, Any] + self, url: URL, generated_url: str, expected_args: List[Any], expected_kwargs: Dict[str, Any] ): assert repr(url) == generated_url @@ -265,62 +279,62 @@ def test_isolation_level(self): def test_trino_connection_basic_auth(): dialect = TrinoDialect() - username = 'trino-user' - password = 'trino-bunny' - url = make_url(f'trino://{username}:{password}@host') + username = "trino-user" + password = "trino-bunny" + url = make_url(f"trino://{username}:{password}@host") _, cparams = dialect.create_connect_args(url) - assert cparams['http_scheme'] == "https" - assert isinstance(cparams['auth'], BasicAuthentication) - assert cparams['auth']._username == username - assert cparams['auth']._password == password + assert cparams["http_scheme"] == "https" + assert isinstance(cparams["auth"], BasicAuthentication) + assert cparams["auth"]._username == username + assert cparams["auth"]._password == password def test_trino_connection_jwt_auth(): dialect = TrinoDialect() - access_token = 'sample-token' - url = make_url(f'trino://host/?access_token={access_token}') + access_token = "sample-token" + url = make_url(f"trino://host/?access_token={access_token}") _, cparams = dialect.create_connect_args(url) - assert cparams['http_scheme'] == "https" - assert isinstance(cparams['auth'], JWTAuthentication) - assert cparams['auth'].token == access_token + assert cparams["http_scheme"] == "https" + assert isinstance(cparams["auth"], JWTAuthentication) + assert cparams["auth"].token == access_token def test_trino_connection_certificate_auth(): dialect = TrinoDialect() - cert = '/path/to/cert.pem' - key = '/path/to/key.pem' - url = make_url(f'trino://host/?cert={cert}&key={key}') + cert = "/path/to/cert.pem" + key = "/path/to/key.pem" + url = make_url(f"trino://host/?cert={cert}&key={key}") _, cparams = dialect.create_connect_args(url) - assert cparams['http_scheme'] == "https" - assert isinstance(cparams['auth'], CertificateAuthentication) - assert cparams['auth']._cert == cert - assert cparams['auth']._key == key + assert cparams["http_scheme"] == "https" + assert isinstance(cparams["auth"], CertificateAuthentication) + assert cparams["auth"]._cert == cert + assert cparams["auth"]._key == key def test_trino_connection_certificate_auth_cert_and_key_required(): dialect = TrinoDialect() - cert = '/path/to/cert.pem' - key = '/path/to/key.pem' - url = make_url(f'trino://host/?cert={cert}') + cert = "/path/to/cert.pem" + key = "/path/to/key.pem" + url = make_url(f"trino://host/?cert={cert}") _, cparams = dialect.create_connect_args(url) - assert 'http_scheme' not in cparams - assert 'auth' not in cparams + assert "http_scheme" not in cparams + assert "auth" not in cparams - url = make_url(f'trino://host/?key={key}') + url = make_url(f"trino://host/?key={key}") _, cparams = dialect.create_connect_args(url) - assert 'http_scheme' not in cparams - assert 'auth' not in cparams + assert "http_scheme" not in cparams + assert "auth" not in cparams def test_trino_connection_oauth2_auth(): dialect = TrinoDialect() - url = make_url('trino://host/?externalAuthentication=true') + url = make_url("trino://host/?externalAuthentication=true") _, cparams = dialect.create_connect_args(url) - assert cparams['http_scheme'] == "https" - assert isinstance(cparams['auth'], OAuth2Authentication) + assert cparams["http_scheme"] == "https" + assert isinstance(cparams["auth"], OAuth2Authentication) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 653423a0..b277ed17 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -119,7 +119,7 @@ def test_request_headers(mock_get_and_post): "catalog1": "NONE", # ensure backwards compatibility "catalog2": "ROLE{catalog2_role}", - } + }, ), http_scheme="http", ) @@ -160,14 +160,7 @@ def test_request_session_properties_headers(mock_get_and_post): req = TrinoRequest( host="coordinator", port=8080, - client_session=ClientSession( - user="test_user", - properties={ - "a": "1", - "b": "2", - "c": "more=v1,v2" - } - ) + client_session=ClientSession(user="test_user", properties={"a": "1", "b": "2", "c": "more=v1,v2"}), ) def assert_headers(headers): @@ -202,10 +195,10 @@ def test_additional_request_post_headers(mock_get_and_post): http_scheme="http", ) - sql = 'select 1' + sql = "select 1" additional_headers = { - 'X-Trino-Fake-1': 'one', - 'X-Trino-Fake-2': 'two', + "X-Trino-Fake-1": "one", + "X-Trino-Fake-2": "two", } combined_headers = req.http_headers @@ -215,7 +208,7 @@ def test_additional_request_post_headers(mock_get_and_post): # Validate that the post call was performed including the addtional headers _, post_kwargs = post.call_args - assert post_kwargs['headers'] == combined_headers + assert post_kwargs["headers"] == combined_headers def test_request_invalid_http_headers(): @@ -237,10 +230,7 @@ def test_request_client_tags_headers(mock_get_and_post): req = TrinoRequest( host="coordinator", port=8080, - client_session=ClientSession( - user="test_user", - client_tags=["tag1", "tag2"] - ), + client_session=ClientSession(user="test_user", client_tags=["tag1", "tag2"]), ) def assert_headers(headers): @@ -263,7 +253,7 @@ def test_request_client_tags_headers_no_client_tags(mock_get_and_post): port=8080, client_session=ClientSession( user="test_user", - ) + ), ) def assert_headers(headers): @@ -387,16 +377,12 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data): # bind post statement httpretty.register_uri( - method=httpretty.POST, - uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", body=post_statement_callback + ) # bind get token get_token_callback = GetTokenCallback(token_server, token, attempts) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() @@ -407,10 +393,11 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data): user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler)) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + ) response = request.post("select 1") - assert response.request.headers['Authorization'] == f"Bearer {token}" + assert response.request.headers["Authorization"] == f"Bearer {token}" assert redirect_handler.redirect_server == redirect_server assert get_token_callback.attempts == 0 assert len(_post_statement_requests()) == 2 @@ -428,20 +415,16 @@ def test_oauth2_refresh_token_flow(sample_post_response_data): # bind post statement httpretty.register_uri( - method=httpretty.POST, - uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", body=post_statement_callback + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandlerWithException( - trino.exceptions.TrinoAuthError( - "Do not use redirect handler when there is no redirect_uri in the response")) + trino.exceptions.TrinoAuthError("Do not use redirect handler when there is no redirect_uri in the response") + ) request = TrinoRequest( host="coordinator", @@ -450,11 +433,12 @@ def test_oauth2_refresh_token_flow(sample_post_response_data): user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler)) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + ) response = request.post("select 1") - assert response.request.headers['Authorization'] == f"Bearer {token}" + assert response.request.headers["Authorization"] == f"Bearer {token}" assert get_token_callback.attempts == 0 assert len(_post_statement_requests()) == 2 @@ -472,16 +456,12 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data): # bind post statement httpretty.register_uri( - method=httpretty.POST, - uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", body=post_statement_callback + ) # bind get token get_token_callback = GetTokenCallback(token_server, token, attempts) - httpretty.register_uri( - method=httpretty.GET, - uri=f"{TOKEN_RESOURCE}/{challenge_id}", - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=f"{TOKEN_RESOURCE}/{challenge_id}", body=get_token_callback) redirect_handler = RedirectHandler() @@ -492,7 +472,8 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data): user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler)) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + ) with pytest.raises(trino.exceptions.TrinoAuthError) as exp: request.post("select 1") @@ -503,20 +484,27 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data): assert len(_get_token_requests(challenge_id)) == _OAuth2TokenBearer.MAX_OAUTH_ATTEMPTS -@pytest.mark.parametrize("header,error", [ - ("", "Error: header WWW-Authenticate not available in the response."), - ('Bearer"', 'Error: header info didn\'t have x_token_server'), - ('x_redirect_server="redirect_server", x_token_server="token_server"', 'Error: header info didn\'t match x_redirect_server="redirect_server", x_token_server="token_server"'), # noqa: E501 - ('Bearer x_redirect_server="redirect_server"', 'Error: header info didn\'t have x_token_server'), -]) +@pytest.mark.parametrize( + "header,error", + [ + ("", "Error: header WWW-Authenticate not available in the response."), + ('Bearer"', "Error: header info didn't have x_token_server"), + ( + 'x_redirect_server="redirect_server", x_token_server="token_server"', + 'Error: header info didn\'t match x_redirect_server="redirect_server", x_token_server="token_server"', + ), # noqa: E501 + ('Bearer x_redirect_server="redirect_server"', "Error: header info didn't have x_token_server"), + ], +) @httprettified def test_oauth2_authentication_missing_headers(header, error): # bind post statement httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - adding_headers={'WWW-Authenticate': header}, - status=401) + adding_headers={"WWW-Authenticate": header}, + status=401, + ) request = TrinoRequest( host="coordinator", @@ -525,7 +513,8 @@ def test_oauth2_authentication_missing_headers(header, error): user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=RedirectHandler())) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=RedirectHandler()), + ) with pytest.raises(trino.exceptions.TrinoAuthError) as exp: request.post("select 1") @@ -533,16 +522,19 @@ def test_oauth2_authentication_missing_headers(header, error): assert str(exp.value) == error -@pytest.mark.parametrize("header", [ - 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge', - 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge="value"', - 'Bearer x_token_server="{token_server}", x_redirect_server="{redirect_server}"', - 'Basic realm="Trino", Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}"', - 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", Basic realm="Trino"', - 'Basic realm="Trino", Bearer realm="Trino", token_type="JWT", Bearer x_redirect_server="{redirect_server}", ' - 'x_token_server="{token_server}"' - 'Bearer x_redirect_server="{redirect_server}",x_token_server="{token_server}",additional_challenge', -]) +@pytest.mark.parametrize( + "header", + [ + 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge', + 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", additional_challenge="value"', + 'Bearer x_token_server="{token_server}", x_redirect_server="{redirect_server}"', + 'Basic realm="Trino", Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}"', + 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", Basic realm="Trino"', + 'Basic realm="Trino", Bearer realm="Trino", token_type="JWT", Bearer x_redirect_server="{redirect_server}", ' + 'x_token_server="{token_server}"' + 'Bearer x_redirect_server="{redirect_server}",x_token_server="{token_server}",additional_challenge', + ], +) @httprettified def test_oauth2_header_parsing(header, sample_post_response_data): token = str(uuid.uuid4()) @@ -556,21 +548,23 @@ def post_statement(request, uri, response_headers): authorization = request.headers.get("Authorization") if authorization and authorization.replace("Bearer ", "") in token: return [200, response_headers, json.dumps(sample_post_response_data)] - return [401, {'Www-Authenticate': header.format(redirect_server=redirect_server, token_server=token_server), - 'Basic realm': '"Trino"'}, ""] + return [ + 401, + { + "Www-Authenticate": header.format(redirect_server=redirect_server, token_server=token_server), + "Basic realm": '"Trino"', + }, + "", + ] # bind post statement httpretty.register_uri( - method=httpretty.POST, - uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=post_statement) + method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", body=post_statement + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() @@ -581,10 +575,10 @@ def post_statement(request, uri, response_headers): user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler), ).post("select 1") - assert response.request.headers['Authorization'] == f"Bearer {token}" + assert response.request.headers["Authorization"] == f"Bearer {token}" assert redirect_handler.redirect_server == redirect_server assert get_token_callback.attempts == 0 assert len(_post_statement_requests()) == 2 @@ -604,15 +598,12 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon # bind post statement httpretty.register_uri( - method=httpretty.POST, - uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + method=httpretty.POST, uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", body=post_statement_callback + ) httpretty.register_uri( - method=httpretty.GET, - uri=f"{TOKEN_RESOURCE}/{challenge_id}", - status=http_status, - body="error") + method=httpretty.GET, uri=f"{TOKEN_RESOURCE}/{challenge_id}", status=http_status, body="error" + ) redirect_handler = RedirectHandler() @@ -623,7 +614,8 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon user="test", ), http_scheme=constants.HTTPS, - auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler)) + auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + ) with pytest.raises(trino.exceptions.TrinoAuthError) as exp: request.post("select 1") @@ -656,7 +648,8 @@ def run(self) -> None: user="test", ), http_scheme=constants.HTTPS, - auth=auth) + auth=auth, + ) for i in range(10): # apparently HTTPretty in the current version is not thread-safe # https://github.com/gabrielfalcao/HTTPretty/issues/209 @@ -768,10 +761,7 @@ def test_trino_fetch_error(mock_requests, sample_get_error_response_data): assert "stack" in error.failure_info assert len(error.failure_info["stack"]) == 36 assert "suppressed" in error.failure_info - assert ( - error.message - == "line 1:15: Schema must be specified when session schema is not set" - ) + assert error.message == "line 1:15: Schema must be specified when session schema is not set" assert error.error_location == (1, 15) assert error.query_id == "20210817_140827_00000_arvdv" @@ -887,10 +877,7 @@ def __str__(self): req = TrinoRequest( host="coordinator", port=constants.DEFAULT_TLS_PORT, - client_session=ClientSession( - user="test", - extra_credential=[("foo", credential)] - ) + client_session=ClientSession(user="test", extra_credential=[("foo", credential)]), ) req.post("SELECT 1") @@ -939,22 +926,36 @@ def _gssapi_sname(principal: str): "options, expected_credentials, expected_hostname, expected_exception", [ ( - {}, None, None, does_not_raise(), + {}, + None, + None, + does_not_raise(), ), ( - {"hostname_override": "foo"}, None, "foo", does_not_raise(), + {"hostname_override": "foo"}, + None, + "foo", + does_not_raise(), ), ( - {"service_name": "bar"}, None, None, + {"service_name": "bar"}, + None, + None, pytest.raises(ValueError, match=r"must be used together with hostname_override"), ), ( - {"hostname_override": "foo", "service_name": "bar"}, None, _gssapi_sname("bar@foo"), does_not_raise(), + {"hostname_override": "foo", "service_name": "bar"}, + None, + _gssapi_sname("bar@foo"), + does_not_raise(), ), ( - {"principal": "foo"}, MockGssapiCredentials(_gssapi_uname("foo"), "initial"), None, does_not_raise(), + {"principal": "foo"}, + MockGssapiCredentials(_gssapi_uname("foo"), "initial"), + None, + does_not_raise(), ), - ] + ], ) def test_authentication_gssapi_init_arguments( options, @@ -999,7 +1000,7 @@ def retry_count(self): [ (KerberosAuthentication, KerberosExchangeError), (GSSAPIAuthentication, SPNEGOExchangeError), - ] + ], ) def test_authentication_fail_retry(auth_class, retry_exception_class, monkeypatch): post_retry = RetryRecorder(error=retry_exception_class()) @@ -1030,11 +1031,14 @@ def test_authentication_fail_retry(auth_class, retry_exception_class, monkeypatc assert post_retry.retry_count == attempts -@pytest.mark.parametrize("status_code, attempts", [ - (502, 3), - (503, 3), - (504, 3), -]) +@pytest.mark.parametrize( + "status_code, attempts", + [ + (502, 3), + (503, 3), + (504, 3), + ], +) def test_5XX_error_retry(status_code, attempts, monkeypatch): http_resp = TrinoRequest.http.Response() http_resp.status_code = status_code @@ -1051,7 +1055,7 @@ def test_5XX_error_retry(status_code, attempts, monkeypatch): client_session=ClientSession( user="test", ), - max_attempts=attempts + max_attempts=attempts, ) req.post("URL") @@ -1078,7 +1082,7 @@ def test_429_error_retry(monkeypatch): client_session=ClientSession( user="test", ), - max_attempts=3 + max_attempts=3, ) req.post("URL") @@ -1088,9 +1092,7 @@ def test_429_error_retry(monkeypatch): assert post_retry.retry_count == 3 -@pytest.mark.parametrize("status_code", [ - 501 -]) +@pytest.mark.parametrize("status_code", [501]) def test_error_no_retry(status_code, monkeypatch): http_resp = TrinoRequest.http.Response() http_resp.status_code = status_code @@ -1147,8 +1149,8 @@ class MockResponse(mock.Mock): @property def headers(self): return { - 'X-Trino-Fake-1': 'one', - 'X-Trino-Fake-2': 'two', + "X-Trino-Fake-1": "one", + "X-Trino-Fake-2": "two", } def json(self): @@ -1167,19 +1169,16 @@ def json(self): http_scheme="http", ) - sql = 'execute my_stament using 1, 2, 3' + sql = "execute my_stament using 1, 2, 3" additional_headers = { - constants.HEADER_PREPARED_STATEMENT: 'my_statement=added_prepare_statement_header', - constants.HEADER_CLIENT_CAPABILITIES: 'PARAMETRIC_DATETIME' + constants.HEADER_PREPARED_STATEMENT: "my_statement=added_prepare_statement_header", + constants.HEADER_CLIENT_CAPABILITIES: "PARAMETRIC_DATETIME", } # Patch the post function to avoid making the requests, as well as to # validate that the function was called with the right arguments. - with mock.patch.object(req, 'post', return_value=MockResponse()) as mock_post: - query = TrinoQuery( - request=req, - query=sql - ) + with mock.patch.object(req, "post", return_value=MockResponse()) as mock_post: + query = TrinoQuery(request=req, query=sql) result = query.execute(additional_http_headers=additional_headers) # Validate the the post function was called with the right argguments @@ -1256,10 +1255,7 @@ def test_request_headers_role_hive_all(mock_get_and_post): req = TrinoRequest( host="coordinator", port=8080, - client_session=ClientSession( - user="test_user", - roles={"hive": "ALL"} - ), + client_session=ClientSession(user="test_user", roles={"hive": "ALL"}), ) req.post("URL") @@ -1277,10 +1273,7 @@ def test_request_headers_role_admin(mock_get_and_post): req = TrinoRequest( host="coordinator", port=8080, - client_session=ClientSession( - user="test_user", - roles={"system": "admin"} - ), + client_session=ClientSession(user="test_user", roles={"system": "admin"}), ) roles = "system=" + urllib.parse.quote("ROLE{admin}") @@ -1324,10 +1317,7 @@ def test_request_headers_with_timezone(mock_get_and_post): req = TrinoRequest( host="coordinator", port=8080, - client_session=ClientSession( - user="test_user", - timezone="Europe/Brussels" - ), + client_session=ClientSession(user="test_user", timezone="Europe/Brussels"), ) req.post("URL") @@ -1365,10 +1355,7 @@ def test_request_with_invalid_timezone(mock_get_and_post): TrinoRequest( host="coordinator", port=8080, - client_session=ClientSession( - user="test_user", - timezone="INVALID_TIMEZONE" - ), + client_session=ClientSession(user="test_user", timezone="INVALID_TIMEZONE"), ) assert str(zinfo_error.value).startswith("'No time zone found with key") diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index ccd44d10..faf63c18 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -70,30 +70,27 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl # bind post statement to submit query httpretty.register_uri( - method=httpretty.POST, - uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", body=post_statement_callback + ) # bind get statement for result retrieval httpretty.register_uri( method=httpretty.GET, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1", - body=get_statement_callback) + body=get_statement_callback, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() with connect( - "coordinator", - user="test", - auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler), - http_scheme=constants.HTTPS + "coordinator", + user="test", + auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + http_scheme=constants.HTTPS, ) as conn: conn.cursor().execute("SELECT 1") conn.cursor().execute("SELECT 2") @@ -101,18 +98,15 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() with connect( - "coordinator", - user="test", - auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler), - http_scheme=constants.HTTPS + "coordinator", + user="test", + auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + http_scheme=constants.HTTPS, ) as conn2: conn2.cursor().execute("SELECT 1") conn2.cursor().execute("SELECT 2") @@ -122,8 +116,9 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl @httprettified -def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data, - sample_get_response_data): +def test_token_retrieved_once_when_authentication_instance_is_shared( + sample_post_response_data, sample_get_response_data +): token = str(uuid.uuid4()) challenge_id = str(uuid.uuid4()) @@ -135,50 +130,34 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post # bind post statement to submit query httpretty.register_uri( - method=httpretty.POST, - uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", body=post_statement_callback + ) # bind get statement for result retrieval httpretty.register_uri( method=httpretty.GET, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1", - body=get_statement_callback) + body=get_statement_callback, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() authentication = OAuth2Authentication(redirect_auth_url_handler=redirect_handler) - with connect( - "coordinator", - user="test", - auth=authentication, - http_scheme=constants.HTTPS - ) as conn: + with connect("coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS) as conn: conn.cursor().execute("SELECT 1") conn.cursor().execute("SELECT 2") conn.cursor().execute("SELECT 3") # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) - with connect( - "coordinator", - user="test", - auth=authentication, - http_scheme=constants.HTTPS - ) as conn2: + with connect("coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS) as conn2: conn2.cursor().execute("SELECT 1") conn2.cursor().execute("SELECT 2") conn2.cursor().execute("SELECT 3") @@ -200,33 +179,25 @@ def test_token_retrieved_once_when_multithreaded(sample_post_response_data, samp # bind post statement to submit query httpretty.register_uri( - method=httpretty.POST, - uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", body=post_statement_callback + ) # bind get statement for result retrieval httpretty.register_uri( method=httpretty.GET, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1", - body=get_statement_callback) + body=get_statement_callback, + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) - httpretty.register_uri( - method=httpretty.GET, - uri=token_server, - body=get_token_callback) + httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback) redirect_handler = RedirectHandler() authentication = OAuth2Authentication(redirect_auth_url_handler=redirect_handler) - conn = connect( - "coordinator", - user="test", - auth=authentication, - http_scheme=constants.HTTPS - ) + conn = connect("coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS) class RunningThread(threading.Thread): lock = threading.Lock() @@ -238,11 +209,7 @@ def run(self) -> None: with RunningThread.lock: conn.cursor().execute("SELECT 1") - threads = [ - RunningThread(), - RunningThread(), - RunningThread() - ] + threads = [RunningThread(), RunningThread(), RunningThread()] # run and join all threads for thread in threads: @@ -313,4 +280,4 @@ def test_hostname_parsing(): def test_description_is_none_when_cursor_is_not_executed(): connection = Connection("sample_trino_cluster:443") cursor = connection.cursor() - assert hasattr(cursor, 'description') + assert hasattr(cursor, "description") diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index a9f00586..f89cae8d 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -27,9 +27,7 @@ def test_isolation_level_levels() -> None: def test_isolation_level_values() -> None: - values = { - 0, 1, 2, 3, 4 - } + values = {0, 1, 2, 3, 4} assert IsolationLevel.values() == values diff --git a/trino/auth.py b/trino/auth.py index 99c5d2d6..31560b2e 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -98,22 +98,24 @@ def get_exceptions(self) -> Tuple[Any, ...]: try: from requests_kerberos.exceptions import KerberosExchangeError - return KerberosExchangeError, + return (KerberosExchangeError,) except ImportError: raise RuntimeError("unable to import requests_kerberos") def __eq__(self, other: object) -> bool: if not isinstance(other, KerberosAuthentication): return False - return (self._config == other._config - and self._service_name == other._service_name - and self._mutual_authentication == other._mutual_authentication - and self._force_preemptive == other._force_preemptive - and self._hostname_override == other._hostname_override - and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response - and self._principal == other._principal - and self._delegate == other._delegate - and self._ca_bundle == other._ca_bundle) + return ( + self._config == other._config + and self._service_name == other._service_name + and self._mutual_authentication == other._mutual_authentication + and self._force_preemptive == other._force_preemptive + and self._hostname_override == other._hostname_override + and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response + and self._principal == other._principal + and self._delegate == other._delegate + and self._ca_bundle == other._ca_bundle + ) class GSSAPIAuthentication(Authentication): @@ -173,9 +175,9 @@ def _get_credentials(self, principal: Optional[str] = None) -> Any: return None def _get_target_name( - self, - hostname_override: Optional[str] = None, - service_name: Optional[str] = None, + self, + hostname_override: Optional[str] = None, + service_name: Optional[str] = None, ) -> Any: if service_name is not None: try: @@ -195,22 +197,24 @@ def get_exceptions(self) -> Tuple[Any, ...]: try: from requests_gssapi.exceptions import SPNEGOExchangeError - return SPNEGOExchangeError, + return (SPNEGOExchangeError,) except ImportError: raise RuntimeError("unable to import requests_kerberos") def __eq__(self, other: object) -> bool: if not isinstance(other, GSSAPIAuthentication): return False - return (self._config == other._config - and self._service_name == other._service_name - and self._mutual_authentication == other._mutual_authentication - and self._force_preemptive == other._force_preemptive - and self._hostname_override == other._hostname_override - and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response - and self._principal == other._principal - and self._delegate == other._delegate - and self._ca_bundle == other._ca_bundle) + return ( + self._config == other._config + and self._service_name == other._service_name + and self._mutual_authentication == other._mutual_authentication + and self._force_preemptive == other._force_preemptive + and self._hostname_override == other._hostname_override + and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response + and self._principal == other._principal + and self._delegate == other._delegate + and self._ca_bundle == other._ca_bundle + ) class BasicAuthentication(Authentication): @@ -353,8 +357,9 @@ def __init__(self) -> None: logger.info("keyring module not found. OAuth2 token will not be stored in keyring.") def is_keyring_available(self) -> bool: - return self._keyring is not None \ - and not isinstance(self._keyring.get_keyring(), self._keyring.backends.fail.Keyring) + return self._keyring is not None and not isinstance( + self._keyring.get_keyring(), self._keyring.backends.fail.Keyring + ) def get_token_from_cache(self, key: Optional[str]) -> Optional[str]: password = self._keyring.get_password(key, "token") @@ -370,9 +375,11 @@ def get_token_from_cache(self, key: Optional[str]) -> Optional[str]: password += str(self._keyring.get_password(key, f"token__{i}")) except self._keyring.errors.NoKeyringError as e: - raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been " - "detected, check https://pypi.org/project/keyring/ for more " - "information.") from e + raise trino.exceptions.NotSupportedError( + "Although keyring module is installed no backend has been " + "detected, check https://pypi.org/project/keyring/ for more " + "information." + ) from e except ValueError: pass @@ -388,7 +395,7 @@ def store_token_to_cache(self, key: Optional[str], token: str) -> None: logger.debug(f"password is {len(token)} characters, sharding it.") password_shards = [ - token[i: i + MAX_NT_PASSWORD_SIZE] for i in range(0, len(token), MAX_NT_PASSWORD_SIZE) + token[i : i + MAX_NT_PASSWORD_SIZE] for i in range(0, len(token), MAX_NT_PASSWORD_SIZE) ] shard_info = { "sharded_password": True, @@ -401,15 +408,18 @@ def store_token_to_cache(self, key: Optional[str], token: str) -> None: for i, s in enumerate(password_shards): self._keyring.set_password(key, f"token__{i}", s) except self._keyring.errors.NoKeyringError as e: - raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been " - "detected, check https://pypi.org/project/keyring/ for more " - "information.") from e + raise trino.exceptions.NotSupportedError( + "Although keyring module is installed no backend has been " + "detected, check https://pypi.org/project/keyring/ for more " + "information." + ) from e class _OAuth2TokenBearer(AuthBase): """ Custom implementation of Trino OAuth2 based authentication to get the token """ + MAX_OAUTH_ATTEMPTS = 5 _BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE) @@ -428,9 +438,9 @@ def __call__(self, r: PreparedRequest) -> PreparedRequest: token = self._get_token_from_cache(key) if token is not None: - r.headers['Authorization'] = "Bearer " + token + r.headers["Authorization"] = "Bearer " + token - r.register_hook('response', self._authenticate) + r.register_hook("response", self._authenticate) return r @@ -455,7 +465,7 @@ def _authenticate(self, response: Response, **kwargs: Any) -> Optional[Response] def _attempt_oauth(self, response: Response, **kwargs: Any) -> None: # we have to handle the authentication, may be token the token expired, or it wasn't there at all - auth_info = response.headers.get('WWW-Authenticate') + auth_info = response.headers.get("WWW-Authenticate") if not auth_info: raise exceptions.TrinoAuthError("Error: header WWW-Authenticate not available in the response.") @@ -468,8 +478,8 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None: # x_token_server="https://trino.com/oauth2/token/uuid4"' auth_info_headers = self._parse_authenticate_header(auth_info) - auth_server = auth_info_headers.get('bearer x_redirect_server', auth_info_headers.get('x_redirect_server')) - token_server = auth_info_headers.get('bearer x_token_server', auth_info_headers.get('x_token_server')) + auth_server = auth_info_headers.get("bearer x_redirect_server", auth_info_headers.get("x_redirect_server")) + token_server = auth_info_headers.get("bearer x_token_server", auth_info_headers.get("x_token_server")) if token_server is None: raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server") @@ -500,7 +510,7 @@ def _retry_request(self, response: Response, **kwargs: Any) -> Optional[Response key = self._construct_cache_key(host, user) token = self._get_token_from_cache(key) if token is not None: - request.headers['Authorization'] = "Bearer " + token + request.headers["Authorization"] = "Bearer " + token retry_response = response.connection.send(request, **kwargs) retry_response.history.append(response) retry_response.request = request @@ -510,24 +520,24 @@ def _get_token(self, token_server: str, response: Response, **kwargs: Any) -> st attempts = 0 while attempts < self.MAX_OAUTH_ATTEMPTS: attempts += 1 - with response.connection.send(Request( - method='GET', url=token_server).prepare(), **kwargs) as response: + with response.connection.send(Request(method="GET", url=token_server).prepare(), **kwargs) as response: if response.status_code == 200: token_response = json.loads(response.text) - token = token_response.get('token') + token = token_response.get("token") if token: return token - error = token_response.get('error') + error = token_response.get("error") if error: raise exceptions.TrinoAuthError(f"Error while getting the token: {error}") else: - token_server = token_response.get('nextUri') + token_server = token_response.get("nextUri") logger.debug(f"nextURi auth token server: {token_server}") else: raise exceptions.TrinoAuthError( f"Error while getting the token response " f"status code: {response.status_code}, " - f"body: {response.text}") + f"body: {response.text}" + ) raise exceptions.TrinoAuthError("Exceeded max attempts while getting the token") @@ -571,10 +581,12 @@ def _parse_authenticate_header(header: str) -> Dict[str, str]: class OAuth2Authentication(Authentication): - def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([ - WebBrowserRedirectHandler(), - ConsoleRedirectHandler() - ])): + def __init__( + self, + redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler( + [WebBrowserRedirectHandler(), ConsoleRedirectHandler()] + ), + ): self._redirect_auth_url = redirect_auth_url_handler self._bearer = _OAuth2TokenBearer(self._redirect_auth_url) diff --git a/trino/client.py b/trino/client.py index 8dc4af80..d21eaf68 100644 --- a/trino/client.py +++ b/trino/client.py @@ -75,7 +75,7 @@ else: PROXIES = {} -_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$') +_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r"^\S[^\s=]*$") ROLE_PATTERN = re.compile(r"^ROLE\{(.*)\}$") @@ -250,10 +250,12 @@ def _format_roles(self, roles): is_legacy_role_pattern = ROLE_PATTERN.match(role) is not None if role in ("NONE", "ALL") or is_legacy_role_pattern: if is_legacy_role_pattern: - warnings.warn(f"A role '{role}' is provided using a legacy format. " - "Please remove the ROLE{} wrapping. Support for the legacy format might be " - "removed in a future release.", - DeprecationWarning) + warnings.warn( + f"A role '{role}' is provided using a legacy format. " + "Please remove the ROLE{} wrapping. Support for the legacy format might be " + "removed in a future release.", + DeprecationWarning, + ) formatted_roles[catalog] = role else: formatted_roles[catalog] = f"ROLE{{{role}}}" @@ -275,26 +277,17 @@ def get_header_values(headers, header): def get_session_property_values(headers, header): kvs = get_header_values(headers, header) - return [ - (k.strip(), urllib.parse.unquote_plus(v.strip())) - for k, v in (kv.split("=", 1) for kv in kvs if kv) - ] + return [(k.strip(), urllib.parse.unquote_plus(v.strip())) for k, v in (kv.split("=", 1) for kv in kvs if kv)] def get_prepared_statement_values(headers, header): kvs = get_header_values(headers, header) - return [ - (k.strip(), urllib.parse.unquote_plus(v.strip())) - for k, v in (kv.split("=", 1) for kv in kvs if kv) - ] + return [(k.strip(), urllib.parse.unquote_plus(v.strip())) for k, v in (kv.split("=", 1) for kv in kvs if kv)] def get_roles_values(headers, header): kvs = get_header_values(headers, header) - return [ - (k.strip(), urllib.parse.unquote_plus(v.strip())) - for k, v in (kv.split("=", 1) for kv in kvs if kv) - ] + return [(k.strip(), urllib.parse.unquote_plus(v.strip())) for k, v in (kv.split("=", 1) for kv in kvs if kv)] @dataclass @@ -324,16 +317,14 @@ def __repr__(self): class _DelayExponential: - def __init__( - self, base=0.1, exponent=2, jitter=True, max_delay=1800 # 100ms # 30 min - ): + def __init__(self, base=0.1, exponent=2, jitter=True, max_delay=1800): # 100ms # 30 min self._base = base self._exponent = exponent self._jitter = jitter self._max_delay = max_delay def __call__(self, attempt): - delay = float(self._base) * (self._exponent ** attempt) + delay = float(self._base) * (self._exponent**attempt) if self._jitter: delay *= random.random() delay = min(float(self._max_delay), delay) @@ -341,9 +332,7 @@ def __call__(self, attempt): class _RetryWithExponentialBackoff: - def __init__( - self, base=0.1, exponent=2, jitter=True, max_delay=1800 # 100ms # 30 min - ): + def __init__(self, base=0.1, exponent=2, jitter=True, max_delay=1800): # 100ms # 30 min self._get_delay = _DelayExponential(base, exponent, jitter, max_delay) def retry(self, func, args, kwargs, err, attempt): @@ -469,7 +458,7 @@ def http_headers(self) -> Dict[str, str]: headers[constants.HEADER_USER] = self._client_session.user headers[constants.HEADER_AUTHORIZATION_USER] = self._client_session.authorization_user headers[constants.HEADER_TIMEZONE] = self._client_session.timezone - headers[constants.HEADER_CLIENT_CAPABILITIES] = 'PARAMETRIC_DATETIME' + headers[constants.HEADER_CLIENT_CAPABILITIES] = "PARAMETRIC_DATETIME" headers["user-agent"] = f"{constants.CLIENT_NAME}/{__version__}" if len(self._client_session.roles.values()): headers[constants.HEADER_ROLE] = ",".join( @@ -502,8 +491,7 @@ def http_headers(self) -> Dict[str, str]: transaction_id = self._client_session.transaction_id headers[constants.HEADER_TRANSACTION] = transaction_id - if self._client_session.extra_credential is not None and \ - len(self._client_session.extra_credential) > 0: + if self._client_session.extra_credential is not None and len(self._client_session.extra_credential) > 0: for tup in self._client_session.extra_credential: self._verify_extra_credential(tup) @@ -511,10 +499,9 @@ def http_headers(self) -> Dict[str, str]: # HTTP 1.1 section 4.2 combine multiple extra credentials into a # comma-separated value # extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format) - headers[constants.HEADER_EXTRA_CREDENTIAL] = \ - ", ".join( - [f"{tup[0]}={urllib.parse.quote_plus(str(tup[1]))}" - for tup in self._client_session.extra_credential]) + headers[constants.HEADER_EXTRA_CREDENTIAL] = ", ".join( + [f"{tup[0]}={urllib.parse.quote_plus(str(tup[1]))}" for tup in self._client_session.extra_credential] + ) return headers @@ -623,15 +610,11 @@ def process(self, http_response) -> TrinoStatus: raise self._process_error(response["error"], response.get("id")) if constants.HEADER_CLEAR_SESSION in http_response.headers: - for prop in get_header_values( - http_response.headers, constants.HEADER_CLEAR_SESSION - ): + for prop in get_header_values(http_response.headers, constants.HEADER_CLEAR_SESSION): self._client_session.properties.pop(prop, None) if constants.HEADER_SET_SESSION in http_response.headers: - for key, value in get_session_property_values( - http_response.headers, constants.HEADER_SET_SESSION - ): + for key, value in get_session_property_values(http_response.headers, constants.HEADER_SET_SESSION): self._client_session.properties[key] = value if constants.HEADER_SET_CATALOG in http_response.headers: @@ -641,21 +624,15 @@ def process(self, http_response) -> TrinoStatus: self._client_session.schema = http_response.headers[constants.HEADER_SET_SCHEMA] if constants.HEADER_SET_ROLE in http_response.headers: - for key, value in get_roles_values( - http_response.headers, constants.HEADER_SET_ROLE - ): + for key, value in get_roles_values(http_response.headers, constants.HEADER_SET_ROLE): self._client_session.roles[key] = value if constants.HEADER_ADDED_PREPARE in http_response.headers: - for name, statement in get_prepared_statement_values( - http_response.headers, constants.HEADER_ADDED_PREPARE - ): + for name, statement in get_prepared_statement_values(http_response.headers, constants.HEADER_ADDED_PREPARE): self._client_session.prepared_statements[name] = statement if constants.HEADER_DEALLOCATED_PREPARE in http_response.headers: - for name in get_header_values( - http_response.headers, constants.HEADER_DEALLOCATED_PREPARE - ): + for name in get_header_values(http_response.headers, constants.HEADER_DEALLOCATED_PREPARE): self._client_session.prepared_statements.pop(name, None) if constants.HEADER_SET_AUTHORIZATION_USER in http_response.headers: @@ -690,7 +667,7 @@ def _verify_extra_credential(self, header): raise ValueError(f"whitespace or '=' are disallowed in extra credential '{key}'") try: - key.encode().decode('ascii') + key.encode().decode("ascii") except UnicodeDecodeError: raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'") @@ -737,10 +714,10 @@ class TrinoQuery: """Represent the execution of a SQL statement by Trino.""" def __init__( - self, - request: TrinoRequest, - query: str, - legacy_primitive_types: bool = False, + self, + request: TrinoRequest, + query: str, + legacy_primitive_types: bool = False, ) -> None: self._query_id: Optional[str] = None self._stats: Dict[Any, Any] = {} @@ -837,8 +814,9 @@ def _update_state(self, status): self._update_count = status.update_count self._next_uri = status.next_uri if not self._row_mapper and status.columns: - self._row_mapper = RowMapperFactory().create(columns=status.columns, - legacy_primitive_types=self._legacy_primitive_types) + self._row_mapper = RowMapperFactory().create( + columns=status.columns, legacy_primitive_types=self._legacy_primitive_types + ) if status.columns: self._columns = status.columns @@ -877,6 +855,7 @@ def cancel(self) -> None: def is_finished(self) -> bool: import warnings + warnings.warn("is_finished is deprecated, use finished instead", DeprecationWarning) return self.finished diff --git a/trino/constants.py b/trino/constants.py index 8193f218..529a1aa5 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -48,9 +48,9 @@ HEADER_STARTED_TRANSACTION = "X-Trino-Started-Transaction-Id" HEADER_TRANSACTION = "X-Trino-Transaction-Id" -HEADER_PREPARED_STATEMENT = 'X-Trino-Prepared-Statement' -HEADER_ADDED_PREPARE = 'X-Trino-Added-Prepare' -HEADER_DEALLOCATED_PREPARE = 'X-Trino-Deallocated-Prepare' +HEADER_PREPARED_STATEMENT = "X-Trino-Prepared-Statement" +HEADER_ADDED_PREPARE = "X-Trino-Added-Prepare" +HEADER_DEALLOCATED_PREPARE = "X-Trino-Deallocated-Prepare" HEADER_SET_SCHEMA = "X-Trino-Set-Schema" HEADER_SET_CATALOG = "X-Trino-Set-Catalog" diff --git a/trino/dbapi.py b/trino/dbapi.py index fb24f867..34179ed8 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -87,6 +87,7 @@ class TimeBoundLRUCache: """A bounded LRU cache which expires entries after a configured number of seconds. Note that expired entries will be evicted only on an attempted access (or through the LRU policy).""" + def __init__(self, capacity: int, ttl_seconds: int): self.capacity = capacity self.ttl_seconds = ttl_seconds @@ -268,7 +269,7 @@ def cursor(self, legacy_primitive_types: bool = None): self, request, # if legacy params are not explicitly set in Cursor, take them from Connection - legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types + legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types, ) def _use_legacy_prepared_statements(self): @@ -278,15 +279,16 @@ def _use_legacy_prepared_statements(self): value = must_use_legacy_prepared_statements.get((self.host, self.port)) if value is None: try: - query = trino.client.TrinoQuery( - self._create_request(), - query="EXECUTE IMMEDIATE 'SELECT 1'") + query = trino.client.TrinoQuery(self._create_request(), query="EXECUTE IMMEDIATE 'SELECT 1'") query.execute() value = False except Exception as e: logger.warning( "EXECUTE IMMEDIATE not available for %s:%s; defaulting to legacy prepared statements (%s)", - self.host, self.port, e) + self.host, + self.port, + e, + ) value = True must_use_legacy_prepared_statements.put((self.host, self.port), value) return value @@ -327,7 +329,7 @@ def from_column(cls, column: Dict[str, Any]): arguments[0]["value"] if raw_type in LENGTH_TYPES else None, # internal_size arguments[0]["value"] if raw_type in PRECISION_TYPES else None, # precision arguments[1]["value"] if raw_type in SCALE_TYPES else None, # scale - None # null_ok + None, # null_ok ) @@ -339,15 +341,9 @@ class Cursor: """ - def __init__( - self, - connection, - request, - legacy_primitive_types: bool = False): + def __init__(self, connection, request, legacy_primitive_types: bool = False): if not isinstance(connection, Connection): - raise ValueError( - "connection must be a Connection object: {}".format(type(connection)) - ) + raise ValueError("connection must be a Connection object: {}".format(type(connection))) self._connection = connection self._request = request @@ -381,9 +377,7 @@ def description(self) -> List[ColumnDescription]: return None # [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ] - return [ - ColumnDescription.from_column(col) for col in self._query.columns - ] + return [ColumnDescription.from_column(col) for col in self._query.columns] @property def rowcount(self): @@ -442,16 +436,13 @@ def _prepare_statement(self, statement: str, name: str) -> None: :param name: name that will be assigned to the prepared statement. """ sql = f"PREPARE {name} FROM {statement}" - query = trino.client.TrinoQuery(self.connection._create_request(), query=sql, - legacy_primitive_types=self._legacy_primitive_types) + query = trino.client.TrinoQuery( + self.connection._create_request(), query=sql, legacy_primitive_types=self._legacy_primitive_types + ) query.execute() - def _execute_prepared_statement( - self, - statement_name, - params - ): - sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params)) + def _execute_prepared_statement(self, statement_name, params): + sql = "EXECUTE " + statement_name + " USING " + ",".join(map(self._format_prepared_param, params)) return trino.client.TrinoQuery(self._request, query=sql, legacy_primitive_types=self._legacy_primitive_types) def _execute_immediate_statement(self, statement: str, params): @@ -461,10 +452,15 @@ def _execute_immediate_statement(self, statement: str, params): :param statement: sql to be executed. :param params: parameters to be bound. """ - sql = "EXECUTE IMMEDIATE '" + statement.replace("'", "''") + \ - "' USING " + ",".join(map(self._format_prepared_param, params)) + sql = ( + "EXECUTE IMMEDIATE '" + + statement.replace("'", "''") + + "' USING " + + ",".join(map(self._format_prepared_param, params)) + ) return trino.client.TrinoQuery( - self.connection._create_request(), query=sql, legacy_primitive_types=self._legacy_primitive_types) + self.connection._create_request(), query=sql, legacy_primitive_types=self._legacy_primitive_types + ) def _format_prepared_param(self, param): """ @@ -491,7 +487,7 @@ def _format_prepared_param(self, param): return "DOUBLE '%s'" % param if isinstance(param, str): - return ("'%s'" % param.replace("'", "''")) + return "'%s'" % param.replace("'", "''") if isinstance(param, (bytes, bytearray)): return "X'%s'" % param.hex() @@ -517,28 +513,25 @@ def _format_prepared_param(self, param): time_str = param.strftime("%H:%M:%S.%f") # named timezones if isinstance(param.tzinfo, ZoneInfo): - utc_offset = datetime.datetime.now(tz=param.tzinfo).strftime('%z') + utc_offset = datetime.datetime.now(tz=param.tzinfo).strftime("%z") return "TIME '%s %s:%s'" % (time_str, utc_offset[:3], utc_offset[3:]) # offset-based timezones - return "TIME '%s %s'" % (time_str, param.strftime('%Z')[3:]) + return "TIME '%s %s'" % (time_str, param.strftime("%Z")[3:]) if isinstance(param, datetime.date): date_str = param.strftime("%Y-%m-%d") return "DATE '%s'" % date_str if isinstance(param, list): - return "ARRAY[%s]" % ','.join(map(self._format_prepared_param, param)) + return "ARRAY[%s]" % ",".join(map(self._format_prepared_param, param)) if isinstance(param, tuple): - return "ROW(%s)" % ','.join(map(self._format_prepared_param, param)) + return "ROW(%s)" % ",".join(map(self._format_prepared_param, param)) if isinstance(param, dict): keys = list(param.keys()) values = [param[key] for key in keys] - return "MAP({}, {})".format( - self._format_prepared_param(keys), - self._format_prepared_param(values) - ) + return "MAP({}, {})".format(self._format_prepared_param(keys), self._format_prepared_param(values)) if isinstance(param, uuid.UUID): return "UUID '%s'" % param @@ -549,19 +542,19 @@ def _format_prepared_param(self, param): raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param)) def _deallocate_prepared_statement(self, statement_name: str) -> None: - sql = 'DEALLOCATE PREPARE ' + statement_name - query = trino.client.TrinoQuery(self.connection._create_request(), query=sql, - legacy_primitive_types=self._legacy_primitive_types) + sql = "DEALLOCATE PREPARE " + statement_name + query = trino.client.TrinoQuery( + self.connection._create_request(), query=sql, legacy_primitive_types=self._legacy_primitive_types + ) query.execute() def _generate_unique_statement_name(self): - return 'st_' + uuid.uuid4().hex.replace('-', '') + return "st_" + uuid.uuid4().hex.replace("-", "") def execute(self, operation, params=None): if params: assert isinstance(params, (list, tuple)), ( - 'params must be a list or tuple containing the query ' - 'parameter values' + "params must be a list or tuple containing the query " "parameter values" ) if self.connection._use_legacy_prepared_statements(): @@ -571,9 +564,7 @@ def execute(self, operation, params=None): try: # Send execute statement and assign the return value to `results` # as it will be returned by the function - self._query = self._execute_prepared_statement( - statement_name, params - ) + self._query = self._execute_prepared_statement(statement_name, params) self._iterator = iter(self._query.execute()) finally: # Send deallocate statement @@ -586,8 +577,9 @@ def execute(self, operation, params=None): self._iterator = iter(self._query.execute()) else: - self._query = trino.client.TrinoQuery(self._request, query=operation, - legacy_primitive_types=self._legacy_primitive_types) + self._query = trino.client.TrinoQuery( + self._request, query=operation, legacy_primitive_types=self._legacy_primitive_types + ) self._iterator = iter(self._query.execute()) return self @@ -726,13 +718,9 @@ def __eq__(self, other): STRING = DBAPITypeObject("VARCHAR", "CHAR", "VARBINARY", "JSON", "IPADDRESS") -BINARY = DBAPITypeObject( - "ARRAY", "MAP", "ROW", "HyperLogLog", "P4HyperLogLog", "QDigest" -) +BINARY = DBAPITypeObject("ARRAY", "MAP", "ROW", "HyperLogLog", "P4HyperLogLog", "QDigest") -NUMBER = DBAPITypeObject( - "BOOLEAN", "TINYINT", "SMALLINT", "INTEGER", "BIGINT", "REAL", "DOUBLE", "DECIMAL" -) +NUMBER = DBAPITypeObject("BOOLEAN", "TINYINT", "SMALLINT", "INTEGER", "BIGINT", "REAL", "DOUBLE", "DECIMAL") DATETIME = DBAPITypeObject( "DATE", diff --git a/trino/logging.py b/trino/logging.py index 5cd45bb9..a840047f 100644 --- a/trino/logging.py +++ b/trino/logging.py @@ -27,4 +27,4 @@ def get_logger(name: str, log_level: Optional[int] = None) -> logging.Logger: # set default log level to LEVEL -trino_root_logger = get_logger('trino', LEVEL) +trino_root_logger = get_logger("trino", LEVEL) diff --git a/trino/mapper.py b/trino/mapper.py index e3f89b30..0c4cafaa 100644 --- a/trino/mapper.py +++ b/trino/mapper.py @@ -44,9 +44,9 @@ def map(self, value: Any) -> Optional[bool]: return None if isinstance(value, bool): return value - if str(value).lower() == 'true': + if str(value).lower() == "true": return True - if str(value).lower() == 'false': + if str(value).lower() == "false": return False raise ValueError(f"Server sent unexpected value {value} of type {type(value)} for boolean") @@ -65,11 +65,11 @@ class DoubleValueMapper(ValueMapper[float]): def map(self, value: Any) -> Optional[float]: if value is None: return None - if value == 'Infinity': + if value == "Infinity": return float("inf") - if value == '-Infinity': + if value == "-Infinity": return float("-inf") - if value == 'NaN': + if value == "NaN": return float("nan") return float(value) @@ -110,12 +110,13 @@ def __init__(self, precision: int): def map(self, value: Any) -> Optional[time]: if value is None: return None - whole_python_temporal_value = value[:self.time_default_size] - remaining_fractional_seconds = value[self.time_default_size + 1:] - return Time( - time.fromisoformat(whole_python_temporal_value), - _fraction_to_decimal(remaining_fractional_seconds) - ).round_to(self.precision).to_python_type() + whole_python_temporal_value = value[: self.time_default_size] + remaining_fractional_seconds = value[self.time_default_size + 1 :] + return ( + Time(time.fromisoformat(whole_python_temporal_value), _fraction_to_decimal(remaining_fractional_seconds)) + .round_to(self.precision) + .to_python_type() + ) def _add_second(self, time_value: time) -> time: return (datetime.combine(datetime(1, 1, 1), time_value) + timedelta(seconds=1)).time() @@ -125,13 +126,17 @@ class TimeWithTimeZoneValueMapper(TimeValueMapper): def map(self, value: Any) -> Optional[time]: if value is None: return None - whole_python_temporal_value = value[:self.time_default_size] - remaining_fractional_seconds = value[self.time_default_size + 1:len(value) - 6] - timezone_part = value[len(value) - 6:] - return TimeWithTimeZone( - time.fromisoformat(whole_python_temporal_value).replace(tzinfo=_create_tzinfo(timezone_part)), - _fraction_to_decimal(remaining_fractional_seconds), - ).round_to(self.precision).to_python_type() + whole_python_temporal_value = value[: self.time_default_size] + remaining_fractional_seconds = value[self.time_default_size + 1 : len(value) - 6] + timezone_part = value[len(value) - 6 :] + return ( + TimeWithTimeZone( + time.fromisoformat(whole_python_temporal_value).replace(tzinfo=_create_tzinfo(timezone_part)), + _fraction_to_decimal(remaining_fractional_seconds), + ) + .round_to(self.precision) + .to_python_type() + ) class TimestampValueMapper(ValueMapper[datetime]): @@ -142,25 +147,33 @@ def __init__(self, precision: int): def map(self, value: Any) -> Optional[datetime]: if value is None: return None - whole_python_temporal_value = value[:self.datetime_default_size] - remaining_fractional_seconds = value[self.datetime_default_size + 1:] - return Timestamp( - datetime.fromisoformat(whole_python_temporal_value), - _fraction_to_decimal(remaining_fractional_seconds), - ).round_to(self.precision).to_python_type() + whole_python_temporal_value = value[: self.datetime_default_size] + remaining_fractional_seconds = value[self.datetime_default_size + 1 :] + return ( + Timestamp( + datetime.fromisoformat(whole_python_temporal_value), + _fraction_to_decimal(remaining_fractional_seconds), + ) + .round_to(self.precision) + .to_python_type() + ) class TimestampWithTimeZoneValueMapper(TimestampValueMapper): def map(self, value: Any) -> Optional[datetime]: if value is None: return None - datetime_with_fraction, timezone_part = value.rsplit(' ', 1) - whole_python_temporal_value = datetime_with_fraction[:self.datetime_default_size] - remaining_fractional_seconds = datetime_with_fraction[self.datetime_default_size + 1:] - return TimestampWithTimeZone( - datetime.fromisoformat(whole_python_temporal_value).replace(tzinfo=_create_tzinfo(timezone_part)), - _fraction_to_decimal(remaining_fractional_seconds), - ).round_to(self.precision).to_python_type() + datetime_with_fraction, timezone_part = value.rsplit(" ", 1) + whole_python_temporal_value = datetime_with_fraction[: self.datetime_default_size] + remaining_fractional_seconds = datetime_with_fraction[self.datetime_default_size + 1 :] + return ( + TimestampWithTimeZone( + datetime.fromisoformat(whole_python_temporal_value).replace(tzinfo=_create_tzinfo(timezone_part)), + _fraction_to_decimal(remaining_fractional_seconds), + ) + .round_to(self.precision) + .to_python_type() + ) def _create_tzinfo(timezone_str: str) -> tzinfo: @@ -183,7 +196,7 @@ def map(self, value: Any) -> Optional[relativedelta]: if value is None: return None is_negative = value[0] == "-" - years, months = (value[1:] if is_negative else value).split('-') + years, months = (value[1:] if is_negative else value).split("-") years, months = int(years), int(months) if is_negative: years, months = -years, -months @@ -195,11 +208,16 @@ def map(self, value: Any) -> Optional[timedelta]: if value is None: return None is_negative = value[0] == "-" - days, time = (value[1:] if is_negative else value).split(' ') - hours, minutes, seconds_milliseconds = time.split(':') - seconds, milliseconds = seconds_milliseconds.split('.') - days, hours, minutes, seconds, milliseconds = (int(days), int(hours), int(minutes), int(seconds), - int(milliseconds)) + days, time = (value[1:] if is_negative else value).split(" ") + hours, minutes, seconds_milliseconds = time.split(":") + seconds, milliseconds = seconds_milliseconds.split(".") + days, hours, minutes, seconds, milliseconds = ( + int(days), + int(hours), + int(minutes), + int(seconds), + int(milliseconds), + ) if is_negative: days, hours, minutes, seconds, milliseconds = -days, -hours, -minutes, -seconds, -milliseconds try: @@ -230,9 +248,7 @@ def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]) def map(self, value: Any) -> Optional[Dict[Any, Optional[Any]]]: if value is None: return None - return { - self.key_mapper.map(k): self.value_mapper.map(v) for k, v in value.items() - } + return {self.key_mapper.map(k): self.value_mapper.map(v) for k, v in value.items()} class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]): @@ -244,11 +260,7 @@ def __init__(self, mappers: List[ValueMapper[Any]], names: List[Optional[str]], def map(self, value: Optional[List[Any]]) -> Optional[Tuple[Optional[Any], ...]]: if value is None: return None - return NamedRowTuple( - list(self.mappers[i].map(v) for i, v in enumerate(value)), - self.names, - self.types - ) + return NamedRowTuple(list(self.mappers[i].map(v) for i, v in enumerate(value)), self.names, self.types) class UuidValueMapper(ValueMapper[uuid.UUID]): @@ -279,82 +291,84 @@ class RowMapperFactory: lambda functions (one for each column) which will process a data value and returns a RowMapper instance which will process rows of data """ + NO_OP_ROW_MAPPER = NoOpRowMapper() def create(self, columns: List[Any], legacy_primitive_types: bool) -> RowMapper | NoOpRowMapper: assert columns is not None if not legacy_primitive_types: - return RowMapper([self._create_value_mapper(column['typeSignature']) for column in columns]) + return RowMapper([self._create_value_mapper(column["typeSignature"]) for column in columns]) return RowMapperFactory.NO_OP_ROW_MAPPER def _create_value_mapper(self, column: Dict[str, Any]) -> ValueMapper[Any]: - col_type = column['rawType'] + col_type = column["rawType"] # primitive types - if col_type == 'boolean': + if col_type == "boolean": return BooleanValueMapper() - if col_type in {'tinyint', 'smallint', 'integer', 'bigint'}: + if col_type in {"tinyint", "smallint", "integer", "bigint"}: return IntegerValueMapper() - if col_type in {'double', 'real'}: + if col_type in {"double", "real"}: return DoubleValueMapper() - if col_type == 'decimal': + if col_type == "decimal": return DecimalValueMapper() - if col_type in {'varchar', 'char'}: + if col_type in {"varchar", "char"}: return StringValueMapper() - if col_type == 'varbinary': + if col_type == "varbinary": return BinaryValueMapper() - if col_type == 'json': + if col_type == "json": return StringValueMapper() - if col_type == 'date': + if col_type == "date": return DateValueMapper() - if col_type == 'time': + if col_type == "time": return TimeValueMapper(self._get_precision(column)) - if col_type == 'time with time zone': + if col_type == "time with time zone": return TimeWithTimeZoneValueMapper(self._get_precision(column)) - if col_type == 'timestamp': + if col_type == "timestamp": return TimestampValueMapper(self._get_precision(column)) - if col_type == 'timestamp with time zone': + if col_type == "timestamp with time zone": return TimestampWithTimeZoneValueMapper(self._get_precision(column)) - if col_type == 'interval year to month': + if col_type == "interval year to month": return IntervalYearToMonthMapper() - if col_type == 'interval day to second': + if col_type == "interval day to second": return IntervalDayToSecondMapper() # structural types - if col_type == 'array': - value_mapper = self._create_value_mapper(column['arguments'][0]['value']) + if col_type == "array": + value_mapper = self._create_value_mapper(column["arguments"][0]["value"]) return ArrayValueMapper(value_mapper) - if col_type == 'map': - key_mapper = self._create_value_mapper(column['arguments'][0]['value']) - value_mapper = self._create_value_mapper(column['arguments'][1]['value']) + if col_type == "map": + key_mapper = self._create_value_mapper(column["arguments"][0]["value"]) + value_mapper = self._create_value_mapper(column["arguments"][1]["value"]) return MapValueMapper(key_mapper, value_mapper) - if col_type == 'row': + if col_type == "row": mappers: List[ValueMapper[Any]] = [] names: List[Optional[str]] = [] types: List[str] = [] - for arg in column['arguments']: - mappers.append(self._create_value_mapper(arg['value']['typeSignature'])) - names.append(arg['value']['fieldName']['name'] if "fieldName" in arg['value'] else None) - types.append(arg['value']['typeSignature']['rawType']) + for arg in column["arguments"]: + mappers.append(self._create_value_mapper(arg["value"]["typeSignature"])) + names.append(arg["value"]["fieldName"]["name"] if "fieldName" in arg["value"] else None) + types.append(arg["value"]["typeSignature"]["rawType"]) return RowValueMapper(mappers, names, types) # others - if col_type == 'uuid': + if col_type == "uuid": return UuidValueMapper() return NoOpValueMapper() def _get_precision(self, column: Dict[str, Any]) -> int: - args = column['arguments'] + args = column["arguments"] if len(args) == 0: return 3 - return args[0]['value'] + return args[0]["value"] class RowMapper: """ Maps a row of data given a list of mapping functions """ + def __init__(self, columns: List[ValueMapper[Any]]): self.columns = columns diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py index 19737761..f34cc91d 100644 --- a/trino/sqlalchemy/compiler.py +++ b/trino/sqlalchemy/compiler.py @@ -106,11 +106,8 @@ def limit_clause(self, select, **kw): text += "\nLIMIT " + self.process(select._limit_clause, **kw) return text - def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, - fromhints=None, use_schema=True, **kwargs): - sql = super(TrinoSQLCompiler, self).visit_table( - table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs - ) + def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, fromhints=None, use_schema=True, **kwargs): + sql = super(TrinoSQLCompiler, self).visit_table(table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs) return self.add_catalog(sql, table) @staticmethod @@ -118,13 +115,10 @@ def add_catalog(sql, table): if table is None or not isinstance(table, DialectKWArgs): return sql - if ( - 'trino' not in table.dialect_options - or 'catalog' not in table.dialect_options['trino'] - ): + if "trino" not in table.dialect_options or "catalog" not in table.dialect_options["trino"]: return sql - catalog = table.dialect_options['trino']['catalog'] + catalog = table.dialect_options["trino"]["catalog"] sql = f'"{catalog}".{sql}' return sql @@ -146,23 +140,23 @@ class GenericIgnoreNulls(GenericFunction): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if kwargs.get('ignore_nulls'): + if kwargs.get("ignore_nulls"): self.ignore_nulls = True class FirstValue(GenericIgnoreNulls): - name = 'first_value' + name = "first_value" class LastValue(GenericIgnoreNulls): - name = 'last_value' + name = "last_value" class NthValue(GenericIgnoreNulls): - name = 'nth_value' + name = "nth_value" class Lead(GenericIgnoreNulls): - name = 'lead' + name = "lead" class Lag(GenericIgnoreNulls): - name = 'lag' + name = "lag" @staticmethod @compiles(FirstValue) @@ -171,9 +165,9 @@ class Lag(GenericIgnoreNulls): @compiles(Lead) @compiles(Lag) def compile_ignore_nulls(element, compiler, **kwargs): - compiled = f'{element.name}({compiler.process(element.clauses)})' + compiled = f"{element.name}({compiler.process(element.clauses)})" if element.ignore_nulls: - compiled += ' IGNORE NULLS' + compiled += " IGNORE NULLS" return compiled def visit_try_cast(self, element, **kw): @@ -248,17 +242,17 @@ def visit_TIME(self, type_, **kw): return datatype def visit_JSON(self, type_, **kw): - return 'JSON' + return "JSON" def visit_MAP(self, type_, **kw): # the key and value types themselves need to be processed otherwise sqltypes.MAP(Float, Float) will get # rendered as MAP(FLOAT, FLOAT) instead of MAP(REAL, REAL) or MAP(DOUBLE, DOUBLE) key_type = self.process(type_.key_type, **kw) value_type = self.process(type_.value_type, **kw) - return f'MAP({key_type}, {value_type})' + return f"MAP({key_type}, {value_type})" def visit_ARRAY(self, type_, **kw): - return f'ARRAY({self.process(type_.item_type, **kw)})' + return f"ARRAY({self.process(type_.item_type, **kw)})" def visit_ROW(self, type_, **kw): return f'ROW({", ".join(f"{name} {self.process(attr_type, **kw)}" for name, attr_type in type_.attr_types)})' diff --git a/trino/sqlalchemy/datatype.py b/trino/sqlalchemy/datatype.py index f5ecf433..bb06f9bc 100644 --- a/trino/sqlalchemy/datatype.py +++ b/trino/sqlalchemy/datatype.py @@ -119,13 +119,11 @@ def process(value): class _JSONFormatter: @staticmethod def format_index(value): - return "$[\"%s\"]" % value + return '$["%s"]' % value @staticmethod def format_path(value): - return "$%s" % ( - "".join(["[\"%s\"]" % elem for elem in value]) - ) + return "$%s" % ("".join(['["%s"]' % elem for elem in value])) class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): @@ -238,7 +236,7 @@ def aware_split( elif character == close_bracket: parens -= 1 elif character == quote: - if quotes and string[j - len(escaped_quote) + 1: j + 1] != escaped_quote: + if quotes and string[j - len(escaped_quote) + 1 : j + 1] != escaped_quote: quotes = False elif not quotes: quotes = True diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index ad28b18a..4baf4804 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -52,10 +52,7 @@ class TrinoDialect(DefaultDialect): - def __init__(self, - json_serializer=None, - json_deserializer=None, - **kwargs): + def __init__(self, json_serializer=None, json_deserializer=None, **kwargs): DefaultDialect.__init__(self, **kwargs) self._json_serializer = json_serializer self._json_deserializer = json_deserializer @@ -142,7 +139,7 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any if "cert" in url.query and "key" in url.query: kwargs["http_scheme"] = "https" - kwargs["auth"] = CertificateAuthentication(unquote_plus(url.query['cert']), unquote_plus(url.query['key'])) + kwargs["auth"] = CertificateAuthentication(unquote_plus(url.query["cert"]), unquote_plus(url.query["key"])) if "externalAuthentication" in url.query: kwargs["http_scheme"] = "https" @@ -214,10 +211,7 @@ def _get_columns(self, connection: Connection, table_name: str, schema: str = No return columns def _get_partitions( - self, - connection: Connection, - table_name: str, - schema: str = None + self, connection: Connection, table_name: str, schema: str = None ) -> List[Dict[str, List[Any]]]: schema = schema or self._get_default_schema_name(connection) query = dedent( @@ -330,11 +324,7 @@ def get_indexes(self, connection: Connection, table_name: str, schema: str = Non logger.debug("Couldn't fetch partition columns. schema: %s, table: %s, error: %s", schema, table_name, e) if not partitioned_columns: return [] - partition_index = dict( - name="partition", - column_names=partitioned_columns, - unique=False - ) + partition_index = dict(name="partition", column_names=partitioned_columns, unique=False) return [partition_index] def get_sequence_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: @@ -371,14 +361,11 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str ).strip() try: res = connection.execute( - sql.text(query), - {"catalog_name": catalog_name, "schema_name": schema_name, "table_name": table_name} + sql.text(query), {"catalog_name": catalog_name, "schema_name": schema_name, "table_name": table_name} ) return dict(text=res.scalar()) except error.TrinoQueryError as e: - if e.error_name in ( - error.PERMISSION_DENIED, - ): + if e.error_name in (error.PERMISSION_DENIED,): return dict(text=None) raise diff --git a/trino/sqlalchemy/util.py b/trino/sqlalchemy/util.py index 58af9206..fa71b6f7 100644 --- a/trino/sqlalchemy/util.py +++ b/trino/sqlalchemy/util.py @@ -32,7 +32,7 @@ def _url( cert: Optional[str] = None, key: Optional[str] = None, verify: Optional[bool] = None, - roles: Optional[Dict[str, str]] = None + roles: Optional[Dict[str, str]] = None, ) -> str: """ Composes a SQLAlchemy connection string from the given database connection diff --git a/trino/transaction.py b/trino/transaction.py index b6b506d6..6e1fb68f 100644 --- a/trino/transaction.py +++ b/trino/transaction.py @@ -66,9 +66,7 @@ def request(self) -> trino.client.TrinoRequest: def begin(self) -> None: response = self._request.post(START_TRANSACTION) if not response.ok: - raise trino.exceptions.DatabaseError( - "failed to start transaction: {}".format(response.status_code) - ) + raise trino.exceptions.DatabaseError("failed to start transaction: {}".format(response.status_code)) transaction_id = response.headers.get(constants.HEADER_STARTED_TRANSACTION) if transaction_id and transaction_id != NO_TRANSACTION: self._id = response.headers[constants.HEADER_STARTED_TRANSACTION] @@ -87,9 +85,7 @@ def commit(self) -> None: try: list(query.execute()) except Exception as err: - raise trino.exceptions.DatabaseError( - "failed to commit transaction {}: {}".format(self._id, err) - ) + raise trino.exceptions.DatabaseError("failed to commit transaction {}: {}".format(self._id, err)) self._id = NO_TRANSACTION self._request.transaction_id = self._id @@ -98,8 +94,6 @@ def rollback(self) -> None: try: list(query.execute()) except Exception as err: - raise trino.exceptions.DatabaseError( - "failed to rollback transaction {}: {}".format(self._id, err) - ) + raise trino.exceptions.DatabaseError("failed to rollback transaction {}: {}".format(self._id, err)) self._id = NO_TRANSACTION self._request.transaction_id = self._id diff --git a/trino/types.py b/trino/types.py index 64d72cb4..f8cd8a01 100644 --- a/trino/types.py +++ b/trino/types.py @@ -36,9 +36,9 @@ def to_python_type(self) -> PythonTemporalType: def round_to(self, precision: int) -> TemporalType[PythonTemporalType]: """ - Python datetime and time only support up to microsecond precision - In case the supplied value exceeds the specified precision, - the value needs to be rounded. + Python datetime and time only support up to microsecond precision + In case the supplied value exceeds the specified precision, + the value needs to be rounded. """ precision = min(precision, MAX_PYTHON_TEMPORAL_PRECISION_POWER) remaining_fractional_seconds = self._remaining_fractional_seconds @@ -53,7 +53,7 @@ def round_to(self, precision: int) -> TemporalType[PythonTemporalType]: @abc.abstractmethod def add_time_delta(self, time_delta: timedelta) -> PythonTemporalType: """ - This method shall be overriden to implement fraction arithmetics. + This method shall be overriden to implement fraction arithmetics. """ pass @@ -99,6 +99,7 @@ def new_instance(self, value: datetime, fraction: Decimal) -> TimestampWithTimeZ class NamedRowTuple(Tuple[Any, ...]): """Custom tuple class as namedtuple doesn't support missing or duplicate names""" + def __new__(cls, values: List[Any], names: List[str], types: List[str]) -> NamedRowTuple: return cast(NamedRowTuple, super().__new__(cls, values))