From 721ba4eae0460559bdcc9c5ccb4721115107a5e9 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Wed, 11 Sep 2024 13:39:26 -0500 Subject: [PATCH] refactor(sql): simplify paren handling for binary ops --- .../test_sql/test_is_parens/notnull/out.sql | 6 +- .../out.sql | 4 +- ibis/backends/sql/compilers/base.py | 99 +++++++++++-------- ibis/backends/sql/compilers/clickhouse.py | 4 +- 4 files changed, 66 insertions(+), 47 deletions(-) diff --git a/ibis/backends/impala/tests/snapshots/test_sql/test_is_parens/notnull/out.sql b/ibis/backends/impala/tests/snapshots/test_sql/test_is_parens/notnull/out.sql index bff317506a92..b3c32d222c9f 100644 --- a/ibis/backends/impala/tests/snapshots/test_sql/test_is_parens/notnull/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_sql/test_is_parens/notnull/out.sql @@ -2,4 +2,8 @@ SELECT * FROM `table` AS `t0` WHERE - `t0`.`a` IS NOT NULL = `t0`.`b` IS NOT NULL \ No newline at end of file + ( + `t0`.`a` IS NOT NULL + ) = ( + `t0`.`b` IS NOT NULL + ) \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_sql/test_logically_negate_complex_boolean_expr/out.sql b/ibis/backends/impala/tests/snapshots/test_sql/test_logically_negate_complex_boolean_expr/out.sql index f0be225bba22..ccb29293341f 100644 --- a/ibis/backends/impala/tests/snapshots/test_sql/test_logically_negate_complex_boolean_expr/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_sql/test_logically_negate_complex_boolean_expr/out.sql @@ -1,5 +1,7 @@ SELECT NOT ( - `t0`.`a` IN ('foo') AND `t0`.`c` IS NOT NULL + `t0`.`a` IN ('foo') AND ( + `t0`.`c` IS NOT NULL + ) ) AS `tmp` FROM `t` AS `t0` \ No newline at end of file diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index de61268377cc..b4af50aa3ceb 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -374,48 +374,61 @@ class SQLGlotCompiler(abc.ABC): BINARY_INFIX_OPS = { # Numeric - ops.Add: (sge.Add, True), - ops.Subtract: (sge.Sub, False), - ops.Multiply: (sge.Mul, True), - ops.Divide: (sge.Div, False), - ops.Modulus: (sge.Mod, False), - ops.Power: (sge.Pow, False), + ops.Add: sge.Add, + ops.Subtract: sge.Sub, + ops.Multiply: sge.Mul, + ops.Divide: sge.Div, + ops.Modulus: sge.Mod, + ops.Power: sge.Pow, # Comparisons - ops.GreaterEqual: (sge.GTE, False), - ops.Greater: (sge.GT, False), - ops.LessEqual: (sge.LTE, False), - ops.Less: (sge.LT, False), - ops.Equals: (sge.EQ, False), - ops.NotEquals: (sge.NEQ, False), + ops.GreaterEqual: sge.GTE, + ops.Greater: sge.GT, + ops.LessEqual: sge.LTE, + ops.Less: sge.LT, + ops.Equals: sge.EQ, + ops.NotEquals: sge.NEQ, # Logical - ops.And: (sge.And, True), - ops.Or: (sge.Or, True), - ops.Xor: (sge.Xor, True), + ops.And: sge.And, + ops.Or: sge.Or, + ops.Xor: sge.Xor, # Bitwise - ops.BitwiseLeftShift: (sge.BitwiseLeftShift, False), - ops.BitwiseRightShift: (sge.BitwiseRightShift, False), - ops.BitwiseAnd: (sge.BitwiseAnd, True), - ops.BitwiseOr: (sge.BitwiseOr, True), - ops.BitwiseXor: (sge.BitwiseXor, True), + ops.BitwiseLeftShift: sge.BitwiseLeftShift, + ops.BitwiseRightShift: sge.BitwiseRightShift, + ops.BitwiseAnd: sge.BitwiseAnd, + ops.BitwiseOr: sge.BitwiseOr, + ops.BitwiseXor: sge.BitwiseXor, # Date - ops.DateAdd: (sge.Add, True), - ops.DateSub: (sge.Sub, False), - ops.DateDiff: (sge.Sub, False), + ops.DateAdd: sge.Add, + ops.DateSub: sge.Sub, + ops.DateDiff: sge.Sub, # Time - ops.TimeAdd: (sge.Add, True), - ops.TimeSub: (sge.Sub, False), - ops.TimeDiff: (sge.Sub, False), + ops.TimeAdd: sge.Add, + ops.TimeSub: sge.Sub, + ops.TimeDiff: sge.Sub, # Timestamp - ops.TimestampAdd: (sge.Add, True), - ops.TimestampSub: (sge.Sub, False), - ops.TimestampDiff: (sge.Sub, False), + ops.TimestampAdd: sge.Add, + ops.TimestampSub: sge.Sub, + ops.TimestampDiff: sge.Sub, # Interval - ops.IntervalAdd: (sge.Add, True), - ops.IntervalMultiply: (sge.Mul, True), - ops.IntervalSubtract: (sge.Sub, False), + ops.IntervalAdd: sge.Add, + ops.IntervalMultiply: sge.Mul, + ops.IntervalSubtract: sge.Sub, } - NEEDS_PARENS = tuple(BINARY_INFIX_OPS) + (ops.IsNull,) + # A set of SQLGlot classes that may need to be parenthesized + SQLGLOT_NEEDS_PARENS = set(BINARY_INFIX_OPS.values()).union((sge.Is,)) + + # A set of SQLGlot classes that are associative operations + SQLGLOT_ASSOCIATIVE_OPS = { + sge.Add, + sge.Mul, + sge.And, + sge.Or, + sge.Xor, + sge.BitwiseAnd, + sge.BitwiseOr, + sge.BitwiseXor, + } # Constructed dynamically in `__init_subclass__` from their respective # UPPERCASE values to handle inheritance, do not modify directly here. @@ -457,14 +470,14 @@ def impl(self, _, *, _name: str = target_name, **kw): # compiler class. if binops := cls.__dict__.get("BINARY_INFIX_OPS", {}): - def make_binop(sge_cls, associative): + def make_binop(sge_cls): def impl(self, op, *, left, right): - return self.binop(sge_cls, op, left, right, associative=associative) + return self.binop(sge_cls, left, right) return impl - for op, (sge_cls, associative) in binops.items(): - setattr(cls, methodname(op), make_binop(sge_cls, associative)) + for op, sge_cls in binops.items(): + setattr(cls, methodname(op), make_binop(sge_cls)) # unconditionally raise an exception for unsupported operations # @@ -1384,8 +1397,8 @@ def visit_Aggregate(self, op, *, parent, groups, metrics): return sel @classmethod - def _add_parens(cls, op, sg_expr): - if isinstance(op, cls.NEEDS_PARENS): + def _add_parens(cls, sg_expr): + if type(sg_expr) in cls.SQLGLOT_NEEDS_PARENS: return sge.paren(sg_expr, copy=False) return sg_expr @@ -1499,16 +1512,16 @@ def visit_SQLQueryResult(self, op, *, query, schema, source): def visit_RegexExtract(self, op, *, arg, pattern, index): return self.f.regexp_extract(arg, pattern, index, dialect=self.dialect) - def binop(self, sg_expr, op, left, right, *, associative=False): + def binop(self, sg_cls, left, right): # If the op is associative we can skip parenthesizing ops of the same # type if they're on the left, since they would evaluate the same. # SQLGlot has an optimizer for generating long sql chains of the same # op of this form without recursion, by avoiding parenthesis in this # common case we can make use of this optimization to handle large # operator chains. - if not associative or type(op) is not type(op.left): - left = self._add_parens(op.left, left) - return sg_expr(this=left, expression=self._add_parens(op.right, right)) + if not (sg_cls in self.SQLGLOT_ASSOCIATIVE_OPS and type(left) is sg_cls): + left = self._add_parens(left) + return sg_cls(this=left, expression=self._add_parens(right)) def visit_Undefined(self, op, **_): raise com.OperationNotDefinedError( diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index a3d7079d3aa6..ea15150b279b 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -162,11 +162,11 @@ def visit_ArrayRepeat(self, op, *, arg, times): return self.f.arrayFlatten(self.f.arrayMap(func, self.f.range(times))) def visit_ArraySlice(self, op, *, arg, start, stop): - start = self._add_parens(op.start, start) + start = self._add_parens(start) start_correct = self.if_(start < 0, start, start + 1) if stop is not None: - stop = self._add_parens(op.stop, stop) + stop = self._add_parens(stop) length = self.if_( stop < 0,