Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Micro optimization for softmax_forward_kernel5 #762

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

insop
Copy link

@insop insop commented Sep 20, 2024

This branch includes a micro-optimization for softmax_forward_kernel5.

Summary

  • use warpReduceMax in attention_forward.cu to use __shfl_down_sync to be consistent with the other kernels (reduce to all threads in a warp)

  • micro optimization for softmax_forward_kernel5

    • Result from ncu ./profile_gpt2cu: compared to the original code, the with this optimization gain improvements (left: original code, right: modified code):

      • Duration: 1.47 ms -> 1.38 ms
      • Compute (SM) [%]: 77.11% -> 78.68%
      • DRAM Throughput [%]: 45.03% -> 47.91%
  • tests done:

    • ./profile_gpt2cu
    • ./attention_forward 4
    • ./attention_forward 5

Output from modified code

  • NCU log using A30
make profile_gpt2cu NO_MULTI_GPU=1
ncu ./profile_gpt2cu

  softmax_forward_kernel5(__nv_bfloat16 *, float, const __nv_bfloat16 *, int, int), 2024-Sep-20 01:45:01, Context 1, Stream 16
    Section: GPU Speed Of Light Throughput
    ---------------------------------------------------------------------- --------------- ------------------------------
    DRAM Frequency                                                           cycle/nsecond                           1.21
    SM Frequency                                                             cycle/usecond                         929.76
    Elapsed Cycles                                                                   cycle                        1283575
    Memory [%]                                                                           %                          54.15
    DRAM Throughput                                                                      %                          47.91
    Duration                                                                       msecond                           1.38
    L1/TEX Cache Throughput                                                              %                          54.50
    L2 Cache Throughput                                                                  %                          51.48
    SM Active Cycles                                                                 cycle                     1275362.68
    Compute (SM) [%]                                                                     %                          78.68
    ---------------------------------------------------------------------- --------------- ------------------------------
                                

Output from the original code

  • NCU log using A30
make profile_gpt2cu NO_MULTI_GPU=1
ncu ./profile_gpt2cu

  softmax_forward_kernel5(__nv_bfloat16 *, float, const __nv_bfloat16 *, int, int), 2024-Sep-20 01:49:03, Context 1, Stream 16
    Section: GPU Speed Of Light Throughput
    ---------------------------------------------------------------------- --------------- ------------------------------
    DRAM Frequency                                                           cycle/nsecond                           1.21
    SM Frequency                                                             cycle/usecond                         928.26
    Elapsed Cycles                                                                   cycle                        1366538
    Memory [%]                                                                           %                          45.03
    DRAM Throughput                                                                      %                          45.03
    Duration                                                                       msecond                           1.47
    L1/TEX Cache Throughput                                                              %                          33.10
    L2 Cache Throughput                                                                  %                          48.18
    SM Active Cycles                                                                 cycle                     1358789.59
    Compute (SM) [%]                                                                     %                          77.11
    ---------------------------------------------------------------------- --------------- ------------------------------
                                    

output from ./attention_forward

nvcc -O3 --use_fast_math -lcublas -lcublasLt attention_forward.cu -o attention_forward
  • testing softmax_forward_kernel4
# ./attention_forward 4
enable_tf32: 1
Using kernel 4
Checking block size 32.
-0.529510 -0.529297
0.889394 0.889160
0.881674 0.881836
0.651789 0.651855
-0.483486 -0.483398
1.000000 1.000000
0.000000 0.000000
0.000000 0.000000
0.000000 0.000000
0.000000 0.000000
Checking block size 64.
-0.529510 -0.529297
0.889394 0.889160
0.881674 0.881836
0.651789 0.651855
-0.483486 -0.483398
1.000000 1.000000
0.000000 0.000000
0.000000 0.000000
0.000000 0.000000
0.000000 0.000000
Checking block size 128.
-0.529510 -0.529297
0.889394 0.889160
0.881674 0.881836
0.651789 0.651855
-0.483486 -0.483398
1.000000 1.000000
0.000000 0.000000
0.000000 0.000000
0.000000 0.000000
0.000000 0.000000
Checking block size 256.
-0.529510 -0.529297
0.889394 0.889160
0.881674 0.881836
0.651789 0.651855
-0.483486 -0.483398
1.000000 1.000000
0.000000 0.000000
0.000000 0.000000
0.000000 0.000000
0.000000 0.000000
Checking block size 512.
-0.529510 -0.529297
0.889394 0.889160
0.881674 0.881836
0.651789 0.651855
-0.483486 -0.483398
1.000000 1.000000
0.000000 0.000000
0.000000 0.000000
0.000000 0.000000
0.000000 0.000000
All results match. Starting benchmarks.

