We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The benchmark attempts to convert SD_PIPELINE_F16_JAX by calling 'to_fp16' on model parameters https://github.com/iree-org/iree-comparative-benchmark/blob/main/common_benchmark_suite/openxla/benchmark/models/jax/stable_diffusion/stable_diffusion_pipeline.py#L47-L51
The only thing this achieves is to convert model weights into float16. Model activations start as float32 https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py#L257 and elsewhere (e.g. when time embeddings are generated). Whenever a flax module is executed with float16 weights and float32 activations (or vice versa), unless it has an explicit compute type, it promotes everything to float32. https://github.com/google/flax/blob/main/flax/linen/linear.py#L189
One way to actually run it in FP16 is to add dtype in the call here https://github.com/iree-org/iree-comparative-benchmark/blob/main/common_benchmark_suite/openxla/benchmark/models/jax/stable_diffusion/stable_diffusion_pipeline.py#L38
The text was updated successfully, but these errors were encountered:
No branches or pull requests
The benchmark attempts to convert SD_PIPELINE_F16_JAX by calling 'to_fp16' on model parameters
https://github.com/iree-org/iree-comparative-benchmark/blob/main/common_benchmark_suite/openxla/benchmark/models/jax/stable_diffusion/stable_diffusion_pipeline.py#L47-L51
The only thing this achieves is to convert model weights into float16. Model activations start as float32 https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py#L257 and elsewhere (e.g. when time embeddings are generated). Whenever a flax module is executed with float16 weights and float32 activations (or vice versa), unless it has an explicit compute type, it promotes everything to float32. https://github.com/google/flax/blob/main/flax/linen/linear.py#L189
One way to actually run it in FP16 is to add dtype in the call here https://github.com/iree-org/iree-comparative-benchmark/blob/main/common_benchmark_suite/openxla/benchmark/models/jax/stable_diffusion/stable_diffusion_pipeline.py#L38
The text was updated successfully, but these errors were encountered: