Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SD_PIPELINE_FP16_JAX benchmark executes in FP32 #176

Open
ekuznetsov139 opened this issue Apr 20, 2024 · 0 comments
Open

SD_PIPELINE_FP16_JAX benchmark executes in FP32 #176

ekuznetsov139 opened this issue Apr 20, 2024 · 0 comments

Comments

@ekuznetsov139
Copy link

The benchmark attempts to convert SD_PIPELINE_F16_JAX by calling 'to_fp16' on model parameters
https://github.com/iree-org/iree-comparative-benchmark/blob/main/common_benchmark_suite/openxla/benchmark/models/jax/stable_diffusion/stable_diffusion_pipeline.py#L47-L51

The only thing this achieves is to convert model weights into float16. Model activations start as float32 https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py#L257 and elsewhere (e.g. when time embeddings are generated). Whenever a flax module is executed with float16 weights and float32 activations (or vice versa), unless it has an explicit compute type, it promotes everything to float32. https://github.com/google/flax/blob/main/flax/linen/linear.py#L189

One way to actually run it in FP16 is to add dtype in the call here https://github.com/iree-org/iree-comparative-benchmark/blob/main/common_benchmark_suite/openxla/benchmark/models/jax/stable_diffusion/stable_diffusion_pipeline.py#L38

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant