Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster Inference & Training Roadmap #226

Open
jeromeku opened this issue Mar 7, 2024 · 27 comments
Open

Faster Inference & Training Roadmap #226

jeromeku opened this issue Mar 7, 2024 · 27 comments
Labels
Discussion Questions or discussions

Comments

@jeromeku
Copy link

jeromeku commented Mar 7, 2024

@danielhanchen

In the unsloth Gemma intro blogpost, you mention VRAM increase due to larger MLP size in Gemma compared to Llama and Mistral, and show a graph demonstrating decreased memory usage when running unsloth vs. HF and FA2:

  • How does unsloth reduce memory usage?
  • What are the model and runtime configs used to generate the HF vs FA2 vs unsloth graph? Is it inference or training?

Curious what optimizations are leading to memory decrease -- quantization, autograd efficiency, etc.

@danielhanchen
Copy link
Contributor

@jeromeku I will get to reviewing GPTQ - sorry on the delay!!

  • The VRAM reductions are from Unsloth's optims :) Ie Triton kernels, making memory copies go away, FA2, autograd tricks etc
  • Oh training! All use 4bit quantization, the ga is set to 1, and bsz and sequence length moves dynamically so I can measure VRAM usage.

@danielhanchen danielhanchen added the Discussion Questions or discussions label Mar 8, 2024
@jeromeku
Copy link
Author

jeromeku commented Mar 8, 2024

@danielhanchen

Thanks -- would be helpful to have a step-by-step breakdown of where the memory savings are coming from, i.e., an ablation study.

Is there interest in faster inference kernels, or is the focus primarily on the training side?

@danielhanchen
Copy link
Contributor

@jeromeku For Mistral itself: https://unsloth.ai/blog/mistral-benchmark
image

Gemma's VRAM reduction should be similar to our breakdown for Mistral.

For inference for Gemma - I did make it 2x faster, but it's mainly cobbling up ideas from vLLM and other packages, so I only spent 1 week on it :) The goal is to merge GPT Fast and other ideas like EAGLE to make inference faster :)

@jeromeku
Copy link
Author

jeromeku commented Mar 8, 2024

@danielhanchen

I'd be interested in contributing on the inference front -- let's create a priority list of ideas for implementation?

@danielhanchen
Copy link
Contributor

@jeromeku That'll be cool!! :) We can collab either via Github or async on our Discord - whatever suites you :)

@jeromeku
Copy link
Author

jeromeku commented Mar 9, 2024

@danielhanchen

Looking forward to it!

What's top of mind currently? Perhaps we can draw up a roadmap (if one doesn't already exist).

@danielhanchen
Copy link
Contributor

@jeromeku Oh ye a roadmap would be nice - don't actually have one for inference specifically :)

@jeromeku
Copy link
Author

jeromeku commented Mar 9, 2024

@danielhanchen

You mentioned integrating ideas from fastGPT and EAGLE, what others did you have in mind?

What's on the roadmap for fine-tuning / training -- architectures, algorithms, etc.? Asking so I know what literature / code to review.

@danielhanchen
Copy link
Contributor

@jeromeku In terms of inference specifically:

  1. GPT Fast
  2. Speculative Decoding (use a small model to generate tokens, then use a large model in 1 forward pass and see if the argmax of the logits match)
  3. EAGLE (Speculative Decoding but only Word2Vec style ie lm_head -> embeddings)
  4. All quant methods - HQQ, AWQ, Exllama etc
  5. vLLM's Paged Attention
  6. Full 1 singular Triton kernel fusion - ie can we write 1 forward pass in 1 humoungous Triton kernel? Very hard since there are synchronizations which have to be done
  7. Using float8 like Fire Attention. cuDNN has float8 flash attention I think as well.
  8. Rewriting matrix vector multiplication in Triton exactly (like what you were trying to do with GPTQ but not matmul, but matvec
  9. Torch export

I might have more, but those are from the top of my head.

For training / finetuning:

  1. Fast MoE matmul kernel DeepseekMoE support with Fused MoE kernel vllm-project/vllm#2453 but for training - much more complex than inference on batch sizes of 1. Mixtral selects the top 2 experts, which can easily be done in Triton. However, when you have bsz>1, we have issues. One has to do dynamic compressed packing then call torch.bmm. The backward pass is even more problematic, since it requires a reversed packing then calling torch.bmm, then deconstructing it. A nightmare.
  2. Galore - extremely fascinating projecting gradients to a small (rank, rank) matrix, then using SVD to update the projectors. It's not Galore that I was fascinatined by, but rather Lomo, which does gradient updates dynamically, and this can save 20GB of VRAM during pretraining.
  3. 1.58bit - I recently wrote on HN about how 1.58bit allows one to not to multiplications since (-1, 0, 1) becomes a simple sign flip then the mantissas are added after the exponents are flipped. Using 8bit floats, 1.58bit uses 2x less space than float8, which makes it possible to cram 2x transistors. Writing it in Triton can be more complex.

Just a brain dump!

@jeromeku
Copy link
Author

jeromeku commented Mar 9, 2024

@danielhanchen

Love it.

Inference:

  • vLLM paged attention - happy to look into this as well as KV cache quantization.
  • Triton GEMV - seems pretty straightforward -- can prototype such implementation in Triton -- I believe Torch compile already can generate such a kernel with proper inductor settings (effectively decomposes to a vectorized mul + add). Can also adapt existing GEMV CUDA kernels for quantized weights.
  • Torch export - can look into it. I've done some work into decoupling Triton kernels from Triton runtime.

Training:

  • I've been playing around with implementing a Cutlass kernel for MoE matmul which could help with larger batch sizes.
  • Galore - top of my list of papers to read
  • 1.58 bit - also been looking into Cutlass for optimizing custom quantized ops.

@danielhanchen danielhanchen changed the title Gemma VRAM Reduction Faster Inference & Training Roadmap Mar 10, 2024
@danielhanchen
Copy link
Contributor

  • Oh ye KV cache quant is cool! On issue I have with it is dynamically quantizing the KV cache will cause overhead issues - a super fast method for quantization will have to be deployed.
  • Triton GEMV: Ye the kernel is fine to create - one possibility is can we fold GEMVs and layernorms and everything into 1 large kernel
  • CUTLASS is good :) Looking forward to it - my main view is we need to use as much Triton as possible for device agnostic purposes :)
  • Ye Galore and 1.58bit :) 1.58bit actually can be very doable in Triton. Galore is very cool.

