Skip to content

Commit

Permalink
Add Gemini and Groq function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
juberti committed Jul 22, 2024
1 parent 4c60bb0 commit 8a78121
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 40 deletions.
6 changes: 4 additions & 2 deletions llm_benchmark_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,11 @@ def _tools_models():
_Llm("claude-3-5-sonnet-20240620"),
_Llm("claude-3-sonnet-20240229"),
_Llm("claude-3-haiku-20240307"),
_Llm(GEMINI_1_5_PRO),
_Llm(GEMINI_1_5_FLASH),
_FireworksLlm("accounts/fireworks/models/firefunction-v2", "firefunction-v2"),
# _GroqLlm("llama3-groq-70b-8192-tool-use-preview"),
# _GroqLlm("llama3-groq-8b-8192-tool-use-preview"),
_GroqLlm("llama3-groq-70b-8192-tool-use-preview"),
_GroqLlm("llama3-groq-8b-8192-tool-use-preview"),
]


Expand Down
61 changes: 23 additions & 38 deletions llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,15 @@ async def openai_chunk_gen(response) -> TokenGenerator:
yield delta_content
elif delta_tool:
function = delta_tool[0]["function"]
token = function.get("name") or function.get("arguments")
if token:
yield token.strip()
usage = chunk.get("usage")
name = function.get("name", "").strip()
if name:
tokens += 1
yield name
args = function.get("arguments", "").strip()
if args:
tokens += 1
yield args
usage = chunk.get("usage") or chunk.get("x_groq", {}).get("usage")
if usage:
num_input_tokens = usage.get("prompt_tokens")
num_output_tokens = usage.get("completion_tokens")
Expand All @@ -267,7 +272,7 @@ async def openai_chat(ctx: ApiContext, path: str = "/chat/completions") -> ApiRe
kwargs = {"messages": make_openai_messages(ctx)}
if ctx.tools:
kwargs["tools"] = ctx.tools
kwargs["tool_choice"] = "auto"
kwargs["tool_choice"] = "required"
if ctx.peft:
kwargs["peft"] = ctx.peft
data = make_openai_chat_body(ctx, **kwargs)
Expand Down Expand Up @@ -425,6 +430,7 @@ async def gemini_chat(ctx: ApiContext) -> ApiResult:
async def chunk_gen(response) -> TokenGenerator:
tokens = 0
async for chunk in make_json_chunk_gen(response):
print("chunk", chunk)
candidates = chunk.get("candidates")
if candidates:
content = candidates[0].get("content")
Expand All @@ -433,6 +439,14 @@ async def chunk_gen(response) -> TokenGenerator:
if "text" in part:
tokens += 1
yield part["text"]
elif "functionCall" in part:
call = part["functionCall"]
if "name" in call:
tokens += 1
yield call["name"]
if "args" in call:
tokens += 1
yield str(call["args"])
usage = chunk.get("usageMetadata")
if usage:
num_tokens = usage.get("candidatesTokenCount")
Expand Down Expand Up @@ -466,6 +480,10 @@ async def chunk_gen(response) -> TokenGenerator:
if not ctx.files or ctx.files[0].is_image
],
}
if ctx.tools:
data["tools"] = (
[{"function_declarations": [tool["function"] for tool in ctx.tools]}],
)
return await post(ctx, url, headers, data, chunk_gen)


Expand All @@ -480,38 +498,6 @@ async def cohere_embed(ctx: ApiContext) -> ApiResult:
return await post(ctx, url, headers, data)


async def make_fixie_chunk_gen(response) -> TokenGenerator:
text = ""
async for line in response.content:
line = line.decode("utf-8").strip()
obj = json.loads(line)
curr_turn = obj["turns"][-1]
if (
curr_turn["role"] == "assistant"
and curr_turn["messages"]
and "content" in curr_turn["messages"][-1]
):
if curr_turn["state"] == "done":
break
new_text = curr_turn["messages"][-1]["content"]
# Sometimes we get a spurious " " message
if new_text == " ":
continue
if new_text.startswith(text):
delta = new_text[len(text) :]
text = new_text
yield delta
else:
print(f"Warning: got unexpected text: '{new_text}' vs '{text}'")


async def fixie_chat(ctx: ApiContext) -> ApiResult:
url = f"https://api.fixie.ai/api/v1/agents/{ctx.model}/conversations"
headers = make_headers(auth_token=get_api_key(ctx, "FIXIE_API_KEY"))
data = {"message": ctx.prompt, "runtimeParameters": {}}
return await post(ctx, url, headers, data, make_fixie_chunk_gen)


async def fake_chat(ctx: ApiContext) -> ApiResult:
class FakeResponse(aiohttp.ClientResponse):
def __init__(self, status, reason):
Expand Down Expand Up @@ -613,7 +599,6 @@ def make_context(
func = openai_chat
if not args.base_url:
provider = "openai"
# case _ elif "/" in model return await fixie_chat(ctx)
case _:
raise ValueError(f"Unknown model: {model}")
name = args.display_name or make_display_name(provider, model)
Expand Down

0 comments on commit 8a78121

Please sign in to comment.