Skip to content

Commit

Permalink
Merge pull request #122 from fixie-ai/juberti/mistral
Browse files Browse the repository at this point in the history
Add Mistral Large 2 and Nemo
  • Loading branch information
juberti authored Sep 12, 2024
2 parents 30d8fe0 + fb93386 commit a9ff30f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
17 changes: 16 additions & 1 deletion llm_benchmark_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
LLAMA_3_8B_CHAT_FP4 = "llama-3-8b-chat-fp4"
MIXTRAL_8X7B_INSTRUCT = "mixtral-8x7b-instruct"
MIXTRAL_8X7B_INSTRUCT_FP8 = "mixtral-8x7b-instruct-fp8"
PHI_2 = "phi-2"


parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -183,6 +182,18 @@ def __init__(self, model: str, display_model: Optional[str] = None):
)


class _MistralLlm(_Llm):
"""See https://docs.mistral.ai/getting-started/models"""

def __init__(self, model: str, display_model: Optional[str] = None):
super().__init__(
model,
"mistral.ai/" + (display_model or model),
api_key=os.getenv("MISTRAL_API_KEY"),
base_url="https://api.mistral.ai/v1",
)


class _NvidiaLlm(_Llm):
"""See https://build.nvidia.com/explore/discover"""

Expand Down Expand Up @@ -330,6 +341,9 @@ def _text_models():
_Llm("gemini-pro"),
_Llm(GEMINI_1_5_PRO),
_Llm(GEMINI_1_5_FLASH),
# Mistral
_MistralLlm("mistral-large-latest", "mistral-large"),
_MistralLlm("open-mistral-nemo", "mistral-nemo"),
# Mistral 8x7b
_DatabricksLlm("databricks-mixtral-8x7b-instruct", MIXTRAL_8X7B_INSTRUCT),
_DeepInfraLlm("mistralai/Mixtral-8x7B-Instruct-v0.1", MIXTRAL_8X7B_INSTRUCT),
Expand Down Expand Up @@ -484,6 +498,7 @@ def _image_models():
_FireworksLlm(
"accounts/fireworks/models/phi-3-vision-128k-instruct", "phi-3-vision"
),
_MistralLlm("pixtral-latest", "pixtral"),
]


Expand Down
6 changes: 4 additions & 2 deletions llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ async def run(self, on_token: Optional[Callable[["ApiContext", str], None]] = No
if not self.metrics.error:
token_time = end_time - first_token_time
self.metrics.total_time = end_time - start_time
self.metrics.tps = min((self.metrics.output_tokens - 1) / token_time, MAX_TPS)
self.metrics.tps = min(
(self.metrics.output_tokens - 1) / token_time, MAX_TPS
)
if self.metrics.tps == MAX_TPS:
self.metrics.tps = 0.0
else:
Expand Down Expand Up @@ -293,7 +295,7 @@ async def openai_chat(ctx: ApiContext, path: str = "/chat/completions") -> ApiRe
# Some providers require opt-in for stream stats, but some providers don't like this opt-in.
# Regardless of opt-in, Azure and ovh.net don't return stream stats at the moment.
# See https://github.com/Azure/azure-rest-api-specs/issues/25062
if not any(p in ctx.name for p in ["azure", "databricks", "fireworks"]):
if not any(p in ctx.name for p in ["azure", "databricks", "fireworks", "mistral"]):
kwargs["stream_options"] = {"include_usage": True}
data = make_openai_chat_body(ctx, **kwargs)
return await post(ctx, url, headers, data, openai_chunk_gen)
Expand Down

0 comments on commit a9ff30f

Please sign in to comment.