-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
should I use openxla or intel extension for transformers to load Whisper Jax Model? #36
Comments
I suggest starting with this OpenXLA Extension first because Transformer Extension should not take care of JAX model before. OpenXLA Extension has simply supported naive quantization feature based on Keras 3, but not verified on all scenarios. |
@Zantares thanks! At the moment, I opted to use HPU (Gaudi2), Can OpenXLA able to compile JAX model on Gaudi2? Can you suggest sample code that loads the JAX model in 4bit Quantaisation and starts loading the model? |
The support for Gaudi is still under development because it uses different low-level software stack... We have added a simple FP8 example in our repo: https://github.com/intel/intel-extension-for-openxla/tree/main/example/fp8, but we didn't verify INT4 yet. What's the INT4 model you are looking for? Maybe we can check it by ourselves first. |
@Zantares Whisper-Jax model. I'm working with Intel actually (Intel AI labs). And one of the former Intel employee who is on our team suggested on Intel-gaudi2 Intel openXLA library won't provide advantage since it already uses JIT. What're your thoughts? |
I may not provide many suggestions on Gaudi because it's not ready... But we have verified that JAX Whisper models (from Transformer example: https://github.com/huggingface/transformers/tree/main/examples/flax/speech-recognition) can be run on Intel GPUs (Data Center Max/Flex). For GPU, OpenXLA can provide some generic optimizations and make applications run faster. For Gaudi, as I know it uses different low-level software stack and may not provide many advantages. |
@Zantares that's wonderful insight. If I'm not mistaken OpenXLA can provide an edge for Intel GPUs right? And Gaudi stack to take advantage of it isn't done yet. I had an feature request: How about OpenXLA library be integrated inside Can you guys integrate OpenXLA library components so that we can run JAX models on Intel GPUs and maybe Gaudi too, using intel extension for transformers. (For inference only) Because I don't think for generic optimizations loading another Intel library is an overkill. |
This is more like a Same story as Gaudi, support PyTorch is the 1st priority, then JAX/OpenXLA. You can see that even in PyTorch side the supporting of Gaudi is still WIP, that's what I mentioned in previous comments: Gaudi is totally a different thing and can't leverage current works. |
I wanted to load Whisper Jax on Intel datacentre GPU Max series. I was wondering should I use Intel OpenXLA extension or Intel extension for transformers? not sure if openxla supports quantisation by default or not
The text was updated successfully, but these errors were encountered: