Skip to content

Commit

Permalink
NNCF helpers (#2979)
Browse files Browse the repository at this point in the history
### Changes

- Added helpers module to NNCF.
- Added dataset helper.

### Reason for changes

- Extended data-free usage.

### Related tickets

- 152550

### Tests

- Added `tests/common/test_helpers.py`
  • Loading branch information
KodiaqQ authored Oct 1, 2024
1 parent 2c8b70c commit ce1fb51
Show file tree
Hide file tree
Showing 7 changed files with 409 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Compress TinyLLama model using synthetic data

This example demonstrates how to optimize Large Language Models (LLMs) using NNCF weight compression API & synthetic data for the advanced algorithms usage. The example applies 4/8-bit mixed-precision quantization & Scale Estimation algorithm to weights of Linear (Fully-connected) layers of [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) model.
To evaluate the accuracy of the compressed model we measure similarity between two texts generated by the baseline and compressed models using [WhoWhatBench](https://github.com/openvinotoolkit/openvino.genai/tree/master/llm_bench/python/who_what_benchmark) library.

The example includes the following steps:

- Prepare `wikitext` dataset.
- Prepare `TinyLlama/TinyLlama-1.1B-Chat-v1.0` text-generation model in OpenVINO representation using [Optimum-Intel](https://huggingface.co/docs/optimum/intel/inference).
- Compress weights of the model with NNCF Weight compression algorithm with Scale Estimation & `wikitext` dataset.
- Prepare `synthetic` dataset using `nncf.data.generate_text_data` method.
- Compress weights of the model with NNCF Weight compression algorithm with Scale Estimation & `synthetic` dataset.
- Measure the similarity of the two models optimized with different datasets.

## Install requirements

To use this example:

- Create a separate Python* environment and activate it: `python3 -m venv nncf_env && source nncf_env/bin/activate`
- Install dependencies:

```bash
pip install -U pip
pip install -r requirements.txt
pip install ../../../../
```

## Run Example

The example is fully automated. Just run the following command in the prepared Python environment:

```bash
python main.py
```
109 changes: 109 additions & 0 deletions examples/llm_compression/openvino/tiny_llama_synthetic_data/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import datasets
import numpy as np
import openvino as ov
import torch
from optimum.intel.openvino import OVModelForCausalLM
from transformers import AutoTokenizer
from whowhatbench import Evaluator

import nncf

SEED = 0


def transform_func(text, tokenizer, ov_model):
input_dtypes = {inp.get_any_name(): inp.get_element_type() for inp in ov_model.inputs}
tokens = tokenizer(text)
input_ids = np.expand_dims(np.array(tokens["input_ids"]), 0)
attention_mask = np.expand_dims(np.array(tokens["attention_mask"]), 0)
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
res = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids.reshape(*attention_mask.shape),
}

def gen_pkv(num_heads, head_dim, num_layers):
res = {}
shape = (1, num_heads, 0, head_dim)
for i in range(num_layers):
key_name = f"past_key_values.{i}.key"
val_name = f"past_key_values.{i}.value"
res[key_name] = ov.Tensor(shape=shape, type=input_dtypes[key_name])
res[val_name] = ov.Tensor(shape=shape, type=input_dtypes[val_name])
return res

res.update(gen_pkv(4, 64, 22))
return res


def compress_model(model, tokenizer, dataset):
quantization_dataset = nncf.Dataset(dataset, partial(transform_func, tokenizer=tokenizer, ov_model=model.model))

optimized_model = nncf.compress_weights(
model.model.clone(),
dataset=quantization_dataset,
mode=nncf.CompressWeightsMode.INT4_SYM,
ratio=1.0,
scale_estimation=True,
)
return optimized_model


def validate_model(evaluator, hf_model, optimized_model, original_ov_model):
hf_model.model = optimized_model
hf_model.request = None
_, all_metrics = evaluator.score(hf_model)
hf_model.model = original_ov_model
hf_model.request = None
return all_metrics["similarity"][0]


def main():
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
hf_model = OVModelForCausalLM.from_pretrained(
MODEL_ID, export=True, load_in_8bit=False, compile=False, stateful=False
)

original_ov_model = hf_model.model.clone()
evaluator = Evaluator(hf_model, tokenizer=tokenizer, metrics=("similarity",))

# Wikitext-based compression
wikitext_dataset = datasets.load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
wikitext_dataset = [d["text"] for d in wikitext_dataset]
wikitext_optimized_model = compress_model(hf_model, tokenizer, wikitext_dataset)

# Synthetic-based compression
saved_seed = torch.seed()
torch.manual_seed(SEED)
synthetic_dataset = nncf.data.generate_text_data(hf_model, tokenizer)
torch.manual_seed(saved_seed)
synthetic_optimized_model = compress_model(hf_model, tokenizer, synthetic_dataset)

# Similarity comparison between Wikitext-based & Synthetic-based compressed models
wikitext_based_similarity = validate_model(evaluator, hf_model, wikitext_optimized_model, original_ov_model)
print(f"Wikitext-quantized model similarity: {wikitext_based_similarity}")

synthetic_based_similarity = validate_model(evaluator, hf_model, synthetic_optimized_model, original_ov_model)
print(f"Synthetic-quantized model similarity: {synthetic_based_similarity}")
return wikitext_based_similarity, synthetic_based_similarity


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-c ../../../../constraints.txt
torch
datasets
whowhatbench @ git+https://github.com/openvinotoolkit/openvino.genai.git#subdirectory=llm_bench/python/who_what_benchmark
numpy>=1.23.5
openvino==2024.4
optimum-intel[openvino]>=1.13.0
transformers>=4.35.2
onnx<1.16.2
1 change: 1 addition & 0 deletions nncf/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
# limitations under the License.

from nncf.data.dataset import Dataset as Dataset
from nncf.data.generators import generate_text_data as generate_text_data
92 changes: 92 additions & 0 deletions nncf/data/generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, TypeVar

import nncf
from nncf.common.logging.track_progress import track

BASE_VOCAB_SIZE = 12000

TModel = TypeVar("TModel")
TTokenizer = TypeVar("TTokenizer")


def generate_text_data(
model: TModel, tokenizer: TTokenizer, seq_len: int = 32, dataset_size: int = 128, unique_tokens_lower_limit: int = 5
) -> List[str]:
"""
Generates text dataset based on the model output.
Since the model is required to be the instance of the PreTrainedModel
and the tokenizer is required to be the instance of the PreTrainedTokenizerBase,
environment must have `transformers` & `torch` modules installed to run this method.
:param model: Model instance.
:param tokenizer: Tokenizer instance.
:param seq_len: Sequence length for generation.
:param dataset_size: Size of the data.
:return: List of the text data ready to use.
"""

try:
import torch
except ImportError:
raise nncf.ModuleNotFoundError("torch is required in order to generate text data: `pip install torch`.")

try:
from transformers import PreTrainedModel
from transformers import PreTrainedTokenizerBase
from transformers.utils import logging

logging.set_verbosity_error()
except ImportError:
raise nncf.ModuleNotFoundError(
"transformers is required in order to generate text data: `pip install transformers`."
)

if not isinstance(model, PreTrainedModel.__bases__):
raise nncf.ValidationError("Model should be instance of the `transformers.PreTrainedModel`.")

if not isinstance(tokenizer, PreTrainedTokenizerBase.__bases__):
raise nncf.ValidationError("tokenizer should be instance of the `transformers.PreTrainedTokenizerBase`.")

generated_data = []

vocab_size_names = ["padded_vocab_size", "vocab_size"]
vocab_size = BASE_VOCAB_SIZE
for vocab_size_name in vocab_size_names:
if hasattr(model.config, vocab_size_name):
vocab_size = getattr(model.config, vocab_size_name)

step_num = max(1, vocab_size // dataset_size)
ids_counter = 0

with track(total=dataset_size, description="Generating text data") as pbar:
while len(generated_data) < dataset_size:
# Creating the input for pre-generate step
input_ids = torch.tensor([[ids_counter % vocab_size]]).to(model.device)

# Collecting data from the pre & post generate steps
outputs_prep = model.generate(input_ids, do_sample=False, max_length=seq_len // 2)
outputs_post = model.generate(outputs_prep, do_sample=True, max_length=seq_len + seq_len // 2)
gen_text = tokenizer.batch_decode(outputs_post[:, outputs_prep.shape[1] :], skip_special_tokens=True)

if len(set(gen_text[0])) < unique_tokens_lower_limit:
ids_counter += 1
continue

ids_counter += step_num

pbar.progress.update(pbar.task, advance=1)
generated_data.extend(gen_text)

return generated_data
90 changes: 90 additions & 0 deletions tests/torch/data/ref_generated_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
[
[
863,
297,
50,
718,
570,
806,
681,
226,
964,
350,
686,
780
],
[
227,
878,
47,
780,
473,
242,
347,
799,
157,
123,
121,
362,
992,
477
],
[
84,
815,
503,
200,
906,
261,
80,
211,
76,
27,
638,
157,
123,
121
],
[
468,
468,
939,
829,
769,
769,
353,
425,
712,
687,
686,
780
],
[
438,
932,
664,
39,
932,
54,
536,
641,
33,
433,
926,
711
],
[
371,
836,
836,
895,
308,
811,
64,
690,
737,
218,
242,
897
]
]
Loading

0 comments on commit ce1fb51

Please sign in to comment.