Skip to content

Commit

Permalink
fix: add literal conversions for intervalday and intervalyear
Browse files Browse the repository at this point in the history
  • Loading branch information
Blizzara committed Oct 24, 2024
1 parent bccdaf0 commit 9d56a80
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 16 deletions.
12 changes: 9 additions & 3 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,15 @@ private class ToSparkType
TimestampType
}

override def visit(expr: Type.IntervalDay): DayTimeIntervalType
override def visit(expr: Type.IntervalDay): DataType = {
if (expr.precision() != Util.MICROSECOND_PRECISION) {
throw new UnsupportedOperationException(
s"Unsupported precision for intervalDay: ${expr.precision()}")
}
DayTimeIntervalType.DEFAULT
}

override def visit(expr: Type.IntervalYear): YearMonthIntervalType
override def visit(expr: Type.IntervalYear): DataType = YearMonthIntervalType.DEFAULT

override def visit(expr: Type.ListType): DataType =
ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable())
Expand Down Expand Up @@ -113,7 +119,7 @@ class ToSubstraitType {
case DateType => Some(creator.DATE)
case TimestampNTZType => Some(creator.precisionTimestamp(Util.MICROSECOND_PRECISION))
case TimestampType => Some(creator.precisionTimestampTZ(Util.MICROSECOND_PRECISION))
case DayTimeIntervalType.DEFAULT => Some(creator.INTERVAL_DAY)
case DayTimeIntervalType.DEFAULT => Some(creator.intervalDay(Util.MICROSECOND_PRECISION))
case YearMonthIntervalType.DEFAULT => Some(creator.INTERVAL_YEAR)
case ArrayType(elementType, containsNull) =>
convert(elementType, Seq.empty, containsNull).map(creator.list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@ package io.substrait.spark.expression

import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType}
import io.substrait.spark.logical.ToLogicalPlan

import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.substrait.SparkTypeUtil
import org.apache.spark.unsafe.types.UTF8String

import io.substrait.`type`.{StringTypeVisitor, Type}
import io.substrait.{expression => exp}
import io.substrait.expression.{Expression => SExpression}
import io.substrait.util.DecimalUtil
import io.substrait.utils.Util
import io.substrait.utils.Util.MICROSECOND_PRECISION
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import io.substrait.utils.Util.SECONDS_PER_DAY

import scala.collection.JavaConverters.{asScalaBufferConverter, mapAsScalaMapConverter}
import scala.math.pow
Expand Down Expand Up @@ -86,11 +88,20 @@ class ToSparkExpression(
}

override def visit(expr: SExpression.PrecisionTimestampLiteral): Long = {
(expr.value() * pow(10, Util.MICROSECOND_PRECISION - expr.precision())).toLong
Util.toMicroseconds(expr.value(), expr.precision())
}

override def visit(expr: SExpression.PrecisionTimestampTZLiteral): Long = {
(expr.value() * pow(10, Util.MICROSECOND_PRECISION - expr.precision())).toLong
Util.toMicroseconds(expr.value(), expr.precision())
}

override def visit(expr: SExpression.IntervalDayLiteral): Long = {
(expr.days() * SECONDS_PER_DAY + expr.seconds()) * Util.MICROSECOND_PRECISION + Util
.toMicroseconds(expr.subseconds(), expr.precision())
}

override def visit(expr: SExpression.IntervalYearLiteral): Long = {
expr.years() * 12 + expr.months()
}

override def visit(expr: SExpression.ListLiteral): ArrayData = {
Expand All @@ -99,7 +110,8 @@ class ToSparkExpression(
}

override def visit(expr: SExpression.MapLiteral): MapData = {
val map = expr.values().asScala.map { case (key, value) => (key.accept(this), value.accept(this)) }
val map =
expr.values().asScala.map { case (key, value) => (key.accept(this), value.accept(this)) }
CatalystTypeConverters.convertToCatalyst(map).asInstanceOf[MapData]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,17 @@ class ToSubstraitLiteral {
scale: Int): SExpression.Literal =
decimal(false, d.toJavaBigDecimal, precision, scale)

private def sparkArray2Substrait(arrayData: ArrayData, elementType: DataType): SExpression.Literal =
list(false, JavaConverters.asJavaIterable(arrayData.array.map(any => apply(Literal(any, elementType)))))
private def sparkArray2Substrait(
arrayData: ArrayData,
elementType: DataType): SExpression.Literal =
list(
false,
JavaConverters.asJavaIterable(arrayData.array.map(any => apply(Literal(any, elementType)))))

private def sparkMap2Substrait(mapData: MapData, keyType: DataType, valueType: DataType): SExpression.Literal = {
private def sparkMap2Substrait(
mapData: MapData,
keyType: DataType,
valueType: DataType): SExpression.Literal = {
val keys = mapData.keyArray().array.map(any => apply(Literal(any, keyType)))
val values = mapData.valueArray().array.map(any => apply(Literal(any, valueType)))
map(false, JavaConverters.mapAsJavaMap(keys.zip(values).toMap))
Expand All @@ -56,8 +63,14 @@ class ToSubstraitLiteral {
val _fp64: Double => SExpression.Literal = fp64(false, _)
val _decimal: (Decimal, Int, Int) => SExpression.Literal = sparkDecimal2Substrait
val _date: Int => SExpression.Literal = date(false, _)
val _timestamp: Long => SExpression.Literal = precisionTimestamp(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
val _timestampTz: Long => SExpression.Literal = precisionTimestampTZ(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
val _timestamp: Long => SExpression.Literal =
precisionTimestamp(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
val _timestampTz: Long => SExpression.Literal =
precisionTimestampTZ(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
val _intervalDay: Long => SExpression.Literal = (ms: Long) =>
intervalDay(false, 0, 0, ms, Util.MICROSECOND_PRECISION)
val _intervalYear: Long => SExpression.Literal = (m: Long) =>
intervalYear(false, (m % 12).toInt, (m / 12).toInt)
val _string: String => SExpression.Literal = string(false, _)
val _binary: Array[Byte] => SExpression.Literal = binary(false, _)
val _array: (ArrayData, DataType) => SExpression.Literal = sparkArray2Substrait
Expand All @@ -80,10 +93,14 @@ class ToSubstraitLiteral {
case Literal(d: Integer, DateType) => Nonnull._date(d)
case Literal(t: Long, TimestampType) => Nonnull._timestampTz(t)
case Literal(t: Long, TimestampNTZType) => Nonnull._timestamp(t)
case Literal(d: Long, DayTimeIntervalType.DEFAULT) => Nonnull._intervalDay(d)
case Literal(ym: Long, YearMonthIntervalType.DEFAULT) => Nonnull._intervalYear(ym)
case Literal(u: UTF8String, StringType) => Nonnull._string(u.toString)
case Literal(b: Array[Byte], BinaryType) => Nonnull._binary(b)
case Literal(a: ArrayData, ArrayType(et, _)) => Nonnull._array(a, et) // TODO: handle containsNull
case Literal(m: MapData, MapType(keyType, valueType, _)) => Nonnull._map(m, keyType, valueType) // TODO: handle containsNull
case Literal(a: ArrayData, ArrayType(et, _)) =>
Nonnull._array(a, et) // TODO: handle containsNull
case Literal(m: MapData, MapType(keyType, valueType, _)) =>
Nonnull._map(m, keyType, valueType) // TODO: handle containsNull
case _ => null
}
)
Expand Down
13 changes: 13 additions & 0 deletions spark/src/main/scala/io/substrait/utils/Util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,21 @@ import scala.collection.mutable.ArrayBuffer

object Util {

val SECONDS_PER_DAY: Long = 24 * 60 * 60;
val MICROSECOND_PRECISION = 6; // for PrecisionTimestamp(TZ) types

def toMicroseconds(value: Long, precision: Int): Long = {
// Spark uses microseconds as a Long value for most time things
val factor = MICROSECOND_PRECISION - precision
if (factor == 0) {
value
} else if (factor > 0) {
value * math.pow(10, factor).toLong
} else {
value / math.pow(10, -factor).toLong
}
}

/**
* Compute the cartesian product for n lists.
*
Expand Down

0 comments on commit 9d56a80

Please sign in to comment.