Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP6 dtype! #208

Open
NicolasMejiaPetit opened this issue Apr 28, 2024 · 31 comments · Fixed by #223 or #358
Open

FP6 dtype! #208

NicolasMejiaPetit opened this issue Apr 28, 2024 · 31 comments · Fixed by #223 or #358
Labels
enhancement New feature or request

Comments

@NicolasMejiaPetit
Copy link

NicolasMejiaPetit commented Apr 28, 2024

🚀 The feature, motivation and pitch

https://arxiv.org/abs/2401.14112

I think you guys are really going to like this.
The deepspeed developers introduce FP6 datatype on cards without fp8 support, while maintaining full tensor core suppourt using a kernel they created called tc-fpX. Tests were done on a a100! And they achieved 1.69x-2.65x inference performance! And I assume this can be transferred over to training (with the exception of possibly the KV cache, and embedding module). This is really exiting, this will breathe new life into the rapidly aging a100 due to the introduction of the h100’s fp8.

It was merged into deepspeed in this commit:
microsoft/DeepSpeed@ccfdb84

Getting this pushed into the Pytorch as a dtype, that would be a major win. These are the benefits FP6 provides:
IMG_4696

Alternatives

These kernels shouldn’t be limited by only the a100, they theoretically could work on any card with uint8_t and fp16 support. Provided these kernels were only written for a100 so without modification it might only work on ampere cards.

Additional context

The tc-FPx kernel essentially takes 4fp16 values, quantizes them to fp6 with some place holders. Then they get pushed into an array built of 3x Uint8_t. As shown here:
IMG_4686
IMG_4688
IMG_4689

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel

@JiHa-Kim
Copy link

Seconded!

@vkuzo
Copy link
Contributor

vkuzo commented May 2, 2024

This is great and the inference e2e integration like a good candidate for addition to https://github.com/pytorch/ao . Let us know if you are interested in contributing!

As far fp6 dtype in PT core, check out https://dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833 for the current thinking on adding new dtypes. We do expect fp6 to get silicon support in the future so it would be a good candidate to add when that silicon support is closer. We don't actually need an fp6 dtype in core to enable w6a16 as implemented in the code linked to this issue.

@msaroufim msaroufim added the enhancement New feature or request label May 7, 2024
@msaroufim
Copy link
Member

Keeping this open because we still need to do the subclass work and the end to end integration

@gau-nernst
Copy link
Collaborator

gau-nernst commented May 21, 2024

Tracker:

@cpuhrsch
Copy link
Contributor

Just a nit on "User-friendly API (either Tensor subclass or FP6Linear module)". You can implement an FP6Linear module using the a Tensor subclass based fp6 dtype. Just call self.weight = nn.Parameter(to_fp6(self.weight)) within the __init__ of your nn.Linear replacement. The FP6Linear module then is one way of injecting that code into the model. It seems like a very popular way of doing that, so it's reasonable to provide as a primitive. Pretty much I'm only pointing out that you don't duplicate work by doing both :) You can then also make it easier for people to add coverage as shown in our toy example

@torchao.dtypes.nf4tensor.implements([torch.ops.aten.gelu.default])
def gelu(func, *args, **kwargs):
# The torch dispatch convention is to pass all args and kwargs via the
# args input.
# args[0] here corresponds to the original *args
# args[1] here corresponds to the original *kwargs
# We're getting the first argument of the original args
inp = args[0][0]
# There's a way very inefficient way to implement it
return to_nf4(torch.nn.functional.gelu(inp.to(torch.float32)), inp.block_size, inp.scaler_block_size)
print(f"gelu(a): {torch.nn.functional.gelu(a)}")
print(f"gelu(a_nf4): {torch.nn.functional.gelu(a_nf4)}")

@gau-nernst
Copy link
Collaborator

Thank you for your feedback. They are just a few suggested ways as discussed with @msaroufim, we haven't decided on what is the final API for FP6 yet.

Of course if we have FP6 subclass, we don't need FP6Linear anymore. But implementing subclass is harder, and almost all ops, except F.linear, do not make sense for FP6. This is because in FP6-LLM, the weight is split and re-arranged in a certain way to optimize global memory access for tensor cores.

