Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds arithmetic operators as separate nodes & fixes casts #1522

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions partiql-eval/api/partiql-eval.api
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ public abstract interface class org/partiql/eval/PartiQLStatement$Query : org/pa
public abstract interface class org/partiql/eval/value/Datum : java/lang/Iterable {
public static fun bagValue (Ljava/lang/Iterable;)Lorg/partiql/eval/value/Datum;
public static fun boolValue (Z)Lorg/partiql/eval/value/Datum;
public static fun decimal (Ljava/math/BigDecimal;II)Lorg/partiql/eval/value/Datum;
public static fun decimalArbitrary (Ljava/math/BigDecimal;)Lorg/partiql/eval/value/Datum;
public static fun doublePrecision (D)Lorg/partiql/eval/value/Datum;
public fun get (Ljava/lang/String;)Lorg/partiql/eval/value/Datum;
public fun getBigDecimal ()Ljava/math/BigDecimal;
public fun getBigInteger ()Ljava/math/BigInteger;
Expand All @@ -87,6 +90,7 @@ public abstract interface class org/partiql/eval/value/Datum : java/lang/Iterabl
public abstract fun getType ()Lorg/partiql/types/PType;
public static fun int32Value (I)Lorg/partiql/eval/value/Datum;
public static fun int64Value (J)Lorg/partiql/eval/value/Datum;
public static fun intArbitrary (Ljava/math/BigInteger;)Lorg/partiql/eval/value/Datum;
public fun isMissing ()Z
public fun isNull ()Z
public fun iterator ()Ljava/util/Iterator;
Expand All @@ -96,9 +100,12 @@ public abstract interface class org/partiql/eval/value/Datum : java/lang/Iterabl
public static fun nullValue ()Lorg/partiql/eval/value/Datum;
public static fun nullValue (Lorg/partiql/types/PType;)Lorg/partiql/eval/value/Datum;
public static fun of (Lorg/partiql/value/PartiQLValue;)Lorg/partiql/eval/value/Datum;
public static fun real (F)Lorg/partiql/eval/value/Datum;
public static fun sexpValue (Ljava/lang/Iterable;)Lorg/partiql/eval/value/Datum;
public static fun smallInt (S)Lorg/partiql/eval/value/Datum;
public static fun stringValue (Ljava/lang/String;)Lorg/partiql/eval/value/Datum;
public static fun structValue (Ljava/lang/Iterable;)Lorg/partiql/eval/value/Datum;
public static fun tinyInt (B)Lorg/partiql/eval/value/Datum;
public fun toPartiQLValue ()Lorg/partiql/value/PartiQLValue;
}

