Skip to content

Commit

Permalink
Add Mixture-of-Miners endpoint (#472)
Browse files Browse the repository at this point in the history
## Changes
  - Add Mixture-of-Miners endpoint.
- Move chat completions processing and mixture of miners into two
separate functions.
  - Add system prompt to the inference task.
  • Loading branch information
dbobrenko authored Dec 20, 2024
1 parent 9c1659d commit b0d8e85
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 90 deletions.
2 changes: 1 addition & 1 deletion prompting/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def generate_query(
"""Generates a query to be used for generating the challenge"""
logger.info("🤖 Generating query...")
llm_messages = [LLMMessage(role="system", content=self.query_system_prompt)] if self.query_system_prompt else []
llm_messages += [LLMMessage(role="user", content=message) for message in messages]
llm_messages.extend([LLMMessage(role="user", content=message) for message in messages])

self.query = LLMWrapper.chat_complete(messages=LLMMessages(*llm_messages))

Expand Down
104 changes: 104 additions & 0 deletions validator_api/chat_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import asyncio
import json
import random
from typing import AsyncGenerator

import httpx
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger

from shared.epistula import make_openai_query
from shared.settings import shared_settings
from shared.uids import get_uids


async def forward_response(uid: int, body: dict[str, any], chunks: list[str]):
uid = int(uid) # sometimes uid is type np.uint64
logger.info(f"Forwarding response to scoring with body: {body}")
if not shared_settings.SCORE_ORGANICS: # Allow disabling of scoring by default
return

if body.get("task") != "InferenceTask":
logger.debug(f"Skipping forwarding for non-inference task: {body.get('task')}")
return
url = f"http://{shared_settings.VALIDATOR_API}/scoring"
payload = {"body": body, "chunks": chunks, "uid": uid}
try:
timeout = httpx.Timeout(timeout=120.0, connect=60.0, read=30.0, write=30.0, pool=5.0)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url, json=payload, headers={"api-key": shared_settings.SCORING_KEY, "Content-Type": "application/json"}
)
if response.status_code == 200:
logger.info(f"Forwarding response completed with status {response.status_code}")

else:
logger.exception(
f"Forwarding response uid {uid} failed with status {response.status_code} and payload {payload}"
)

except Exception as e:
logger.error(f"Tried to forward response to {url} with payload {payload}")
logger.exception(f"Error while forwarding response: {e}")


async def stream_response(
response, collected_chunks: list[str], body: dict[str, any], uid: int
) -> AsyncGenerator[str, None]:
chunks_received = False
try:
async for chunk in response:
chunks_received = True
collected_chunks.append(chunk.choices[0].delta.content)
yield f"data: {json.dumps(chunk.model_dump())}\n\n"

if not chunks_received:
logger.error("Stream is empty: No chunks were received")
yield 'data: {"error": "502 - Response is empty"}\n\n'
yield "data: [DONE]\n\n"

# Forward the collected chunks after streaming is complete
asyncio.create_task(forward_response(uid=uid, body=body, chunks=collected_chunks))
except asyncio.CancelledError:
logger.info("Client disconnected, streaming cancelled")
raise
except Exception as e:
logger.exception(f"Error during streaming: {e}")
yield 'data: {"error": "Internal server Error"}\n\n'


async def chat_completion(body: dict[str, any], uid: int | None = None) -> tuple | StreamingResponse:
"""Handle regular chat completion without mixture of miners."""
if uid is None:
uid = random.choice(get_uids(sampling_mode="top_incentive", k=100))

if uid is None:
logger.error("No available miner found")
raise HTTPException(status_code=503, detail="No available miner found")

logger.debug(f"Querying uid {uid}")
STREAM = body.get("stream", False)

collected_chunks: list[str] = []

logger.info(f"Making {'streaming' if STREAM else 'non-streaming'} openai query with body: {body}")
response = await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=STREAM)

if STREAM:
return StreamingResponse(
stream_response(response, collected_chunks, body, uid),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
else:
asyncio.create_task(forward_response(uid=uid, body=body, chunks=response[1]))
return response[0]


async def get_response_from_miner(body: dict[str, any], uid: int) -> tuple:
"""Get response from a single miner."""
return await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=False)
99 changes: 10 additions & 89 deletions validator_api/gpt_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,107 +1,28 @@
import asyncio
import json
import random

import httpx
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, Request
from loguru import logger
from starlette.responses import StreamingResponse

from shared.epistula import make_openai_query
from shared.settings import shared_settings
from shared.uids import get_uids
from validator_api.chat_completion import chat_completion
from validator_api.mixture_of_miners import mixture_of_miners

router = APIRouter()


async def forward_response(uid: int, body: dict[str, any], chunks: list[str]):
uid = int(uid) # sometimes uid is type np.uint64
logger.info(f"Forwarding response to scoring with body: {body}")
if not shared_settings.SCORE_ORGANICS: # Allow disabling of scoring by default
return

if body.get("task") != "InferenceTask":
logger.debug(f"Skipping forwarding for non-inference task: {body.get('task')}")
return
url = f"http://{shared_settings.VALIDATOR_API}/scoring"
payload = {"body": body, "chunks": chunks, "uid": uid}
try:
timeout = httpx.Timeout(timeout=120.0, connect=60.0, read=30.0, write=30.0, pool=5.0)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url, json=payload, headers={"api-key": shared_settings.SCORING_KEY, "Content-Type": "application/json"}
)
if response.status_code == 200:
logger.info(f"Forwarding response completed with status {response.status_code}")

else:
logger.exception(
f"Forwarding response uid {uid} failed with status {response.status_code} and payload {payload}"
)

except Exception as e:
logger.error(f"Tried to forward response to {url} with payload {payload}")
logger.exception(f"Error while forwarding response: {e}")