I tried implementing FP6 subclass in #223 (and removed it in the end). Even implementing dispatch for aten.linear feels finicky because it seems pytorch will dispatch aten.mm (or aten.addmm) instead, so I have to store the transposed flag, set and check it correctly before calling the FP6-linear kernel (the CUDA kernel only works with A @ W.T i.e. Linear layer). To support other ops, we would need to (1) re-arrange the weight in natural order and (2) dequantize to FP32/FP16/BF16 (and reverse it back to FP6). It would be too expensive.

So I think implementing a custom FP6Linear layer would be easier, since we don't need to guarantee anything about the weight i.e. The weight itself is an internal implementation detail.

Just some of my thoughts when working on this. Once #248 is merged, I will work on adapting weight splitting logic (currently it's a CPU-only C++ extension). Note that the original code does not have weight un-splitting logic.

@gau-nernst
Copy link
Collaborator

gau-nernst commented Jun 1, 2024

Just to update people here on the progress. We have added a user API for FP6-LLM

from torchao.quantization.fp6_llm import convert_fp6_llm

convert_fp6_llm(model)  # convert model in-place, replacing nn.Linear modules with Fp6LlmLinear

Everything should work (in eager mode). Some local end2end testing by me and @Iron-Bound show that it works as expected. We will probably close this issue once we have an LLM eval in this repo for uniform evaluation across quantization methods (there is also a small difference in how we handle FP16->FP6 quantization compared to the released code, so I want to make sure this difference is not significant).

Some known limitations:

  • The kernel is for FP16 activation - FP6_E3M2 weight. If your model is BF16, it should still work, but you will spend some small overhead converting BF16 <-> FP16. (perhaps we can implement a BF16 version in a future PR? not sure how much work is required - only need to change weight dequant logic and call the correct tensor core instruction?)
  • When tested with gpt-fast, torch.compile does not work for an FP6-LLM end2end model (it does work for small test cases though - we have CI for that). Need to debug this.
    • UPDATE: adding torch._inductor.config.triton.cudagraph_trees = False fixes the issue.

Data from gpt-fast for meta-llama/Llama-2-7b-chat-hf on 4070Ti SUPER

name tokens/s
BF16 baseline (w/ compile) 46.72
int8 (w/ compile) 87.95
FP6-LLM (w/ compile) 107.11

