Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SM75 (Turing) support for FP6 kernel #942

Merged
merged 5 commits into from
Sep 29, 2024

Conversation

tobiasvanderwerff
Copy link
Contributor

The FP6 CUDA kernel is currently limited to >=SM80, i.e. at least Ampere generation NVIDIA GPUs. However, as mentioned in this issue, the FP6 kernel could theoretically work for any card with uint8_t and fp16 support.

This PR adds FP6 support for SM75 (Turing generation) GPUs. Given the modest support for SM75 GPUs (e.g. the NVIDIA T4) in torchao, I'm hoping this can be a step towards improving that.

Most important changes:

  • Replace asynchronous copy operations with vectorized loads (diff)
  • Replace m16n8k16 Tensor core operation with two m16n8k8 operations (diff)
  • Account for a difference in expected parameters for the ldmatrix operation (diff)

Before this, FP6 failed to run on SM75 (specifically, it would error out with "RuntimeError: operator torchao::quant_llm_linear does not exist").

Below are the benchmark and correctness test results (using ao/benchmarks/benchmark_fp6.py). Compared to the initial FP6 results done by @gau-nernst, we still see a speedup of at least 2x compared to FP16 in most cases.

Testing specs:

  • NVIDIA T4 GPU (SM75 generation)
  • Torch version: 2.6.0.dev20240917
m k n fp6_latency (ms) fp16_latency (ms) speedup (d/s) correct
1 8192 8192 223.05 517.205 2.31878 1
1 8192 10240 267.519 644.85 2.41048 1
1 8192 57344 1409.66 3618.91 2.56722 1
1 28672 8192 760.036 1800.83 2.3694 1
2 8192 8192 226.652 550.078 2.42697 1
2 8192 10240 271.363 687.359 2.53299 1
2 8192 57344 1442.43 3795.44 2.63128 1
2 28672 8192 768.993 1945.42 2.52982 1
4 8192 8192 229.543 567.315 2.4715 1
4 8192 10240 276.362 690.747 2.49943 1
4 8192 57344 1457.8 3805.93 2.61074 1
4 28672 8192 776.344 1965.42 2.53163 1
8 8192 8192 237.398 573 2.41367 1
8 8192 10240 284.349 810.79 2.85139 1
8 8192 57344 1521.18 3823.49 2.51349 1
8 28672 8192 783.763 1988.13 2.53665 1
16 8192 8192 273.57 582.142 2.12795 1
16 8192 10240 334.392 815.086 2.43751 1
16 8192 57344 1908.2 3857.97 2.02178 1
16 28672 8192 896.452 2002.43 2.23373 1
32 8192 8192 449.671 659.112 1.46576 1
32 8192 10240 548.003 821.81 1.49965 1
32 8192 57344 3255.59 4543.11 1.39548 1
32 28672 8192 1531.63 2403.52 1.56925 1
64 8192 8192 58.3608 672.911 11.5302 0
64 8192 10240 60.6918 887.669 14.6258 0
64 8192 57344 433.124 4812.82 11.1119 0
64 28672 8192 58.4863 2505.37 42.8369 0
128 8192 8192 58.2865 1096.99 18.8206 0
128 8192 10240 121.688 1202.46 9.88147 0
128 8192 57344 767.995 6741.57 8.77814 0
128 28672 8192 58.5768 3579.83 61.1134 0
256 8192 8192 194.799 1731.37 8.888 0
256 8192 10240 212.755 2273.48 10.6859 0
256 8192 57344 1186.61 11964.1 10.0825 0
256 28672 8192 195.011 5868.9 30.0953 0
512 8192 8192 388.129 3043.3 7.84093 0
512 8192 10240 199.79 3544.64 17.7418 0
512 8192 57344 51.1464 19975.7 390.559 0
512 28672 8192 388.241 10728.5 27.6335 0

As you can see, the kernel produces correct results for m < 64 but incorrect results for m >= 64. So in that sense this kernel is currently not functional for m >= 64. Note also that the latency is drastically reduced for m >= 64, indicating that the kernel is exiting early for some reason. I did some digging and documented my findings below.

