Skip to content

Commit

Permalink
refactor(api): restrict arbitrary input nesting
Browse files Browse the repository at this point in the history
Previously we allowed arbitrary nesting of input expressions for various
API methods like `table.join`, `table.group_by`, etc. While this was
backward compatible with `8.0.0` it also exposes additional ambiguity
which can be confusing for users and difficult to reason about.

This change restricts the nesting of input expressions to a
"single level" of nesting with the exception of the first positional
argument which can be a list of expressions, but only if there are no
more positional arguments. Examples:

```python
t.select([t.foo, t.bar])  # OK
t.select(t.foo, t.bar)    # OK
t.select([t.foo], t.bar)  # Error
t.select(t.foo, name=t.bar)  # OK
t.select([t.foo], name=t.bar)  # OK
```
  • Loading branch information
kszucs committed Apr 27, 2024
1 parent 5cb83fc commit 5f317a3
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 98 deletions.
3 changes: 1 addition & 2 deletions ibis/backends/tests/test_vectorized_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,7 @@ def test_elementwise_udf_named_destruct(udf_alltypes):
add_one_struct_udf = create_add_one_struct_udf(
result_formatter=lambda v1, v2: (v1, v2)
)
msg = "Duplicate column name 'new_struct' in result set"
with pytest.raises(com.IntegrityError, match=msg):
with pytest.raises(com.InputTypeError, match="Unable to infer datatype"):
udf_alltypes.mutate(
new_struct=add_one_struct_udf(udf_alltypes["double_col"]).destructure()
)
Expand Down
8 changes: 3 additions & 5 deletions ibis/expr/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,15 @@ def order_by(self, expr) -> Self:
return self.copy(orderings=self.orderings + util.promote_tuple(expr))

def bind(self, table):
from ibis.expr.types.relations import bind

if table is None:
if self._table is None:
raise IbisInputError("Cannot bind window frame without a table")
else:
table = self._table.to_expr()

grouping = bind(table, self.groupings)
orderings = bind(table, self.orderings)
return self.copy(groupings=grouping, orderings=orderings)
return self.copy(
groupings=table.bind(self.groupings), orderings=table.bind(self.orderings)
)


class LegacyWindowBuilder(WindowBuilder):
Expand Down
6 changes: 5 additions & 1 deletion ibis/expr/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ def backtrack(cls, value):
yield value, distance
value = value.rel.values.get(value.name)
distance += 1
if value is not None and not value.find(ops.Impure, filter=ops.Value):
if (
value is not None
and value.relations
and not value.find(ops.Impure, filter=ops.Value)
):
yield value, distance

def dereference(self, value):
Expand Down
21 changes: 19 additions & 2 deletions ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
import ibis.selectors as s
from ibis import _
from ibis.common.annotations import ValidationError
from ibis.common.exceptions import IbisInputError, IntegrityError
Expand Down Expand Up @@ -499,7 +500,7 @@ def test_subsequent_filter():
assert f2.op() == expected


def test_project_dereferences_literal_expressions():
def test_project_doesnt_dereference_literal_expressions():
one = ibis.literal(1)
two = ibis.literal(2)
four = (one + one) * two
Expand All @@ -516,7 +517,7 @@ def test_project_dereferences_literal_expressions():
)

t2 = t1.select(four)
assert t2.op() == Project(parent=t1, values={four.get_name(): t1.four})
assert t2.op() == Project(parent=t1, values={four.get_name(): four})


def test_project_before_and_after_filter():
Expand Down Expand Up @@ -864,6 +865,22 @@ def test_join_predicate_dereferencing_using_tuple_syntax():
assert j2.op() == expected


def test_join_with_selector_predicate():
t1 = ibis.table(name="t1", schema={"a": "string", "b": "string"})
t2 = ibis.table(name="t2", schema={"c": "string", "d": "string"})

joined = t1.join(t2, s.of_type("string"))
with join_tables(joined) as (r1, r2):
expected = JoinChain(
first=r1,
rest=[
JoinLink("inner", r2, [r1.a == r2.c, r1.b == r2.d]),
],
values={"a": r1.a, "b": r1.b, "c": r2.c, "d": r2.d},
)
assert joined.op() == expected


