From f534626c2a6c981c2b2a3a7c21993b76687203ac Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Thu, 5 Sep 2024 07:32:26 -0400 Subject: [PATCH] chore: split up aggregate compilation to avoid copypasta --- ibis/backends/sql/compilers/base.py | 48 ++++++++++++++--------- ibis/backends/sql/compilers/datafusion.py | 21 +--------- 2 files changed, 32 insertions(+), 37 deletions(-) diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index d74ded9b208f7..2530a4616eef6 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -1383,30 +1383,42 @@ def visit_JoinLink(self, op, *, how, table, predicates): def _generate_groups(groups): return map(sge.convert, range(1, len(groups) + 1)) + def _compile_agg_select(self, op, *, parent, keys, metrics): + return sg.select( + *self._cleanup_names(keys), *self._cleanup_names(metrics), copy=False + ).from_(parent, copy=False) + + def _compile_group_by(self, sel, *, groups, grouping_sets, rollups, cubes): + expressions = list(self._generate_groups(groups.values())) + group = sge.Group( + expressions=expressions, + grouping_sets=[ + sge.GroupingSets( + expressions=[ + sge.Tuple(expressions=expressions) + for expressions in grouping_set + ] + ) + for grouping_set in grouping_sets + ], + rollup=[sge.Rollup(expressions=rollup) for rollup in rollups], + cube=[sge.Cube(expressions=cube) for cube in cubes], + ) + return sel.group_by(group, copy=False) + def visit_Aggregate( self, op, *, parent, keys, groups, metrics, grouping_sets, rollups, cubes ): - sel = sg.select( - *self._cleanup_names(keys), *self._cleanup_names(metrics), copy=False - ).from_(parent, copy=False) + sel = self._compile_agg_select(op, parent=parent, keys=keys, metrics=metrics) if groups or grouping_sets or rollups or cubes: - expressions = list(self._generate_groups(groups.values())) - group = sge.Group( - expressions=expressions, - grouping_sets=[ - sge.GroupingSets( - expressions=[ - sge.Tuple(expressions=expressions) - for expressions in grouping_set - ] - ) - for grouping_set in grouping_sets - ], - rollup=[sge.Rollup(expressions=rollup) for rollup in rollups], - cube=[sge.Cube(expressions=cube) for cube in cubes], + sel = self._compile_group_by( + sel, + groups=groups, + grouping_sets=grouping_sets, + rollups=rollups, + cubes=cubes, ) - sel = sel.group_by(group, copy=False) return sel diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index 56b3314a4796c..76c87d166a0fa 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -439,9 +439,7 @@ def visit_Last(self, op, *, arg, where, order_by, include_null): where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.last_value(arg, where=where, order_by=order_by) - def visit_Aggregate( - self, op, *, parent, keys, groups, metrics, grouping_sets, rollups, cubes - ): + def _compile_agg_select(self, op, *, parent, keys, metrics): """Support `GROUP BY` expressions in `SELECT` since DataFusion does not.""" quoted = self.quoted metrics = tuple(self._cleanup_names(metrics)) @@ -481,22 +479,7 @@ def visit_Aggregate( selections = metrics or (STAR,) table = parent - sel = sg.select(*selections).from_(table) - - if groups or grouping_sets or rollups or cubes: - expressions = list(self._generate_groups(groups.values())) - group = sge.Group( - expressions=expressions, - grouping_sets=[ - sge.GroupingSets(expressions=grouping_set) - for grouping_set in grouping_sets - ], - rollup=[sge.Rollup(expressions=rollup) for rollup in rollups], - cube=[sge.Cube(expressions=cube) for cube in cubes], - ) - sel = sel.group_by(group, copy=False) - - return sel + return sg.select(*selections).from_(table) def visit_StructColumn(self, op, *, names, values): args = []