Why does m >= 64 produce incorrect results?

After looking into this, I think the problem is allocation of too much shared memory. The A100 has quite a lot more shared memory (164 KB per SM) than the T4, so some assumptions made about shared memory may not hold for the T4.

First, I ran the following script to query the device properties of my T4 GPU:

#include <cuda_runtime.h>
#include <iostream>

int main() {
    cudaDeviceProp prop;
    cudaGetDeviceProperties(&prop, 0);
    std::cout << "Compute Capability: " << prop.major << "." << prop.minor << std::endl;
    std::cout << "Shared memory per block: " << prop.sharedMemPerBlock << " bytes" << std::endl;
    std::cout << "Shared memory per multiprocessor: " << prop.sharedMemPerMultiprocessor << " bytes" << std::endl;
    return 0;
}

Which outputs the following:

Compute Capability: 7.5
Shared memory per block: 49152 bytes
Shared memory per multiprocessor: 65536 bytes

So the available shared memory is indeed highly reduced compared to the A100. At the threshold of m>=64, the kernel gets called with a TilingConfig where WARP_COL_MMA_TENSORS is set to 8, instead of 4:

case 32: Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;

This leads to an increased size of TILE_N:

static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS;

Which in turn leads to an increased shared memory allocation:

static constexpr int SMEM_SIZE_B_TILE = TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2
static constexpr int SMEM_SIZE_C_TILE = TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4

Printing TilingConfig::SMEM_SIZE_C_TILE during runtime shows that when m goes from 32 to 64, TilingConfig::SMEM_SIZE_C_TILE goes from 38912 to 69362. Since 69362 bytes is more shared memory than the T4 has per block and per SM, it seems likely that this is causing problems. (But somehow, this does not throw any kind of error.) Setting WARP_COL_MMA_TENSORS to a maximum of 4, i.e. setting TilingConfig<4, 1, 4> here:

case 64: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
default: if (N_PowerOf2 % 128 != 0) {
printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
return cudaErrorUnknown;
}
Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;

and here:

case 64: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
default: if (N_PowerOf2 % 128 != 0) {
printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
return cudaErrorUnknown;
}
Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;

leads to correct results for all benchmarks:

m k n fp6_latency (ms) fp16_latency (ms) speedup (d/s) correct
1 8192 8192 223.036 516.966 2.31786 1
1 8192 10240 269.426 644.646 2.39266 1
1 8192 57344 1410.92 3648.77 2.58609 1
1 28672 8192 759.226 1804.14 2.37629 1
2 8192 8192 225.865 549.892 2.4346 1
2 8192 10240 271.288 687.137 2.53287 1
2 8192 57344 1441.76 3795.4 2.63247 1
2 28672 8192 765.302 1945.28 2.54185 1
4 8192 8192 230.63 567.303 2.45979 1
4 8192 10240 276.171 690.602 2.50063 1
4 8192 57344 1456.33 3805.46 2.61305 1
4 28672 8192 774.708 1965.41 2.53696 1
8 8192 8192 237.608 573.233 2.41252 1
8 8192 10240 284.422 811.645 2.85367 1
8 8192 57344 1521.34 3823.1 2.51298 1
8 28672 8192 786.987 1986.32 2.52395 1
16 8192 8192 272.092 581.971 2.13888 1
16 8192 10240 334.622 816.67 2.44058 1
16 8192 57344 1909.48 3857.44 2.02015 1
16 28672 8192 890.995 1978.75 2.22083 1
32 8192 8192 450.208 659.402 1.46466 1
32 8192 10240 547.148 825.829 1.50933 1
32 8192 57344 3261.63 4540.91 1.39222 1
32 28672 8192 1532.26 2396.69 1.56416 1
64 8192 8192 739.317 671.094 0.907721 1
64 8192 10240 886.388 839.778 0.947416 1
64 8192 57344 5192.64 4846.74 0.933387 1
64 28672 8192 2360.6 2637.51 1.1173 1
128 8192 8192 1243.43 1049.64 0.844148 1
128 8192 10240 1598.47 1226.33 0.767191 1
128 8192 57344 9111.21 6696.32 0.734954 1
128 28672 8192 4154.84 3649.96 0.878483 1
256 8192 8192 2493.73 1603.46 0.642994 1
256 8192 10240 3069.87 2045.3 0.666251 1
256 8192 57344 17182.4 11757.5 0.68428 1
256 28672 8192 8275.68 5849.96 0.706886 1
512 8192 8192 5030.14 3091.38 0.614572 1
512 8192 10240 5893.7 3557.85 0.603671 1
512 8192 57344 31538.6 19964.7 0.633024 1
512 28672 8192 16674 11045.6 0.662444 1

