Skip to content

Releases: intel/intel-extension-for-openxla

Intel® Extension for OpenXLA* 0.4.0 Release

12 Aug 07:21
Compare
Choose a tag to compare

Major Features

Intel® Extension for OpenXLA* is Intel optimized PyPI package to extend official OpenXLA on Intel GPU. It is based on PJRT plugin mechanism, which can seamlessly run JAX models on Intel® Data Center GPU Max Series and Intel® Data Center GPU Flex Series. This release contains following major features:

  • Jax Upgrade: Upgrade version to v0.4.26 and support the Compatibility of jax and jaxlib, it allows the Extension to support multiple different versions of jax. Please refer to <How are jax and jaxlib versioned?> for more version details between jax and jaxlib.
    intel-extension-for-openxla jaxlib jax
    0.4.0 0.4.26 >= 0.4.26, <= 0.4.27
  • Feature Support:
    • Support Float8 training and inference based on Keras 3.0. A new FP8 case is added to example.
    • Continue to improve jax native distributed scale-up collectives. A new distributed scale-up inference case Grok is added to example.
    • Support the FMHA backward fusion on Intel GPU.
  • Bug Fix:
    • Fix crash in jax native multi-process API.
    • Fix an accuracy error in dynamic slice fusion.
    • Fix known caveat crash related to Binary operations and SPMD multi-device parallelism API psum_scatter under same partial annotation.
    • Fix known caveat hang related to deadlock when working with Toolkit 2024.1.
    • Fix known caveat OOM related to deprecated API clear_backends.
  • Toolkit Support: Support Intel® oneAPI Base Toolkit 2024.2.
  • Driver Support: Support upgraded Driver LTS release 2350.63
  • OneDNN support: Support oneDNN v3.5.1.

Known Caveats

  • Some models show performance regression when working with Toolkit 2024.2. Recommend to use Toolkit 2024.1 if meet performance issues.
  • Multi-process API support is still experimental and may cause hang issues with collectives.

Breaking changes

  • Previous JAX v0.4.24 is no longer supported by this release. Please follow JAX change log to update the application if meets version errors. Please roll back the Extension version if want to use it with old JAX version.

Documents

Intel® Extension for OpenXLA* 0.3.0 Release

29 Mar 20:42
Compare
Choose a tag to compare

Major Features

Intel® Extension for OpenXLA* is Intel optimized PyPI package to extend official OpenXLA on Intel GPU. It is based on PJRT plugin mechanism, which can seamlessly run JAX models on Intel® Data Center GPU Max Series and Intel® Data Center GPU Flex Series. This release contains following major features:

  • JAX Upgrade: Upgrade version to v0.4.24.
  • Feature Support:
    • Supports custom call registration mechanism by new OpenXLA C API. This feature provides the ability to interact with third-party software, such as mpi4jax.
    • Continue to improve JAX native distributed scale-up collectives. Now it supports any number of devices less than 16 in a single node.
    • Experimental support for Intel® Data Center GPU Flex Series.
  • Bug Fix:
    • Fix accuracy issues in GEMM kernel when it's optimized by Intel® Xe Templates for Linear Algebra (XeTLA).
    • Fix crash when input batch size is greater than 65535.
  • Toolkit Support: Support Intel® oneAPI Base Toolkit 2024.1.

Known Caveats

  • Extension will crash when using Binary operations (e.g. Mul, MatMul) and SPMD multi-device parallelism API psum_scatter under same partial annotation. Please refer JAX UT test_matmul_reduce_scatter to understand the error scenario better.
  • JAX collectives fall into deadlock and hang Extension when working with Toolkit 2024.1. Recommend to use Toolkit 2024.0 if need collectives.
  • clear_backends API doesn't work and may cause an OOM exception as below when working with Toolkit 2024.0.
terminate called after throwing an instance of 'sycl::_V1::runtime_error'
  what():  Native API failed. Native API returns: -5 (PI_ERROR_OUT_OF_RESOURCES) -5 (PI_ERROR_OUT_OF_RESOURCES)
Fatal Python error: Aborted

Note: clear_backends API will be deprecated by JAX soon.

Breaking changes

  • Previous JAX v0.4.20 is no longer supported. Please follow JAX change log to update application if meets version errors.

Documents

0.2.1

06 Feb 08:49
Compare
Choose a tag to compare

Bug Fixes and Other Changes

  • Fix known caveat related to XLA_ENABLE_MULTIPLE_STREAM=1. The accuracy issue is fixed and no need to set this environment variable anymore.
  • Fix known caveat related to MHA=0. The crash error is fixed and no need to set this environment variable anymore.
  • Fix compatibility issue with upgraded Driver LTS release 2350.29
  • Fix random accuracy issue caused by AllToAll collective.
  • Upgrade transformers used by examples to 4.36 to fix open CVE.

Known Caveats

  • Device number is restricted as 2/4/6/8/10/12 in the experimental supported collectives in single node.
  • Do not use collectives (e.g. AllReduce) in nested pjit, it may cause random accuracy issue. Please refer JAX UT testAutodiff to understand the error scenario better.

Full Changelog: 0.2.0...0.2.1

Intel® Extension for OpenXLA* 0.2.0 Release

01 Dec 07:22
Compare
Choose a tag to compare

Major Features

Intel® Extension for OpenXLA* is Intel optimized PyPI package to extend official OpenXLA on Intel GPU. It is based on PJRT plugin mechanism, which can seamlessly run JAX models on Intel® Data Center GPU Max Series. This release contains following major features:

  • Upgrade JAX version to v0.4.20.

  • Experimental support JAX native distributed scale-up collectives based on JAX pmap.

  • Continuous optimize common kernels, and optimize GEMM kernels by Intel® Xe Templates for Linear Algebra. 3 inference models (Stable Diffusion, GPT-J, FLAN-T5) are verified on Intel® Data Center GPU Max Series single device, and added to examples.

Known Caveats

  • Device number is restricted as 2/4/6/8/10/12 in the experimental supported collectives in single node.

  • XLA_ENABLE_MULTIPLE_STREAM=1 should be set when use JAX parallelization on multiply devices without collectives. It will add synchronization between different devices to avoid possible accuracy issue.

  • MHA=0 should be set to disable MHA fusion in training. MHA fusion is not supported in training yet and will cause runtime error as below:

ir_emission_[utils.cc:109](http://utils.cc:109/)] Check failed: lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)) == rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))

Breaking changes

  • Previous JAX v0.4.13 is no longer supported. Please follow JAX change log to update application if meet version errors.

  • GCC 10.0.0 or newer is required if build from source. Please refer installation guide for more details.

Documents

Intel® Extension for OpenXLA* 0.1.0 Experimental Release

06 Sep 03:08
Compare
Choose a tag to compare

Major Features

Intel® Extension for OpenXLA* is Intel optimized Python package to extend official OpenXLA on Intel GPU. It is based on PJRT plugin mechanism, which can seamlessly run JAX models on Intel® Data Center GPU Max Series. The PJRT API simplified the integration, which allowed the Intel XPU plugin to be developed separately and quickly integrated into JAX. This release contains following major features:

  • Kernel enabling and optimization

Common kernels are enabled with LLVM/SPIRV software stack. Convolution and GEMM are enabled with OneDNN. And Stable Diffusion is verified.

Known Issues

  • Limited support for collective ops due to the limitation of oneCCL.

Related Blog