generated from opentensor/bittensor-subnet-template
-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Mixture-of-Miners endpoint (#472)
## Changes - Add Mixture-of-Miners endpoint. - Move chat completions processing and mixture of miners into two separate functions. - Add system prompt to the inference task.
- Loading branch information
Showing
4 changed files
with
203 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import asyncio | ||
import json | ||
import random | ||
from typing import AsyncGenerator | ||
|
||
import httpx | ||
from fastapi import HTTPException | ||
from fastapi.responses import StreamingResponse | ||
from loguru import logger | ||
|
||
from shared.epistula import make_openai_query | ||
from shared.settings import shared_settings | ||
from shared.uids import get_uids | ||
|
||
|
||
async def forward_response(uid: int, body: dict[str, any], chunks: list[str]): | ||
uid = int(uid) # sometimes uid is type np.uint64 | ||
logger.info(f"Forwarding response to scoring with body: {body}") | ||
if not shared_settings.SCORE_ORGANICS: # Allow disabling of scoring by default | ||
return | ||
|
||
if body.get("task") != "InferenceTask": | ||
logger.debug(f"Skipping forwarding for non-inference task: {body.get('task')}") | ||
return | ||
url = f"http://{shared_settings.VALIDATOR_API}/scoring" | ||
payload = {"body": body, "chunks": chunks, "uid": uid} | ||
try: | ||
timeout = httpx.Timeout(timeout=120.0, connect=60.0, read=30.0, write=30.0, pool=5.0) | ||
async with httpx.AsyncClient(timeout=timeout) as client: | ||
response = await client.post( | ||
url, json=payload, headers={"api-key": shared_settings.SCORING_KEY, "Content-Type": "application/json"} | ||
) | ||
if response.status_code == 200: | ||
logger.info(f"Forwarding response completed with status {response.status_code}") | ||
|
||
else: | ||
logger.exception( | ||
f"Forwarding response uid {uid} failed with status {response.status_code} and payload {payload}" | ||
) | ||
|
||
except Exception as e: | ||
logger.error(f"Tried to forward response to {url} with payload {payload}") | ||
logger.exception(f"Error while forwarding response: {e}") | ||
|
||
|
||
async def stream_response( | ||
response, collected_chunks: list[str], body: dict[str, any], uid: int | ||
) -> AsyncGenerator[str, None]: | ||
chunks_received = False | ||
try: | ||
async for chunk in response: | ||
chunks_received = True | ||
collected_chunks.append(chunk.choices[0].delta.content) | ||
yield f"data: {json.dumps(chunk.model_dump())}\n\n" | ||
|
||
if not chunks_received: | ||
logger.error("Stream is empty: No chunks were received") | ||
yield 'data: {"error": "502 - Response is empty"}\n\n' | ||
yield "data: [DONE]\n\n" | ||
|
||
# Forward the collected chunks after streaming is complete | ||
asyncio.create_task(forward_response(uid=uid, body=body, chunks=collected_chunks)) | ||
except asyncio.CancelledError: | ||
logger.info("Client disconnected, streaming cancelled") | ||
raise | ||
except Exception as e: | ||
logger.exception(f"Error during streaming: {e}") | ||
yield 'data: {"error": "Internal server Error"}\n\n' | ||
|
||
|
||
async def chat_completion(body: dict[str, any], uid: int | None = None) -> tuple | StreamingResponse: | ||
"""Handle regular chat completion without mixture of miners.""" | ||
if uid is None: | ||
uid = random.choice(get_uids(sampling_mode="top_incentive", k=100)) | ||
|
||
if uid is None: | ||
logger.error("No available miner found") | ||
raise HTTPException(status_code=503, detail="No available miner found") | ||
|
||
logger.debug(f"Querying uid {uid}") | ||
STREAM = body.get("stream", False) | ||
|
||
collected_chunks: list[str] = [] | ||
|
||
logger.info(f"Making {'streaming' if STREAM else 'non-streaming'} openai query with body: {body}") | ||
response = await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=STREAM) | ||
|
||
if STREAM: | ||
return StreamingResponse( | ||
stream_response(response, collected_chunks, body, uid), | ||
media_type="text/event-stream", | ||
headers={ | ||
"Cache-Control": "no-cache", | ||
"Connection": "keep-alive", | ||
}, | ||
) | ||
else: | ||
asyncio.create_task(forward_response(uid=uid, body=body, chunks=response[1])) | ||
return response[0] | ||
|
||
|
||
async def get_response_from_miner(body: dict[str, any], uid: int) -> tuple: | ||
"""Get response from a single miner.""" | ||
return await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,107 +1,28 @@ | ||
import asyncio | ||
import json | ||
import random | ||
|
||
import httpx | ||
from fastapi import APIRouter, HTTPException, Request | ||
from fastapi import APIRouter, Request | ||
from loguru import logger | ||
from starlette.responses import StreamingResponse | ||
|
||
from shared.epistula import make_openai_query | ||
from shared.settings import shared_settings | ||
from shared.uids import get_uids | ||
from validator_api.chat_completion import chat_completion | ||
from validator_api.mixture_of_miners import mixture_of_miners | ||
|
||
router = APIRouter() | ||
|
||
|
||
async def forward_response(uid: int, body: dict[str, any], chunks: list[str]): | ||
uid = int(uid) # sometimes uid is type np.uint64 | ||
logger.info(f"Forwarding response to scoring with body: {body}") | ||
if not shared_settings.SCORE_ORGANICS: # Allow disabling of scoring by default | ||
return | ||
|
||
if body.get("task") != "InferenceTask": | ||
logger.debug(f"Skipping forwarding for non-inference task: {body.get('task')}") | ||
return | ||
url = f"http://{shared_settings.VALIDATOR_API}/scoring" | ||
payload = {"body": body, "chunks": chunks, "uid": uid} | ||
try: | ||
timeout = httpx.Timeout(timeout=120.0, connect=60.0, read=30.0, write=30.0, pool=5.0) | ||
async with httpx.AsyncClient(timeout=timeout) as client: | ||
response = await client.post( | ||
url, json=payload, headers={"api-key": shared_settings.SCORING_KEY, "Content-Type": "application/json"} | ||
) | ||
if response.status_code == 200: | ||
logger.info(f"Forwarding response completed with status {response.status_code}") | ||
|
||
else: | ||
logger.exception( | ||
f"Forwarding response uid {uid} failed with status {response.status_code} and payload {payload}" | ||
) | ||
|
||
except Exception as e: | ||
logger.error(f"Tried to forward response to {url} with payload {payload}") | ||
logger.exception(f"Error while forwarding response: {e}") | ||
|
||
|
||
@router.post("/v1/chat/completions") | ||
async def chat_completion(request: Request): # , cbackground_tasks: BackgroundTasks): | ||
async def completions(request: Request): | ||
"""Main endpoint that handles both regular and mixture of miners chat completion.""" | ||
try: | ||
body = await request.json() | ||
body["seed"] = int(body.get("seed") or random.randint(0, 1000000)) | ||
STREAM = body.get("stream") or False | ||
logger.debug(f"Streaming: {STREAM}") | ||
# Get random miner from top 100 incentive. | ||
uid = random.choice(get_uids(sampling_mode="top_incentive", k=100)) | ||
# uid = get_available_miner(task=body.get("task"), model=body.get("model")) | ||
if uid is None: | ||
logger.error("No available miner found") | ||
raise HTTPException(status_code=503, detail="No available miner found") | ||
logger.debug(f"Querying uid {uid}") | ||
|
||
collected_chunks: list[str] = [] | ||
|
||
# Create a wrapper for the streaming response | ||
async def stream_with_error_handling(): | ||
chunks_received = False | ||
try: | ||
async for chunk in response: | ||
chunks_received = True | ||
collected_chunks.append(chunk.choices[0].delta.content) | ||
yield f"data: {json.dumps(chunk.model_dump())}\n\n" | ||
|
||
if not chunks_received: | ||
logger.error("Stream is empty: No chunks were received") | ||
yield 'data: {"error": "502 - Response is empty"}\n\n' | ||
yield "data: [DONE]\n\n" | ||
|
||
# Once the stream is done, forward the collected chunks | ||
asyncio.create_task(forward_response(uid=uid, body=body, chunks=collected_chunks)) | ||
# background_tasks.add_task(forward_response, uid=uid, body=body, chunks=collected_chunks) | ||
except asyncio.CancelledError: | ||
logger.info("Client disconnected, streaming cancelled") | ||
raise | ||
except Exception as e: | ||
logger.exception(f"Error during streaming: {e}") | ||
yield 'data: {"error": "Internal server Error"}\n\n' | ||
|
||
logger.info(f"Making {'streaming' if STREAM else 'non-streaming'} openai query with body: {body}") | ||
response = await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=STREAM) | ||
|
||
if STREAM: | ||
return StreamingResponse( | ||
stream_with_error_handling(), | ||
media_type="text/event-stream", | ||
headers={ | ||
"Cache-Control": "no-cache", | ||
"Connection": "keep-alive", | ||
}, | ||
) | ||
# Choose between regular completion and mixture of miners. | ||
if body.get("mixture", False): | ||
return await mixture_of_miners(body) | ||
else: | ||
logger.info("Forwarding response to scoring...") | ||
asyncio.create_task(forward_response(uid=uid, body=body, chunks=response[1])) | ||
return response[0] | ||
return await chat_completion(body) | ||
|
||
except Exception as e: | ||
logger.exception(f"Error setting up streaming: {e}") | ||
logger.exception(f"Error in chat completion: {e}") | ||
return StreamingResponse(content="Internal Server Error", status_code=500) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import asyncio | ||
import copy | ||
import random | ||
|
||
from fastapi import HTTPException | ||
from fastapi.responses import StreamingResponse | ||
from loguru import logger | ||
|
||
from shared.uids import get_uids | ||
from validator_api.chat_completion import chat_completion, get_response_from_miner | ||
|
||
DEFAULT_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query. | ||
Your task is to synthesize these responses into a single, high-quality and concise response. | ||
It is crucial to follow the provided instuctions or examples in the given prompt if any, and ensure the answer is in correct and expected format. | ||
Critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. | ||
Your response should not simply replicate the given answers but should offer a refined and accurate reply to the instruction. | ||
Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. | ||
Responses from models:""" | ||
|
||
TASK_SYSTEM_PROMPT = { | ||
None: DEFAULT_SYSTEM_PROMPT, | ||
# Add more task-specific system prompts here. | ||
} | ||
|
||
NUM_MIXTURE_MINERS = 5 | ||
TOP_INCENTIVE_POOL = 100 | ||
|
||
|
||
async def get_miner_response(body: dict, uid: str) -> tuple | None: | ||
"""Get response from a single miner with error handling.""" | ||
try: | ||
return await get_response_from_miner(body, uid) | ||
except Exception as e: | ||
logger.error(f"Error getting response from miner {uid}: {e}") | ||
return None | ||
|
||
|
||
async def mixture_of_miners(body: dict[str, any]) -> tuple | StreamingResponse: | ||
"""Handle chat completion with mixture of miners approach. | ||
Based on Mixture-of-Agents Enhances Large Language Model Capabilities, 2024, Wang et al.: | ||
https://arxiv.org/abs/2406.04692 | ||
Args: | ||
body: Query parameters: | ||
messages: User prompt. | ||
stream: If True, stream the response. | ||
model: Optional model used for inference, SharedSettings.LLM_MODEL is used by default. | ||
task: Optional task, see prompting/tasks/task_registry.py, InferenceTask is used by default. | ||
""" | ||
body_first_step = copy.deepcopy(body) | ||
body_first_step["stream"] = False | ||
|
||
# Get multiple miners | ||
miner_uids = get_uids(sampling_mode="top_incentive", k=NUM_MIXTURE_MINERS) | ||
if len(miner_uids) == 0: | ||
raise HTTPException(status_code=503, detail="No available miners found") | ||
|
||
# Concurrently collect responses from all miners. | ||
miner_tasks = [get_miner_response(body_first_step, uid) for uid in miner_uids] | ||
responses = await asyncio.gather(*miner_tasks) | ||
|
||
# Filter out None responses (failed requests). | ||
valid_responses = [r for r in responses if r is not None] | ||
|
||
if not valid_responses: | ||
raise HTTPException(status_code=503, detail="Failed to get responses from miners") | ||
|
||
# Extract completions from the responses. | ||
completions = [response[1][0] for response in valid_responses if response and len(response) > 1] | ||
|
||
task_name = body.get("task") | ||
system_prompt = TASK_SYSTEM_PROMPT.get(task_name, DEFAULT_SYSTEM_PROMPT) | ||
|
||
# Aggregate responses into one system prompt. | ||
agg_system_prompt = system_prompt + "\n" + "\n".join([f"{i+1}. {comp}" for i, comp in enumerate(completions)]) | ||
|
||
# Prepare new messages with the aggregated system prompt. | ||
new_messages = [{"role": "system", "content": agg_system_prompt}] | ||
new_messages.extend([msg for msg in body["messages"] if msg["role"] != "system"]) | ||
|
||
# Update the body with the new messages. | ||
final_body = copy.deepcopy(body) | ||
final_body["messages"] = new_messages | ||
|
||
# Get final response using a random top miner. | ||
final_uid = random.choice(get_uids(sampling_mode="top_incentive", k=TOP_INCENTIVE_POOL)) | ||
return await chat_completion(final_body, final_uid) |