Skip to content

Commit

Permalink
chore: split up aggregate compilation to avoid copypasta
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Sep 5, 2024
1 parent 578707b commit f534626
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 37 deletions.
48 changes: 30 additions & 18 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,30 +1383,42 @@ def visit_JoinLink(self, op, *, how, table, predicates):
def _generate_groups(groups):
return map(sge.convert, range(1, len(groups) + 1))

def _compile_agg_select(self, op, *, parent, keys, metrics):
return sg.select(
*self._cleanup_names(keys), *self._cleanup_names(metrics), copy=False
).from_(parent, copy=False)

def _compile_group_by(self, sel, *, groups, grouping_sets, rollups, cubes):
expressions = list(self._generate_groups(groups.values()))
group = sge.Group(
expressions=expressions,
grouping_sets=[
sge.GroupingSets(
expressions=[
sge.Tuple(expressions=expressions)
for expressions in grouping_set
]
)
for grouping_set in grouping_sets
],
rollup=[sge.Rollup(expressions=rollup) for rollup in rollups],
cube=[sge.Cube(expressions=cube) for cube in cubes],
)
return sel.group_by(group, copy=False)

def visit_Aggregate(
self, op, *, parent, keys, groups, metrics, grouping_sets, rollups, cubes
):
sel = sg.select(
*self._cleanup_names(keys), *self._cleanup_names(metrics), copy=False
).from_(parent, copy=False)
sel = self._compile_agg_select(op, parent=parent, keys=keys, metrics=metrics)

if groups or grouping_sets or rollups or cubes:
expressions = list(self._generate_groups(groups.values()))
group = sge.Group(
expressions=expressions,
grouping_sets=[
sge.GroupingSets(
expressions=[
sge.Tuple(expressions=expressions)
for expressions in grouping_set
]
)
for grouping_set in grouping_sets
],
rollup=[sge.Rollup(expressions=rollup) for rollup in rollups],
cube=[sge.Cube(expressions=cube) for cube in cubes],
sel = self._compile_group_by(
sel,
groups=groups,
grouping_sets=grouping_sets,
rollups=rollups,
cubes=cubes,
)
sel = sel.group_by(group, copy=False)

return sel

Expand Down
21 changes: 2 additions & 19 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,7 @@ def visit_Last(self, op, *, arg, where, order_by, include_null):
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_Aggregate(
self, op, *, parent, keys, groups, metrics, grouping_sets, rollups, cubes
):
def _compile_agg_select(self, op, *, parent, keys, metrics):
"""Support `GROUP BY` expressions in `SELECT` since DataFusion does not."""
quoted = self.quoted
metrics = tuple(self._cleanup_names(metrics))
Expand Down Expand Up @@ -481,22 +479,7 @@ def visit_Aggregate(
selections = metrics or (STAR,)
table = parent

sel = sg.select(*selections).from_(table)

if groups or grouping_sets or rollups or cubes:
expressions = list(self._generate_groups(groups.values()))
group = sge.Group(
expressions=expressions,
grouping_sets=[
sge.GroupingSets(expressions=grouping_set)
for grouping_set in grouping_sets
],
rollup=[sge.Rollup(expressions=rollup) for rollup in rollups],
cube=[sge.Cube(expressions=cube) for cube in cubes],
)
sel = sel.group_by(group, copy=False)

return sel
return sg.select(*selections).from_(table)

def visit_StructColumn(self, op, *, names, values):
args = []
Expand Down

0 comments on commit f534626

Please sign in to comment.