Skip to content

Commit

Permalink
refactor(api): restrict arbitrary input nesting
Browse files Browse the repository at this point in the history
Previously we allowed arbitrary nesting of input expressions for various
API methods like `table.join`, `table.group_by`, etc. While this was
backward compatible with `8.0.0` it also exposes additional ambiguity
which can be confusing for users and difficult to reason about.

This change restricts the nesting of input expressions to a
"single level" of nesting with the exception of the first positional
argument which can be a list of expressions, but only if there are no
more positional arguments. Examples:

```python
t.select([t.foo, t.bar])  # OK
t.select(t.foo, t.bar)    # OK
t.select([t.foo], t.bar)  # Error
t.select(t.foo, name=t.bar)  # OK
t.select([t.foo], name=t.bar)  # OK
```
  • Loading branch information
kszucs committed Apr 9, 2024
1 parent 423a733 commit 264c880
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 48 deletions.
8 changes: 4 additions & 4 deletions ibis/expr/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,17 @@ def order_by(self, expr) -> Self:
return self.copy(orderings=self.orderings + util.promote_tuple(expr))

def bind(self, table):
from ibis.expr.types.relations import bind
# from ibis.expr.types.relations import bind

if table is None:
if self._table is None:
raise IbisInputError("Cannot bind window frame without a table")
else:
table = self._table.to_expr()

grouping = bind(table, self.groupings)
orderings = bind(table, self.orderings)
return self.copy(groupings=grouping, orderings=orderings)
return self.copy(
groupings=table.bind(self.groupings), orderings=table.bind(self.orderings)
)


class LegacyWindowBuilder(WindowBuilder):
Expand Down
19 changes: 9 additions & 10 deletions ibis/expr/types/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
from ibis.common.grounds import Concrete
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.rewrites import rewrite_window_input
from ibis.expr.types.relations import bind

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from collections.abc import Sequence


@public
Expand Down Expand Up @@ -71,7 +70,7 @@ def aggregate(self, metrics=(), **kwds) -> ir.Table:

agg = aggregate

def having(self, expr: ir.BooleanScalar) -> GroupedTable:
def having(self, *predicates: ir.BooleanScalar) -> GroupedTable:
"""Add a post-aggregation result filter `expr`.
::: {.callout-warning}
Expand All @@ -80,19 +79,19 @@ def having(self, expr: ir.BooleanScalar) -> GroupedTable:
Parameters
----------
expr
An expression that filters based on an aggregate value.
predicates
Expressions that filters based on an aggregate value.
Returns
-------
GroupedTable
A grouped table expression
"""
table = self.table.to_expr()
havings = tuple(bind(table, expr))
havings = table.bind(*predicates)
return self.copy(havings=self.havings + havings)

