Skip to content

Commit

Permalink
Synthetic data updates (#2999)
Browse files Browse the repository at this point in the history
### Changes

- Updated `Usage.md` with a new method description.
- Extended `test_examples` with the example.

### Reason for changes

- On top of #2979.

### Related tickets

- 152550

### Tests

- test_examples/531/ - success
- test_examples_3.10/117/ - success
  • Loading branch information
KodiaqQ authored Oct 7, 2024
1 parent b52bf6c commit 3b2b8c3
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ nncf_dataset = nncf.Dataset(data_source, transform_fn)
compressed_model = compress_weights(model, mode=CompressWeightsMode.INT4_SYM, ratio=0.8, dataset=nncf_dataset) # model is openvino.Model object
```

- Additionally, it is possible to generate a synthetic dataset by the `nncf.data.generate_text_data` method to use it in the data-aware weight compression. The method takes a language model (e.g. from `optimum.intel.openvino`) and a tokenizer (e.g. from `transformers`) as input and returns the list of strings generated by the model. Note that the dataset generation takes time and depends on various conditions like the model size, requested dataset length or environment setup. Also, since the dataset is generated by the model output, it does not guarantee significant accuracy improvement after the compression. This method is recommended only in cases when a better dataset is not available. Refer to the [example](https://github.com/openvinotoolkit/nncf/tree/develop/examples/llm_compression/openvino/tiny_llama_synthetic_data/) for details of the usage.

```python
from nncf import compress_weights, CompressWeightsMode, Dataset
from nncf.data import generate_text_data
synthetic_data = generate_text_data(model, tokenizer)
nncf_dataset = nncf.Dataset(synthetic_data, transform_fn)
```

- Accuracy of the 4-bit compressed models also can be improved by using AWQ, Scale Estimation, GPTQ or Lora Correction algorithms over data-based mixed-precision algorithm. These algorithms work by equalizing a subset of weights to minimize the difference between the original precision and the 4-bit precision.
Unlike all others, the Lora Correction algorithm inserts an additional Linear layers for reducing quantization noise and further accuracy improvement. Inevitably, this approach introduces a memory and a runtime overheads, but they are negligible, since the inserted weight much smaller and can be quantized to 8-bit. The AWQ, Scale Estimation (SE) and Lora Correction (LC) algo can be used in any combination together: AWQ + SE, AWQ + LC, SE + LC, AWQ + SE + LC. The GPTQ algorithm can be combined with AWQ and Scale Estimation in any combination: AWQ + GPTQ, GPTQ + SE, AWQ + GPTQ + SE. Below are examples demonstrating how to enable the AWQ, Scale Estimation, GPTQ or Lora Correction algorithms:

Expand Down
63 changes: 24 additions & 39 deletions examples/llm_compression/openvino/tiny_llama_synthetic_data/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@

from functools import partial

import datasets
import numpy as np
import openvino as ov
import torch
from optimum.intel.openvino import OVModelForCausalLM
from transformers import AutoTokenizer
from whowhatbench import Evaluator

import nncf

Expand Down Expand Up @@ -51,28 +49,6 @@ def gen_pkv(num_heads, head_dim, num_layers):
return res


def compress_model(model, tokenizer, dataset):
quantization_dataset = nncf.Dataset(dataset, partial(transform_func, tokenizer=tokenizer, ov_model=model.model))

optimized_model = nncf.compress_weights(
model.model.clone(),
dataset=quantization_dataset,
mode=nncf.CompressWeightsMode.INT4_SYM,
ratio=1.0,
scale_estimation=True,
)
return optimized_model


def validate_model(evaluator, hf_model, optimized_model, original_ov_model):
hf_model.model = optimized_model
hf_model.request = None
_, all_metrics = evaluator.score(hf_model)
hf_model.model = original_ov_model
hf_model.request = None
return all_metrics["similarity"][0]


def main():
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

Expand All @@ -81,28 +57,37 @@ def main():
MODEL_ID, export=True, load_in_8bit=False, compile=False, stateful=False
)

original_ov_model = hf_model.model.clone()
evaluator = Evaluator(hf_model, tokenizer=tokenizer, metrics=("similarity",))

# Wikitext-based compression
wikitext_dataset = datasets.load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
wikitext_dataset = [d["text"] for d in wikitext_dataset]
wikitext_optimized_model = compress_model(hf_model, tokenizer, wikitext_dataset)
dataset_size = 100

# Synthetic-based compression
saved_seed = torch.seed()
torch.manual_seed(SEED)
synthetic_dataset = nncf.data.generate_text_data(hf_model, tokenizer)
synthetic_dataset = nncf.data.generate_text_data(hf_model, tokenizer, dataset_size=dataset_size)
quantization_dataset = nncf.Dataset(
synthetic_dataset, partial(transform_func, tokenizer=tokenizer, ov_model=hf_model.model)
)
hf_model.request = None
torch.manual_seed(saved_seed)
synthetic_optimized_model = compress_model(hf_model, tokenizer, synthetic_dataset)

# Similarity comparison between Wikitext-based & Synthetic-based compressed models
wikitext_based_similarity = validate_model(evaluator, hf_model, wikitext_optimized_model, original_ov_model)
print(f"Wikitext-quantized model similarity: {wikitext_based_similarity}")
optimized_model = nncf.compress_weights(
hf_model.model.clone(),
dataset=quantization_dataset,
mode=nncf.CompressWeightsMode.INT4_SYM,
ratio=1.0,
scale_estimation=True,
)

# Verify the model output in comparison to floating-point one
input_ids = tokenizer("What is Python? ", return_tensors="pt").to(device=hf_model.device)
max_new_tokens = 100

hf_model.model = optimized_model
hf_model.request = None
opt_output = hf_model.generate(**input_ids, max_new_tokens=max_new_tokens)
opt_output_text = tokenizer.decode(opt_output[0])

synthetic_based_similarity = validate_model(evaluator, hf_model, synthetic_optimized_model, original_ov_model)
print(f"Synthetic-quantized model similarity: {synthetic_based_similarity}")
return wikitext_based_similarity, synthetic_based_similarity
print(f"Optimized model output: {opt_output_text}\n")
return opt_output_text


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
-c ../../../../constraints.txt
torch
datasets
whowhatbench @ git+https://github.com/openvinotoolkit/openvino.genai.git#subdirectory=llm_bench/python/who_what_benchmark
numpy>=1.23.5
openvino==2024.4
optimum-intel[openvino]>=1.13.0
Expand Down
8 changes: 8 additions & 0 deletions tests/cross_fw/examples/example_scope.json
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@
"group_size": 128
}
},
"llm_compression_synthetic": {
"backend": "openvino",
"requirements": "examples/llm_compression/openvino/tiny_llama_synthetic_data/requirements.txt",
"cpu": "Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz",
"accuracy_metrics": {
"word_count": 83
}
},
"quantization_aware_training_torch_anomalib": {
"backend": "torch",
"requirements": "examples/quantization_aware_training/torch/anomalib/requirements.txt",
Expand Down
8 changes: 8 additions & 0 deletions tests/cross_fw/examples/run_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ def llm_tune_params() -> Dict[str, float]:
return {"awq": bool(awq), "ratio": ratio, "group_size": group_size}


def llm_compression_synthetic() -> Dict[str, float]:
from examples.llm_compression.openvino.tiny_llama_synthetic_data.main import main as llm_compression_synthetic_main

result = llm_compression_synthetic_main()

return {"word_count": len(result.split())}


def quantization_aware_training_torch_resnet18():
from examples.quantization_aware_training.torch.resnet18.main import main as resnet18_main

Expand Down

0 comments on commit 3b2b8c3

Please sign in to comment.