From caaee2a6549d7772f42bf5c1e03525d7a6f136f2 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 19 Dec 2024 12:17:58 +0000 Subject: [PATCH] add tests --- .gitignore | 2 + scripts/__init__.py | 0 scripts/pytest.sh | 4 ++ scripts/req.py | 62 ++++++++++------- tests/test_basic.py | 160 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 203 insertions(+), 25 deletions(-) create mode 100644 scripts/__init__.py create mode 100755 scripts/pytest.sh create mode 100644 tests/test_basic.py diff --git a/.gitignore b/.gitignore index d242cfb..565a70a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ build tmp target model.cache +__pycache__ +*.py[cod] diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/pytest.sh b/scripts/pytest.sh new file mode 100755 index 0000000..b12400e --- /dev/null +++ b/scripts/pytest.sh @@ -0,0 +1,4 @@ +#!/bin/sh + +cd "$(dirname "$0")/.." +PYTHONPATH=$(pwd) pytest -v tests "$@" diff --git a/scripts/req.py b/scripts/req.py index c15f3a7..bb5f2f5 100755 --- a/scripts/req.py +++ b/scripts/req.py @@ -62,13 +62,16 @@ def send_stream(path: str, payload: dict) -> requests.Response | None: return None -def joke_msg(): - return messages( - "Ignore the text below.\n" - + gen_prompt(PROMPT_SIZE) - + "\n\n" - + "Now, please tell me a long joke." - ) +def joke_msg(prompt_size: int = PROMPT_SIZE): + if prompt_size: + return messages( + "Ignore the text below.\n" + + gen_prompt(prompt_size) + + "\n\n" + + "Now, please tell me a long joke." + ) + else: + return messages("Tell me a long joke.") def llg_data(): @@ -83,17 +86,27 @@ def llg_data(): } -def req_data(): - properties = {} - required = [] - for idx in range(NUM_JOKES): - properties[f"joke_{idx}"] = {"type": "string"} - properties[f"rating_{idx}"] = {"type": "number"} - required.extend([f"joke_{idx}", f"rating_{idx}"]) - return { +def req_data( + max_tokens: int = MAX_TOKENS, + llg: bool = LLG, + temperature: float = 0.8, + prompt_size: int = PROMPT_SIZE, +): + r = { "model": "model", - "messages": joke_msg(), - ("response_format" if LLG else "ignore_me"): { + "messages": joke_msg(prompt_size=prompt_size), + # "llg_log_level": "json", + "max_tokens": max_tokens, + "temperature": temperature, + } + if llg: + properties = {} + required = [] + for idx in range(NUM_JOKES): + properties[f"joke_{idx}"] = {"type": "string"} + properties[f"rating_{idx}"] = {"type": "number"} + required.extend([f"joke_{idx}", f"rating_{idx}"]) + r["response_format"] = { "type": "json_schema", "json_schema": { "strict": True, @@ -104,11 +117,8 @@ def req_data(): "required": required, }, }, - }, - # "llg_log_level": "json", - "max_tokens": MAX_TOKENS, - "temperature": 0.8, - } + } + return r class Results: @@ -125,7 +135,7 @@ def finalize(self): self.completion_tokens = self.usage.get("completion_tokens", 0) self.completion_tokens2 = len(self.tbt) self.text = "".join(self.text_chunks) - print(self.text) + # print(self.text) if not self.tbt: self.avg_tbt = 0 self.med_tbt = 0 @@ -248,6 +258,8 @@ def one_round(): """ + + def main(): # random.seed(0) parser = argparse.ArgumentParser() @@ -268,7 +280,6 @@ def main(): one_round() return - if args.max_threads > 0: thr = 1 @@ -322,4 +333,5 @@ def csv_line(lst): print(r.error) -main() +if __name__ == "__main__": + main() diff --git a/tests/test_basic.py b/tests/test_basic.py new file mode 100644 index 0000000..8eb3075 --- /dev/null +++ b/tests/test_basic.py @@ -0,0 +1,160 @@ +import pytest +import json +import sys + +sys.path.append(".") +from scripts.req import send_one, send_one_stream, req_data, messages + +json_tag = "<|python_tag|>" +n = 10 + + +def joke_data(): + d = req_data(temperature=1.2, prompt_size=0, llg=False, max_tokens=50) + d["n"] = n + return d + + +def send_and_check(d): + resp = send_one(d) + assert resp is not None + assert len(resp["choices"]) == n + return resp + + +def test_simple_chat(): + resp = send_and_check(joke_data()) + jokes = [c["message"]["content"] for c in resp["choices"]] + assert len(set(jokes)) > 1 + + +def test_stream_chat(): + resp = send_one_stream(joke_data()) + assert len(resp) == n + jokes = [r.text for r in resp] + assert len(set(jokes)) > 1 + + +def test_resp_format_json_object(): + d = { + "model": "model", + "messages": messages("Please tell me a one line joke."), + "max_tokens": 50, + "temperature": 0.8, + "n": n, + "response_format": {"type": "json_object"}, + } + resp = send_and_check(d) + + num_ok = 0 + for c in resp["choices"]: + content: str = c["message"]["content"] + if c["finish_reason"] == "stop": + d = json.loads(content) + num_ok += 1 + + assert num_ok > 0 + + +def schema_data(msg: str, schema: dict): + return { + "model": "model", + "messages": messages(msg), + "max_tokens": 50, + "temperature": 0.8, + "n": n, + "response_format": { + "type": "json_schema", + "json_schema": { + "strict": True, + "schema": schema, + }, + }, + } + + +def send_schema(msg: str, schema: dict): + d = schema_data(msg, schema) + return send_and_check(d) + + +def test_resp_format_json_schema(): + resp = send_schema( + "Please tell me a one line joke.", + { + "type": "object", + "properties": { + "joke": {"type": "string"}, + "rating": {"type": "number"}, + }, + "additionalProperties": False, + "required": ["joke", "rating"], + }, + ) + + num_ok = 0 + for c in resp["choices"]: + content: str = c["message"]["content"] + if c["finish_reason"] == "stop": + j = json.loads(content) + assert "joke" in j + assert "rating" in j + num_ok += 1 + + assert num_ok > 0 + + +def test_resp_string_schema(): + resp = send_schema( + "How much is 1+1?", + { + "$schema": "http://json-schema.org/draft-06/schema#", + "type": "string", + "enum": ["1", "2-two", "3 three", None], + }, + ) + for c in resp["choices"]: + content: str = c["message"]["content"] + assert content == '"2-two"' + + +@pytest.mark.parametrize("strict", [True, False]) +@pytest.mark.parametrize("weather", [True, False]) +def test_tools(strict: bool, weather: bool): + d = { + "model": "model", + "messages": messages( + "What is the weather in London?" if weather else "How much is 2 + 2?" + ), + "max_tokens": 50, + "temperature": 0.8, + "n": n, + "tool_choice": "required" if weather else "auto", + "tools": [ + { + "type": "function", + "function": { + "name": "weather", + "description": "Get the weather for a ", + "strict": strict, + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "additionalProperties": False, + "required": ["location"], + }, + }, + } + ], + } + resp = send_and_check(d) + for c in resp["choices"]: + content: str = c["message"]["content"] + if weather: + assert '"function"' in content + assert content.startswith(json_tag) + j = json.loads(content[len(json_tag) :]) + assert j["name"] == "weather" + assert "London" in j["parameters"]["location"] + else: + assert '"function"' not in content