def order_by(self, expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
def order_by(self, *by: ir.Value) -> GroupedTable:
"""Sort a grouped table expression by `expr`.
Notes
Expand All @@ -101,7 +100,7 @@ def order_by(self, expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
Parameters
----------
expr
by
Expressions to order the results by
Returns
Expand All @@ -110,7 +109,7 @@ def order_by(self, expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
A sorted grouped GroupedTable
"""
table = self.table.to_expr()
orderings = tuple(bind(table, expr))
orderings = table.bind(*by)
return self.copy(orderings=self.orderings + orderings)

def mutate(
Expand Down Expand Up @@ -201,7 +200,7 @@ def _selectables(self, *exprs, **kwexprs):
[`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate)
"""
table = self.table.to_expr()
values = bind(table, (exprs, kwexprs))
values = table.bind(*exprs, **kwexprs)
window = ibis.window(group_by=self.groupings, order_by=self.orderings)
return [rewrite_window_input(expr.op(), window).to_expr() for expr in values]

Expand Down
11 changes: 6 additions & 5 deletions ibis/expr/types/joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ibis.expr.types.generic import Value
from ibis.expr.types.relations import (
Table,
bind,
dereference_mapping,
unwrap_aliases,
)
Expand Down Expand Up @@ -220,9 +219,11 @@ def prepare_predicates(
else:
lk = rk = pred

# TODO(kszucs): bind can emit multiple predicates, this would allow
# selectors to be used as join keys, use zip()
# bind the predicates to the join chain
(left_value,) = bind(left, lk)
(right_value,) = bind(right, rk)
(left_value,) = left.bind(lk)
(right_value,) = right.bind(rk)

# dereference the left value to one of the relations in the join chain
left_value, right_value = dereference_sides(
Expand Down Expand Up @@ -380,7 +381,7 @@ def asof_join(
filtered, predicates=[left_on == right_on] + predicates
)
values = {**self.op().values, **filtered.op().values}
return result.select(values)
return result.select(**values)

left = self.op()
right = ops.JoinTable(right, index=left.length)
Expand Down Expand Up @@ -425,7 +426,7 @@ def cross_join(
@functools.wraps(Table.select)
def select(self, *args, **kwargs):
chain = self.op()
values = bind(self, (args, kwargs))
values = self.bind(*args, **kwargs)
values = unwrap_aliases(values)

# if there are values referencing fields from the join chain constructed
Expand Down
54 changes: 29 additions & 25 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,11 @@ def f( # noqa: D417
return f


# TODO(kszucs): should use (table, *args, **kwargs) instead to avoid interpreting
# nested inputs
def bind(table: Table, value: Any, int_as_column=False) -> Iterator[ir.Value]:
def bind(table: Table, value) -> Iterator[ir.Value]:
"""Bind a value to a table expression."""
if isinstance(value, str):
# TODO(kszucs): perhaps use getattr(table, value) instead for nicer error msg
yield ops.Field(table, value).to_expr()
elif isinstance(value, bool):
yield literal(value)
elif int_as_column and isinstance(value, int):
name = table.columns[value]
yield ops.Field(table, name).to_expr()
elif isinstance(value, ops.Value):
yield value.to_expr()
elif isinstance(value, Value):
Expand All @@ -118,13 +111,6 @@ def bind(table: Table, value: Any, int_as_column=False) -> Iterator[ir.Value]:
yield value.resolve({"_": table})
elif isinstance(value, Selector):
yield from value.expand(table)
elif isinstance(value, Mapping):
for k, v in value.items():
for val in bind(table, v, int_as_column=int_as_column):
yield val.name(k)
elif util.is_iterable(value):
for v in value:
yield from bind(table, v, int_as_column=int_as_column)
elif callable(value):
yield value(table)
else:
Expand Down Expand Up @@ -295,6 +281,19 @@ def _bind_reduction_filter(self, where):

return where.resolve(self)

def bind(self, *args, **kwargs):
if len(args) == 1:
args = util.promote_list(args[0])

values = []
for arg in args:
values.extend(bind(self, arg))
for key, arg in kwargs.items():
(value,) = bind(self, arg)
values.append(value.name(key))

return tuple(values)

def as_scalar(self) -> ir.ScalarExpr:
"""Inform ibis that the table expression should be treated as a scalar.
Expand Down Expand Up @@ -785,7 +784,12 @@ def __getitem__(self, what):
limit, offset = util.slice_to_limit_offset(what, self.count())
return self.limit(limit, offset=offset)

values = tuple(bind(self, what, int_as_column=True))
args = [
self.columns[arg] if isinstance(arg, int) else arg
for arg in util.promote_list(what)
]
values = self.bind(args)

if isinstance(what, (str, int)):
assert len(values) == 1
return values[0]
Expand Down Expand Up @@ -970,7 +974,7 @@ def group_by(
from ibis.expr.types.groupby import GroupedTable

by = tuple(v for v in by if v is not None)
groups = bind(self, (by, key_exprs))
groups = self.bind(*by, **key_exprs)
return GroupedTable(self, groups)

# TODO(kszucs): shouldn't this be ibis.rowid() instead not bound to a specific table?
Expand Down Expand Up @@ -1149,9 +1153,9 @@ def aggregate(

node = self.op()

groups = bind(self, by)
metrics = bind(self, (metrics, kwargs))
having = bind(self, having)
groups = self.bind(by)
metrics = self.bind(metrics, **kwargs)
having = self.bind(having)

groups = unwrap_aliases(groups)
metrics = unwrap_aliases(metrics)
Expand Down Expand Up @@ -1690,7 +1694,7 @@ def order_by(
│ 2 │ B │ 6 │
└───────┴────────┴───────┘
"""
keys = bind(self, by)
keys = self.bind(*by)
keys = unwrap_aliases(keys)
keys = dereference_values(self.op(), keys)
if not keys:
Expand Down Expand Up @@ -1939,7 +1943,7 @@ def mutate(self, *exprs: Sequence[ir.Expr] | None, **mutations: ir.Value) -> Tab
# string and integer inputs are going to be coerced to literals instead
# of interpreted as column references like in select
node = self.op()
values = bind(self, (exprs, mutations))
values = self.bind(*exprs, **mutations)
values = unwrap_aliases(values)
# allow overriding of fields, hence the mutation behavior
values = {**node.fields, **values}
Expand Down Expand Up @@ -2124,7 +2128,7 @@ def select(
"""
from ibis.expr.rewrites import rewrite_project_input

values = bind(self, (exprs, named_exprs))
values = self.bind(*exprs, **named_exprs)
values = unwrap_aliases(values)
values = dereference_values(self.op(), values)
if not values:
Expand Down Expand Up @@ -2501,7 +2505,7 @@ def filter(
from ibis.expr.analysis import flatten_predicates
from ibis.expr.rewrites import rewrite_filter_input

preds = bind(self, predicates)
preds = self.bind(*predicates)
preds = unwrap_aliases(preds)
preds = dereference_values(self.op(), preds)
preds = flatten_predicates(list(preds.values()))
Expand Down Expand Up @@ -2637,7 +2641,7 @@ def dropna(
344
"""
if subset is not None:
subset = bind(self, subset)
subset = self.bind(subset)
return ops.DropNa(self, how, subset).to_expr()

def fillna(
Expand Down
8 changes: 4 additions & 4 deletions ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,10 @@ def test_projection_array_expr(table):
assert_equal(result, expected)


@pytest.mark.parametrize("empty", [list(), dict()])
def test_projection_no_expr(table, empty):
with pytest.raises(com.IbisTypeError, match="must select at least one"):
table.select(empty)
# @pytest.mark.parametrize("empty", [list(), dict()])
# def test_projection_no_expr(table, empty):
# with pytest.raises(com.IbisTypeError, match="must select at least one"):
# table.select(empty)


# FIXME(kszucs): currently bind() flattens the list of expressions, so arbitrary
Expand Down

0 comments on commit 264c880

Please sign in to comment.