Skip to content

Commit

Permalink
Adds arithmetic operators as separate nodes
Browse files Browse the repository at this point in the history
Fixes casts to invalid types for conformance
  • Loading branch information
johnedquinn committed Jul 24, 2024
1 parent 73db367 commit 12dda57
Show file tree
Hide file tree
Showing 29 changed files with 2,175 additions and 223 deletions.
7 changes: 7 additions & 0 deletions partiql-eval/api/partiql-eval.api
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ public abstract interface class org/partiql/eval/PartiQLStatement$Query : org/pa
public abstract interface class org/partiql/eval/value/Datum : java/lang/Iterable {
public static fun bagValue (Ljava/lang/Iterable;)Lorg/partiql/eval/value/Datum;
public static fun boolValue (Z)Lorg/partiql/eval/value/Datum;
public static fun decimal (Ljava/math/BigDecimal;II)Lorg/partiql/eval/value/Datum;
public static fun decimalArbitrary (Ljava/math/BigDecimal;)Lorg/partiql/eval/value/Datum;
public static fun doublePrecision (D)Lorg/partiql/eval/value/Datum;
public fun get (Ljava/lang/String;)Lorg/partiql/eval/value/Datum;
public fun getBigDecimal ()Ljava/math/BigDecimal;
public fun getBigInteger ()Ljava/math/BigInteger;
Expand All @@ -87,6 +90,7 @@ public abstract interface class org/partiql/eval/value/Datum : java/lang/Iterabl
public abstract fun getType ()Lorg/partiql/types/PType;
public static fun int32Value (I)Lorg/partiql/eval/value/Datum;
public static fun int64Value (J)Lorg/partiql/eval/value/Datum;
public static fun intArbitrary (Ljava/math/BigInteger;)Lorg/partiql/eval/value/Datum;
public fun isMissing ()Z
public fun isNull ()Z
public fun iterator ()Ljava/util/Iterator;
Expand All @@ -96,9 +100,12 @@ public abstract interface class org/partiql/eval/value/Datum : java/lang/Iterabl
public static fun nullValue ()Lorg/partiql/eval/value/Datum;
public static fun nullValue (Lorg/partiql/types/PType;)Lorg/partiql/eval/value/Datum;
public static fun of (Lorg/partiql/value/PartiQLValue;)Lorg/partiql/eval/value/Datum;
public static fun real (F)Lorg/partiql/eval/value/Datum;
public static fun sexpValue (Ljava/lang/Iterable;)Lorg/partiql/eval/value/Datum;
public static fun smallInt (S)Lorg/partiql/eval/value/Datum;
public static fun stringValue (Ljava/lang/String;)Lorg/partiql/eval/value/Datum;
public static fun structValue (Ljava/lang/Iterable;)Lorg/partiql/eval/value/Datum;
public static fun tinyInt (B)Lorg/partiql/eval/value/Datum;
public fun toPartiQLValue ()Lorg/partiql/value/PartiQLValue;
}

