Skip to content

Commit

Permalink
feat(datasets): create separate ibis.FileDataset
Browse files Browse the repository at this point in the history
Signed-off-by: Deepyaman Datta <[email protected]>
  • Loading branch information
deepyaman committed Sep 20, 2024
1 parent 552b973 commit fbcf8ff
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 6 deletions.
4 changes: 3 additions & 1 deletion kedro-datasets/kedro_datasets/ibis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import lazy_loader as lazy

# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901
FileDataset: Any
TableDataset: Any

__getattr__, __dir__, __all__ = lazy.attach(
__name__, submod_attrs={"table_dataset": ["TableDataset"]}
__name__,
submod_attrs={"file_dataset": ["FileDataset"], "table_dataset": ["TableDataset"]},
)
160 changes: 160 additions & 0 deletions kedro-datasets/kedro_datasets/ibis/file_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Provide file loading and saving functionality for Ibis's backends."""
from __future__ import annotations

from copy import deepcopy
from typing import TYPE_CHECKING, Any, ClassVar

import ibis.expr.types as ir
from kedro.io import AbstractDataset, DatasetError

if TYPE_CHECKING:
from ibis import BaseBackend


class FileDataset(AbstractDataset[ir.Table, ir.Table]):
"""``FileDataset`` loads/saves data from/to a specified file format.
Example usage for the
`YAML API <https://docs.kedro.org/en/stable/data/data_catalog_yaml_examples.html>`_:
.. code-block:: yaml
cars:
type: ibis.TableDataset
filepath: data/01_raw/company/cars.csv
file_format: csv
table_name: cars
connection:
backend: duckdb
database: company.db
load_args:
sep: ","
nullstr: "#NA"
save_args:
sep: ","
nullstr: "#NA"
Example usage for the
`Python API <https://docs.kedro.org/en/stable/data/\
advanced_data_catalog_usage.html>`_:
.. code-block:: pycon
>>> import ibis
>>> from kedro_datasets.ibis import FileDataset
>>>
>>> data = ibis.memtable({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = FileDataset(
... filepath=tmp_path / "test.csv",
... file_format="csv",
... table_name="test",
... connection={"backend": "duckdb", "database": tmp_path / "file.db"},
... )
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.execute().equals(reloaded.execute())
"""

DEFAULT_LOAD_ARGS: ClassVar[dict[str, Any]] = {}
DEFAULT_SAVE_ARGS: ClassVar[dict[str, Any]] = {}

_connections: ClassVar[dict[tuple[tuple[str, str]], BaseBackend]] = {}

