Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Explore how to best deal with large numbers of aggregations in the short term #8141

Closed
revans2 opened this issue Apr 18, 2023 · 9 comments
Assignees
Labels
performance A performance related task/issue reliability Features to improve reliability or bugs that severly impact the reliability of the plugin

Comments

@revans2
Copy link
Collaborator

revans2 commented Apr 18, 2023

Is your feature request related to a problem? Please describe.
The current aggregation code can run into a lot of performance issues when.

  1. The output is large
  2. There are a huge number of aggregations being done.

Currently the algorithm is to read in each input batch and aggregate it to an intermediate result. After that we try to merge those batches together into a single output batch. If we are unable to merge those batches into one that is smaller than the target batch size we fall back to doing a sort based aggregation (which is really a sort based merge aggregation). We sort all of the intermediate batches by the grouping keys. Then we split the batches on key boundaries and finally do a final merge pass of each of those batches and output them.

Technically this might mess with some aggregations like first and last where we need to maintain the order because the sort is not stable.

Within each aggregation step there is a preprocess step, the actual aggregation, and a post process step. This pre-process step can take a lot of memory, because in some cases, like with decimal128, we have to expand the output into something that is larger than the input to have enough space to detect overflow conditions.

When we have hundreds of aggregations to do what is the fastest way to do those aggregations? A hash aggregation is likely very fast, but to build a hash table with hundreds of intermediate results in it, is going to have memory pressure problems and also potentially performance problems.

We should explore if a sort based aggregation ever wins when there are large numbers of aggregations being done. If not, then we should look at removing the sort and instead doing a sub-partitioning algorithm similar to what we have done in join recently. If it does, then we should look at finding a good way to do a stable sort for first/last, and if we can detect that the input is already sorted so we can tell CUDF this and also avoid needing to cache the data before processing it.

In the long term we might need to work with CUDF to find better ways of optimizing for this use case.

@revans2 revans2 added feature request New feature or request ? - Needs Triage Need team to review and classify performance A performance related task/issue reliability Features to improve reliability or bugs that severly impact the reliability of the plugin labels Apr 18, 2023
@sameerz sameerz removed the feature request New feature or request label Apr 18, 2023
@mattahrens mattahrens added feature request New feature or request and removed ? - Needs Triage Need team to review and classify labels Apr 18, 2023
@jbrennan333
Copy link
Collaborator

For an initial test of this, I created an SQL query of the store_sales table from NDS. I renamed the decimal columns to a,b,c... to simplify things, and then built a query in pyspark like this:

sqlString = f"SELECT {sort_col}"
vals = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm' ]
perms = itertools.islice(itertools.permutations(vals, num_mul_cols), num_funcs)
for elem in perms:
        el = list(elem)
        tag = ''.join(el)
        sqlString += ",\n"
        mul = ' * '.join(el)
        sqlString += f'SUM({mul}) AS sum_{tag}'

sqlString += "\nFROM data\n"
sqlString += f"GROUP BY {sort_col}"

For example, a query with num_funcs=10 and sort_col=ss_store_sk looks like this:

SELECT ss_store_sk,
    SUM(a * b * c * d) AS sum_abcd,
    SUM(a * b * c * e) AS sum_abce,
    SUM(a * b * c * f) AS sum_abcf,
    SUM(a * b * c * g) AS sum_abcg,
    SUM(a * b * c * h) AS sum_abch,
    SUM(a * b * c * i) AS sum_abci,
    SUM(a * b * c * j) AS sum_abcj,
    SUM(a * b * c * k) AS sum_abck,
    SUM(a * b * c * l) AS sum_abcl,
    SUM(a * b * c * m) AS sum_abcm
FROM data
GROUP BY ss_store_sk

The store_sales table is partitioned on ss_sold_date_sk, so for presorted runs I used sort_col=ss_sold_date_sk and for unsorted runs I used sort_col=ss_store_sk.

I then ran this locally on an RTX4000 at scale 100 and on an 8-node A100 cluster at scale 3000. To compare the sort fallback merge vs the default hashed merge, I ran it with the original (current code), and with the call to tryMergeAggregatedBatches() commented out, so we always use the sort fallback. Note that to get this to run without OOMs, I had to set spark.rapids.sql.batchSizeBytes=32mb on the 8-node cluster, and 1mb on my local desktop.

Here are the results for runs on the 8-node a100 cluster:
<style type="text/css"></style>