@jeromeku
Copy link
Author

Let me know what I should prioritize.

Also, can you expand more on Triton GEMV? What kind of horizontal / vertical fusions to target?

@danielhanchen
Copy link
Contributor

Oh so GEMV is generally OK I guess - the issue is the dequant step merged in (ie what you were doing with GPTQ, except its not matrix matrix mult but matrix vector mult) this allows different optimizations - ie is blocked mm better or is column or is row wise mv better? It depends on the cache footprint

But the goal is can we somehow merge X @ Wq, X @ Wk, X @ Wv together with RoPE and attention and everything into 1 large kernel

@jeromeku
Copy link
Author

jeromeku commented Mar 10, 2024

If I understand correctly:

  • Separate the dequant step from matmul
  • Fuse as much of the forward pass into a single kernel for Llama, Mistral, and Gemma architectures

Can you point me to the current GEMV implementation? Need a minimal implementation / testbed for benchmarking purposes.

@danielhanchen
Copy link
Contributor

Oh for inference, you method of fusing the dequant step inside the kernel is actually ideal! For training its not, since CUBLAS is relatively smart in data movements.

An ideal kernel for GEMV ie vector * matrix kernel normally is done via:
image

However a more optimal procedure is to split the reductions into 4 blocks by using atomic_add. It in fact can be say reduction columns of 4, but say blocks of 24, and cycling using the modulus function.
image

A final reduction will need to be made at the end.

The current GEMV implementation will be probably the one in Fast-GPT although I haven't inspected it myself yet.

The hardest is the folding in of Bitsandbytes int4 which is a nightmare, since the blocksize is lopsided ie not whole integer multiple, which is a nightmare for cache optimality.

@danielhanchen
Copy link
Contributor

Another approach people do is row wise
image

which again can be done in parallel with a reduction as i described above

@jeromeku
Copy link
Author

jeromeku commented Mar 10, 2024

@danielhanchen
Ok - so I'm clear on objectives:

  • a reasonable first pass is a Triton kernel that fuses bitsandbytes 4-bit dequant with an efficient GEMV
  • further iterations would then fold in additional prologue / epilogue ops such as positional encodings, activations, etc.
  • ultimate goal would be fusing in as much of the forward pass as possible into single kernel.

@nivibilla
Copy link

nivibilla commented Mar 10, 2024

For training / finetuning:

@danielhanchen Obligatory request for Multi GPU XD

@danielhanchen
Copy link
Contributor

@jeromeku Extremely sorry on the delay - yep sounds right! :) @nivibilla Yep!

@jeromeku
Copy link
Author

@danielhanchen

Is the issue with the existing bitsandbytes gemv the fact that it's CUDA only?

@danielhanchen
Copy link
Contributor

@jeromeku Yes that can be one of the main issues - the other is folding it inside other kernels ie say 1 singular kernel can become too complex to do.

The main issue I still see with 1 kernel, so maybe I'm overthinking, is every new op requires synchronization, so maybe we should rather rely on torch.compile with CUDAGraphs to reduce the CPU overhead in between.

@jeromeku
Copy link
Author

I'd imagine there is an optimization spectrum:

  • torch.compile entire graph with appropriate inductor settings to maximize fusion / reduce overhead
  • manually fuse kernels and use torch cudagraph APIs to glue things together

Will make a quick pass at implementing bnb dequant gemv in triton to see how performance compares.

Cutlass also enables some flexibility with bespoke gemm and fusions but is again cuda only. Let me know if this is of interest.

@danielhanchen
Copy link
Contributor

@jeromeku Oh ye let's try be device agnostic :)) compile is OK, but I guess handwritting is best :) We then can use CUDAGraphs manually

@jeromeku
Copy link
Author

@danielhanchen

A few updates:

  • GaLore -- ran some initial experiments to fuse the GaLore Adam update step -- see PR
  • Is a triton 4-bit bnb dequant kernel of interest?
  • Going to start working on implementing fused backward pass for mixtral.

@danielhanchen
Copy link
Contributor

@jeromeku Fantastic work as always!! very very cool on fusing Adam and Galore!! Love this!

Oh on Mixtral - https://github.com/shawntan/scattermoe/tree/main/scattermoe :) Was reading up on this as well :)

On BnB dequant - I'll have a look first at it :) But you're more than happy to do it if you want :)

@jeromeku
Copy link
Author

@danielhanchen

  • Are you planning on integrating GaLore into unsloth? Planning on working on an Adam8bit version.
  • Will make a quick crack at bnb dequant

@pHaeusler
Copy link

Really excited about optimized kernels for inference!

Worth looking at https://github.com/zeux/calm - where the forward pass is implemented as a single cuda kernel

Uses fp8 rather than int4/8 quantization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Discussion Questions or discussions
Projects
None yet
Development

No branches or pull requests

4 participants