Skip to content

Commit

Permalink
feat(python): Various Schema improvements (new base_types method,…
Browse files Browse the repository at this point in the history
… improved equality/init dtype checks)
  • Loading branch information
alexander-beedie committed Oct 22, 2024
1 parent 27289b2 commit 873d39b
Show file tree
Hide file tree
Showing 17 changed files with 138 additions and 66 deletions.
2 changes: 2 additions & 0 deletions py-polars/polars/_reexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from polars.dataframe import DataFrame
from polars.expr import Expr, When
from polars.lazyframe import LazyFrame
from polars.schema import Schema
from polars.series import Series

__all__ = [
"DataFrame",
"Expr",
"LazyFrame",
"Schema",
"Series",
"When",
]
2 changes: 1 addition & 1 deletion py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ def schema(self) -> Schema:
>>> df.schema
Schema({'foo': Int64, 'bar': Float64, 'ham': String})
"""
return Schema(zip(self.columns, self.dtypes))
return Schema(zip(self.columns, self.dtypes), check_dtypes=False)

def __array__(
self, dtype: npt.DTypeLike | None = None, copy: bool | None = None
Expand Down
8 changes: 6 additions & 2 deletions py-polars/polars/datatypes/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,14 @@


def is_polars_dtype(
dtype: Any, *, include_unknown: bool = False
dtype: Any,
*,
include_unknown: bool = False,
require_instantiated: bool = False,
) -> TypeGuard[PolarsDataType]:
"""Indicate whether the given input is a Polars dtype, or dtype specialization."""
is_dtype = isinstance(dtype, (DataType, DataTypeClass))
check_classes = DataType if require_instantiated else (DataType, DataTypeClass)
is_dtype = isinstance(dtype, check_classes) # type: ignore[arg-type]

if not include_unknown:
return is_dtype and dtype != Unknown
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/io/spreadsheet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def _read_spreadsheet(
infer_schema_length=infer_schema_length,
)
engine_options = (engine_options or {}).copy()
schema_overrides = dict(schema_overrides or {})
schema_overrides = pl.Schema(schema_overrides or {})

# establish the reading function, parser, and available worksheets
reader_fn, parser, worksheets = _initialise_spreadsheet_parser(
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2263,7 +2263,7 @@ def collect_schema(self) -> Schema:
>>> schema.len()
3
"""
return Schema(self._ldf.collect_schema())
return Schema(self._ldf.collect_schema(), check_dtypes=False)

