Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llm: Add Anthropic support #377

Merged
merged 3 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions taskweaver/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from taskweaver.llm.sentence_transformer import SentenceTransformerService
from taskweaver.llm.util import ChatMessageType, format_chat_message
from taskweaver.llm.zhipuai import ZhipuAIService
from taskweaver.llm.anthropic import AnthropicService

llm_completion_config_map = {
"openai": OpenAIService,
Expand All @@ -33,6 +34,7 @@
"qwen": QWenService,
"zhipuai": ZhipuAIService,
"groq": GroqService,
"anthropic": AnthropicService,
}

# TODO
Expand Down Expand Up @@ -66,6 +68,8 @@ def __init__(
self._set_completion_service(ZhipuAIService)
elif self.config.api_type == "groq":
self._set_completion_service(GroqService)
elif self.config.api_type == "anthropic": # Add support for Anthropic
self._set_completion_service(AnthropicService)
else:
raise ValueError(f"API type {self.config.api_type} is not supported")

Expand Down
93 changes: 93 additions & 0 deletions taskweaver/llm/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os
from typing import Any, Generator, List, Optional

from injector import inject

from taskweaver.llm.util import ChatMessageType, format_chat_message

from .base import CompletionService, EmbeddingService, LLMServiceConfig

DEFAULT_STOP_TOKEN: List[str] = ["<EOS>"]

class AnthropicServiceConfig(LLMServiceConfig):
def _configure(self) -> None:
shared_api_key = self.llm_module_config.api_key
self.api_key = self._get_str(
"api_key",
os.environ.get("ANTHROPIC_API_KEY", shared_api_key)
)
self.model = self._get_str("model", "claude-3-opus-20240229")
self.max_tokens = self._get_int("max_tokens", 1024)
self.temperature = self._get_float("temperature", 0)
self.top_p = self._get_float("top_p", 1)
self.stop_token = self._get_list("stop_token", DEFAULT_STOP_TOKEN)

class AnthropicService(CompletionService):
client = None

@inject
def __init__(self, config: AnthropicServiceConfig):
self.config = config
if AnthropicService.client is None:
try:
from anthropic import Anthropic
AnthropicService.client = Anthropic(api_key=self.config.api_key)
except Exception :
raise Exception(
"Package anthropic is required for using Anthropic API. Run 'pip install anthropic' to install.",
)

def chat_completion(
self,
messages: List[ChatMessageType],
stream: bool = True,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Generator[ChatMessageType, None, None]:
temperature = temperature if temperature is not None else self.config.temperature
max_tokens = max_tokens if max_tokens is not None else self.config.max_tokens
top_p = top_p if top_p is not None else self.config.top_p
stop = stop if stop is not None else self.config.stop_token

try:
# Extract system message if present
system_message = None
anthropic_messages = []
for msg in messages:
if msg["role"] == "system":
system_message = msg["content"]
else:
anthropic_messages.append({"role": msg["role"], "content": msg["content"]})

# Prepare kwargs for Anthropic API
anthropic_kwargs = {
"model": self.config.model,
"messages": anthropic_messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"stop_sequences": stop,
}

# Add system message if present
if system_message:
anthropic_kwargs["system"] = system_message

if stream:
with self.client.messages.stream(**anthropic_kwargs) as stream:
for response in stream:
if response.type == "content_block_delta":
yield format_chat_message("assistant", response.delta.text)
else:
response = self.client.messages.create(**anthropic_kwargs)
yield format_chat_message("assistant", response.content[0].text)

except Exception as e:
raise Exception(f"Anthropic API request failed: {str(e)}")

# Note: Anthropic doesn't provide a native embedding service.
# If you need embeddings, you might want to use a different service or library for that functionality.

35 changes: 35 additions & 0 deletions website/docs/llms/anthropic.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
---
description: Using LLMs from Anthropic (Claude)
---

# Anthropic (Claude)

1. Create an account on [Anthropic](https://www.anthropic.com/) and get your API key from the [Anthropic Console](https://console.anthropic.com/settings/keys).

2. Add the following to your `taskweaver_config.json` file:

```json showLineNumbers
{
"llm.api_type": "anthropic",
"llm.api_base": "https://api.anthropic.com/v1/messages",
"llm.api_key": "YOUR_API_KEY",
"llm.model": "claude-3-opus"
}
```

:::tip
`llm.model` is the model name you want to use. You can find the list of available Claude models in the [Anthropic API documentation](https://docs.anthropic.com/claude/reference/selecting-a-model).
:::

:::info
Anthropic's Claude API doesn't have a specific `response_format` parameter like OpenAI. If you need structured output, you should instruct Claude to respond in a specific format (e.g., JSON) within your prompts.
:::

:::caution
Anthropic doesn't provide a native embedding service. If you need embeddings, you'll need to configure a different service for that functionality.
:::

3. Start TaskWeaver and chat with TaskWeaver using Claude.
You can refer to the [Quick Start](../quickstart.md) for more details.

Remember to replace `YOUR_API_KEY` with your actual Anthropic API key.