def __init__( # noqa: PLR0913
self,
filepath: str,
file_format: str,
*,
table_name: str | None = None,
connection: dict[str, Any] | None = None,
load_args: dict[str, Any] | None = None,
save_args: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Creates a new ``FileDataset`` pointing to the given filepath.
``FileDataset`` connects to the Ibis backend object constructed
from the connection configuration. The `backend` key provided in
the config can be any of the `supported backends <https://ibis-\
project.org/install>`_. The remaining dictionary entries will be
passed as arguments to the underlying ``connect()`` method (e.g.
`ibis.duckdb.connect() <https://ibis-project.org/backends/duckdb\
#ibis.duckdb.connect>`_).
The read method corresponding to the given ``file_format`` (e.g.
`read_csv() <https://ibis-project.org/backends/\
duckdb#ibis.backends.duckdb.Backend.read_csv>`_) is used to load
the file with the backend. Note that only the data is loaded; no
link to the underlying file exists past ``FileDataset.load()``.
Args:
filepath: Path to a file to register as a table. Most useful
for loading data into your data warehouse (for testing).
On save, the backend exports data to the specified path.
file_format: String specifying the file format for the file.
table_name: The name to use for the created table (on load).
connection: Configuration for connecting to an Ibis backend.
load_args: Additional arguments passed to the Ibis backend's
`read_{file_format}` method.
save_args: Additional arguments passed to the Ibis backend's
`to_{file_format}` method.
metadata: Any arbitrary metadata. This is ignored by Kedro,
but may be consumed by users or external plugins.
"""
self._filepath = filepath
self._file_format = file_format
self._table_name = table_name
self._connection_config = connection
self.metadata = metadata

# Set load and save arguments, overwriting defaults if provided.
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)

self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

@property
def connection(self) -> BaseBackend:
def hashable(value):
if isinstance(value, dict):
return tuple((k, hashable(v)) for k, v in sorted(value.items()))
if isinstance(value, list):
return tuple(hashable(x) for x in value)
return value

cls = type(self)
key = hashable(self._connection_config)
if key not in cls._connections:
import ibis

config = deepcopy(self._connection_config)
backend_attr = config.pop("backend") if config else None
backend = getattr(ibis, backend_attr)
cls._connections[key] = backend.connect(**config)

return cls._connections[key]

def _load(self) -> ir.Table:
reader = getattr(self.connection, f"read_{self._file_format}")
return reader(self._filepath, self._table_name, **self._load_args)

def _save(self, data: ir.Table) -> None:
writer = getattr(self.connection, f"to_{self._file_format}")
writer(data, self._filepath, **self._save_args)

def _describe(self) -> dict[str, Any]:
return {
"filepath": self._filepath,
"file_format": self._file_format,
"table_name": self._table_name,
"backend": self._connection_config.get("backend")
if self._connection_config
else None,
"load_args": self._load_args,
"save_args": self._save_args,
}
5 changes: 0 additions & 5 deletions kedro-datasets/kedro_datasets/ibis/table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,10 @@ class TableDataset(AbstractDataset[ir.Table, ir.Table]):
cars:
type: ibis.TableDataset
filepath: data/01_raw/company/cars.csv
file_format: csv
table_name: cars
connection:
backend: duckdb
database: company.db
load_args:
sep: ","
nullstr: "#NA"
save_args:
materialized: table
Expand Down
99 changes: 99 additions & 0 deletions kedro-datasets/tests/ibis/test_file_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import duckdb
import ibis
import pytest
from kedro.io import DatasetError
from pandas.testing import assert_frame_equal

from kedro_datasets.ibis import FileDataset


@pytest.fixture
def filepath_csv(tmp_path):
return (tmp_path / "test.csv").as_posix()


@pytest.fixture
def database(tmp_path):
return (tmp_path / "file.db").as_posix()


@pytest.fixture(params=[None])
def connection_config(request, database):
return request.param or {"backend": "duckdb", "database": database}


@pytest.fixture
def file_dataset(filepath_csv, connection_config, load_args, save_args):
return FileDataset(
filepath=filepath_csv,
file_format="csv",
connection=connection_config,
load_args=load_args,
save_args=save_args,
)


@pytest.fixture
def dummy_table():
return ibis.memtable({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})


class TestFileDataset:
def test_save_and_load(self, file_dataset, dummy_table, database):
"""Test saving and reloading the data set."""
file_dataset.save(dummy_table)
reloaded = file_dataset.load()
assert_frame_equal(dummy_table.execute(), reloaded.execute())

@pytest.mark.parametrize("load_args", [{"filename": True}], indirect=True)
def test_load_extra_params(self, file_dataset, load_args, dummy_table):
"""Test overriding the default load arguments."""
file_dataset.save(dummy_table)
assert "filename" in file_dataset.load()

@pytest.mark.parametrize("save_args", [{"sep": "|"}], indirect=True)
def test_save_extra_params(
self, file_dataset, save_args, dummy_table, filepath_csv
):
"""Test overriding the default save arguments."""
file_dataset.save(dummy_table)

# Verify that the delimiter character from `save_args` was used.
with open(filepath_csv) as f:
for line in f:
assert save_args["sep"] in line

@pytest.mark.parametrize(
("connection_config", "key"),
[
(
{"backend": "duckdb", "database": "file.db", "extensions": ["spatial"]},
(
("backend", "duckdb"),
("database", "file.db"),
("extensions", ("spatial",)),
),
),
# https://github.com/kedro-org/kedro-plugins/pull/560#discussion_r1536083525
(
{
"host": "xxx.sql.azuresynapse.net",
"database": "xxx",
"query": {"driver": "ODBC Driver 17 for SQL Server"},
"backend": "mssql",
},
(
("backend", "mssql"),
("database", "xxx"),
("host", "xxx.sql.azuresynapse.net"),
("query", (("driver", "ODBC Driver 17 for SQL Server"),)),
),
),
],
indirect=["connection_config"],
)
def test_connection_config(self, mocker, file_dataset, connection_config, key):
"""Test hashing of more complicated connection configuration."""
mocker.patch(f"ibis.{connection_config['backend']}")
file_dataset.load()
assert key in file_dataset._connections

0 comments on commit fbcf8ff

Please sign in to comment.