Skip to content

Commit

Permalink
Litellm dev 12 07 2024 (BerriAI#7086)
Browse files Browse the repository at this point in the history
* fix(main.py): support passing max retries to azure/openai embedding integrations

Fixes BerriAI#7003

* feat(team_endpoints.py): allow updating team model aliases

Closes BerriAI#6956

* feat(router.py): allow specifying model id as fallback - skips any cooldown check

Allows a default model to be checked if all models in cooldown

s/o @micahjsmith

* docs(reliability.md): add fallback to specific model to docs

* fix(utils.py): new 'is_prompt_caching_valid_prompt' helper util

Allows user to identify if messages/tools have prompt caching

Related issue: BerriAI#6784

* feat(router.py): store model id for prompt caching valid prompt

Allows routing to that model id on subsequent requests

* fix(router.py): only cache if prompt is valid prompt caching prompt

prevents storing unnecessary items in cache

* feat(router.py): support routing prompt caching enabled models to previous deployments

Closes BerriAI#6784

* test: fix linting errors

* feat(databricks/): convert basemodel to dict and exclude none values

allow passing pydantic message to databricks

* fix(utils.py): ensure all chat completion messages are dict

* (feat) Track `custom_llm_provider` in LiteLLMSpendLogs (BerriAI#7081)

* add custom_llm_provider to SpendLogsPayload

* add custom_llm_provider to SpendLogs

* add custom llm provider to SpendLogs payload

* test_spend_logs_payload

* Add MLflow to the side bar (BerriAI#7031)

Signed-off-by: B-Step62 <[email protected]>

* (bug fix) SpendLogs update DB catch all possible DB errors for retrying  (BerriAI#7082)

* catch DB_CONNECTION_ERROR_TYPES

* fix DB retry mechanism for SpendLog updates

* use DB_CONNECTION_ERROR_TYPES in auth checks

* fix exp back off for writing SpendLogs

* use _raise_failed_update_spend_exception to ensure errors print as NON blocking

* test_update_spend_logs_multiple_batches_with_failure

* (Feat) Add StructuredOutputs support for Fireworks.AI (BerriAI#7085)

* fix model cost map fireworks ai "supports_response_schema": true,

* fix supports_response_schema

* fix map openai params fireworks ai

* test_map_response_format

* test_map_response_format

* added deepinfra/Meta-Llama-3.1-405B-Instruct (BerriAI#7084)

* bump: version 1.53.9 → 1.54.0

* fix deepinfra

* litellm db fixes LiteLLM_UserTable (BerriAI#7089)

* ci/cd queue new release

* fix llama-3.3-70b-versatile

* refactor - use consistent file naming convention `AI21/` -> `ai21`  (BerriAI#7090)

* fix refactor - use consistent file naming convention

* ci/cd run again

* fix naming structure

* fix use consistent naming (BerriAI#7092)

---------

Signed-off-by: B-Step62 <[email protected]>
Co-authored-by: Ishaan Jaff <[email protected]>
Co-authored-by: Yuki Watanabe <[email protected]>
Co-authored-by: ali sayyah <[email protected]>
  • Loading branch information
4 people authored Dec 8, 2024
1 parent 36e99eb commit 0c0498d
Show file tree
Hide file tree
Showing 24 changed files with 840 additions and 193 deletions.
59 changes: 59 additions & 0 deletions docs/my-website/docs/proxy/reliability.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,64 @@ litellm_settings:
cooldown_time: 30 # how long to cooldown model if fails/min > allowed_fails
```
### Fallback to Specific Model ID
If all models in a group are in cooldown (e.g. rate limited), LiteLLM will fallback to the model with the specific model ID.
This skips any cooldown check for the fallback model.
1. Specify the model ID in `model_info`
```yaml
model_list:
- model_name: gpt-4
litellm_params:
model: openai/gpt-4
model_info:
id: my-specific-model-id # 👈 KEY CHANGE
- model_name: gpt-4
litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
- model_name: anthropic-claude
litellm_params:
model: anthropic/claude-3-opus-20240229
api_key: os.environ/ANTHROPIC_API_KEY
```

**Note:** This will only fallback to the model with the specific model ID. If you want to fallback to another model group, you can set `fallbacks=[{"gpt-4": ["anthropic-claude"]}]`

2. Set fallbacks in config

```yaml
litellm_settings:
fallbacks: [{"gpt-4": ["my-specific-model-id"]}]
```

3. Test it!

```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-D '{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "ping"
}
],
"mock_testing_fallbacks": true
}'
```

Validate it works, by checking the response header `x-litellm-model-id`

```bash
x-litellm-model-id: my-specific-model-id
```

### Test Fallbacks!

Check if your fallbacks are working as expected.
Expand All @@ -337,6 +395,7 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
'
```


#### **Content Policy Fallbacks**
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
Expand Down
3 changes: 2 additions & 1 deletion docs/my-website/docs/routing.md
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ router_settings:

If a call fails after num_retries, fall back to another model group.

### Quick Start
#### Quick Start

```python
from litellm import Router
Expand Down Expand Up @@ -1366,6 +1366,7 @@ litellm --config /path/to/config.yaml
</TabItem>
</Tabs>


### Caching

In production, we recommend using a Redis cache. For quickly testing things locally, we also support simple in-memory caching.
Expand Down
1 change: 1 addition & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
DEFAULT_BATCH_SIZE,
DEFAULT_FLUSH_INTERVAL_SECONDS,
ROUTER_MAX_FALLBACKS,
DEFAULT_MAX_RETRIES,
)
from litellm.types.guardrails import GuardrailItem
from litellm.proxy._types import (
Expand Down
1 change: 1 addition & 0 deletions litellm/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
ROUTER_MAX_FALLBACKS = 5
DEFAULT_BATCH_SIZE = 512
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
DEFAULT_MAX_RETRIES = 2
3 changes: 2 additions & 1 deletion litellm/llms/OpenAI/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,12 +1217,13 @@ def embedding( # type: ignore
api_base: Optional[str] = None,
client=None,
aembedding=None,
max_retries: Optional[int] = None,
) -> litellm.EmbeddingResponse:
super().embedding()
try:
model = model
data = {"model": model, "input": input, **optional_params}
max_retries = data.pop("max_retries", 2)
max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
## LOGGING
Expand Down
3 changes: 2 additions & 1 deletion litellm/llms/azure/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,7 @@ def embedding(
optional_params: dict,
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
max_retries: Optional[int] = None,
client=None,
aembedding=None,
) -> litellm.EmbeddingResponse:
Expand All @@ -879,7 +880,7 @@ def embedding(
self._client_session = self.create_client_session()
try:
data = {"model": model, "input": input, **optional_params}
max_retries = data.pop("max_retries", 2)
max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
Expand Down
1 change: 1 addition & 0 deletions litellm/llms/azure_ai/embed/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def embedding(
api_base: Optional[str] = None,
client=None,
aembedding=None,
max_retries: Optional[int] = None,
) -> litellm.EmbeddingResponse:
"""
- Separate image url from text
Expand Down
2 changes: 1 addition & 1 deletion litellm/llms/databricks/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _transform_messages(
new_messages = []
for idx, message in enumerate(messages):
if isinstance(message, BaseModel):
_message = message.model_dump()
_message = message.model_dump(exclude_none=True)
else:
_message = message
new_messages.append(_message)
Expand Down
7 changes: 5 additions & 2 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
read_config_args,
supports_httpx_timeout,
token_counter,
validate_chat_completion_user_messages,
validate_chat_completion_messages,
)

from ._logging import verbose_logger
Expand Down Expand Up @@ -931,7 +931,7 @@ def completion( # type: ignore # noqa: PLR0915
) # support region-based pricing for bedrock

### VALIDATE USER MESSAGES ###
validate_chat_completion_user_messages(messages=messages)
messages = validate_chat_completion_messages(messages=messages)

### TIMEOUT LOGIC ###
timeout = timeout or kwargs.get("request_timeout", 600) or 600
Expand Down Expand Up @@ -3274,6 +3274,7 @@ def embedding( # noqa: PLR0915
client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None)
max_retries = kwargs.get("max_retries", None)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
cooldown_time = kwargs.get("cooldown_time", None)
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
Expand Down Expand Up @@ -3422,6 +3423,7 @@ def embedding( # noqa: PLR0915
optional_params=optional_params,
client=client,
aembedding=aembedding,
max_retries=max_retries,
)
elif (
model in litellm.open_ai_embedding_models
Expand Down Expand Up @@ -3466,6 +3468,7 @@ def embedding( # noqa: PLR0915
optional_params=optional_params,
client=client,
aembedding=aembedding,
max_retries=max_retries,
)
elif custom_llm_provider == "databricks":
api_base = (
Expand Down
4 changes: 3 additions & 1 deletion litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ class KeyRequest(LiteLLMBase):


class LiteLLM_ModelTable(LiteLLMBase):
model_aliases: Optional[str] = None # json dump the dict
model_aliases: Optional[Union[str, dict]] = None # json dump the dict
created_by: str
updated_by: str

Expand Down Expand Up @@ -981,6 +981,7 @@ class UpdateTeamRequest(LiteLLMBase):
blocked: Optional[bool] = None
budget_duration: Optional[str] = None
tags: Optional[list] = None
model_aliases: Optional[dict] = None


class ResetTeamBudgetRequest(LiteLLMBase):
Expand Down Expand Up @@ -1059,6 +1060,7 @@ class LiteLLM_TeamTable(TeamBase):
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None
litellm_model_table: Optional[LiteLLM_ModelTable] = None

model_config = ConfigDict(protected_namespaces=())

Expand Down
65 changes: 58 additions & 7 deletions litellm/proxy/management_endpoints/team_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,13 @@ async def new_team( # noqa: PLR0915
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
complete_team_data.budget_reset_at = reset_at

team_row: LiteLLM_TeamTable = await prisma_client.insert_data( # type: ignore
data=complete_team_data.json(exclude_none=True), table_name="team"
complete_team_data_dict = complete_team_data.model_dump(exclude_none=True)
complete_team_data_dict = prisma_client.jsonify_team_object(
db_data=complete_team_data_dict
)
team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create(
data=complete_team_data_dict,
include={"litellm_model_table": True}, # type: ignore
)

## ADD TEAM ID TO USER TABLE ##
Expand Down Expand Up @@ -340,6 +345,37 @@ async def new_team( # noqa: PLR0915
return team_row.dict()


async def _update_model_table(
data: UpdateTeamRequest,
model_id: Optional[str],
prisma_client: PrismaClient,
user_api_key_dict: UserAPIKeyAuth,
litellm_proxy_admin_name: str,
) -> Optional[str]:
"""
Upsert model table and return the model id
"""
## UPSERT MODEL TABLE
_model_id = model_id
if data.model_aliases is not None and isinstance(data.model_aliases, dict):
litellm_modeltable = LiteLLM_ModelTable(
model_aliases=json.dumps(data.model_aliases),
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
model_dict = await prisma_client.db.litellm_modeltable.upsert(
where={"id": model_id},
data={
"update": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore
"create": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore
},
) # type: ignore

_model_id = model_dict.id

return _model_id


@router.post(
"/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
)
Expand Down Expand Up @@ -370,6 +406,7 @@ async def update_team(
- blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id.
- tags: Optional[List[str]] - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing).
- organization_id: Optional[str] - The organization id of the team. Default is None. Create via `/organization/new`.
- model_aliases: Optional[dict] - Model aliases for the team. [Docs](https://docs.litellm.ai/docs/proxy/team_based_routing#create-team-with-model-alias)
Example - update team TPM Limit
Expand Down Expand Up @@ -446,11 +483,25 @@ async def update_team(
else:
updated_kv["metadata"] = {"tags": _tags}

updated_kv = prisma_client.jsonify_object(data=updated_kv)
team_row: Optional[
LiteLLM_TeamTable
] = await prisma_client.db.litellm_teamtable.update(
where={"team_id": data.team_id}, data=updated_kv # type: ignore
if "model_aliases" in updated_kv:
updated_kv.pop("model_aliases")
_model_id = await _update_model_table(
data=data,
model_id=existing_team_row.model_id,
prisma_client=prisma_client,
user_api_key_dict=user_api_key_dict,
litellm_proxy_admin_name=litellm_proxy_admin_name,
)
if _model_id is not None:
updated_kv["model_id"] = _model_id

updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv)
team_row: Optional[LiteLLM_TeamTable] = (
await prisma_client.db.litellm_teamtable.update(
where={"team_id": data.team_id},
data=updated_kv,
include={"litellm_model_table": True}, # type: ignore
)
)

if team_row is None or team_row.team_id is None:
Expand Down
Loading

0 comments on commit 0c0498d

Please sign in to comment.