def test_join_rhs_dereferencing():
t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"})
t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"})
Expand Down
22 changes: 11 additions & 11 deletions ibis/expr/types/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
from ibis.common.grounds import Concrete
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.rewrites import rewrite_window_input
from ibis.expr.types.relations import bind

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from collections.abc import Sequence


@public
Expand Down Expand Up @@ -65,13 +64,14 @@ def __getattr__(self, attr):

def aggregate(self, *metrics, **kwds) -> ir.Table:
"""Compute aggregates over a group by."""
metrics = self.table.to_expr().bind(*metrics, **kwds)
return self.table.to_expr().aggregate(
metrics, by=self.groupings, having=self.havings, **kwds
metrics, by=self.groupings, having=self.havings
)

agg = aggregate

def having(self, *expr: ir.BooleanScalar) -> GroupedTable:
def having(self, *predicates: ir.BooleanScalar) -> GroupedTable:
"""Add a post-aggregation result filter `expr`.
::: {.callout-warning}
Expand All @@ -80,19 +80,19 @@ def having(self, *expr: ir.BooleanScalar) -> GroupedTable:
Parameters
----------
expr
An expression that filters based on an aggregate value.
predicates
Expressions that filters based on an aggregate value.
Returns
-------
GroupedTable
A grouped table expression
"""
table = self.table.to_expr()
havings = tuple(bind(table, expr))
havings = table.bind(*predicates)
return self.copy(havings=self.havings + havings)

def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
def order_by(self, *by: ir.Value) -> GroupedTable:
"""Sort a grouped table expression by `expr`.
Notes
Expand All @@ -101,7 +101,7 @@ def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
Parameters
----------
expr
by
Expressions to order the results by
Returns
Expand All @@ -110,7 +110,7 @@ def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
A sorted grouped GroupedTable
"""
table = self.table.to_expr()
orderings = tuple(bind(table, expr))
orderings = table.bind(*by)
return self.copy(orderings=self.orderings + orderings)

def mutate(
Expand Down Expand Up @@ -201,7 +201,7 @@ def _selectables(self, *exprs, **kwexprs):
[`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate)
"""
table = self.table.to_expr()
values = bind(table, (exprs, kwexprs))
values = table.bind(*exprs, **kwexprs)
window = ibis.window(group_by=self.groupings, order_by=self.orderings)
return [rewrite_window_input(expr.op(), window).to_expr() for expr in values]

Expand Down
20 changes: 8 additions & 12 deletions ibis/expr/types/joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,10 @@ def prepare_predicates(
else:
lk = rk = pred

# bind the predicates to the join chain
(left_value,) = bind(left, lk)
(right_value,) = bind(right, rk)

# dereference the left value to one of the relations in the join chain
left_value = deref_left.dereference(left_value.op())
right_value = deref_right.dereference(right_value.op())

yield comparison(left_value, right_value)
for lhs, rhs in zip(bind(left, lk), bind(right, rk)):
lhs = deref_left.dereference(lhs.op())
rhs = deref_right.dereference(rhs.op())
yield comparison(lhs, rhs)


def finished(method):
Expand Down Expand Up @@ -335,8 +330,9 @@ def asof_join(
result = self.left_join(
filtered, predicates=[left_on == right_on] + predicates
)
values = {**self.op().values, **filtered.op().values}
return result.select(values)
values = {**filtered.op().values, **self.op().values}

return result.select(**values)

chain = self.op()
right = right.op()
Expand Down Expand Up @@ -383,7 +379,7 @@ def cross_join(
@functools.wraps(Table.select)
def select(self, *args, **kwargs):
chain = self.op()
values = bind(self, (args, kwargs))
values = self.bind(*args, **kwargs)
values = unwrap_aliases(values)

links = [link.table for link in chain.rest if link.how not in ("semi", "anti")]
Expand Down
Loading

0 comments on commit 5f317a3

Please sign in to comment.