From 48a85315f2d4d2cc110b48e8458d881082ffc110 Mon Sep 17 00:00:00 2001 From: Maxwell Muoto <41130755+max-muoto@users.noreply.github.com> Date: Sun, 20 Oct 2024 17:16:35 -0500 Subject: [PATCH 1/3] Improve `read_database_uri` overloads --- py-polars/polars/io/database/functions.py | 30 +++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/py-polars/polars/io/database/functions.py b/py-polars/polars/io/database/functions.py index ac5dcaaac3e7..5cc33d94bedf 100644 --- a/py-polars/polars/io/database/functions.py +++ b/py-polars/polars/io/database/functions.py @@ -263,6 +263,36 @@ def read_database( ) +@overload +def read_database_uri( + query: str, + uri: str, + *, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + protocol: str | None = None, + engine: Literal["adbc"], + schema_overrides: SchemaDict | None = None, + execute_options: dict[str, Any] | None = None, +) -> DataFrame: ... + + +@overload +def read_database_uri( + query: list[str] | str, + uri: str, + *, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + protocol: str | None = None, + engine: Literal["connectorx"] | None = None, + schema_overrides: SchemaDict | None = None, + execute_options: None = None, +) -> DataFrame: ... + + def read_database_uri( query: list[str] | str, uri: str, From 48c6f71fc52361b28f59205d393c1ca1863619b9 Mon Sep 17 00:00:00 2001 From: Maxwell Muoto <41130755+max-muoto@users.noreply.github.com> Date: Sun, 20 Oct 2024 17:26:01 -0500 Subject: [PATCH 2/3] Typing fixes --- py-polars/polars/io/database/functions.py | 15 +++++++++++++++ py-polars/tests/unit/io/database/test_read.py | 5 +++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/io/database/functions.py b/py-polars/polars/io/database/functions.py index 5cc33d94bedf..aa686e2f813e 100644 --- a/py-polars/polars/io/database/functions.py +++ b/py-polars/polars/io/database/functions.py @@ -293,6 +293,21 @@ def read_database_uri( ) -> DataFrame: ... +@overload +def read_database_uri( + query: str, + uri: str, + *, + partition_on: str | None = None, + partition_range: tuple[int, int] | None = None, + partition_num: int | None = None, + protocol: str | None = None, + engine: DbReadEngine | None = None, + schema_overrides: None = None, + execute_options: dict[str, Any] | None = None, +) -> DataFrame: ... + + def read_database_uri( query: list[str] | str, uri: str, diff --git a/py-polars/tests/unit/io/database/test_read.py b/py-polars/tests/unit/io/database/test_read.py index deb44a5a79f4..fd20cf643add 100644 --- a/py-polars/tests/unit/io/database/test_read.py +++ b/py-polars/tests/unit/io/database/test_read.py @@ -292,17 +292,18 @@ def test_read_database( tmp_sqlite_db: Path, ) -> None: if read_method == "read_database_uri": + connect_using = cast(DbReadEngine, connect_using) # instantiate the connection ourselves, using connectorx/adbc df = pl.read_database_uri( uri=f"sqlite:///{tmp_sqlite_db}", query="SELECT * FROM test_data", - engine=str(connect_using), # type: ignore[arg-type] + engine=connect_using, schema_overrides=schema_overrides, ) df_empty = pl.read_database_uri( uri=f"sqlite:///{tmp_sqlite_db}", query="SELECT * FROM test_data WHERE name LIKE '%polars%'", - engine=str(connect_using), # type: ignore[arg-type] + engine=connect_using, schema_overrides=schema_overrides, ) elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]: From c0a029523eb83a9a1043fc844ef218e0a8c4c83d Mon Sep 17 00:00:00 2001 From: Maxwell Muoto <41130755+max-muoto@users.noreply.github.com> Date: Sun, 20 Oct 2024 17:27:42 -0500 Subject: [PATCH 3/3] Linting --- py-polars/tests/unit/io/database/test_read.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-polars/tests/unit/io/database/test_read.py b/py-polars/tests/unit/io/database/test_read.py index fd20cf643add..fb6fa8dca0ed 100644 --- a/py-polars/tests/unit/io/database/test_read.py +++ b/py-polars/tests/unit/io/database/test_read.py @@ -292,7 +292,7 @@ def test_read_database( tmp_sqlite_db: Path, ) -> None: if read_method == "read_database_uri": - connect_using = cast(DbReadEngine, connect_using) + connect_using = cast("DbReadEngine", connect_using) # instantiate the connection ourselves, using connectorx/adbc df = pl.read_database_uri( uri=f"sqlite:///{tmp_sqlite_db}",