Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for OpenRouter and update OpenAIChat #124

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions examples/manifest_openrouter.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"OPENROUTER_API_KEY = \"sk-...\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use ChatOpenAI\n",
"\n",
"Set you `OPENROUTER_API_KEY` environment variable."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from manifest import Manifest\n",
"from manifest.connections.client_pool import ClientConnection\n",
"\n",
"openai_chat = ClientConnection(\n",
" client_name=\"openrouter\",\n",
" client_connection=OPENROUTER_API_KEY,\n",
" engine=\"meta-llama/codellama-70b-instruct\"\n",
")\n",
"\n",
"manifest = Manifest(client_pool=[openai_chat])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2020 World Series was played at the Globe Life Field in Arlington, Texas.\n"
]
}
],
"source": [
"# Simple question\n",
"chat_dict = [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
" {\"role\": \"assistant\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
" {\"role\": \"user\", \"content\": \"Where was it played?\"}\n",
"]\n",
"print(manifest.run(chat_dict, max_tokens=100))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"vscode": {
"interpreter": {
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
2 changes: 2 additions & 0 deletions manifest/clients/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4o",
"gpt-4o-mini",
"gpt-4-32k",
"gpt-4-1106-preview",
}
Expand Down
151 changes: 151 additions & 0 deletions manifest/clients/openrouter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""OpenRouter client."""

import copy
import logging
import os
from typing import Any, Dict, Optional

from manifest.clients.client import Client
from manifest.request import LMRequest

logger = logging.getLogger(__name__)


class OpenRouterClient(Client):
"""OpenRouter client."""

# Params are defined in https://openrouter.ai/docs/parameters
PARAMS = {
"engine": ("model", "meta-llama/codellama-70b-instruct"),
"max_tokens": ("max_tokens", 1000),
"temperature": ("temperature", 0.1),
"top_k": ("k", 0),
"frequency_penalty": ("frequency_penalty", 0.0),
"presence_penalty": ("presence_penalty", 0.0),
"stop_sequences": ("stop", None),
}
REQUEST_CLS = LMRequest
NAME = "openrouter"
IS_CHAT = True

def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the OpenRouter server.

connection_str is passed as default OPENROUTER_API_KEY if variable not set.

Args:
connection_str: connection string.
client_args: client arguments.
"""
self.api_key = connection_str or os.environ.get("OPENROUTER_API_KEY")
if self.api_key is None:
raise ValueError(
"OpenRouter API key not set. Set OPENROUTER_API_KEY environment "
"variable or pass through `client_connection`."
)
self.host = "https://openrouter.ai/api/v1"
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))

def close(self) -> None:
"""Close the client."""

def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.

Returns:
header.
"""
return {
"Authorization": f"Bearer {self.api_key}",
}

def get_generation_url(self) -> str:
"""Get generation URL."""
return self.host + "/chat/completions"

def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return False

def supports_streaming_inference(self) -> bool:
"""Return whether the client supports streaming inference.

Override in child client class.
"""
return True

def get_model_params(self) -> Dict:
"""
Get model params.

By getting model params from the server, we can add to request
and make sure cache keys are unique to model.

Returns:
model params.
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}

def preprocess_request_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""
Preprocess request params.

Args:
request: request params.

Returns:
request params.
"""
# Format for chat model
request = copy.deepcopy(request)
prompt = request.pop("prompt")
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list) and isinstance(prompt[0], str):
prompt_list = prompt
messages = [{"role": "user", "content": prompt} for prompt in prompt_list]
elif isinstance(prompt, list) and isinstance(prompt[0], dict):
for pmt_dict in prompt:
if "role" not in pmt_dict or "content" not in pmt_dict:
raise ValueError(
"Prompt must be list of dicts with 'role' and 'content' "
f"keys. Got {prompt}."
)
messages = prompt
else:
raise ValueError(
"Prompt must be string, list of strings, or list of dicts."
f"Got {prompt}"
)
request["messages"] = messages
return super().preprocess_request_params(request)

def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

Args:
response: response
request: request

Return:
response as dict
"""
new_choices = []
response = copy.deepcopy(response)
for message in response["choices"]:
if "delta" in message:
# This is a streaming response
if "content" in message["delta"]:
new_choices.append({"text": message["delta"]["content"]})
else:
new_choices.append({"text": message["message"]["content"]})
response["choices"] = new_choices
return super().postprocess_response(response, request)
2 changes: 2 additions & 0 deletions manifest/connections/client_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from manifest.clients.openai import OpenAIClient
from manifest.clients.openai_chat import OpenAIChatClient
from manifest.clients.openai_embedding import OpenAIEmbeddingClient
from manifest.clients.openrouter import OpenRouterClient
from manifest.clients.toma import TOMAClient
from manifest.connections.scheduler import RandomScheduler, RoundRobinScheduler

Expand All @@ -37,6 +38,7 @@
OpenAIClient.NAME: OpenAIClient,
OpenAIChatClient.NAME: OpenAIChatClient,
OpenAIEmbeddingClient.NAME: OpenAIEmbeddingClient,
OpenRouterClient.NAME: OpenRouterClient,
TOMAClient.NAME: TOMAClient,
}

Expand Down