From 5338421db717af4a43459e7f7024cf9858328707 Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Tue, 12 Mar 2024 15:13:51 -0700 Subject: [PATCH 1/7] fix: default params default params need to be applied Signed-off-by: Ishaan Sehgal --- presets/inference/text-generation/inference_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presets/inference/text-generation/inference_api.py b/presets/inference/text-generation/inference_api.py index 73c7b5095..ddcf2780f 100644 --- a/presets/inference/text-generation/inference_api.py +++ b/presets/inference/text-generation/inference_api.py @@ -157,7 +157,7 @@ class UnifiedRequestModel(BaseModel): clean_up_tokenization_spaces: Optional[bool] = Field(False, description="Clean up extra spaces in text output") prefix: Optional[str] = Field(None, description="Prefix added to prompt") handle_long_generation: Optional[str] = Field(None, description="Strategy to handle long generation") - generate_kwargs: Optional[GenerateKwargs] = Field(None, description="Additional kwargs for generate method") + generate_kwargs: Optional[GenerateKwargs] = Field(default_factory=GenerateKwargs, description="Additional kwargs for generate method") # Field for conversational model messages: Optional[List[Dict[str, str]]] = Field(None, description="Messages for conversational model") From d8a060d38fcc02377e52b896c26ac1cb16a442c0 Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Tue, 12 Mar 2024 16:03:54 -0700 Subject: [PATCH 2/7] fix: fix defaults Signed-off-by: Ishaan Sehgal --- .../inference/text-generation/inference_api.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/presets/inference/text-generation/inference_api.py b/presets/inference/text-generation/inference_api.py index ddcf2780f..bf739844d 100644 --- a/presets/inference/text-generation/inference_api.py +++ b/presets/inference/text-generation/inference_api.py @@ -123,30 +123,18 @@ def health_check(): return {"status": "Healthy"} class GenerateKwargs(BaseModel): - max_length: int = 200 + max_length: int = 200 # Length of input prompt+max_new_tokens min_length: int = 0 - do_sample: bool = True + do_sample: bool = False early_stopping: bool = False num_beams: int = 1 - num_beam_groups: int = 1 - diversity_penalty: float = 0.0 temperature: float = 1.0 - top_k: int = 10 + top_k: int = 50 top_p: float = 1 typical_p: float = 1 repetition_penalty: float = 1 - length_penalty: float = 1 - no_repeat_ngram_size: int = 0 - encoder_no_repeat_ngram_size: int = 0 - bad_words_ids: Optional[List[int]] = None - num_return_sequences: int = 1 - output_scores: bool = False - return_dict_in_generate: bool = False pad_token_id: Optional[int] = tokenizer.pad_token_id eos_token_id: Optional[int] = tokenizer.eos_token_id - forced_bos_token_id: Optional[int] = None - forced_eos_token_id: Optional[int] = None - remove_invalid_values: Optional[bool] = None class Config: extra = Extra.allow # Allows for additional fields not explicitly defined From 25d128c41ee1a68d07e53803232420b3404357bb Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Tue, 12 Mar 2024 16:16:42 -0700 Subject: [PATCH 3/7] feat: add UT for defaults Signed-off-by: Ishaan Sehgal --- .../tests/test_inference_api.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/presets/inference/text-generation/tests/test_inference_api.py b/presets/inference/text-generation/tests/test_inference_api.py index ff9866bbb..df7628c7b 100644 --- a/presets/inference/text-generation/tests/test_inference_api.py +++ b/presets/inference/text-generation/tests/test_inference_api.py @@ -127,3 +127,39 @@ def test_get_metrics_no_gpus(configured_app): response = client.get("/metrics") assert response.status_code == 200 assert response.json()["gpu_info"] == [] + +def test_default_generation_params(configured_app): + if configured_app.test_config['pipeline'] != 'text-generation': + pytest.skip("Skipping non-text-generation tests") + + client = TestClient(configured_app) + + request_data = { + "prompt": "Test default params", + "return_full_text": True, + "clean_up_tokenization_spaces": False + # Note: generate_kwargs is not provided, so defaults should be used + } + + with patch('inference_api.pipeline') as mock_pipeline: + mock_pipeline.return_value = [{"generated_text": "Mocked response"}] # Mock the response of the pipeline function + + response = client.post("/chat", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert "Result" in data + assert len(data["Result"]) > 0 + + # Check the default args + _, kwargs = mock_pipeline.call_args + assert kwargs['max_length'] == 200 + assert kwargs['min_length'] == 0 + assert kwargs['do_sample'] is True + assert kwargs['temperature'] == 1.0 + assert kwargs['top_k'] == 50 + assert kwargs['top_p'] == 1 + assert kwargs['typical_p'] == 1 + assert kwargs['repetition_penalty'] == 1 + assert kwargs['num_beams'] == 1 + assert kwargs['early_stopping'] is False From 18d88b6ebc6267eca8bbb03ff7f7f3096bc5c4d5 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Tue, 12 Mar 2024 16:20:54 -0700 Subject: [PATCH 4/7] fix: update default tests --- presets/inference/text-generation/tests/test_inference_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presets/inference/text-generation/tests/test_inference_api.py b/presets/inference/text-generation/tests/test_inference_api.py index df7628c7b..797618427 100644 --- a/presets/inference/text-generation/tests/test_inference_api.py +++ b/presets/inference/text-generation/tests/test_inference_api.py @@ -162,4 +162,4 @@ def test_default_generation_params(configured_app): assert kwargs['typical_p'] == 1 assert kwargs['repetition_penalty'] == 1 assert kwargs['num_beams'] == 1 - assert kwargs['early_stopping'] is False + assert kwargs['early_stopping'] is False \ No newline at end of file From 174cacc32cab0b23ddd2d7e336610e91d494755e Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Tue, 12 Mar 2024 16:29:37 -0700 Subject: [PATCH 5/7] fix: param --- presets/inference/text-generation/tests/test_inference_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presets/inference/text-generation/tests/test_inference_api.py b/presets/inference/text-generation/tests/test_inference_api.py index 797618427..dbf6c494e 100644 --- a/presets/inference/text-generation/tests/test_inference_api.py +++ b/presets/inference/text-generation/tests/test_inference_api.py @@ -155,7 +155,7 @@ def test_default_generation_params(configured_app): _, kwargs = mock_pipeline.call_args assert kwargs['max_length'] == 200 assert kwargs['min_length'] == 0 - assert kwargs['do_sample'] is True + assert kwargs['do_sample'] is False assert kwargs['temperature'] == 1.0 assert kwargs['top_k'] == 50 assert kwargs['top_p'] == 1 From b12530e3096cd5ae7ddc2b29dfc738413c76ea9d Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Tue, 12 Mar 2024 16:30:39 -0700 Subject: [PATCH 6/7] Update test_inference_api.py Signed-off-by: Ishaan Sehgal --- presets/inference/text-generation/tests/test_inference_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presets/inference/text-generation/tests/test_inference_api.py b/presets/inference/text-generation/tests/test_inference_api.py index dbf6c494e..1484d4786 100644 --- a/presets/inference/text-generation/tests/test_inference_api.py +++ b/presets/inference/text-generation/tests/test_inference_api.py @@ -162,4 +162,4 @@ def test_default_generation_params(configured_app): assert kwargs['typical_p'] == 1 assert kwargs['repetition_penalty'] == 1 assert kwargs['num_beams'] == 1 - assert kwargs['early_stopping'] is False \ No newline at end of file + assert kwargs['early_stopping'] is False From 940cda199fdef4bec5f52b54bc0c881537be69c7 Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Tue, 12 Mar 2024 18:59:18 -0700 Subject: [PATCH 7/7] fix: Add UTs for length generation Signed-off-by: Ishaan Sehgal --- .../tests/test_inference_api.py | 73 ++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/presets/inference/text-generation/tests/test_inference_api.py b/presets/inference/text-generation/tests/test_inference_api.py index 1484d4786..d6506b08b 100644 --- a/presets/inference/text-generation/tests/test_inference_api.py +++ b/presets/inference/text-generation/tests/test_inference_api.py @@ -6,6 +6,7 @@ import pytest import torch from fastapi.testclient import TestClient +from transformers import AutoTokenizer # Get the parent directory of the current file parent_dir = str(Path(__file__).resolve().parent.parent) @@ -149,7 +150,7 @@ def test_default_generation_params(configured_app): assert response.status_code == 200 data = response.json() assert "Result" in data - assert len(data["Result"]) > 0 + assert data["Result"] == "Mocked response", "The response content doesn't match the expected mock response" # Check the default args _, kwargs = mock_pipeline.call_args @@ -163,3 +164,73 @@ def test_default_generation_params(configured_app): assert kwargs['repetition_penalty'] == 1 assert kwargs['num_beams'] == 1 assert kwargs['early_stopping'] is False + +def test_generation_with_max_length(configured_app): + if configured_app.test_config['pipeline'] != 'text-generation': + pytest.skip("Skipping non-text-generation tests") + + client = TestClient(configured_app) + prompt = "This prompt requests a response of a certain minimum length to test the functionality." + avg_res_len = 15 + max_length = 40 # Set to lower than default (200) to prevent test hanging + + request_data = { + "prompt": prompt, + "return_full_text": True, + "clean_up_tokenization_spaces": False, + "generate_kwargs": {"max_length": max_length} + } + + response = client.post("/chat", json=request_data) + + assert response.status_code == 200 + data = response.json() + print("Response: ", data["Result"]) + assert "Result" in data, "The response should contain a 'Result' key" + + tokenizer = AutoTokenizer.from_pretrained(configured_app.test_config['model_path']) + prompt_tokens = tokenizer.tokenize(prompt) + total_tokens = tokenizer.tokenize(data["Result"]) # data["Result"] includes the input prompt + + prompt_tokens_len = len(prompt_tokens) + max_new_tokens = max_length - prompt_tokens_len + new_tokens = len(total_tokens) - prompt_tokens_len + + assert avg_res_len <= new_tokens, f"Ideally response should generate at least 15 tokens" + assert new_tokens <= max_new_tokens, "Response must not generate more than max new tokens" + assert len(total_tokens) <= max_length, "Total # of tokens has to be less than or equal to max_length" + +def test_generation_with_min_length(configured_app): + if configured_app.test_config['pipeline'] != 'text-generation': + pytest.skip("Skipping non-text-generation tests") + + client = TestClient(configured_app) + prompt = "This prompt requests a response of a certain minimum length to test the functionality." + min_length = 30 + max_length = 40 + + request_data = { + "prompt": prompt, + "return_full_text": True, + "clean_up_tokenization_spaces": False, + "generate_kwargs": {"min_length": min_length, "max_length": max_length} + } + + response = client.post("/chat", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert "Result" in data, "The response should contain a 'Result' key" + + tokenizer = AutoTokenizer.from_pretrained(configured_app.test_config['model_path']) + prompt_tokens = tokenizer.tokenize(prompt) + total_tokens = tokenizer.tokenize(data["Result"]) # data["Result"] includes the input prompt + + prompt_tokens_len = len(prompt_tokens) + + min_new_tokens = min_length - prompt_tokens_len + max_new_tokens = max_length - prompt_tokens_len + new_tokens = len(total_tokens) - prompt_tokens_len + + assert min_new_tokens <= new_tokens <= max_new_tokens, "Response should generate at least min_new_tokens and at most max_new_tokens new tokens" + assert len(total_tokens) <= max_length, "Total # of tokens has to be less than or equal to max_length"