From 8a78121e07c8cb5b981f66de2d7f62fcb2120854 Mon Sep 17 00:00:00 2001 From: juberti Date: Mon, 22 Jul 2024 14:05:28 -0700 Subject: [PATCH 1/2] Add Gemini and Groq function calling --- llm_benchmark_suite.py | 6 +++-- llm_request.py | 61 ++++++++++++++++-------------------------- 2 files changed, 27 insertions(+), 40 deletions(-) diff --git a/llm_benchmark_suite.py b/llm_benchmark_suite.py index a128f14..92506de 100644 --- a/llm_benchmark_suite.py +++ b/llm_benchmark_suite.py @@ -412,9 +412,11 @@ def _tools_models(): _Llm("claude-3-5-sonnet-20240620"), _Llm("claude-3-sonnet-20240229"), _Llm("claude-3-haiku-20240307"), + _Llm(GEMINI_1_5_PRO), + _Llm(GEMINI_1_5_FLASH), _FireworksLlm("accounts/fireworks/models/firefunction-v2", "firefunction-v2"), - # _GroqLlm("llama3-groq-70b-8192-tool-use-preview"), - # _GroqLlm("llama3-groq-8b-8192-tool-use-preview"), + _GroqLlm("llama3-groq-70b-8192-tool-use-preview"), + _GroqLlm("llama3-groq-8b-8192-tool-use-preview"), ] diff --git a/llm_request.py b/llm_request.py index afb0b63..8268953 100644 --- a/llm_request.py +++ b/llm_request.py @@ -250,10 +250,15 @@ async def openai_chunk_gen(response) -> TokenGenerator: yield delta_content elif delta_tool: function = delta_tool[0]["function"] - token = function.get("name") or function.get("arguments") - if token: - yield token.strip() - usage = chunk.get("usage") + name = function.get("name", "").strip() + if name: + tokens += 1 + yield name + args = function.get("arguments", "").strip() + if args: + tokens += 1 + yield args + usage = chunk.get("usage") or chunk.get("x_groq", {}).get("usage") if usage: num_input_tokens = usage.get("prompt_tokens") num_output_tokens = usage.get("completion_tokens") @@ -267,7 +272,7 @@ async def openai_chat(ctx: ApiContext, path: str = "/chat/completions") -> ApiRe kwargs = {"messages": make_openai_messages(ctx)} if ctx.tools: kwargs["tools"] = ctx.tools - kwargs["tool_choice"] = "auto" + kwargs["tool_choice"] = "required" if ctx.peft: kwargs["peft"] = ctx.peft data = make_openai_chat_body(ctx, **kwargs) @@ -425,6 +430,7 @@ async def gemini_chat(ctx: ApiContext) -> ApiResult: async def chunk_gen(response) -> TokenGenerator: tokens = 0 async for chunk in make_json_chunk_gen(response): + print("chunk", chunk) candidates = chunk.get("candidates") if candidates: content = candidates[0].get("content") @@ -433,6 +439,14 @@ async def chunk_gen(response) -> TokenGenerator: if "text" in part: tokens += 1 yield part["text"] + elif "functionCall" in part: + call = part["functionCall"] + if "name" in call: + tokens += 1 + yield call["name"] + if "args" in call: + tokens += 1 + yield str(call["args"]) usage = chunk.get("usageMetadata") if usage: num_tokens = usage.get("candidatesTokenCount") @@ -466,6 +480,10 @@ async def chunk_gen(response) -> TokenGenerator: if not ctx.files or ctx.files[0].is_image ], } + if ctx.tools: + data["tools"] = ( + [{"function_declarations": [tool["function"] for tool in ctx.tools]}], + ) return await post(ctx, url, headers, data, chunk_gen) @@ -480,38 +498,6 @@ async def cohere_embed(ctx: ApiContext) -> ApiResult: return await post(ctx, url, headers, data) -async def make_fixie_chunk_gen(response) -> TokenGenerator: - text = "" - async for line in response.content: - line = line.decode("utf-8").strip() - obj = json.loads(line) - curr_turn = obj["turns"][-1] - if ( - curr_turn["role"] == "assistant" - and curr_turn["messages"] - and "content" in curr_turn["messages"][-1] - ): - if curr_turn["state"] == "done": - break - new_text = curr_turn["messages"][-1]["content"] - # Sometimes we get a spurious " " message - if new_text == " ": - continue - if new_text.startswith(text): - delta = new_text[len(text) :] - text = new_text - yield delta - else: - print(f"Warning: got unexpected text: '{new_text}' vs '{text}'") - - -async def fixie_chat(ctx: ApiContext) -> ApiResult: - url = f"https://api.fixie.ai/api/v1/agents/{ctx.model}/conversations" - headers = make_headers(auth_token=get_api_key(ctx, "FIXIE_API_KEY")) - data = {"message": ctx.prompt, "runtimeParameters": {}} - return await post(ctx, url, headers, data, make_fixie_chunk_gen) - - async def fake_chat(ctx: ApiContext) -> ApiResult: class FakeResponse(aiohttp.ClientResponse): def __init__(self, status, reason): @@ -613,7 +599,6 @@ def make_context( func = openai_chat if not args.base_url: provider = "openai" - # case _ elif "/" in model return await fixie_chat(ctx) case _: raise ValueError(f"Unknown model: {model}") name = args.display_name or make_display_name(provider, model) From 407e6b5c692ebe8069c02a431e6e42b0c08d3e87 Mon Sep 17 00:00:00 2001 From: juberti Date: Mon, 22 Jul 2024 14:05:48 -0700 Subject: [PATCH 2/2] Update llm_request.py --- llm_request.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llm_request.py b/llm_request.py index 8268953..c4da062 100644 --- a/llm_request.py +++ b/llm_request.py @@ -430,7 +430,6 @@ async def gemini_chat(ctx: ApiContext) -> ApiResult: async def chunk_gen(response) -> TokenGenerator: tokens = 0 async for chunk in make_json_chunk_gen(response): - print("chunk", chunk) candidates = chunk.get("candidates") if candidates: content = candidates[0].get("content")