Skip to content

Commit

Permalink
reuse timeFormatters when serializing value to Hive string
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenliang123 committed Dec 6, 2023
1 parent 9d3bf85 commit f250a4d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,17 @@ import scala.collection.JavaConverters._
import org.apache.hive.service.rpc.thrift._
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.execution.HiveResult.TimeFormatters
import org.apache.spark.sql.execution.HiveResult.{getTimeFormatters, TimeFormatters}
import org.apache.spark.sql.types._

import org.apache.kyuubi.util.RowSetUtils._

object RowSet {

private val timeUnrelatedDataTypes: Set[DataType] =
Set(BooleanType, FloatType, BinaryType, StringType)

def toHiveString(valueAndType: (Any, DataType), nested: Boolean = false): String = {
// compatible w/ Spark 3.1 and above
val timeFormatters: TimeFormatters = valueAndType match {
case (_, dataType)
if dataType.isInstanceOf[NumericType] || timeUnrelatedDataTypes.contains(dataType) =>
null
case _ => HiveResult.getTimeFormatters
}
def toHiveString(
valueAndType: (Any, DataType),
nested: Boolean = false,
timeFormatters: TimeFormatters): String = {
HiveResult.toHiveString(valueAndType, nested, timeFormatters)
}

Expand Down Expand Up @@ -80,14 +73,15 @@ object RowSet {
def toRowBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
val rowSize = rows.length
val tRows = new java.util.ArrayList[TRow](rowSize)
val timeFormatters = getTimeFormatters
var i = 0
while (i < rowSize) {
val row = rows(i)
val tRow = new TRow()
var j = 0
val columnSize = row.length
while (j < columnSize) {
val columnValue = toTColumnValue(j, row, schema)
val columnValue = toTColumnValue(j, row, schema, timeFormatters)
tRow.addToColVals(columnValue)
j += 1
}
Expand All @@ -100,18 +94,23 @@ object RowSet {
def toColumnBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
val rowSize = rows.length
val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](rowSize))
val timeFormatters = getTimeFormatters
var i = 0
val columnSize = schema.length
while (i < columnSize) {
val field = schema(i)
val tColumn = toTColumn(rows, i, field.dataType)
val tColumn = toTColumn(rows, i, field.dataType, timeFormatters)
tRowSet.addToColumns(tColumn)
i += 1
}
tRowSet
}

private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType): TColumn = {
private def toTColumn(
rows: Seq[Row],
ordinal: Int,
typ: DataType,
timeFormatters: TimeFormatters): TColumn = {
val nulls = new java.util.BitSet()
typ match {
case BooleanType =>
Expand Down Expand Up @@ -161,7 +160,7 @@ object RowSet {
while (i < rowSize) {
val row = rows(i)
nulls.set(i, row.isNullAt(ordinal))
values.add(toHiveString(row.get(ordinal) -> typ))
values.add(toHiveString(row.get(ordinal) -> typ, timeFormatters = timeFormatters))
i += 1
}
TColumn.stringVal(new TStringColumn(values, nulls))
Expand Down Expand Up @@ -193,7 +192,8 @@ object RowSet {
private def toTColumnValue(
ordinal: Int,
row: Row,
types: StructType): TColumnValue = {
types: StructType,
timeFormatters: TimeFormatters): TColumnValue = {
types(ordinal).dataType match {
case BooleanType =>
val boolValue = new TBoolValue
Expand Down Expand Up @@ -241,7 +241,9 @@ object RowSet {
case _ =>
val tStrValue = new TStringValue
if (!row.isNullAt(ordinal)) {
tStrValue.setValue(toHiveString(row.get(ordinal) -> types(ordinal).dataType))
tStrValue.setValue(toHiveString(
row.get(ordinal) -> types(ordinal).dataType,
timeFormatters = timeFormatters))
}
TColumnValue.stringVal(tStrValue)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.HiveResult.getTimeFormatters
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand Down Expand Up @@ -103,19 +104,20 @@ object SparkDatasetHelper extends Logging {
timestampAsString: Boolean): DataFrame = {

val quotedCol = (name: String) => col(quoteIfNeeded(name))
val tf = getTimeFormatters

// an udf to call `RowSet.toHiveString` on complex types(struct/array/map) and timestamp type.
val toHiveStringUDF = udf[String, Row, String]((row, schemaDDL) => {
val dt = DataType.fromDDL(schemaDDL)
dt match {
case StructType(Array(StructField(_, st: StructType, _, _))) =>
RowSet.toHiveString((row, st), nested = true)
RowSet.toHiveString((row, st), nested = true, timeFormatters = tf)
case StructType(Array(StructField(_, at: ArrayType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, at), nested = true)
RowSet.toHiveString((row.toSeq.head, at), nested = true, timeFormatters = tf)
case StructType(Array(StructField(_, mt: MapType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, mt), nested = true)
RowSet.toHiveString((row.toSeq.head, mt), nested = true, timeFormatters = tf)
case StructType(Array(StructField(_, tt: TimestampType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, tt), nested = true)
RowSet.toHiveString((row.toSeq.head, tt), nested = true, timeFormatters = tf)
case _ =>
throw new UnsupportedOperationException
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.collection.JavaConverters._

import org.apache.hive.service.rpc.thrift.TProtocolVersion
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.HiveResult.{getTimeFormatters, TimeFormatters}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -97,6 +98,8 @@ class RowSetSuite extends KyuubiFunSuite {

private val rows: Seq[Row] = (0 to 10).map(genRow) ++ Seq(Row.fromSeq(Seq.fill(17)(null)))

private val tf: TimeFormatters = getTimeFormatters

test("column based set") {
val tRowSet = RowSet.toColumnBasedSet(rows, schema)
assert(tRowSet.getColumns.size() === schema.size)
Expand Down Expand Up @@ -165,14 +168,18 @@ class RowSetSuite extends KyuubiFunSuite {
dateCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "NULL")
case (b, i) =>
assert(b === RowSet.toHiveString(Date.valueOf(s"2018-11-${i + 1}") -> DateType))
assert(b === RowSet.toHiveString(
Date.valueOf(s"2018-11-${i + 1}") -> DateType,
timeFormatters = tf))
}

val tsCol = cols.next().getStringVal
tsCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b ===
RowSet.toHiveString(Timestamp.valueOf(s"2018-11-17 13:33:33.$i") -> TimestampType))
RowSet.toHiveString(
Timestamp.valueOf(s"2018-11-17 13:33:33.$i") -> TimestampType,
timeFormatters = tf))
}

val binCol = cols.next().getBinaryVal
Expand All @@ -185,14 +192,16 @@ class RowSetSuite extends KyuubiFunSuite {
arrCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === RowSet.toHiveString(
Array.fill(i)(java.lang.Double.valueOf(s"$i.$i")).toSeq -> ArrayType(DoubleType)))
Array.fill(i)(java.lang.Double.valueOf(s"$i.$i")).toSeq -> ArrayType(DoubleType),
timeFormatters = tf))
}

val mapCol = cols.next().getStringVal
mapCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === RowSet.toHiveString(
Map(i -> java.lang.Double.valueOf(s"$i.$i")) -> MapType(IntegerType, DoubleType)))
Map(i -> java.lang.Double.valueOf(s"$i.$i")) -> MapType(IntegerType, DoubleType),
timeFormatters = tf))
}

val intervalCol = cols.next().getStringVal
Expand Down Expand Up @@ -241,7 +250,7 @@ class RowSetSuite extends KyuubiFunSuite {
val r8 = iter.next().getColVals
assert(r8.get(12).getStringVal.getValue === Array.fill(7)(7.7d).mkString("[", ",", "]"))
assert(r8.get(13).getStringVal.getValue ===
RowSet.toHiveString(Map(7 -> 7.7d) -> MapType(IntegerType, DoubleType)))
RowSet.toHiveString(Map(7 -> 7.7d) -> MapType(IntegerType, DoubleType), timeFormatters = tf))

val r9 = iter.next().getColVals
assert(r9.get(14).getStringVal.getValue === new CalendarInterval(8, 8, 8).toString)
Expand Down

0 comments on commit f250a4d

Please sign in to comment.