Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix LORE OOM #46

Merged
merged 3 commits into from
Jul 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/DumpUtils.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -149,30 +149,24 @@ object DumpUtils extends Logging {
class ParquetDumper(private val outputStream: OutputStream, table: Table) extends HostBufferConsumer
with AutoCloseable {
private[this] val tempBuffer = new Array[Byte](128 * 1024)
private[this] val buffers = mutable.Queue[(HostMemoryBuffer, Long)]()

def this(path: String, table: Table) = {
this(new FileOutputStream(path), table)
}

val tableWriter: TableWriter = {
private lazy val tableWriter: TableWriter = {
// avoid anything conversion, just dump as it is
val builder = ParquetDumper.parquetWriterOptionsFromTable(ParquetWriterOptions.builder(), table)
.withCompressionType(ParquetDumper.COMPRESS_TYPE)
Table.writeParquetChunked(builder.build(), this)
}

override
def handleBuffer(buffer: HostMemoryBuffer, len: Long): Unit =
buffers += Tuple2(buffer, len)

def writeBufferedData(): Unit = {
ColumnarOutputWriter.writeBufferedData(buffers, tempBuffer, outputStream)
}
override def handleBuffer(buffer: HostMemoryBuffer, len: Long): Unit =
ColumnarOutputWriter.writeBufferedData(mutable.Queue((buffer, len)), tempBuffer,
outputStream)

def writeTable(table: Table): Unit = {
tableWriter.write(table)
writeBufferedData()
}

/**
Expand All @@ -181,7 +175,6 @@ class ParquetDumper(private val outputStream: OutputStream, table: Table) extend
*/
def close(): Unit = {
tableWriter.close()
writeBufferedData()
outputStream.close()
}
}
Expand Down
Loading