Skip to content

Commit

Permalink
Use psycopg rather than psycopg2 for Risingwave
Browse files Browse the repository at this point in the history
  • Loading branch information
judahrand committed Jan 9, 2025
1 parent 4c83730 commit e74e91f
Show file tree
Hide file tree
Showing 21 changed files with 102 additions and 140 deletions.
2 changes: 1 addition & 1 deletion .github/renovate.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
},
{
"addLabels": ["risingwave"],
"matchPackageNames": ["/psycopg2/", "/risingwave/"]
"matchPackageNames": ["/psycopg/", "/risingwave/"]
},
{
"addLabels": ["snowflake"],
Expand Down
9 changes: 4 additions & 5 deletions ibis/backends/risingwave/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from operator import itemgetter
from typing import TYPE_CHECKING

import psycopg2
import psycopg
import sqlglot as sg
import sqlglot.expressions as sge
from psycopg2 import extras

import ibis
import ibis.backends.sql.compilers as sc
Expand Down Expand Up @@ -110,12 +109,12 @@ def do_connect(
month int32
"""

self.con = psycopg2.connect(
self.con = psycopg.connect(
host=host,
port=port,
user=user,
password=password,
database=database,
dbname=database,
options=(f"-csearch_path={schema}" * (schema is not None)) or None,
)

Expand Down Expand Up @@ -289,7 +288,7 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
)
with self.begin() as cur:
cur.execute(create_stmt_sql)
extras.execute_batch(cur, sql, data, 128)
cur.executemany(sql, data)

def list_databases(
self, *, like: str | None = None, catalog: str | None = None
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/risingwave/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class TestConf(ServiceBackendTest):
supports_structs = False
rounding_method = "half_to_even"
service_name = "risingwave"
deps = ("psycopg2",)
deps = ("psycopg",)

@property
def test_files(self) -> Iterable[Path]:
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/risingwave/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import ibis.expr.types as ir
from ibis.util import gen_name

pytest.importorskip("psycopg2")
pytest.importorskip("psycopg")

RISINGWAVE_TEST_DB = os.environ.get("IBIS_TEST_RISINGWAVE_DATABASE", "dev")
IBIS_RISINGWAVE_HOST = os.environ.get("IBIS_TEST_RISINGWAVE_HOST", "localhost")
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/risingwave/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import ibis.expr.datatypes as dt
from ibis import literal as L

pytest.importorskip("psycopg2")
pytest.importorskip("psycopg")


@pytest.mark.parametrize(("value", "expected"), [(0, None), (5.5, 5.5)])
Expand Down
19 changes: 0 additions & 19 deletions ibis/backends/tests/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,25 +131,6 @@
PsycoPgOperationalError
) = PsycoPgUndefinedObject = PsycoPgArraySubscriptError = None

try:
from psycopg2.errors import ArraySubscriptError as PsycoPg2ArraySubscriptError
from psycopg2.errors import DivisionByZero as PsycoPg2DivisionByZero
from psycopg2.errors import IndeterminateDatatype as PsycoPg2IndeterminateDatatype
from psycopg2.errors import InternalError_ as PsycoPg2InternalError
from psycopg2.errors import (
InvalidTextRepresentation as PsycoPg2InvalidTextRepresentation,
)
from psycopg2.errors import OperationalError as PsycoPg2OperationalError
from psycopg2.errors import ProgrammingError as PsycoPg2ProgrammingError
from psycopg2.errors import SyntaxError as PsycoPg2SyntaxError
from psycopg2.errors import UndefinedObject as PsycoPg2UndefinedObject
except ImportError:
PsycoPg2SyntaxError = PsycoPg2IndeterminateDatatype = (
PsycoPg2InvalidTextRepresentation
) = PsycoPg2DivisionByZero = PsycoPg2InternalError = PsycoPg2ProgrammingError = (
PsycoPg2OperationalError
) = PsycoPg2UndefinedObject = PsycoPg2ArraySubscriptError = None

try:
from MySQLdb import NotSupportedError as MySQLNotSupportedError
from MySQLdb import OperationalError as MySQLOperationalError
Expand Down
12 changes: 6 additions & 6 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
MySQLNotSupportedError,
OracleDatabaseError,
PolarsInvalidOperationError,
PsycoPg2InternalError,
PsycoPgInternalError,
Py4JError,
Py4JJavaError,
PyAthenaOperationalError,
Expand Down Expand Up @@ -963,7 +963,7 @@ def test_approx_quantile(con, filtered, multi):
),
pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="function covar_pop(integer, integer) does not exist",
),
],
Expand All @@ -983,7 +983,7 @@ def test_approx_quantile(con, filtered, multi):
),
pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="function covar_pop(integer, integer) does not exist",
),
],
Expand All @@ -1005,7 +1005,7 @@ def test_approx_quantile(con, filtered, multi):
),
pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="function covar_pop(integer, integer) does not exist",
),
],
Expand Down Expand Up @@ -1062,7 +1062,7 @@ def test_approx_quantile(con, filtered, multi):
),
pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="function covar_pop(integer, integer) does not exist",
),
],
Expand All @@ -1088,7 +1088,7 @@ def test_approx_quantile(con, filtered, multi):
),
pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="function covar_pop(integer, integer) does not exist",
),
],
Expand Down
44 changes: 21 additions & 23 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@
GoogleBadRequest,
MySQLOperationalError,
PolarsComputeError,
PsycoPg2IndeterminateDatatype,
PsycoPg2InternalError,
PsycoPg2ProgrammingError,
PsycoPgIndeterminateDatatype,
PsycoPgInternalError,
PsycoPgInvalidTextRepresentation,
PsycoPgProgrammingError,
PsycoPgSyntaxError,
Py4JJavaError,
PyAthenaDatabaseError,
Expand Down Expand Up @@ -506,7 +504,7 @@ def test_array_slice(backend, start, stop):
)
@pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="TODO(Kexiang): seems a bug",
)
@pytest.mark.notimpl(["athena"], raises=PyAthenaDatabaseError)
Expand Down Expand Up @@ -565,7 +563,7 @@ def test_array_map(con, input, output, func):
)
@pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="TODO(Kexiang): seems a bug",
)
@pytest.mark.notimpl(["athena"], raises=PyAthenaDatabaseError)
Expand Down Expand Up @@ -646,7 +644,7 @@ def test_array_map_with_index(con, input, output, func):
)
@pytest.mark.notyet(
"risingwave",
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="no support for not null column constraint",
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -693,7 +691,7 @@ def test_array_filter(con, input, output, predicate):
)
@pytest.mark.notyet(
"risingwave",
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="no support for not null column constraint",
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -740,7 +738,7 @@ def test_array_filter_with_index(con, input, output, predicate):
)
@pytest.mark.notyet(
"risingwave",
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="no support for not null column constraint",
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -1097,7 +1095,7 @@ def test_array_intersect(con, data):

@builtin_array
@pytest.mark.notimpl(["postgres"], raises=PsycoPgSyntaxError)
@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError)
@pytest.mark.notimpl(["risingwave"], raises=PsycoPgInternalError)
@pytest.mark.notimpl(
["trino"], reason="inserting maps into structs doesn't work", raises=TrinoUserError
)
Expand All @@ -1117,7 +1115,7 @@ def test_unnest_struct(con):

@builtin_array
@pytest.mark.notimpl(["postgres"], raises=PsycoPgSyntaxError)
@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2InternalError)
@pytest.mark.notimpl(["risingwave"], raises=PsycoPgInternalError)
@pytest.mark.notimpl(
["trino"], reason="inserting maps into structs doesn't work", raises=TrinoUserError
)
Expand Down Expand Up @@ -1208,7 +1206,7 @@ def test_zip_null(con, fn):

@builtin_array
@pytest.mark.notimpl(["postgres"], raises=PsycoPgSyntaxError)
@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2ProgrammingError)
@pytest.mark.notimpl(["risingwave"], raises=PsycoPgProgrammingError)
@pytest.mark.notimpl(["datafusion"], raises=Exception, reason="not yet supported")
@pytest.mark.notimpl(
["polars"],
Expand Down Expand Up @@ -1291,8 +1289,8 @@ def flatten_data():
reason="Risingwave doesn't truly support arrays of arrays",
raises=(
com.OperationNotDefinedError,
PsycoPg2IndeterminateDatatype,
PsycoPg2InternalError,
PsycoPgIndeterminateDatatype,
PsycoPgInternalError,
),
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -1399,7 +1397,7 @@ def test_range_start_stop_step(con, start, stop, step):
@pytest.mark.notimpl(["flink"], raises=com.OperationNotDefinedError)
@pytest.mark.never(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="Invalid parameter step: step size cannot equal zero",
)
def test_range_start_stop_step_zero(con, start, stop):
Expand Down Expand Up @@ -1432,7 +1430,7 @@ def test_unnest_empty_array(con):
@pytest.mark.notimpl(["sqlite"], raises=com.UnsupportedBackendType)
@pytest.mark.notyet(
"risingwave",
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="no support for not null column constraint",
)
@pytest.mark.notimpl(["athena"], raises=PyAthenaDatabaseError)
Expand Down Expand Up @@ -1515,7 +1513,7 @@ def swap(token):
id="pos",
marks=pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="function make_interval() does not exist",
),
),
Expand All @@ -1533,7 +1531,7 @@ def swap(token):
),
pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="function neg(interval) does not exist",
),
],
Expand All @@ -1553,7 +1551,7 @@ def swap(token):
),
pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="function neg(interval) does not exist",
),
],
Expand Down Expand Up @@ -1585,7 +1583,7 @@ def test_timestamp_range(con, start, stop, step, freq, tzinfo):
pytest.mark.notyet(["polars"], raises=PolarsComputeError),
pytest.mark.notyet(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="function make_interval() does not exist",
),
],
Expand All @@ -1604,7 +1602,7 @@ def test_timestamp_range(con, start, stop, step, freq, tzinfo):
),
pytest.mark.notyet(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="function neg(interval) does not exist",
),
],
Expand Down Expand Up @@ -1760,7 +1758,7 @@ def test_table_unnest_with_keep_empty(con):
["datafusion", "polars", "flink"], raises=com.OperationNotDefinedError
)
@pytest.mark.notyet(
["risingwave"], raises=PsycoPg2InternalError, reason="not supported in risingwave"
["risingwave"], raises=PsycoPgInternalError, reason="not supported in risingwave"
)
@pytest.mark.notimpl(
["athena"],
Expand All @@ -1781,9 +1779,9 @@ def test_table_unnest_column_expr(backend):
@pytest.mark.notimpl(["trino"], raises=TrinoUserError)
@pytest.mark.notimpl(["athena"], raises=PyAthenaOperationalError)
@pytest.mark.notimpl(["postgres"], raises=PsycoPgSyntaxError)
@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2ProgrammingError)
@pytest.mark.notimpl(["risingwave"], raises=PsycoPgProgrammingError)
@pytest.mark.notyet(
["risingwave"], raises=PsycoPg2InternalError, reason="not supported in risingwave"
["risingwave"], raises=PsycoPgInternalError, reason="not supported in risingwave"
)
def test_table_unnest_array_of_struct_of_array(con):
t = ibis.memtable(
Expand Down
12 changes: 6 additions & 6 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ExaQueryError,
ImpalaHiveServer2Error,
OracleDatabaseError,
PsycoPg2InternalError,
PsycoPgInternalError,
PsycoPgUndefinedObject,
Py4JJavaError,
PyAthenaDatabaseError,
Expand Down Expand Up @@ -415,7 +415,7 @@ def test_rename_table(con, temp_table, temp_table_orig):
)
@pytest.mark.never(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason='Feature is not yet implemented: column constraints "NOT NULL"',
)
def test_nullable_input_output(con, temp_table):
Expand Down Expand Up @@ -538,7 +538,7 @@ def test_insert_no_overwrite_from_dataframe(
@pytest.mark.notimpl(["polars"], reason="`insert` method not implemented")
@pytest.mark.notyet(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="truncate not supported upstream",
)
@pytest.mark.notyet(
Expand Down Expand Up @@ -584,7 +584,7 @@ def test_insert_no_overwrite_from_expr(
)
@pytest.mark.notyet(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="truncate not supported upstream",
)
def test_insert_overwrite_from_expr(
Expand All @@ -608,7 +608,7 @@ def test_insert_overwrite_from_expr(
)
@pytest.mark.notyet(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="truncate not supported upstream",
)
def test_insert_overwrite_from_list(con, employee_data_1_temp_table):
Expand Down Expand Up @@ -737,7 +737,7 @@ def test_list_database_contents(con):
@pytest.mark.notyet(["impala"], raises=ImpalaHiveServer2Error)
@pytest.mark.notyet(
["risingwave"],
raises=PsycoPg2InternalError,
raises=PsycoPgInternalError,
reason="unsigned integers are not supported",
)
@pytest.mark.notimpl(
Expand Down
Loading

0 comments on commit e74e91f

Please sign in to comment.