Skip to content

Intel® Extension for OpenXLA* 0.2.0 Release

Compare
Choose a tag to compare
@Zantares Zantares released this 01 Dec 07:22
· 14 commits to r0.2.0 since this release

Major Features

Intel® Extension for OpenXLA* is Intel optimized PyPI package to extend official OpenXLA on Intel GPU. It is based on PJRT plugin mechanism, which can seamlessly run JAX models on Intel® Data Center GPU Max Series. This release contains following major features:

  • Upgrade JAX version to v0.4.20.

  • Experimental support JAX native distributed scale-up collectives based on JAX pmap.

  • Continuous optimize common kernels, and optimize GEMM kernels by Intel® Xe Templates for Linear Algebra. 3 inference models (Stable Diffusion, GPT-J, FLAN-T5) are verified on Intel® Data Center GPU Max Series single device, and added to examples.

Known Caveats

  • Device number is restricted as 2/4/6/8/10/12 in the experimental supported collectives in single node.

  • XLA_ENABLE_MULTIPLE_STREAM=1 should be set when use JAX parallelization on multiply devices without collectives. It will add synchronization between different devices to avoid possible accuracy issue.

  • MHA=0 should be set to disable MHA fusion in training. MHA fusion is not supported in training yet and will cause runtime error as below:

ir_emission_[utils.cc:109](http://utils.cc:109/)] Check failed: lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)) == rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))

Breaking changes

  • Previous JAX v0.4.13 is no longer supported. Please follow JAX change log to update application if meet version errors.

  • GCC 10.0.0 or newer is required if build from source. Please refer installation guide for more details.

Documents