diff --git a/src/deepsparse/evaluation/cli.py b/src/deepsparse/evaluation/cli.py index be56027e99..674859cace 100644 --- a/src/deepsparse/evaluation/cli.py +++ b/src/deepsparse/evaluation/cli.py @@ -75,7 +75,7 @@ from src.deepsparse.evaluation.integrations import ( # noqa: F401 try_import_llm_evaluation_harness, ) -from src.deepsparse.evaluation.results import Result, save_evaluations +from src.deepsparse.evaluation.results import Result, save_result from src.deepsparse.evaluation.utils import args_to_dict, get_save_path from src.deepsparse.pipeline import DEEPSPARSE_ENGINE, ORT_ENGINE, TORCHSCRIPT_ENGINE @@ -210,7 +210,7 @@ def main( **integration_args, ) - _LOGGER.info(f"Evaluation done. Results:\n{result}") + _LOGGER.info(f"Evaluation done. Result:\n{result.formatted}") save_path = get_save_path( save_path=save_path, @@ -219,8 +219,8 @@ def main( ) if save_path: _LOGGER.info(f"Saving the evaluation results to {save_path}") - save_evaluations( - evaluations=result.formatted, + save_result( + result=result, save_path=save_path, save_format=type_serialization, ) diff --git a/src/deepsparse/evaluation/integrations/llm_evaluation_harness.py b/src/deepsparse/evaluation/integrations/llm_evaluation_harness.py index 8c311204cd..193f828659 100644 --- a/src/deepsparse/evaluation/integrations/llm_evaluation_harness.py +++ b/src/deepsparse/evaluation/integrations/llm_evaluation_harness.py @@ -101,13 +101,29 @@ def integration_eval( results_raw = evaluator.simple_evaluate(**evaluator_input.dict()) results = Result( - raw=dict(output=results_raw, input=evaluator_input), + raw=dict(output=results_raw, input=filter_evaluator_input(evaluator_input)), formatted=format_raw_results(results_raw), ) return results +def filter_evaluator_input( + evaluator_input: "EvaluatorInputSchema", +) -> Dict[str, Any]: # noqa: F821 + """ + Filter the evaluator input to remove the model field. + The model field is a complex object that cannot be serialized. + + :param evaluator_input: the evaluator input to filter + :return: the filtered evaluator input + """ + evaluator = evaluator_input.dict() + del evaluator["model"] + + return evaluator + + def format_raw_results(results: Dict[str, Any]) -> List[Evaluation]: """ Format the raw results from llm_evaluation_harness into a list of diff --git a/src/deepsparse/evaluation/results.py b/src/deepsparse/evaluation/results.py index ad220d1791..8740247dde 100644 --- a/src/deepsparse/evaluation/results.py +++ b/src/deepsparse/evaluation/results.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -from collections import OrderedDict from typing import Any, List, Optional import yaml @@ -28,7 +26,7 @@ "EvalSample", "Evaluation", "Result", - "save_evaluations", + "save_result", ] @@ -63,68 +61,43 @@ class Evaluation(BaseModel): class Result(BaseModel): formatted: List[Evaluation] = Field( - description="Evaluation results represented in the unified, structured format" + description="Evaluation result represented in the unified, structured format" ) raw: Any = Field( - description="Evaluation results represented in the raw format " + description="Evaluation result represented in the raw format " "(characteristic for the specific evaluation integration)" ) - def __str__(self): - """ - The string representation of the Result object is - the formatted evaluation results serialized in JSON. - """ - return save_evaluations(self.formatted, save_format="json", save_path=None) - -def save_evaluations( - evaluations: List[Evaluation], save_format: str = "json", save_path: str = None +def save_result( + result: Result, + save_path: str, + save_format: str = "json", ): """ Saves a list of Evaluation objects to a file in the specified format. - :param evaluations: List of Evaluation objects to save + :param result: Result object to save :param save_format: Format to save the evaluations in. :param save_path: Path to save the evaluations to. - If None, the evaluations will not be saved. :return: The serialized evaluations """ - # serialize the evaluations - evaluations: List[Evaluation] = prep_for_serialization(evaluations) - # convert to ordered dicts to preserve order - evaluations: List[OrderedDict] = evaluations_to_dicts(evaluations) + # prepare the Result object for serialization + result: Result = prep_for_serialization(result) if save_format == "json": - return _save_to_json(evaluations, save_path) + _save_to_json(result, save_path) elif save_format == "yaml": - return _save_to_yaml(evaluations, save_path) + _save_to_yaml(result, save_path) else: NotImplementedError("Currently only json and yaml formats are supported") -def _save_to_json(evaluations: List[OrderedDict], save_path: Optional[str]) -> str: - data = json.dumps(evaluations, indent=4) - if save_path: - _save(data, save_path, expected_ext=".json") - return data - - -def _save_to_yaml(evaluations: List[OrderedDict], save_path: Optional[str]) -> str: - # required to properly process OrderedDicts - yaml.add_representer( - OrderedDict, - lambda dumper, data: dumper.represent_mapping( - "tag:yaml.org,2002:map", data.items() - ), - ) - data = yaml.dump(evaluations, default_flow_style=False) - if save_path: - _save(data, save_path, expected_ext=".yaml") - return data +def _save_to_json(result: Result, save_path: str): + _save(result.json(), save_path, expected_ext=".json") -def evaluations_to_dicts(evaluations: List[Evaluation]): - return [OrderedDict(**evaluation.dict()) for evaluation in evaluations] +def _save_to_yaml(result: Result, save_path: str): + _save(yaml.dump(result.dict()), save_path, expected_ext=".yaml") def _save(data: str, save_path: str, expected_ext: str): diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index c7338322f2..20d08a5f3b 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -834,6 +834,11 @@ def engine_forward( generated_tokens.append(token) generated_logits.append(logits) + if session.total_num_processed_tokens >= session.capacity: + # if the kv cache is full, stop generation + finished_reason.append(FinishReason.CAPACITY) + break + if ( token == self.tokenizer.eos_token_id and not self.force_max_tokens diff --git a/src/deepsparse/version.py b/src/deepsparse/version.py index 8cf2c09834..848f460af3 100644 --- a/src/deepsparse/version.py +++ b/src/deepsparse/version.py @@ -39,7 +39,7 @@ from deepsparse.generated_version import is_enterprise, is_release, splash, version except Exception: # otherwise, fall back to version info in this file - version = "1.6.0" + version = "1.7.0" is_release = False is_enterprise = False splash = ( diff --git a/src/deepsparse/yolov8/annotate.py b/src/deepsparse/yolov8/annotate.py index 3140311d1e..22329481ca 100644 --- a/src/deepsparse/yolov8/annotate.py +++ b/src/deepsparse/yolov8/annotate.py @@ -13,7 +13,7 @@ # limitations under the License. """ -Usage: deepsparse.object_detection.annotate [OPTIONS] +Usage: deepsparse.yolov8.annotate [OPTIONS] Annotation Script for YOLOv8 with DeepSparse @@ -21,14 +21,16 @@ --model_filepath, --model-filepath TEXT Path/SparseZoo stub to the model file to be used for annotation - --source TEXT File path to an image or directory of image + --subtask TEXT A subtask to run the YOLOv8 model on. + Defaults to 'detection' + --source TEXT File path to image or directory of .jpg files, a .mp4 video, or an integer (i.e. 0) for webcam [required] --engine [deepsparse|onnxruntime|torch] Inference engine backend to run on. Choices are 'deepsparse', 'onnxruntime', and 'torch'. Default is 'deepsparse' - --model_input_image_shape, --model-input-shape INTEGER... + --model_input_image_shape, --model-input-image-shape INTEGER... Image shape to override model with for inference, must be two integers --num_cores, --num-cores INTEGER @@ -51,8 +53,8 @@ then it will be ignored --no_save, --no-save Set flag when source is from webcam to not save results.Not supported for non-webcam - sources [default: False] - --help Show this message and exit + sources + --help Show this message and exit. ####### Examples: diff --git a/tests/deepsparse/evaluation/test_results.py b/tests/deepsparse/evaluation/test_results.py index efd5c3cf7c..33274e9a69 100644 --- a/tests/deepsparse/evaluation/test_results.py +++ b/tests/deepsparse/evaluation/test_results.py @@ -23,7 +23,8 @@ EvalSample, Evaluation, Metric, - save_evaluations, + Result, + save_result, ) @@ -56,124 +57,22 @@ def evaluations(): @pytest.fixture() -def evaluations_json(): - return """[ - { - "task": "task_1", - "dataset": { - "type": "type_1", - "name": "name_1", - "config": "config_1", - "split": "split_1" - }, - "metrics": [ - { - "name": "metric_name_1", - "value": 1.0 - } - ], - "samples": [ - { - "input": [ - [ - 5 - ] - ], - "output": 5 - } - ] - }, - { - "task": "task_2", - "dataset": { - "type": "type_2", - "name": "name_2", - "config": "config_2", - "split": "split_2" - }, - "metrics": [ - { - "name": "metric_name_2", - "value": 2.0 - }, - { - "name": "metric_name_3", - "value": 3.0 - } - ], - "samples": [ - { - "input": [ - [ - 10.0 - ] - ], - "output": 10.0 - }, - { - "input": [ - [ - 20.0 - ] - ], - "output": 20.0 - } - ] - } -]""" # noqa: E501 +def result(evaluations): + return Result(formatted=evaluations, raw="dummy_raw_evaluation") -@pytest.fixture() -def evaluations_yaml(): - return """- task: task_1 - dataset: - config: config_1 - name: name_1 - split: split_1 - type: type_1 - metrics: - - name: metric_name_1 - value: 1.0 - samples: - - input: - - - 5 - output: 5 -- task: task_2 - dataset: - config: config_2 - name: name_2 - split: split_2 - type: type_2 - metrics: - - name: metric_name_2 - value: 2.0 - - name: metric_name_3 - value: 3.0 - samples: - - input: - - - 10.0 - output: 10.0 - - input: - - - 20.0 - output: 20.0 -""" - - -def test_serialize_evaluation_json(tmp_path, evaluations, evaluations_json): +def test_serialize_result_json(tmp_path, result): path_to_file = tmp_path / "result.json" - evaluations_serialized = save_evaluations( - evaluations=evaluations, save_format="json", save_path=path_to_file.as_posix() - ) + save_result(result=result, save_format="json", save_path=path_to_file.as_posix()) + with open(path_to_file.as_posix(), "r") as f: - assert json.load(f) - assert evaluations_serialized == evaluations_json + reloaded_results = json.load(f) + assert reloaded_results == result.dict() -def test_serialize_evaluation_yaml(tmp_path, evaluations, evaluations_yaml): +def test_serialize_result_yaml(tmp_path, result): path_to_file = tmp_path / "result.yaml" - evaluations_serialized = save_evaluations( - evaluations=evaluations, save_format="yaml", save_path=path_to_file.as_posix() - ) + save_result(result=result, save_format="yaml", save_path=path_to_file.as_posix()) with open(path_to_file.as_posix(), "r") as f: - assert yaml.safe_load(f) - assert evaluations_serialized == evaluations_yaml + reloaded_results = yaml.safe_load(f) + assert reloaded_results == result.dict() diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index ad9526d54a..ba2a52c40e 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -55,6 +55,118 @@ def test_run_same_prompt_multiple_times(pipeline, prompt): ) +def _test_stop_inference_kv_cache_full( + pipeline, + prompt, + max_new_tokens, + expected_finished_reason, + expected_generated_tokens_length=None, +): + out = pipeline(prompt=prompt, max_new_tokens=max_new_tokens) + kv_cache_state = out.kv_cache_state[0] + finished_reason = out.generations[0].finished_reason + generated_text = out.generations[0].text + assert finished_reason == expected_finished_reason + assert len(pipeline.tokenizer(generated_text)["input_ids"]) == ( + expected_generated_tokens_length or max_new_tokens + ) + return kv_cache_state + + +def test_stop_inference_kv_cache_full(prompt): + # Tests the proper behavior of the kv cache around the + # scenario when the kv cache becomes full during the inference + + # We set the sequence length to a small value to assert that + # the kv cache buffer fills up quickly + sequence_length = 32 + # We set the prompt sequence length to 1 to assert that the + # inference will run until the kv cache is full. If the + # `prompt_sequence_length` is larger than 1, it is very probable + # that the inference will stop before the kv cache is full + # (as the `prompt_sequence_length` reduces the number of + # tokens that are generated in the first iteration) + prompt_sequence_length = 1 + + pipeline = Pipeline.create( + task="text_generation", + model_path="hf:mgoin/TinyStories-1M-deepsparse", + engine_type="onnxruntime", + sequence_length=sequence_length, + force_max_tokens=True, + prompt_sequence_length=prompt_sequence_length, + ) + pipeline._debug = True + + prompt_length = len(pipeline.tokenizer(prompt)["input_ids"]) + + cache_capacity = sequence_length - prompt_sequence_length + # we need to subtract 1 to account for the initial generated token during the + # prompt inference + cache_capacity -= 1 + + # max_new_tokens so that there is still one more "free" space in the kv cache + # (we can still do autoregressive inference) + max_new_tokens_minus_one = cache_capacity - prompt_length - 1 + # max_new_tokens so that the kv cache is full + # (so we can still do one last correct autoregressive + # inference in the next iteration) + max_new_tokens = cache_capacity - prompt_length + # max_new_tokens so that kv cache has already removed the last entry + # (so we can no longer do autoregressive inference in the next iteration) + max_new_tokens_plus_one = cache_capacity - prompt_length + 1 + # max_new_tokens so that kv cache would remove two last entries + # (but it will not, the inference terminates early and produces + # the same result as max_new_tokens_plus_one) + max_new_tokens_plus_two = cache_capacity - prompt_length + 2 + + kv_cache_state_full_minus_one = _test_stop_inference_kv_cache_full( + pipeline, + prompt, + max_new_tokens_minus_one, + expected_finished_reason="max_new_tokens", + ) + kv_cache_state_full = _test_stop_inference_kv_cache_full( + pipeline, prompt, max_new_tokens, expected_finished_reason="max_new_tokens" + ) + kv_cache_state_full_plus_one = _test_stop_inference_kv_cache_full( + pipeline, prompt, max_new_tokens_plus_one, expected_finished_reason="capacity" + ) + kv_cache_state_full_plus_two = _test_stop_inference_kv_cache_full( + pipeline, + prompt, + max_new_tokens_plus_two, + expected_generated_tokens_length=max_new_tokens_plus_one, + expected_finished_reason="capacity", + ) + """ + Check the following structure ok the kv cache: + minus_one | full | plus_one | plus_two + -------------------------------------- + [- 0 -] | [row A] | [row B] | [row B] + [row A] | [row B] | [row C] | [row C] + [row B] | [row C] | [row D] | [row D] + ... | ... | ... | ... + """ + # check for the "free" space in the kv cache + assert kv_cache_state_full_minus_one["past_key_values.0.key"][:, :, 0, :].sum() == 0 + # check for the row A + assert numpy.array_equal( + kv_cache_state_full_minus_one["past_key_values.0.key"][:, :, 1, :], + kv_cache_state_full["past_key_values.0.key"][:, :, 0, :], + ) + # check for the row B + assert numpy.array_equal( + kv_cache_state_full["past_key_values.0.key"][:, :, 1, :], + kv_cache_state_full_plus_one["past_key_values.0.key"][:, :, 0, :], + ) + # check equality between plus_one and plus_two + assert numpy.array_equal( + kv_cache_state_full_plus_one["past_key_values.0.key"], + kv_cache_state_full_plus_two["past_key_values.0.key"], + ) + + def test_run_multiple_prompts_in_parallel(pipeline, prompt): # Test the scenario, where multiple prompts are run in parallel # Same two prompts should produce the same output