From f5d21a9cf06e946ac8dcad5a82c7d7d6fd2fe013 Mon Sep 17 00:00:00 2001 From: "Hongbin Ma (Mahone)" Date: Mon, 1 Jul 2024 17:14:09 +0800 Subject: [PATCH 1/9] workable version without tests Signed-off-by: Hongbin Ma (Mahone) --- .../scala/com/nvidia/spark/rapids/Arm.scala | 16 +- .../spark/rapids/GpuAggregateExec.scala | 426 +++++++++++++----- .../com/nvidia/spark/rapids/GpuExec.scala | 2 + .../com/nvidia/spark/rapids/RapidsConf.scala | 9 + 4 files changed, 327 insertions(+), 126 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala index 926f770a683..b0cd798c179 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.ControlThrowable import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -134,6 +134,20 @@ object Arm extends ArmScalaSpecificImpl { } } + /** Executes the provided code block, closing the resources only if an exception occurs */ + def closeOnExcept[T <: AutoCloseable, V](r: ListBuffer[T])(block: ListBuffer[T] => V): V = { + try { + block(r) + } catch { + case t: ControlThrowable => + // Don't close for these cases.. + throw t + case t: Throwable => + r.safeClose(t) + throw t + } + } + /** Executes the provided code block, closing the resources only if an exception occurs */ def closeOnExcept[T <: AutoCloseable, V](r: mutable.Queue[T])(block: mutable.Queue[T] => V): V = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala index 7e6a1056d01..b28101f3442 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala @@ -16,11 +16,9 @@ package com.nvidia.spark.rapids -import java.util - import scala.annotation.tailrec -import scala.collection.JavaConverters.collectionAsScalaIterableConverter import scala.collection.mutable +import scala.collection.mutable.ListBuffer import ai.rapids.cudf import ai.rapids.cudf.{NvtxColor, NvtxRange} @@ -46,11 +44,11 @@ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.{ExplainUtils, SortExec, SparkPlan} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.rapids.aggregate.{CpuToGpuAggregateBufferConverter, CudfAggregate, GpuAggregateExpression, GpuToCpuAggregateBufferConverter} -import org.apache.spark.sql.rapids.execution.{GpuShuffleMeta, TrampolineUtil} +import org.apache.spark.sql.rapids.execution.{GpuBatchSubPartitioner, GpuShuffleMeta, TrampolineUtil} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -object AggregateUtils { +object AggregateUtils extends Logging { private val aggs = List("min", "max", "avg", "sum", "count", "first", "last") @@ -86,9 +84,10 @@ object AggregateUtils { /** * Computes a target input batch size based on the assumption that computation can consume up to * 4X the configured batch size. - * @param confTargetSize user-configured maximum desired batch size - * @param inputTypes input batch schema - * @param outputTypes output batch schema + * + * @param confTargetSize user-configured maximum desired batch size + * @param inputTypes input batch schema + * @param outputTypes output batch schema * @param isReductionOnly true if this is a reduction-only aggregation without grouping * @return maximum target batch size to keep computation under the 4X configured batch limit */ @@ -99,6 +98,7 @@ object AggregateUtils { isReductionOnly: Boolean): Long = { def typesToSize(types: Seq[DataType]): Long = types.map(GpuBatchUtils.estimateGpuMemory(_, nullable = false, rowCount = 1)).sum + val inputRowSize = typesToSize(inputTypes) val outputRowSize = typesToSize(outputTypes) // The cudf hash table implementation allocates four 32-bit integers per input row. @@ -124,6 +124,129 @@ object AggregateUtils { // Finally compute the input target batching size taking into account the cudf row limits Math.min(inputRowSize * maxRows, Int.MaxValue) } + + + /** + * Concatenate batches together and perform a merge aggregation on the result. The input batches + * will be closed as part of this operation. + * + * @param batches batches to concatenate and merge aggregate + * @return lazy spillable batch which has NOT been marked spillable + */ + private def concatenateAndMerge( + batches: mutable.Buffer[SpillableColumnarBatch], + metrics: GpuHashAggregateMetrics, + concatAndMergeHelper: AggHelper): SpillableColumnarBatch = { + // TODO: concatenateAndMerge (and calling code) could output a sequence + // of batches for the partial aggregate case. This would be done in case + // a retry failed a certain number of times. + val concatBatch = withResource(batches) { _ => + val concatSpillable = concatenateBatches(metrics, batches.toSeq) + withResource(concatSpillable) { + _.getColumnarBatch() + } + } + computeAggregateAndClose(metrics, concatBatch, concatAndMergeHelper) + } + + /** + * Perform a single pass over the aggregated batches attempting to merge adjacent batches. + * + * @return true if at least one merge operation occurred + */ + private def mergePass( + aggregatedBatches: mutable.Buffer[SpillableColumnarBatch], + targetMergeBatchSize: Long, + helper: AggHelper, + metrics: GpuHashAggregateMetrics + ): Boolean = { + val batchesToConcat: mutable.ArrayBuffer[SpillableColumnarBatch] = mutable.ArrayBuffer.empty + var wasBatchMerged = false + // Current size in bytes of the batches targeted for the next concatenation + var concatSize: Long = 0L + var batchesLeftInPass = aggregatedBatches.size + + while (batchesLeftInPass > 0) { + closeOnExcept(batchesToConcat) { _ => + var isConcatSearchFinished = false + // Old batches are picked up at the front of the queue and freshly merged batches are + // appended to the back of the queue. Although tempting to allow the pass to "wrap around" + // and pick up batches freshly merged in this pass, it's avoided to prevent changing the + // order of aggregated batches. + while (batchesLeftInPass > 0 && !isConcatSearchFinished) { + val candidate = aggregatedBatches.head + val potentialSize = concatSize + candidate.sizeInBytes + isConcatSearchFinished = concatSize > 0 && potentialSize > targetMergeBatchSize + if (!isConcatSearchFinished) { + batchesLeftInPass -= 1 + batchesToConcat += aggregatedBatches.remove(0) + concatSize = potentialSize + } + } + } + + val mergedBatch = if (batchesToConcat.length > 1) { + wasBatchMerged = true + concatenateAndMerge(batchesToConcat, metrics, helper) + } else { + // Unable to find a neighboring buffer to produce a valid merge in this pass, + // so simply put this buffer back on the queue for other passes. + batchesToConcat.remove(0) + } + + // Add the merged batch to the end of the aggregated batch queue. Only a single pass over + // the batches is being performed due to the batch count check above, so the single-pass + // loop will terminate before picking up this new batch. + aggregatedBatches += mergedBatch + batchesToConcat.clear() + concatSize = 0 + } + + wasBatchMerged + } + + + /** + * Attempt to merge adjacent batches in the aggregatedBatches queue until either there is only + * one batch or merging adjacent batches would exceed the target batch size. + */ + def tryMergeAggregatedBatches( + aggregatedBatches: mutable.Buffer[SpillableColumnarBatch], + isReductionOnly: Boolean, + metrics: GpuHashAggregateMetrics, + targetMergeBatchSize: Long, + helper: AggHelper + ): Unit = { + while (aggregatedBatches.size > 1) { + val concatTime = metrics.concatTime + val opTime = metrics.opTime + withResource(new NvtxWithMetrics("agg merge pass", NvtxColor.BLUE, concatTime, + opTime)) { _ => + // continue merging as long as some batches are able to be combined + if (!mergePass(aggregatedBatches, targetMergeBatchSize, helper, metrics)) + if (aggregatedBatches.size > 1 && isReductionOnly) { + // We were unable to merge the aggregated batches within the target batch size limit, + // which means normally we would fallback to a sort-based approach. However for + // reduction-only aggregation there are no keys to use for a sort. The only way this + // can work is if all batches are merged. This will exceed the target batch size limit, + // but at this point it is either risk an OOM/cudf error and potentially work or + // not work at all. + logWarning(s"Unable to merge reduction-only aggregated batches within " + + s"target batch limit of $targetMergeBatchSize, attempting to merge remaining " + + s"${aggregatedBatches.size} batches beyond limit") + withResource(mutable.ArrayBuffer[SpillableColumnarBatch]()) { batchesToConcat => + aggregatedBatches.foreach(b => batchesToConcat += b) + aggregatedBatches.clear() + val batch = concatenateAndMerge(batchesToConcat, metrics, helper) + // batch does not need to be marked spillable since it is the last and only batch + // and will be immediately retrieved on the next() call. + aggregatedBatches += batch + } + } + return + } + } + } } /** Utility class to hold all of the metrics related to hash aggregation */ @@ -135,6 +258,7 @@ case class GpuHashAggregateMetrics( computeAggTime: GpuMetric, concatTime: GpuMetric, sortTime: GpuMetric, + repartitionTime: GpuMetric, numAggOps: GpuMetric, numPreSplits: GpuMetric, singlePassTasks: GpuMetric, @@ -711,6 +835,8 @@ object GpuAggFinalPassIterator { * @param useTieredProject user-specified option to enable tiered projections * @param allowNonFullyAggregatedOutput if allowed to skip third pass Agg * @param skipAggPassReductionRatio skip if the ratio of rows after a pass is bigger than this value + * @param aggFallbackAlgorithm use sort-based fallback or repartition-based fallback + * for oversize agg * @param localInputRowsCount metric to track the number of input rows processed locally */ class GpuMergeAggregateIterator( @@ -726,15 +852,17 @@ class GpuMergeAggregateIterator( useTieredProject: Boolean, allowNonFullyAggregatedOutput: Boolean, skipAggPassReductionRatio: Double, + aggFallbackAlgorithm: String, localInputRowsCount: LocalGpuMetric) extends Iterator[ColumnarBatch] with AutoCloseable with Logging { private[this] val isReductionOnly = groupingExpressions.isEmpty private[this] val targetMergeBatchSize = computeTargetMergeBatchSize(configuredTargetBatchSize) - private[this] val aggregatedBatches = new util.ArrayDeque[SpillableColumnarBatch] + private[this] val aggregatedBatches = ListBuffer.empty[SpillableColumnarBatch] private[this] var outOfCoreIter: Option[GpuOutOfCoreSortIterator] = None + private[this] var repartitionIter: Option[RepartitionAggregateIterator] = None /** Iterator for fetching aggregated batches either if: - * 1. a sort-based fallback has occurred + * 1. a sort-based/repartition-based fallback has occurred * 2. skip third pass agg has occurred **/ private[this] var fallbackIter: Option[Iterator[ColumnarBatch]] = None @@ -752,7 +880,7 @@ class GpuMergeAggregateIterator( override def hasNext: Boolean = { fallbackIter.map(_.hasNext).getOrElse { // reductions produce a result even if the input is empty - hasReductionOnlyBatch || !aggregatedBatches.isEmpty || firstPassIter.hasNext + hasReductionOnlyBatch || aggregatedBatches.nonEmpty || firstPassIter.hasNext } } @@ -769,9 +897,11 @@ class GpuMergeAggregateIterator( if (isReductionOnly || skipAggPassReductionRatio * localInputRowsCount.value >= rowsAfterFirstPassAgg) { // second pass agg - tryMergeAggregatedBatches() + AggregateUtils.tryMergeAggregatedBatches( + aggregatedBatches, isReductionOnly, + metrics, targetMergeBatchSize, concatAndMergeHelper) - val rowsAfterSecondPassAgg = aggregatedBatches.asScala.foldLeft(0L) { + val rowsAfterSecondPassAgg = aggregatedBatches.foldLeft(0L) { (totalRows, batch) => totalRows + batch.numRows() } shouldSkipThirdPassAgg = @@ -784,7 +914,7 @@ class GpuMergeAggregateIterator( } } - if (aggregatedBatches.size() > 1) { + if (aggregatedBatches.size > 1) { // Unable to merge to a single output, so must fall back if (allowNonFullyAggregatedOutput && shouldSkipThirdPassAgg) { // skip third pass agg, return the aggregated batches directly @@ -792,17 +922,23 @@ class GpuMergeAggregateIterator( s"${skipAggPassReductionRatio * 100}% of " + s"rows after first pass, skip the third pass agg") fallbackIter = Some(new Iterator[ColumnarBatch] { - override def hasNext: Boolean = !aggregatedBatches.isEmpty + override def hasNext: Boolean = aggregatedBatches.nonEmpty override def next(): ColumnarBatch = { - withResource(aggregatedBatches.pop()) { spillableBatch => + withResource(aggregatedBatches.remove(0)) { spillableBatch => spillableBatch.getColumnarBatch() } } }) } else { // fallback to sort agg, this is the third pass agg - fallbackIter = Some(buildSortFallbackIterator()) + aggFallbackAlgorithm.toLowerCase match { + case "repartition" => + fallbackIter = Some(buildRepartitionFallbackIterator()) + case "sort" => fallbackIter = Some(buildSortFallbackIterator()) + case _ => throw new IllegalArgumentException( + s"Unsupported aggregation fallback algorithm: $aggFallbackAlgorithm") + } } fallbackIter.get.next() } else if (aggregatedBatches.isEmpty) { @@ -815,7 +951,7 @@ class GpuMergeAggregateIterator( } else { // this will be the last batch hasReductionOnlyBatch = false - withResource(aggregatedBatches.pop()) { spillableBatch => + withResource(aggregatedBatches.remove(0)) { spillableBatch => spillableBatch.getColumnarBatch() } } @@ -823,10 +959,12 @@ class GpuMergeAggregateIterator( } override def close(): Unit = { - aggregatedBatches.forEach(_.safeClose()) + aggregatedBatches.foreach(_.safeClose()) aggregatedBatches.clear() outOfCoreIter.foreach(_.close()) outOfCoreIter = None + repartitionIter.foreach(_.close()) + repartitionIter = None fallbackIter = None hasReductionOnlyBatch = false } @@ -843,133 +981,161 @@ class GpuMergeAggregateIterator( while (firstPassIter.hasNext) { val batch = firstPassIter.next() rowsAfter += batch.numRows() - aggregatedBatches.add(batch) + aggregatedBatches += batch } rowsAfter } - /** - * Attempt to merge adjacent batches in the aggregatedBatches queue until either there is only - * one batch or merging adjacent batches would exceed the target batch size. - */ - private def tryMergeAggregatedBatches(): Unit = { - while (aggregatedBatches.size() > 1) { - val concatTime = metrics.concatTime - val opTime = metrics.opTime - withResource(new NvtxWithMetrics("agg merge pass", NvtxColor.BLUE, concatTime, - opTime)) { _ => - // continue merging as long as some batches are able to be combined - if (!mergePass()) { - if (aggregatedBatches.size() > 1 && isReductionOnly) { - // We were unable to merge the aggregated batches within the target batch size limit, - // which means normally we would fallback to a sort-based approach. However for - // reduction-only aggregation there are no keys to use for a sort. The only way this - // can work is if all batches are merged. This will exceed the target batch size limit, - // but at this point it is either risk an OOM/cudf error and potentially work or - // not work at all. - logWarning(s"Unable to merge reduction-only aggregated batches within " + - s"target batch limit of $targetMergeBatchSize, attempting to merge remaining " + - s"${aggregatedBatches.size()} batches beyond limit") - withResource(mutable.ArrayBuffer[SpillableColumnarBatch]()) { batchesToConcat => - aggregatedBatches.forEach(b => batchesToConcat += b) - aggregatedBatches.clear() - val batch = concatenateAndMerge(batchesToConcat) - // batch does not need to be marked spillable since it is the last and only batch - // and will be immediately retrieved on the next() call. - aggregatedBatches.add(batch) - } - } - return + private lazy val concatAndMergeHelper = + new AggHelper(inputAttributes, groupingExpressions, aggregateExpressions, + forceMerge = true, useTieredProject = useTieredProject) + + private def cbIteratorStealingFromBuffer(input: ListBuffer[SpillableColumnarBatch]) = { + val aggregatedBatchIter = new Iterator[ColumnarBatch] { + override def hasNext: Boolean = input.nonEmpty + + override def next(): ColumnarBatch = { + withResource(input.remove(0)) { spillable => + spillable.getColumnarBatch() } } } + aggregatedBatchIter } - /** - * Perform a single pass over the aggregated batches attempting to merge adjacent batches. - * @return true if at least one merge operation occurred - */ - private def mergePass(): Boolean = { - val batchesToConcat: mutable.ArrayBuffer[SpillableColumnarBatch] = mutable.ArrayBuffer.empty - var wasBatchMerged = false - // Current size in bytes of the batches targeted for the next concatenation - var concatSize: Long = 0L - var batchesLeftInPass = aggregatedBatches.size() + private case class RepartitionAggregateIterator( + inputBatches: ListBuffer[SpillableColumnarBatch], + hashKeys: Seq[GpuExpression], + targetSize: Long, + opTime: GpuMetric, + repartitionTime: GpuMetric) extends Iterator[ColumnarBatch] + with AutoCloseable { - while (batchesLeftInPass > 0) { - closeOnExcept(batchesToConcat) { _ => - var isConcatSearchFinished = false - // Old batches are picked up at the front of the queue and freshly merged batches are - // appended to the back of the queue. Although tempting to allow the pass to "wrap around" - // and pick up batches freshly merged in this pass, it's avoided to prevent changing the - // order of aggregated batches. - while (batchesLeftInPass > 0 && !isConcatSearchFinished) { - val candidate = aggregatedBatches.getFirst - val potentialSize = concatSize + candidate.sizeInBytes - isConcatSearchFinished = concatSize > 0 && potentialSize > targetMergeBatchSize - if (!isConcatSearchFinished) { - batchesLeftInPass -= 1 - batchesToConcat += aggregatedBatches.removeFirst() - concatSize = potentialSize + case class AggregatePartition(batches: ListBuffer[SpillableColumnarBatch], seed: Int) + extends AutoCloseable { + override def close(): Unit = { + batches.safeClose() + } + + def totalRows(): Long = batches.map(_.numRows()).sum + + def totalSize(): Long = batches.map(_.sizeInBytes).sum + + def split(): ListBuffer[AggregatePartition] = { + withResource(new NvtxWithMetrics("agg repartition", NvtxColor.CYAN, repartitionTime)) { _ => + if (seed > hashSeed + 20) { + throw new IllegalStateException("At most repartition 3 times for a partition") + } + val totalSize = batches.map(_.sizeInBytes).sum + val newSeed = seed + 10 + val iter = cbIteratorStealingFromBuffer(batches) + withResource(new GpuBatchSubPartitioner( + iter, hashKeys, computeNumPartitions(totalSize), newSeed, "aggRepartition")) { + partitioner => + closeOnExcept(ListBuffer.empty[AggregatePartition]) { partitions => + preparePartitions(newSeed, partitioner, partitions) + partitions + } } } } + } - val mergedBatch = if (batchesToConcat.length > 1) { - wasBatchMerged = true - concatenateAndMerge(batchesToConcat) - } else { - // Unable to find a neighboring buffer to produce a valid merge in this pass, - // so simply put this buffer back on the queue for other passes. - batchesToConcat.remove(0) + private def preparePartitions( + newSeed: Int, + partitioner: GpuBatchSubPartitioner, + partitions: ListBuffer[AggregatePartition]): Unit = { + (0 until partitioner.partitionsCount).foreach { id => + val buffer = ListBuffer.empty[SpillableColumnarBatch] + buffer ++= partitioner.releaseBatchesByPartition(id) + val newPart = AggregatePartition.apply(buffer, newSeed) + if (newPart.totalRows() > 0) { + partitions += newPart + } else { + newPart.safeClose() + } } + } - // Add the merged batch to the end of the aggregated batch queue. Only a single pass over - // the batches is being performed due to the batch count check above, so the single-pass - // loop will terminate before picking up this new batch. - aggregatedBatches.addLast(mergedBatch) - batchesToConcat.clear() - concatSize = 0 + private[this] def computeNumPartitions(totalSize: Long): Int = { + Math.floorDiv(totalSize, targetMergeBatchSize).toInt + 1 } - wasBatchMerged - } + private val hashSeed = 100 + private val aggPartitions = ListBuffer.empty[AggregatePartition] + private val deferredAggPartitions = ListBuffer.empty[AggregatePartition] + deferredAggPartitions += AggregatePartition.apply(inputBatches, hashSeed) - private lazy val concatAndMergeHelper = - new AggHelper(inputAttributes, groupingExpressions, aggregateExpressions, - forceMerge = true, useTieredProject = useTieredProject) + override def hasNext: Boolean = aggPartitions.nonEmpty || deferredAggPartitions.nonEmpty - /** - * Concatenate batches together and perform a merge aggregation on the result. The input batches - * will be closed as part of this operation. - * @param batches batches to concatenate and merge aggregate - * @return lazy spillable batch which has NOT been marked spillable - */ - private def concatenateAndMerge( - batches: mutable.ArrayBuffer[SpillableColumnarBatch]): SpillableColumnarBatch = { - // TODO: concatenateAndMerge (and calling code) could output a sequence - // of batches for the partial aggregate case. This would be done in case - // a retry failed a certain number of times. - val concatBatch = withResource(batches) { _ => - val concatSpillable = concatenateBatches(metrics, batches.toSeq) - withResource(concatSpillable) { _.getColumnarBatch() } + override def next(): ColumnarBatch = { + withResource(new NvtxWithMetrics("RepartitionAggregateIterator.next", + NvtxColor.BLUE, opTime)) { _ => + if (aggPartitions.isEmpty && deferredAggPartitions.nonEmpty) { + val headDeferredPartition = deferredAggPartitions.remove(0) + withResource(headDeferredPartition) { _ => + aggPartitions ++= headDeferredPartition.split() + } + return next() + } + + val headPartition = aggPartitions.remove(0) + if (headPartition.totalSize() > targetMergeBatchSize) { + deferredAggPartitions += headPartition + return next() + } + + withResource(headPartition) { _ => + val batchSizeBeforeMerge = headPartition.batches.size + AggregateUtils.tryMergeAggregatedBatches( + headPartition.batches, isReductionOnly, metrics, + targetMergeBatchSize, concatAndMergeHelper) + if (headPartition.batches.size != 1) { + throw new IllegalStateException( + "Expected a single batch after tryMergeAggregatedBatches, but got " + + s"${headPartition.batches.size} batches. Before merge, there were " + + s"$batchSizeBeforeMerge batches.") + } + headPartition.batches.head.getColumnarBatch() + } + } + } + + override def close(): Unit = { + aggPartitions.foreach(_.safeClose()) + deferredAggPartitions.foreach(_.safeClose()) } - computeAggregateAndClose(metrics, concatBatch, concatAndMergeHelper) } + /** Build an iterator that uses a sort-based approach to merge aggregated batches together. */ - private def buildSortFallbackIterator(): Iterator[ColumnarBatch] = { - logInfo(s"Falling back to sort-based aggregation with ${aggregatedBatches.size()} batches") + private def buildRepartitionFallbackIterator(): Iterator[ColumnarBatch] = { + logInfo(s"Falling back to repartition-based aggregation with " + + s"${aggregatedBatches.size} batches") metrics.numTasksFallBacked += 1 - val aggregatedBatchIter = new Iterator[ColumnarBatch] { - override def hasNext: Boolean = !aggregatedBatches.isEmpty - override def next(): ColumnarBatch = { - withResource(aggregatedBatches.removeFirst()) { spillable => - spillable.getColumnarBatch() - } - } - } + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val aggBufferAttributes = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + + val hashKeys: Seq[GpuExpression] = + GpuBindReferences.bindGpuReferences(groupingAttributes, aggBufferAttributes.toSeq) + + + repartitionIter = Some(RepartitionAggregateIterator( + aggregatedBatches, + hashKeys, + targetMergeBatchSize, + opTime = metrics.opTime, + repartitionTime = metrics.repartitionTime)) + repartitionIter.get + } + + /** Build an iterator that uses a sort-based approach to merge aggregated batches together. */ + private def buildSortFallbackIterator(): Iterator[ColumnarBatch] = { + logInfo(s"Falling back to sort-based aggregation with ${aggregatedBatches.size} batches") + metrics.numTasksFallBacked += 1 + val aggregatedBatchIter = cbIteratorStealingFromBuffer(aggregatedBatches) if (isReductionOnly) { // Normally this should never happen because `tryMergeAggregatedBatches` should have done @@ -1332,7 +1498,8 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan]( conf.forceSinglePassPartialSortAgg, allowSinglePassAgg, allowNonFullyAggregatedOutput, - conf.skipAggPassReductionRatio) + conf.skipAggPassReductionRatio, + conf.aggFallbackAlgorithm) } } @@ -1420,7 +1587,8 @@ abstract class GpuTypedImperativeSupportedAggregateExecMeta[INPUT <: BaseAggrega false, false, false, - 1) + 1, + conf.aggFallbackAlgorithm) } else { super.convertToGpu() } @@ -1773,6 +1941,8 @@ object GpuHashAggregateExecBase { * (can omit non fully aggregated data for non-final * stage of aggregation) * @param skipAggPassReductionRatio skip if the ratio of rows after a pass is bigger than this value + * @param aggFallbackAlgorithm use sort-based fallback or repartition-based fallback for + * oversize agg */ case class GpuHashAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -1787,7 +1957,8 @@ case class GpuHashAggregateExec( forceSinglePassAgg: Boolean, allowSinglePassAgg: Boolean, allowNonFullyAggregatedOutput: Boolean, - skipAggPassReductionRatio: Double + skipAggPassReductionRatio: Double, + aggFallbackAlgorithm: String ) extends ShimUnaryExecNode with GpuExec { // lifted directly from `BaseAggregateExec.inputAttributes`, edited comment. @@ -1809,6 +1980,7 @@ case class GpuHashAggregateExec( AGG_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_AGG_TIME), CONCAT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_CONCAT_TIME), SORT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_SORT_TIME), + REPARTITION_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_REPARTITION_TIME), "NUM_AGGS" -> createMetric(DEBUG_LEVEL, "num agg operations"), "NUM_PRE_SPLITS" -> createMetric(DEBUG_LEVEL, "num pre splits"), "NUM_TASKS_SINGLE_PASS" -> createMetric(MODERATE_LEVEL, "number of single pass tasks"), @@ -1839,6 +2011,7 @@ case class GpuHashAggregateExec( computeAggTime = gpuLongMetric(AGG_TIME), concatTime = gpuLongMetric(CONCAT_TIME), sortTime = gpuLongMetric(SORT_TIME), + repartitionTime = gpuLongMetric(REPARTITION_TIME), numAggOps = gpuLongMetric("NUM_AGGS"), numPreSplits = gpuLongMetric("NUM_PRE_SPLITS"), singlePassTasks = gpuLongMetric("NUM_TASKS_SINGLE_PASS"), @@ -1873,7 +2046,8 @@ case class GpuHashAggregateExec( boundGroupExprs, aggregateExprs, aggregateAttrs, resultExprs, modeInfo, localEstimatedPreProcessGrowth, alreadySorted, expectedOrdering, postBoundReferences, targetBatchSize, aggMetrics, useTieredProject, - localForcePre, localAllowPre, allowNonFullyAggregatedOutput, skipAggPassReductionRatio) + localForcePre, localAllowPre, allowNonFullyAggregatedOutput, skipAggPassReductionRatio, + aggFallbackAlgorithm) } } @@ -1991,7 +2165,8 @@ class DynamicGpuPartialSortAggregateIterator( forceSinglePassAgg: Boolean, allowSinglePassAgg: Boolean, allowNonFullyAggregatedOutput: Boolean, - skipAggPassReductionRatio: Double + skipAggPassReductionRatio: Double, + aggFallbackAlgorithm: String ) extends Iterator[ColumnarBatch] { private var aggIter: Option[Iterator[ColumnarBatch]] = None private[this] val isReductionOnly = boundGroupExprs.outputTypes.isEmpty @@ -2092,6 +2267,7 @@ class DynamicGpuPartialSortAggregateIterator( useTiered, allowNonFullyAggregatedOutput, skipAggPassReductionRatio, + aggFallbackAlgorithm, localInputRowsMetrics) GpuAggFinalPassIterator.makeIter(mergeIter, postBoundReferences, metrics) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala index d83f20113b2..1cbf899c04d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala @@ -61,6 +61,7 @@ object GpuMetric extends Logging { val COLLECT_TIME = "collectTime" val CONCAT_TIME = "concatTime" val SORT_TIME = "sortTime" + val REPARTITION_TIME = "repartitionTime" val AGG_TIME = "computeAggTime" val JOIN_TIME = "joinTime" val FILTER_TIME = "filterTime" @@ -95,6 +96,7 @@ object GpuMetric extends Logging { val DESCRIPTION_COLLECT_TIME = "collect batch time" val DESCRIPTION_CONCAT_TIME = "concat batch time" val DESCRIPTION_SORT_TIME = "sort time" + val DESCRIPTION_REPARTITION_TIME = "repartition time spent in agg" val DESCRIPTION_AGG_TIME = "aggregation time" val DESCRIPTION_JOIN_TIME = "join time" val DESCRIPTION_FILTER_TIME = "filter time" diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index aad4f05b334..46c2806140e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -1517,6 +1517,13 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") .checkValue(v => v >= 0 && v <= 1, "The ratio value must be in [0, 1].") .createWithDefault(1.0) + val FALLBACK_ALGORITHM_FOR_OVERSIZE_AGG = conf("spark.rapids.sql.agg.fallbackAlgorithm") + .doc("When agg cannot be done in a single pass, use sort-based fallback or " + + "repartition-based fallback.") + .stringConf + .checkValues(Set("sort", "repartition")) + .createWithDefault("sort") + val FORCE_SINGLE_PASS_PARTIAL_SORT_AGG: ConfEntryWithDefault[Boolean] = conf("spark.rapids.sql.agg.forceSinglePassPartialSort") .doc("Force a single pass partial sort agg to happen in all cases that it could, " + @@ -3079,6 +3086,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val skipAggPassReductionRatio: Double = get(SKIP_AGG_PASS_REDUCTION_RATIO) + lazy val aggFallbackAlgorithm: String = get(FALLBACK_ALGORITHM_FOR_OVERSIZE_AGG) + lazy val isRegExpEnabled: Boolean = get(ENABLE_REGEXP) lazy val maxRegExpStateMemory: Long = { From 10b7d20d2bb4675805b821b3810ab194671c9005 Mon Sep 17 00:00:00 2001 From: "Hongbin Ma (Mahone)" Date: Mon, 1 Jul 2024 17:43:59 +0800 Subject: [PATCH 2/9] doc Signed-off-by: Hongbin Ma (Mahone) --- docs/additional-functionality/advanced_configs.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md index 033e332b99c..7be166ed5de 100644 --- a/docs/additional-functionality/advanced_configs.md +++ b/docs/additional-functionality/advanced_configs.md @@ -60,6 +60,7 @@ Name | Description | Default Value | Applicable at spark.rapids.shuffle.ucx.activeMessages.forceRndv|Set to true to force 'rndv' mode for all UCX Active Messages. This should only be required with UCX 1.10.x. UCX 1.11.x deployments should set to false.|false|Startup spark.rapids.shuffle.ucx.managementServerHost|The host to be used to start the management server|null|Startup spark.rapids.shuffle.ucx.useWakeup|When set to true, use UCX's event-based progress (epoll) in order to wake up the progress thread when needed, instead of a hot loop.|true|Startup +spark.rapids.sql.agg.fallbackAlgorithm|When agg cannot be done in a single pass, use sort-based fallback or repartition-based fallback.|sort|Runtime spark.rapids.sql.agg.skipAggPassReductionRatio|In non-final aggregation stages, if the previous pass has a row reduction ratio greater than this value, the next aggregation pass will be skipped.Setting this to 1 essentially disables this feature.|1.0|Runtime spark.rapids.sql.allowMultipleJars|Allow multiple rapids-4-spark, spark-rapids-jni, and cudf jars on the classpath. Spark will take the first one it finds, so the version may not be expected. Possisble values are ALWAYS: allow all jars, SAME_REVISION: only allow jars with the same revision, NEVER: do not allow multiple jars at all.|SAME_REVISION|Startup spark.rapids.sql.castDecimalToFloat.enabled|Casting from decimal to floating point types on the GPU returns results that have tiny difference compared to results returned from CPU.|true|Runtime From 4451c5467392176adf3571645ea9d8546d4e3ee4 Mon Sep 17 00:00:00 2001 From: "Hongbin Ma (Mahone)" Date: Mon, 1 Jul 2024 19:05:08 +0800 Subject: [PATCH 3/9] fix scala 2.13 Signed-off-by: Hongbin Ma (Mahone) --- sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala index b0cd798c179..96254b9f38d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala @@ -43,7 +43,8 @@ object Arm extends ArmScalaSpecificImpl { } /** Executes the provided code block and then closes the sequence of resources */ - def withResource[T <: AutoCloseable, V](r: Seq[T])(block: Seq[T] => V): V = { + def withResource[T <: AutoCloseable, V](r: scala.collection.Seq[T]) + (block: scala.collection.Seq[T] => V): V = { try { block(r) } finally { From 4da57979fc2d8cd3e086cf013656c6910e9c34d2 Mon Sep 17 00:00:00 2001 From: "Hongbin Ma (Mahone)" Date: Mon, 1 Jul 2024 20:33:18 +0800 Subject: [PATCH 4/9] fix compile Signed-off-by: Hongbin Ma (Mahone) --- sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala | 4 ++-- .../main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala index 96254b9f38d..de75381d1d1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala @@ -43,8 +43,8 @@ object Arm extends ArmScalaSpecificImpl { } /** Executes the provided code block and then closes the sequence of resources */ - def withResource[T <: AutoCloseable, V](r: scala.collection.Seq[T]) - (block: scala.collection.Seq[T] => V): V = { + def withResource[T <: AutoCloseable, V](r: Seq[T]) + (block: Seq[T] => V): V = { try { block(r) } finally { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala index b28101f3442..8553db848c7 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala @@ -134,7 +134,7 @@ object AggregateUtils extends Logging { * @return lazy spillable batch which has NOT been marked spillable */ private def concatenateAndMerge( - batches: mutable.Buffer[SpillableColumnarBatch], + batches: mutable.ArrayBuffer[SpillableColumnarBatch], metrics: GpuHashAggregateMetrics, concatAndMergeHelper: AggHelper): SpillableColumnarBatch = { // TODO: concatenateAndMerge (and calling code) could output a sequence From e803c36c40856debd14bb212bbae81bd8256ca62 Mon Sep 17 00:00:00 2001 From: "Hongbin Ma (Mahone)" Date: Mon, 1 Jul 2024 23:20:29 +0800 Subject: [PATCH 5/9] fix it Signed-off-by: Hongbin Ma (Mahone) --- .../scala/com/nvidia/spark/rapids/GpuAggregateExec.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala index 8553db848c7..252b9e8a95b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala @@ -223,7 +223,7 @@ object AggregateUtils extends Logging { withResource(new NvtxWithMetrics("agg merge pass", NvtxColor.BLUE, concatTime, opTime)) { _ => // continue merging as long as some batches are able to be combined - if (!mergePass(aggregatedBatches, targetMergeBatchSize, helper, metrics)) + if (!mergePass(aggregatedBatches, targetMergeBatchSize, helper, metrics)) { if (aggregatedBatches.size > 1 && isReductionOnly) { // We were unable to merge the aggregated batches within the target batch size limit, // which means normally we would fallback to a sort-based approach. However for @@ -243,7 +243,8 @@ object AggregateUtils extends Logging { aggregatedBatches += batch } } - return + return + } } } } From 0b50434faba9ca526cfbfea560fd2e50058e7bcd Mon Sep 17 00:00:00 2001 From: "Hongbin Ma (Mahone)" Date: Tue, 2 Jul 2024 17:05:19 +0800 Subject: [PATCH 6/9] enable it Signed-off-by: Hongbin Ma (Mahone) --- .../src/main/python/hash_aggregate_test.py | 7 +++- .../scala/com/nvidia/spark/rapids/Arm.scala | 3 +- .../spark/rapids/GpuAggregateExec.scala | 32 ++++++++++++------- 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index d1cd70aa43c..d749960020e 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -38,6 +38,9 @@ _float_conf_skipagg = copy_and_update(_float_smallbatch_conf, {'spark.rapids.sql.agg.skipAggPassReductionRatio': '0'}) +_float_conf_repartition_fallback = copy_and_update(_float_smallbatch_conf, + {'spark.rapids.sql.agg.fallbackAlgorithm': 'repartition'}) + _float_conf_partial = copy_and_update(_float_conf, {'spark.rapids.sql.hashAgg.replaceMode': 'partial'}) @@ -225,7 +228,9 @@ def get_params(init_list, marked_params=[]): # Run these tests with in 5 modes, all on the GPU -_confs = [_float_conf, _float_smallbatch_conf, _float_conf_skipagg, _float_conf_final, _float_conf_partial] +_confs = [_float_conf, _float_smallbatch_conf, + _float_conf_skipagg, _float_conf_repartition_fallback, + _float_conf_final, _float_conf_partial] # Pytest marker for list of operators allowed to run on the CPU, # esp. useful in partial and final only modes. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala index de75381d1d1..b0cd798c179 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala @@ -43,8 +43,7 @@ object Arm extends ArmScalaSpecificImpl { } /** Executes the provided code block and then closes the sequence of resources */ - def withResource[T <: AutoCloseable, V](r: Seq[T]) - (block: Seq[T] => V): V = { + def withResource[T <: AutoCloseable, V](r: Seq[T])(block: Seq[T] => V): V = { try { block(r) } finally { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala index 252b9e8a95b..82b308d3d27 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala @@ -119,7 +119,7 @@ object AggregateUtils extends Logging { } // Calculate the max rows that can be processed during computation within the budget - val maxRows = totalBudget / computationBytesPerRow + val maxRows = Math.max(totalBudget / computationBytesPerRow, 1) // Finally compute the input target batching size taking into account the cudf row limits Math.min(inputRowSize * maxRows, Int.MaxValue) @@ -212,7 +212,7 @@ object AggregateUtils extends Logging { */ def tryMergeAggregatedBatches( aggregatedBatches: mutable.Buffer[SpillableColumnarBatch], - isReductionOnly: Boolean, + forceMergeRegardlessOfOversize: Boolean, metrics: GpuHashAggregateMetrics, targetMergeBatchSize: Long, helper: AggHelper @@ -224,14 +224,19 @@ object AggregateUtils extends Logging { opTime)) { _ => // continue merging as long as some batches are able to be combined if (!mergePass(aggregatedBatches, targetMergeBatchSize, helper, metrics)) { - if (aggregatedBatches.size > 1 && isReductionOnly) { + if (aggregatedBatches.size > 1 && forceMergeRegardlessOfOversize) { // We were unable to merge the aggregated batches within the target batch size limit, - // which means normally we would fallback to a sort-based approach. However for - // reduction-only aggregation there are no keys to use for a sort. The only way this - // can work is if all batches are merged. This will exceed the target batch size limit, - // but at this point it is either risk an OOM/cudf error and potentially work or - // not work at all. - logWarning(s"Unable to merge reduction-only aggregated batches within " + + // which means normally we would fallback to a sort-based approach. + // There are two exceptions: + // 1. reduction-only aggregation case, there are no keys to use for a sort. + // The only way this can work is if all batches are merged. This will exceed + // the target batch size limit but at this point it is either risk an OOM/cudf error + // and potentially work or not work at all. + // 2. re-partition agg case where all batches are have only 1 row each (Usually + // this only happens in test cases). Doing more re-partitioning will not help to reduce + // the partition size anymore. In this case we should merge all the batches into one + // regardless of the target size. + logWarning(s"Unable to merge aggregated batches within " + s"target batch limit of $targetMergeBatchSize, attempting to merge remaining " + s"${aggregatedBatches.size} batches beyond limit") withResource(mutable.ArrayBuffer[SpillableColumnarBatch]()) { batchesToConcat => @@ -1022,6 +1027,10 @@ class GpuMergeAggregateIterator( def totalSize(): Long = batches.map(_.sizeInBytes).sum + def isAllBatchesSingleRow: Boolean = { + batches.forall(_.numRows() == 1) + } + def split(): ListBuffer[AggregatePartition] = { withResource(new NvtxWithMetrics("agg repartition", NvtxColor.CYAN, repartitionTime)) { _ => if (seed > hashSeed + 20) { @@ -1081,7 +1090,8 @@ class GpuMergeAggregateIterator( } val headPartition = aggPartitions.remove(0) - if (headPartition.totalSize() > targetMergeBatchSize) { + if (!headPartition.isAllBatchesSingleRow && + headPartition.totalSize() > targetMergeBatchSize) { deferredAggPartitions += headPartition return next() } @@ -1089,7 +1099,7 @@ class GpuMergeAggregateIterator( withResource(headPartition) { _ => val batchSizeBeforeMerge = headPartition.batches.size AggregateUtils.tryMergeAggregatedBatches( - headPartition.batches, isReductionOnly, metrics, + headPartition.batches, isReductionOnly || headPartition.isAllBatchesSingleRow, metrics, targetMergeBatchSize, concatAndMergeHelper) if (headPartition.batches.size != 1) { throw new IllegalStateException( From a000c9bbd4376aa7f68a202c0013c618f2c1258c Mon Sep 17 00:00:00 2001 From: "Hongbin Ma (Mahone)" Date: Tue, 2 Jul 2024 17:07:05 +0800 Subject: [PATCH 7/9] metric name Signed-off-by: Hongbin Ma (Mahone) --- sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala index 1cbf899c04d..d58d6c11036 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala @@ -96,7 +96,7 @@ object GpuMetric extends Logging { val DESCRIPTION_COLLECT_TIME = "collect batch time" val DESCRIPTION_CONCAT_TIME = "concat batch time" val DESCRIPTION_SORT_TIME = "sort time" - val DESCRIPTION_REPARTITION_TIME = "repartition time spent in agg" + val DESCRIPTION_REPARTITION_TIME = "repartition time" val DESCRIPTION_AGG_TIME = "aggregation time" val DESCRIPTION_JOIN_TIME = "join time" val DESCRIPTION_FILTER_TIME = "filter time" From 82cacbf254e2f04e1e803e7b501b882990ff2c27 Mon Sep 17 00:00:00 2001 From: "Hongbin Ma (Mahone)" Date: Tue, 2 Jul 2024 17:26:46 +0800 Subject: [PATCH 8/9] minor Signed-off-by: Hongbin Ma (Mahone) --- .../spark/rapids/GpuAggregateExec.scala | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala index 82b308d3d27..d1c7f3d59c5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala @@ -1033,8 +1033,9 @@ class GpuMergeAggregateIterator( def split(): ListBuffer[AggregatePartition] = { withResource(new NvtxWithMetrics("agg repartition", NvtxColor.CYAN, repartitionTime)) { _ => - if (seed > hashSeed + 20) { - throw new IllegalStateException("At most repartition 3 times for a partition") + if (seed >= hashSeed + 100) { + throw new IllegalStateException("repartitioned too many times, please " + + "consider disabling repartition-based fallback for aggregation") } val totalSize = batches.map(_.sizeInBytes).sum val newSeed = seed + 10 @@ -1049,26 +1050,26 @@ class GpuMergeAggregateIterator( } } } - } - private def preparePartitions( - newSeed: Int, - partitioner: GpuBatchSubPartitioner, - partitions: ListBuffer[AggregatePartition]): Unit = { - (0 until partitioner.partitionsCount).foreach { id => - val buffer = ListBuffer.empty[SpillableColumnarBatch] - buffer ++= partitioner.releaseBatchesByPartition(id) - val newPart = AggregatePartition.apply(buffer, newSeed) - if (newPart.totalRows() > 0) { - partitions += newPart - } else { - newPart.safeClose() + private def preparePartitions( + newSeed: Int, + partitioner: GpuBatchSubPartitioner, + partitions: ListBuffer[AggregatePartition]): Unit = { + (0 until partitioner.partitionsCount).foreach { id => + val buffer = ListBuffer.empty[SpillableColumnarBatch] + buffer ++= partitioner.releaseBatchesByPartition(id) + val newPart = AggregatePartition.apply(buffer, newSeed) + if (newPart.totalRows() > 0) { + partitions += newPart + } else { + newPart.safeClose() + } } } - } - private[this] def computeNumPartitions(totalSize: Long): Int = { - Math.floorDiv(totalSize, targetMergeBatchSize).toInt + 1 + private[this] def computeNumPartitions(totalSize: Long): Int = { + Math.floorDiv(totalSize, targetMergeBatchSize).toInt + 1 + } } private val hashSeed = 100 From 4cf4a4566008321f6bc9f600365563daa11614cf Mon Sep 17 00:00:00 2001 From: "Hongbin Ma (Mahone)" Date: Tue, 2 Jul 2024 18:35:27 +0800 Subject: [PATCH 9/9] change seed Signed-off-by: Hongbin Ma (Mahone) --- .../main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala index d1c7f3d59c5..1693c3203d7 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala @@ -1038,7 +1038,7 @@ class GpuMergeAggregateIterator( "consider disabling repartition-based fallback for aggregation") } val totalSize = batches.map(_.sizeInBytes).sum - val newSeed = seed + 10 + val newSeed = seed + 7 val iter = cbIteratorStealingFromBuffer(batches) withResource(new GpuBatchSubPartitioner( iter, hashKeys, computeNumPartitions(totalSize), newSeed, "aggRepartition")) {