Skip to content

Commit

Permalink
refactor(api): restrict arbitrary input nesting
Browse files Browse the repository at this point in the history
Previously we allowed arbitrary nesting of input expressions for various
API methods like `table.join`, `table.group_by`, etc. While this was
backward compatible with `8.0.0` it also exposes additional ambiguity
which can be confusing for users and difficult to reason about.

This change restricts the nesting of input expressions to a
"single level" of nesting with the exception of the first positional
argument which can be a list of expressions, but only if there are no
more positional arguments. Examples:

```python
t.select([t.foo, t.bar])  # OK
t.select(t.foo, t.bar)    # OK
t.select([t.foo], t.bar)  # Error
t.select(t.foo, name=t.bar)  # OK
t.select([t.foo], name=t.bar)  # OK
```
  • Loading branch information
kszucs committed Apr 25, 2024
1 parent d7a31aa commit 50cdc8e
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 50 deletions.
8 changes: 4 additions & 4 deletions ibis/expr/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,17 @@ def order_by(self, expr) -> Self:
return self.copy(orderings=self.orderings + util.promote_tuple(expr))

def bind(self, table):
from ibis.expr.types.relations import bind
# from ibis.expr.types.relations import bind

if table is None:
if self._table is None:
raise IbisInputError("Cannot bind window frame without a table")
else:
table = self._table.to_expr()

grouping = bind(table, self.groupings)
orderings = bind(table, self.orderings)
return self.copy(groupings=grouping, orderings=orderings)
return self.copy(
groupings=table.bind(self.groupings), orderings=table.bind(self.orderings)
)


class LegacyWindowBuilder(WindowBuilder):
Expand Down
24 changes: 11 additions & 13 deletions ibis/expr/types/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
from ibis.common.grounds import Concrete
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.rewrites import rewrite_window_input
from ibis.expr.types.relations import bind

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from collections.abc import Sequence


@public
Expand Down Expand Up @@ -65,13 +64,12 @@ def __getattr__(self, attr):

def aggregate(self, *metrics, **kwds) -> ir.Table:
"""Compute aggregates over a group by."""
return self.table.to_expr().aggregate(
metrics, by=self.groupings, having=self.havings, **kwds
)
metrics = self.table.to_expr().bind(*metrics, **kwds)
return self.table.to_expr().aggregate(metrics, by=self.groupings, having=self.havings)

agg = aggregate

def having(self, *expr: ir.BooleanScalar) -> GroupedTable:
def having(self, *predicates: ir.BooleanScalar) -> GroupedTable:
"""Add a post-aggregation result filter `expr`.
::: {.callout-warning}
Expand All @@ -80,19 +78,19 @@ def having(self, *expr: ir.BooleanScalar) -> GroupedTable:
Parameters
----------
expr
An expression that filters based on an aggregate value.
predicates
Expressions that filters based on an aggregate value.
Returns
-------
GroupedTable
A grouped table expression
"""
table = self.table.to_expr()
havings = tuple(bind(table, expr))
havings = table.bind(*predicates)
return self.copy(havings=self.havings + havings)

def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
def order_by(self, *by: ir.Value) -> GroupedTable:
"""Sort a grouped table expression by `expr`.
Notes
Expand All @@ -101,7 +99,7 @@ def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
Parameters
----------
expr
by
Expressions to order the results by
Returns
Expand All @@ -110,7 +108,7 @@ def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
A sorted grouped GroupedTable
"""
table = self.table.to_expr()
orderings = tuple(bind(table, expr))
orderings = table.bind(*by)
return self.copy(orderings=self.orderings + orderings)

