Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Update default params and add associated UTs #294

Merged
merged 8 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions presets/inference/text-generation/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,30 +123,18 @@ def health_check():
return {"status": "Healthy"}

class GenerateKwargs(BaseModel):
max_length: int = 200
max_length: int = 200 # Length of input prompt+max_new_tokens
min_length: int = 0
do_sample: bool = True
do_sample: bool = False
early_stopping: bool = False
num_beams: int = 1
num_beam_groups: int = 1
diversity_penalty: float = 0.0
temperature: float = 1.0
top_k: int = 10
top_k: int = 50
top_p: float = 1
typical_p: float = 1
repetition_penalty: float = 1
length_penalty: float = 1
no_repeat_ngram_size: int = 0
encoder_no_repeat_ngram_size: int = 0
bad_words_ids: Optional[List[int]] = None
num_return_sequences: int = 1
output_scores: bool = False
return_dict_in_generate: bool = False
pad_token_id: Optional[int] = tokenizer.pad_token_id
eos_token_id: Optional[int] = tokenizer.eos_token_id
forced_bos_token_id: Optional[int] = None
forced_eos_token_id: Optional[int] = None
remove_invalid_values: Optional[bool] = None
class Config:
extra = Extra.allow # Allows for additional fields not explicitly defined

Expand All @@ -157,7 +145,7 @@ class UnifiedRequestModel(BaseModel):
clean_up_tokenization_spaces: Optional[bool] = Field(False, description="Clean up extra spaces in text output")
prefix: Optional[str] = Field(None, description="Prefix added to prompt")
handle_long_generation: Optional[str] = Field(None, description="Strategy to handle long generation")
generate_kwargs: Optional[GenerateKwargs] = Field(None, description="Additional kwargs for generate method")
generate_kwargs: Optional[GenerateKwargs] = Field(default_factory=GenerateKwargs, description="Additional kwargs for generate method")

# Field for conversational model
messages: Optional[List[Dict[str, str]]] = Field(None, description="Messages for conversational model")
Expand Down
107 changes: 107 additions & 0 deletions presets/inference/text-generation/tests/test_inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import torch
from fastapi.testclient import TestClient
from transformers import AutoTokenizer

# Get the parent directory of the current file
parent_dir = str(Path(__file__).resolve().parent.parent)
Expand Down Expand Up @@ -127,3 +128,109 @@ def test_get_metrics_no_gpus(configured_app):
response = client.get("/metrics")
assert response.status_code == 200
assert response.json()["gpu_info"] == []

def test_default_generation_params(configured_app):
if configured_app.test_config['pipeline'] != 'text-generation':
pytest.skip("Skipping non-text-generation tests")

client = TestClient(configured_app)

request_data = {
"prompt": "Test default params",
"return_full_text": True,
"clean_up_tokenization_spaces": False
# Note: generate_kwargs is not provided, so defaults should be used
}

with patch('inference_api.pipeline') as mock_pipeline:
mock_pipeline.return_value = [{"generated_text": "Mocked response"}] # Mock the response of the pipeline function

response = client.post("/chat", json=request_data)

assert response.status_code == 200
data = response.json()
assert "Result" in data
assert data["Result"] == "Mocked response", "The response content doesn't match the expected mock response"

# Check the default args
_, kwargs = mock_pipeline.call_args
assert kwargs['max_length'] == 200
assert kwargs['min_length'] == 0
assert kwargs['do_sample'] is False
assert kwargs['temperature'] == 1.0
assert kwargs['top_k'] == 50
assert kwargs['top_p'] == 1
assert kwargs['typical_p'] == 1
assert kwargs['repetition_penalty'] == 1
assert kwargs['num_beams'] == 1
assert kwargs['early_stopping'] is False

def test_generation_with_max_length(configured_app):
if configured_app.test_config['pipeline'] != 'text-generation':
pytest.skip("Skipping non-text-generation tests")

client = TestClient(configured_app)
prompt = "This prompt requests a response of a certain minimum length to test the functionality."
avg_res_len = 15
max_length = 40 # Set to lower than default (200) to prevent test hanging

request_data = {
"prompt": prompt,
"return_full_text": True,
"clean_up_tokenization_spaces": False,
"generate_kwargs": {"max_length": max_length}
}

response = client.post("/chat", json=request_data)

assert response.status_code == 200
data = response.json()
print("Response: ", data["Result"])
assert "Result" in data, "The response should contain a 'Result' key"

tokenizer = AutoTokenizer.from_pretrained(configured_app.test_config['model_path'])
prompt_tokens = tokenizer.tokenize(prompt)
total_tokens = tokenizer.tokenize(data["Result"]) # data["Result"] includes the input prompt

prompt_tokens_len = len(prompt_tokens)
max_new_tokens = max_length - prompt_tokens_len
new_tokens = len(total_tokens) - prompt_tokens_len

assert avg_res_len <= new_tokens, f"Ideally response should generate at least 15 tokens"
assert new_tokens <= max_new_tokens, "Response must not generate more than max new tokens"
assert len(total_tokens) <= max_length, "Total # of tokens has to be less than or equal to max_length"

def test_generation_with_min_length(configured_app):
if configured_app.test_config['pipeline'] != 'text-generation':
pytest.skip("Skipping non-text-generation tests")

client = TestClient(configured_app)
prompt = "This prompt requests a response of a certain minimum length to test the functionality."
min_length = 30
max_length = 40

request_data = {
"prompt": prompt,
"return_full_text": True,
"clean_up_tokenization_spaces": False,
"generate_kwargs": {"min_length": min_length, "max_length": max_length}
}

response = client.post("/chat", json=request_data)

assert response.status_code == 200
data = response.json()
assert "Result" in data, "The response should contain a 'Result' key"

tokenizer = AutoTokenizer.from_pretrained(configured_app.test_config['model_path'])
prompt_tokens = tokenizer.tokenize(prompt)
total_tokens = tokenizer.tokenize(data["Result"]) # data["Result"] includes the input prompt

prompt_tokens_len = len(prompt_tokens)

min_new_tokens = min_length - prompt_tokens_len
max_new_tokens = max_length - prompt_tokens_len
new_tokens = len(total_tokens) - prompt_tokens_len

assert min_new_tokens <= new_tokens <= max_new_tokens, "Response should generate at least min_new_tokens and at most max_new_tokens new tokens"
assert len(total_tokens) <= max_length, "Total # of tokens has to be less than or equal to max_length"
Loading