Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Dec 19, 2024
1 parent 25e12c2 commit caaee2a
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 25 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ build
tmp
target
model.cache
__pycache__
*.py[cod]
Empty file added scripts/__init__.py
Empty file.
4 changes: 4 additions & 0 deletions scripts/pytest.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/sh

cd "$(dirname "$0")/.."
PYTHONPATH=$(pwd) pytest -v tests "$@"
62 changes: 37 additions & 25 deletions scripts/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,16 @@ def send_stream(path: str, payload: dict) -> requests.Response | None:
return None


def joke_msg():
return messages(
"Ignore the text below.\n"
+ gen_prompt(PROMPT_SIZE)
+ "\n\n"
+ "Now, please tell me a long joke."
)
def joke_msg(prompt_size: int = PROMPT_SIZE):
if prompt_size:
return messages(
"Ignore the text below.\n"
+ gen_prompt(prompt_size)
+ "\n\n"
+ "Now, please tell me a long joke."
)
else:
return messages("Tell me a long joke.")


def llg_data():
Expand All @@ -83,17 +86,27 @@ def llg_data():
}


def req_data():
properties = {}
required = []
for idx in range(NUM_JOKES):
properties[f"joke_{idx}"] = {"type": "string"}
properties[f"rating_{idx}"] = {"type": "number"}
required.extend([f"joke_{idx}", f"rating_{idx}"])
return {
def req_data(
max_tokens: int = MAX_TOKENS,
llg: bool = LLG,
temperature: float = 0.8,
prompt_size: int = PROMPT_SIZE,
):
r = {
"model": "model",
"messages": joke_msg(),
("response_format" if LLG else "ignore_me"): {
"messages": joke_msg(prompt_size=prompt_size),
# "llg_log_level": "json",
"max_tokens": max_tokens,
"temperature": temperature,
}
if llg:
properties = {}
required = []
for idx in range(NUM_JOKES):
properties[f"joke_{idx}"] = {"type": "string"}
properties[f"rating_{idx}"] = {"type": "number"}
required.extend([f"joke_{idx}", f"rating_{idx}"])
r["response_format"] = {
"type": "json_schema",
"json_schema": {
"strict": True,
Expand All @@ -104,11 +117,8 @@ def req_data():
"required": required,
},
},
},
# "llg_log_level": "json",
"max_tokens": MAX_TOKENS,
"temperature": 0.8,
}
}
return r


class Results:
Expand All @@ -125,7 +135,7 @@ def finalize(self):
self.completion_tokens = self.usage.get("completion_tokens", 0)
self.completion_tokens2 = len(self.tbt)
self.text = "".join(self.text_chunks)
print(self.text)
# print(self.text)
if not self.tbt:
self.avg_tbt = 0
self.med_tbt = 0
Expand Down Expand Up @@ -248,6 +258,8 @@ def one_round():
"""




def main():
# random.seed(0)
parser = argparse.ArgumentParser()
Expand All @@ -268,7 +280,6 @@ def main():
one_round()
return


if args.max_threads > 0:
thr = 1

Expand Down Expand Up @@ -322,4 +333,5 @@ def csv_line(lst):
print(r.error)


main()
if __name__ == "__main__":
main()
160 changes: 160 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import pytest
import json
import sys

sys.path.append(".")
from scripts.req import send_one, send_one_stream, req_data, messages

json_tag = "<|python_tag|>"
n = 10


def joke_data():
d = req_data(temperature=1.2, prompt_size=0, llg=False, max_tokens=50)
d["n"] = n
return d


def send_and_check(d):
resp = send_one(d)
assert resp is not None
assert len(resp["choices"]) == n
return resp


def test_simple_chat():
resp = send_and_check(joke_data())
jokes = [c["message"]["content"] for c in resp["choices"]]
assert len(set(jokes)) > 1


def test_stream_chat():
resp = send_one_stream(joke_data())
assert len(resp) == n
jokes = [r.text for r in resp]
assert len(set(jokes)) > 1


def test_resp_format_json_object():
d = {
"model": "model",
"messages": messages("Please tell me a one line joke."),
"max_tokens": 50,
"temperature": 0.8,
"n": n,
"response_format": {"type": "json_object"},
}
resp = send_and_check(d)

num_ok = 0
for c in resp["choices"]:
content: str = c["message"]["content"]
if c["finish_reason"] == "stop":
d = json.loads(content)
num_ok += 1

assert num_ok > 0


def schema_data(msg: str, schema: dict):
return {
"model": "model",
"messages": messages(msg),
"max_tokens": 50,
"temperature": 0.8,
"n": n,
"response_format": {
"type": "json_schema",
"json_schema": {
"strict": True,
"schema": schema,
},
},
}


def send_schema(msg: str, schema: dict):
d = schema_data(msg, schema)
return send_and_check(d)


def test_resp_format_json_schema():
resp = send_schema(
"Please tell me a one line joke.",
{
"type": "object",
"properties": {
"joke": {"type": "string"},
"rating": {"type": "number"},
},
"additionalProperties": False,
"required": ["joke", "rating"],
},
)

num_ok = 0
for c in resp["choices"]:
content: str = c["message"]["content"]
if c["finish_reason"] == "stop":
j = json.loads(content)
assert "joke" in j
assert "rating" in j
num_ok += 1

assert num_ok > 0


def test_resp_string_schema():
resp = send_schema(
"How much is 1+1?",
{
"$schema": "http://json-schema.org/draft-06/schema#",
"type": "string",
"enum": ["1", "2-two", "3 three", None],
},
)
for c in resp["choices"]:
content: str = c["message"]["content"]
assert content == '"2-two"'


@pytest.mark.parametrize("strict", [True, False])
@pytest.mark.parametrize("weather", [True, False])
def test_tools(strict: bool, weather: bool):
d = {
"model": "model",
"messages": messages(
"What is the weather in London?" if weather else "How much is 2 + 2?"
),
"max_tokens": 50,
"temperature": 0.8,
"n": n,
"tool_choice": "required" if weather else "auto",
"tools": [
{
"type": "function",
"function": {
"name": "weather",
"description": "Get the weather for a <location>",
"strict": strict,
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"additionalProperties": False,
"required": ["location"],
},
},
}
],
}
resp = send_and_check(d)
for c in resp["choices"]:
content: str = c["message"]["content"]
if weather:
assert '"function"' in content
assert content.startswith(json_tag)
j = json.loads(content[len(json_tag) :])
assert j["name"] == "weather"
assert "London" in j["parameters"]["location"]
else:
assert '"function"' not in content

0 comments on commit caaee2a

Please sign in to comment.