-
Notifications
You must be signed in to change notification settings - Fork 174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SM75 (Turing) support for FP6 kernel #942
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/942
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 56718f9 with merge base 2dea315 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@tobiasvanderwerff Thank you for your contribution! The debug when m>=64 is very insightful. I was the one added this FP6-LLM kernel to torchao. I don't have enough CUDA knowledge to fully understand the changes, but it seems reasonable to me. Have you had the chance to discuss this with the original author about this change? https://github.com/usyd-fsalab/fp6_llm I will let @msaroufim decide if we should support / improve the kernel for sm75 at all. Concerns include CI testing + deviation from upstream, as well as not worth the efforts due to old age. General comments:
Side question. Would you be able to modify this kernel to support BF16? Since all modern models use BF16, it would be useful. I figure it would be a matter of changing the dequant logic + MMA instructions, but again, I'm not too confident with my CUDA skills 😅. |
@gau-nernst let me address your comments/questions one by one.
I did not have contact with the original authors. Looking at their repo and paper, I don't have the impression that they had GPUs before A100 generation in mind when they implemented FP6. My main motivation for writing this port is that I thought it was an interesting challenge to try and make this work on a slightly older GPU for a GPU poor person like myself who finds an A100 too expensive to rent for extended periods of time :)
That's understandable. A little while back I asked around in the GPU-mode discord (formerly CUDA-mode), where @HDCharles mentioned there is potentially a demand for supporting more architectures than what is currently supported.
That's a good idea -- I will try it out.
Err, you mean a comment in the C++/CUDA files? As in, add a comment at the top of each changed file saying how I added SM75 support to it? If so, I could do definitely do that.
I'd love to try that, but unfortunately the T4 does not support BF16 (it requires SM80). So that would require an A100 to work with, which is currently not feasible for me. |
@tobiasvanderwerff Thank you for your response I'm just checking if you ask the original author about this change since he probably can give more feedback in terms of accuracy and perf wise. No worries! Since bad perf for large M is to be expected (actually FP6-LLM kernel is also slower than FP16 on A100 for M>=128), I think we don't need to check for M>=64? Just use
Yes, exactly.
No worries at all! In case you don't know, for sm80 stuff, Lightning AI gives out free 22hrs/month of L4. Some GPU clouds also provide rental service for consumer cards, like 3090 and 4090, which is much cheaper than A100. Just something to consider. Some other pointers
|
I like this PR but am a bit conflicted because we won't be running this in CI so won't get signal on whether SM75 is working and if fp6 in upstream is still actively maintained it'll make it more difficult for us to stay up to date @tobiasvanderwerff if your issue is primarily compute availability lemme know on the server what you're thinking of working on medium term and we might be able to arrange something |
@msaroufim Regarding CI testing, I think we can treat this as "not officially supported", like Windows support. So interested users have to compile from source + no guarantee that future versions will continue working. Since the change is quite small, I think it's still somewhat reasonable. |
benchmarks/benchmark_fp6.py
Outdated
@@ -7,7 +7,7 @@ | |||
from tqdm import tqdm | |||
|
|||
|
|||
def benchmark(m: int, k: int, n: int): | |||
def benchmark(m: int, n: int, k: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this change intentional? usually we talk about matmuls as mkn i.e. m x k activation and k x n weight (odd to reverse them and i'm unsure if benchmarks previously were assuming the other ordering)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this again, I indeed made a mistake. The benchmark results are still correct, it's just the list of shapes are different.
I took the benchmark shapes from here: https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/run.sh - It's a bit confusing since the author use different variable names...
The code generating the list of shapes (under __name__ == "__main__"
) are correct (follow the author), and it calls benchmark(m, n, k)
. If you think we should benchmark a different sets of shapes, it should be good too!
In summary, this change corrects my previous mistake.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's intentional. The function signature was def benchmark(m: int, k: int, n: int):
, but arguments were passed as (m, n, k)
, so I thought that that was unnecessarily confusing and wanted to change the ordering in either the function call or the function signature. In the function itself, the shapes become m x k for the activation and n x k for the weight.
I see one benchmark example (benchmark_gpu_sparsity
, see below) where the ordering is m, k, n, so let me change the function signature back to that ordering for consistency.
ao/benchmarks/benchmark_gpu_sparsity.py
Line 25 in 4b5b5ee
def run_gpu_sparse_benchmark(m, k, n, args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed your comment before I posted my own @gau-nernst. Thanks for clarifying! I actually noticed that m
gets passed as n
to the actual kernel, which is slightly confusing. If you don't mind, I'll change this for consistency. I don't think it should affect the results, expect that m
will be switched by n
in the performance table.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tobiasvanderwerff You mean k
and n
right? Your current change looks correct. Yea it doesn't affect the results, it will only show results for different shapes instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I meant is slightly different @gau-nernst. I'm referring to the fact that the original authors do some odd switching of the shapes in fp6_linear.cu
. The arguments that get passed are _in_feats
(activations of shape m x k) and _weights
(shape n x k), but then they unpack the shapes as M = _weights.size(0)
, K = _in_feats.shape(0)
, N = _in_feats.shape(1)
(see below).
ao/torchao/csrc/cuda/fp6_llm/fp6_linear.cu
Lines 150 to 158 in 4b5b5ee
int num_in_feats = _in_feats.size(0); | |
int num_in_channels = _in_feats.size(1); | |
int num_out_channels = _weights.size(0); | |
TORCH_CHECK(num_in_channels % 64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels); | |
TORCH_CHECK((num_in_channels / 8 * NBITS) == _weights.size(1)); // Making sure the K dimension is matched. | |
// | |
int M = num_out_channels; | |
int K = num_in_channels; | |
int N = num_in_feats; |
So even though we pass the arguments correctly to the benchmark function as m, k, n, the names get switched inside the kernel. Anyway, this is mainly confusing when debugging the kernel, but it might actually be fine to just leave it as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i like hte PR overall, the slowdown for large m is not unexpected for kernels focusing on weight only quantization, we expect for bsz small, your io speed dominates (yielding a speedup) and then when bsz is very large you'll asymptotically approach fp16 x fp16 speed but for a intermediate range the quantization overhead is going to be maximally painful leading to a significant slowdown.
i'll accept tentatively since it sounds like there are still some small bits to address
The kernel now handles the I couldn't use the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing my feedback! Everything looks very nice!
I think __CUDA_ARCH__
is not available on host code, so we can't use it outside CUDA kernel.
Benchmark results on Llama-2-7b-chat-hfTested on Lightning Studio. Testing specs:
float16python generate.py --compile --precision float16 Average tokens/sec: 19.04 fp6python generate.py --compile --precision float16 --quantization fp6 Average tokens/sec: 40.52 |
Fp6 eval resultspython eval.py --compile --precision float16 --quantization fp6 wikitext:
|
@gau-nernst feel free to merge this whenever you feel it's ready |
* SM75 support for FP6 kernel * More consistent argument ordering in benchmark function * Add a note about SM75 support in the floatx README * Handle FP6 + SM75 + N>=64 edge case * Document changes made for FP6 SM75 support
The FP6 CUDA kernel is currently limited to >=SM80, i.e. at least Ampere generation NVIDIA GPUs. However, as mentioned in this issue, the FP6 kernel could theoretically work for any card with uint8_t and fp16 support.
This PR adds FP6 support for SM75 (Turing generation) GPUs. Given the modest support for SM75 GPUs (e.g. the NVIDIA T4) in
torchao
, I'm hoping this can be a step towards improving that.Most important changes:
m16n8k16
Tensor core operation with twom16n8k8
operations (diff)ldmatrix
operation (diff)Before this, FP6 failed to run on SM75 (specifically, it would error out with "RuntimeError: operator torchao::quant_llm_linear does not exist").
Below are the benchmark and correctness test results (using
ao/benchmarks/benchmark_fp6.py
). Compared to the initial FP6 results done by @gau-nernst, we still see a speedup of at least 2x compared to FP16 in most cases.Testing specs:
As you can see, the kernel produces correct results for
m < 64
but incorrect results form >= 64
. So in that sense this kernel is currently not functional form >= 64
. Note also that the latency is drastically reduced form >= 64
, indicating that the kernel is exiting early for some reason. I did some digging and documented my findings below.Why does
m >= 64
produce incorrect results?After looking into this, I think the problem is allocation of too much shared memory. The A100 has quite a lot more shared memory (164 KB per SM) than the T4, so some assumptions made about shared memory may not hold for the T4.
First, I ran the following script to query the device properties of my T4 GPU:
Which outputs the following:
So the available shared memory is indeed highly reduced compared to the A100. At the threshold of
m>=64
, the kernel gets called with aTilingConfig
whereWARP_COL_MMA_TENSORS
is set to 8, instead of 4:ao/torchao/csrc/cuda/fp6_llm/fp6_linear.cu
Lines 101 to 103 in 2dea315
This leads to an increased size of
TILE_N
:ao/torchao/csrc/cuda/fp6_llm/configs.h
Line 56 in 2dea315
Which in turn leads to an increased shared memory allocation:
ao/torchao/csrc/cuda/fp6_llm/configs.h
Lines 62 to 63 in 2dea315
Printing
TilingConfig::SMEM_SIZE_C_TILE
during runtime shows that whenm
goes from 32 to 64,TilingConfig::SMEM_SIZE_C_TILE
goes from 38912 to 69362. Since 69362 bytes is more shared memory than the T4 has per block and per SM, it seems likely that this is causing problems. (But somehow, this does not throw any kind of error.) SettingWARP_COL_MMA_TENSORS
to a maximum of 4, i.e. settingTilingConfig<4, 1, 4>
here:ao/torchao/csrc/cuda/fp6_llm/fp6_linear.cu
Lines 88 to 94 in 2dea315
and here:
ao/torchao/csrc/cuda/fp6_llm/fp6_linear.cu
Lines 102 to 108 in 2dea315
leads to correct results for all benchmarks:
But as you can see, performance becomes quite slow now for
m>=64
(slower than FP16). I might look into the details to see if there's any optimizations that can be made for the T4 in this particular case.I tried creating a guard that checks for sm_75 and
m>=64
, but this is causing problems. Specifically, I tried to check for__CUDA_ARCH__
infp6_linear.cu
(as done here), but this does not seem to work for some reason (best reason I can think of is that__CUDA_ARCH__
is not defined because it is not device code, but that doesn't explain why the guard does work at the beginning of the file). However, if I instead try to put the directive inside the device code, it seems that functions likeassert
don't have any effect (I'm not really sure why).So currently, the
m>=64
case is not handled well (it will lead to incorrect results) and I'd be happy to hear suggestions about how to deal with this case.Final thoughts
splitK
paramater for the FP6 kernel is set based on a dictionary that is optimized for the A100 GPU. It might not be optimal for the T4, but I don't have any insights into how much it matters.Any feedback is very welcome!