Skip to content

Commit

Permalink
Adds CAST transpilation to Redshift target (#1223)
Browse files Browse the repository at this point in the history
  • Loading branch information
yliuuuu authored Sep 27, 2023
1 parent b3f8ff6 commit ef73d18
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.partiql.ast.Expr
import org.partiql.ast.Identifier
import org.partiql.ast.builder.AstFactory
import org.partiql.transpiler.ProblemCallback
import org.partiql.transpiler.error
import org.partiql.transpiler.info
import org.partiql.transpiler.sql.SqlArgs
import org.partiql.transpiler.sql.SqlCallFn
Expand Down Expand Up @@ -55,6 +56,64 @@ public class RedshiftCalls(private val onProblem: ProblemCallback) : SqlCalls()
exprVar(id, Expr.Var.Scope.DEFAULT)
}

override fun rewriteCast(type: PartiQLValueType, args: SqlArgs): Expr = Ast.create {
when (type) {
PartiQLValueType.ANY -> {
onProblem.error("PartiQL `ANY` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.INT8 -> {
onProblem.error("PartiQL `INT8` type (1-byte integer) not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.INT -> {
onProblem.error("PartiQL `INT` type (arbitrary precision integer) not supported in Redshift")
// this needs a extra safety renaming because int refers to int4 in redshift.
exprCast(args[0].expr, typeCustom("Arbitrary Precision Integer"))
}
PartiQLValueType.MISSING -> {
onProblem.error("PartiQL `MISSING` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.SYMBOL -> {
onProblem.error("PartiQL `SYMBOL` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.INTERVAL -> {
onProblem.error("PartiQL `INTERVAL` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.BLOB -> {
onProblem.error("PartiQL `BLOB` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.CLOB -> {
onProblem.error("PartiQL `CLOB` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.BAG -> {
onProblem.error("PartiQL `BAG` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.LIST -> {
onProblem.error("PartiQL `LIST` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.SEXP -> {
onProblem.error("PartiQL `SEXP` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.STRUCT -> {
onProblem.error("PartiQL `STRUCT` type not supported in Redshift")
super.rewriteCast(type, args)
}
// using the customer type to rename type
PartiQLValueType.FLOAT32 -> exprCast(args[0].expr, typeCustom("FLOAT4"))
PartiQLValueType.FLOAT64 -> exprCast(args[0].expr, typeCustom("FLOAT8"))
PartiQLValueType.BINARY -> exprCast(args[0].expr, typeCustom("VARBYTE"))
PartiQLValueType.BYTE -> TODO("Mapping to VARBYTE(1), do this after supporting parameterized type")
else -> super.rewriteCast(type, args)
}
/**
* Push the negation down if possible.
* For example : NOT 1 is NULL -> 1 is NOT NULL.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
override fun visitTypeInterval(node: Type.Interval, head: SqlBlock) = head concat type("INTERVAL", node.precision)

// unsupported
override fun visitTypeCustom(node: Type.Custom, head: SqlBlock) = defaultReturn(node, head)
override fun visitTypeCustom(node: Type.Custom, head: SqlBlock) = head concat r(node.name)

// Expressions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import org.partiql.planner.Env
import org.partiql.planner.typer.toNonNullStaticType
import org.partiql.planner.typer.toStaticType
import org.partiql.types.StaticType
import org.partiql.types.TimeType
import org.partiql.types.TimestampType
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.boolValue
Expand Down Expand Up @@ -384,8 +386,47 @@ internal object RexConverter {
TODO("SQL Special Form EXTRACT")
}

override fun visitExprCast(node: Expr.Cast, ctx: Env): Rex {
TODO("SQL Special Form CAST")
// TODO: Ignoring type parameter now
override fun visitExprCast(node: Expr.Cast, ctx: Env): Rex = transform {
val type = node.asType
val arg0 = visitExpr(node.value, ctx)
when(type) {
is Type.NullType -> rex(StaticType.NULL, call("cast_null", arg0))
is Type.Missing -> rex(StaticType.MISSING, call("cast_missing", arg0))
is Type.Bool -> rex(StaticType.BOOL, call("cast_bool", arg0))
is Type.Tinyint -> TODO("Static Type does not have TINYINT type")
is Type.Smallint, is Type.Int2 -> rex(StaticType.INT2, call("cast_int16", arg0))
is Type.Int4 -> rex(StaticType.INT4, call("cast_int32", arg0))
is Type.Bigint, is Type.Int8 -> rex(StaticType.INT8, call("cast_int64", arg0))
is Type.Int -> rex(StaticType.INT, call("cast_int", arg0))
is Type.Real -> TODO("Static Type does not have REAL type")
is Type.Float32 -> TODO("Static Type does not have FLOAT32 type")
is Type.Float64 -> rex(StaticType.FLOAT, call("cast_float64", arg0))
is Type.Decimal -> rex(StaticType.DECIMAL, call("cast_decimal", arg0))
is Type.Numeric -> rex(StaticType.DECIMAL, call("cast_numeric", arg0))
is Type.Char -> rex(StaticType.CHAR, call("cast_char", arg0))
is Type.Varchar -> rex(StaticType.STRING, call("cast_varchar", arg0))
is Type.String -> rex(StaticType.STRING, call("cast_string", arg0))
is Type.Symbol -> rex(StaticType.SYMBOL, call("cast_symbol", arg0))
is Type.Bit -> TODO("Static Type does not have Bit type")
is Type.BitVarying -> TODO("Static Type does not have BitVarying type")
is Type.ByteString -> TODO("Static Type does not have ByteString type")
is Type.Blob -> rex(StaticType.BLOB, call("cast_blob", arg0))
is Type.Clob -> rex(StaticType.CLOB, call("cast_clob", arg0))
is Type.Date -> rex(StaticType.DATE, call("cast_date", arg0))
is Type.Time -> rex(StaticType.TIME, call("cast_time", arg0))
is Type.TimeWithTz -> rex(TimeType(null, true), call("cast_timeWithTz", arg0))
is Type.Timestamp -> TODO("Need to rebase main")
is Type.TimestampWithTz -> rex(StaticType.TIMESTAMP, call("cast_timeWithTz", arg0))
is Type.Interval -> TODO("Static Type does not have Interval type")
is Type.Bag -> rex(StaticType.BAG, call("cast_bag", arg0))
is Type.List -> rex(StaticType.LIST, call("cast_list", arg0))
is Type.Sexp -> rex(StaticType.SEXP, call("cast_sexp", arg0))
is Type.Tuple -> rex(StaticType.STRUCT, call("cast_tuple", arg0))
is Type.Struct -> rex(StaticType.STRUCT, call("cast_struct", arg0))
is Type.Any -> rex(StaticType.ANY, call("cast_any", arg0))
is Type.Custom -> TODO("Custom type not supported ")
}
}

override fun visitExprCanCast(node: Expr.CanCast, ctx: Env): Rex {
Expand Down

0 comments on commit ef73d18

Please sign in to comment.