From 4b9d5440946127a80130622268ce55c130b36947 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 8 Apr 2024 13:18:21 +0200 Subject: [PATCH] refactor(api): restrict arbitrary input nesting Previously we allowed arbitrary nesting of input expressions for various API methods like `table.join`, `table.group_by`, etc. While this was backward compatible with `8.0.0` it also exposes additional ambiguity which can be confusing for users and difficult to reason about. This change restricts the nesting of input expressions to a "single level" of nesting with the exception of the first positional argument which can be a list of expressions, but only if there are no more positional arguments. Examples: ```python t.select([t.foo, t.bar]) # OK t.select(t.foo, t.bar) # OK t.select([t.foo], t.bar) # Error t.select(t.foo, name=t.bar) # OK t.select([t.foo], name=t.bar) # OK ``` --- ibis/backends/tests/test_vectorized_udf.py | 3 +- ibis/expr/builders.py | 8 +- ibis/expr/tests/test_newrels.py | 17 ++++ ibis/expr/types/groupby.py | 24 +++-- ibis/expr/types/joins.py | 20 ++-- ibis/expr/types/relations.py | 106 ++++++++++----------- ibis/tests/expr/test_table.py | 94 +++++++++++++++++- 7 files changed, 181 insertions(+), 91 deletions(-) diff --git a/ibis/backends/tests/test_vectorized_udf.py b/ibis/backends/tests/test_vectorized_udf.py index 7b9bff89a85ec..6312cbe329703 100644 --- a/ibis/backends/tests/test_vectorized_udf.py +++ b/ibis/backends/tests/test_vectorized_udf.py @@ -592,8 +592,7 @@ def test_elementwise_udf_named_destruct(udf_alltypes): add_one_struct_udf = create_add_one_struct_udf( result_formatter=lambda v1, v2: (v1, v2) ) - msg = "Duplicate column name 'new_struct' in result set" - with pytest.raises(com.IntegrityError, match=msg): + with pytest.raises(com.InputTypeError, match="Unable to infer datatype"): udf_alltypes.mutate( new_struct=add_one_struct_udf(udf_alltypes["double_col"]).destructure() ) diff --git a/ibis/expr/builders.py b/ibis/expr/builders.py index b984599270212..4bd99fabe6478 100644 --- a/ibis/expr/builders.py +++ b/ibis/expr/builders.py @@ -226,7 +226,7 @@ def order_by(self, expr) -> Self: return self.copy(orderings=self.orderings + util.promote_tuple(expr)) def bind(self, table): - from ibis.expr.types.relations import bind + # from ibis.expr.types.relations import bind if table is None: if self._table is None: @@ -234,9 +234,9 @@ def bind(self, table): else: table = self._table.to_expr() - grouping = bind(table, self.groupings) - orderings = bind(table, self.orderings) - return self.copy(groupings=grouping, orderings=orderings) + return self.copy( + groupings=table.bind(self.groupings), orderings=table.bind(self.orderings) + ) class LegacyWindowBuilder(WindowBuilder): diff --git a/ibis/expr/tests/test_newrels.py b/ibis/expr/tests/test_newrels.py index 22588deb65d8f..a099c9075bb89 100644 --- a/ibis/expr/tests/test_newrels.py +++ b/ibis/expr/tests/test_newrels.py @@ -9,6 +9,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.types as ir +import ibis.selectors as s from ibis import _ from ibis.common.annotations import ValidationError from ibis.common.exceptions import IbisInputError, IntegrityError @@ -864,6 +865,22 @@ def test_join_predicate_dereferencing_using_tuple_syntax(): assert j2.op() == expected +def test_join_with_selector_predicate(): + t1 = ibis.table(name="t1", schema={"a": "string", "b": "string"}) + t2 = ibis.table(name="t2", schema={"c": "string", "d": "string"}) + + joined = t1.join(t2, s.of_type("string")) + with join_tables(joined) as (r1, r2): + expected = JoinChain( + first=r1, + rest=[ + JoinLink("inner", r2, [r1.a == r2.c, r1.b == r2.d]), + ], + values={"a": r1.a, "b": r1.b, "c": r2.c, "d": r2.d}, + ) + assert joined.op() == expected + + def test_join_rhs_dereferencing(): t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"}) t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"}) diff --git a/ibis/expr/types/groupby.py b/ibis/expr/types/groupby.py index 3e18946eebb92..cbcc20a97f059 100644 --- a/ibis/expr/types/groupby.py +++ b/ibis/expr/types/groupby.py @@ -28,10 +28,9 @@ from ibis.common.grounds import Concrete from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.rewrites import rewrite_window_input -from ibis.expr.types.relations import bind if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Sequence @public @@ -65,13 +64,12 @@ def __getattr__(self, attr): def aggregate(self, *metrics, **kwds) -> ir.Table: """Compute aggregates over a group by.""" - return self.table.to_expr().aggregate( - metrics, by=self.groupings, having=self.havings, **kwds - ) + metrics = self.table.to_expr().bind(*metrics, **kwds) + return self.table.to_expr().aggregate(metrics, by=self.groupings, having=self.havings) agg = aggregate - def having(self, *expr: ir.BooleanScalar) -> GroupedTable: + def having(self, *predicates: ir.BooleanScalar) -> GroupedTable: """Add a post-aggregation result filter `expr`. ::: {.callout-warning} @@ -80,8 +78,8 @@ def having(self, *expr: ir.BooleanScalar) -> GroupedTable: Parameters ---------- - expr - An expression that filters based on an aggregate value. + predicates + Expressions that filters based on an aggregate value. Returns ------- @@ -89,10 +87,10 @@ def having(self, *expr: ir.BooleanScalar) -> GroupedTable: A grouped table expression """ table = self.table.to_expr() - havings = tuple(bind(table, expr)) + havings = table.bind(*predicates) return self.copy(havings=self.havings + havings) - def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable: + def order_by(self, *by: ir.Value) -> GroupedTable: """Sort a grouped table expression by `expr`. Notes @@ -101,7 +99,7 @@ def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable: Parameters ---------- - expr + by Expressions to order the results by Returns @@ -110,7 +108,7 @@ def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable: A sorted grouped GroupedTable """ table = self.table.to_expr() - orderings = tuple(bind(table, expr)) + orderings = table.bind(*by) return self.copy(orderings=self.orderings + orderings) def mutate( @@ -201,7 +199,7 @@ def _selectables(self, *exprs, **kwexprs): [`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate) """ table = self.table.to_expr() - values = bind(table, (exprs, kwexprs)) + values = table.bind(*exprs, **kwexprs) window = ibis.window(group_by=self.groupings, order_by=self.orderings) return [rewrite_window_input(expr.op(), window).to_expr() for expr in values] diff --git a/ibis/expr/types/joins.py b/ibis/expr/types/joins.py index ff081a353a510..3723864cd1334 100644 --- a/ibis/expr/types/joins.py +++ b/ibis/expr/types/joins.py @@ -176,15 +176,10 @@ def prepare_predicates( else: lk = rk = pred - # bind the predicates to the join chain - (left_value,) = bind(left, lk) - (right_value,) = bind(right, rk) - - # dereference the left value to one of the relations in the join chain - left_value = deref_left.dereference(left_value.op()) - right_value = deref_right.dereference(right_value.op()) - - yield comparison(left_value, right_value) + for lhs, rhs in zip(bind(left, lk), bind(right, rk)): + lhs = deref_left.dereference(lhs.op()) + rhs = deref_right.dereference(rhs.op()) + yield comparison(lhs, rhs) def finished(method): @@ -335,8 +330,9 @@ def asof_join( result = self.left_join( filtered, predicates=[left_on == right_on] + predicates ) - values = {**self.op().values, **filtered.op().values} - return result.select(values) + values = {**filtered.op().values, **self.op().values} + + return result.select(**values) chain = self.op() right = right.op() @@ -383,7 +379,7 @@ def cross_join( @functools.wraps(Table.select) def select(self, *args, **kwargs): chain = self.op() - values = bind(self, (args, kwargs)) + values = self.bind(*args, **kwargs) values = unwrap_aliases(values) links = [link.table for link in chain.rest if link.how not in ("semi", "anti")] diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 6678f3e9f6900..2243dab6bb080 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -91,18 +91,11 @@ def f( # noqa: D417 return f -# TODO(kszucs): should use (table, *args, **kwargs) instead to avoid interpreting -# nested inputs -def bind(table: Table, value: Any, int_as_column=False) -> Iterator[ir.Value]: +def bind(table: Table, value) -> Iterator[ir.Value]: """Bind a value to a table expression.""" if isinstance(value, str): # TODO(kszucs): perhaps use getattr(table, value) instead for nicer error msg yield ops.Field(table, value).to_expr() - elif isinstance(value, bool): - yield literal(value) - elif int_as_column and isinstance(value, int): - name = table.columns[value] - yield ops.Field(table, name).to_expr() elif isinstance(value, ops.Value): yield value.to_expr() elif isinstance(value, Value): @@ -116,13 +109,6 @@ def bind(table: Table, value: Any, int_as_column=False) -> Iterator[ir.Value]: yield value.resolve({"_": table}) elif isinstance(value, Selector): yield from value.expand(table) - elif isinstance(value, Mapping): - for k, v in value.items(): - for val in bind(table, v, int_as_column=int_as_column): - yield val.name(k) - elif util.is_iterable(value): - for v in value: - yield from bind(table, v, int_as_column=int_as_column) elif callable(value): yield value(table) else: @@ -135,7 +121,7 @@ def unwrap_aliases(values: Iterator[ir.Value]) -> Mapping[str, ir.Value]: for value in values: node = value.op() if node.name in result: - raise com.IntegrityError( + raise com.IbisInputError( f"Duplicate column name {node.name!r} in result set" ) if isinstance(node, ops.Alias): @@ -145,29 +131,6 @@ def unwrap_aliases(values: Iterator[ir.Value]) -> Mapping[str, ir.Value]: return result -def dereference_values( - parents: Iterable[ops.Parents], values: Mapping[str, ops.Value] -) -> Mapping[str, ops.Value]: - """Trace and replace fields from earlier relations in the hierarchy. - - For more details see :class:`ibis.expr.rewrites.DerefMap`. - - Parameters - ---------- - parents - The relations we want the values to point to. - values - The values to dereference. - - Returns - ------- - The same mapping as `values` but with all the dereferenceable fields - replaced with the fields from the parents. - """ - dm = DerefMap.from_targets(parents) - return {k: dm.dereference(v) for k, v in values.items()} - - @public class Table(Expr, _FixedTextJupyterMixin): """An immutable and lazy dataframe. @@ -247,6 +210,40 @@ def _bind_reduction_filter(self, where): return where.resolve(self) + def bind(self, *args, **kwargs): + # allow the first argument to be either a dictionary or a list of values + if len(args) == 1: + if isinstance(args[0], dict): + kwargs = {**args[0], **kwargs} + args = () + else: + args = util.promote_list(args[0]) + + # bind positional arguments + values = [] + for arg in args: + values.extend(bind(self, arg)) + + # bind keyword arguments where each entry can produce only one value + # which is then named with the given key + for key, arg in kwargs.items(): + try: + (value,) = bind(self, arg) + except ValueError: + raise com.IbisInputError( + "Keyword arguments cannot produce more than one value" + ) + values.append(value.name(key)) + + # dereference the values to `self` + dm = DerefMap.from_targets(self.op()) + result = [] + for original in values: + value = dm.dereference(original.op()).to_expr() + value = value.name(original.get_name()) + result.append(value) + return tuple(result) + def as_scalar(self) -> ir.ScalarExpr: """Inform ibis that the table expression should be treated as a scalar. @@ -769,7 +766,12 @@ def __getitem__(self, what): limit, offset = util.slice_to_limit_offset(what, self.count()) return self.limit(limit, offset=offset) - values = tuple(bind(self, what, int_as_column=True)) + args = [ + self.columns[arg] if isinstance(arg, int) else arg + for arg in util.promote_list(what) + ] + values = self.bind(args) + if isinstance(what, (str, int)): assert len(values) == 1 return values[0] @@ -954,7 +956,7 @@ def group_by( from ibis.expr.types.groupby import GroupedTable by = tuple(v for v in by if v is not None) - groups = bind(self, (by, key_exprs)) + groups = self.bind(*by, **key_exprs) return GroupedTable(self, groups) # TODO(kszucs): shouldn't this be ibis.rowid() instead not bound to a specific table? @@ -1133,16 +1135,13 @@ def aggregate( node = self.op() - groups = bind(self, by) - metrics = bind(self, (metrics, kwargs)) - having = tuple(bind(self, having)) + groups = self.bind(by) + metrics = self.bind(metrics, **kwargs) + having = self.bind(having) groups = unwrap_aliases(groups) metrics = unwrap_aliases(metrics) - groups = dereference_values(node, groups) - metrics = dereference_values(node, metrics) - # the user doesn't need to specify the metrics used in the having clause # explicitly, we implicitly add them to the metrics list by looking for # any metrics depending on self which are not specified explicitly @@ -1672,9 +1671,8 @@ def order_by( │ 2 │ B │ 6 │ └───────┴────────┴───────┘ """ - keys = bind(self, by) + keys = self.bind(*by) keys = unwrap_aliases(keys) - keys = dereference_values(self.op(), keys) if not keys: raise com.IbisError("At least one sort key must be provided") @@ -1921,7 +1919,7 @@ def mutate(self, *exprs: Sequence[ir.Expr] | None, **mutations: ir.Value) -> Tab # string and integer inputs are going to be coerced to literals instead # of interpreted as column references like in select node = self.op() - values = bind(self, (exprs, mutations)) + values = self.bind(*exprs, **mutations) values = unwrap_aliases(values) # allow overriding of fields, hence the mutation behavior values = {**node.fields, **values} @@ -2106,9 +2104,8 @@ def select( """ from ibis.expr.rewrites import rewrite_project_input - values = bind(self, (exprs, named_exprs)) + values = self.bind(*exprs, **named_exprs) values = unwrap_aliases(values) - values = dereference_values(self.op(), values) if not values: raise com.IbisTypeError( "You must select at least one column for a valid projection" @@ -2483,9 +2480,8 @@ def filter( from ibis.expr.analysis import flatten_predicates from ibis.expr.rewrites import rewrite_filter_input - preds = bind(self, predicates) + preds = self.bind(*predicates) preds = unwrap_aliases(preds) - preds = dereference_values(self.op(), preds) preds = flatten_predicates(list(preds.values())) preds = list(map(rewrite_filter_input, preds)) if not preds: @@ -2619,7 +2615,7 @@ def dropna( 344 """ if subset is not None: - subset = bind(self, subset) + subset = self.bind(subset) return ops.DropNa(self, how, subset).to_expr() def fillna( diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index c8b2ed0103ce9..bf41b8dc7f349 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -192,7 +192,7 @@ def test_projection_with_exprs(table): def test_projection_duplicate_names(table): - with pytest.raises(com.IntegrityError): + with pytest.raises(com.IbisInputError, match="Duplicate column name 'c'"): table.select([table.c, table.c]) @@ -267,10 +267,10 @@ def test_projection_array_expr(table): assert_equal(result, expected) -@pytest.mark.parametrize("empty", [list(), dict()]) -def test_projection_no_expr(table, empty): - with pytest.raises(com.IbisTypeError, match="must select at least one"): - table.select(empty) +# @pytest.mark.parametrize("empty", [list(), dict()]) +# def test_projection_no_expr(table, empty): +# with pytest.raises(com.IbisTypeError, match="must select at least one"): +# table.select(empty) # FIXME(kszucs): currently bind() flattens the list of expressions, so arbitrary @@ -2077,3 +2077,87 @@ def test_unbind_with_namespace(): assert s.op() == expected.op() assert s.equals(expected) + + +def test_table_bind(): + def eq(left, right): + return all(a.equals(b) for a, b in zip(left, right)) + + t = ibis.table({"a": "int", "b": "string"}, name="t") + + # boolean literals + exprs = t.bind(True, False) + expected = (ibis.literal(True), ibis.literal(False)) + assert eq(exprs, expected) + + # int literals + exprs = t.bind(1, 2) + expected = (ibis.literal(1), ibis.literal(2)) + assert eq(exprs, expected) + + # lambda input + exprs = t.bind(lambda t: t.a, lambda t: t.b) + expected = (t.a, t.b) + assert eq(exprs, expected) + + # deferred input + exprs = t.bind(_.a, _.b) + expected = (t.a, t.b) + assert eq(exprs, expected) + + # single table arg + exprs = t.bind(t) + expected = (t.a, t.b) + assert eq(exprs, expected) + + # single selector arg + exprs = t.bind(s.all()) + expected = (t.a, t.b) + assert eq(exprs, expected) + + # single tuple arg + exprs = t.bind([1, "a"]) + expected = (ibis.literal(1), t.a) + assert eq(exprs, expected) + + # single list arg + exprs = t.bind([1, 2, "b"]) + expected = (ibis.literal(1), ibis.literal(2), t.b) + assert eq(exprs, expected) + + # single list arg with kwargs + exprs = t.bind([1], b=2) + expected = (ibis.literal(1), ibis.literal(2).name("b")) + assert eq(exprs, expected) + + # single dict arg + exprs = t.bind({"c": 1, "d": 2}) + expected = (ibis.literal(1).name("c"), ibis.literal(2).name("d")) + assert eq(exprs, expected) + + # single dict arg with kwargs + exprs = t.bind({"c": 1}, d=2) + expected = (ibis.literal(1).name("c"), ibis.literal(2).name("d")) + assert eq(exprs, expected) + + # single dict arg with overlapping kwargs + exprs = t.bind({"c": 1, "d": 2}, c=2) + expected = (ibis.literal(2).name("c"), ibis.literal(2).name("d")) + assert eq(exprs, expected) + + # kwargs cannot cannot produce more than one value + with pytest.raises(com.IbisInputError): + t.bind(alias=t) + with pytest.raises(com.IbisInputError): + t.bind(alias=s.all()) + + # multiple args + exprs = t.bind(t, ["a", "b"], {"c": 1}, d=2) + expected = ( + t.a, + t.b, + ibis.literal(["a", "b"]), + ibis.literal({"c": 1}), + ibis.literal(2).name("d"), + ) + assert eq(exprs, expected)