hellaswag eval (from https://github.com/EleutherAI/lm-evaluation-harness) for meta-llama/Llama-2-7b-chat-hf (credits to @Iron-Bound)

name acc_norm
baseline 75.50
FP6-LLM 75.36

@supriyar
Copy link
Contributor

supriyar commented Jun 5, 2024

@gau-nernst thanks for working on this! Very excited to see the initial results.

torch.compile does not work for an FP6-LLM end2end model (it does work for small test cases though - we have CI for that). Need to debug this.
UPDATE: adding torch._inductor.config.triton.cudagraph_trees = False fixes the issue.

do you have new perf benchmark results with the updated setting and torch.compile + FP6?

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Jun 5, 2024

@gau-nernst - If you have the error trace due to torch._inductor.config.triton.cudagraph_trees = True, we might be able to fix this bug. I don't think setting this to False will necessarily prevent the use of CUDA graphs, but if it does we should fix it, because not using CUDA graphs will likely slow things down a lot.

@gau-nernst
Copy link
Collaborator

@supriyar I updated the results in the table.

@cpuhrsch There is no error, but the outputs become NaN. I added FP6-LLM integration to gpt-fast in my branch. You can take a look to find the issue: https://github.com/gau-nernst/gpt-fast/blob/d1304ed1032f2eabc909fa0259a8d87750abade7/generate.py#L326-L327

  • I'm guessing it has to do with the in-place modification in KV cache. When I comment out L71 and L326-327 in generate.py, there is index out of bounds error _call_with_frames_removed: block: [0,0,0], thread: [95,0,0] Assertion index out of bounds: 0 <= tmp4 < 32000 failed. (note that if I only comment out L326-327, compile works but outputs is NaN)

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Jun 6, 2024

@gau-nernst - Hm, strange. Do you might attaching the code you generate from TORCH_LOGS='output_code' as a gist?

@gau-nernst
Copy link
Collaborator

@cpuhrsch I put generated code for without and with setting torch._inductor.config.triton.cudagraph_trees = False in the same gist below.
https://gist.github.com/gau-nernst/cde24dabe000f11991030609fc497a80

(Didn't realize GitHub won't collapse the view. You need to scroll to the middle of the page to see the other file)

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Jun 6, 2024

@gau-nernst - Looks like there's no real difference between the two

% diff ~/Downloads/nan_outputs.py ~/Downloads/normal_outputs.py
0a1
> # this is set: torch._inductor.config.triton.cudagraph_trees = False

paging @eellison if he has some time - Before further investigation: is there a common reason why cudagraph_trees could cause NaN for gpt-fast under this novel FP6?

@gau-nernst
Copy link
Collaborator

@cpuhrsch I just double-checked to make sure I didn't upload the same file by mistake. Indeed, the generated code is identical. It even writes to the same file /tmp/torchinductor_ubuntu/kr/ckrsfl6uftxbnzbgtttmf33opjqsbspylnkf2iaehq5u656uejth.py (I guess this is some kind of hash of the generated code?)

@eellison
Copy link

eellison commented Jun 6, 2024

@gau-nernst would you mind running with TORCH_LOGS="cudagraphs" ? would you also give full repro command ?

cc @BoyuanFeng

@gau-nernst
Copy link
Collaborator

@eellison when I run with TORCH_LOGS="cudagraphs", I don't see anything being output. Perhaps I miss something.

Anyway, to reproduce, you can

  1. Go to my branch: https://github.com/gau-nernst/gpt-fast/tree/fp6_llm
  2. Comment out these lines: https://github.com/gau-nernst/gpt-fast/blob/d1304ed1032f2eabc909fa0259a8d87750abade7/generate.py#L326-L327
  3. Run with python generate.py --checkpoint_path checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --prompt "Hello, my name is" --compile --fp6_llm (after running the prepare model step i.e. ./scripts/prepare.sh meta-llama/Llama-2-7b-chat-hf)

@eellison
Copy link

You can error here with pytorch/pytorch#125264 patched that gives actionable feedback. cc @isuruf

@Vezora-AI
Copy link

Vezora-AI commented Jun 15, 2024

I'm getting a bunch of compilation errors, in the fp6 C++ code while compiling the wheels, is there a plan for more sets of PIP wheels, the last wheels were created in late may; before fp6 was added (for both nightly and standard).

Error.txt

I tried to compile with both GCC and MVSC and got the same error on both.

@msaroufim
Copy link
Member

@Vezora-AI please try pip install torchao torch==2.3.1 --extra-index-url https://download.pytorch.org/whl/test/cu121 --force-reinstall

Otherwise you can try building from source using pip install .

If you're building on Windows you might need the patch from #305 if you can get it working it'll make me more confident to merge the code since I don't have Windows CI setup in ao yet

@Vezora-AI
Copy link

Awesome thank you! I'll try the 305 changes, do i need to use torch 2.3.1? or can i stay on 2.3.0 there is not current available xformers for 2.3.1 from what i can see xformers Issue.

@msaroufim
Copy link
Member

In this case you'll need to be a on a specific version of pytorch unfortunately, I'm open to ideas for how to make this better though

@Vezora-AI
Copy link

Vezora-AI commented Jun 16, 2024

In this case you'll need to be a on a specific version of pytorch unfortunately, I'm open to ideas for how to make this better though

I was able to compile it! strangely enough; It failed compiling with x86 native build tools, but it worked using MVSC as the compiler inside of git bash. (I re-cloned the repository, and checked out 305, and set the GCX compiler to MVSC and it compiled.)

To avoid having to downgrade; i am going to build xformers from source; and cross my fingers that 2.3.1 doesn't bring and breaking changes.
gitbash.txt

Also i can no longer find it but apparently visual studio build tools updated to 17.10 brings breaking changes to mvsc, but anything under it is fine, so avoid updating to anyone whom this may concern.
Actually here is one describing the same issue, but its not the same issue i came across.

@Vezora-AI
Copy link

@msaroufim I was easily able to compile xformers, and pip did throw any dependency warnings at me about the torch versions. However with the changes made, it reverted back to the old API, (since it went back to the old branch with 305 changes) so i couldn't use the nn.linear replacement. I was however able to use the 'from torchao.quantization.fp6_llm import convert_fp6_llm' api succesfully. Can the changes be merged in to the latest branch, if there is no change to any other platforms?

@msaroufim
Copy link
Member

msaroufim commented Jun 17, 2024

I suggest you merge the changes from #305 into main instead of just checking out #305 , lemme know if that works - will need to add windows CI on our end before merging this

@Vezora-AI
Copy link

Vezora-AI commented Jun 17, 2024

@msaroufim gotcha! i got merged and installed so pip now shows AO version 3.0 instead of 2.0, I build all the remaining dependencies; but i get 5 errors during the test_fp6.py. This is due to triton, and the codecache.py not being able to locate a windows C++ compiler, i changed the location of the linux triton path (in the _triton.py) to my windows triton that runs on a version of LLVM that uses MVSC and i got a bunch of compilation errors. (my triton install can run all the Unsloth triton kernels, with no modifications, so i know it can work, just don't know exactly how to integrate it, since its not using the standard 'import triton' api). I also tried to modify the codecache.py, by directly linking it to MVSC's cl.exe path, but then i got unexpected keywords, i know I'm on the right path (under is the errors i got after i linked codecache to the MVSC compiler). I'm gonna take a look at this again tomorrow.
errormessage1.txt

@gau-nernst
Copy link
Collaborator

Triton does not support windows, so any torch.compile() on CUDA will error out.

You can try this to check if the FP6-LLM kernel works for you: https://github.com/pytorch/ao/tree/main/torchao/prototype/fp6_llm

@Vezora-AI
Copy link

Vezora-AI commented Jun 17, 2024

Triton does not support windows, so any torch.compile() on CUDA will error out.

You can try this to check if the FP6-LLM kernel works for you: https://github.com/pytorch/ao/tree/main/torchao/prototype/fp6_llm

Gotcha will do! I do know triton technically isn't supported on windows, but it kinda is. (I don't wanna deviate too much off-topic, but wanted to explain why i was trying to fix Torch Compile) I have triton installed and it runs and compiles kernels and everything. The wheel was created by here by wkpark https://github.com/wkpark/triton i haven't looked into the CI file, but i used his wheels and his LLVM library compiled for MVSC and it runs. Well enough to compile and run all of Unsloth's kernels. However, ill stick with using it w/o Torch compile.

Regardless though some of the errors I got I don't think were related to Torch Compile, the terminal some of the tests that failed in test/prototype/test_fp6_llm.py/ read "could not find C++ compiler", then when I did link pytorch's codecache.py to the compiler, I got unexpected keywords error, from MVSC (Which I believe the fix might be as simple as adding some arguments and removing any unused ones.) Side note: I did also test the convert_fp6_llm api on windows, it does work.

@msaroufim msaroufim reopened this Jun 17, 2024
@gau-nernst
Copy link
Collaborator

@Vezora-AI the latest main branch has support for Windows build now. Can you try if you can install it?

To make sure the kernel is compiled correctly, run the following command from the root of the repo

python benchmarks/benchmark_fp6_llm.py

Expected outputs are shown in #396

To use it for your model, you can look at https://github.com/pytorch/ao/tree/main/torchao/prototype/fp6_llm

@msaroufim
Copy link
Member

Btw @Vezora-AI if you're using the fp6 kernels for a real world use case would love to hear from you - we hang out on discord.gg/cudamode in the #torchao channel

@gau-nernst
Copy link
Collaborator

gau-nernst commented Jun 25, 2024

@cpuhrsch @eellison regarding NaN outputs when torch._inductor.config.triton.cudagraph_trees = True, would you happen to know if using the default CUDA stream might be the problem?

fp6_linear_kernel(0, // Using default stream here.

Additionally, what is the recommended CUDA stream to use for custom CUDA kernel? is the default 0 ok, or something like here https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html is better?

UPDATE: I can confirm that using auto stream = at::cuda::getCurrentCUDAStream(); fixes the NaN outputs problem with torch._inductor.config.triton.cudagraph_trees = True (following tinygemm https://github.com/pytorch/pytorch/blob/bbdeff76fc1205d13358ee6c147a4d8930b562ee/aten/src/ATen/native/cuda/int4mm.cu#L928).

@msaroufim Perhaps we can also update the the Custom CUDA extension README about using at::cuda::getCurrentCUDAStream()? Someone can confirm if it is the recommended way.

@cpuhrsch
Copy link
Contributor

@gau-nernst - Yes, indeed you should get the current CUDA stream in case you're in a different context. getCurrentCUDAStream should be the correct API here to get the current CUDA stream from the global context.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
9 participants