Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(sql): clean up unnecessary use of explicit visit methods #10419

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ class SQLGlotCompiler(abc.ABC):
ops.Translate: "translate",
ops.Unnest: "explode",
ops.Uppercase: "upper",
ops.RandomUUID: "uuid",
ops.RandomScalar: "rand",
}

BINARY_INFIX_OPS = {
Expand Down Expand Up @@ -869,14 +871,6 @@ def visit_Floor(self, op, *, arg):
def visit_Round(self, op, *, arg, digits):
return self.cast(self.f.round(arg, digits), op.dtype)

### Random Noise

def visit_RandomScalar(self, op, **kwargs):
return self.f.rand()

def visit_RandomUUID(self, op, **kwargs):
return self.f.uuid()

### Dtype Dysmorphia

def visit_TryCast(self, op, *, arg, to):
Expand Down
6 changes: 1 addition & 5 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ class BigQueryCompiler(SQLGlotCompiler):
ops.TimeFromHMS: "time_from_parts",
ops.TimestampNow: "current_timestamp",
ops.ExtractHost: "net.host",
ops.ArgMin: "min_by",
ops.ArgMax: "max_by",
ops.RandomUUID: "generate_uuid",
}

def to_sqlglot(
Expand Down Expand Up @@ -997,9 +996,6 @@ def visit_CountDistinct(self, op, *, arg, where):
arg = self.if_(where, arg, NULL)
return self.f.count(sge.Distinct(expressions=[arg]))

def visit_RandomUUID(self, op, **kwargs):
return self.f.generate_uuid()

def visit_ExtractFile(self, op, *, arg):
return self._pudf("cw_url_extract_file", arg)

Expand Down
8 changes: 2 additions & 6 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.TimestampNow: "now",
ops.TypeOf: "toTypeName",
ops.Unnest: "arrayJoin",
ops.RandomUUID: "generateUUIDv4",
ops.RandomScalar: "randCanonical",
}

@staticmethod
Expand Down Expand Up @@ -719,12 +721,6 @@ def visit_TimestampRange(self, op, *, start, stop, step):
def visit_RegexSplit(self, op, *, arg, pattern):
return self.f.splitByRegexp(pattern, self.cast(arg, dt.String(nullable=False)))

def visit_RandomScalar(self, op, **kwargs):
return self.f.randCanonical()

def visit_RandomUUID(self, op, **kwargs):
return self.f.generateUUIDv4()

@staticmethod
def _generate_groups(groups):
return groups
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
def visit_ArrayFlatten(self, op, *, arg):
return self.if_(arg.is_(NULL), NULL, self.f.flatten(arg))

def visit_RandomUUID(self, op, **kw):
def visit_RandomUUID(self, op):
return self.f.anon.uuid()


Expand Down
7 changes: 1 addition & 6 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class DuckDBCompiler(SQLGlotCompiler):
ops.GeoWithin: "st_within",
ops.GeoX: "st_x",
ops.GeoY: "st_y",
ops.RandomScalar: "random",
}

def to_sqlglot(
Expand Down Expand Up @@ -602,12 +603,6 @@ def visit_StructField(self, op, *, arg, field):
)
return super().visit_StructField(op, arg=arg, field=field)

def visit_RandomScalar(self, op, **kwargs):
return self.f.random()

def visit_RandomUUID(self, op, **kwargs):
return self.f.uuid()

def visit_TypeOf(self, op, *, arg):
return self.f.coalesce(self.f.nullif(self.f.typeof(arg), '"NULL"'), "NULL")

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def visit_CountDistinct(self, op, *, arg, where):
def visit_Xor(self, op, *, left, right):
return sg.and_(sg.or_(left, right), sg.not_(sg.and_(left, right)))

def visit_RandomScalar(self, op, **_):
def visit_RandomScalar(self, op):
return self.f.rand(self.f.utc_to_unix_micros(self.f.utc_timestamp()))

def visit_DayOfWeekIndex(self, op, *, arg):
Expand Down
6 changes: 2 additions & 4 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.TimestampNow: "sysdatetime",
ops.Min: "min",
ops.Max: "max",
ops.RandomUUID: "newid",
}

NAN = sg.func("double", sge.convert("NaN"))
Expand Down Expand Up @@ -177,10 +178,7 @@ def to_sqlglot(
table_expr = table_expr.mutate(**conversions)
return super().to_sqlglot(table_expr, limit=limit, params=params)

def visit_RandomUUID(self, op, **_):
return self.f.newid()

def visit_RandomScalar(self, op, **_):
def visit_RandomScalar(self, op):
# By default RAND() will generate the same value for all calls within a
# query. The standard way to work around this is to pass in a unique
# value per call, which `CHECKSUM(NEWID())` provides.
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def visit_Log(self, op, *, arg, base):
def visit_IsInf(self, op, *, arg):
return arg.isin(self.POS_INF, self.NEG_INF)

def visit_RandomScalar(self, op, **_):
def visit_RandomScalar(self, op):
# Not using FuncGen here because of dotted function call
return sg.func("dbms_random.value")

Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class PostgresCompiler(SQLGlotCompiler):
ops.MapValues: "avals",
ops.RegexSearch: "regexp_like",
ops.TimeFromHMS: "make_time",
ops.RandomUUID: "gen_random_uuid",
}

def to_sqlglot(
Expand Down Expand Up @@ -179,9 +180,6 @@ def _compile_python_udf(self, udf_node: ops.ScalarUDF):
args=", ".join(argnames),
)

def visit_RandomUUID(self, op, **kwargs):
return self.f.gen_random_uuid()

def visit_Mode(self, op, *, arg, where):
expr = self.f.mode()
expr = sge.WithinGroup(
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ class PySparkCompiler(SQLGlotCompiler):
}

SIMPLE_OPS = {
ops.ArgMax: "max_by",
ops.ArgMin: "min_by",
ops.ArrayDistinct: "array_distinct",
ops.ArrayFlatten: "flatten",
ops.ArrayIntersect: "array_intersect",
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def visit_MapLength(self, op, *, arg):
def visit_Log(self, op, *, arg, base):
return self.f.log(base, arg)

def visit_RandomScalar(self, op, **_):
def visit_RandomScalar(self, op):
return self.f.uniform(
self.f.to_double(0.0), self.f.to_double(1.0), self.f.random()
)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/sql/compilers/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ def visit_Clip(self, op, *, arg, lower, upper):

return arg

def visit_RandomScalar(self, op, **kwargs):
def visit_RandomScalar(self, op):
return 0.5 + self.f.random() / sge.Literal.number(float(-1 << 64))

def visit_Cot(self, op, *, arg):
return 1 / self.f.tan(arg)
return 1.0 / self.f.tan(arg)

def visit_ArgMin(self, *args, **kwargs):
return self._visit_arg_reduction("min", *args, **kwargs)
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ class TrinoCompiler(SQLGlotCompiler):

SIMPLE_OPS = {
ops.Arbitrary: "any_value",
ops.ArgMax: "max_by",
ops.ArgMin: "min_by",
ops.Pi: "pi",
ops.E: "e",
ops.RegexReplace: "regexp_replace",
Expand Down
Loading