Number of SUMS orig-sorted fallback-sorted orig-unsorted fallback-unsorted
10 47.71 52.81 48.73 53.45
20 74.05 83.41 75.83 84.10
40 128.20 145.87 137.41 174.68
80 246.62 272.69 246.12 344.88
160 492.80 531.85 535.28 697.74
320 977.15 1,117.47 1,008.52 1,373.64

big-agg

@revans2, what do you recommend for next steps?

We should explore if a sort based aggregation ever wins when there are large numbers of aggregations being done. If not, then we should look at removing the sort and instead doing a sub-partitioning algorithm similar to what we have done in join recently.

Should I start working on a sub-partitioning algorithm, or are there additional tests or other optimizations worth trying first?

@revans2
Copy link
Collaborator Author

revans2 commented May 12, 2023

A few things that pop out to me.

The first one is the fact that you had to set the batch size in bytes so small to even make this work. I am guessing that this has to do with the project that happens before we do the aggregation. We need a small input data because the output data is going to be some number of times larger than the input data.

The second thing is the key being used for the aggregations. The cardinality of ss_store_sk and ss_sold_date_sk don't appear to scale at the same rate as each other. At the smaller scale factors the cardinalities are not even close to each other. This can have a big impact on performance, and memory usage. Not just the time it takes to do the aggregation. I would like to see us use the same key in all tests. Unless we are explicitly changing to a different key to see the impact of cardinality. Also I would like the key to not be one that we have partitioned by or sorted by previously.

I have a couple ideas in my head about how we might be able to process the data to reduce memory usage while doing these really big aggregates. I would love it if we could come up with some experiments that might point us in the right direction before having to write 10 different changes and see which one is best, but I am not 100% sure how to do that.

The current code does the following.

  1. pre-process (I think this is where we are hitting memory issues right now)
  2. initial-agg (output of each input batch is saved away)
  3. check for sort fallback (if the output of the initial-agg pass is larger than a single target batch, then we sort the initial-agg data before merging)
  4. merge-pass
  5. post-processing

I don't think I have described the steps perfectly, but hopefully it is good enough.

First off the pre-process step is likely to make the output batch size much larger than the input batch size. This is because things like average require a SUM and a COUNT, or decimal SUMs need multiple output columns to detect overflow, or even the fact that I might be doing lots of aggregates on a single column. i.e. SUM(a), AVG(a), COUNT(a), MIN(a), MAX(a)... in addition to a regular project that might increase the number of columns.

So to deal with this we need to either be more efficient with memory or partition the data. Being more efficient with the memory is good, but it cannot solve the problem 100%. So we have to look at splitting the data in some way.

Our data is two dimensional (columns and rows) so we can split it two ways. We can keep all of the columns and split the number of rows in a batch. This is a good generic solution for project everywhere, so I think it is something we should look at first. Really it comes down to estimating the output size of the project based off of the output schema. If all of the data is fixed width types then we are done. If some are variable width we can estimate things there. The problem with doing a row wise split is at some point we can have so many columns that we are doing the processing a row or a few at a time. That is going to be horrible for performance.

We could also split the data by aggregations, i.e. do half of the aggregations in one pass, and then half of the aggregations in another pass. The problem with this is two fold. First we might have to keep all of the input columns around while we process half of the aggregations, and second we have to have a way line up the rows after the aggregations. This is hard with a hash-aggregate because it can decide to reorder rows. So to make this work we would have to sort the data ourselves and then tall CUDF to use a sort based aggregation instead.

The second part of the memory issue is the aggregations themselves (pre and merge). We know that sort is slower than hash partitioning so falling back to a hash re-partition instead of a sort would be interesting. But potentially a lot of work, and if we ever what to do a column wise split it is not going to help because we have to sort the data.

So the main question(s) I have are really about how many aggregations do we have to run before we start to run out of memory and we might have to start thinking about splitting the data by aggregations instead of by rows. So with a batch size of 512 MiB at what point does the current code start to run out of memory (16-GiB GPU and/or large GPU). Second for the 300 aggregations case how large are the output batches when we tune the input batches to not crash. (Number of rows, columns and MiB).

That should hopefully tell us if we have to support splitting the aggregations for all likely cases or not.

@jbrennan333
Copy link
Collaborator

Thanks @revans! This helps a lot.

The first one is the fact that you had to set the batch size in bytes so small to even make this work

Yes. I was initially only running with >300 functions, so I tuned it down to avoid the ooms, but I was concerned that these batch sizes might be too small for a meaningful test.