Expand Down
40 changes: 40 additions & 0 deletions partiql-eval/src/main/java/org/partiql/eval/value/Datum.java
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ default Datum getInsensitive(@NotNull String name) {
@Deprecated
default PartiQLValue toPartiQLValue() {
PType type = this.getType();
if (this.isMissing()) {
return PartiQL.missingValue();
}
switch (type.getKind()) {
case BOOL:
return this.isNull() ? PartiQL.boolValue(null) : PartiQL.boolValue(this.getBoolean());
Expand Down Expand Up @@ -534,6 +537,16 @@ static Datum bagValue(@NotNull Iterable<Datum> values) {
return new DatumCollection(values, PType.typeBag());
}

@NotNull
static Datum tinyInt(byte value) {
return new DatumByte(value, PType.typeTinyInt());
}

@NotNull
static Datum smallInt(short value) {
return new DatumShort(value);
}

@NotNull
static Datum int64Value(long value) {
return new DatumLong(value);
Expand All @@ -544,6 +557,33 @@ static Datum int32Value(int value) {
return new DatumInt(value);
}

@Deprecated
@NotNull
static Datum intArbitrary(@NotNull BigInteger value) {
return new DatumBigInteger(value);
}

@NotNull
static Datum real(float value) {
return new DatumFloat(value);
}

@NotNull
static Datum doublePrecision(double value) {
return new DatumDouble(value);
}

@Deprecated
@NotNull
static Datum decimalArbitrary(@NotNull BigDecimal value) {
return new DatumDecimal(value, PType.typeDecimalArbitrary());
}

@NotNull
static Datum decimal(@NotNull BigDecimal value, int precision, int scale) {
return new DatumDecimal(value, PType.typeDecimal(precision, scale));
}

@NotNull
static Datum boolValue(boolean value) {
return new DatumBoolean(value);
Expand Down
101 changes: 101 additions & 0 deletions partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import org.partiql.eval.internal.operator.rel.RelSort
import org.partiql.eval.internal.operator.rel.RelUnionAll
import org.partiql.eval.internal.operator.rel.RelUnionDistinct
import org.partiql.eval.internal.operator.rel.RelUnpivot
import org.partiql.eval.internal.operator.rex.ExprArithmeticBinary
import org.partiql.eval.internal.operator.rex.ExprArithmeticUnary
import org.partiql.eval.internal.operator.rex.ExprCallDynamic
import org.partiql.eval.internal.operator.rex.ExprCallStatic
import org.partiql.eval.internal.operator.rex.ExprCase
Expand Down Expand Up @@ -279,6 +281,98 @@ internal class Compiler(
}
}

override fun visitRexOpAdd(node: Rex.Op.Add, ctx: PType?): Operator {
val (lhs, rhs) = visitBinaryArithmeticArgs(node.lhs, node.rhs, ctx!!)
val factory = getArithmeticBinaryFactory(ctx)
return factory.add(lhs, rhs)
}

override fun visitRexOpMultiply(node: Rex.Op.Multiply, ctx: PType?): Operator {
val (lhs, rhs) = visitBinaryArithmeticArgs(node.lhs, node.rhs, ctx!!)
val factory = getArithmeticBinaryFactory(ctx)
return factory.multiply(lhs, rhs)
}

override fun visitRexOpDivide(node: Rex.Op.Divide, ctx: PType?): Operator {
val (lhs, rhs) = visitBinaryArithmeticArgs(node.lhs, node.rhs, ctx!!)
val factory = getArithmeticBinaryFactory(ctx)
return factory.divide(lhs, rhs)
}

override fun visitRexOpSubtract(node: Rex.Op.Subtract, ctx: PType?): Operator {
val (lhs, rhs) = visitBinaryArithmeticArgs(node.lhs, node.rhs, ctx!!)
val factory = getArithmeticBinaryFactory(ctx)
return factory.subtract(lhs, rhs)
}

override fun visitRexOpModulo(node: Rex.Op.Modulo, ctx: PType?): Operator {
val (lhs, rhs) = visitBinaryArithmeticArgs(node.lhs, node.rhs, ctx!!)
val factory = getArithmeticBinaryFactory(ctx)
return factory.modulo(lhs, rhs)
}

override fun visitRexOpNegative(node: Rex.Op.Negative, ctx: PType?): Operator {
val arg = visitRex(node.arg, ctx)
val factory = getArithmeticUnaryFactory(node.arg.type)
return factory.negative(arg)
}

override fun visitRexOpPositive(node: Rex.Op.Positive, ctx: PType?): Operator {
val arg = visitRex(node.arg, ctx)
val factory = getArithmeticUnaryFactory(node.arg.type)
return factory.positive(arg)
}

private fun visitBinaryArithmeticArgs(l: Rex, r: Rex, returnType: PType): Pair<Operator.Expr, Operator.Expr> {
val lhsVisited = visitRex(l, l.type)
val rhsVisited = visitRex(r, r.type)
return when (returnType.kind) {
PType.Kind.TINYINT, PType.Kind.SMALLINT, PType.Kind.INT, PType.Kind.BIGINT, PType.Kind.INT_ARBITRARY,
PType.Kind.REAL, PType.Kind.DOUBLE_PRECISION,
PType.Kind.DECIMAL_ARBITRARY -> lhsVisited.coerce(l.type, returnType) to rhsVisited.coerce(r.type, returnType)
PType.Kind.DECIMAL -> lhsVisited.coerce(l.type, returnType) to rhsVisited.coerce(r.type, PType.typeDecimalArbitrary())
PType.Kind.DYNAMIC -> lhsVisited to rhsVisited
else -> error("Unsupported type: $returnType")
}
}

private fun getArithmeticUnaryFactory(returns: PType): ExprArithmeticUnary.Factory {
return when (returns.kind) {
PType.Kind.TINYINT -> ExprArithmeticUnary.Factory.Byte
PType.Kind.SMALLINT -> ExprArithmeticUnary.Factory.Short
PType.Kind.INT -> ExprArithmeticUnary.Factory.Int
PType.Kind.BIGINT -> ExprArithmeticUnary.Factory.BigInt
PType.Kind.INT_ARBITRARY -> ExprArithmeticUnary.Factory.IntArbitrary
PType.Kind.REAL -> ExprArithmeticUnary.Factory.Float
PType.Kind.DOUBLE_PRECISION -> ExprArithmeticUnary.Factory.Double
PType.Kind.DECIMAL -> ExprArithmeticUnary.Factory.Decimal(returns)
PType.Kind.DECIMAL_ARBITRARY -> ExprArithmeticUnary.Factory.DecimalArbitrary
PType.Kind.DYNAMIC -> ExprArithmeticUnary.Factory.Dynamic
else -> error("Unsupported type: $returns")
}
}

private fun getArithmeticBinaryFactory(returns: PType): ExprArithmeticBinary.Factory {
return when (returns.kind) {
PType.Kind.TINYINT -> ExprArithmeticBinary.Factory.Byte
PType.Kind.SMALLINT -> ExprArithmeticBinary.Factory.Short
PType.Kind.INT -> ExprArithmeticBinary.Factory.Int
PType.Kind.BIGINT -> ExprArithmeticBinary.Factory.BigInt
PType.Kind.INT_ARBITRARY -> ExprArithmeticBinary.Factory.IntArbitrary
PType.Kind.REAL -> ExprArithmeticBinary.Factory.Float
PType.Kind.DOUBLE_PRECISION -> ExprArithmeticBinary.Factory.Double
PType.Kind.DECIMAL -> ExprArithmeticBinary.Factory.Decimal(returns)
PType.Kind.DECIMAL_ARBITRARY -> ExprArithmeticBinary.Factory.DecimalArbitrary
PType.Kind.DYNAMIC -> ExprArithmeticBinary.Factory.Dynamic
else -> error("Unsupported type: $returns")
}
}

private fun Operator.Expr.coerce(input: PType, target: PType): Operator.Expr {
if (input == target) return this
return ExprCast(this, Ref.Cast(input, target, isNullable = true))
}

// REL
override fun visitRel(node: Rel, ctx: PType?): Operator.Relation {
return super.visitRelOp(node.op, ctx) as Operator.Relation
Expand Down Expand Up @@ -431,4 +525,11 @@ internal class Compiler(
}
return item
}

internal companion object {
private fun Operator.Expr.coerce(input: PType, target: PType): Operator.Expr {
if (input == target) return this
return ExprCast(this, Ref.Cast(input, target, isNullable = true))
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package org.partiql.eval.internal.helpers

import org.partiql.types.PType
import org.partiql.types.PType.Kind
import kotlin.math.max
import kotlin.math.min

/**
* This is a mirror copy to [org.partiql.planner.internal.typer.ArithmeticTyper].
*/
internal object ArithmeticTyper {

/**
* Follows SQL-Server for decimals:
* Result Precision = max(s1, s2) + max(p1 - s1, p2 - s2) + 1
* Result Scale = max(s1, s2)
* @return null if the operation is not allowed on the input types.
*/
internal fun add(lhs: PType, rhs: PType): PType? = arithmeticBinary(lhs, rhs) { lDec, rDec ->
val precision = max(lDec.scale, rDec.scale) + max(lDec.precision - lDec.scale, rDec.precision - rDec.scale) + 1
val scale = max(lDec.scale, rDec.scale)
precision to scale
}

/**
* Follows SQL-Server for decimals:
* Result Precision = max(s1, s2) + max(p1 - s1, p2 - s2) + 1
* Result Scale = max(s1, s2)
* @return null if the operation is not allowed on the input types.
*/
internal fun subtract(lhs: PType, rhs: PType): PType? = arithmeticBinary(lhs, rhs) { lDec, rDec ->
val precision = max(lhs.scale, rhs.scale) + max(lhs.precision - lhs.scale, rhs.precision - rhs.scale) + 1
val scale = max(lhs.scale, rhs.scale)
precision to scale
}

/**
* Follows SQL-Server for decimals:
* Result Precision = p1 - s1 + s2 + max(6, s1 + p2 + 1)
* Result Scale = max(6, s1 + p2 + 1)
* @return null if the operation is not allowed on the input types.
*/
internal fun divide(lhs: PType, rhs: PType): PType? = arithmeticBinary(lhs, rhs) { lDec, rDec ->
val precision = lhs.precision - rhs.scale + lhs.scale + max(6, lhs.scale + rhs.precision + 1)
val scale = max(6, lhs.scale + rhs.precision + 1)
precision to scale
}

/**
* Follows SQL-Server for decimals:
* Result Precision = p1 + p2 + 1
* Result Scale = s1 + s2
* @return null if the operation is not allowed on the input types.
*/
internal fun multiply(lhs: PType, rhs: PType): PType? = arithmeticBinary(lhs, rhs) { lDec, rDec ->
val precision = lhs.precision + rhs.precision + 1
val scale = lhs.scale + rhs.scale
precision to scale
}

/**
* Follows SQL-Server for decimals:
* Result Precision: min(p1 - s1, p2 - s2) + max(s1, s2)
* Result Scale: max(s1, s2)
* @return null if the operation is not allowed on the input types.
*/
internal fun modulo(lhs: PType, rhs: PType): PType? = arithmeticBinary(lhs, rhs) { lDec, rDec ->
val precision = min(lhs.precision - lhs.scale, rhs.precision - rhs.scale) + max(lhs.scale, rhs.scale)
val scale = max(lhs.scale, rhs.scale)
precision to scale
}

internal fun negative(arg: PType): PType? = arithmeticUnary(arg)

internal fun positive(arg: PType): PType? = arithmeticUnary(arg)

private fun arithmeticUnary(arg: PType): PType? {
val argMayBeNumber = arg.kind == Kind.DYNAMIC || TypeFamily.NUMBERS.contains(arg.kind)
return when (argMayBeNumber) {
true -> arg
false -> null
}
}

private fun arithmeticBinary(
lhs: PType,
rhs: PType,
handleDecimal: (PType, PType) -> Pair<Int, Int>
): PType? {
val lhsCannotBeNumber = lhs.kind != Kind.DYNAMIC && !TypeFamily.NUMBERS.contains(lhs.kind)
val rhsCannotBeNumber = rhs.kind != Kind.DYNAMIC && !TypeFamily.NUMBERS.contains(rhs.kind)
if (lhsCannotBeNumber || rhsCannotBeNumber) {
return null
}
if (lhs.kind == Kind.DYNAMIC || rhs.kind == Kind.DYNAMIC) {
return PType.typeDynamic()
}
val lhsPrecedence = TypePrecedence[lhs.kind]!!
val rhsPrecedence = TypePrecedence[rhs.kind]!!
val comp = lhsPrecedence.compareTo(rhsPrecedence)
when (comp) {
-1 -> return rhs
1 -> return lhs
0 -> if (lhs.kind != Kind.DECIMAL) {
return lhs
}

else -> error("This shouldn't have occurred.")
}
val (precision, scale) = handleDecimal(lhs, rhs)
// TODO: Check if this is what we want
return when (precision > 38 || scale > 38) {
true -> PType.typeDecimalArbitrary()
false -> PType.typeDecimal(precision, scale)
}
}
}
Loading
Loading