Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

should I use openxla or intel extension for transformers to load Whisper Jax Model? #36

Open
sleepingcat4 opened this issue Jul 15, 2024 · 7 comments
Labels
question Further information is requested

Comments

@sleepingcat4
Copy link

I wanted to load Whisper Jax on Intel datacentre GPU Max series. I was wondering should I use Intel OpenXLA extension or Intel extension for transformers? not sure if openxla supports quantisation by default or not

@Zantares
Copy link

Zantares commented Aug 9, 2024

I suggest starting with this OpenXLA Extension first because Transformer Extension should not take care of JAX model before. OpenXLA Extension has simply supported naive quantization feature based on Keras 3, but not verified on all scenarios.

@Zantares Zantares added the question Further information is requested label Aug 9, 2024
@sleepingcat4
Copy link
Author

@Zantares thanks! At the moment, I opted to use HPU (Gaudi2), Can OpenXLA able to compile JAX model on Gaudi2?

Can you suggest sample code that loads the JAX model in 4bit Quantaisation and starts loading the model?

@Zantares
Copy link

@Zantares thanks! At the moment, I opted to use HPU (Gaudi2), Can OpenXLA able to compile JAX model on Gaudi2?

Can you suggest sample code that loads the JAX model in 4bit Quantaisation and starts loading the model?

The support for Gaudi is still under development because it uses different low-level software stack... We have added a simple FP8 example in our repo: https://github.com/intel/intel-extension-for-openxla/tree/main/example/fp8, but we didn't verify INT4 yet. What's the INT4 model you are looking for? Maybe we can check it by ourselves first.

@sleepingcat4
Copy link
Author

@Zantares Whisper-Jax model. I'm working with Intel actually (Intel AI labs). And one of the former Intel employee who is on our team suggested on Intel-gaudi2 Intel openXLA library won't provide advantage since it already uses JIT.

What're your thoughts?

@Zantares
Copy link

I may not provide many suggestions on Gaudi because it's not ready... But we have verified that JAX Whisper models (from Transformer example: https://github.com/huggingface/transformers/tree/main/examples/flax/speech-recognition) can be run on Intel GPUs (Data Center Max/Flex).

For GPU, OpenXLA can provide some generic optimizations and make applications run faster. For Gaudi, as I know it uses different low-level software stack and may not provide many advantages.

@sleepingcat4
Copy link
Author

@Zantares that's wonderful insight. If I'm not mistaken OpenXLA can provide an edge for Intel GPUs right? And Gaudi stack to take advantage of it isn't done yet.

I had an feature request: How about OpenXLA library be integrated inside Intel_extension_for_transformers library. Since Intel extension for transformers aim to standardise HF models inference and loading on Intel hardware. I have already ran HF models with 4bit using Intel extension for transformers and it was quite a breeze.

Can you guys integrate OpenXLA library components so that we can run JAX models on Intel GPUs and maybe Gaudi too, using intel extension for transformers. (For inference only)

Because I don't think for generic optimizations loading another Intel library is an overkill.

@Zantares
Copy link

This is more like a Intel_extension_for_transformers request but not OpenXLA. Since Intel_extension_for_transformers is a third-party independent modular, there's no much work from OpenXLA side. Maybe you can raise it to Intel_extension_for_transformers community if you have strong request of JAX model. I can't determine it because the current main direction is PyTorch, so Intel_extension_for_transformers will serve PyTorch first.

Same story as Gaudi, support PyTorch is the 1st priority, then JAX/OpenXLA. You can see that even in PyTorch side the supporting of Gaudi is still WIP, that's what I mentioned in previous comments: Gaudi is totally a different thing and can't leverage current works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants