Skip to content

Commit

Permalink
(Refactor) Code Quality improvement - Use Common base handler for `an…
Browse files Browse the repository at this point in the history
…thropic_text/` (BerriAI#7143)

* add anthropic text provider

* add ANTHROPIC_TEXT to LlmProviders

* fix anthropic text implementation

* working anthropic text claude-2

* test_acompletion_claude2_stream

* add param mapping for anthropic text

* fix unused imports

* fix anthropic completion handler.py
  • Loading branch information
ishaan-jaff authored Dec 10, 2024
1 parent 5e016fe commit bdb2082
Show file tree
Hide file tree
Showing 12 changed files with 527 additions and 497 deletions.
3 changes: 2 additions & 1 deletion litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,7 @@ class LlmProviders(str, Enum):
COHERE_CHAT = "cohere_chat"
CLARIFAI = "clarifai"
ANTHROPIC = "anthropic"
ANTHROPIC_TEXT = "anthropic_text"
REPLICATE = "replicate"
HUGGINGFACE = "huggingface"
TOGETHER_AI = "together_ai"
Expand Down Expand Up @@ -1060,7 +1061,7 @@ class LlmProviders(str, Enum):
AnthropicExperimentalPassThroughConfig,
)
from .llms.groq.stt.transformation import GroqSTTConfig
from .llms.anthropic.completion import AnthropicTextConfig
from .llms.anthropic.completion.transformation import AnthropicTextConfig
from .llms.databricks.chat.transformation import DatabricksConfig
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
from .llms.predibase import PredibaseConfig
Expand Down
1 change: 1 addition & 0 deletions litellm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"cohere_chat",
"clarifai",
"anthropic",
"anthropic_text",
"replicate",
"huggingface",
"together_ai",
Expand Down
46 changes: 44 additions & 2 deletions litellm/litellm_core_utils/get_llm_provider_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,39 @@ def handle_cohere_chat_model_custom_llm_provider(
return model, custom_llm_provider


def handle_anthropic_text_model_custom_llm_provider(
model: str, custom_llm_provider: Optional[str] = None
) -> Tuple[str, Optional[str]]:
"""
if user sets model = "anthropic/claude-2" -> use custom_llm_provider = "anthropic_text"
Args:
model:
custom_llm_provider:
Returns:
model, custom_llm_provider
"""

if custom_llm_provider:
if (
custom_llm_provider == "anthropic"
and litellm.AnthropicTextConfig._is_anthropic_text_model(model)
):
return model, "anthropic_text"

if "/" in model:
_custom_llm_provider, _model = model.split("/", 1)
if (
_custom_llm_provider
and _custom_llm_provider == "anthropic"
and litellm.AnthropicTextConfig._is_anthropic_text_model(_model)
):
return _model, "anthropic_text"

return model, custom_llm_provider


def get_llm_provider( # noqa: PLR0915
model: str,
custom_llm_provider: Optional[str] = None,
Expand Down Expand Up @@ -92,6 +125,10 @@ def get_llm_provider( # noqa: PLR0915
model, custom_llm_provider
)

model, custom_llm_provider = handle_anthropic_text_model_custom_llm_provider(
model, custom_llm_provider
)

if custom_llm_provider:
if (
model.split("/")[0] == custom_llm_provider
Expand Down Expand Up @@ -210,7 +247,10 @@ def get_llm_provider( # noqa: PLR0915
custom_llm_provider = "text-completion-openai"
## anthropic
elif model in litellm.anthropic_models:
custom_llm_provider = "anthropic"
if litellm.AnthropicTextConfig._is_anthropic_text_model(model):
custom_llm_provider = "anthropic_text"
else:
custom_llm_provider = "anthropic"
## cohere
elif model in litellm.cohere_models or model in litellm.cohere_embedding_models:
custom_llm_provider = "cohere"
Expand Down Expand Up @@ -531,7 +571,9 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
)
elif custom_llm_provider == "galadriel":
api_base = (
api_base or get_secret("GALADRIEL_API_BASE") or "https://api.galadriel.com/v1"
api_base
or get_secret("GALADRIEL_API_BASE")
or "https://api.galadriel.com/v1"
) # type: ignore
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
if api_base is not None and not isinstance(api_base, str):
Expand Down
42 changes: 0 additions & 42 deletions litellm/litellm_core_utils/streaming_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,40 +223,6 @@ def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
self.holding_chunk = ""
return hold, curr_chunk

def handle_anthropic_text_chunk(self, chunk):
"""
For old anthropic models - claude-1, claude-2.
Claude-3 is handled from within Anthropic.py VIA ModelResponseIterator()
"""
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
text = ""
is_finished = False
finish_reason = None
if str_line.startswith("data:"):
data_json = json.loads(str_line[5:])
type_chunk = data_json.get("type", None)
if type_chunk == "completion":
text = data_json.get("completion")
finish_reason = data_json.get("stop_reason")
if finish_reason is not None:
is_finished = True
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif "error" in str_line:
raise ValueError(f"Unable to parse response. Original response: {str_line}")
else:
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}

def handle_predibase_chunk(self, chunk):
try:
if not isinstance(chunk, str):
Expand Down Expand Up @@ -1005,14 +971,6 @@ def chunk_creator(self, chunk): # type: ignore # noqa: PLR0915
setattr(model_response, key, value)

response_obj = anthropic_response_obj
elif (
self.custom_llm_provider
and self.custom_llm_provider == "anthropic_text"
):
response_obj = self.handle_anthropic_text_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
response_obj = self.handle_replicate_chunk(chunk)
completion_obj["content"] = response_obj["text"]
Expand Down
Loading

0 comments on commit bdb2082

Please sign in to comment.