Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(api): restrict arbitrary input nesting #8917

Merged
merged 5 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ibis/backends/tests/test_vectorized_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,7 @@ def test_elementwise_udf_named_destruct(udf_alltypes):
add_one_struct_udf = create_add_one_struct_udf(
result_formatter=lambda v1, v2: (v1, v2)
)
msg = "Duplicate column name 'new_struct' in result set"
with pytest.raises(com.IntegrityError, match=msg):
with pytest.raises(com.InputTypeError, match="Unable to infer datatype"):
udf_alltypes.mutate(
new_struct=add_one_struct_udf(udf_alltypes["double_col"]).destructure()
)
Expand Down
8 changes: 3 additions & 5 deletions ibis/expr/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,15 @@ def order_by(self, expr) -> Self:
return self.copy(orderings=self.orderings + util.promote_tuple(expr))

def bind(self, table):
from ibis.expr.types.relations import bind

if table is None:
if self._table is None:
raise IbisInputError("Cannot bind window frame without a table")
else:
table = self._table.to_expr()

grouping = bind(table, self.groupings)
orderings = bind(table, self.orderings)
return self.copy(groupings=grouping, orderings=orderings)
return self.copy(
groupings=table.bind(self.groupings), orderings=table.bind(self.orderings)
)


class LegacyWindowBuilder(WindowBuilder):
Expand Down
6 changes: 5 additions & 1 deletion ibis/expr/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ def backtrack(cls, value):
yield value, distance
value = value.rel.values.get(value.name)
distance += 1
if value is not None and not value.find(ops.Impure, filter=ops.Value):
if (
value is not None
and value.relations
and not value.find(ops.Impure, filter=ops.Value)
):
yield value, distance

def dereference(self, value):
Expand Down
21 changes: 19 additions & 2 deletions ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
import ibis.selectors as s
from ibis import _
from ibis.common.annotations import ValidationError
from ibis.common.exceptions import IbisInputError, IntegrityError
Expand Down Expand Up @@ -499,7 +500,7 @@ def test_subsequent_filter():
assert f2.op() == expected


def test_project_dereferences_literal_expressions():
def test_project_doesnt_dereference_literal_expressions():
one = ibis.literal(1)
two = ibis.literal(2)
four = (one + one) * two
Expand All @@ -516,7 +517,7 @@ def test_project_dereferences_literal_expressions():
)

t2 = t1.select(four)
assert t2.op() == Project(parent=t1, values={four.get_name(): t1.four})
assert t2.op() == Project(parent=t1, values={four.get_name(): four})


def test_project_before_and_after_filter():
Expand Down Expand Up @@ -864,6 +865,22 @@ def test_join_predicate_dereferencing_using_tuple_syntax():
assert j2.op() == expected


def test_join_with_selector_predicate():
t1 = ibis.table(name="t1", schema={"a": "string", "b": "string"})
t2 = ibis.table(name="t2", schema={"c": "string", "d": "string"})

joined = t1.join(t2, s.of_type("string"))
with join_tables(joined) as (r1, r2):
expected = JoinChain(
first=r1,
rest=[
JoinLink("inner", r2, [r1.a == r2.c, r1.b == r2.d]),
],
values={"a": r1.a, "b": r1.b, "c": r2.c, "d": r2.d},
)
assert joined.op() == expected


def test_join_rhs_dereferencing():
t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"})
t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"})
Expand Down
22 changes: 11 additions & 11 deletions ibis/expr/types/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
from ibis.common.grounds import Concrete
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.rewrites import rewrite_window_input
from ibis.expr.types.relations import bind

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from collections.abc import Sequence


@public
Expand Down Expand Up @@ -65,13 +64,14 @@ def __getattr__(self, attr):

def aggregate(self, *metrics, **kwds) -> ir.Table:
"""Compute aggregates over a group by."""
metrics = self.table.to_expr().bind(*metrics, **kwds)
return self.table.to_expr().aggregate(
metrics, by=self.groupings, having=self.havings, **kwds
metrics, by=self.groupings, having=self.havings
)

agg = aggregate

def having(self, *expr: ir.BooleanScalar) -> GroupedTable:
def having(self, *predicates: ir.BooleanScalar) -> GroupedTable:
"""Add a post-aggregation result filter `expr`.

::: {.callout-warning}
Expand All @@ -80,19 +80,19 @@ def having(self, *expr: ir.BooleanScalar) -> GroupedTable:

Parameters
----------
expr
An expression that filters based on an aggregate value.
predicates
Expressions that filters based on an aggregate value.

Returns
-------
GroupedTable
A grouped table expression
"""
table = self.table.to_expr()
havings = tuple(bind(table, expr))
havings = table.bind(*predicates)
return self.copy(havings=self.havings + havings)

def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
def order_by(self, *by: ir.Value) -> GroupedTable:
"""Sort a grouped table expression by `expr`.

Notes
Expand All @@ -101,7 +101,7 @@ def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:

Parameters
----------
expr
by
Expressions to order the results by

Returns
Expand All @@ -110,7 +110,7 @@ def order_by(self, *expr: ir.Value | Iterable[ir.Value]) -> GroupedTable:
A sorted grouped GroupedTable
"""
table = self.table.to_expr()
orderings = tuple(bind(table, expr))
orderings = table.bind(*by)
return self.copy(orderings=self.orderings + orderings)

def mutate(
Expand Down Expand Up @@ -201,7 +201,7 @@ def _selectables(self, *exprs, **kwexprs):
[`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate)
"""
table = self.table.to_expr()
values = bind(table, (exprs, kwexprs))
values = table.bind(*exprs, **kwexprs)
window = ibis.window(group_by=self.groupings, order_by=self.orderings)
return [rewrite_window_input(expr.op(), window).to_expr() for expr in values]

Expand Down
20 changes: 8 additions & 12 deletions ibis/expr/types/joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,10 @@ def prepare_predicates(
else:
lk = rk = pred

# bind the predicates to the join chain
(left_value,) = bind(left, lk)
(right_value,) = bind(right, rk)

# dereference the left value to one of the relations in the join chain
left_value = deref_left.dereference(left_value.op())
right_value = deref_right.dereference(right_value.op())

yield comparison(left_value, right_value)
for lhs, rhs in zip(bind(left, lk), bind(right, rk)):
lhs = deref_left.dereference(lhs.op())
rhs = deref_right.dereference(rhs.op())
yield comparison(lhs, rhs)


def finished(method):
Expand Down Expand Up @@ -335,8 +330,9 @@ def asof_join(
result = self.left_join(
filtered, predicates=[left_on == right_on] + predicates
)
values = {**self.op().values, **filtered.op().values}
return result.select(values)
values = {**filtered.op().values, **self.op().values}

return result.select(**values)

chain = self.op()
right = right.op()
Expand Down Expand Up @@ -383,7 +379,7 @@ def cross_join(
@functools.wraps(Table.select)
def select(self, *args, **kwargs):
chain = self.op()
values = bind(self, (args, kwargs))
values = self.bind(*args, **kwargs)
values = unwrap_aliases(values)

links = [link.table for link in chain.rest if link.how not in ("semi", "anti")]
Expand Down
Loading
Loading