But as you can see, performance becomes quite slow now for m>=64 (slower than FP16). I might look into the details to see if there's any optimizations that can be made for the T4 in this particular case.

I tried creating a guard that checks for sm_75 and m>=64, but this is causing problems. Specifically, I tried to check for __CUDA_ARCH__ in fp6_linear.cu (as done here), but this does not seem to work for some reason (best reason I can think of is that __CUDA_ARCH__ is not defined because it is not device code, but that doesn't explain why the guard does work at the beginning of the file). However, if I instead try to put the directive inside the device code, it seems that functions like assert don't have any effect (I'm not really sure why).

So currently, the m>=64 case is not handled well (it will lead to incorrect results) and I'd be happy to hear suggestions about how to deal with this case.

Final thoughts

  • I noticed the splitK paramater for the FP6 kernel is set based on a dictionary that is optimized for the A100 GPU. It might not be optimal for the T4, but I don't have any insights into how much it matters.

Any feedback is very welcome!

Copy link

pytorch-bot bot commented Sep 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/942

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 56718f9 with merge base 2dea315 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 25, 2024
@gau-nernst
Copy link
Collaborator

@tobiasvanderwerff Thank you for your contribution! The debug when m>=64 is very insightful.

I was the one added this FP6-LLM kernel to torchao. I don't have enough CUDA knowledge to fully understand the changes, but it seems reasonable to me. Have you had the chance to discuss this with the original author about this change? https://github.com/usyd-fsalab/fp6_llm

I will let @msaroufim decide if we should support / improve the kernel for sm75 at all. Concerns include CI testing + deviation from upstream, as well as not worth the efforts due to old age.

General comments:

  • Low performance for large M is to be expected. I don't think you need to try to optimize for it
  • Regarding guard for M>=64, if the macro doesn't work, I think you can check with either torch.cuda.get_device_capability() or cudaDeviceProp
  • Regarding the splitK, yea I don't really like it too. I wish we have an autotune for CUDA kernel, similar to triton 😄
  • (Can do later, once everything is finalized) Can you add a summary of your changes at the top of each modified file?

Side question. Would you be able to modify this kernel to support BF16? Since all modern models use BF16, it would be useful. I figure it would be a matter of changing the dequant logic + MMA instructions, but again, I'm not too confident with my CUDA skills 😅.

@tobiasvanderwerff
Copy link
Contributor Author

@gau-nernst let me address your comments/questions one by one.

Have you had the chance to discuss this with the original author about this change?

I did not have contact with the original authors. Looking at their repo and paper, I don't have the impression that they had GPUs before A100 generation in mind when they implemented FP6. My main motivation for writing this port is that I thought it was an interesting challenge to try and make this work on a slightly older GPU for a GPU poor person like myself who finds an A100 too expensive to rent for extended periods of time :)

I will let @msaroufim decide if we should support / improve the kernel for sm75 at all. Concerns include CI testing + deviation from upstream, as well as not worth the efforts due to old age.

That's understandable. A little while back I asked around in the GPU-mode discord (formerly CUDA-mode), where @HDCharles mentioned there is potentially a demand for supporting more architectures than what is currently supported.

Regarding guard for M>=64, if the macro doesn't work, I think you can check with either torch.cuda.get_device_capability() or cudaDeviceProp

That's a good idea -- I will try it out.

(Can do later, once everything is finalized) Can you add a summary of your changes at the top of each modified file?

Err, you mean a comment in the C++/CUDA files? As in, add a comment at the top of each changed file saying how I added SM75 support to it? If so, I could do definitely do that.

Side question. Would you be able to modify this kernel to support BF16? Since all modern models use BF16, it would be useful. I figure it would be a matter of changing the dequant logic + MMA instructions, but again, I'm not too confident with my CUDA skills 😅.

I'd love to try that, but unfortunately the T4 does not support BF16 (it requires SM80). So that would require an A100 to work with, which is currently not feasible for me.

@gau-nernst
Copy link
Collaborator

@tobiasvanderwerff Thank you for your response

I'm just checking if you ask the original author about this change since he probably can give more feedback in terms of accuracy and perf wise. No worries!

Since bad perf for large M is to be expected (actually FP6-LLM kernel is also slower than FP16 on A100 for M>=128), I think we don't need to check for M>=64? Just use __CUDA_ARCH__ to select tile size accordingly.

Err, you mean a comment in the C++/CUDA files? As in, add a comment at the top of each changed file saying how I added SM75 support to it? If so, I could do definitely do that.

Yes, exactly.

I'd love to try that, but unfortunately the T4 does not support BF16 (it requires SM80). So that would require an A100 to work with, which is currently not feasible for me.

No worries at all! In case you don't know, for sm80 stuff, Lightning AI gives out free 22hrs/month of L4. Some GPU clouds also provide rental service for consumer cards, like 3090 and 4090, which is much cheaper than A100. Just something to consider.

Some other pointers

@msaroufim
Copy link
Member

msaroufim commented Sep 26, 2024

I like this PR but am a bit conflicted because we won't be running this in CI so won't get signal on whether SM75 is working and if fp6 in upstream is still actively maintained it'll make it more difficult for us to stay up to date

@tobiasvanderwerff if your issue is primarily compute availability lemme know on the server what you're thinking of working on medium term and we might be able to arrange something

@gau-nernst
Copy link
Collaborator

@msaroufim Regarding CI testing, I think we can treat this as "not officially supported", like Windows support. So interested users have to compile from source + no guarantee that future versions will continue working. Since the change is quite small, I think it's still somewhat reasonable.

@@ -7,7 +7,7 @@
from tqdm import tqdm


def benchmark(m: int, k: int, n: int):
def benchmark(m: int, n: int, k: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change intentional? usually we talk about matmuls as mkn i.e. m x k activation and k x n weight (odd to reverse them and i'm unsure if benchmarks previously were assuming the other ordering)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this again, I indeed made a mistake. The benchmark results are still correct, it's just the list of shapes are different.

I took the benchmark shapes from here: https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/run.sh - It's a bit confusing since the author use different variable names...

The code generating the list of shapes (under __name__ == "__main__") are correct (follow the author), and it calls benchmark(m, n, k). If you think we should benchmark a different sets of shapes, it should be good too!

In summary, this change corrects my previous mistake.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's intentional. The function signature was def benchmark(m: int, k: int, n: int):, but arguments were passed as (m, n, k), so I thought that that was unnecessarily confusing and wanted to change the ordering in either the function call or the function signature. In the function itself, the shapes become m x k for the activation and n x k for the weight.

I see one benchmark example (benchmark_gpu_sparsity, see below) where the ordering is m, k, n, so let me change the function signature back to that ordering for consistency.

def run_gpu_sparse_benchmark(m, k, n, args):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed your comment before I posted my own @gau-nernst. Thanks for clarifying! I actually noticed that m gets passed as n to the actual kernel, which is slightly confusing. If you don't mind, I'll change this for consistency. I don't think it should affect the results, expect that m will be switched by n in the performance table.

Copy link
Collaborator

@gau-nernst gau-nernst Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tobiasvanderwerff You mean k and n right? Your current change looks correct. Yea it doesn't affect the results, it will only show results for different shapes instead.

Copy link
Contributor Author

@tobiasvanderwerff tobiasvanderwerff Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant is slightly different @gau-nernst. I'm referring to the fact that the original authors do some odd switching of the shapes in fp6_linear.cu. The arguments that get passed are _in_feats (activations of shape m x k) and _weights (shape n x k), but then they unpack the shapes as M = _weights.size(0), K = _in_feats.shape(0), N = _in_feats.shape(1) (see below).

int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
int num_out_channels = _weights.size(0);
TORCH_CHECK(num_in_channels % 64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels);
TORCH_CHECK((num_in_channels / 8 * NBITS) == _weights.size(1)); // Making sure the K dimension is matched.
//
int M = num_out_channels;
int K = num_in_channels;
int N = num_in_feats;

So even though we pass the arguments correctly to the benchmark function as m, k, n, the names get switched inside the kernel. Anyway, this is mainly confusing when debugging the kernel, but it might actually be fine to just leave it as is.

Copy link
Contributor

@HDCharles HDCharles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i like hte PR overall, the slowdown for large m is not unexpected for kernels focusing on weight only quantization, we expect for bsz small, your io speed dominates (yielding a speedup) and then when bsz is very large you'll asymptotically approach fp16 x fp16 speed but for a intermediate range the quantization overhead is going to be maximally painful leading to a significant slowdown.

i'll accept tentatively since it sounds like there are still some small bits to address

@tobiasvanderwerff
Copy link
Contributor Author

tobiasvanderwerff commented Sep 26, 2024

The kernel now handles the m>=64 edge case to produce correct results for all shapes used in the benchmark (see 650ba03).

I couldn't use the __CUDA_ARCH__ constant, so I'm instead checking the CUDA arch at runtime. It unfortunately makes the kernel launch code a bit uglier, since the template parameters of TilingConfig need to be known at compile time. I'm not the most experienced with C++ so there may be a better solution, but this is the best I could come up with to handle this.

Copy link
Collaborator

@gau-nernst gau-nernst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my feedback! Everything looks very nice!

I think __CUDA_ARCH__ is not available on host code, so we can't use it outside CUDA kernel.

torchao/csrc/cuda/fp6_llm/fp6_linear.cu Show resolved Hide resolved
torchao/csrc/cuda/fp6_llm/fp6_linear.cu Show resolved Hide resolved
@tobiasvanderwerff
Copy link
Contributor Author

tobiasvanderwerff commented Sep 27, 2024

Benchmark results on Llama-2-7b-chat-hf

Tested on Lightning Studio.

Testing specs:

  • T4 GPU
  • Torch version: 2.4.1

float16

python generate.py --compile --precision float16

Average tokens/sec: 19.04
Average Bandwidth: 251.60 GB/s
Peak Memory Usage: 13.90 GB
Model Size: 13.21 GB

fp6

python generate.py --compile --precision float16 --quantization fp6

Average tokens/sec: 40.52
Average Bandwidth: 200.93 GB/s
Peak Memory Usage: 6.61 GB
Model Size: 4.96 GB

@tobiasvanderwerff
Copy link
Contributor Author

Fp6 eval results

python eval.py --compile --precision float16 --quantization fp6

wikitext:

  • word_perplexity: 12.3653
  • byte_perplexity: 1.6005
  • bits_per_byte: 0.6785

@msaroufim
Copy link
Member

@gau-nernst feel free to merge this whenever you feel it's ready

@gau-nernst gau-nernst merged commit 96e8fee into pytorch:main Sep 29, 2024
17 checks passed
@tobiasvanderwerff tobiasvanderwerff deleted the fp6-sm75 branch September 29, 2024 06:45
melvinebenezer pushed a commit to melvinebenezer/ao that referenced this pull request Oct 3, 2024
* SM75 support for FP6 kernel

* More consistent argument ordering in benchmark function

* Add a note about SM75 support in the floatx README

* Handle FP6 + SM75 + N>=64 edge case

* Document changes made for FP6 SM75 support
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants