From 5c2eadcdd5b2fbfcdae454e7149d9438c52e190f Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Wed, 4 Sep 2024 13:51:54 -0500 Subject: [PATCH] fix(sql): properly parenthesize binary ops containing named expressions --- ibis/backends/sql/compilers/base.py | 3 --- .../sql/compilers/bigquery/__init__.py | 14 ++++++++++---- ibis/backends/sql/compilers/clickhouse.py | 14 ++++++++++---- ibis/backends/sql/compilers/duckdb.py | 18 +++++++++++++----- ibis/backends/sql/compilers/postgres.py | 14 ++++++++++---- ibis/backends/sql/compilers/pyspark.py | 14 ++++++++++---- ibis/backends/sql/compilers/snowflake.py | 17 +++++++++++------ ibis/backends/sql/compilers/trino.py | 14 ++++++++++---- ibis/backends/sql/rewrites.py | 9 ++++++++- .../test_subquery_where_location/decompiled.py | 2 +- .../out.sql | 5 +++++ ibis/backends/tests/sql/test_compiler.py | 2 +- ibis/backends/tests/sql/test_sql.py | 6 ++++++ ibis/expr/operations/relations.py | 11 ++++------- ibis/expr/types/relations.py | 6 +++++- 15 files changed, 104 insertions(+), 45 deletions(-) create mode 100644 ibis/backends/tests/sql/snapshots/test_sql/test_binop_with_alias_still_parenthesized/out.sql diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index 2b9031548c6e..7b2153ecdff1 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -689,9 +689,6 @@ def visit_Cast(self, op, *, arg, to): def visit_ScalarSubquery(self, op, *, rel): return rel.this.subquery(copy=False) - def visit_Alias(self, op, *, arg, name): - return arg - def visit_Literal(self, op, *, value, dtype): """Compile a literal value. diff --git a/ibis/backends/sql/compilers/bigquery/__init__.py b/ibis/backends/sql/compilers/bigquery/__init__.py index fcb3f3c0d873..b40c85b0fa4a 100644 --- a/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/ibis/backends/sql/compilers/bigquery/__init__.py @@ -1017,7 +1017,14 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop): return sg.select(column).from_(parent) def visit_TableUnnest( - self, op, *, parent, column, offset: str | None, keep_empty: bool + self, + op, + *, + parent, + column, + column_name: str, + offset: str | None, + keep_empty: bool, ): quoted = self.quoted @@ -1029,9 +1036,8 @@ def visit_TableUnnest( table = sg.to_identifier(parent.alias_or_name, quoted=quoted) - opname = op.column.name - overlaps_with_parent = opname in op.parent.schema - computed_column = column_alias.as_(opname, quoted=quoted) + overlaps_with_parent = column_name in op.parent.schema + computed_column = column_alias.as_(column_name, quoted=quoted) # replace the existing column if the unnested column hasn't been # renamed diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index 479c2eea2eec..a3d7079d3aa6 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -688,7 +688,14 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop): return sg.select(column).from_(parent) def visit_TableUnnest( - self, op, *, parent, column, offset: str | None, keep_empty: bool + self, + op, + *, + parent, + column, + column_name: str, + offset: str | None, + keep_empty: bool, ): quoted = self.quoted @@ -700,9 +707,8 @@ def visit_TableUnnest( selcols = [] - opname = op.column.name - overlaps_with_parent = opname in op.parent.schema - computed_column = column_alias.as_(opname, quoted=quoted) + overlaps_with_parent = column_name in op.parent.schema + computed_column = column_alias.as_(column_name, quoted=quoted) if offset is not None: if overlaps_with_parent: diff --git a/ibis/backends/sql/compilers/duckdb.py b/ibis/backends/sql/compilers/duckdb.py index 1481510ef37e..9d97eac0a3c8 100644 --- a/ibis/backends/sql/compilers/duckdb.py +++ b/ibis/backends/sql/compilers/duckdb.py @@ -609,15 +609,21 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop): return sg.select(column).from_(parent) def visit_TableUnnest( - self, op, *, parent, column, offset: str | None, keep_empty: bool + self, + op, + *, + parent, + column, + column_name: str, + offset: str | None, + keep_empty: bool, ): quoted = self.quoted column_alias = sg.to_identifier(gen_name("table_unnest_column"), quoted=quoted) - opname = op.column.name - overlaps_with_parent = opname in op.parent.schema - computed_column = column_alias.as_(opname, quoted=quoted) + overlaps_with_parent = column_name in op.parent.schema + computed_column = column_alias.as_(column_name, quoted=quoted) selcols = [] @@ -627,7 +633,9 @@ def visit_TableUnnest( # TODO: clean this up once WITH ORDINALITY is supported in DuckDB # no need for struct_extract once that's upstream column = self.f.list_zip(column, self.f.range(self.f.len(column))) - extract = self.f.struct_extract(column_alias, 1).as_(opname, quoted=quoted) + extract = self.f.struct_extract(column_alias, 1).as_( + column_name, quoted=quoted + ) if overlaps_with_parent: replace = sge.Column(this=sge.Star(replace=[extract]), table=table) diff --git a/ibis/backends/sql/compilers/postgres.py b/ibis/backends/sql/compilers/postgres.py index 643b834d5fc7..b224d1d180e1 100644 --- a/ibis/backends/sql/compilers/postgres.py +++ b/ibis/backends/sql/compilers/postgres.py @@ -718,7 +718,14 @@ def visit_Hash(self, op, *, arg): ) def visit_TableUnnest( - self, op, *, parent, column, offset: str | None, keep_empty: bool + self, + op, + *, + parent, + column, + column_name: str, + offset: str | None, + keep_empty: bool, ): quoted = self.quoted @@ -726,10 +733,9 @@ def visit_TableUnnest( parent_alias = parent.alias_or_name - opname = op.column.name parent_schema = op.parent.schema - overlaps_with_parent = opname in parent_schema - computed_column = column_alias.as_(opname, quoted=quoted) + overlaps_with_parent = column_name in parent_schema + computed_column = column_alias.as_(column_name, quoted=quoted) selcols = [] diff --git a/ibis/backends/sql/compilers/pyspark.py b/ibis/backends/sql/compilers/pyspark.py index 1555b8bc9502..36599f8d21e5 100644 --- a/ibis/backends/sql/compilers/pyspark.py +++ b/ibis/backends/sql/compilers/pyspark.py @@ -500,16 +500,22 @@ def visit_HexDigest(self, op, *, arg, how): raise NotImplementedError(f"No available hashing function for {how}") def visit_TableUnnest( - self, op, *, parent, column, offset: str | None, keep_empty: bool + self, + op, + *, + parent, + column, + column_name: str, + offset: str | None, + keep_empty: bool, ): quoted = self.quoted column_alias = sg.to_identifier(gen_name("table_unnest_column"), quoted=quoted) - opname = op.column.name parent_schema = op.parent.schema - overlaps_with_parent = opname in parent_schema - computed_column = column_alias.as_(opname, quoted=quoted) + overlaps_with_parent = column_name in parent_schema + computed_column = column_alias.as_(column_name, quoted=quoted) parent_alias = parent.alias_or_name diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index c6e49cbba5f6..207422a12b6f 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -810,7 +810,14 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop): return sg.select(column).from_(parent) def visit_TableUnnest( - self, op, *, parent, column, offset: str | None, keep_empty: bool + self, + op, + *, + parent, + column, + column_name: str, + offset: str | None, + keep_empty: bool, ): quoted = self.quoted @@ -825,12 +832,10 @@ def visit_TableUnnest( selcols = [] - opcol = op.column - opname = opcol.name - overlaps_with_parent = opname in op.parent.schema + overlaps_with_parent = column_name in op.parent.schema computed_column = self.cast( - self.f.nullif(column_alias, null_sentinel), opcol.dtype.value_type - ).as_(opname, quoted=quoted) + self.f.nullif(column_alias, null_sentinel), op.column.dtype.value_type + ).as_(column_name, quoted=quoted) if overlaps_with_parent: selcols.append( diff --git a/ibis/backends/sql/compilers/trino.py b/ibis/backends/sql/compilers/trino.py index f67dc0d9f7af..56452a1b4dbf 100644 --- a/ibis/backends/sql/compilers/trino.py +++ b/ibis/backends/sql/compilers/trino.py @@ -546,16 +546,22 @@ def visit_ToJSONArray(self, op, *, arg): ) def visit_TableUnnest( - self, op, *, parent, column, offset: str | None, keep_empty: bool + self, + op, + *, + parent, + column, + column_name: str, + offset: str | None, + keep_empty: bool, ): quoted = self.quoted column_alias = sg.to_identifier(gen_name("table_unnest_column"), quoted=quoted) - opname = op.column.name parent_schema = op.parent.schema - overlaps_with_parent = opname in parent_schema - computed_column = column_alias.as_(opname, quoted=quoted) + overlaps_with_parent = column_name in parent_schema + computed_column = column_alias.as_(column_name, quoted=quoted) parent_alias_or_name = parent.alias_or_name diff --git a/ibis/backends/sql/rewrites.py b/ibis/backends/sql/rewrites.py index 820d9f890c22..8d2c7253b336 100644 --- a/ibis/backends/sql/rewrites.py +++ b/ibis/backends/sql/rewrites.py @@ -165,7 +165,7 @@ def fill_null_to_select(_, **kwargs): for name in _.parent.schema.names: col = ops.Field(_.parent, name) if (value := mapping.get(name)) is not None: - col = ops.Alias(ops.Coalesce((col, value)), name) + col = ops.Coalesce((col, value)) selections[name] = col return Select(_.parent, selections=selections) @@ -206,6 +206,12 @@ def first_to_firstvalue(_, **kwargs): return _.copy(func=klass(_.func.arg)) +@replace(p.Alias) +def remove_aliases(_, **kwargs): + """Remove all remaining aliases, they're not needed for remaining compilation.""" + return _.arg + + def complexity(node): """Assign a complexity score to a node. @@ -372,6 +378,7 @@ def sqlize( context = {"params": params} result = node.replace( replace_parameter + | remove_aliases | project_to_select | filter_to_select | sort_to_select diff --git a/ibis/backends/tests/sql/snapshots/test_compiler/test_subquery_where_location/decompiled.py b/ibis/backends/tests/sql/snapshots/test_compiler/test_subquery_where_location/decompiled.py index aa735e73ec47..76f89e244bef 100644 --- a/ibis/backends/tests/sql/snapshots/test_compiler/test_subquery_where_location/decompiled.py +++ b/ibis/backends/tests/sql/snapshots/test_compiler/test_subquery_where_location/decompiled.py @@ -11,7 +11,7 @@ }, ) param = ibis.param("timestamp") -f = alltypes.filter((alltypes.timestamp_col < param.name("my_param"))) +f = alltypes.filter((alltypes.timestamp_col < param)) agg = f.aggregate([f.float_col.sum().name("foo")], by=[f.string_col]) result = agg.foo.count() diff --git a/ibis/backends/tests/sql/snapshots/test_sql/test_binop_with_alias_still_parenthesized/out.sql b/ibis/backends/tests/sql/snapshots/test_sql/test_binop_with_alias_still_parenthesized/out.sql new file mode 100644 index 000000000000..6ad547995b5c --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_sql/test_binop_with_alias_still_parenthesized/out.sql @@ -0,0 +1,5 @@ +SELECT + ( + "t0"."a" + "t0"."b" + ) * "t0"."c" AS "x" +FROM "t" AS "t0" \ No newline at end of file diff --git a/ibis/backends/tests/sql/test_compiler.py b/ibis/backends/tests/sql/test_compiler.py index bc789dba86ba..dc27463647b8 100644 --- a/ibis/backends/tests/sql/test_compiler.py +++ b/ibis/backends/tests/sql/test_compiler.py @@ -196,7 +196,7 @@ def test_subquery_where_location(snapshot): ], name="alltypes", ) - param = ibis.param("timestamp").name("my_param") + param = ibis.param("timestamp") expr = ( t[["float_col", "timestamp_col", "int_col", "string_col"]][ lambda t: t.timestamp_col < param diff --git a/ibis/backends/tests/sql/test_sql.py b/ibis/backends/tests/sql/test_sql.py index 01fd47220982..f979b864d916 100644 --- a/ibis/backends/tests/sql/test_sql.py +++ b/ibis/backends/tests/sql/test_sql.py @@ -143,6 +143,12 @@ def test_binop_parens(snapshot, opname, dtype, associative): snapshot.assert_match(combined, "out.sql") +def test_binop_with_alias_still_parenthesized(snapshot): + t = ibis.table({"a": "int", "b": "int", "c": "int"}, name="t") + sql = to_sql(((t.a + t.b).name("d") * t.c).name("x")) + snapshot.assert_match(sql, "out.sql") + + @pytest.mark.parametrize( "expr_fn", [ diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index 49c288ca651a..b4e8a85d301f 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -498,6 +498,7 @@ class TableUnnest(Relation): parent: Relation column: Value[dt.Array] + column_name: str offset: typing.Union[str, None] keep_empty: bool @@ -507,15 +508,11 @@ def values(self): @attribute def schema(self): - column = self.column - offset = self.offset - base = self.parent.schema.fields.copy() + base[self.column_name] = self.column.dtype.value_type - base[column.name] = column.dtype.value_type - - if offset is not None: - base[offset] = dt.int64 + if self.offset is not None: + base[self.offset] = dt.int64 return Schema(base) diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 7204e50ca78d..c6cd6d30041b 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -4890,7 +4890,11 @@ def unnest( """ (column,) = self.bind(column) return ops.TableUnnest( - parent=self, column=column, offset=offset, keep_empty=keep_empty + parent=self, + column=column, + column_name=column.get_name(), + offset=offset, + keep_empty=keep_empty, ).to_expr()