Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Various Schema improvements (equality/init dtype checks) #19379

Merged
merged 6 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions py-polars/polars/_reexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from polars.dataframe import DataFrame
from polars.expr import Expr, When
from polars.lazyframe import LazyFrame
from polars.schema import Schema
from polars.series import Series

__all__ = [
"DataFrame",
"Expr",
"LazyFrame",
"Schema",
"Series",
"When",
]
2 changes: 1 addition & 1 deletion py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ def schema(self) -> Schema:
>>> df.schema
Schema({'foo': Int64, 'bar': Float64, 'ham': String})
"""
return Schema(zip(self.columns, self.dtypes))
return Schema(zip(self.columns, self.dtypes), check_dtypes=False)

def __array__(
self, dtype: npt.DTypeLike | None = None, copy: bool | None = None
Expand Down
8 changes: 6 additions & 2 deletions py-polars/polars/datatypes/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,14 @@


def is_polars_dtype(
dtype: Any, *, include_unknown: bool = False
dtype: Any,
*,
include_unknown: bool = False,
require_instantiated: bool = False,
) -> TypeGuard[PolarsDataType]:
"""Indicate whether the given input is a Polars dtype, or dtype specialization."""
is_dtype = isinstance(dtype, (DataType, DataTypeClass))
check_classes = DataType if require_instantiated else (DataType, DataTypeClass)
is_dtype = isinstance(dtype, check_classes) # type: ignore[arg-type]

if not include_unknown:
return is_dtype and dtype != Unknown
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/io/spreadsheet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def _read_spreadsheet(
infer_schema_length=infer_schema_length,
)
engine_options = (engine_options or {}).copy()
schema_overrides = dict(schema_overrides or {})
schema_overrides = pl.Schema(schema_overrides or {})

# establish the reading function, parser, and available worksheets
reader_fn, parser, worksheets = _initialise_spreadsheet_parser(
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2263,7 +2263,7 @@ def collect_schema(self) -> Schema:
>>> schema.len()
3
"""
return Schema(self._ldf.collect_schema())
return Schema(self._ldf.collect_schema(), check_dtypes=False)