The second thing is the key being used for the aggregations. The cardinality of ss_store_sk and ss_sold_date_sk don't appear to scale at the same rate as each other. At the smaller scale factors the cardinalities are not even close to each other. This can have a big impact on performance, and memory usage. Not just the time it takes to do the aggregation. I would like to see us use the same key in all tests. Unless we are explicitly changing to a different key to see the impact of cardinality. Also I would like the key to not be one that we have partitioned by or sorted by previously.

I was actually thinking of this as two separate tests, one with GROUPBY ss_sold_data_sk, which matches the parquet partitioning, and one with GROUPBY ss_store_sk. I didn't actually sort the data in either case, so should not have used the terms sorted/unsorted. And I agree that I should not have put them all in one chart, because it doesn't make sense to compare with the two separate keys. Because the fallback sort was not significantly worse in the groupby matches parquet partitioning case, I wanted to try it with a different key to see if there was more of an impact. I will focus on using just one key for future testing.

So the main question(s) I have are really about how many aggregations do we have to run before we start to run out of memory and we might have to start thinking about splitting the data by aggregations instead of by rows. So with a batch size of 512 MiB at what point does the current code start to run out of memory (16-GiB GPU and/or large GPU). Second for the 300 aggregations case how large are the output batches when we tune the input batches to not crash. (Number of rows, columns and MiB).

I will work on answering these questions. Thanks!

@jbrennan333
Copy link
Collaborator

I did some testing on the 8-node a100 cluster to determine when we start seeing OOMs with a fixed configuration.
I started it with:

export SPARK_MASTER_URL=spark://$SPARK_MASTER_HOSTNAME:7077

export SPARK_RAPIDS_PLUGIN_DIR=/opt/sparkRapidsPlugin
export SPARK_RAPIDS_PLUGIN_JAR=$SPARK_RAPIDS_PLUGIN_DIR/rapids-4-spark_2.12-23.06.0-SNAPSHOT-cuda11-jimb-orig.jar
export TZ=utc
NUM_FUNCS=${1:-10}
${SPARK_HOME}/bin/spark-submit\
 --master ${SPARK_MASTER_URL}\
 --conf spark.locality.wait=0 \
 --conf spark.plugins=com.nvidia.spark.SQLPlugin \
 --conf spark.sql.adaptive.enabled=true \
 --conf spark.sql.files.maxPartitionBytes=2gb \
 --conf spark.driver.maxResultSize=2GB \
 --conf spark.driver.memory=50G \
 --conf spark.driver.extraClassPath=$SPARK_RAPIDS_PLUGIN_JAR \
 --conf spark.executor.instances=8 \
 --conf spark.executor.cores=16 \
 --conf spark.executor.memory=16G \
 --conf spark.executor.resource.gpu.amount=1 \
 --conf spark.executor.extraClassPath=$SPARK_RAPIDS_PLUGIN_JAR \
 --conf spark.task.resource.gpu.amount=0.0625 \
 --conf spark.rapids.memory.host.spillStorageSize=32G \
 --conf spark.rapids.memory.pinnedPool.size=8G \
 --conf spark.rapids.sql.batchSizeBytes=512mb\
 --conf spark.rapids.sql.concurrentGpuTasks=4 \
 --conf spark.rapids.sql.incompatibleOps.enabled=true \
 --conf spark.rapids.sql.explain=ALL\
 --conf spark.rapids.sql.metrics.level=DEBUG\
 --conf spark.shuffle.manager=com.nvidia.spark.rapids.spark321.RapidsShuffleManager\
 --conf spark.rapids.shuffle.multiThreaded.writer.threads=32\
 --conf spark.rapids.shuffle.multiThreaded.reader.threads=32\
 --conf spark.rapids.shuffle.mode=MULTITHREADED\
 --jars $SPARK_RAPIDS_PLUGIN_JAR\
 big-agg.py ${NUM_FUNCS}

I was varying the NUM_FUNCS, which defines how many aggregation functions are included (see the script).
Each agg function is of the form SUM( a * b * c * d), where a different permutation of columns is used for each.

I started hitting OOM errors with NUM_FUNCS=103.

Output from the previous run shows:

output-102.txt:JTB: funcs: 102 multiply-columns: 4 sort-column: ss_store_sk time: 625.6179025173187

On my desktop, with an RTX4000 with 16GB, I started hitting OOMs are 52 funcs.

The next step is to add some debug logging to capture more information about the input/output batch sizes.

Here is the full script I was using for these tests:

import sys
import time
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum
import itertools

num_funcs = int(10)
if len(sys.argv) == 2:
	num_funcs = int(sys.argv[1])
