diff --git a/py-polars/polars/_reexport.py b/py-polars/polars/_reexport.py index 408fead781de..10818f473166 100644 --- a/py-polars/polars/_reexport.py +++ b/py-polars/polars/_reexport.py @@ -3,12 +3,14 @@ from polars.dataframe import DataFrame from polars.expr import Expr, When from polars.lazyframe import LazyFrame +from polars.schema import Schema from polars.series import Series __all__ = [ "DataFrame", "Expr", "LazyFrame", + "Schema", "Series", "When", ] diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 8709af81983f..c5bcd01f26d2 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -912,7 +912,7 @@ def schema(self) -> Schema: >>> df.schema Schema({'foo': Int64, 'bar': Float64, 'ham': String}) """ - return Schema(zip(self.columns, self.dtypes)) + return Schema(zip(self.columns, self.dtypes), check_dtypes=False) def __array__( self, dtype: npt.DTypeLike | None = None, copy: bool | None = None diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index d46d8c111581..f773cc28b6ca 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -65,10 +65,14 @@ def is_polars_dtype( - dtype: Any, *, include_unknown: bool = False + dtype: Any, + *, + include_unknown: bool = False, + require_instantiated: bool = False, ) -> TypeGuard[PolarsDataType]: """Indicate whether the given input is a Polars dtype, or dtype specialization.""" - is_dtype = isinstance(dtype, (DataType, DataTypeClass)) + check_classes = DataType if require_instantiated else (DataType, DataTypeClass) + is_dtype = isinstance(dtype, check_classes) # type: ignore[arg-type] if not include_unknown: return is_dtype and dtype != Unknown diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 8b3df004f4b3..1d4bb5fe90b6 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -565,7 +565,7 @@ def _read_spreadsheet( infer_schema_length=infer_schema_length, ) engine_options = (engine_options or {}).copy() - schema_overrides = dict(schema_overrides or {}) + schema_overrides = pl.Schema(schema_overrides or {}) # establish the reading function, parser, and available worksheets reader_fn, parser, worksheets = _initialise_spreadsheet_parser( diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index ad513bc1ff55..64608ae825fb 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -2263,7 +2263,7 @@ def collect_schema(self) -> Schema: >>> schema.len() 3 """ - return Schema(self._ldf.collect_schema()) + return Schema(self._ldf.collect_schema(), check_dtypes=False) @unstable() def sink_parquet( diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py index 72eb8b86d25e..81ade5a6b206 100644 --- a/py-polars/polars/schema.py +++ b/py-polars/polars/schema.py @@ -1,23 +1,54 @@ from __future__ import annotations +import sys from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union -from polars.datatypes import DataType +from polars._typing import PythonDataType +from polars.datatypes import DataType, DataTypeClass, is_polars_dtype from polars.datatypes._parse import parse_into_dtype -BaseSchema = OrderedDict[str, DataType] - if TYPE_CHECKING: from collections.abc import Iterable - from polars._typing import PythonDataType + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + + +if sys.version_info >= (3, 10): + + def _required_init_args(tp: DataTypeClass) -> bool: + # note: this check is ~20% faster than the check for a + # custom "__init__", below, but is not available on py39 + return bool(tp.__annotations__) +else: + + def _required_init_args(tp: DataTypeClass) -> bool: + # indicates override of the default __init__ + # (eg: this type requires specific args) + return "__init__" in tp.__dict__ + + +BaseSchema = OrderedDict[str, DataType] +SchemaInitDataType: TypeAlias = Union[DataType, DataTypeClass, PythonDataType] __all__ = ["Schema"] +def _check_dtype(tp: DataType | DataTypeClass) -> DataType: + if not isinstance(tp, DataType): + # note: if nested/decimal, or has signature params, this implies required args + if tp.is_nested() or tp.is_decimal() or _required_init_args(tp): + msg = f"dtypes must be fully-specified, got: {tp!r}" + raise TypeError(msg) + tp = tp() + return tp # type: ignore[return-value] + + class Schema(BaseSchema): """ Ordered mapping of column names to their data type. @@ -54,18 +85,42 @@ class Schema(BaseSchema): def __init__( self, schema: ( - Mapping[str, DataType | PythonDataType] - | Iterable[tuple[str, DataType | PythonDataType]] + Mapping[str, SchemaInitDataType] + | Iterable[tuple[str, SchemaInitDataType]] | None ) = None, + *, + check_dtypes: bool = True, ) -> None: input = ( schema.items() if schema and isinstance(schema, Mapping) else (schema or {}) ) - super().__init__({name: parse_into_dtype(tp) for name, tp in input}) # type: ignore[misc] - - def __setitem__(self, name: str, dtype: DataType | PythonDataType) -> None: - super().__setitem__(name, parse_into_dtype(dtype)) # type: ignore[assignment] + for name, tp in input: # type: ignore[misc] + if not check_dtypes: + super().__setitem__(name, tp) # type: ignore[assignment] + elif is_polars_dtype(tp): + super().__setitem__(name, _check_dtype(tp)) + else: + self[name] = tp + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Mapping): + return False + if len(self) != len(other): + return False + for (nm1, tp1), (nm2, tp2) in zip(self.items(), other.items()): + if nm1 != nm2 or not tp1.is_(tp2): + return False + return True + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def __setitem__( + self, name: str, dtype: DataType | DataTypeClass | PythonDataType + ) -> None: + dtype = _check_dtype(parse_into_dtype(dtype)) + super().__setitem__(name, dtype) def names(self) -> list[str]: """Get the column names of the schema.""" @@ -81,7 +136,7 @@ def len(self) -> int: def to_python(self) -> dict[str, type]: """ - Return Schema as a dictionary of column names and their Python types. + Return a dictionary of column names and Python types. Examples -------- diff --git a/py-polars/polars/series/struct.py b/py-polars/polars/series/struct.py index e8137a23be32..a04d254808d3 100644 --- a/py-polars/polars/series/struct.py +++ b/py-polars/polars/series/struct.py @@ -107,7 +107,7 @@ def schema(self) -> Schema: return Schema({}) schema = self._s.dtype().to_schema() - return Schema(schema) + return Schema(schema, check_dtypes=False) def unnest(self) -> DataFrame: """ diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index d340433ddf10..3ab507f31fa8 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -150,7 +150,7 @@ def test_init_dict() -> None: data={"dt": dates, "dtm": datetimes}, schema=coldefs, ) - assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime} + assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime("us")} assert df.rows() == list(zip(py_dates, py_datetimes)) # Overriding dict column names/types @@ -251,7 +251,7 @@ class TradeNT(NamedTuple): ) assert df.schema == { "ts": pl.Datetime("ms"), - "tk": pl.Categorical, + "tk": pl.Categorical(ordering="physical"), "pc": pl.Decimal(scale=1), "sz": pl.UInt16, } @@ -284,7 +284,6 @@ class PageView(BaseModel): models = adapter.validate_json(data_json) result = pl.DataFrame(models) - expected = pl.DataFrame( { "user_id": ["x"], diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 29856bf3eb4e..d8910cda4fb2 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -760,7 +760,7 @@ def test_to_dummies() -> None: "i": [1, 2, 3], "category": ["dog", "cat", "cat"], }, - schema={"i": pl.Int32, "category": pl.Categorical}, + schema={"i": pl.Int32, "category": pl.Categorical("lexical")}, ) expected = pl.DataFrame( { diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 8c5502d698fd..f7774fd70191 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -49,7 +49,7 @@ def test_dtype() -> None: "u": pl.List(pl.UInt64), "tm": pl.List(pl.Time), "dt": pl.List(pl.Date), - "dtm": pl.List(pl.Datetime), + "dtm": pl.List(pl.Datetime("us")), } assert all(tp.is_nested() for tp in df.dtypes) assert df.schema["i"].inner == pl.Int8 # type: ignore[attr-defined] @@ -160,7 +160,7 @@ def test_empty_list_construction() -> None: assert df.to_dict(as_series=False) == expected df = pl.DataFrame(schema=[("col", pl.List)]) - assert df.schema == {"col": pl.List} + assert df.schema == {"col": pl.List(pl.Null)} assert df.rows() == [] diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index 149367621aa0..605c898eaaa2 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -211,7 +211,7 @@ def build_struct_df(data: list[dict[str, object]]) -> pl.DataFrame: # struct column df = build_struct_df([{"struct_col": {"inner": 1}}]) assert df.columns == ["struct_col"] - assert df.schema == {"struct_col": pl.Struct} + assert df.schema == {"struct_col": pl.Struct({"inner": pl.Int64})} assert df["struct_col"].struct.field("inner").to_list() == [1] # struct in struct diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index a925c6f18781..042a0fca786b 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -632,16 +632,7 @@ def test_asof_join() -> None: "2016-05-25 13:30:00.072", "2016-05-25 13:30:00.075", ] - ticker = [ - "GOOG", - "MSFT", - "MSFT", - "MSFT", - "GOOG", - "AAPL", - "GOOG", - "MSFT", - ] + ticker = ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"] quotes = pl.DataFrame( { "dates": pl.Series(dates).str.strptime(pl.Datetime, format=format), @@ -656,13 +647,7 @@ def test_asof_join() -> None: "2016-05-25 13:30:00.048", "2016-05-25 13:30:00.048", ] - ticker = [ - "MSFT", - "MSFT", - "GOOG", - "GOOG", - "AAPL", - ] + ticker = ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"] trades = pl.DataFrame( { "dates": pl.Series(dates).str.strptime(pl.Datetime, format=format), @@ -678,11 +663,11 @@ def test_asof_join() -> None: out = trades.join_asof(quotes, on="dates", strategy="backward") assert out.schema == { - "bid": pl.Float64, - "bid_right": pl.Float64, "dates": pl.Datetime("ms"), "ticker": pl.String, + "bid": pl.Float64, "ticker_right": pl.String, + "bid_right": pl.Float64, } assert out.columns == ["dates", "ticker", "bid", "ticker_right", "bid_right"] assert (out["dates"].cast(int)).to_list() == [ diff --git a/py-polars/tests/unit/interop/test_from_pandas.py b/py-polars/tests/unit/interop/test_from_pandas.py index c22d8abafd21..aa0ab7e8210a 100644 --- a/py-polars/tests/unit/interop/test_from_pandas.py +++ b/py-polars/tests/unit/interop/test_from_pandas.py @@ -60,7 +60,7 @@ def test_from_pandas() -> None: "floats_nulls": pl.Float64, "strings": pl.String, "strings_nulls": pl.String, - "strings-cat": pl.Categorical, + "strings-cat": pl.Categorical(ordering="physical"), } assert out.rows() == [ (False, None, 1, 1.0, 1.0, 1.0, "foo", "foo", "foo"), diff --git a/py-polars/tests/unit/interop/test_interop.py b/py-polars/tests/unit/interop/test_interop.py index eff356391b4b..b69a10671ca7 100644 --- a/py-polars/tests/unit/interop/test_interop.py +++ b/py-polars/tests/unit/interop/test_interop.py @@ -122,7 +122,7 @@ def test_from_dict_struct() -> None: assert df.shape == (2, 2) assert df["a"][0] == {"b": 1, "c": 2} assert df["a"][1] == {"b": 3, "c": 4} - assert df.schema == {"a": pl.Struct, "d": pl.Int64} + assert df.schema == {"a": pl.Struct({"b": pl.Int64, "c": pl.Int64}), "d": pl.Int64} def test_from_dicts() -> None: @@ -397,7 +397,7 @@ def test_dataframe_from_repr() -> None: assert frame.schema == { "a": pl.Int64, "b": pl.Float64, - "c": pl.Categorical, + "c": pl.Categorical(ordering="physical"), "d": pl.Boolean, "e": pl.String, "f": pl.Date, diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index a764f3c70755..b7b03a0bd02e 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -232,7 +232,7 @@ def test_read_excel_basic_datatypes(engine: ExcelSpreadsheetEngine) -> None: xls = BytesIO() df.write_excel(xls, position="C5") - schema_overrides = {"datetime": pl.Datetime, "nulls": pl.Boolean} + schema_overrides = {"datetime": pl.Datetime("us"), "nulls": pl.Boolean()} df_compare = df.with_columns( pl.col(nm).cast(tp) for nm, tp in schema_overrides.items() ) @@ -322,13 +322,12 @@ def test_read_mixed_dtype_columns( ) -> None: spreadsheet_path = request.getfixturevalue(source) schema_overrides = { - "Employee ID": pl.Utf8, - "Employee Name": pl.Utf8, - "Date": pl.Date, - "Details": pl.Categorical, - "Asset ID": pl.Utf8, + "Employee ID": pl.Utf8(), + "Employee Name": pl.Utf8(), + "Date": pl.Date(), + "Details": pl.Categorical("lexical"), + "Asset ID": pl.Utf8(), } - df = read_spreadsheet( spreadsheet_path, sheet_id=0, diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 5ed57b374149..cff43b43274c 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -406,7 +406,7 @@ def test_group_by_sorted_empty_dataframe_3680() -> None: ) assert df.rows() == [] assert df.shape == (0, 2) - assert df.schema == {"key": pl.Categorical, "val": pl.Float64} + assert df.schema == {"key": pl.Categorical(ordering="physical"), "val": pl.Float64} def test_group_by_custom_agg_empty_list() -> None: diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 7b6a1583ef11..9c5848382f42 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -1,6 +1,8 @@ import pickle from datetime import datetime +import pytest + import polars as pl @@ -13,14 +15,40 @@ def test_schema() -> None: assert s.names() == ["foo", "bar"] assert s.dtypes() == [pl.Int8(), pl.String()] + with pytest.raises( + TypeError, + match="dtypes must be fully-specified, got: List", + ): + pl.Schema({"foo": pl.String, "bar": pl.List}) + + +def test_schema_equality() -> None: + s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()}) + s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) + s3 = pl.Schema({"bar": pl.Float64(), "foo": pl.Int8()}) + + assert s1 == s1 + assert s2 == s2 + assert s3 == s3 + assert s1 != s2 + assert s1 != s3 + assert s2 != s3 + + s4 = pl.Schema({"foo": pl.Datetime("us"), "bar": pl.Duration("ns")}) + s5 = pl.Schema({"foo": pl.Datetime("ns"), "bar": pl.Duration("us")}) + s6 = {"foo": pl.Datetime, "bar": pl.Duration} + + assert s4 != s5 + assert s4 != s6 -def test_schema_parse_nonpolars_dtypes() -> None: + +def test_schema_parse_python_dtypes() -> None: cardinal_directions = pl.Enum(["north", "south", "east", "west"]) - s = pl.Schema({"foo": pl.List, "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type] + s = pl.Schema({"foo": pl.List(pl.Int32), "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type] s["ham"] = datetime - assert s["foo"] == pl.List + assert s["foo"] == pl.List(pl.Int32) assert s["bar"] == pl.Int64 assert s["baz"] == cardinal_directions assert s["ham"] == pl.Datetime("us") @@ -33,19 +61,6 @@ def test_schema_parse_nonpolars_dtypes() -> None: assert [tp.to_python() for tp in s.dtypes()] == [list, int, str, datetime] -def test_schema_equality() -> None: - s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()}) - s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) - s3 = pl.Schema({"bar": pl.Float64(), "foo": pl.Int8()}) - - assert s1 == s1 - assert s2 == s2 - assert s3 == s3 - assert s1 != s2 - assert s1 != s3 - assert s2 != s3 - - def test_schema_picklable() -> None: s = pl.Schema( { @@ -88,7 +103,6 @@ def test_schema_in_map_elements_returns_scalar() -> None: "amounts": [100.0, -110.0] * 2, } ) - q = ldf.group_by("portfolio").agg( pl.col("amounts") .map_elements( @@ -112,7 +126,6 @@ def test_schema_functions_in_agg_with_literal_arg_19011() -> None: .rolling(index_column=pl.int_range(pl.len()).alias("idx"), period="3i") .agg(pl.col("a").fill_null(0).alias("a_1"), pl.col("a").pow(2.0).alias("a_2")) ) - assert q.collect_schema() == pl.Schema( [("idx", pl.Int64), ("a_1", pl.List(pl.Int64)), ("a_2", pl.List(pl.Float64))] ) diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index bf44ff87bac5..dd2c415c9a13 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -515,7 +515,7 @@ def test_selector_temporal(df: pl.DataFrame) -> None: assert df.select(cs.temporal()).schema == { "ghi": pl.Time, "JJK": pl.Date, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), "opp": pl.Datetime("ms"), } all_columns = set(df.columns) @@ -611,7 +611,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: "eee": pl.Boolean, "ghi": pl.Time, "JJK": pl.Date, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), "opp": pl.Datetime("ms"), "qqR": pl.String, } @@ -629,7 +629,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: assert df.select(cs.temporal() - cs.matches("opp|JJK")).schema == OrderedDict( { "ghi": pl.Time, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), } ) @@ -639,7 +639,7 @@ def test_selector_sets(df: pl.DataFrame) -> None: ).schema == OrderedDict( { "ghi": pl.Time, - "Lmn": pl.Duration, + "Lmn": pl.Duration("us"), } )