Skip to content

Commit

Permalink
[Doc] Update v0.2.0 release doc (#133)
Browse files Browse the repository at this point in the history
Co-authored-by: Lu, Fengqing <[email protected]>
  • Loading branch information
Lu Teng and LuFinch committed Dec 1, 2023
1 parent e75bde9 commit 3d6b9e1
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 44 deletions.
31 changes: 17 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,29 @@ This guide introduces the overview of OpenXLA high level integration structure a

Verified Hardware Platforms:

* Intel® Data Center GPU Max Series, Driver Version: [682](https://dgpu-docs.intel.com/releases/production_682.14_20230804.html)
* Intel® Data Center GPU Max Series, Driver Version: [736](https://dgpu-docs.intel.com/releases/stable_736_25_20231031.html)

* Intel® Data Center GPU Flex Series 170, Driver Version: [682](https://dgpu-docs.intel.com/releases/production_682.14_20230804.html)
* Intel® Data Center GPU Flex Series 170, Driver Version: [736](https://dgpu-docs.intel.com/releases/stable_736_25_20231031.html)

### Software Requirements

* Ubuntu 22.04, Red Hat 8.6/8.8/9.2 (64-bit)
* Ubuntu 22.04 (64-bit)
* Intel® Data Center GPU Flex Series
* Ubuntu 22.04, Red Hat 8.6/8.8/9.2 (64-bit), SUSE Linux Enterprise Server(SLES) 15 SP4
* Ubuntu 22.04, SUSE Linux Enterprise Server(SLES) 15 SP4
* Intel® Data Center GPU Max Series
* Intel® oneAPI Base Toolkit 2023.2
* Intel® oneAPI Base Toolkit 2024.0
* Jax/Jaxlib 0.4.20
* Python 3.9-3.11
* pip 19.0 or later (requires manylinux2014 support)

**NOTE: Since Jax has its own [platform limitation](https://jax.readthedocs.io/en/latest/installation.html#supported-platforms) (Ubuntu 20.04 or later), real software requirements is restricted when works with Jax.**

### Install Intel GPU Drivers

|OS|Intel GPU|Install Intel GPU Driver|
|-|-|-|
|Ubuntu 22.04, Red Hat 8.6/8.8/9.2|Intel® Data Center GPU Flex Series| Refer to the [Installation Guides](https://dgpu-docs.intel.com/installation-guides/index.html#intel-data-center-gpu-flex-series) for latest driver installation. If install the verified Intel® Data Center GPU Max Series/Intel® Data Center GPU Flex Series [682](https://dgpu-docs.intel.com/releases/production_682.14_20230804.html), please append the specific version after components, such as `sudo apt-get install intel-opencl-icd==23.22.26516.25-682~22.04`|
|Ubuntu 22.04, Red Hat 8.6/8.8/9.2, SLES 15 SP4|Intel® Data Center GPU Max Series| Refer to the [Installation Guides](https://dgpu-docs.intel.com/installation-guides/index.html#intel-data-center-gpu-max-series) for latest driver installation. If install the verified Intel® Data Center GPU Max Series/Intel® Data Center GPU Flex Series [682](https://dgpu-docs.intel.com/releases/production_682.14_20230804.html), please append the specific version after components, such as `sudo apt-get install intel-opencl-icd==23.22.26516.25-682~22.04`|
|Ubuntu 22.04 |Intel® Data Center GPU Flex Series| Refer to the [Installation Guides](https://dgpu-docs.intel.com/installation-guides/index.html#intel-data-center-gpu-flex-series) for latest driver installation. If install the verified Intel® Data Center GPU Max Series/Intel® Data Center GPU Flex Series [736](https://dgpu-docs.intel.com/releases/stable_736_25_20231031.html), please append the specific version after components, such as `sudo apt-get install intel-opencl-icd==23.30.26918.50-736~22.04`|
|Ubuntu 22.04, SLES 15 SP4|Intel® Data Center GPU Max Series| Refer to the [Installation Guides](https://dgpu-docs.intel.com/installation-guides/index.html#intel-data-center-gpu-max-series) for latest driver installation. If install the verified Intel® Data Center GPU Max Series/Intel® Data Center GPU Flex Series [736](https://dgpu-docs.intel.com/releases/stable_736_25_20231031.html), please append the specific version after components, such as `sudo apt-get install intel-opencl-icd==23.30.26918.50-736~22.04`|

### Install oneAPI Base Toolkit Packages

Expand All @@ -57,13 +59,13 @@ Need to install components of Intel® oneAPI Base Toolkit:
* Intel® oneAPI Threading Building Blocks (TBB), dependency of DPC++ Compiler.

```bash
wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/992857b9-624c-45de-9701-f6445d845359/l_BaseKit_p_2023.2.0.49397_offline.sh
sudo sh ./l_BaseKit_p_2023.2.0.49397_offline.sh
wget https://registrationcenter-download.intel.com/akdlm//IRC_NAS/20f4e6a1-6b0b-4752-b8c1-e5eacba10e01/l_BaseKit_p_2024.0.0.49564.sh
# 2 components are necessary: DPC++/C++ Compiler and oneMKL
sudo sh l_BaseKit_p_2024.0.0.49564.sh

# Source OneAPI env
source /opt/intel/oneapi/compiler/2023.2.0/env/vars.sh
source /opt/intel/oneapi/mkl/2023.2.0/env/vars.sh
source /opt/intel/oneapi/tbb/2021.9.0/env/vars.sh
source /opt/intel/oneapi/compiler/2024.0/env/vars.sh
source /opt/intel/oneapi/mkl/2024.0/env/vars.sh
```

### Install Jax and Jaxlib
Expand All @@ -82,12 +84,13 @@ pip install --upgrade intel-extension-for-openxla

### Install from Source Build

**NOTE: Extra software (GCC 10.0.0 or later) is required if want to build from source.**
```bash
git clone https://github.com/intel/intel-extension-for-openxla.git
./configure # Choose Yes for all.
bazel build //xla/tools/pip_package:build_pip_package
./bazel-bin/xla/tools/pip_package/build_pip_package ./
pip install intel_extension_for_openxla-0.1.0-cp39-cp39-linux_x86_64.whl
pip install intel_extension_for_openxla-0.2.0-cp39-cp39-linux_x86_64.whl
```

**Aditional Build Option**:
Expand Down Expand Up @@ -163,4 +166,4 @@ jax.local_devices(): [xpu(id=0), xpu(id=1)]
sudo apt install plocate
locate libstdc++.so |grep /usr/lib/ # For example, the output of the library path is "/usr/lib/x86_64-linux-gnu/libstdc++.so.6".
sudo ln -s /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /usr/lib/gcc/x86_64-linux-gnu/12/libstdc++.so
```
```
2 changes: 1 addition & 1 deletion configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# pylint: enable=g-import-not-at-top


_DEFAULT_SYCL_TOOLKIT_PATH = '/opt/intel/oneapi/compiler/latest/linux'
_DEFAULT_SYCL_TOOLKIT_PATH = '/opt/intel/oneapi/compiler/latest'
_DEFAULT_AOT_CONFIG = ''
_DEFAULT_GCC_TOOLCHAIN_PATH = ''
_DEFAULT_GCC_TOOLCHAIN_TARGET = ''
Expand Down
22 changes: 15 additions & 7 deletions example/bert/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
# Quick Start for fine-tunes BERT on SQuAD
Fine-tunes BERT model on SQuAD task by [Question Answering examples](https://github.com/huggingface/transformers/tree/v4.27.4/examples/flax/question-answering#question-answering-examples).
Fine-tunes BERT model on SQuAD task by [Question Answering examples](https://github.com/huggingface/transformers/tree/v4.32.0/examples/flax/question-answering#question-answering-examples).
This expample is referred from [HuggingFace Transformers](https://github.com/huggingface/transformers). See [Backup](#Backup) for modification details.


**IMPORTANT: This example is temporarily unavailable under JAX v0.4.20 with below error due to public issue (https://github.com/huggingface/transformers/issues/27644):**
```
AttributeError: 'ArrayImpl' object has no attribute 'split'
```
**Will reenable it once it's fixed in community.**


## Requirements

### 1. Install intel-extension-for-openxla
Expand Down Expand Up @@ -31,7 +39,7 @@ cd -
### Running command
```bash
python run_qa.py \
--model_name_or_path <WORKSPACE>/examples/bert/models \
--model_name_or_path <WORKSPACE>/example/bert/models \
--dataset_name squad \
--do_train \
--per_device_train_batch_size 8 \
Expand Down Expand Up @@ -62,10 +70,10 @@ Performance... xxx iter/s
### Backup
```patch
diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py
index 230480428..00d901d76 100644
index a2839539e..a530d8560 100644
--- a/examples/flax/question-answering/run_qa.py
+++ b/examples/flax/question-answering/run_qa.py
@@ -821,7 +821,8 @@ def main():
@@ -846,7 +846,8 @@ def main():

# region Training steps and logging init
train_dataset = processed_raw_datasets["train"]
Expand All @@ -75,7 +83,7 @@ index 230480428..00d901d76 100644

# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
@@ -931,11 +932,12 @@ def main():
@@ -957,11 +958,12 @@ def main():
state = replicate(state)

train_time = 0
Expand All @@ -89,7 +97,7 @@ index 230480428..00d901d76 100644
train_metrics = []

# Create sampling rng
@@ -956,6 +958,13 @@ def main():
@@ -982,6 +984,13 @@ def main():

cur_step = epoch * step_per_epoch + step

Expand All @@ -103,7 +111,7 @@ index 230480428..00d901d76 100644
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics
train_metric = unreplicate(train_metric)
@@ -1022,6 +1031,9 @@ def main():
@@ -1048,6 +1057,9 @@ def main():
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
Expand Down
10 changes: 5 additions & 5 deletions example/bert/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
datasets >= 1.8.0
jax>=0.4.13
jaxlib>=0.4.13
flax>=0.3.5
datasets>=1.8.0
jax==0.4.20
jaxlib==0.4.20
flax>=0.7.0
optax>=0.0.8
transformers==4.27.4
transformers==4.32.0
evaluate>=0.4.1
60 changes: 43 additions & 17 deletions example/bert/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import random
import sys
import time
import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -55,13 +56,13 @@
PreTrainedTokenizerFast,
is_tensorboard_available,
)
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils import check_min_version, send_example_telemetry


logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.27.0")
check_min_version("4.32.0")

Array = Any
Dataset = datasets.arrow_dataset.Dataset
Expand Down Expand Up @@ -155,12 +156,28 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
token: str = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it will"
"execute code present on the Hub on your local machine."
)
},
)
Expand Down Expand Up @@ -438,6 +455,12 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if model_args.use_auth_token is not None:
warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_qa", model_args, data_args, framework="flax")
Expand All @@ -462,14 +485,14 @@ def main():

# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Retrieve of infer repo_name
repo_name = training_args.hub_model_id
if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)

# region Load Data
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
Expand All @@ -487,7 +510,7 @@ def main():
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)
else:
# Loading the dataset from local csv or json file.
Expand All @@ -507,7 +530,7 @@ def main():
data_files=data_files,
field="data",
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
Expand All @@ -520,14 +543,16 @@ def main():
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=True,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
# endregion

Expand Down Expand Up @@ -875,7 +900,8 @@ def write_eval_metric(summary_writer, eval_metrics, step):
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
)
Expand Down

0 comments on commit 3d6b9e1

Please sign in to comment.