num_mul_cols = 4
sort_col = 'ss_store_sk'
#sort_col = 'ss_sold_date_sk'
spark = SparkSession.builder.appName('big-agg').getOrCreate()

#df = spark.read.parquet("/opt/data/nds2/parquet_sf100/store_sales")
df = spark.read.parquet("hdfs:///data/nds2.0/parquet_sf3k_decimal/store_sales")
df.printSchema()
newNames = [
	'ss_sold_time_sk', 'ss_item_sk', 'ss_customer_sk', 'ss_cdemo_sk',
	'ss_hdemo_sk', 'ss_addr_sk', 'ss_store_sk', 'ss_promo_sk', 'ss_ticket_number',
	'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
	'ss_sold_date_sk',
]

newDf = df.toDF(*newNames)
newDf.printSchema()
newDf.createOrReplaceTempView("data")

sqlString = f"SELECT {sort_col}"
vals = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm' ]
perms = itertools.islice(itertools.permutations(vals, num_mul_cols), num_funcs)
for elem in perms:
	el = list(elem)
	tag = ''.join(el)
	sqlString += ",\n"
	mul = ' * '.join(el)
	sqlString += f'SUM({mul}) AS sum_{tag}'

sqlString += "\nFROM data\n"
sqlString += f"GROUP BY {sort_col}"
aggDf = spark.sql(sqlString)
start_time = time.time()
aggDf.write.mode('overwrite').parquet("/data/jimb/big-agg-output")
elapsed = time.time() - start_time
print(f"JTB: funcs: {num_funcs} multiply-columns: {num_mul_cols} sort-column: {sort_col} time: {elapsed}\n")

@revans2
Copy link
Collaborator Author

revans2 commented May 24, 2023

103 aggregations at 40 GiB feels really low. I decided to just get an idea of how large the output is compared to the input batch, but I messed up when setting the batch size and instead set spark.sql.files.maxPartitionBytes to 512 MiB. I think that this shows the possibility of doing a split on the preprocess project to try and get a target output size.

With a partition size of 512MiB I was able to get 160 aggregations to succeed on a 48 GiB a6000 with a concurrency of 4 and a batch size of 1 GiB, but the output of the preprocess step was up to 17.2 GiB. Technically that is larger than what our estimates indicate would work on a 48 GiB GPU. We would need about 70 GiB for the aggregation step to pass with our 4x input size estimate.

But it also let me get a good estimate on what the maxPartitionBytes would need to be for us to get 1 GiB of output after doing the preprocess stage. When I set it to 30 MiB I was able to successfully test it at 512 aggregations, and I estimate that I could have done 1600 aggregations before we ran out of memory again. That would have equated to over 512 aggregations on a 16 GiB GPU. That ended up splitting the input data into 1,000,000 row batches, because that is what the GPU uses as the default row group size when writing out parquet. I am sure we could have gone even smaller to get even more aggregations. I think that this shows that we have a good first step with just splitting by rows, and implementing a hash based fallback for aggregations instead of a sort based fallback for aggregations.

That said we should also file a follow on issue to really dig into what it would be the fastest way to do these aggregations. Is it ever faster to sort the data and do the aggregations a few at a time? If so how many should we do at a time. How many aggregations makes it so we should take the sorting hit, etc...

On the happy side I did test the performance on my desktop CPU and even with spilling/etc we are 17.2x faster at 512 aggregations vs a 12 core CPU with 80 GiB of heap. At 10 aggregations we were about 17-18x faster too, so it is good to see that we at least scale similarly to the CPU.

@mattahrens
Copy link
Collaborator

@revans2 to file follow-on experiment issues.

@revans2
Copy link
Collaborator Author

revans2 commented May 24, 2023

I filed #8382 as one follow on issue to this. There will be more, but it might take a while to get to filing them.

@revans2
Copy link
Collaborator Author

revans2 commented May 24, 2023

I filed #8390 and #8391 as two other follow on issues. I still need to file at least one more follow on issue. But that is to explore splitting the data by column instead of by row, just to see which is faster.

@revans2
Copy link
Collaborator Author

revans2 commented May 25, 2023

The final issue I filed is #8398 to explore sort and splitting the data by aggregations instead of by row.

@revans2 revans2 closed this as completed May 25, 2023
@sameerz sameerz removed the feature request New feature or request label May 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance A performance related task/issue reliability Features to improve reliability or bugs that severly impact the reliability of the plugin
Projects
None yet
Development

No branches or pull requests

4 participants