def mutate(
Expand Down Expand Up @@ -201,7 +199,7 @@ def _selectables(self, *exprs, **kwexprs):
[`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate)
"""
table = self.table.to_expr()
values = bind(table, (exprs, kwexprs))
values = table.bind(*exprs, **kwexprs)
window = ibis.window(group_by=self.groupings, order_by=self.orderings)
return [rewrite_window_input(expr.op(), window).to_expr() for expr in values]

Expand Down
10 changes: 6 additions & 4 deletions ibis/expr/types/joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,11 @@ def prepare_predicates(
else:
lk = rk = pred

# TODO(kszucs): bind can emit multiple predicates, this would allow
# selectors to be used as join keys, use zip()
# bind the predicates to the join chain
(left_value,) = bind(left, lk)
(right_value,) = bind(right, rk)
(left_value,) = left.bind(lk)
(right_value,) = right.bind(rk)

# dereference the left value to one of the relations in the join chain
left_value = deref_left.dereference(left_value.op())
Expand Down Expand Up @@ -336,7 +338,7 @@ def asof_join(
filtered, predicates=[left_on == right_on] + predicates
)
values = {**self.op().values, **filtered.op().values}
return result.select(values)
return result.select(**values)

chain = self.op()
right = right.op()
Expand Down Expand Up @@ -383,7 +385,7 @@ def cross_join(
@functools.wraps(Table.select)
def select(self, *args, **kwargs):
chain = self.op()
values = bind(self, (args, kwargs))
values = self.bind(*args, **kwargs)
values = unwrap_aliases(values)

links = [link.table for link in chain.rest if link.how not in ("semi", "anti")]
Expand Down
63 changes: 38 additions & 25 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,11 @@ def f( # noqa: D417
return f


# TODO(kszucs): should use (table, *args, **kwargs) instead to avoid interpreting
# nested inputs
def bind(table: Table, value: Any, int_as_column=False) -> Iterator[ir.Value]:
def bind(table: Table, value) -> Iterator[ir.Value]:
"""Bind a value to a table expression."""
if isinstance(value, str):
# TODO(kszucs): perhaps use getattr(table, value) instead for nicer error msg
yield ops.Field(table, value).to_expr()
elif isinstance(value, bool):
yield literal(value)
elif int_as_column and isinstance(value, int):
name = table.columns[value]
yield ops.Field(table, name).to_expr()
elif isinstance(value, ops.Value):
yield value.to_expr()
elif isinstance(value, Value):
Expand All @@ -116,13 +109,6 @@ def bind(table: Table, value: Any, int_as_column=False) -> Iterator[ir.Value]:
yield value.resolve({"_": table})
elif isinstance(value, Selector):
yield from value.expand(table)
elif isinstance(value, Mapping):
for k, v in value.items():
for val in bind(table, v, int_as_column=int_as_column):
yield val.name(k)
elif util.is_iterable(value):
for v in value:
yield from bind(table, v, int_as_column=int_as_column)
elif callable(value):
yield value(table)
else:
Expand Down Expand Up @@ -247,6 +233,28 @@ def _bind_reduction_filter(self, where):

return where.resolve(self)

def bind(self, *args, **kwargs):
if len(args) == 1:
if isinstance(args[0], dict):
kwargs = {**args[0], **kwargs}
args = ()
else:
args = util.promote_list(args[0])

values = []
for arg in args:
values.extend(bind(self, arg))
for key, arg in kwargs.items():
try:
(value,) = bind(self, arg)
except ValueError:
raise com.IbisInputError(
"Keyword arguments cannot produce more than one value"
)
values.append(value.name(key))

return tuple(values)

def as_scalar(self) -> ir.ScalarExpr:
"""Inform ibis that the table expression should be treated as a scalar.
Expand Down Expand Up @@ -769,7 +777,12 @@ def __getitem__(self, what):
limit, offset = util.slice_to_limit_offset(what, self.count())
return self.limit(limit, offset=offset)

values = tuple(bind(self, what, int_as_column=True))
args = [
self.columns[arg] if isinstance(arg, int) else arg
for arg in util.promote_list(what)
]
values = self.bind(args)

if isinstance(what, (str, int)):
assert len(values) == 1
return values[0]
Expand Down Expand Up @@ -954,7 +967,7 @@ def group_by(
from ibis.expr.types.groupby import GroupedTable

by = tuple(v for v in by if v is not None)
groups = bind(self, (by, key_exprs))
groups = self.bind(*by, **key_exprs)
return GroupedTable(self, groups)

# TODO(kszucs): shouldn't this be ibis.rowid() instead not bound to a specific table?
Expand Down Expand Up @@ -1133,9 +1146,9 @@ def aggregate(

node = self.op()

groups = bind(self, by)
metrics = bind(self, (metrics, kwargs))
having = tuple(bind(self, having))
groups = self.bind(by)
metrics = self.bind(metrics, **kwargs)
having = self.bind(having)

groups = unwrap_aliases(groups)
metrics = unwrap_aliases(metrics)
Expand Down Expand Up @@ -1672,7 +1685,7 @@ def order_by(
│ 2 │ B │ 6 │
└───────┴────────┴───────┘
"""
keys = bind(self, by)
keys = self.bind(*by)
keys = unwrap_aliases(keys)
keys = dereference_values(self.op(), keys)
if not keys:
Expand Down Expand Up @@ -1921,7 +1934,7 @@ def mutate(self, *exprs: Sequence[ir.Expr] | None, **mutations: ir.Value) -> Tab
# string and integer inputs are going to be coerced to literals instead
# of interpreted as column references like in select
node = self.op()
values = bind(self, (exprs, mutations))
values = self.bind(*exprs, **mutations)
values = unwrap_aliases(values)
# allow overriding of fields, hence the mutation behavior
values = {**node.fields, **values}
Expand Down Expand Up @@ -2106,7 +2119,7 @@ def select(
"""
from ibis.expr.rewrites import rewrite_project_input

values = bind(self, (exprs, named_exprs))
values = self.bind(*exprs, **named_exprs)
values = unwrap_aliases(values)
values = dereference_values(self.op(), values)
if not values:
Expand Down Expand Up @@ -2483,7 +2496,7 @@ def filter(
from ibis.expr.analysis import flatten_predicates
from ibis.expr.rewrites import rewrite_filter_input

preds = bind(self, predicates)
preds = self.bind(*predicates)
preds = unwrap_aliases(preds)
preds = dereference_values(self.op(), preds)
preds = flatten_predicates(list(preds.values()))
Expand Down Expand Up @@ -2619,7 +2632,7 @@ def dropna(
344
"""
if subset is not None:
subset = bind(self, subset)
subset = self.bind(subset)
return ops.DropNa(self, how, subset).to_expr()

def fillna(
Expand Down
72 changes: 68 additions & 4 deletions ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,10 @@ def test_projection_array_expr(table):
assert_equal(result, expected)


@pytest.mark.parametrize("empty", [list(), dict()])
def test_projection_no_expr(table, empty):
with pytest.raises(com.IbisTypeError, match="must select at least one"):
table.select(empty)
# @pytest.mark.parametrize("empty", [list(), dict()])
# def test_projection_no_expr(table, empty):
# with pytest.raises(com.IbisTypeError, match="must select at least one"):
# table.select(empty)


# FIXME(kszucs): currently bind() flattens the list of expressions, so arbitrary
Expand Down Expand Up @@ -2077,3 +2077,67 @@ def test_unbind_with_namespace():

assert s.op() == expected.op()
assert s.equals(expected)


def test_table_bind():
def eq(left, right):
return all(a.equals(b) for a, b in zip(left, right))

t = ibis.table({"a": "int", "b": "string"}, name="t")

# single table arg
exprs = t.bind(t)
expected = (t.a, t.b)
assert eq(exprs, expected)

# single selector arg
exprs = t.bind(s.all())
expected = (t.a, t.b)
assert eq(exprs, expected)

# single tuple arg
exprs = t.bind([1, "a"])
expected = (ibis.literal(1), t.a)
assert eq(exprs, expected)

# single list arg
exprs = t.bind([1, 2, "b"])
expected = (ibis.literal(1), ibis.literal(2), t.b)
assert eq(exprs, expected)

# single list arg with kwargs
exprs = t.bind([1], b=2)
expected = (ibis.literal(1), ibis.literal(2).name("b"))
assert eq(exprs, expected)

# single dict arg
exprs = t.bind({"c": 1, "d": 2})
expected = (ibis.literal(1).name("c"), ibis.literal(2).name("d"))
assert eq(exprs, expected)

# single dict arg with kwargs
exprs = t.bind({"c": 1}, d=2)
expected = (ibis.literal(1).name("c"), ibis.literal(2).name("d"))
assert eq(exprs, expected)

# single dict arg with overlapping kwargs
exprs = t.bind({"c": 1, "d": 2}, c=2)
expected = (ibis.literal(2).name("c"), ibis.literal(2).name("d"))
assert eq(exprs, expected)

# kwargs cannot cannot produce more than one value
with pytest.raises(com.IbisInputError):
t.bind(alias=t)
with pytest.raises(com.IbisInputError):
t.bind(alias=s.all())

# multiple args
exprs = t.bind(t, ["a", "b"], {"c": 1}, d=2)
expected = (
t.a,
t.b,
ibis.literal(["a", "b"]),
ibis.literal({"c": 1}),
ibis.literal(2).name("d"),
)
assert eq(exprs, expected)

0 comments on commit 50cdc8e

Please sign in to comment.