From fb9338657d05646b2cf85d3d8ef715a2c83ddf48 Mon Sep 17 00:00:00 2001 From: juberti Date: Wed, 11 Sep 2024 20:40:18 -0700 Subject: [PATCH] Add Mistral Large 2 and Nemo --- llm_benchmark_suite.py | 17 ++++++++++++++++- llm_request.py | 6 ++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/llm_benchmark_suite.py b/llm_benchmark_suite.py index 97743c6..354d6be 100644 --- a/llm_benchmark_suite.py +++ b/llm_benchmark_suite.py @@ -40,7 +40,6 @@ LLAMA_3_8B_CHAT_FP4 = "llama-3-8b-chat-fp4" MIXTRAL_8X7B_INSTRUCT = "mixtral-8x7b-instruct" MIXTRAL_8X7B_INSTRUCT_FP8 = "mixtral-8x7b-instruct-fp8" -PHI_2 = "phi-2" parser = argparse.ArgumentParser() @@ -183,6 +182,18 @@ def __init__(self, model: str, display_model: Optional[str] = None): ) +class _MistralLlm(_Llm): + """See https://docs.mistral.ai/getting-started/models""" + + def __init__(self, model: str, display_model: Optional[str] = None): + super().__init__( + model, + "mistral.ai/" + (display_model or model), + api_key=os.getenv("MISTRAL_API_KEY"), + base_url="https://api.mistral.ai/v1", + ) + + class _NvidiaLlm(_Llm): """See https://build.nvidia.com/explore/discover""" @@ -330,6 +341,9 @@ def _text_models(): _Llm("gemini-pro"), _Llm(GEMINI_1_5_PRO), _Llm(GEMINI_1_5_FLASH), + # Mistral + _MistralLlm("mistral-large-latest", "mistral-large"), + _MistralLlm("open-mistral-nemo", "mistral-nemo"), # Mistral 8x7b _DatabricksLlm("databricks-mixtral-8x7b-instruct", MIXTRAL_8X7B_INSTRUCT), _DeepInfraLlm("mistralai/Mixtral-8x7B-Instruct-v0.1", MIXTRAL_8X7B_INSTRUCT), @@ -484,6 +498,7 @@ def _image_models(): _FireworksLlm( "accounts/fireworks/models/phi-3-vision-128k-instruct", "phi-3-vision" ), + _MistralLlm("pixtral-latest", "pixtral"), ] diff --git a/llm_request.py b/llm_request.py index 31ee185..3216d49 100644 --- a/llm_request.py +++ b/llm_request.py @@ -146,7 +146,9 @@ async def run(self, on_token: Optional[Callable[["ApiContext", str], None]] = No if not self.metrics.error: token_time = end_time - first_token_time self.metrics.total_time = end_time - start_time - self.metrics.tps = min((self.metrics.output_tokens - 1) / token_time, MAX_TPS) + self.metrics.tps = min( + (self.metrics.output_tokens - 1) / token_time, MAX_TPS + ) if self.metrics.tps == MAX_TPS: self.metrics.tps = 0.0 else: @@ -293,7 +295,7 @@ async def openai_chat(ctx: ApiContext, path: str = "/chat/completions") -> ApiRe # Some providers require opt-in for stream stats, but some providers don't like this opt-in. # Regardless of opt-in, Azure and ovh.net don't return stream stats at the moment. # See https://github.com/Azure/azure-rest-api-specs/issues/25062 - if not any(p in ctx.name for p in ["azure", "databricks", "fireworks"]): + if not any(p in ctx.name for p in ["azure", "databricks", "fireworks", "mistral"]): kwargs["stream_options"] = {"include_usage": True} data = make_openai_chat_body(ctx, **kwargs) return await post(ctx, url, headers, data, openai_chunk_gen)