block_size   32 | time 2.794404 ms
block_size   64 | time 2.136679 ms
block_size  128 | time 2.125906 ms
block_size  256 | time 2.128598 ms
block_size  512 | time 2.151445 ms

  • testing softmax_forward_kernel5

# ./attention_forward 5
enable_tf32: 1
Using kernel 5
Checking block size 32.
-0.529510 -0.531250
0.889394 0.890625
0.881674 0.882812
0.651789 0.652344
-0.483486 -0.484375
Checking block size 64.
-0.529510 -0.531250
0.889394 0.890625
0.881674 0.882812
0.651789 0.652344
-0.483486 -0.484375
Checking block size 128.
-0.529510 -0.531250
0.889394 0.890625
0.881674 0.882812
0.651789 0.652344
-0.483486 -0.484375
Checking block size 256.
-0.529510 -0.531250
0.889394 0.890625
0.881674 0.882812
0.651789 0.652344
-0.483486 -0.484375
Checking block size 512.
-0.529510 -0.531250
0.889394 0.890625
0.881674 0.882812
0.651789 0.652344
-0.483486 -0.484375
All results match. Starting benchmarks.

block_size   32 | time 2.016379 ms
block_size   64 | time 1.455155 ms
block_size  128 | time 1.452482 ms
block_size  256 | time 1.450271 ms
block_size  512 | time 1.454224 ms

insop song and others added 4 commits September 16, 2024 16:00
- Micro optimize softmax_forward5
- use __shfl_xor_sync for warpReducMax for all threads return the max
@insop
Copy link
Author

insop commented Sep 22, 2024

@gordicaleksa , @ngc92, @ademeure,

It would be great if you can take a look at this PR when you get a chance.
Thank you.

@ngc92
Copy link
Contributor

ngc92 commented Sep 23, 2024

could you give a bit more detail about these changes? From a quick look, it seems like you changed a block-wise reduction into just warp-level reduction. Is that correct?

@insop
Copy link
Author

insop commented Sep 24, 2024

could you give a bit more detail about these changes? From a quick look, it seems like you changed a block-wise reduction into just warp-level reduction. Is that correct?

Hi @ngc92

When I profiled softmax_forward_kernel5(), I found that the last part of the function, which updates the row and writes back to the out memory, consumes a large portion of the time the function runs.

So, I looked more closely at the last part and determined that organizing the memory write as 4 floats improves memory throughput due to better coalesced access.

@ngc92
Copy link
Contributor

ngc92 commented Sep 24, 2024

@insop
the optimization of kernel 5 looks clear to me, it's the changes of kernel 4 that I'm worried about.

@insop
Copy link
Author

insop commented Sep 24, 2024

@insop the optimization of kernel 5 looks clear to me, it's the changes of kernel 4 that I'm worried about.

Hi @ngc92

  • Looking at this change for softmax_forward_4 again, I think I will remove that part of the change from this PR so that we can review only softmax_forward_5.

  • I agree that my change transitions from block-level reduction to warp-level reduction, and I may need to consider the implications of this change, regardless of the test passing.

  • The fact that softmax_forward_5 in the main file attention.cuh uses warp-level reduction (utilizing __shfl_xor_sync instead of __shfl_down_sync) for all threads, link led me to update softmax_forward_4 initially.

Now, softmax_forward_4 change is reverted, PTAL.

@insop
Copy link
Author

insop commented Sep 28, 2024

@insop the optimization of kernel 5 looks clear to me, it's the changes of kernel 4 that I'm worried about.

Hi @ngc92

  • Looking at this change for softmax_forward_4 again, I think I will remove that part of the change from this PR so that we can review only softmax_forward_5.
  • I agree that my change transitions from block-level reduction to warp-level reduction, and I may need to consider the implications of this change, regardless of the test passing.
  • The fact that softmax_forward_5 in the main file attention.cuh uses warp-level reduction (utilizing __shfl_xor_sync instead of __shfl_down_sync) for all threads, link led me to update softmax_forward_4 initially.

Now, softmax_forward_4 change is reverted, PTAL.

Hi @ngc92
PTAL,

Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants