From 0c0498dd6007342aa7eef78c7acfd0d9acbb8577 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Sun, 8 Dec 2024 00:30:33 -0800 Subject: [PATCH] Litellm dev 12 07 2024 (#7086) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(main.py): support passing max retries to azure/openai embedding integrations Fixes https://github.com/BerriAI/litellm/issues/7003 * feat(team_endpoints.py): allow updating team model aliases Closes https://github.com/BerriAI/litellm/issues/6956 * feat(router.py): allow specifying model id as fallback - skips any cooldown check Allows a default model to be checked if all models in cooldown s/o @micahjsmith * docs(reliability.md): add fallback to specific model to docs * fix(utils.py): new 'is_prompt_caching_valid_prompt' helper util Allows user to identify if messages/tools have prompt caching Related issue: https://github.com/BerriAI/litellm/issues/6784 * feat(router.py): store model id for prompt caching valid prompt Allows routing to that model id on subsequent requests * fix(router.py): only cache if prompt is valid prompt caching prompt prevents storing unnecessary items in cache * feat(router.py): support routing prompt caching enabled models to previous deployments Closes https://github.com/BerriAI/litellm/issues/6784 * test: fix linting errors * feat(databricks/): convert basemodel to dict and exclude none values allow passing pydantic message to databricks * fix(utils.py): ensure all chat completion messages are dict * (feat) Track `custom_llm_provider` in LiteLLMSpendLogs (#7081) * add custom_llm_provider to SpendLogsPayload * add custom_llm_provider to SpendLogs * add custom llm provider to SpendLogs payload * test_spend_logs_payload * Add MLflow to the side bar (#7031) Signed-off-by: B-Step62 * (bug fix) SpendLogs update DB catch all possible DB errors for retrying (#7082) * catch DB_CONNECTION_ERROR_TYPES * fix DB retry mechanism for SpendLog updates * use DB_CONNECTION_ERROR_TYPES in auth checks * fix exp back off for writing SpendLogs * use _raise_failed_update_spend_exception to ensure errors print as NON blocking * test_update_spend_logs_multiple_batches_with_failure * (Feat) Add StructuredOutputs support for Fireworks.AI (#7085) * fix model cost map fireworks ai "supports_response_schema": true, * fix supports_response_schema * fix map openai params fireworks ai * test_map_response_format * test_map_response_format * added deepinfra/Meta-Llama-3.1-405B-Instruct (#7084) * bump: version 1.53.9 → 1.54.0 * fix deepinfra * litellm db fixes LiteLLM_UserTable (#7089) * ci/cd queue new release * fix llama-3.3-70b-versatile * refactor - use consistent file naming convention `AI21/` -> `ai21` (#7090) * fix refactor - use consistent file naming convention * ci/cd run again * fix naming structure * fix use consistent naming (#7092) --------- Signed-off-by: B-Step62 Co-authored-by: Ishaan Jaff Co-authored-by: Yuki Watanabe <31463517+B-Step62@users.noreply.github.com> Co-authored-by: ali sayyah --- docs/my-website/docs/proxy/reliability.md | 59 ++++++ docs/my-website/docs/routing.md | 3 +- litellm/__init__.py | 1 + litellm/constants.py | 1 + litellm/llms/OpenAI/openai.py | 3 +- litellm/llms/azure/azure.py | 3 +- litellm/llms/azure_ai/embed/handler.py | 1 + .../llms/databricks/chat/transformation.py | 2 +- litellm/main.py | 7 +- litellm/proxy/_types.py | 4 +- .../management_endpoints/team_endpoints.py | 65 +++++- litellm/proxy/utils.py | 176 +--------------- litellm/router.py | 59 +++++- .../router_utils/fallback_event_handlers.py | 1 + litellm/router_utils/prompt_caching_cache.py | 193 ++++++++++++++++++ litellm/utils.py | 51 +++++ .../base_embedding_unit_tests.py | 67 ++++++ tests/llm_translation/base_llm_unit_tests.py | 10 + tests/llm_translation/test_azure_openai.py | 13 ++ .../test_anthropic_prompt_caching.py | 176 +++++++++++++++- tests/local_testing/test_router.py | 18 ++ tests/local_testing/test_router_cooldowns.py | 31 +++ tests/local_testing/test_router_fallbacks.py | 23 +++ .../test_router_prompt_caching.py | 66 ++++++ 24 files changed, 840 insertions(+), 193 deletions(-) create mode 100644 litellm/router_utils/prompt_caching_cache.py create mode 100644 tests/llm_translation/base_embedding_unit_tests.py create mode 100644 tests/router_unit_tests/test_router_prompt_caching.py diff --git a/docs/my-website/docs/proxy/reliability.md b/docs/my-website/docs/proxy/reliability.md index 1e6d0e26c2ee..9f1c1c8bb959 100644 --- a/docs/my-website/docs/proxy/reliability.md +++ b/docs/my-website/docs/proxy/reliability.md @@ -315,6 +315,64 @@ litellm_settings: cooldown_time: 30 # how long to cooldown model if fails/min > allowed_fails ``` +### Fallback to Specific Model ID + +If all models in a group are in cooldown (e.g. rate limited), LiteLLM will fallback to the model with the specific model ID. + +This skips any cooldown check for the fallback model. + +1. Specify the model ID in `model_info` +```yaml +model_list: + - model_name: gpt-4 + litellm_params: + model: openai/gpt-4 + model_info: + id: my-specific-model-id # 👈 KEY CHANGE + - model_name: gpt-4 + litellm_params: + model: azure/chatgpt-v-2 + api_base: os.environ/AZURE_API_BASE + api_key: os.environ/AZURE_API_KEY + - model_name: anthropic-claude + litellm_params: + model: anthropic/claude-3-opus-20240229 + api_key: os.environ/ANTHROPIC_API_KEY +``` + +**Note:** This will only fallback to the model with the specific model ID. If you want to fallback to another model group, you can set `fallbacks=[{"gpt-4": ["anthropic-claude"]}]` + +2. Set fallbacks in config + +```yaml +litellm_settings: + fallbacks: [{"gpt-4": ["my-specific-model-id"]}] +``` + +3. Test it! + +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-D '{ + "model": "gpt-4", + "messages": [ + { + "role": "user", + "content": "ping" + } + ], + "mock_testing_fallbacks": true +}' +``` + +Validate it works, by checking the response header `x-litellm-model-id` + +```bash +x-litellm-model-id: my-specific-model-id +``` + ### Test Fallbacks! Check if your fallbacks are working as expected. @@ -337,6 +395,7 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \ ' ``` + #### **Content Policy Fallbacks** ```bash curl -X POST 'http://0.0.0.0:4000/chat/completions' \ diff --git a/docs/my-website/docs/routing.md b/docs/my-website/docs/routing.md index 87fad7437e47..c4b633a976fe 100644 --- a/docs/my-website/docs/routing.md +++ b/docs/my-website/docs/routing.md @@ -1130,7 +1130,7 @@ router_settings: If a call fails after num_retries, fall back to another model group. -### Quick Start +#### Quick Start ```python from litellm import Router @@ -1366,6 +1366,7 @@ litellm --config /path/to/config.yaml + ### Caching In production, we recommend using a Redis cache. For quickly testing things locally, we also support simple in-memory caching. diff --git a/litellm/__init__.py b/litellm/__init__.py index 542ab89c9152..4b872b01430b 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -21,6 +21,7 @@ DEFAULT_BATCH_SIZE, DEFAULT_FLUSH_INTERVAL_SECONDS, ROUTER_MAX_FALLBACKS, + DEFAULT_MAX_RETRIES, ) from litellm.types.guardrails import GuardrailItem from litellm.proxy._types import ( diff --git a/litellm/constants.py b/litellm/constants.py index 331a0a630c7e..97dc6c7348bd 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -1,3 +1,4 @@ ROUTER_MAX_FALLBACKS = 5 DEFAULT_BATCH_SIZE = 512 DEFAULT_FLUSH_INTERVAL_SECONDS = 5 +DEFAULT_MAX_RETRIES = 2 diff --git a/litellm/llms/OpenAI/openai.py b/litellm/llms/OpenAI/openai.py index 057340b51306..66ce75701886 100644 --- a/litellm/llms/OpenAI/openai.py +++ b/litellm/llms/OpenAI/openai.py @@ -1217,12 +1217,13 @@ def embedding( # type: ignore api_base: Optional[str] = None, client=None, aembedding=None, + max_retries: Optional[int] = None, ) -> litellm.EmbeddingResponse: super().embedding() try: model = model data = {"model": model, "input": input, **optional_params} - max_retries = data.pop("max_retries", 2) + max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES if not isinstance(max_retries, int): raise OpenAIError(status_code=422, message="max retries must be an int") ## LOGGING diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 24303ef2fe5b..fb2dfbc9f1e3 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -871,6 +871,7 @@ def embedding( optional_params: dict, api_key: Optional[str] = None, azure_ad_token: Optional[str] = None, + max_retries: Optional[int] = None, client=None, aembedding=None, ) -> litellm.EmbeddingResponse: @@ -879,7 +880,7 @@ def embedding( self._client_session = self.create_client_session() try: data = {"model": model, "input": input, **optional_params} - max_retries = data.pop("max_retries", 2) + max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES if not isinstance(max_retries, int): raise AzureOpenAIError( status_code=422, message="max retries must be an int" diff --git a/litellm/llms/azure_ai/embed/handler.py b/litellm/llms/azure_ai/embed/handler.py index 2946a84dd4b0..933be7eb8108 100644 --- a/litellm/llms/azure_ai/embed/handler.py +++ b/litellm/llms/azure_ai/embed/handler.py @@ -219,6 +219,7 @@ def embedding( api_base: Optional[str] = None, client=None, aembedding=None, + max_retries: Optional[int] = None, ) -> litellm.EmbeddingResponse: """ - Separate image url from text diff --git a/litellm/llms/databricks/chat/transformation.py b/litellm/llms/databricks/chat/transformation.py index 009e5e189466..6b362c3662e9 100644 --- a/litellm/llms/databricks/chat/transformation.py +++ b/litellm/llms/databricks/chat/transformation.py @@ -134,7 +134,7 @@ def _transform_messages( new_messages = [] for idx, message in enumerate(messages): if isinstance(message, BaseModel): - _message = message.model_dump() + _message = message.model_dump(exclude_none=True) else: _message = message new_messages.append(_message) diff --git a/litellm/main.py b/litellm/main.py index 29ee90154688..f574b9339c9b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -77,7 +77,7 @@ read_config_args, supports_httpx_timeout, token_counter, - validate_chat_completion_user_messages, + validate_chat_completion_messages, ) from ._logging import verbose_logger @@ -931,7 +931,7 @@ def completion( # type: ignore # noqa: PLR0915 ) # support region-based pricing for bedrock ### VALIDATE USER MESSAGES ### - validate_chat_completion_user_messages(messages=messages) + messages = validate_chat_completion_messages(messages=messages) ### TIMEOUT LOGIC ### timeout = timeout or kwargs.get("request_timeout", 600) or 600 @@ -3274,6 +3274,7 @@ def embedding( # noqa: PLR0915 client = kwargs.pop("client", None) rpm = kwargs.pop("rpm", None) tpm = kwargs.pop("tpm", None) + max_retries = kwargs.get("max_retries", None) litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore cooldown_time = kwargs.get("cooldown_time", None) mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore @@ -3422,6 +3423,7 @@ def embedding( # noqa: PLR0915 optional_params=optional_params, client=client, aembedding=aembedding, + max_retries=max_retries, ) elif ( model in litellm.open_ai_embedding_models @@ -3466,6 +3468,7 @@ def embedding( # noqa: PLR0915 optional_params=optional_params, client=client, aembedding=aembedding, + max_retries=max_retries, ) elif custom_llm_provider == "databricks": api_base = ( diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 991ab98c8546..cdc583b4c175 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -723,7 +723,7 @@ class KeyRequest(LiteLLMBase): class LiteLLM_ModelTable(LiteLLMBase): - model_aliases: Optional[str] = None # json dump the dict + model_aliases: Optional[Union[str, dict]] = None # json dump the dict created_by: str updated_by: str @@ -981,6 +981,7 @@ class UpdateTeamRequest(LiteLLMBase): blocked: Optional[bool] = None budget_duration: Optional[str] = None tags: Optional[list] = None + model_aliases: Optional[dict] = None class ResetTeamBudgetRequest(LiteLLMBase): @@ -1059,6 +1060,7 @@ class LiteLLM_TeamTable(TeamBase): budget_duration: Optional[str] = None budget_reset_at: Optional[datetime] = None model_id: Optional[int] = None + litellm_model_table: Optional[LiteLLM_ModelTable] = None model_config = ConfigDict(protected_namespaces=()) diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index cb28bd1fbe7c..356461a213eb 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -293,8 +293,13 @@ async def new_team( # noqa: PLR0915 reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) complete_team_data.budget_reset_at = reset_at - team_row: LiteLLM_TeamTable = await prisma_client.insert_data( # type: ignore - data=complete_team_data.json(exclude_none=True), table_name="team" + complete_team_data_dict = complete_team_data.model_dump(exclude_none=True) + complete_team_data_dict = prisma_client.jsonify_team_object( + db_data=complete_team_data_dict + ) + team_row: LiteLLM_TeamTable = await prisma_client.db.litellm_teamtable.create( + data=complete_team_data_dict, + include={"litellm_model_table": True}, # type: ignore ) ## ADD TEAM ID TO USER TABLE ## @@ -340,6 +345,37 @@ async def new_team( # noqa: PLR0915 return team_row.dict() +async def _update_model_table( + data: UpdateTeamRequest, + model_id: Optional[str], + prisma_client: PrismaClient, + user_api_key_dict: UserAPIKeyAuth, + litellm_proxy_admin_name: str, +) -> Optional[str]: + """ + Upsert model table and return the model id + """ + ## UPSERT MODEL TABLE + _model_id = model_id + if data.model_aliases is not None and isinstance(data.model_aliases, dict): + litellm_modeltable = LiteLLM_ModelTable( + model_aliases=json.dumps(data.model_aliases), + created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + ) + model_dict = await prisma_client.db.litellm_modeltable.upsert( + where={"id": model_id}, + data={ + "update": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore + "create": {**litellm_modeltable.json(exclude_none=True)}, # type: ignore + }, + ) # type: ignore + + _model_id = model_dict.id + + return _model_id + + @router.post( "/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)] ) @@ -370,6 +406,7 @@ async def update_team( - blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id. - tags: Optional[List[str]] - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing). - organization_id: Optional[str] - The organization id of the team. Default is None. Create via `/organization/new`. + - model_aliases: Optional[dict] - Model aliases for the team. [Docs](https://docs.litellm.ai/docs/proxy/team_based_routing#create-team-with-model-alias) Example - update team TPM Limit @@ -446,11 +483,25 @@ async def update_team( else: updated_kv["metadata"] = {"tags": _tags} - updated_kv = prisma_client.jsonify_object(data=updated_kv) - team_row: Optional[ - LiteLLM_TeamTable - ] = await prisma_client.db.litellm_teamtable.update( - where={"team_id": data.team_id}, data=updated_kv # type: ignore + if "model_aliases" in updated_kv: + updated_kv.pop("model_aliases") + _model_id = await _update_model_table( + data=data, + model_id=existing_team_row.model_id, + prisma_client=prisma_client, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + ) + if _model_id is not None: + updated_kv["model_id"] = _model_id + + updated_kv = prisma_client.jsonify_team_object(db_data=updated_kv) + team_row: Optional[LiteLLM_TeamTable] = ( + await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, + data=updated_kv, + include={"litellm_model_table": True}, # type: ignore + ) ) if team_row is None or team_row.team_id is None: diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index abfaf60f02db..b5f26cb12666 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1207,173 +1207,6 @@ async def check_view_exists(self): except Exception: raise - - # try: - # # Try to select one row from the view - # await self.db.query_raw( - # """SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""" - # ) - # print("LiteLLM_VerificationTokenView Exists!") # noqa - # except Exception as e: - # If an error occurs, the view does not exist, so create it - - # try: - # await self.db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""") - # print("MonthlyGlobalSpend Exists!") # noqa - # except Exception as e: - # sql_query = """ - # CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS - # SELECT - # DATE("startTime") AS date, - # SUM("spend") AS spend - # FROM - # "LiteLLM_SpendLogs" - # WHERE - # "startTime" >= (CURRENT_DATE - INTERVAL '30 days') - # GROUP BY - # DATE("startTime"); - # """ - # await self.db.execute_raw(query=sql_query) - - # print("MonthlyGlobalSpend Created!") # noqa - - # try: - # await self.db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""") - # print("Last30dKeysBySpend Exists!") # noqa - # except Exception as e: - # sql_query = """ - # CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS - # SELECT - # L."api_key", - # V."key_alias", - # V."key_name", - # SUM(L."spend") AS total_spend - # FROM - # "LiteLLM_SpendLogs" L - # LEFT JOIN - # "LiteLLM_VerificationToken" V - # ON - # L."api_key" = V."token" - # WHERE - # L."startTime" >= (CURRENT_DATE - INTERVAL '30 days') - # GROUP BY - # L."api_key", V."key_alias", V."key_name" - # ORDER BY - # total_spend DESC; - # """ - # await self.db.execute_raw(query=sql_query) - - # print("Last30dKeysBySpend Created!") # noqa - - # try: - # await self.db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""") - # print("Last30dModelsBySpend Exists!") # noqa - # except Exception as e: - # sql_query = """ - # CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS - # SELECT - # "model", - # SUM("spend") AS total_spend - # FROM - # "LiteLLM_SpendLogs" - # WHERE - # "startTime" >= (CURRENT_DATE - INTERVAL '30 days') - # AND "model" != '' - # GROUP BY - # "model" - # ORDER BY - # total_spend DESC; - # """ - # await self.db.execute_raw(query=sql_query) - - # print("Last30dModelsBySpend Created!") # noqa - # try: - # await self.db.query_raw( - # """SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1""" - # ) - # print("MonthlyGlobalSpendPerKey Exists!") # noqa - # except Exception as e: - # sql_query = """ - # CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS - # SELECT - # DATE("startTime") AS date, - # SUM("spend") AS spend, - # api_key as api_key - # FROM - # "LiteLLM_SpendLogs" - # WHERE - # "startTime" >= (CURRENT_DATE - INTERVAL '30 days') - # GROUP BY - # DATE("startTime"), - # api_key; - # """ - # await self.db.execute_raw(query=sql_query) - - # print("MonthlyGlobalSpendPerKey Created!") # noqa - # try: - # await self.db.query_raw( - # """SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1""" - # ) - # print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa - # except Exception as e: - # sql_query = """ - # CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS - # SELECT - # DATE("startTime") AS date, - # SUM("spend") AS spend, - # api_key as api_key, - # "user" as "user" - # FROM - # "LiteLLM_SpendLogs" - # WHERE - # "startTime" >= (CURRENT_DATE - INTERVAL '20 days') - # GROUP BY - # DATE("startTime"), - # "user", - # api_key; - # """ - # await self.db.execute_raw(query=sql_query) - - # print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa - - # try: - # await self.db.query_raw("""SELECT 1 FROM "DailyTagSpend" LIMIT 1""") - # print("DailyTagSpend Exists!") # noqa - # except Exception as e: - # sql_query = """ - # CREATE OR REPLACE VIEW DailyTagSpend AS - # SELECT - # jsonb_array_elements_text(request_tags) AS individual_request_tag, - # DATE(s."startTime") AS spend_date, - # COUNT(*) AS log_count, - # SUM(spend) AS total_spend - # FROM "LiteLLM_SpendLogs" s - # GROUP BY individual_request_tag, DATE(s."startTime"); - # """ - # await self.db.execute_raw(query=sql_query) - - # print("DailyTagSpend Created!") # noqa - - # try: - # await self.db.query_raw( - # """SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1""" - # ) - # print("Last30dTopEndUsersSpend Exists!") # noqa - # except Exception as e: - # sql_query = """ - # CREATE VIEW "Last30dTopEndUsersSpend" AS - # SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend - # FROM "LiteLLM_SpendLogs" - # WHERE end_user <> '' AND end_user <> user - # AND "startTime" >= CURRENT_DATE - INTERVAL '30 days' - # GROUP BY end_user - # ORDER BY total_spend DESC - # LIMIT 100; - # """ - # await self.db.execute_raw(query=sql_query) - - # print("Last30dTopEndUsersSpend Created!") # noqa - return @log_db_metrics @@ -1784,6 +1617,14 @@ async def get_data( # noqa: PLR0915 ) raise e + def jsonify_team_object(self, db_data: dict): + db_data = self.jsonify_object(data=db_data) + if db_data.get("members_with_roles", None) is not None and isinstance( + db_data["members_with_roles"], list + ): + db_data["members_with_roles"] = json.dumps(db_data["members_with_roles"]) + return db_data + # Define a retrying strategy with exponential backoff @backoff.on_exception( backoff.expo, @@ -2348,7 +2189,6 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: module_name = value instance_name = None try: - print_verbose(f"value: {value}") # Split the path by dots to separate module from instance parts = value.split(".") diff --git a/litellm/router.py b/litellm/router.py index c9cfab555685..2f333bf6b38b 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -36,6 +36,7 @@ Tuple, TypedDict, Union, + cast, ) import httpx @@ -96,6 +97,7 @@ ) from litellm.scheduler import FlowItem, Scheduler from litellm.types.llms.openai import ( + AllMessageValues, Assistant, AssistantToolParam, AsyncCursorPage, @@ -149,10 +151,12 @@ get_llm_provider, get_secret, get_utc_datetime, + is_prompt_caching_valid_prompt, is_region_allowed, ) from .router_utils.pattern_match_deployments import PatternMatchRouter +from .router_utils.prompt_caching_cache import PromptCachingCache if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -737,7 +741,9 @@ def _completion( model_client = potential_model_client ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit) - self.routing_strategy_pre_call_checks(deployment=deployment) + ## only run if model group given, not model id + if model not in self.get_model_ids(): + self.routing_strategy_pre_call_checks(deployment=deployment) response = litellm.completion( **{ @@ -2787,8 +2793,10 @@ async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915 *args, **input_kwargs, ) + return response except Exception as new_exception: + traceback.print_exc() parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) verbose_router_logger.error( "litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format( @@ -3376,6 +3384,29 @@ async def deployment_callback_on_success( deployment_id=id, ) + ## PROMPT CACHING + prompt_cache = PromptCachingCache( + cache=self.cache, + ) + if ( + standard_logging_object["messages"] is not None + and isinstance(standard_logging_object["messages"], list) + and deployment_name is not None + and isinstance(deployment_name, str) + ): + valid_prompt = is_prompt_caching_valid_prompt( + messages=standard_logging_object["messages"], # type: ignore + tools=None, + model=deployment_name, + custom_llm_provider=None, + ) + if valid_prompt: + await prompt_cache.async_add_model_id( + model_id=id, + messages=standard_logging_object["messages"], # type: ignore + tools=None, + ) + return tpm_key except Exception as e: @@ -5190,7 +5221,6 @@ def _common_checks_available_deployment( - List, if multiple models chosen - Dict, if specific model chosen """ - # check if aliases set on litellm model alias map if specific_deployment is True: return model, self._get_deployment_by_litellm_model(model=model) @@ -5302,13 +5332,6 @@ async def async_get_available_deployment( cooldown_deployments=cooldown_deployments, ) - # filter pre-call checks - _allowed_model_region = ( - request_kwargs.get("allowed_model_region") - if request_kwargs is not None - else None - ) - if self.enable_pre_call_checks and messages is not None: healthy_deployments = self._pre_call_checks( model=model, @@ -5317,6 +5340,24 @@ async def async_get_available_deployment( request_kwargs=request_kwargs, ) + if messages is not None and is_prompt_caching_valid_prompt( + messages=cast(List[AllMessageValues], messages), + model=model, + custom_llm_provider=None, + ): + prompt_cache = PromptCachingCache( + cache=self.cache, + ) + healthy_deployment = ( + await prompt_cache.async_get_prompt_caching_deployment( + router=self, + messages=cast(List[AllMessageValues], messages), + tools=None, + ) + ) + if healthy_deployment is not None: + return healthy_deployment + # check if user wants to do tag based routing healthy_deployments = await get_deployments_for_tag( # type: ignore llm_router_instance=self, diff --git a/litellm/router_utils/fallback_event_handlers.py b/litellm/router_utils/fallback_event_handlers.py index 5d027e59717f..41c3080e9a07 100644 --- a/litellm/router_utils/fallback_event_handlers.py +++ b/litellm/router_utils/fallback_event_handlers.py @@ -49,6 +49,7 @@ async def run_async_fallback( raise original_exception error_from_fallbacks = original_exception + for mg in fallback_model_group: if mg == original_model_group: continue diff --git a/litellm/router_utils/prompt_caching_cache.py b/litellm/router_utils/prompt_caching_cache.py new file mode 100644 index 000000000000..d1861dc7c87a --- /dev/null +++ b/litellm/router_utils/prompt_caching_cache.py @@ -0,0 +1,193 @@ +""" +Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called. +""" + +import hashlib +import json +import time +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict + +import litellm +from litellm import verbose_logger +from litellm.caching.caching import Cache, DualCache +from litellm.caching.in_memory_cache import InMemoryCache +from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + from litellm.router import Router + + litellm_router = Router + Span = _Span +else: + Span = Any + litellm_router = Any + + +class PromptCachingCacheValue(TypedDict): + model_id: str + + +class PromptCachingCache: + def __init__(self, cache: DualCache): + self.cache = cache + self.in_memory_cache = InMemoryCache() + + @staticmethod + def serialize_object(obj: Any) -> Any: + """Helper function to serialize Pydantic objects, dictionaries, or fallback to string.""" + if hasattr(obj, "dict"): + # If the object is a Pydantic model, use its `dict()` method + return obj.dict() + elif isinstance(obj, dict): + # If the object is a dictionary, serialize it with sorted keys + return json.dumps( + obj, sort_keys=True, separators=(",", ":") + ) # Standardize serialization + + elif isinstance(obj, list): + # Serialize lists by ensuring each element is handled properly + return [PromptCachingCache.serialize_object(item) for item in obj] + elif isinstance(obj, (int, float, bool)): + return obj # Keep primitive types as-is + return str(obj) + + @staticmethod + def get_prompt_caching_cache_key( + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> Optional[str]: + if messages is None and tools is None: + return None + # Use serialize_object for consistent and stable serialization + data_to_hash = {} + if messages is not None: + serialized_messages = PromptCachingCache.serialize_object(messages) + data_to_hash["messages"] = serialized_messages + if tools is not None: + serialized_tools = PromptCachingCache.serialize_object(tools) + data_to_hash["tools"] = serialized_tools + + # Combine serialized data into a single string + data_to_hash_str = json.dumps( + data_to_hash, + sort_keys=True, + separators=(",", ":"), + ) + + # Create a hash of the serialized data for a stable cache key + hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest() + return f"deployment:{hashed_data}:prompt_caching" + + def add_model_id( + self, + model_id: str, + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> None: + if messages is None and tools is None: + return None + + cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools) + self.cache.set_cache( + cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300 + ) + return None + + async def async_add_model_id( + self, + model_id: str, + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> None: + if messages is None and tools is None: + return None + + cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools) + await self.cache.async_set_cache( + cache_key, + PromptCachingCacheValue(model_id=model_id), + ttl=300, # store for 5 minutes + ) + return None + + async def async_get_model_id( + self, + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> Optional[PromptCachingCacheValue]: + """ + if messages is not none + - check full messages + - check messages[:-1] + - check messages[:-2] + - check messages[:-3] + + use self.cache.async_batch_get_cache(keys=potential_cache_keys]) + """ + if messages is None and tools is None: + return None + + # Generate potential cache keys by slicing messages + + potential_cache_keys = [] + + if messages is not None: + full_cache_key = PromptCachingCache.get_prompt_caching_cache_key( + messages, tools + ) + potential_cache_keys.append(full_cache_key) + + # Check progressively shorter message slices + for i in range(1, min(4, len(messages))): + partial_messages = messages[:-i] + partial_cache_key = PromptCachingCache.get_prompt_caching_cache_key( + partial_messages, tools + ) + potential_cache_keys.append(partial_cache_key) + + # Perform batch cache lookup + cache_results = await self.cache.async_batch_get_cache( + keys=potential_cache_keys + ) + + if cache_results is None: + return None + + # Return the first non-None cache result + for result in cache_results: + if result is not None: + return result + + return None + + def get_model_id( + self, + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> Optional[PromptCachingCacheValue]: + if messages is None and tools is None: + return None + + cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools) + return self.cache.get_cache(cache_key) + + async def async_get_prompt_caching_deployment( + self, + router: litellm_router, + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]], + ) -> Optional[dict]: + model_id_dict = await self.async_get_model_id( + messages=messages, + tools=tools, + ) + + if model_id_dict is not None: + healthy_deployment_pydantic_obj = router.get_deployment( + model_id=model_id_dict["model_id"] + ) + if healthy_deployment_pydantic_obj is not None: + return healthy_deployment_pydantic_obj.model_dump(exclude_none=True) + return None diff --git a/litellm/utils.py b/litellm/utils.py index 2f8225e9f965..bd36e211cff7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6151,6 +6151,38 @@ def add_dummy_tool(custom_llm_provider: str) -> List[ChatCompletionToolParam]: ) +def convert_to_dict(message: Union[BaseModel, dict]) -> dict: + """ + Converts a message to a dictionary if it's a Pydantic model. + + Args: + message: The message, which may be a Pydantic model or a dictionary. + + Returns: + dict: The converted message. + """ + if isinstance(message, BaseModel): + return message.model_dump(exclude_none=True) + elif isinstance(message, dict): + return message + else: + raise TypeError( + f"Invalid message type: {type(message)}. Expected dict or Pydantic model." + ) + + +def validate_chat_completion_messages(messages: List[AllMessageValues]): + """ + Ensures all messages are valid OpenAI chat completion messages. + """ + # 1. convert all messages to dict + messages = [ + cast(AllMessageValues, convert_to_dict(cast(dict, m))) for m in messages + ] + # 2. validate user messages + return validate_chat_completion_user_messages(messages=messages) + + def validate_chat_completion_user_messages(messages: List[AllMessageValues]): """ Ensures all user messages are valid OpenAI chat completion messages. @@ -6229,3 +6261,22 @@ def get_end_user_id_for_cost_tracking( ): return None return proxy_server_request.get("body", {}).get("user", None) + + +def is_prompt_caching_valid_prompt( + model: str, + messages: Optional[List[AllMessageValues]], + tools: Optional[List[ChatCompletionToolParam]] = None, + custom_llm_provider: Optional[str] = None, +) -> bool: + """ + Returns true if the prompt is valid for prompt caching. + + OpenAI + Anthropic providers have a minimum token count of 1024 for prompt caching. + """ + if messages is None and tools is None: + return False + if custom_llm_provider is not None and not model.startswith(custom_llm_provider): + model = custom_llm_provider + "/" + model + token_count = token_counter(messages=messages, tools=tools, model=model) + return token_count >= 1024 diff --git a/tests/llm_translation/base_embedding_unit_tests.py b/tests/llm_translation/base_embedding_unit_tests.py new file mode 100644 index 000000000000..94edeccdf364 --- /dev/null +++ b/tests/llm_translation/base_embedding_unit_tests.py @@ -0,0 +1,67 @@ +import asyncio +import httpx +import json +import pytest +import sys +from typing import Any, Dict, List +from unittest.mock import MagicMock, Mock, patch +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from litellm.exceptions import BadRequestError +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.utils import ( + CustomStreamWrapper, + get_supported_openai_params, + get_optional_params, + get_optional_params_embeddings, +) + +# test_example.py +from abc import ABC, abstractmethod + + +class BaseLLMEmbeddingTest(ABC): + """ + Abstract base test class that enforces a common test across all test classes. + """ + + @abstractmethod + def get_base_embedding_call_args(self) -> dict: + """Must return the base embedding call args""" + pass + + @abstractmethod + def get_custom_llm_provider(self) -> litellm.LlmProviders: + """Must return the custom llm provider""" + pass + + @pytest.mark.asyncio() + @pytest.mark.parametrize("sync_mode", [True, False]) + async def test_basic_embedding(self, sync_mode): + litellm.set_verbose = True + embedding_call_args = self.get_base_embedding_call_args() + if sync_mode is True: + response = litellm.embedding( + **embedding_call_args, + input=["hello", "world"], + ) + + print("embedding response: ", response) + else: + response = await litellm.aembedding( + **embedding_call_args, + input=["hello", "world"], + ) + + print("async embedding response: ", response) + + def test_embedding_optional_params_max_retries(self): + embedding_call_args = self.get_base_embedding_call_args() + optional_params = get_optional_params_embeddings( + **embedding_call_args, max_retries=20 + ) + assert optional_params["max_retries"] == 20 diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index 480fbce2fb51..19b382347db5 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -82,6 +82,16 @@ def test_content_list_handling(self): # for OpenAI the content contains the JSON schema, so we need to assert that the content is not None assert response.choices[0].message.content is not None + def test_pydantic_model_input(self): + litellm.set_verbose = True + + from litellm import completion, Message + + base_completion_call_args = self.get_base_completion_call_args() + messages = [Message(content="Hello, how are you?", role="user")] + + completion(**base_completion_call_args, messages=messages) + @pytest.mark.parametrize("image_url", ["str", "dict"]) def test_pdf_handling(self, pdf_messages, image_url): from litellm.utils import supports_pdf_input diff --git a/tests/llm_translation/test_azure_openai.py b/tests/llm_translation/test_azure_openai.py index 9bf16a5cf5b4..431fd4347a9c 100644 --- a/tests/llm_translation/test_azure_openai.py +++ b/tests/llm_translation/test_azure_openai.py @@ -8,6 +8,7 @@ import pytest from litellm.llms.azure.common_utils import process_azure_headers from httpx import Headers +from base_embedding_unit_tests import BaseLLMEmbeddingTest def test_process_azure_headers_empty(): @@ -188,3 +189,15 @@ def test_process_azure_endpoint_url(api_base, model, expected_endpoint): } result = azure_chat_completion.create_azure_base_url(**input_args) assert result == expected_endpoint, "Unexpected endpoint" + + +class TestAzureEmbedding(BaseLLMEmbeddingTest): + def get_base_embedding_call_args(self) -> dict: + return { + "model": "azure/azure-embedding-model", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + } + + def get_custom_llm_provider(self) -> litellm.LlmProviders: + return litellm.LlmProviders.AZURE diff --git a/tests/local_testing/test_anthropic_prompt_caching.py b/tests/local_testing/test_anthropic_prompt_caching.py index 829f5699b2da..cdee19384709 100644 --- a/tests/local_testing/test_anthropic_prompt_caching.py +++ b/tests/local_testing/test_anthropic_prompt_caching.py @@ -161,6 +161,49 @@ def return_val(): ) +@pytest.fixture +def anthropic_messages(): + return [ + # System Message + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Here is the full text of a complex legal agreement" * 400, + "cache_control": {"type": "ephemeral"}, + } + ], + }, + # marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + { + "role": "assistant", + "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo", + }, + # The final turn is marked with cache-control, for continuing in followups. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + ] + + @pytest.mark.asyncio() async def test_anthropic_api_prompt_caching_basic(): litellm.set_verbose = True @@ -227,8 +270,6 @@ async def test_anthropic_api_prompt_caching_basic(): @pytest.mark.asyncio() async def test_anthropic_api_prompt_caching_with_content_str(): - from litellm.llms.prompt_templates.factory import anthropic_messages_pt - system_message = [ { "role": "system", @@ -546,3 +587,134 @@ def return_val(): mock_post.assert_called_once_with( expected_url, json=expected_json, headers=expected_headers, timeout=600.0 ) + + +def test_is_prompt_caching_enabled(anthropic_messages): + assert litellm.utils.is_prompt_caching_valid_prompt( + messages=anthropic_messages, + tools=None, + custom_llm_provider="anthropic", + model="anthropic/claude-3-5-sonnet-20240620", + ) + + +@pytest.mark.parametrize( + "messages, expected_model_id", + [("anthropic_messages", True), ("normal_messages", False)], +) +@pytest.mark.asyncio() +async def test_router_prompt_caching_model_stored( + messages, expected_model_id, anthropic_messages +): + """ + If a model is called with prompt caching supported, then the model id should be stored in the router cache. + """ + import asyncio + from litellm.router import Router + from litellm.router_utils.prompt_caching_cache import PromptCachingCache + + router = Router( + model_list=[ + { + "model_name": "claude-model", + "litellm_params": { + "model": "anthropic/claude-3-5-sonnet-20240620", + "api_key": os.environ.get("ANTHROPIC_API_KEY"), + }, + "model_info": {"id": "1234"}, + } + ] + ) + + if messages == "anthropic_messages": + _messages = anthropic_messages + else: + _messages = [{"role": "user", "content": "Hello"}] + + await router.acompletion( + model="claude-model", + messages=_messages, + mock_response="The sky is blue.", + ) + await asyncio.sleep(1) + cache = PromptCachingCache( + cache=router.cache, + ) + + cached_model_id = cache.get_model_id(messages=_messages, tools=None) + + if expected_model_id: + assert cached_model_id["model_id"] == "1234" + else: + assert cached_model_id is None + + +@pytest.mark.asyncio() +async def test_router_with_prompt_caching(anthropic_messages): + """ + if prompt caching supported model called with prompt caching valid prompt, + then 2nd call should go to the same model. + """ + from litellm.router import Router + import asyncio + from litellm.router_utils.prompt_caching_cache import PromptCachingCache + + router = Router( + model_list=[ + { + "model_name": "claude-model", + "litellm_params": { + "model": "anthropic/claude-3-5-sonnet-20240620", + "api_key": os.environ.get("ANTHROPIC_API_KEY"), + }, + }, + { + "model_name": "claude-model", + "litellm_params": { + "model": "anthropic.claude-3-5-sonnet-20241022-v2:0", + }, + }, + ] + ) + + response = await router.acompletion( + messages=anthropic_messages, + model="claude-model", + mock_response="The sky is blue.", + ) + print("response=", response) + + initial_model_id = response._hidden_params["model_id"] + + await asyncio.sleep(1) + cache = PromptCachingCache( + cache=router.cache, + ) + + cached_model_id = cache.get_model_id(messages=anthropic_messages, tools=None) + + prompt_caching_cache_key = PromptCachingCache.get_prompt_caching_cache_key( + messages=anthropic_messages, tools=None + ) + print(f"prompt_caching_cache_key: {prompt_caching_cache_key}") + assert cached_model_id["model_id"] == initial_model_id + + new_messages = anthropic_messages + [ + {"role": "user", "content": "What is the weather in SF?"} + ] + + pc_deployment = await cache.async_get_prompt_caching_deployment( + router=router, + messages=new_messages, + tools=None, + ) + assert pc_deployment is not None + + response = await router.acompletion( + messages=new_messages, + model="claude-model", + mock_response="The sky is blue.", + ) + print("response=", response) + + assert response._hidden_params["model_id"] == initial_model_id diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index d3db083f68a1..7b84f454085a 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -2697,3 +2697,21 @@ def test_model_group_alias(hidden): # assert int(response_headers["x-ratelimit-remaining-requests"]) > 0 # assert response_headers["x-ratelimit-limit-tokens"] == 100500 # assert int(response_headers["x-ratelimit-remaining-tokens"]) > 0 + + +def test_router_completion_with_model_id(): + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "gpt-3.5-turbo"}, + "model_info": {"id": "123"}, + } + ] + ) + + with patch.object( + router, "routing_strategy_pre_call_checks" + ) as mock_pre_call_checks: + router.completion(model="123", messages=[{"role": "user", "content": "hi"}]) + mock_pre_call_checks.assert_not_called() diff --git a/tests/local_testing/test_router_cooldowns.py b/tests/local_testing/test_router_cooldowns.py index 774b36e2abbc..8c907af297e9 100644 --- a/tests/local_testing/test_router_cooldowns.py +++ b/tests/local_testing/test_router_cooldowns.py @@ -560,3 +560,34 @@ async def test_high_traffic_cooldowns_one_rate_limited_deployment(): 1. _set_cooldown_deployments() will cooldown a deployment after it fails 50% requests """ + + +def test_router_fallbacks_with_cooldowns_and_model_id(): + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "gpt-3.5-turbo", "rpm": 1}, + "model_info": { + "id": "123", + }, + } + ], + routing_strategy="usage-based-routing-v2", + fallbacks=[{"gpt-3.5-turbo": ["123"]}], + ) + + ## trigger ratelimit + try: + router.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + mock_response="litellm.RateLimitError", + ) + except litellm.RateLimitError: + pass + + router.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + ) diff --git a/tests/local_testing/test_router_fallbacks.py b/tests/local_testing/test_router_fallbacks.py index 3c97506916f5..e555b51d726f 100644 --- a/tests/local_testing/test_router_fallbacks.py +++ b/tests/local_testing/test_router_fallbacks.py @@ -1498,3 +1498,26 @@ async def test_router_disable_fallbacks_dynamically(): print(e) mock_client.assert_not_called() + + +def test_router_fallbacks_with_model_id(): + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "gpt-3.5-turbo", "rpm": 1}, + "model_info": { + "id": "123", + }, + } + ], + routing_strategy="usage-based-routing-v2", + fallbacks=[{"gpt-3.5-turbo": ["123"]}], + ) + + ## test model id fallback works + router.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + mock_testing_fallbacks=True, + ) diff --git a/tests/router_unit_tests/test_router_prompt_caching.py b/tests/router_unit_tests/test_router_prompt_caching.py new file mode 100644 index 000000000000..ca62f1e80d26 --- /dev/null +++ b/tests/router_unit_tests/test_router_prompt_caching.py @@ -0,0 +1,66 @@ +import sys +import os +import traceback +from dotenv import load_dotenv +from fastapi import Request +from datetime import datetime + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from litellm import Router +import pytest +import litellm +from unittest.mock import patch, MagicMock, AsyncMock +from create_mock_standard_logging_payload import create_standard_logging_payload +from litellm.types.utils import StandardLoggingPayload +import unittest +from pydantic import BaseModel +from litellm.router_utils.prompt_caching_cache import PromptCachingCache + + +class ExampleModel(BaseModel): + field1: str + field2: int + + +def test_serialize_pydantic_object(): + model = ExampleModel(field1="value", field2=42) + serialized = PromptCachingCache.serialize_object(model) + assert serialized == {"field1": "value", "field2": 42} + + +def test_serialize_dict(): + obj = {"b": 2, "a": 1} + serialized = PromptCachingCache.serialize_object(obj) + assert serialized == '{"a":1,"b":2}' # JSON string with sorted keys + + +def test_serialize_nested_dict(): + obj = {"z": {"b": 2, "a": 1}, "x": [1, 2, {"c": 3}]} + serialized = PromptCachingCache.serialize_object(obj) + expected = '{"x":[1,2,{"c":3}],"z":{"a":1,"b":2}}' # JSON string with sorted keys + assert serialized == expected + + +def test_serialize_list(): + obj = ["item1", {"a": 1, "b": 2}, 42] + serialized = PromptCachingCache.serialize_object(obj) + expected = ["item1", '{"a":1,"b":2}', 42] + assert serialized == expected + + +def test_serialize_fallback(): + obj = 12345 # Simple non-serializable object + serialized = PromptCachingCache.serialize_object(obj) + assert serialized == 12345 + + +def test_serialize_non_serializable(): + class CustomClass: + def __str__(self): + return "custom_object" + + obj = CustomClass() + serialized = PromptCachingCache.serialize_object(obj) + assert serialized == "custom_object" # Fallback to string conversion