From b67990224232b78069900d35c006c3a86e45be77 Mon Sep 17 00:00:00 2001 From: Ted Conbeer Date: Mon, 6 Nov 2023 16:14:19 -0700 Subject: [PATCH] fix: do not crash if table is initialized with no data --- CHANGELOG.md | 2 ++ pyproject.toml | 1 + src/textual_fastdatatable/backend.py | 11 +++++++--- stubs/pyarrow/__init__.pyi | 14 +++++++------ stubs/pyarrow/parquet.pyi | 29 +++++++++++++++++++++++++- tests/unit_tests/test_arrow_backend.py | 6 ++++++ 6 files changed, 53 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f2215b5..6ed979d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +- Fixes a crash when computing the widths of columns with no rows ([#19](https://github.com/tconbeer/textual-fastdatatable/issues/19)). + ## [0.1.3] - 2023-10-09 - Fixes a crash when creating a column from a null or complex type. diff --git a/pyproject.toml b/pyproject.toml index 05657da..33ab1a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ target-version = "py38" python_version = "3.8" files = [ "src/**/*.py", + "tests/unit_tests/**/*.py" ] mypy_path = "src:stubs" diff --git a/src/textual_fastdatatable/backend.py b/src/textual_fastdatatable/backend.py index bb98968..4749458 100644 --- a/src/textual_fastdatatable/backend.py +++ b/src/textual_fastdatatable/backend.py @@ -177,10 +177,14 @@ def column_content_widths(self) -> list[int]: ], names=self.data.column_names, ) - self._column_content_widths = [ + content_widths = [ pc.max(pc.utf8_length(arr).fill_null(0)).as_py() for arr in self._string_data.itercolumns() ] + # pc.max returns None for each column without rows; we need to return 0 + # instead. + self._column_content_widths = [cw or 0 for cw in content_widths] + return self._column_content_widths def get_row_at(self, index: int) -> Sequence[Any]: @@ -279,11 +283,12 @@ def _safe_cast_arr_to_str(arr: pa._PandasConvertible) -> pa._PandasConvertible: and other nested types), we fall back to Python. """ try: - return arr.cast( + arr = arr.cast( pa.string(), safe=False, ) except pl.ArrowNotImplementedError: # todo: vectorize this with a pyarrow udf native_list = arr.to_pylist() - return pa.array([str(i) for i in native_list], type=pa.string()) + arr = pa.array([str(i) for i in native_list], type=pa.string()) + return arr.fill_null("") diff --git a/stubs/pyarrow/__init__.pyi b/stubs/pyarrow/__init__.pyi index 5ff89e7..1fb1e13 100644 --- a/stubs/pyarrow/__init__.pyi +++ b/stubs/pyarrow/__init__.pyi @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Iterable, Iterator, Literal, Mapping, Type, TypeVar +from typing import Any, Iterable, Iterator, Literal, Mapping, Sequence, Type, TypeVar from .compute import CastOptions from .types import DataType as DataType @@ -69,20 +69,22 @@ class _Tabular: self: T, field_: str | Field, column: Array | ChunkedArray ) -> T: ... def column(self, i: int | str) -> _PandasConvertible: ... + def equals(self: T, other: T, check_metadata: bool = False) -> bool: ... def itercolumns(self) -> Iterator[_PandasConvertible]: ... + def select(self: T, columns: Sequence[str | int]) -> T: ... def set_column( self: T, i: int, field_: str | Field, column: Array | ChunkedArray ) -> T: ... - def sort_by( - self: T, - sorting: str | list[tuple[str, Literal["ascending", "descending"]]], - **kwargs: Any, - ) -> T: ... def slice( # noqa: A003 self: T, offset: int = 0, length: int | None = None, ) -> T: ... + def sort_by( + self: T, + sorting: str | list[tuple[str, Literal["ascending", "descending"]]], + **kwargs: Any, + ) -> T: ... def to_pylist(self) -> list[dict[str, Any]]: ... class RecordBatch(_Tabular): ... diff --git a/stubs/pyarrow/parquet.pyi b/stubs/pyarrow/parquet.pyi index 6ffa9e0..7c2ef52 100644 --- a/stubs/pyarrow/parquet.pyi +++ b/stubs/pyarrow/parquet.pyi @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, BinaryIO +from typing import Any, BinaryIO, Literal from . import NativeFile, Schema, Table from .compute import Expression @@ -31,3 +31,30 @@ def read_table( thrift_string_size_limit: int | None = None, thrift_container_size_limit: int | None = None, ) -> Table: ... +def write_table( + table: Table, + where: str | NativeFile, + row_group_size: int | None = None, + version: Literal["1.0", "2.4", "2.6"] = "2.6", + use_dictionary: bool | list = True, + compression: Literal["none", "snappy", "gzip", "brotli", "lz4", "zstd"] + | dict[str, Literal["none", "snappy", "gzip", "brotli", "lz4", "zstd"]] = "snappy", + write_statistics: bool | list = True, + use_deprecated_int96_timestamps: bool | None = None, + coerce_timestamps: str | None = None, + allow_truncated_timestamps: bool = False, + data_page_size: int | None = None, + flavor: Literal["spark"] | None = None, + filesystem: FileSystem | None = None, + compression_level: int | dict | None = None, + use_byte_stream_split: bool | list = False, + column_encoding: str | dict | None = None, + data_page_version: Literal["1.0", "2.0"] = "1.0", + use_compliant_nested_type: bool = True, + encryption_properties: Any | None = None, + write_batch_size: int | None = None, + dictionary_pagesize_limit: int | None = None, + store_schema: bool = True, + write_page_index: bool = False, + **kwargs: Any, +) -> None: ... diff --git a/tests/unit_tests/test_arrow_backend.py b/tests/unit_tests/test_arrow_backend.py index 97f02b7..90eece9 100644 --- a/tests/unit_tests/test_arrow_backend.py +++ b/tests/unit_tests/test_arrow_backend.py @@ -157,3 +157,9 @@ def test_sort(backend: ArrowBackend) -> None: backend.sort(by=[("first column", "ascending")]) assert backend.data.equals(original_table) + + +def test_empty_query() -> None: + data: dict[str, list] = {"a": []} + backend = ArrowBackend.from_pydict(data) + assert backend.column_content_widths == [0]