-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster Inference & Training Roadmap #226
Comments
@jeromeku I will get to reviewing GPTQ - sorry on the delay!!
|
Thanks -- would be helpful to have a step-by-step breakdown of where the memory savings are coming from, i.e., an ablation study. Is there interest in faster inference kernels, or is the focus primarily on the training side? |
@jeromeku For Mistral itself: https://unsloth.ai/blog/mistral-benchmark Gemma's VRAM reduction should be similar to our breakdown for Mistral. For inference for Gemma - I did make it 2x faster, but it's mainly cobbling up ideas from vLLM and other packages, so I only spent 1 week on it :) The goal is to merge GPT Fast and other ideas like EAGLE to make inference faster :) |
I'd be interested in contributing on the inference front -- let's create a priority list of ideas for implementation? |
@jeromeku That'll be cool!! :) We can collab either via Github or async on our Discord - whatever suites you :) |
Looking forward to it! What's top of mind currently? Perhaps we can draw up a roadmap (if one doesn't already exist). |
@jeromeku Oh ye a roadmap would be nice - don't actually have one for inference specifically :) |
You mentioned integrating ideas from fastGPT and EAGLE, what others did you have in mind? What's on the roadmap for fine-tuning / training -- architectures, algorithms, etc.? Asking so I know what literature / code to review. |
@jeromeku In terms of inference specifically:
I might have more, but those are from the top of my head. For training / finetuning:
Just a brain dump! |
Love it. Inference:
Training:
|
|
Let me know what I should prioritize. Also, can you expand more on Triton GEMV? What kind of horizontal / vertical fusions to target? |
Oh so GEMV is generally OK I guess - the issue is the dequant step merged in (ie what you were doing with GPTQ, except its not matrix matrix mult but matrix vector mult) this allows different optimizations - ie is blocked mm better or is column or is row wise mv better? It depends on the cache footprint But the goal is can we somehow merge |
If I understand correctly:
Can you point me to the current GEMV implementation? Need a minimal implementation / testbed for benchmarking purposes. |
@danielhanchen
|
@danielhanchen Obligatory request for Multi GPU XD |
@jeromeku Extremely sorry on the delay - yep sounds right! :) @nivibilla Yep! |
Is the issue with the existing |
@jeromeku Yes that can be one of the main issues - the other is folding it inside other kernels ie say 1 singular kernel can become too complex to do. The main issue I still see with 1 kernel, so maybe I'm overthinking, is every new op requires synchronization, so maybe we should rather rely on |
I'd imagine there is an optimization spectrum:
Will make a quick pass at implementing Cutlass also enables some flexibility with bespoke |
@jeromeku Oh ye let's try be device agnostic :)) compile is OK, but I guess handwritting is best :) We then can use CUDAGraphs manually |
A few updates:
|
@jeromeku Fantastic work as always!! very very cool on fusing Adam and Galore!! Love this! Oh on Mixtral - https://github.com/shawntan/scattermoe/tree/main/scattermoe :) Was reading up on this as well :) On BnB dequant - I'll have a look first at it :) But you're more than happy to do it if you want :) |
|
Really excited about optimized kernels for inference! Worth looking at https://github.com/zeux/calm - where the forward pass is implemented as a single cuda kernel Uses fp8 rather than int4/8 quantization. |
@danielhanchen
In the unsloth Gemma intro blogpost, you mention VRAM increase due to larger
MLP
size inGemma
compared toLlama
andMistral
, and show a graph demonstrating decreased memory usage when runningunsloth
vs.HF
andFA2
:HF
vsFA2
vsunsloth
graph? Is it inference or training?Curious what optimizations are leading to memory decrease -- quantization, autograd efficiency, etc.
The text was updated successfully, but these errors were encountered: