diff --git a/validator_api/chat_completion.py b/validator_api/chat_completion.py index 0d1470e2..337ae45d 100644 --- a/validator_api/chat_completion.py +++ b/validator_api/chat_completion.py @@ -2,6 +2,7 @@ import asyncio import json import random +from typing import AsyncGenerator from fastapi import HTTPException from fastapi.responses import StreamingResponse import httpx @@ -41,7 +42,12 @@ async def forward_response(uid: int, body: dict[str, any], chunks: list[str]): logger.exception(f"Error while forwarding response: {e}") -async def stream_response(response, collected_chunks: list[str], body: dict[str, any], uid: int) -> AsyncGenerator[str, None]: +async def stream_response( + response, + collected_chunks: list[str], + body: dict[str, any], + uid: int + ) -> AsyncGenerator[str, None]: chunks_received = False try: async for chunk in response: @@ -64,7 +70,7 @@ async def stream_response(response, collected_chunks: list[str], body: dict[str, yield 'data: {"error": "Internal server Error"}\n\n' -async def regular_chat_completion(body: dict[str, any], uid: int | None = None) -> tuple | StreamingResponse: +async def chat_completion(body: dict[str, any], uid: int | None = None) -> tuple | StreamingResponse: """Handle regular chat completion without mixture of miners.""" if uid is None: uid = random.choice(get_uids(sampling_mode="top_incentive", k=100)) diff --git a/validator_api/gpt_endpoints.py b/validator_api/gpt_endpoints.py index 9638a89c..0058f738 100644 --- a/validator_api/gpt_endpoints.py +++ b/validator_api/gpt_endpoints.py @@ -4,14 +4,14 @@ from loguru import logger from starlette.responses import StreamingResponse -from validator_api import mixture_of_miners -from validator_api.chat_completion import regular_chat_completion +from validator_api.mixture_of_miners import mixture_of_miners +from validator_api.chat_completion import chat_completion router = APIRouter() @router.post("/v1/chat/completions") -async def chat_completion(request: Request): +async def completions(request: Request): """Main endpoint that handles both regular and mixture of miners chat completion.""" try: body = await request.json() @@ -21,7 +21,7 @@ async def chat_completion(request: Request): if body.get("mixture", False): return await mixture_of_miners(body) else: - return await regular_chat_completion(body) + return await chat_completion(body) except Exception as e: logger.exception(f"Error in chat completion: {e}") diff --git a/validator_api/mixture_of_miners.py b/validator_api/mixture_of_miners.py index b15f6f7d..189bfc56 100644 --- a/validator_api/mixture_of_miners.py +++ b/validator_api/mixture_of_miners.py @@ -7,7 +7,7 @@ from loguru import logger from shared.uids import get_uids -from validator_api.chat_completion import get_response_from_miner, regular_chat_completion +from validator_api.chat_completion import get_response_from_miner, chat_completion DEFAULT_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query. @@ -55,7 +55,7 @@ async def mixture_of_miners(body: dict[str, any]) -> tuple | StreamingResponse: # Get multiple miners miner_uids = get_uids(sampling_mode="top_incentive", k=NUM_MIXTURE_MINERS) - if not miner_uids: + if len(miner_uids) == 0: raise HTTPException(status_code=503, detail="No available miners found") # Concurrently collect responses from all miners. @@ -87,4 +87,4 @@ async def mixture_of_miners(body: dict[str, any]) -> tuple | StreamingResponse: # Get final response using a random top miner. final_uid = random.choice(get_uids(sampling_mode="top_incentive", k=TOP_INCENTIVE_POOL)) - return await regular_chat_completion(final_body, final_uid) + return await chat_completion(final_body, final_uid)