Skip to content

Commit

Permalink
refactor(sql): clean up unnecessary use of explicit visit methods
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Nov 3, 2024
1 parent ebee0e9 commit 2acaa16
Show file tree
Hide file tree
Showing 11 changed files with 15 additions and 38 deletions.
10 changes: 2 additions & 8 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ class SQLGlotCompiler(abc.ABC):
ops.Translate: "translate",
ops.Unnest: "explode",
ops.Uppercase: "upper",
ops.RandomUUID: "uuid",
ops.RandomScalar: "rand",
}

BINARY_INFIX_OPS = {
Expand Down Expand Up @@ -869,14 +871,6 @@ def visit_Floor(self, op, *, arg):
def visit_Round(self, op, *, arg, digits):
return self.cast(self.f.round(arg, digits), op.dtype)

### Random Noise

def visit_RandomScalar(self, op, **kwargs):
return self.f.rand()

def visit_RandomUUID(self, op, **kwargs):
return self.f.uuid()

### Dtype Dysmorphia

def visit_TryCast(self, op, *, arg, to):
Expand Down
6 changes: 1 addition & 5 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ class BigQueryCompiler(SQLGlotCompiler):
ops.TimeFromHMS: "time_from_parts",
ops.TimestampNow: "current_timestamp",
ops.ExtractHost: "net.host",
ops.ArgMin: "min_by",
ops.ArgMax: "max_by",
ops.RandomUUID: "generate_uuid",
}

def to_sqlglot(
Expand Down Expand Up @@ -997,9 +996,6 @@ def visit_CountDistinct(self, op, *, arg, where):
arg = self.if_(where, arg, NULL)
return self.f.count(sge.Distinct(expressions=[arg]))

def visit_RandomUUID(self, op, **kwargs):
return self.f.generate_uuid()

def visit_ExtractFile(self, op, *, arg):
return self._pudf("cw_url_extract_file", arg)

Expand Down
8 changes: 2 additions & 6 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.TimestampNow: "now",
ops.TypeOf: "toTypeName",
ops.Unnest: "arrayJoin",
ops.RandomUUID: "generateUUIDv4",
ops.RandomScalar: "randCanonical",
}

@staticmethod
Expand Down Expand Up @@ -719,12 +721,6 @@ def visit_TimestampRange(self, op, *, start, stop, step):
def visit_RegexSplit(self, op, *, arg, pattern):
return self.f.splitByRegexp(pattern, self.cast(arg, dt.String(nullable=False)))

def visit_RandomScalar(self, op, **kwargs):
return self.f.randCanonical()

def visit_RandomUUID(self, op, **kwargs):
return self.f.generateUUIDv4()

@staticmethod
def _generate_groups(groups):
return groups
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
def visit_ArrayFlatten(self, op, *, arg):
return self.if_(arg.is_(NULL), NULL, self.f.flatten(arg))

def visit_RandomUUID(self, op, **kw):
def visit_RandomUUID(self, op):
return self.f.anon.uuid()


Expand Down
7 changes: 1 addition & 6 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class DuckDBCompiler(SQLGlotCompiler):
ops.GeoWithin: "st_within",
ops.GeoX: "st_x",
ops.GeoY: "st_y",
ops.RandomScalar: "random",
}

def to_sqlglot(
Expand Down Expand Up @@ -608,12 +609,6 @@ def visit_StructField(self, op, *, arg, field):
)
return super().visit_StructField(op, arg=arg, field=field)

def visit_RandomScalar(self, op, **kwargs):
return self.f.random()

def visit_RandomUUID(self, op, **kwargs):
return self.f.uuid()

def visit_TypeOf(self, op, *, arg):
return self.f.coalesce(self.f.nullif(self.f.typeof(arg), '"NULL"'), "NULL")

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def visit_CountDistinct(self, op, *, arg, where):
def visit_Xor(self, op, *, left, right):
return sg.and_(sg.or_(left, right), sg.not_(sg.and_(left, right)))

def visit_RandomScalar(self, op, **_):
def visit_RandomScalar(self, op):
return self.f.rand(self.f.utc_to_unix_micros(self.f.utc_timestamp()))

def visit_DayOfWeekIndex(self, op, *, arg):
Expand Down
6 changes: 2 additions & 4 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.TimestampNow: "sysdatetime",
ops.Min: "min",
ops.Max: "max",
ops.RandomUUID: "newid",
}

NAN = sg.func("double", sge.convert("NaN"))
Expand Down Expand Up @@ -177,10 +178,7 @@ def to_sqlglot(
table_expr = table_expr.mutate(**conversions)
return super().to_sqlglot(table_expr, limit=limit, params=params)

def visit_RandomUUID(self, op, **_):
return self.f.newid()

def visit_RandomScalar(self, op, **_):
def visit_RandomScalar(self, op):
# By default RAND() will generate the same value for all calls within a
# query. The standard way to work around this is to pass in a unique
# value per call, which `CHECKSUM(NEWID())` provides.
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def visit_Log(self, op, *, arg, base):
def visit_IsInf(self, op, *, arg):
return arg.isin(self.POS_INF, self.NEG_INF)

def visit_RandomScalar(self, op, **_):
def visit_RandomScalar(self, op):
# Not using FuncGen here because of dotted function call
return sg.func("dbms_random.value")

Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class PostgresCompiler(SQLGlotCompiler):
ops.MapValues: "avals",
ops.RegexSearch: "regexp_like",
ops.TimeFromHMS: "make_time",
ops.RandomUUID: "gen_random_uuid",
}

def to_sqlglot(
Expand Down Expand Up @@ -179,9 +180,6 @@ def _compile_python_udf(self, udf_node: ops.ScalarUDF):
args=", ".join(argnames),
)

def visit_RandomUUID(self, op, **kwargs):
return self.f.gen_random_uuid()

def visit_Mode(self, op, *, arg, where):
expr = self.f.mode()
expr = sge.WithinGroup(
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def visit_MapLength(self, op, *, arg):
def visit_Log(self, op, *, arg, base):
return self.f.log(base, arg)

def visit_RandomScalar(self, op, **_):
def visit_RandomScalar(self, op):
return self.f.uniform(
self.f.to_double(0.0), self.f.to_double(1.0), self.f.random()
)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/sql/compilers/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ def visit_Clip(self, op, *, arg, lower, upper):

return arg

def visit_RandomScalar(self, op, **kwargs):
def visit_RandomScalar(self, op):
return 0.5 + self.f.random() / sge.Literal.number(float(-1 << 64))

def visit_Cot(self, op, *, arg):
return 1 / self.f.tan(arg)
return 1.0 / self.f.tan(arg)

def visit_ArgMin(self, *args, **kwargs):
return self._visit_arg_reduction("min", *args, **kwargs)
Expand Down

0 comments on commit 2acaa16

Please sign in to comment.