Skip to content

Commit

Permalink
Finish implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
dbobrenko committed Dec 19, 2024
1 parent f82105a commit 27b035a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
10 changes: 8 additions & 2 deletions validator_api/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
import json
import random
from typing import AsyncGenerator
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
import httpx
Expand Down Expand Up @@ -41,7 +42,12 @@ async def forward_response(uid: int, body: dict[str, any], chunks: list[str]):
logger.exception(f"Error while forwarding response: {e}")


async def stream_response(response, collected_chunks: list[str], body: dict[str, any], uid: int) -> AsyncGenerator[str, None]:
async def stream_response(
response,
collected_chunks: list[str],
body: dict[str, any],
uid: int
) -> AsyncGenerator[str, None]:
chunks_received = False
try:
async for chunk in response:
Expand All @@ -64,7 +70,7 @@ async def stream_response(response, collected_chunks: list[str], body: dict[str,
yield 'data: {"error": "Internal server Error"}\n\n'


async def regular_chat_completion(body: dict[str, any], uid: int | None = None) -> tuple | StreamingResponse:
async def chat_completion(body: dict[str, any], uid: int | None = None) -> tuple | StreamingResponse:
"""Handle regular chat completion without mixture of miners."""
if uid is None:
uid = random.choice(get_uids(sampling_mode="top_incentive", k=100))
Expand Down
8 changes: 4 additions & 4 deletions validator_api/gpt_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from loguru import logger
from starlette.responses import StreamingResponse

from validator_api import mixture_of_miners
from validator_api.chat_completion import regular_chat_completion
from validator_api.mixture_of_miners import mixture_of_miners
from validator_api.chat_completion import chat_completion

router = APIRouter()


@router.post("/v1/chat/completions")
async def chat_completion(request: Request):
async def completions(request: Request):
"""Main endpoint that handles both regular and mixture of miners chat completion."""
try:
body = await request.json()
Expand All @@ -21,7 +21,7 @@ async def chat_completion(request: Request):
if body.get("mixture", False):
return await mixture_of_miners(body)
else:
return await regular_chat_completion(body)
return await chat_completion(body)

except Exception as e:
logger.exception(f"Error in chat completion: {e}")
Expand Down
6 changes: 3 additions & 3 deletions validator_api/mixture_of_miners.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from loguru import logger

from shared.uids import get_uids
from validator_api.chat_completion import get_response_from_miner, regular_chat_completion
from validator_api.chat_completion import get_response_from_miner, chat_completion


DEFAULT_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query.
Expand Down Expand Up @@ -55,7 +55,7 @@ async def mixture_of_miners(body: dict[str, any]) -> tuple | StreamingResponse:

# Get multiple miners
miner_uids = get_uids(sampling_mode="top_incentive", k=NUM_MIXTURE_MINERS)
if not miner_uids:
if len(miner_uids) == 0:
raise HTTPException(status_code=503, detail="No available miners found")

# Concurrently collect responses from all miners.
Expand Down Expand Up @@ -87,4 +87,4 @@ async def mixture_of_miners(body: dict[str, any]) -> tuple | StreamingResponse:

# Get final response using a random top miner.
final_uid = random.choice(get_uids(sampling_mode="top_incentive", k=TOP_INCENTIVE_POOL))
return await regular_chat_completion(final_body, final_uid)
return await chat_completion(final_body, final_uid)

0 comments on commit 27b035a

Please sign in to comment.