@router.post("/v1/chat/completions")
async def chat_completion(request: Request): # , cbackground_tasks: BackgroundTasks):
async def completions(request: Request):
"""Main endpoint that handles both regular and mixture of miners chat completion."""
try:
body = await request.json()
body["seed"] = int(body.get("seed") or random.randint(0, 1000000))
STREAM = body.get("stream") or False
logger.debug(f"Streaming: {STREAM}")
# Get random miner from top 100 incentive.
uid = random.choice(get_uids(sampling_mode="top_incentive", k=100))
# uid = get_available_miner(task=body.get("task"), model=body.get("model"))
if uid is None:
logger.error("No available miner found")
raise HTTPException(status_code=503, detail="No available miner found")
logger.debug(f"Querying uid {uid}")

collected_chunks: list[str] = []

# Create a wrapper for the streaming response
async def stream_with_error_handling():
chunks_received = False
try:
async for chunk in response:
chunks_received = True
collected_chunks.append(chunk.choices[0].delta.content)
yield f"data: {json.dumps(chunk.model_dump())}\n\n"

if not chunks_received:
logger.error("Stream is empty: No chunks were received")
yield 'data: {"error": "502 - Response is empty"}\n\n'
yield "data: [DONE]\n\n"

# Once the stream is done, forward the collected chunks
asyncio.create_task(forward_response(uid=uid, body=body, chunks=collected_chunks))
# background_tasks.add_task(forward_response, uid=uid, body=body, chunks=collected_chunks)
except asyncio.CancelledError:
logger.info("Client disconnected, streaming cancelled")
raise
except Exception as e:
logger.exception(f"Error during streaming: {e}")
yield 'data: {"error": "Internal server Error"}\n\n'

logger.info(f"Making {'streaming' if STREAM else 'non-streaming'} openai query with body: {body}")
response = await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=STREAM)

if STREAM:
return StreamingResponse(
stream_with_error_handling(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
# Choose between regular completion and mixture of miners.
if body.get("mixture", False):
return await mixture_of_miners(body)
else:
logger.info("Forwarding response to scoring...")
asyncio.create_task(forward_response(uid=uid, body=body, chunks=response[1]))
return response[0]
return await chat_completion(body)

except Exception as e:
logger.exception(f"Error setting up streaming: {e}")
logger.exception(f"Error in chat completion: {e}")
return StreamingResponse(content="Internal Server Error", status_code=500)
88 changes: 88 additions & 0 deletions validator_api/mixture_of_miners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import asyncio
import copy
import random

from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger

from shared.uids import get_uids
from validator_api.chat_completion import chat_completion, get_response_from_miner

DEFAULT_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query.
Your task is to synthesize these responses into a single, high-quality and concise response.
It is crucial to follow the provided instuctions or examples in the given prompt if any, and ensure the answer is in correct and expected format.
Critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect.
Your response should not simply replicate the given answers but should offer a refined and accurate reply to the instruction.
Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
Responses from models:"""

TASK_SYSTEM_PROMPT = {
None: DEFAULT_SYSTEM_PROMPT,
# Add more task-specific system prompts here.
}

NUM_MIXTURE_MINERS = 5
TOP_INCENTIVE_POOL = 100


async def get_miner_response(body: dict, uid: str) -> tuple | None:
"""Get response from a single miner with error handling."""
try:
return await get_response_from_miner(body, uid)
except Exception as e:
logger.error(f"Error getting response from miner {uid}: {e}")
return None


async def mixture_of_miners(body: dict[str, any]) -> tuple | StreamingResponse:
"""Handle chat completion with mixture of miners approach.
Based on Mixture-of-Agents Enhances Large Language Model Capabilities, 2024, Wang et al.:
https://arxiv.org/abs/2406.04692
Args:
body: Query parameters:
messages: User prompt.
stream: If True, stream the response.
model: Optional model used for inference, SharedSettings.LLM_MODEL is used by default.
task: Optional task, see prompting/tasks/task_registry.py, InferenceTask is used by default.
"""
body_first_step = copy.deepcopy(body)
body_first_step["stream"] = False

# Get multiple miners
miner_uids = get_uids(sampling_mode="top_incentive", k=NUM_MIXTURE_MINERS)
if len(miner_uids) == 0:
raise HTTPException(status_code=503, detail="No available miners found")

# Concurrently collect responses from all miners.
miner_tasks = [get_miner_response(body_first_step, uid) for uid in miner_uids]
responses = await asyncio.gather(*miner_tasks)

# Filter out None responses (failed requests).
valid_responses = [r for r in responses if r is not None]

if not valid_responses:
raise HTTPException(status_code=503, detail="Failed to get responses from miners")

# Extract completions from the responses.
completions = [response[1][0] for response in valid_responses if response and len(response) > 1]

task_name = body.get("task")
system_prompt = TASK_SYSTEM_PROMPT.get(task_name, DEFAULT_SYSTEM_PROMPT)

# Aggregate responses into one system prompt.
agg_system_prompt = system_prompt + "\n" + "\n".join([f"{i+1}. {comp}" for i, comp in enumerate(completions)])

# Prepare new messages with the aggregated system prompt.
new_messages = [{"role": "system", "content": agg_system_prompt}]
new_messages.extend([msg for msg in body["messages"] if msg["role"] != "system"])

# Update the body with the new messages.
final_body = copy.deepcopy(body)
final_body["messages"] = new_messages

# Get final response using a random top miner.
final_uid = random.choice(get_uids(sampling_mode="top_incentive", k=TOP_INCENTIVE_POOL))
return await chat_completion(final_body, final_uid)

0 comments on commit b0d8e85

Please sign in to comment.