From 883296394d756da5d7ec802a456891edf7c07b2c Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 22 Oct 2024 16:21:52 +0400 Subject: [PATCH 1/6] feat(python): Various `Schema` improvements (new `base_types` method, improved equality/init dtype checks) --- py-polars/polars/_reexport.py | 2 + py-polars/polars/dataframe/frame.py | 2 +- py-polars/polars/datatypes/convert.py | 8 ++- py-polars/polars/io/spreadsheet/functions.py | 2 +- py-polars/polars/lazyframe/frame.py | 2 +- py-polars/polars/schema.py | 71 ++++++++++++++++--- py-polars/polars/series/struct.py | 2 +- .../unit/constructors/test_constructors.py | 5 +- py-polars/tests/unit/datatypes/test_list.py | 4 +- py-polars/tests/unit/datatypes/test_struct.py | 2 +- .../tests/unit/datatypes/test_temporal.py | 23 ++---- .../tests/unit/interop/test_from_pandas.py | 2 +- py-polars/tests/unit/interop/test_interop.py | 4 +- py-polars/tests/unit/io/test_spreadsheet.py | 5 +- .../tests/unit/operations/test_group_by.py | 2 +- py-polars/tests/unit/test_schema.py | 66 ++++++++++++----- py-polars/tests/unit/test_selectors.py | 8 +-- 17 files changed, 141 insertions(+), 69 deletions(-) diff --git a/py-polars/polars/_reexport.py b/py-polars/polars/_reexport.py index 408fead781de..10818f473166 100644 --- a/py-polars/polars/_reexport.py +++ b/py-polars/polars/_reexport.py @@ -3,12 +3,14 @@ from polars.dataframe import DataFrame from polars.expr import Expr, When from polars.lazyframe import LazyFrame +from polars.schema import Schema from polars.series import Series __all__ = [ "DataFrame", "Expr", "LazyFrame", + "Schema", "Series", "When", ] diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 8709af81983f..c5bcd01f26d2 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -912,7 +912,7 @@ def schema(self) -> Schema: >>> df.schema Schema({'foo': Int64, 'bar': Float64, 'ham': String}) """ - return Schema(zip(self.columns, self.dtypes)) + return Schema(zip(self.columns, self.dtypes), check_dtypes=False) def __array__( self, dtype: npt.DTypeLike | None = None, copy: bool | None = None diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index d46d8c111581..f773cc28b6ca 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -65,10 +65,14 @@ def is_polars_dtype( - dtype: Any, *, include_unknown: bool = False + dtype: Any, + *, + include_unknown: bool = False, + require_instantiated: bool = False, ) -> TypeGuard[PolarsDataType]: """Indicate whether the given input is a Polars dtype, or dtype specialization.""" - is_dtype = isinstance(dtype, (DataType, DataTypeClass)) + check_classes = DataType if require_instantiated else (DataType, DataTypeClass) + is_dtype = isinstance(dtype, check_classes) # type: ignore[arg-type] if not include_unknown: return is_dtype and dtype != Unknown diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 8b3df004f4b3..1d4bb5fe90b6 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -565,7 +565,7 @@ def _read_spreadsheet( infer_schema_length=infer_schema_length, ) engine_options = (engine_options or {}).copy() - schema_overrides = dict(schema_overrides or {}) + schema_overrides = pl.Schema(schema_overrides or {}) # establish the reading function, parser, and available worksheets reader_fn, parser, worksheets = _initialise_spreadsheet_parser( diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index ad513bc1ff55..64608ae825fb 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -2263,7 +2263,7 @@ def collect_schema(self) -> Schema: >>> schema.len() 3 """ - return Schema(self._ldf.collect_schema()) + return Schema(self._ldf.collect_schema(), check_dtypes=False) @unstable() def sink_parquet( diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py index 72eb8b86d25e..1e589bc198fc 100644 --- a/py-polars/polars/schema.py +++ b/py-polars/polars/schema.py @@ -4,20 +4,31 @@ from collections.abc import Mapping from typing import TYPE_CHECKING -from polars.datatypes import DataType +from polars.datatypes import DataType, is_polars_dtype from polars.datatypes._parse import parse_into_dtype -BaseSchema = OrderedDict[str, DataType] - if TYPE_CHECKING: from collections.abc import Iterable from polars._typing import PythonDataType + from polars.datatypes import DataTypeClass + +BaseSchema = OrderedDict[str, DataType] __all__ = ["Schema"] +def _check_dtype(tp: DataType | DataTypeClass) -> DataType: + if not isinstance(tp, DataType): + # note: if nested, or has annotations, this implies required init params + if tp.is_nested() or tp.__annotations__: + msg = f"dtypes must be fully-specified, got: {tp!r}" + raise TypeError(msg) + tp = tp() + return tp # type: ignore[return-value] + + class Schema(BaseSchema): """ Ordered mapping of column names to their data type. @@ -54,18 +65,60 @@ class Schema(BaseSchema): def __init__( self, schema: ( - Mapping[str, DataType | PythonDataType] - | Iterable[tuple[str, DataType | PythonDataType]] + Mapping[str, DataType | DataTypeClass | PythonDataType] + | Iterable[tuple[str, DataType | DataTypeClass | PythonDataType]] | None ) = None, + *, + check_dtypes: bool = True, ) -> None: input = ( schema.items() if schema and isinstance(schema, Mapping) else (schema or {}) ) - super().__init__({name: parse_into_dtype(tp) for name, tp in input}) # type: ignore[misc] + for name, tp in input: # type: ignore[misc] + if not check_dtypes: + super().__setitem__(name, tp) # type: ignore[assignment] + elif is_polars_dtype(tp): + super().__setitem__(name, _check_dtype(tp)) + else: + self[name] = tp + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Mapping): + return False + if len(self) != len(other): + return False + for (nm1, tp1), (nm2, tp2) in zip(self.items(), other.items()): + if nm1 != nm2 or not tp1.is_(tp2): + return False + return True + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def __setitem__( + self, name: str, dtype: DataType | DataTypeClass | PythonDataType + ) -> None: + dtype = _check_dtype(parse_into_dtype(dtype)) + super().__setitem__(name, dtype) + + def base_types(self) -> dict[str, DataTypeClass]: + """ + Return a dictionary of column names and the fundamental/root type class. - def __setitem__(self, name: str, dtype: DataType | PythonDataType) -> None: - super().__setitem__(name, parse_into_dtype(dtype)) # type: ignore[assignment] + Examples + -------- + >>> s = pl.Schema( + ... { + ... "x": pl.Float64(), + ... "y": pl.List(pl.Int32), + ... "z": pl.Struct([pl.Field("a", pl.Int8), pl.Field("b", pl.Boolean)]), + ... } + ... ) + >>> s.base_types() + {'x': Float64, 'y': List, 'z': Struct} + """ + return {name: tp.base_type() for name, tp in self.items()} def names(self) -> list[str]: """Get the column names of the schema.""" @@ -81,7 +134,7 @@ def len(self) -> int: def to_python(self) -> dict[str, type]: """ - Return Schema as a dictionary of column names and their Python types. + Return a dictionary of column names and Python types. Examples -------- diff --git a/py-polars/polars/series/struct.py b/py-polars/polars/series/struct.py index e8137a23be32..a04d254808d3 100644 --- a/py-polars/polars/series/struct.py +++ b/py-polars/polars/series/struct.py @@ -107,7 +107,7 @@ def schema(self) -> Schema: return Schema({}) schema = self._s.dtype().to_schema() - return Schema(schema) + return Schema(schema, check_dtypes=False) def unnest(self) -> DataFrame: """ diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index d340433ddf10..3ab507f31fa8 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -150,7 +150,7 @@ def test_init_dict() -> None: data={"dt": dates, "dtm": datetimes}, schema=coldefs, ) - assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime} + assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime("us")} assert df.rows() == list(zip(py_dates, py_datetimes)) # Overriding dict column names/types @@ -251,7 +251,7 @@ class TradeNT(NamedTuple): ) assert df.schema == { "ts": pl.Datetime("ms"), - "tk": pl.Categorical, + "tk": pl.Categorical(ordering="physical"), "pc": pl.Decimal(scale=1), "sz": pl.UInt16, } @@ -284,7 +284,6 @@ class PageView(BaseModel): models = adapter.validate_json(data_json) result = pl.DataFrame(models) - expected = pl.DataFrame( { "user_id": ["x"], diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 8c5502d698fd..f7774fd70191 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -49,7 +49,7 @@ def test_dtype() -> None: "u": pl.List(pl.UInt64), "tm": pl.List(pl.Time), "dt": pl.List(pl.Date), - "dtm": pl.List(pl.Datetime), + "dtm": pl.List(pl.Datetime("us")), } assert all(tp.is_nested() for tp in df.dtypes) assert df.schema["i"].inner == pl.Int8 # type: ignore[attr-defined] @@ -160,7 +160,7 @@ def test_empty_list_construction() -> None: assert df.to_dict(as_series=False) == expected df = pl.DataFrame(schema=[("col", pl.List)]) - assert df.schema == {"col": pl.List} + assert df.schema == {"col": pl.List(pl.Null)} assert df.rows() == [] diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index 149367621aa0..605c898eaaa2 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -211,7 +211,7 @@ def build_struct_df(data: list[dict[str, object]]) -> pl.DataFrame: # struct column df = build_struct_df([{"struct_col": {"inner": 1}}]) assert df.columns == ["struct_col"] - assert df.schema == {"struct_col": pl.Struct} + assert df.schema == {"struct_col": pl.Struct({"inner": pl.Int64})} assert df["struct_col"].struct.field("inner").to_list() == [1] # struct in struct diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index a925c6f18781..042a0fca786b 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -632,16 +632,7 @@ def test_asof_join() -> None: "2016-05-25 13:30:00.072", "2016-05-25 13:30:00.075", ] - ticker = [ - "GOOG", - "MSFT", - "MSFT", - "MSFT", - "GOOG", - "AAPL", - "GOOG", - "MSFT", - ] + ticker = ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"] quotes = pl.DataFrame( { "dates": pl.Series(dates).str.strptime(pl.Datetime, format=format), @@ -656,13 +647,7 @@ def test_asof_join() -> None: "2016-05-25 13:30:00.048", "2016-05-25 13:30:00.048", ] - ticker = [ - "MSFT", - "MSFT", - "GOOG", - "GOOG", - "AAPL", - ] + ticker = ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"] trades = pl.DataFrame( { "dates": pl.Series(dates).str.strptime(pl.Datetime, format=format), @@ -678,11 +663,11 @@ def test_asof_join() -> None: out = trades.join_asof(quotes, on="dates", strategy="backward") assert out.schema == { - "bid": pl.Float64, - "bid_right": pl.Float64, "dates": pl.Datetime("ms"), "ticker": pl.String, + "bid": pl.Float64, "ticker_right": pl.String, + "bid_right": pl.Float64, } assert out.columns == ["dates", "ticker", "bid", "ticker_right", "bid_right"] assert (out["dates"].cast(int)).to_list() == [ diff --git a/py-polars/tests/unit/interop/test_from_pandas.py b/py-polars/tests/unit/interop/test_from_pandas.py index c22d8abafd21..aa0ab7e8210a 100644 --- a/py-polars/tests/unit/interop/test_from_pandas.py +++ b/py-polars/tests/unit/interop/test_from_pandas.py @@ -60,7 +60,7 @@ def test_from_pandas() -> None: "floats_nulls": pl.Float64, "strings": pl.String, "strings_nulls": pl.String, - "strings-cat": pl.Categorical, + "strings-cat": pl.Categorical(ordering="physical"), } assert out.rows() == [ (False, None, 1, 1.0, 1.0, 1.0, "foo", "foo", "foo"), diff --git a/py-polars/tests/unit/interop/test_interop.py b/py-polars/tests/unit/interop/test_interop.py index eff356391b4b..b69a10671ca7 100644 --- a/py-polars/tests/unit/interop/test_interop.py +++ b/py-polars/tests/unit/interop/test_interop.py @@ -122,7 +122,7 @@ def test_from_dict_struct() -> None: assert df.shape == (2, 2) assert df["a"][0] == {"b": 1, "c": 2} assert df["a"][1] == {"b": 3, "c": 4} - assert df.schema == {"a": pl.Struct, "d": pl.Int64} + assert df.schema == {"a": pl.Struct({"b": pl.Int64, "c": pl.Int64}), "d": pl.Int64} def test_from_dicts() -> None: @@ -397,7 +397,7 @@ def test_dataframe_from_repr() -> None: assert frame.schema == { "a": pl.Int64, "b": pl.Float64, - "c": pl.Categorical, + "c": pl.Categorical(ordering="physical"), "d": pl.Boolean, "e": pl.String, "f": pl.Date, diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index a764f3c70755..2f957268c4a1 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -232,7 +232,7 @@ def test_read_excel_basic_datatypes(engine: ExcelSpreadsheetEngine) -> None: xls = BytesIO() df.write_excel(xls, position="C5") - schema_overrides = {"datetime": pl.Datetime, "nulls": pl.Boolean} + schema_overrides = {"datetime": pl.Datetime("us"), "nulls": pl.Boolean()} df_compare = df.with_columns( pl.col(nm).cast(tp) for nm, tp in schema_overrides.items() ) @@ -325,10 +325,9 @@ def test_read_mixed_dtype_columns( "Employee ID": pl.Utf8, "Employee Name": pl.Utf8, "Date": pl.Date, - "Details": pl.Categorical, + "Details": pl.Categorical("lexical"), "Asset ID": pl.Utf8, } - df = read_spreadsheet( spreadsheet_path, sheet_id=0, diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 5ed57b374149..cff43b43274c 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -406,7 +406,7 @@ def test_group_by_sorted_empty_dataframe_3680() -> None: ) assert df.rows() == [] assert df.shape == (0, 2) - assert df.schema == {"key": pl.Categorical, "val": pl.Float64} + assert df.schema == {"key": pl.Categorical(ordering="physical"), "val": pl.Float64} def test_group_by_custom_agg_empty_list() -> None: diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 7b6a1583ef11..3477dafe8504 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -1,6 +1,8 @@ import pickle from datetime import datetime +import pytest + import polars as pl @@ -13,24 +15,28 @@ def test_schema() -> None: assert s.names() == ["foo", "bar"] assert s.dtypes() == [pl.Int8(), pl.String()] + with pytest.raises( + TypeError, + match="dtypes must be fully-specified, got: List", + ): + pl.Schema({"foo": pl.String, "bar": pl.List}) -def test_schema_parse_nonpolars_dtypes() -> None: - cardinal_directions = pl.Enum(["north", "south", "east", "west"]) - - s = pl.Schema({"foo": pl.List, "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type] - s["ham"] = datetime - - assert s["foo"] == pl.List - assert s["bar"] == pl.Int64 - assert s["baz"] == cardinal_directions - assert s["ham"] == pl.Datetime("us") - - assert s.len() == 4 - assert s.names() == ["foo", "bar", "baz", "ham"] - assert s.dtypes() == [pl.List, pl.Int64, cardinal_directions, pl.Datetime("us")] - assert list(s.to_python().values()) == [list, int, str, datetime] - assert [tp.to_python() for tp in s.dtypes()] == [list, int, str, datetime] +def test_schema_base_types() -> None: + s = pl.Schema( + { + "a": pl.Int8(), + "b": pl.Datetime("us"), + "c": pl.Array(pl.Int8(), shape=(4,)), + "d": pl.Struct({"time": pl.List(pl.Duration), "dist": pl.Float64}), + } + ) + assert s.base_types() == { + "a": pl.Int8, + "b": pl.Datetime, + "c": pl.Array, + "d": pl.Struct, + } def test_schema_equality() -> None: @@ -45,6 +51,32 @@ def test_schema_equality() -> None: assert s1 != s3 assert s2 != s3 + s4 = pl.Schema({"foo": pl.Datetime("us"), "bar": pl.Duration("ns")}) + s5 = pl.Schema({"foo": pl.Datetime("ns"), "bar": pl.Duration("us")}) + s6 = {"foo": pl.Datetime, "bar": pl.Duration} + + assert s4 != s5 + assert s4 != s6 + + +def test_schema_parse_python_dtypes() -> None: + cardinal_directions = pl.Enum(["north", "south", "east", "west"]) + + s = pl.Schema({"foo": pl.List(pl.Int32), "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type] + s["ham"] = datetime + + assert s["foo"] == pl.List(pl.Int32) + assert s["bar"] == pl.Int64 + assert s["baz"] == cardinal_directions + assert s["ham"] == pl.Datetime("us") + + assert s.len() == 4 + assert s.names() == ["foo", "bar", "baz", "ham"] + assert s.dtypes() == [pl.List, pl.Int64, cardinal_directions, pl.Datetime("us")] + + assert list(s.to_python().values()) == [list, int, str, datetime] + assert [tp.to_python() for tp in s.dtypes()] == [list, int, str, datetime] + def test_schema_picklable() -> None: s = pl.Schema( @@ -88,7 +120,6 @@ def test_schema_in_map_elements_returns_scalar() -> None: "amounts": [100.0, -110.0] * 2, } ) - q = ldf.group_by("portfolio").agg( pl.col("amounts") .map_elements( @@ -112,7 +143,6 @@ def test_schema_functions_in_agg_with_literal_arg_19011() -> None: .rolling(index_column=pl.int_range(pl.len()).alias("idx"), period="3i") .agg(pl.col("a").fill_null(0).alias("a_1"), pl.col("a").pow(2.0).alias("a_2")) ) - assert q.collect_schema() == pl.Schema( [("idx", pl.Int64), ("a_1", pl.List(pl.Int64)), ("a_2", pl.List(pl.Float64))] ) diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index bf44ff87bac5..dd2c415c9a13 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -515,7 +515,7 @@ def test_selector_temporal(df: pl.DataFrame) -> None: assert df.select(cs.temporal()).schema == { "ghi": pl.Time, "JJK": pl.Date, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), "opp": pl.Datetime("ms"), } all_columns = set(df.columns) @@ -611,7 +611,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: "eee": pl.Boolean, "ghi": pl.Time, "JJK": pl.Date, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), "opp": pl.Datetime("ms"), "qqR": pl.String, } @@ -629,7 +629,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: assert df.select(cs.temporal() - cs.matches("opp|JJK")).schema == OrderedDict( { "ghi": pl.Time, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), } ) @@ -639,7 +639,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: ).schema == OrderedDict( { "ghi": pl.Time, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), } ) From ea52dcea6ffed69aa98b8ea4596a505a9c621dc5 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 22 Oct 2024 19:30:59 +0400 Subject: [PATCH 2/6] fix some test lint --- py-polars/polars/schema.py | 18 ++++++++++++------ py-polars/tests/unit/dataframe/test_df.py | 2 +- py-polars/tests/unit/io/test_spreadsheet.py | 8 ++++---- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py index 1e589bc198fc..c94035f54f7f 100644 --- a/py-polars/polars/schema.py +++ b/py-polars/polars/schema.py @@ -2,19 +2,25 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union -from polars.datatypes import DataType, is_polars_dtype +from polars._typing import PythonDataType +from polars.datatypes import DataType, DataTypeClass, is_polars_dtype from polars.datatypes._parse import parse_into_dtype if TYPE_CHECKING: + import sys from collections.abc import Iterable - from polars._typing import PythonDataType - from polars.datatypes import DataTypeClass + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias BaseSchema = OrderedDict[str, DataType] +SchemaInitDataType: TypeAlias = Union[DataType, DataTypeClass, PythonDataType] + __all__ = ["Schema"] @@ -65,8 +71,8 @@ class Schema(BaseSchema): def __init__( self, schema: ( - Mapping[str, DataType | DataTypeClass | PythonDataType] - | Iterable[tuple[str, DataType | DataTypeClass | PythonDataType]] + Mapping[str, SchemaInitDataType] + | Iterable[tuple[str, SchemaInitDataType]] | None ) = None, *, diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 29856bf3eb4e..d8910cda4fb2 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -760,7 +760,7 @@ def test_to_dummies() -> None: "i": [1, 2, 3], "category": ["dog", "cat", "cat"], }, - schema={"i": pl.Int32, "category": pl.Categorical}, + schema={"i": pl.Int32, "category": pl.Categorical("lexical")}, ) expected = pl.DataFrame( { diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index 2f957268c4a1..b7b03a0bd02e 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -322,11 +322,11 @@ def test_read_mixed_dtype_columns( ) -> None: spreadsheet_path = request.getfixturevalue(source) schema_overrides = { - "Employee ID": pl.Utf8, - "Employee Name": pl.Utf8, - "Date": pl.Date, + "Employee ID": pl.Utf8(), + "Employee Name": pl.Utf8(), + "Date": pl.Date(), "Details": pl.Categorical("lexical"), - "Asset ID": pl.Utf8, + "Asset ID": pl.Utf8(), } df = read_spreadsheet( spreadsheet_path, From 456219bc01d4b263afd9dce5a9b9c849cd8c63c3 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 22 Oct 2024 19:49:49 +0400 Subject: [PATCH 3/6] py3.9 lint --- py-polars/polars/schema.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py index c94035f54f7f..83f99def44f6 100644 --- a/py-polars/polars/schema.py +++ b/py-polars/polars/schema.py @@ -2,6 +2,7 @@ from collections import OrderedDict from collections.abc import Mapping +from inspect import signature from typing import TYPE_CHECKING, Union from polars._typing import PythonDataType @@ -27,8 +28,8 @@ def _check_dtype(tp: DataType | DataTypeClass) -> DataType: if not isinstance(tp, DataType): - # note: if nested, or has annotations, this implies required init params - if tp.is_nested() or tp.__annotations__: + # note: if nested, or has signature params, this implies required init args + if tp.is_nested() or signature(tp).parameters: msg = f"dtypes must be fully-specified, got: {tp!r}" raise TypeError(msg) tp = tp() From 8faee2cac349fd0d89176ae2f272f5705b8171ec Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 22 Oct 2024 20:10:21 +0400 Subject: [PATCH 4/6] optimise check on >=py3.10 --- py-polars/polars/schema.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py index 83f99def44f6..07628b21fbab 100644 --- a/py-polars/polars/schema.py +++ b/py-polars/polars/schema.py @@ -1,8 +1,8 @@ from __future__ import annotations +import sys from collections import OrderedDict from collections.abc import Mapping -from inspect import signature from typing import TYPE_CHECKING, Union from polars._typing import PythonDataType @@ -10,7 +10,6 @@ from polars.datatypes._parse import parse_into_dtype if TYPE_CHECKING: - import sys from collections.abc import Iterable if sys.version_info >= (3, 10): @@ -19,6 +18,21 @@ from typing_extensions import TypeAlias +if sys.version_info >= (3, 10): + from inspect import get_annotations + + def _required_init_args(tp: DataTypeClass) -> bool: + # note: this check is ~10x faster than using 'signature', + # but is not available on py39 + return bool(get_annotations(tp)) + +else: + from inspect import signature + + def _required_init_args(tp: DataTypeClass) -> bool: + return bool(signature(tp).parameters) + + BaseSchema = OrderedDict[str, DataType] SchemaInitDataType: TypeAlias = Union[DataType, DataTypeClass, PythonDataType] @@ -28,8 +42,8 @@ def _check_dtype(tp: DataType | DataTypeClass) -> DataType: if not isinstance(tp, DataType): - # note: if nested, or has signature params, this implies required init args - if tp.is_nested() or signature(tp).parameters: + # note: if nested/decimal, or has signature params, this implies required args + if tp.is_nested() or tp.is_decimal() or _required_init_args(tp): msg = f"dtypes must be fully-specified, got: {tp!r}" raise TypeError(msg) tp = tp() From 215186e3e379c89a7ec5ab646fff7500c92abdb3 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Wed, 23 Oct 2024 10:58:05 +0400 Subject: [PATCH 5/6] optimise check on py3.9 --- py-polars/polars/schema.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py index 07628b21fbab..3df42d6607d7 100644 --- a/py-polars/polars/schema.py +++ b/py-polars/polars/schema.py @@ -19,18 +19,17 @@ if sys.version_info >= (3, 10): - from inspect import get_annotations def _required_init_args(tp: DataTypeClass) -> bool: - # note: this check is ~10x faster than using 'signature', - # but is not available on py39 - return bool(get_annotations(tp)) - + # note: this check is ~20% faster than the check for a + # custom "__init__", below, but is not available on py39 + return bool(tp.__annotations__) else: - from inspect import signature def _required_init_args(tp: DataTypeClass) -> bool: - return bool(signature(tp).parameters) + # indicates override of the default __init__ + # (eg: this type requires specific args) + return "__init__" in tp.__dict__ BaseSchema = OrderedDict[str, DataType] From db59be1d97a5e21c2213fa602490895e266741e4 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Wed, 23 Oct 2024 11:04:22 +0400 Subject: [PATCH 6/6] drop `base_types` method --- py-polars/polars/schema.py | 18 ------------------ py-polars/tests/unit/test_schema.py | 17 ----------------- 2 files changed, 35 deletions(-) diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py index 3df42d6607d7..81ade5a6b206 100644 --- a/py-polars/polars/schema.py +++ b/py-polars/polars/schema.py @@ -122,24 +122,6 @@ def __setitem__( dtype = _check_dtype(parse_into_dtype(dtype)) super().__setitem__(name, dtype) - def base_types(self) -> dict[str, DataTypeClass]: - """ - Return a dictionary of column names and the fundamental/root type class. - - Examples - -------- - >>> s = pl.Schema( - ... { - ... "x": pl.Float64(), - ... "y": pl.List(pl.Int32), - ... "z": pl.Struct([pl.Field("a", pl.Int8), pl.Field("b", pl.Boolean)]), - ... } - ... ) - >>> s.base_types() - {'x': Float64, 'y': List, 'z': Struct} - """ - return {name: tp.base_type() for name, tp in self.items()} - def names(self) -> list[str]: """Get the column names of the schema.""" return list(self.keys()) diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 3477dafe8504..9c5848382f42 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -22,23 +22,6 @@ def test_schema() -> None: pl.Schema({"foo": pl.String, "bar": pl.List}) -def test_schema_base_types() -> None: - s = pl.Schema( - { - "a": pl.Int8(), - "b": pl.Datetime("us"), - "c": pl.Array(pl.Int8(), shape=(4,)), - "d": pl.Struct({"time": pl.List(pl.Duration), "dist": pl.Float64}), - } - ) - assert s.base_types() == { - "a": pl.Int8, - "b": pl.Datetime, - "c": pl.Array, - "d": pl.Struct, - } - - def test_schema_equality() -> None: s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()}) s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()})