diff --git a/examples/llm_compression/openvino/tiny_llama_synthetic_data/README.md b/examples/llm_compression/openvino/tiny_llama_synthetic_data/README.md new file mode 100644 index 00000000000..da4556bfba3 --- /dev/null +++ b/examples/llm_compression/openvino/tiny_llama_synthetic_data/README.md @@ -0,0 +1,34 @@ +# Compress TinyLLama model using synthetic data + +This example demonstrates how to optimize Large Language Models (LLMs) using NNCF weight compression API & synthetic data for the advanced algorithms usage. The example applies 4/8-bit mixed-precision quantization & Scale Estimation algorithm to weights of Linear (Fully-connected) layers of [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) model. +To evaluate the accuracy of the compressed model we measure similarity between two texts generated by the baseline and compressed models using [WhoWhatBench](https://github.com/openvinotoolkit/openvino.genai/tree/master/llm_bench/python/who_what_benchmark) library. + +The example includes the following steps: + +- Prepare `wikitext` dataset. +- Prepare `TinyLlama/TinyLlama-1.1B-Chat-v1.0` text-generation model in OpenVINO representation using [Optimum-Intel](https://huggingface.co/docs/optimum/intel/inference). +- Compress weights of the model with NNCF Weight compression algorithm with Scale Estimation & `wikitext` dataset. +- Prepare `synthetic` dataset using `nncf.data.generate_text_data` method. +- Compress weights of the model with NNCF Weight compression algorithm with Scale Estimation & `synthetic` dataset. +- Measure the similarity of the two models optimized with different datasets. + +## Install requirements + +To use this example: + +- Create a separate Python* environment and activate it: `python3 -m venv nncf_env && source nncf_env/bin/activate` +- Install dependencies: + +```bash +pip install -U pip +pip install -r requirements.txt +pip install ../../../../ +``` + +## Run Example + +The example is fully automated. Just run the following command in the prepared Python environment: + +```bash +python main.py +``` diff --git a/examples/llm_compression/openvino/tiny_llama_synthetic_data/main.py b/examples/llm_compression/openvino/tiny_llama_synthetic_data/main.py new file mode 100644 index 00000000000..bfd9dc24106 --- /dev/null +++ b/examples/llm_compression/openvino/tiny_llama_synthetic_data/main.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import datasets +import numpy as np +import openvino as ov +import torch +from optimum.intel.openvino import OVModelForCausalLM +from transformers import AutoTokenizer +from whowhatbench import Evaluator + +import nncf + +SEED = 0 + + +def transform_func(text, tokenizer, ov_model): + input_dtypes = {inp.get_any_name(): inp.get_element_type() for inp in ov_model.inputs} + tokens = tokenizer(text) + input_ids = np.expand_dims(np.array(tokens["input_ids"]), 0) + attention_mask = np.expand_dims(np.array(tokens["attention_mask"]), 0) + position_ids = np.cumsum(attention_mask, axis=1) - 1 + position_ids[attention_mask == 0] = 1 + res = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids.reshape(*attention_mask.shape), + } + + def gen_pkv(num_heads, head_dim, num_layers): + res = {} + shape = (1, num_heads, 0, head_dim) + for i in range(num_layers): + key_name = f"past_key_values.{i}.key" + val_name = f"past_key_values.{i}.value" + res[key_name] = ov.Tensor(shape=shape, type=input_dtypes[key_name]) + res[val_name] = ov.Tensor(shape=shape, type=input_dtypes[val_name]) + return res + + res.update(gen_pkv(4, 64, 22)) + return res + + +def compress_model(model, tokenizer, dataset): + quantization_dataset = nncf.Dataset(dataset, partial(transform_func, tokenizer=tokenizer, ov_model=model.model)) + + optimized_model = nncf.compress_weights( + model.model.clone(), + dataset=quantization_dataset, + mode=nncf.CompressWeightsMode.INT4_SYM, + ratio=1.0, + scale_estimation=True, + ) + return optimized_model + + +def validate_model(evaluator, hf_model, optimized_model, original_ov_model): + hf_model.model = optimized_model + hf_model.request = None + _, all_metrics = evaluator.score(hf_model) + hf_model.model = original_ov_model + hf_model.request = None + return all_metrics["similarity"][0] + + +def main(): + MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + hf_model = OVModelForCausalLM.from_pretrained( + MODEL_ID, export=True, load_in_8bit=False, compile=False, stateful=False + ) + + original_ov_model = hf_model.model.clone() + evaluator = Evaluator(hf_model, tokenizer=tokenizer, metrics=("similarity",)) + + # Wikitext-based compression + wikitext_dataset = datasets.load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + wikitext_dataset = [d["text"] for d in wikitext_dataset] + wikitext_optimized_model = compress_model(hf_model, tokenizer, wikitext_dataset) + + # Synthetic-based compression + saved_seed = torch.seed() + torch.manual_seed(SEED) + synthetic_dataset = nncf.data.generate_text_data(hf_model, tokenizer) + torch.manual_seed(saved_seed) + synthetic_optimized_model = compress_model(hf_model, tokenizer, synthetic_dataset) + + # Similarity comparison between Wikitext-based & Synthetic-based compressed models + wikitext_based_similarity = validate_model(evaluator, hf_model, wikitext_optimized_model, original_ov_model) + print(f"Wikitext-quantized model similarity: {wikitext_based_similarity}") + + synthetic_based_similarity = validate_model(evaluator, hf_model, synthetic_optimized_model, original_ov_model) + print(f"Synthetic-quantized model similarity: {synthetic_based_similarity}") + return wikitext_based_similarity, synthetic_based_similarity + + +if __name__ == "__main__": + main() diff --git a/examples/llm_compression/openvino/tiny_llama_synthetic_data/requirements.txt b/examples/llm_compression/openvino/tiny_llama_synthetic_data/requirements.txt new file mode 100644 index 00000000000..027f94e4a4a --- /dev/null +++ b/examples/llm_compression/openvino/tiny_llama_synthetic_data/requirements.txt @@ -0,0 +1,9 @@ +-c ../../../../constraints.txt +torch +datasets +whowhatbench @ git+https://github.com/openvinotoolkit/openvino.genai.git#subdirectory=llm_bench/python/who_what_benchmark +numpy>=1.23.5 +openvino==2024.4 +optimum-intel[openvino]>=1.13.0 +transformers>=4.35.2 +onnx<1.16.2 diff --git a/nncf/data/__init__.py b/nncf/data/__init__.py index f6dde64a3e4..b6acbcef923 100644 --- a/nncf/data/__init__.py +++ b/nncf/data/__init__.py @@ -10,3 +10,4 @@ # limitations under the License. from nncf.data.dataset import Dataset as Dataset +from nncf.data.generators import generate_text_data as generate_text_data diff --git a/nncf/data/generators.py b/nncf/data/generators.py new file mode 100644 index 00000000000..b8ca0edc947 --- /dev/null +++ b/nncf/data/generators.py @@ -0,0 +1,92 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, TypeVar + +import nncf +from nncf.common.logging.track_progress import track + +BASE_VOCAB_SIZE = 12000 + +TModel = TypeVar("TModel") +TTokenizer = TypeVar("TTokenizer") + + +def generate_text_data( + model: TModel, tokenizer: TTokenizer, seq_len: int = 32, dataset_size: int = 128, unique_tokens_lower_limit: int = 5 +) -> List[str]: + """ + Generates text dataset based on the model output. + + Since the model is required to be the instance of the PreTrainedModel + and the tokenizer is required to be the instance of the PreTrainedTokenizerBase, + environment must have `transformers` & `torch` modules installed to run this method. + + :param model: Model instance. + :param tokenizer: Tokenizer instance. + :param seq_len: Sequence length for generation. + :param dataset_size: Size of the data. + :return: List of the text data ready to use. + """ + + try: + import torch + except ImportError: + raise nncf.ModuleNotFoundError("torch is required in order to generate text data: `pip install torch`.") + + try: + from transformers import PreTrainedModel + from transformers import PreTrainedTokenizerBase + from transformers.utils import logging + + logging.set_verbosity_error() + except ImportError: + raise nncf.ModuleNotFoundError( + "transformers is required in order to generate text data: `pip install transformers`." + ) + + if not isinstance(model, PreTrainedModel.__bases__): + raise nncf.ValidationError("Model should be instance of the `transformers.PreTrainedModel`.") + + if not isinstance(tokenizer, PreTrainedTokenizerBase.__bases__): + raise nncf.ValidationError("tokenizer should be instance of the `transformers.PreTrainedTokenizerBase`.") + + generated_data = [] + + vocab_size_names = ["padded_vocab_size", "vocab_size"] + vocab_size = BASE_VOCAB_SIZE + for vocab_size_name in vocab_size_names: + if hasattr(model.config, vocab_size_name): + vocab_size = getattr(model.config, vocab_size_name) + + step_num = max(1, vocab_size // dataset_size) + ids_counter = 0 + + with track(total=dataset_size, description="Generating text data") as pbar: + while len(generated_data) < dataset_size: + # Creating the input for pre-generate step + input_ids = torch.tensor([[ids_counter % vocab_size]]).to(model.device) + + # Collecting data from the pre & post generate steps + outputs_prep = model.generate(input_ids, do_sample=False, max_length=seq_len // 2) + outputs_post = model.generate(outputs_prep, do_sample=True, max_length=seq_len + seq_len // 2) + gen_text = tokenizer.batch_decode(outputs_post[:, outputs_prep.shape[1] :], skip_special_tokens=True) + + if len(set(gen_text[0])) < unique_tokens_lower_limit: + ids_counter += 1 + continue + + ids_counter += step_num + + pbar.progress.update(pbar.task, advance=1) + generated_data.extend(gen_text) + + return generated_data diff --git a/tests/torch/data/ref_generated_data.json b/tests/torch/data/ref_generated_data.json new file mode 100644 index 00000000000..4f925032c61 --- /dev/null +++ b/tests/torch/data/ref_generated_data.json @@ -0,0 +1,90 @@ +[ + [ + 863, + 297, + 50, + 718, + 570, + 806, + 681, + 226, + 964, + 350, + 686, + 780 + ], + [ + 227, + 878, + 47, + 780, + 473, + 242, + 347, + 799, + 157, + 123, + 121, + 362, + 992, + 477 + ], + [ + 84, + 815, + 503, + 200, + 906, + 261, + 80, + 211, + 76, + 27, + 638, + 157, + 123, + 121 + ], + [ + 468, + 468, + 939, + 829, + 769, + 769, + 353, + 425, + 712, + 687, + 686, + 780 + ], + [ + 438, + 932, + 664, + 39, + 932, + 54, + 536, + 641, + 33, + 433, + 926, + 711 + ], + [ + 371, + 836, + 836, + 895, + 308, + 811, + 64, + 690, + 737, + 218, + 242, + 897 + ] +] \ No newline at end of file diff --git a/tests/torch/test_dataset_generators.py b/tests/torch/test_dataset_generators.py new file mode 100644 index 00000000000..6e888232c5f --- /dev/null +++ b/tests/torch/test_dataset_generators.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer + +import nncf +from nncf.data import generate_text_data +from tests.cross_fw.shared.helpers import load_json +from tests.cross_fw.shared.paths import TEST_ROOT +from tests.torch.helpers import set_torch_seed + +BASE_TEST_MODEL_ID = "hf-internal-testing/tiny-random-gpt2" +GENERATED_TEXT_REF = TEST_ROOT / "torch" / "data" / "ref_generated_data.json" + + +@pytest.mark.parametrize( + "model, tokenizer, usage_error", + [ + [None, None, True], + [AutoModelForCausalLM.from_pretrained(BASE_TEST_MODEL_ID), None, True], + [None, AutoTokenizer.from_pretrained(BASE_TEST_MODEL_ID), True], + [ + AutoModelForCausalLM.from_pretrained(BASE_TEST_MODEL_ID), + AutoTokenizer.from_pretrained(BASE_TEST_MODEL_ID), + False, + ], + ], +) +def test_generate_text_data_usage(model, tokenizer, usage_error): + try: + with set_torch_seed(0): + generate_text_data(model, tokenizer, seq_len=2, dataset_size=1) + except Exception as e: + if usage_error: + assert isinstance(e, nncf.ValidationError), "Expected exception." + + +def test_generate_text_data_functional(): + seq_len = 12 + max_seq_len = seq_len + seq_len // 2 + dataset_size = 6 + + model = AutoModelForCausalLM.from_pretrained(BASE_TEST_MODEL_ID) + tokenizer = AutoTokenizer.from_pretrained(BASE_TEST_MODEL_ID) + + with set_torch_seed(0): + generated_data = generate_text_data( + model, + tokenizer, + seq_len=seq_len, + dataset_size=dataset_size, + ) + + assert len(generated_data) == dataset_size + generated_data = [tokenizer.encode(d) for d in generated_data] + + # Uncomment lines below to generate reference for new models. + # from tests.shared.helpers import dump_to_json + # dump_to_json(GENERATED_TEXT_REF, generated_data) + + reference_data = load_json(GENERATED_TEXT_REF) + for ref_data, gen_data in zip(reference_data, generated_data): + assert len(gen_data) <= max_seq_len + assert ref_data == gen_data