From 3d3f4f344880c2401f970373889348f6f5ce4150 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Thu, 20 Jul 2023 14:31:33 +0200 Subject: [PATCH] fix(api): allow scalar window order keys previously the builder API disallowed passing scalar order keys like ibis.NA and ibis.random() --- ibis/backends/dask/execution/util.py | 28 ++++++++++++++++---------- ibis/backends/pandas/execution/util.py | 28 +++++++++++++++----------- ibis/backends/tests/test_window.py | 10 +++++++++ ibis/expr/builders.py | 8 ++++++-- ibis/expr/operations/sortkeys.py | 2 +- ibis/tests/expr/test_operations.py | 11 ++++++++++ ibis/tests/expr/test_window_frames.py | 22 ++++++++++++++++++++ 7 files changed, 83 insertions(+), 26 deletions(-) diff --git a/ibis/backends/dask/execution/util.py b/ibis/backends/dask/execution/util.py index d624fa2f1d44..caf0b43449bd 100644 --- a/ibis/backends/dask/execution/util.py +++ b/ibis/backends/dask/execution/util.py @@ -219,19 +219,25 @@ def compute_sort_key( `execute` the expression and sort by the new derived column. """ name = ibis.util.guid() - if key.name in data: - return name, data[key.name] - if isinstance(key, str): - return key, None + if key.output_shape.is_columnar(): + if key.name in data: + return name, data[key.name] + if isinstance(key, str): + return key, None + else: + if scope is None: + scope = Scope() + scope = scope.merge_scopes( + Scope({t: data}, timecontext) + for t in an.find_immediate_parent_tables(key) + ) + new_column = execute(key, scope=scope, **kwargs) + new_column.name = name + return name, new_column else: - if scope is None: - scope = Scope() - scope = scope.merge_scopes( - Scope({t: data}, timecontext) for t in an.find_immediate_parent_tables(key) + raise NotImplementedError( + "Scalar sort keys are not yet supported in the dask backend" ) - new_column = execute(key, scope=scope, **kwargs) - new_column.name = name - return name, new_column def compute_sorted_frame( diff --git a/ibis/backends/pandas/execution/util.py b/ibis/backends/pandas/execution/util.py index 67859598147f..c6ab518c0011 100644 --- a/ibis/backends/pandas/execution/util.py +++ b/ibis/backends/pandas/execution/util.py @@ -32,20 +32,24 @@ def get_join_suffix_for_op(op: ops.TableColumn, join_op: ops.Join): def compute_sort_key(key, data, timecontext, scope=None, **kwargs): - if isinstance(key, str): - return key, None - elif key.name in data: - return key.name, None + if key.output_shape.is_columnar(): + if key.name in data: + return key.name, None + else: + if scope is None: + scope = Scope() + scope = scope.merge_scopes( + Scope({t: data}, timecontext) + for t in an.find_immediate_parent_tables(key) + ) + new_column = execute(key, scope=scope, **kwargs) + name = ibis.util.guid() + new_column.name = name + return name, new_column else: - if scope is None: - scope = Scope() - scope = scope.merge_scopes( - Scope({t: data}, timecontext) for t in an.find_immediate_parent_tables(key) + raise NotImplementedError( + "Scalar sort keys are not yet supported in the pandas backend" ) - new_column = execute(key, scope=scope, **kwargs) - name = ibis.util.guid() - new_column.name = name - return name, new_column def compute_sorted_frame(df, order_by, group_by=(), timecontext=None, **kwargs): diff --git a/ibis/backends/tests/test_window.py b/ibis/backends/tests/test_window.py index ea4c32435ca4..6f95bd7d1b92 100644 --- a/ibis/backends/tests/test_window.py +++ b/ibis/backends/tests/test_window.py @@ -571,6 +571,16 @@ def test_simple_ungrouped_unbound_following_window( backend.assert_series_equal(result, expected) +@pytest.mark.broken(["pandas", "dask"], raises=NotImplementedError) +@pytest.mark.notimpl(["datafusion", "polars"], raises=com.OperationNotDefinedError) +def test_simple_ungrouped_window_with_scalar_order_by(backend, alltypes): + t = alltypes[alltypes.double_col < 50].order_by('id') + w = ibis.window(rows=(0, None), order_by=ibis.NA) + expr = t.double_col.sum().over(w).name('double_col') + # hard to reproduce this in pandas, so just test that it actually executes + expr.execute() + + @pytest.mark.parametrize( ("result_fn", "expected_fn", "ordered"), [ diff --git a/ibis/expr/builders.py b/ibis/expr/builders.py index 87a7e8b714a1..8a85d77fe532 100644 --- a/ibis/expr/builders.py +++ b/ibis/expr/builders.py @@ -99,12 +99,16 @@ class WindowBuilder(Builder): how = rlz.optional(rlz.isin({'rows', 'range'}), default="rows") start = end = rlz.optional(rlz.option(rlz.range_window_boundary)) - groupings = orderings = rlz.optional( + groupings = rlz.optional( rlz.tuple_of( - rlz.one_of([rlz.column(rlz.any), rlz.instance_of((str, Deferred))]) + rlz.one_of([rlz.instance_of((str, Deferred)), rlz.column(rlz.any)]) ), default=(), ) + orderings = rlz.optional( + rlz.tuple_of(rlz.one_of([rlz.instance_of((str, Deferred)), rlz.any])), + default=(), + ) max_lookback = rlz.optional(rlz.interval) def _maybe_cast_boundary(self, boundary, dtype): diff --git a/ibis/expr/operations/sortkeys.py b/ibis/expr/operations/sortkeys.py index 385d0119313c..4f936e909632 100644 --- a/ibis/expr/operations/sortkeys.py +++ b/ibis/expr/operations/sortkeys.py @@ -16,7 +16,7 @@ class SortKey(Value): ascending = rlz.optional(rlz.bool_, default=True) output_dtype = rlz.dtype_like("expr") - output_shape = rlz.Shape.COLUMNAR + output_shape = rlz.shape_like("expr") @property def name(self) -> str: diff --git a/ibis/tests/expr/test_operations.py b/ibis/tests/expr/test_operations.py index 88d51ec17af5..9600e914f861 100644 --- a/ibis/tests/expr/test_operations.py +++ b/ibis/tests/expr/test_operations.py @@ -267,3 +267,14 @@ def test_expression_class_aliases(): assert ir.AnyValue is ir.Value assert ir.AnyScalar is ir.Scalar assert ir.AnyColumn is ir.Column + + +def test_sortkey_propagates_dtype_and_shape(): + k = ops.SortKey(ibis.literal(1), ascending=True) + assert k.output_dtype == dt.int8 + assert k.output_shape == rlz.Shape.SCALAR + + t = ibis.table([('a', 'int16')], name='t') + k = ops.SortKey(t.a, ascending=True) + assert k.output_dtype == dt.int16 + assert k.output_shape == rlz.Shape.COLUMNAR diff --git a/ibis/tests/expr/test_window_frames.py b/ibis/tests/expr/test_window_frames.py index 8a4f27543ee0..5095a3f453e8 100644 --- a/ibis/tests/expr/test_window_frames.py +++ b/ibis/tests/expr/test_window_frames.py @@ -188,6 +188,28 @@ def test_window_api_supports_value_expressions(alltypes): ) +def test_window_api_supports_scalar_order_by(alltypes): + t = alltypes + + w = ibis.window(order_by=ibis.NA) + assert w.bind(t) == ops.RowsWindowFrame( + table=t, + start=None, + end=None, + group_by=(), + order_by=(ibis.NA.op(),), + ) + + w = ibis.window(order_by=ibis.random()) + assert w.bind(t) == ops.RowsWindowFrame( + table=t, + start=None, + end=None, + group_by=(), + order_by=(ibis.random().op(),), + ) + + def test_window_api_properly_determines_how(): assert ibis.window(between=(None, 5)).how == 'rows' assert ibis.window(between=(1, 3)).how == 'rows'