@unstable()
def sink_parquet(
Expand Down
66 changes: 59 additions & 7 deletions py-polars/polars/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,30 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING

from polars.datatypes import DataType
from polars.datatypes import DataType, is_polars_dtype
from polars.datatypes._parse import parse_into_dtype

BaseSchema = OrderedDict[str, DataType]

if TYPE_CHECKING:
from collections.abc import Iterable

from polars._typing import PythonDataType
from polars.datatypes import DataTypeClass


BaseSchema = OrderedDict[str, DataType]

__all__ = ["Schema"]


def _check_dtype(tp: DataType | DataTypeClass) -> bool:
if not isinstance(tp, DataType):
# note: if nested, or has annotations, this implies required init params
if tp.is_nested() or tp.__annotations__:
msg = f"dtypes must be fully-specified, got: {tp!r}"
raise TypeError(msg)
return tp


class Schema(BaseSchema):
"""
Ordered mapping of column names to their data type.
Expand Down Expand Up @@ -58,15 +68,57 @@ def __init__(
| Iterable[tuple[str, DataType | PythonDataType]]
| None
) = None,
*,
check_dtypes: bool = True,
) -> None:
input = (
schema.items() if schema and isinstance(schema, Mapping) else (schema or {})
)
super().__init__({name: parse_into_dtype(tp) for name, tp in input}) # type: ignore[misc]

def __setitem__(self, name: str, dtype: DataType | PythonDataType) -> None:
for name, tp in input: # type: ignore[misc]
if not check_dtypes:
super().__setitem__(name, tp) # type: ignore[assignment]
elif is_polars_dtype(tp):
super().__setitem__(name, _check_dtype(tp)) # type: ignore[assignment]
else:
self[name] = tp

def __eq__(self, other: object) -> bool:
if not isinstance(other, Mapping):
return False
if len(self) != len(other):
return False
for (nm1, tp1), (nm2, tp2) in zip(self.items(), other.items()):
if nm1 != nm2 or not tp1.is_(tp2):
return False
return True

def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

def __setitem__(
self, name: str, dtype: DataType | DataTypeClass | PythonDataType
) -> None:
dtype = _check_dtype(parse_into_dtype(dtype))
super().__setitem__(name, parse_into_dtype(dtype)) # type: ignore[assignment]

def base_types(self) -> dict[str, DataTypeClass]:
"""
Return a dictionary of column names and the fundamental/root type class.
Examples
--------
>>> s = pl.Schema(
... {
... "x": pl.Float64(),
... "y": pl.List(pl.Int32),
... "z": pl.Struct([pl.Field("a", pl.Int8), pl.Field("b", pl.Boolean)]),
... }
... )
>>> s.base_types()
{'x': Float64, 'y': List, 'z': Struct}
"""
return {name: tp.base_type() for name, tp in self.items()}

def names(self) -> list[str]:
"""Get the column names of the schema."""
return list(self.keys())
Expand All @@ -81,7 +133,7 @@ def len(self) -> int:

def to_python(self) -> dict[str, type]:
"""
Return Schema as a dictionary of column names and their Python types.
Return a dictionary of column names and Python types.
Examples
--------
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def schema(self) -> Schema:
return Schema({})

schema = self._s.dtype().to_schema()
return Schema(schema)
return Schema(schema, check_dtypes=False)

def unnest(self) -> DataFrame:
"""
Expand Down
5 changes: 2 additions & 3 deletions py-polars/tests/unit/constructors/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_init_dict() -> None:
data={"dt": dates, "dtm": datetimes},
schema=coldefs,
)
assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime}
assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime("us")}
assert df.rows() == list(zip(py_dates, py_datetimes))

# Overriding dict column names/types
Expand Down Expand Up @@ -251,7 +251,7 @@ class TradeNT(NamedTuple):
)
assert df.schema == {
"ts": pl.Datetime("ms"),
"tk": pl.Categorical,
"tk": pl.Categorical(ordering="physical"),
"pc": pl.Decimal(scale=1),
"sz": pl.UInt16,
}
Expand Down Expand Up @@ -284,7 +284,6 @@ class PageView(BaseModel):
models = adapter.validate_json(data_json)

result = pl.DataFrame(models)

expected = pl.DataFrame(
{
"user_id": ["x"],
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_dtype() -> None:
"u": pl.List(pl.UInt64),
"tm": pl.List(pl.Time),
"dt": pl.List(pl.Date),
"dtm": pl.List(pl.Datetime),
"dtm": pl.List(pl.Datetime("us")),
}
assert all(tp.is_nested() for tp in df.dtypes)
assert df.schema["i"].inner == pl.Int8 # type: ignore[attr-defined]
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_empty_list_construction() -> None:
assert df.to_dict(as_series=False) == expected

df = pl.DataFrame(schema=[("col", pl.List)])
assert df.schema == {"col": pl.List}
assert df.schema == {"col": pl.List(pl.Null)}
assert df.rows() == []


Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/datatypes/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def build_struct_df(data: list[dict[str, object]]) -> pl.DataFrame:
# struct column
df = build_struct_df([{"struct_col": {"inner": 1}}])
assert df.columns == ["struct_col"]
assert df.schema == {"struct_col": pl.Struct}
assert df.schema == {"struct_col": pl.Struct({"inner": pl.Int64})}
assert df["struct_col"].struct.field("inner").to_list() == [1]

# struct in struct
Expand Down
23 changes: 4 additions & 19 deletions py-polars/tests/unit/datatypes/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,16 +632,7 @@ def test_asof_join() -> None:
"2016-05-25 13:30:00.072",
"2016-05-25 13:30:00.075",
]
ticker = [
"GOOG",
"MSFT",
"MSFT",
"MSFT",
"GOOG",
"AAPL",
"GOOG",
"MSFT",
]
ticker = ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"]
quotes = pl.DataFrame(
{
"dates": pl.Series(dates).str.strptime(pl.Datetime, format=format),
Expand All @@ -656,13 +647,7 @@ def test_asof_join() -> None:
"2016-05-25 13:30:00.048",
"2016-05-25 13:30:00.048",
]
ticker = [
"MSFT",
"MSFT",
"GOOG",
"GOOG",
"AAPL",
]
ticker = ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"]
trades = pl.DataFrame(
{
"dates": pl.Series(dates).str.strptime(pl.Datetime, format=format),
Expand All @@ -678,11 +663,11 @@ def test_asof_join() -> None:
out = trades.join_asof(quotes, on="dates", strategy="backward")

assert out.schema == {
"bid": pl.Float64,
"bid_right": pl.Float64,
"dates": pl.Datetime("ms"),
"ticker": pl.String,
"bid": pl.Float64,
"ticker_right": pl.String,
"bid_right": pl.Float64,
}
assert out.columns == ["dates", "ticker", "bid", "ticker_right", "bid_right"]
assert (out["dates"].cast(int)).to_list() == [
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/interop/test_from_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_from_pandas() -> None:
"floats_nulls": pl.Float64,
"strings": pl.String,
"strings_nulls": pl.String,
"strings-cat": pl.Categorical,
"strings-cat": pl.Categorical(ordering="physical"),
}
assert out.rows() == [
(False, None, 1, 1.0, 1.0, 1.0, "foo", "foo", "foo"),
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/interop/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_from_dict_struct() -> None:
assert df.shape == (2, 2)
assert df["a"][0] == {"b": 1, "c": 2}
assert df["a"][1] == {"b": 3, "c": 4}
assert df.schema == {"a": pl.Struct, "d": pl.Int64}
assert df.schema == {"a": pl.Struct({"b": pl.Int64, "c": pl.Int64}), "d": pl.Int64}


def test_from_dicts() -> None:
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_dataframe_from_repr() -> None:
assert frame.schema == {
"a": pl.Int64,
"b": pl.Float64,
"c": pl.Categorical,
"c": pl.Categorical(ordering="physical"),
"d": pl.Boolean,
"e": pl.String,
"f": pl.Date,
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/io/test_spreadsheet.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_read_excel_basic_datatypes(engine: ExcelSpreadsheetEngine) -> None:
xls = BytesIO()
df.write_excel(xls, position="C5")

schema_overrides = {"datetime": pl.Datetime, "nulls": pl.Boolean}
schema_overrides = {"datetime": pl.Datetime("us"), "nulls": pl.Boolean}
df_compare = df.with_columns(
pl.col(nm).cast(tp) for nm, tp in schema_overrides.items()
)
Expand Down Expand Up @@ -325,7 +325,7 @@ def test_read_mixed_dtype_columns(
"Employee ID": pl.Utf8,
"Employee Name": pl.Utf8,
"Date": pl.Date,
"Details": pl.Categorical,
"Details": pl.Categorical("physical"),
"Asset ID": pl.Utf8,
}

Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def test_group_by_sorted_empty_dataframe_3680() -> None:
)
assert df.rows() == []
assert df.shape == (0, 2)
assert df.schema == {"key": pl.Categorical, "val": pl.Float64}
assert df.schema == {"key": pl.Categorical(ordering="physical"), "val": pl.Float64}


def test_group_by_custom_agg_empty_list() -> None:
Expand Down
Loading

0 comments on commit 873d39b

Please sign in to comment.