Skip to content

Commit

Permalink
Close orc fetchOrcStatement when ExecuteStatement close
Browse files Browse the repository at this point in the history
  • Loading branch information
lsm1 committed Nov 30, 2023
1 parent d7863c0 commit 86a9db5
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class ExecuteStatement(
override def getOperationLog: Option[OperationLog] = Option(operationLog)
override protected def supportProgress: Boolean = true

private var fetchOrcStatement: Option[FetchOrcStatement] = None
override protected def resultSchema: StructType = {
if (result == null || result.schema.isEmpty) {
new StructType().add("Result", "string")
Expand All @@ -65,6 +66,11 @@ class ExecuteStatement(
OperationLog.removeCurrentOperationLog()
}

override def close(): Unit = {
super.close()
fetchOrcStatement.foreach(_.close())
}

protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
resultDF.toLocalIterator().asScala
}
Expand Down Expand Up @@ -174,7 +180,8 @@ class ExecuteStatement(
.option("compression", "zstd").format("orc").save(fileName)
}
info(s"Save result to $fileName")
return new FetchOrcStatement(spark).getIterator(fileName, resultSchema)
fetchOrcStatement = Some(new FetchOrcStatement(spark))
return fetchOrcStatement.get.getIterator(fileName, resultSchema)
}
val internalArray = if (resultMaxRows <= 0) {
info("Execute in full collect mode")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,13 @@ import org.apache.spark.sql.execution.datasources.orc.OrcDeserializer
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.KyuubiException
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil.SPARK_ENGINE_RUNTIME_VERSION
import org.apache.kyuubi.operation.{FetchIterator, IterableFetchIterator}
import org.apache.kyuubi.util.reflect.DynConstructors

class FetchOrcStatement(spark: SparkSession) {

var orcIter: OrcFileIterator = _
def getIterator(path: String, orcSchema: StructType): FetchIterator[Row] = {
val conf = spark.sparkContext.hadoopConfiguration
val savePath = new Path(path)
Expand All @@ -59,21 +63,42 @@ class FetchOrcStatement(spark: SparkSession) {
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
val deserializer = getOrcDeserializer(orcSchema, colId)
val iter = new OrcFileIterator(list)
val iterRow = iter.map(value =>
orcIter = new OrcFileIterator(list)
val iterRow = orcIter.map(value =>
unsafeProjection(deserializer.deserialize(value)))
.map(value => toRowConverter(value))
new IterableFetchIterator[Row](iterRow.toIterable)
}

def close(): Unit = {
orcIter.close()
}

private def getOrcDeserializer(orcSchema: StructType, colId: Array[Int]): OrcDeserializer = {
try {
val cls = Class.forName("org.apache.spark.sql.execution.datasources.orc.OrcDeserializer")
val constructor = cls.getDeclaredConstructors.apply(0)
if (constructor.getParameterCount == 3) {
constructor.newInstance(new StructType, orcSchema, colId).asInstanceOf[OrcDeserializer]
if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") {
// https://issues.apache.org/jira/browse/SPARK-34535
DynConstructors.builder()
.impl(
classOf[OrcDeserializer],
classOf[StructType],
classOf[Array[Int]])
.build[OrcDeserializer]()
.newInstance(
orcSchema,
colId)
} else {
constructor.newInstance(orcSchema, colId).asInstanceOf[OrcDeserializer]
DynConstructors.builder()
.impl(
classOf[OrcDeserializer],
classOf[StructType],
classOf[StructType],
classOf[Array[Int]])
.build[OrcDeserializer]()
.newInstance(
new StructType,
orcSchema,
colId)
}
} catch {
case e: Throwable =>
Expand All @@ -84,7 +109,7 @@ class FetchOrcStatement(spark: SparkSession) {

class OrcFileIterator(fileList: ListBuffer[LocatedFileStatus]) extends Iterator[OrcStruct] {

val iters = fileList.map(x => getOrcFileIterator(x))
private val iters = fileList.map(x => getOrcFileIterator(x))

var idx = 0

Expand All @@ -106,6 +131,10 @@ class OrcFileIterator(fileList: ListBuffer[LocatedFileStatus]) extends Iterator[
}
}

def close(): Unit = {
iters.foreach(_.close())
}

private def getOrcFileIterator(file: LocatedFileStatus): RecordReaderIterator[OrcStruct] = {
val orcRecordReader = {
val split =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1912,6 +1912,7 @@ object KyuubiConf {
.doc("The threshold of Spark result save to hdfs file, default value is 200 MB")
.version("1.9.0")
.intConf
.checkValue(_ > 0, "must be positive value")
.createWithDefault(209715200)

val OPERATION_INCREMENTAL_COLLECT: ConfigEntry[Boolean] =
Expand Down

0 comments on commit 86a9db5

Please sign in to comment.