@unstable()
def sink_parquet(
Expand Down
79 changes: 67 additions & 12 deletions py-polars/polars/schema.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,54 @@
from __future__ import annotations

import sys
from collections import OrderedDict
from collections.abc import Mapping
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union

from polars.datatypes import DataType
from polars._typing import PythonDataType
from polars.datatypes import DataType, DataTypeClass, is_polars_dtype
from polars.datatypes._parse import parse_into_dtype

BaseSchema = OrderedDict[str, DataType]

if TYPE_CHECKING:
from collections.abc import Iterable

from polars._typing import PythonDataType
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias


if sys.version_info >= (3, 10):

def _required_init_args(tp: DataTypeClass) -> bool:
# note: this check is ~20% faster than the check for a
# custom "__init__", below, but is not available on py39
return bool(tp.__annotations__)
else:

def _required_init_args(tp: DataTypeClass) -> bool:
# indicates override of the default __init__
# (eg: this type requires specific args)
return "__init__" in tp.__dict__


BaseSchema = OrderedDict[str, DataType]
SchemaInitDataType: TypeAlias = Union[DataType, DataTypeClass, PythonDataType]


__all__ = ["Schema"]


def _check_dtype(tp: DataType | DataTypeClass) -> DataType:
if not isinstance(tp, DataType):
# note: if nested/decimal, or has signature params, this implies required args
if tp.is_nested() or tp.is_decimal() or _required_init_args(tp):
msg = f"dtypes must be fully-specified, got: {tp!r}"
raise TypeError(msg)
tp = tp()
return tp # type: ignore[return-value]


class Schema(BaseSchema):
"""
Ordered mapping of column names to their data type.
Expand Down Expand Up @@ -54,18 +85,42 @@ class Schema(BaseSchema):
def __init__(
self,
schema: (
Mapping[str, DataType | PythonDataType]
| Iterable[tuple[str, DataType | PythonDataType]]
Mapping[str, SchemaInitDataType]
| Iterable[tuple[str, SchemaInitDataType]]
| None
) = None,
*,
check_dtypes: bool = True,
) -> None:
input = (
schema.items() if schema and isinstance(schema, Mapping) else (schema or {})
)
super().__init__({name: parse_into_dtype(tp) for name, tp in input}) # type: ignore[misc]

def __setitem__(self, name: str, dtype: DataType | PythonDataType) -> None:
super().__setitem__(name, parse_into_dtype(dtype)) # type: ignore[assignment]
for name, tp in input: # type: ignore[misc]
if not check_dtypes:
super().__setitem__(name, tp) # type: ignore[assignment]
elif is_polars_dtype(tp):
super().__setitem__(name, _check_dtype(tp))
else:
self[name] = tp

def __eq__(self, other: object) -> bool:
if not isinstance(other, Mapping):
return False
if len(self) != len(other):
return False
for (nm1, tp1), (nm2, tp2) in zip(self.items(), other.items()):
if nm1 != nm2 or not tp1.is_(tp2):
return False
return True

def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

def __setitem__(
self, name: str, dtype: DataType | DataTypeClass | PythonDataType
) -> None:
dtype = _check_dtype(parse_into_dtype(dtype))
super().__setitem__(name, dtype)

def names(self) -> list[str]:
"""Get the column names of the schema."""
Expand All @@ -81,7 +136,7 @@ def len(self) -> int:

def to_python(self) -> dict[str, type]:
"""
Return Schema as a dictionary of column names and their Python types.
Return a dictionary of column names and Python types.

Examples
--------
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def schema(self) -> Schema:
return Schema({})

schema = self._s.dtype().to_schema()
return Schema(schema)
return Schema(schema, check_dtypes=False)

def unnest(self) -> DataFrame:
"""
Expand Down
5 changes: 2 additions & 3 deletions py-polars/tests/unit/constructors/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_init_dict() -> None:
data={"dt": dates, "dtm": datetimes},
schema=coldefs,
)
assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime}
assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime("us")}
assert df.rows() == list(zip(py_dates, py_datetimes))

# Overriding dict column names/types
Expand Down Expand Up @@ -251,7 +251,7 @@ class TradeNT(NamedTuple):
)
assert df.schema == {
"ts": pl.Datetime("ms"),
"tk": pl.Categorical,
"tk": pl.Categorical(ordering="physical"),
"pc": pl.Decimal(scale=1),
"sz": pl.UInt16,
}
Expand Down Expand Up @@ -284,7 +284,6 @@ class PageView(BaseModel):
models = adapter.validate_json(data_json)

result = pl.DataFrame(models)

expected = pl.DataFrame(
{
"user_id": ["x"],
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def test_to_dummies() -> None:
"i": [1, 2, 3],
"category": ["dog", "cat", "cat"],
},
schema={"i": pl.Int32, "category": pl.Categorical},
schema={"i": pl.Int32, "category": pl.Categorical("lexical")},
)
expected = pl.DataFrame(
{
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_dtype() -> None:
"u": pl.List(pl.UInt64),
"tm": pl.List(pl.Time),
"dt": pl.List(pl.Date),
"dtm": pl.List(pl.Datetime),
"dtm": pl.List(pl.Datetime("us")),
}
assert all(tp.is_nested() for tp in df.dtypes)
assert df.schema["i"].inner == pl.Int8 # type: ignore[attr-defined]
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_empty_list_construction() -> None:
assert df.to_dict(as_series=False) == expected

df = pl.DataFrame(schema=[("col", pl.List)])
assert df.schema == {"col": pl.List}
assert df.schema == {"col": pl.List(pl.Null)}
assert df.rows() == []


Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/datatypes/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def build_struct_df(data: list[dict[str, object]]) -> pl.DataFrame:
# struct column
df = build_struct_df([{"struct_col": {"inner": 1}}])
assert df.columns == ["struct_col"]
assert df.schema == {"struct_col": pl.Struct}
assert df.schema == {"struct_col": pl.Struct({"inner": pl.Int64})}
assert df["struct_col"].struct.field("inner").to_list() == [1]

# struct in struct
Expand Down
23 changes: 4 additions & 19 deletions py-polars/tests/unit/datatypes/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,16 +632,7 @@ def test_asof_join() -> None:
"2016-05-25 13:30:00.072",
"2016-05-25 13:30:00.075",
]
ticker = [
"GOOG",
"MSFT",
"MSFT",
"MSFT",
"GOOG",
"AAPL",
"GOOG",
"MSFT",
]
ticker = ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"]
quotes = pl.DataFrame(
{
"dates": pl.Series(dates).str.strptime(pl.Datetime, format=format),
Expand All @@ -656,13 +647,7 @@ def test_asof_join() -> None:
"2016-05-25 13:30:00.048",
"2016-05-25 13:30:00.048",
]
ticker = [
"MSFT",
"MSFT",
"GOOG",
"GOOG",
"AAPL",
]
ticker = ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"]
trades = pl.DataFrame(
{
"dates": pl.Series(dates).str.strptime(pl.Datetime, format=format),
Expand All @@ -678,11 +663,11 @@ def test_asof_join() -> None:
out = trades.join_asof(quotes, on="dates", strategy="backward")

assert out.schema == {
"bid": pl.Float64,
"bid_right": pl.Float64,
"dates": pl.Datetime("ms"),
"ticker": pl.String,
"bid": pl.Float64,
"ticker_right": pl.String,
"bid_right": pl.Float64,
}
assert out.columns == ["dates", "ticker", "bid", "ticker_right", "bid_right"]
assert (out["dates"].cast(int)).to_list() == [
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/interop/test_from_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_from_pandas() -> None:
"floats_nulls": pl.Float64,
"strings": pl.String,
"strings_nulls": pl.String,
"strings-cat": pl.Categorical,
"strings-cat": pl.Categorical(ordering="physical"),
}
assert out.rows() == [
(False, None, 1, 1.0, 1.0, 1.0, "foo", "foo", "foo"),
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/interop/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_from_dict_struct() -> None:
assert df.shape == (2, 2)
assert df["a"][0] == {"b": 1, "c": 2}
assert df["a"][1] == {"b": 3, "c": 4}
assert df.schema == {"a": pl.Struct, "d": pl.Int64}
assert df.schema == {"a": pl.Struct({"b": pl.Int64, "c": pl.Int64}), "d": pl.Int64}


def test_from_dicts() -> None:
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_dataframe_from_repr() -> None:
assert frame.schema == {
"a": pl.Int64,
"b": pl.Float64,
"c": pl.Categorical,
"c": pl.Categorical(ordering="physical"),
"d": pl.Boolean,
"e": pl.String,
"f": pl.Date,
Expand Down
13 changes: 6 additions & 7 deletions py-polars/tests/unit/io/test_spreadsheet.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_read_excel_basic_datatypes(engine: ExcelSpreadsheetEngine) -> None:
xls = BytesIO()
df.write_excel(xls, position="C5")

schema_overrides = {"datetime": pl.Datetime, "nulls": pl.Boolean}
schema_overrides = {"datetime": pl.Datetime("us"), "nulls": pl.Boolean()}
df_compare = df.with_columns(
pl.col(nm).cast(tp) for nm, tp in schema_overrides.items()
)
Expand Down Expand Up @@ -322,13 +322,12 @@ def test_read_mixed_dtype_columns(
) -> None:
spreadsheet_path = request.getfixturevalue(source)
schema_overrides = {
"Employee ID": pl.Utf8,
"Employee Name": pl.Utf8,
"Date": pl.Date,
"Details": pl.Categorical,
"Asset ID": pl.Utf8,
"Employee ID": pl.Utf8(),
"Employee Name": pl.Utf8(),
"Date": pl.Date(),
"Details": pl.Categorical("lexical"),
"Asset ID": pl.Utf8(),
}

df = read_spreadsheet(
spreadsheet_path,
sheet_id=0,
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def test_group_by_sorted_empty_dataframe_3680() -> None:
)
assert df.rows() == []
assert df.shape == (0, 2)
assert df.schema == {"key": pl.Categorical, "val": pl.Float64}
assert df.schema == {"key": pl.Categorical(ordering="physical"), "val": pl.Float64}


def test_group_by_custom_agg_empty_list() -> None:
Expand Down
Loading