Expand Down
40 changes: 40 additions & 0 deletions partiql-eval/src/main/java/org/partiql/eval/value/Datum.java
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ default Datum getInsensitive(@NotNull String name) {
@Deprecated
default PartiQLValue toPartiQLValue() {
PType type = this.getType();
if (this.isMissing()) {
return PartiQL.missingValue();
}
switch (type.getKind()) {
case BOOL:
return this.isNull() ? PartiQL.boolValue(null) : PartiQL.boolValue(this.getBoolean());
Expand Down Expand Up @@ -534,6 +537,16 @@ static Datum bagValue(@NotNull Iterable<Datum> values) {
return new DatumCollection(values, PType.typeBag());
}

@NotNull
static Datum tinyInt(byte value) {
return new DatumByte(value, PType.typeTinyInt());
}

@NotNull
static Datum smallInt(short value) {
return new DatumShort(value);
}

@NotNull
static Datum int64Value(long value) {
return new DatumLong(value);
Expand All @@ -544,6 +557,33 @@ static Datum int32Value(int value) {
return new DatumInt(value);
}

@Deprecated
@NotNull
static Datum intArbitrary(@NotNull BigInteger value) {
return new DatumBigInteger(value);
}

@NotNull
static Datum real(float value) {
return new DatumFloat(value);
}

@NotNull
static Datum doublePrecision(double value) {
return new DatumDouble(value);
}

@Deprecated
@NotNull
static Datum decimalArbitrary(@NotNull BigDecimal value) {
return new DatumDecimal(value, PType.typeDecimalArbitrary());
}

@NotNull
static Datum decimal(@NotNull BigDecimal value, int precision, int scale) {
return new DatumDecimal(value, PType.typeDecimal(precision, scale));
}

@NotNull
static Datum boolValue(boolean value) {
return new DatumBoolean(value);
Expand Down
101 changes: 101 additions & 0 deletions partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import org.partiql.eval.internal.operator.rel.RelSort
import org.partiql.eval.internal.operator.rel.RelUnionAll
import org.partiql.eval.internal.operator.rel.RelUnionDistinct
import org.partiql.eval.internal.operator.rel.RelUnpivot
import org.partiql.eval.internal.operator.rex.ExprArithmeticBinary
import org.partiql.eval.internal.operator.rex.ExprArithmeticUnary
import org.partiql.eval.internal.operator.rex.ExprCallDynamic
import org.partiql.eval.internal.operator.rex.ExprCallStatic
import org.partiql.eval.internal.operator.rex.ExprCase
Expand Down Expand Up @@ -279,6 +281,98 @@ internal class Compiler(
}
}

override fun visitRexOpAdd(node: Rex.Op.Add, ctx: PType?): Operator {
val (lhs, rhs) = visitBinaryArithmeticArgs(node.lhs, node.rhs, ctx!!)
val factory = getArithmeticBinaryFactory(ctx)
return factory.add(lhs, rhs)
}

override fun visitRexOpMultiply(node: Rex.Op.Multiply, ctx: PType?): Operator {
val (lhs, rhs) = visitBinaryArithmeticArgs(node.lhs, node.rhs, ctx!!)
val factory = getArithmeticBinaryFactory(ctx)
return factory.multiply(lhs, rhs)
}

override fun visitRexOpDivide(node: Rex.Op.Divide, ctx: PType?): Operator {
val (lhs, rhs) = visitBinaryArithmeticArgs(node.lhs, node.rhs, ctx!!)
val factory = getArithmeticBinaryFactory(ctx)
return factory.divide(lhs, rhs)
}

override fun visitRexOpSubtract(node: Rex.Op.Subtract, ctx: PType?): Operator {
val (lhs, rhs) = visitBinaryArithmeticArgs(node.lhs, node.rhs, ctx!!)
val factory = getArithmeticBinaryFactory(ctx)
return factory.subtract(lhs, rhs)
}

override fun visitRexOpModulo(node: Rex.Op.Modulo, ctx: PType?): Operator {
val (lhs, rhs) = visitBinaryArithmeticArgs(node.lhs, node.rhs, ctx!!)
val factory = getArithmeticBinaryFactory(ctx)
return factory.modulo(lhs, rhs)
}

override fun visitRexOpNegative(node: Rex.Op.Negative, ctx: PType?): Operator {
val arg = visitRex(node.arg, ctx)
val factory = getArithmeticUnaryFactory(node.arg.type)
return factory.negative(arg)
}

override fun visitRexOpPositive(node: Rex.Op.Positive, ctx: PType?): Operator {
val arg = visitRex(node.arg, ctx)
val factory = getArithmeticUnaryFactory(node.arg.type)
return factory.positive(arg)
}

private fun visitBinaryArithmeticArgs(l: Rex, r: Rex, returnType: PType): Pair<Operator.Expr, Operator.Expr> {
val lhsVisited = visitRex(l, l.type)
val rhsVisited = visitRex(r, r.type)
return when (returnType.kind) {
PType.Kind.TINYINT, PType.Kind.SMALLINT, PType.Kind.INT, PType.Kind.BIGINT, PType.Kind.INT_ARBITRARY,
PType.Kind.REAL, PType.Kind.DOUBLE_PRECISION,
PType.Kind.DECIMAL_ARBITRARY -> lhsVisited.coerce(l.type, returnType) to rhsVisited.coerce(r.type, returnType)
PType.Kind.DECIMAL -> lhsVisited.coerce(l.type, returnType) to rhsVisited.coerce(r.type, PType.typeDecimalArbitrary())
PType.Kind.DYNAMIC -> lhsVisited to rhsVisited
else -> error("Unsupported type: $returnType")
}
}

private fun getArithmeticUnaryFactory(returns: PType): ExprArithmeticUnary.Factory {
return when (returns.kind) {
PType.Kind.TINYINT -> ExprArithmeticUnary.Factory.Byte
PType.Kind.SMALLINT -> ExprArithmeticUnary.Factory.Short
PType.Kind.INT -> ExprArithmeticUnary.Factory.Int
PType.Kind.BIGINT -> ExprArithmeticUnary.Factory.BigInt
PType.Kind.INT_ARBITRARY -> ExprArithmeticUnary.Factory.IntArbitrary
PType.Kind.REAL -> ExprArithmeticUnary.Factory.Float
PType.Kind.DOUBLE_PRECISION -> ExprArithmeticUnary.Factory.Double
PType.Kind.DECIMAL -> ExprArithmeticUnary.Factory.Decimal(returns)
PType.Kind.DECIMAL_ARBITRARY -> ExprArithmeticUnary.Factory.DecimalArbitrary
PType.Kind.DYNAMIC -> ExprArithmeticUnary.Factory.Dynamic
else -> error("Unsupported type: $returns")
}
}

private fun getArithmeticBinaryFactory(returns: PType): ExprArithmeticBinary.Factory {
return when (returns.kind) {
PType.Kind.TINYINT -> ExprArithmeticBinary.Factory.Byte
PType.Kind.SMALLINT -> ExprArithmeticBinary.Factory.Short
PType.Kind.INT -> ExprArithmeticBinary.Factory.Int
PType.Kind.BIGINT -> ExprArithmeticBinary.Factory.BigInt
PType.Kind.INT_ARBITRARY -> ExprArithmeticBinary.Factory.IntArbitrary
PType.Kind.REAL -> ExprArithmeticBinary.Factory.Float
PType.Kind.DOUBLE_PRECISION -> ExprArithmeticBinary.Factory.Double
PType.Kind.DECIMAL -> ExprArithmeticBinary.Factory.Decimal(returns)
PType.Kind.DECIMAL_ARBITRARY -> ExprArithmeticBinary.Factory.DecimalArbitrary
PType.Kind.DYNAMIC -> ExprArithmeticBinary.Factory.Dynamic
else -> error("Unsupported type: $returns")
}
}

private fun Operator.Expr.coerce(input: PType, target: PType): Operator.Expr {
if (input == target) return this
return ExprCast(this, Ref.Cast(input, target, isNullable = true))
}

// REL
override fun visitRel(node: Rel, ctx: PType?): Operator.Relation {
return super.visitRelOp(node.op, ctx) as Operator.Relation
Expand Down Expand Up @@ -431,4 +525,11 @@ internal class Compiler(
}
return item
}

internal companion object {
private fun Operator.Expr.coerce(input: PType, target: PType): Operator.Expr {
if (input == target) return this
return ExprCast(this, Ref.Cast(input, target, isNullable = true))
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package org.partiql.eval.internal.helpers

import org.partiql.types.PType
import org.partiql.types.PType.Kind
import kotlin.math.max
import kotlin.math.min

/**
* This is a mirror copy to [org.partiql.planner.internal.typer.ArithmeticTyper].
*/
internal object ArithmeticTyper {

/**
* Follows SQL-Server for decimals:
* Result Precision = max(s1, s2) + max(p1 - s1, p2 - s2) + 1
* Result Scale = max(s1, s2)
* @return null if the operation is not allowed on the input types.
*/
internal fun add(lhs: PType, rhs: PType): PType? = arithmeticBinary(lhs, rhs) { lDec, rDec ->
val precision = max(lDec.scale, rDec.scale) + max(lDec.precision - lDec.scale, rDec.precision - rDec.scale) + 1
val scale = max(lDec.scale, rDec.scale)
precision to scale
}

/**
* Follows SQL-Server for decimals:
* Result Precision = max(s1, s2) + max(p1 - s1, p2 - s2) + 1
* Result Scale = max(s1, s2)
* @return null if the operation is not allowed on the input types.
*/
internal fun subtract(lhs: PType, rhs: PType): PType? = arithmeticBinary(lhs, rhs) { lDec, rDec ->
val precision = max(lhs.scale, rhs.scale) + max(lhs.precision - lhs.scale, rhs.precision - rhs.scale) + 1
val scale = max(lhs.scale, rhs.scale)
precision to scale
}

/**
* Follows SQL-Server for decimals:
* Result Precision = p1 - s1 + s2 + max(6, s1 + p2 + 1)
* Result Scale = max(6, s1 + p2 + 1)
* @return null if the operation is not allowed on the input types.
*/
internal fun divide(lhs: PType, rhs: PType): PType? = arithmeticBinary(lhs, rhs) { lDec, rDec ->
val precision = lhs.precision - rhs.scale + lhs.scale + max(6, lhs.scale + rhs.precision + 1)
val scale = max(6, lhs.scale + rhs.precision + 1)
precision to scale
}

/**
* Follows SQL-Server for decimals:
* Result Precision = p1 + p2 + 1
* Result Scale = s1 + s2
* @return null if the operation is not allowed on the input types.
*/
internal fun multiply(lhs: PType, rhs: PType): PType? = arithmeticBinary(lhs, rhs) { lDec, rDec ->
val precision = lhs.precision + rhs.precision + 1
val scale = lhs.scale + rhs.scale
precision to scale
}

/**
* Follows SQL-Server for decimals:
* Result Precision: min(p1 - s1, p2 - s2) + max(s1, s2)
* Result Scale: max(s1, s2)
* @return null if the operation is not allowed on the input types.
*/
internal fun modulo(lhs: PType, rhs: PType): PType? = arithmeticBinary(lhs, rhs) { lDec, rDec ->
val precision = min(lhs.precision - lhs.scale, rhs.precision - rhs.scale) + max(lhs.scale, rhs.scale)
val scale = max(lhs.scale, rhs.scale)
precision to scale
}

internal fun negative(arg: PType): PType? = arithmeticUnary(arg)

internal fun positive(arg: PType): PType? = arithmeticUnary(arg)

private fun arithmeticUnary(arg: PType): PType? {
val argMayBeNumber = arg.kind == Kind.DYNAMIC || TypeFamily.NUMBERS.contains(arg.kind)
return when (argMayBeNumber) {
true -> null
false -> arg
}
}

private fun arithmeticBinary(
lhs: PType,
rhs: PType,
handleDecimal: (PType, PType) -> Pair<Int, Int>
): PType? {
val lhsCannotBeNumber = lhs.kind != Kind.DYNAMIC && !TypeFamily.NUMBERS.contains(lhs.kind)
val rhsCannotBeNumber = rhs.kind != Kind.DYNAMIC && !TypeFamily.NUMBERS.contains(rhs.kind)
if (lhsCannotBeNumber || rhsCannotBeNumber) {
return null
}
if (lhs.kind == Kind.DYNAMIC || rhs.kind == Kind.DYNAMIC) {
return PType.typeDynamic()
}
val lhsPrecedence = TypePrecedence[lhs.kind]!!
val rhsPrecedence = TypePrecedence[rhs.kind]!!
val comp = lhsPrecedence.compareTo(rhsPrecedence)
when (comp) {
-1 -> return rhs
1 -> return lhs
0 -> if (lhs.kind != Kind.DECIMAL) {
return lhs
}

else -> error("This shouldn't have occurred.")
}
val (precision, scale) = handleDecimal(lhs, rhs)
// TODO: Check if this is what we want
return when (precision > 38 || scale > 38) {
true -> PType.typeDecimalArbitrary()
false -> PType.typeDecimal(precision, scale)
}
}
}
Loading

0 comments on commit 12dda57

Please sign in to comment.