From 873d39b6223dc42065604f7cf4056694ae5e7ab0 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 22 Oct 2024 16:21:52 +0400 Subject: [PATCH] feat(python): Various `Schema` improvements (new `base_types` method, improved equality/init dtype checks) --- py-polars/polars/_reexport.py | 2 + py-polars/polars/dataframe/frame.py | 2 +- py-polars/polars/datatypes/convert.py | 8 ++- py-polars/polars/io/spreadsheet/functions.py | 2 +- py-polars/polars/lazyframe/frame.py | 2 +- py-polars/polars/schema.py | 66 +++++++++++++++++-- py-polars/polars/series/struct.py | 2 +- .../unit/constructors/test_constructors.py | 5 +- py-polars/tests/unit/datatypes/test_list.py | 4 +- py-polars/tests/unit/datatypes/test_struct.py | 2 +- .../tests/unit/datatypes/test_temporal.py | 23 ++----- .../tests/unit/interop/test_from_pandas.py | 2 +- py-polars/tests/unit/interop/test_interop.py | 4 +- py-polars/tests/unit/io/test_spreadsheet.py | 4 +- .../tests/unit/operations/test_group_by.py | 2 +- py-polars/tests/unit/test_schema.py | 66 ++++++++++++++----- py-polars/tests/unit/test_selectors.py | 8 +-- 17 files changed, 138 insertions(+), 66 deletions(-) diff --git a/py-polars/polars/_reexport.py b/py-polars/polars/_reexport.py index 408fead781def..10818f4731662 100644 --- a/py-polars/polars/_reexport.py +++ b/py-polars/polars/_reexport.py @@ -3,12 +3,14 @@ from polars.dataframe import DataFrame from polars.expr import Expr, When from polars.lazyframe import LazyFrame +from polars.schema import Schema from polars.series import Series __all__ = [ "DataFrame", "Expr", "LazyFrame", + "Schema", "Series", "When", ] diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 8709af81983f5..c5bcd01f26d2a 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -912,7 +912,7 @@ def schema(self) -> Schema: >>> df.schema Schema({'foo': Int64, 'bar': Float64, 'ham': String}) """ - return Schema(zip(self.columns, self.dtypes)) + return Schema(zip(self.columns, self.dtypes), check_dtypes=False) def __array__( self, dtype: npt.DTypeLike | None = None, copy: bool | None = None diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index d46d8c1115810..f773cc28b6ca0 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -65,10 +65,14 @@ def is_polars_dtype( - dtype: Any, *, include_unknown: bool = False + dtype: Any, + *, + include_unknown: bool = False, + require_instantiated: bool = False, ) -> TypeGuard[PolarsDataType]: """Indicate whether the given input is a Polars dtype, or dtype specialization.""" - is_dtype = isinstance(dtype, (DataType, DataTypeClass)) + check_classes = DataType if require_instantiated else (DataType, DataTypeClass) + is_dtype = isinstance(dtype, check_classes) # type: ignore[arg-type] if not include_unknown: return is_dtype and dtype != Unknown diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 8b3df004f4b3c..1d4bb5fe90b63 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -565,7 +565,7 @@ def _read_spreadsheet( infer_schema_length=infer_schema_length, ) engine_options = (engine_options or {}).copy() - schema_overrides = dict(schema_overrides or {}) + schema_overrides = pl.Schema(schema_overrides or {}) # establish the reading function, parser, and available worksheets reader_fn, parser, worksheets = _initialise_spreadsheet_parser( diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index ad513bc1ff55a..64608ae825fb3 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -2263,7 +2263,7 @@ def collect_schema(self) -> Schema: >>> schema.len() 3 """ - return Schema(self._ldf.collect_schema()) + return Schema(self._ldf.collect_schema(), check_dtypes=False) @unstable() def sink_parquet( diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py index 72eb8b86d25e3..39620850ae719 100644 --- a/py-polars/polars/schema.py +++ b/py-polars/polars/schema.py @@ -4,20 +4,30 @@ from collections.abc import Mapping from typing import TYPE_CHECKING -from polars.datatypes import DataType +from polars.datatypes import DataType, is_polars_dtype from polars.datatypes._parse import parse_into_dtype -BaseSchema = OrderedDict[str, DataType] - if TYPE_CHECKING: from collections.abc import Iterable from polars._typing import PythonDataType + from polars.datatypes import DataTypeClass +BaseSchema = OrderedDict[str, DataType] + __all__ = ["Schema"] +def _check_dtype(tp: DataType | DataTypeClass) -> bool: + if not isinstance(tp, DataType): + # note: if nested, or has annotations, this implies required init params + if tp.is_nested() or tp.__annotations__: + msg = f"dtypes must be fully-specified, got: {tp!r}" + raise TypeError(msg) + return tp + + class Schema(BaseSchema): """ Ordered mapping of column names to their data type. @@ -58,15 +68,57 @@ def __init__( | Iterable[tuple[str, DataType | PythonDataType]] | None ) = None, + *, + check_dtypes: bool = True, ) -> None: input = ( schema.items() if schema and isinstance(schema, Mapping) else (schema or {}) ) - super().__init__({name: parse_into_dtype(tp) for name, tp in input}) # type: ignore[misc] - - def __setitem__(self, name: str, dtype: DataType | PythonDataType) -> None: + for name, tp in input: # type: ignore[misc] + if not check_dtypes: + super().__setitem__(name, tp) # type: ignore[assignment] + elif is_polars_dtype(tp): + super().__setitem__(name, _check_dtype(tp)) # type: ignore[assignment] + else: + self[name] = tp + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Mapping): + return False + if len(self) != len(other): + return False + for (nm1, tp1), (nm2, tp2) in zip(self.items(), other.items()): + if nm1 != nm2 or not tp1.is_(tp2): + return False + return True + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def __setitem__( + self, name: str, dtype: DataType | DataTypeClass | PythonDataType + ) -> None: + dtype = _check_dtype(parse_into_dtype(dtype)) super().__setitem__(name, parse_into_dtype(dtype)) # type: ignore[assignment] + def base_types(self) -> dict[str, DataTypeClass]: + """ + Return a dictionary of column names and the fundamental/root type class. + + Examples + -------- + >>> s = pl.Schema( + ... { + ... "x": pl.Float64(), + ... "y": pl.List(pl.Int32), + ... "z": pl.Struct([pl.Field("a", pl.Int8), pl.Field("b", pl.Boolean)]), + ... } + ... ) + >>> s.base_types() + {'x': Float64, 'y': List, 'z': Struct} + """ + return {name: tp.base_type() for name, tp in self.items()} + def names(self) -> list[str]: """Get the column names of the schema.""" return list(self.keys()) @@ -81,7 +133,7 @@ def len(self) -> int: def to_python(self) -> dict[str, type]: """ - Return Schema as a dictionary of column names and their Python types. + Return a dictionary of column names and Python types. Examples -------- diff --git a/py-polars/polars/series/struct.py b/py-polars/polars/series/struct.py index e8137a23be32f..a04d254808d37 100644 --- a/py-polars/polars/series/struct.py +++ b/py-polars/polars/series/struct.py @@ -107,7 +107,7 @@ def schema(self) -> Schema: return Schema({}) schema = self._s.dtype().to_schema() - return Schema(schema) + return Schema(schema, check_dtypes=False) def unnest(self) -> DataFrame: """ diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index d340433ddf10b..3ab507f31fa8a 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -150,7 +150,7 @@ def test_init_dict() -> None: data={"dt": dates, "dtm": datetimes}, schema=coldefs, ) - assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime} + assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime("us")} assert df.rows() == list(zip(py_dates, py_datetimes)) # Overriding dict column names/types @@ -251,7 +251,7 @@ class TradeNT(NamedTuple): ) assert df.schema == { "ts": pl.Datetime("ms"), - "tk": pl.Categorical, + "tk": pl.Categorical(ordering="physical"), "pc": pl.Decimal(scale=1), "sz": pl.UInt16, } @@ -284,7 +284,6 @@ class PageView(BaseModel): models = adapter.validate_json(data_json) result = pl.DataFrame(models) - expected = pl.DataFrame( { "user_id": ["x"], diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 8c5502d698fd9..f7774fd70191c 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -49,7 +49,7 @@ def test_dtype() -> None: "u": pl.List(pl.UInt64), "tm": pl.List(pl.Time), "dt": pl.List(pl.Date), - "dtm": pl.List(pl.Datetime), + "dtm": pl.List(pl.Datetime("us")), } assert all(tp.is_nested() for tp in df.dtypes) assert df.schema["i"].inner == pl.Int8 # type: ignore[attr-defined] @@ -160,7 +160,7 @@ def test_empty_list_construction() -> None: assert df.to_dict(as_series=False) == expected df = pl.DataFrame(schema=[("col", pl.List)]) - assert df.schema == {"col": pl.List} + assert df.schema == {"col": pl.List(pl.Null)} assert df.rows() == [] diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index 149367621aa06..605c898eaaa28 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -211,7 +211,7 @@ def build_struct_df(data: list[dict[str, object]]) -> pl.DataFrame: # struct column df = build_struct_df([{"struct_col": {"inner": 1}}]) assert df.columns == ["struct_col"] - assert df.schema == {"struct_col": pl.Struct} + assert df.schema == {"struct_col": pl.Struct({"inner": pl.Int64})} assert df["struct_col"].struct.field("inner").to_list() == [1] # struct in struct diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index a925c6f18781c..042a0fca786b5 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -632,16 +632,7 @@ def test_asof_join() -> None: "2016-05-25 13:30:00.072", "2016-05-25 13:30:00.075", ] - ticker = [ - "GOOG", - "MSFT", - "MSFT", - "MSFT", - "GOOG", - "AAPL", - "GOOG", - "MSFT", - ] + ticker = ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"] quotes = pl.DataFrame( { "dates": pl.Series(dates).str.strptime(pl.Datetime, format=format), @@ -656,13 +647,7 @@ def test_asof_join() -> None: "2016-05-25 13:30:00.048", "2016-05-25 13:30:00.048", ] - ticker = [ - "MSFT", - "MSFT", - "GOOG", - "GOOG", - "AAPL", - ] + ticker = ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"] trades = pl.DataFrame( { "dates": pl.Series(dates).str.strptime(pl.Datetime, format=format), @@ -678,11 +663,11 @@ def test_asof_join() -> None: out = trades.join_asof(quotes, on="dates", strategy="backward") assert out.schema == { - "bid": pl.Float64, - "bid_right": pl.Float64, "dates": pl.Datetime("ms"), "ticker": pl.String, + "bid": pl.Float64, "ticker_right": pl.String, + "bid_right": pl.Float64, } assert out.columns == ["dates", "ticker", "bid", "ticker_right", "bid_right"] assert (out["dates"].cast(int)).to_list() == [ diff --git a/py-polars/tests/unit/interop/test_from_pandas.py b/py-polars/tests/unit/interop/test_from_pandas.py index c22d8abafd219..aa0ab7e8210ab 100644 --- a/py-polars/tests/unit/interop/test_from_pandas.py +++ b/py-polars/tests/unit/interop/test_from_pandas.py @@ -60,7 +60,7 @@ def test_from_pandas() -> None: "floats_nulls": pl.Float64, "strings": pl.String, "strings_nulls": pl.String, - "strings-cat": pl.Categorical, + "strings-cat": pl.Categorical(ordering="physical"), } assert out.rows() == [ (False, None, 1, 1.0, 1.0, 1.0, "foo", "foo", "foo"), diff --git a/py-polars/tests/unit/interop/test_interop.py b/py-polars/tests/unit/interop/test_interop.py index eff356391b4bc..b69a10671ca71 100644 --- a/py-polars/tests/unit/interop/test_interop.py +++ b/py-polars/tests/unit/interop/test_interop.py @@ -122,7 +122,7 @@ def test_from_dict_struct() -> None: assert df.shape == (2, 2) assert df["a"][0] == {"b": 1, "c": 2} assert df["a"][1] == {"b": 3, "c": 4} - assert df.schema == {"a": pl.Struct, "d": pl.Int64} + assert df.schema == {"a": pl.Struct({"b": pl.Int64, "c": pl.Int64}), "d": pl.Int64} def test_from_dicts() -> None: @@ -397,7 +397,7 @@ def test_dataframe_from_repr() -> None: assert frame.schema == { "a": pl.Int64, "b": pl.Float64, - "c": pl.Categorical, + "c": pl.Categorical(ordering="physical"), "d": pl.Boolean, "e": pl.String, "f": pl.Date, diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index a764f3c707559..dfb3e005370d3 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -232,7 +232,7 @@ def test_read_excel_basic_datatypes(engine: ExcelSpreadsheetEngine) -> None: xls = BytesIO() df.write_excel(xls, position="C5") - schema_overrides = {"datetime": pl.Datetime, "nulls": pl.Boolean} + schema_overrides = {"datetime": pl.Datetime("us"), "nulls": pl.Boolean} df_compare = df.with_columns( pl.col(nm).cast(tp) for nm, tp in schema_overrides.items() ) @@ -325,7 +325,7 @@ def test_read_mixed_dtype_columns( "Employee ID": pl.Utf8, "Employee Name": pl.Utf8, "Date": pl.Date, - "Details": pl.Categorical, + "Details": pl.Categorical("physical"), "Asset ID": pl.Utf8, } diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 5ed57b374149a..cff43b43274c4 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -406,7 +406,7 @@ def test_group_by_sorted_empty_dataframe_3680() -> None: ) assert df.rows() == [] assert df.shape == (0, 2) - assert df.schema == {"key": pl.Categorical, "val": pl.Float64} + assert df.schema == {"key": pl.Categorical(ordering="physical"), "val": pl.Float64} def test_group_by_custom_agg_empty_list() -> None: diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 7b6a1583ef117..3477dafe85042 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -1,6 +1,8 @@ import pickle from datetime import datetime +import pytest + import polars as pl @@ -13,24 +15,28 @@ def test_schema() -> None: assert s.names() == ["foo", "bar"] assert s.dtypes() == [pl.Int8(), pl.String()] + with pytest.raises( + TypeError, + match="dtypes must be fully-specified, got: List", + ): + pl.Schema({"foo": pl.String, "bar": pl.List}) -def test_schema_parse_nonpolars_dtypes() -> None: - cardinal_directions = pl.Enum(["north", "south", "east", "west"]) - - s = pl.Schema({"foo": pl.List, "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type] - s["ham"] = datetime - - assert s["foo"] == pl.List - assert s["bar"] == pl.Int64 - assert s["baz"] == cardinal_directions - assert s["ham"] == pl.Datetime("us") - - assert s.len() == 4 - assert s.names() == ["foo", "bar", "baz", "ham"] - assert s.dtypes() == [pl.List, pl.Int64, cardinal_directions, pl.Datetime("us")] - assert list(s.to_python().values()) == [list, int, str, datetime] - assert [tp.to_python() for tp in s.dtypes()] == [list, int, str, datetime] +def test_schema_base_types() -> None: + s = pl.Schema( + { + "a": pl.Int8(), + "b": pl.Datetime("us"), + "c": pl.Array(pl.Int8(), shape=(4,)), + "d": pl.Struct({"time": pl.List(pl.Duration), "dist": pl.Float64}), + } + ) + assert s.base_types() == { + "a": pl.Int8, + "b": pl.Datetime, + "c": pl.Array, + "d": pl.Struct, + } def test_schema_equality() -> None: @@ -45,6 +51,32 @@ def test_schema_equality() -> None: assert s1 != s3 assert s2 != s3 + s4 = pl.Schema({"foo": pl.Datetime("us"), "bar": pl.Duration("ns")}) + s5 = pl.Schema({"foo": pl.Datetime("ns"), "bar": pl.Duration("us")}) + s6 = {"foo": pl.Datetime, "bar": pl.Duration} + + assert s4 != s5 + assert s4 != s6 + + +def test_schema_parse_python_dtypes() -> None: + cardinal_directions = pl.Enum(["north", "south", "east", "west"]) + + s = pl.Schema({"foo": pl.List(pl.Int32), "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type] + s["ham"] = datetime + + assert s["foo"] == pl.List(pl.Int32) + assert s["bar"] == pl.Int64 + assert s["baz"] == cardinal_directions + assert s["ham"] == pl.Datetime("us") + + assert s.len() == 4 + assert s.names() == ["foo", "bar", "baz", "ham"] + assert s.dtypes() == [pl.List, pl.Int64, cardinal_directions, pl.Datetime("us")] + + assert list(s.to_python().values()) == [list, int, str, datetime] + assert [tp.to_python() for tp in s.dtypes()] == [list, int, str, datetime] + def test_schema_picklable() -> None: s = pl.Schema( @@ -88,7 +120,6 @@ def test_schema_in_map_elements_returns_scalar() -> None: "amounts": [100.0, -110.0] * 2, } ) - q = ldf.group_by("portfolio").agg( pl.col("amounts") .map_elements( @@ -112,7 +143,6 @@ def test_schema_functions_in_agg_with_literal_arg_19011() -> None: .rolling(index_column=pl.int_range(pl.len()).alias("idx"), period="3i") .agg(pl.col("a").fill_null(0).alias("a_1"), pl.col("a").pow(2.0).alias("a_2")) ) - assert q.collect_schema() == pl.Schema( [("idx", pl.Int64), ("a_1", pl.List(pl.Int64)), ("a_2", pl.List(pl.Float64))] ) diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index bf44ff87bac58..dd2c415c9a138 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -515,7 +515,7 @@ def test_selector_temporal(df: pl.DataFrame) -> None: assert df.select(cs.temporal()).schema == { "ghi": pl.Time, "JJK": pl.Date, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), "opp": pl.Datetime("ms"), } all_columns = set(df.columns) @@ -611,7 +611,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: "eee": pl.Boolean, "ghi": pl.Time, "JJK": pl.Date, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), "opp": pl.Datetime("ms"), "qqR": pl.String, } @@ -629,7 +629,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: assert df.select(cs.temporal() - cs.matches("opp|JJK")).schema == OrderedDict( { "ghi": pl.Time, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), } ) @@ -639,7 +639,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: ).schema == OrderedDict( { "ghi": pl.Time, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), } )