Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mixture-of-Miners endpoint #472

Merged
merged 51 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
4457d4e
Initial upload
Nov 19, 2024
62ae30c
Get everything working
Nov 19, 2024
8e25f33
SN1-331: Adding initial draft for endpoints
Hollyqui Nov 20, 2024
ee5351f
SN1-331: Adding API keys
Hollyqui Nov 22, 2024
4ebe74d
Adding test miner ids
Hollyqui Nov 22, 2024
20e5376
Adding tasks to scoring queue
Hollyqui Nov 24, 2024
37a8874
Enabling non-streaming response + bug fixes
Hollyqui Nov 25, 2024
a09cd9b
Making model loading non-blocking
Hollyqui Nov 25, 2024
566dc77
Protecting endpoints with API key
Hollyqui Nov 26, 2024
9d810ce
Improving error messages + improving API key saving
Hollyqui Nov 26, 2024
685290a
Signing epistula properly for recipient
Hollyqui Nov 26, 2024
7296fe7
Passing task type
Hollyqui Nov 26, 2024
622427a
Merge branch 'main' into kalei/api-working-branch
bkb2135 Nov 26, 2024
120a90a
Move streaming of miners into query_miners function
bkb2135 Nov 26, 2024
7200833
WIP: Add system prompt
dbobrenko Nov 27, 2024
0b77548
Merge branch 'SN1-331-create-api-on-validator' into feature/SN1-329-m…
dbobrenko Nov 27, 2024
6cf7aa5
Use query_miners in api
bkb2135 Nov 27, 2024
fed8964
Merge branch 'kalei/api-working-branch' into feature/SN1-329-moa-endp…
dbobrenko Nov 27, 2024
07620fd
Fix syntax errors
bkb2135 Nov 27, 2024
21c2236
Manually dump models
bkb2135 Nov 27, 2024
ba900a2
Use autoawq 0.2.0
richwardle Nov 27, 2024
26d1db1
Support delta or message in sn19 response
richwardle Nov 27, 2024
0389d1b
Remove Unecessary Line
richwardle Nov 27, 2024
351c14c
Formatting
bkb2135 Nov 27, 2024
b132d13
Merge pull request #467 from macrocosm-os/hotfix/support-multiple-sn1…
bkb2135 Nov 27, 2024
0f6bfd7
Update pyproject.toml
bkb2135 Nov 27, 2024
fd186a9
Merge remote-tracking branch 'origin/release/v2.13.2' into kalei/api-…
bkb2135 Nov 27, 2024
9a037cc
Add test_api to scripts
bkb2135 Nov 28, 2024
6ff5600
SN1-328: Finish MoM, add sampling params
dbobrenko Nov 28, 2024
72fed85
SN1-327: Clean up, link system prompt ticket
dbobrenko Nov 28, 2024
fd476ce
Fix syntax
dbobrenko Nov 28, 2024
0b37518
Update api_keys.json
bkb2135 Nov 28, 2024
9485e56
Update prompting/api/gpt_endpoints/api.py
bkb2135 Nov 28, 2024
f176821
Add keys example
dbobrenko Nov 28, 2024
1bf3996
Push Working Changes
richwardle Nov 28, 2024
6bab37e
Add Optional Api Deployment
bkb2135 Nov 28, 2024
bb115cf
Fixing formatting
Hollyqui Dec 2, 2024
09e4103
sort: fix import formatting
richwardle Dec 2, 2024
2852e41
Merge with API branch
dbobrenko Dec 2, 2024
e2965fb
Fix synapse system prompt
dbobrenko Dec 2, 2024
87d8430
Merge branch 'staging' into feature/SN1-329-moa-endpoint
dbobrenko Dec 18, 2024
a7c53c8
WIP: Move MoA to new API
dbobrenko Dec 18, 2024
871fa48
Merge branch 'staging' into feature/SN1-329-moa-endpoint
dbobrenko Dec 18, 2024
34949f1
Add MoA code
dbobrenko Dec 18, 2024
edf1bfe
WIP: Finish MoA
dbobrenko Dec 19, 2024
f82105a
Merge branch 'staging' into feature/SN1-329-moa-endpoint
dbobrenko Dec 19, 2024
27b035a
Finish implementation
dbobrenko Dec 19, 2024
f75d28c
Clean up code
dbobrenko Dec 19, 2024
bc48fb9
Run pre-commit hook
dbobrenko Dec 19, 2024
e53a5db
Merge branch 'staging' into feature/SN1-329-moa-endpoint
dbobrenko Dec 20, 2024
1445e1f
Merge with staging
dbobrenko Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion prompting/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def generate_query(
"""Generates a query to be used for generating the challenge"""
logger.info("🤖 Generating query...")
llm_messages = [LLMMessage(role="system", content=self.query_system_prompt)] if self.query_system_prompt else []
llm_messages += [LLMMessage(role="user", content=message) for message in messages]
llm_messages.extend([LLMMessage(role="user", content=message) for message in messages])

self.query = LLMWrapper.chat_complete(messages=LLMMessages(*llm_messages))

Expand Down
104 changes: 104 additions & 0 deletions validator_api/chat_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import asyncio
import json
import random
from typing import AsyncGenerator

import httpx
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger

from shared.epistula import make_openai_query
from shared.settings import shared_settings
from shared.uids import get_uids


async def forward_response(uid: int, body: dict[str, any], chunks: list[str]):
if not shared_settings.SCORE_ORGANICS: # Allow disabling of scoring by default
return

# if body.get("task") != "InferenceTask":
# logger.debug(f"Skipping forwarding for non-inference task: {body.get('task')}")
# return
url = f"http://{shared_settings.VALIDATOR_API}/scoring"
payload = {"body": body, "chunks": chunks, "uid": uid}
# headers = {
# "Authorization": f"Bearer {shared_settings.SCORING_KEY}", #Add API key in Authorization header
# "Content-Type": "application/json",
# }
try:
timeout = httpx.Timeout(timeout=120.0, connect=60.0, read=30.0, write=30.0, pool=5.0)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload) # , headers=headers)
if response.status_code == 200:
logger.info(f"Forwarding response completed with status {response.status_code}")

else:
logger.exception(
f"Forwarding response uid {uid} failed with status {response.status_code} and payload {payload}"
)

except Exception as e:
logger.error(f"Tried to forward response to {url} with payload {payload}")
logger.exception(f"Error while forwarding response: {e}")


async def stream_response(
response, collected_chunks: list[str], body: dict[str, any], uid: int
) -> AsyncGenerator[str, None]:
chunks_received = False
try:
async for chunk in response:
chunks_received = True
collected_chunks.append(chunk.choices[0].delta.content)
yield f"data: {json.dumps(chunk.model_dump())}\n\n"

if not chunks_received:
logger.error("Stream is empty: No chunks were received")
yield 'data: {"error": "502 - Response is empty"}\n\n'
yield "data: [DONE]\n\n"

# Forward the collected chunks after streaming is complete
asyncio.create_task(forward_response(uid=uid, body=body, chunks=collected_chunks))
except asyncio.CancelledError:
logger.info("Client disconnected, streaming cancelled")
raise
except Exception as e:
logger.exception(f"Error during streaming: {e}")
yield 'data: {"error": "Internal server Error"}\n\n'


async def chat_completion(body: dict[str, any], uid: int | None = None) -> tuple | StreamingResponse:
"""Handle regular chat completion without mixture of miners."""
if uid is None:
uid = random.choice(get_uids(sampling_mode="top_incentive", k=100))

if uid is None:
logger.error("No available miner found")
raise HTTPException(status_code=503, detail="No available miner found")

logger.debug(f"Querying uid {uid}")
STREAM = body.get("stream", False)

collected_chunks: list[str] = []

logger.info(f"Making {'streaming' if STREAM else 'non-streaming'} openai query with body: {body}")
response = await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=STREAM)

if STREAM:
return StreamingResponse(
stream_response(response, collected_chunks, body, uid),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
else:
asyncio.create_task(forward_response(uid=uid, body=body, chunks=response[1]))
return response[0]


async def get_response_from_miner(body: dict[str, any], uid: int) -> tuple:
"""Get response from a single miner."""
return await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=False)
98 changes: 10 additions & 88 deletions validator_api/gpt_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,28 @@
import asyncio
import json
import random

import httpx
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, Request
from loguru import logger
from starlette.responses import StreamingResponse

from shared.epistula import make_openai_query
from shared.settings import shared_settings
from shared.uids import get_uids
from validator_api.chat_completion import chat_completion
from validator_api.mixture_of_miners import mixture_of_miners

router = APIRouter()


async def forward_response(uid: int, body: dict[str, any], chunks: list[str]):
if not shared_settings.SCORE_ORGANICS: # Allow disabling of scoring by default
return

# if body.get("task") != "InferenceTask":
# logger.debug(f"Skipping forwarding for non-inference task: {body.get('task')}")
# return
url = f"http://{shared_settings.VALIDATOR_API}/scoring"
payload = {"body": body, "chunks": chunks, "uid": uid}
# headers = {
# "Authorization": f"Bearer {shared_settings.SCORING_KEY}", #Add API key in Authorization header
# "Content-Type": "application/json",
# }
try:
timeout = httpx.Timeout(timeout=120.0, connect=60.0, read=30.0, write=30.0, pool=5.0)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=payload) # , headers=headers)
if response.status_code == 200:
logger.info(f"Forwarding response completed with status {response.status_code}")

else:
logger.exception(
f"Forwarding response uid {uid} failed with status {response.status_code} and payload {payload}"
)

except Exception as e:
logger.error(f"Tried to forward response to {url} with payload {payload}")
logger.exception(f"Error while forwarding response: {e}")


@router.post("/v1/chat/completions")
async def chat_completion(request: Request): # , cbackground_tasks: BackgroundTasks):
async def completions(request: Request):
"""Main endpoint that handles both regular and mixture of miners chat completion."""
try:
body = await request.json()
body["seed"] = int(body.get("seed") or random.randint(0, 1000000))
STREAM = body.get("stream") or False
logger.debug(f"Streaming: {STREAM}")
# Get random miner from top 100 incentive.
uid = random.choice(get_uids(sampling_mode="top_incentive", k=100))
# uid = get_available_miner(task=body.get("task"), model=body.get("model"))
if uid is None:
logger.error("No available miner found")
raise HTTPException(status_code=503, detail="No available miner found")
logger.debug(f"Querying uid {uid}")

collected_chunks: list[str] = []

# Create a wrapper for the streaming response
async def stream_with_error_handling():
chunks_received = False
try:
async for chunk in response:
chunks_received = True
collected_chunks.append(chunk.choices[0].delta.content)
yield f"data: {json.dumps(chunk.model_dump())}\n\n"

if not chunks_received:
logger.error("Stream is empty: No chunks were received")
yield 'data: {"error": "502 - Response is empty"}\n\n'
yield "data: [DONE]\n\n"

# Once the stream is done, forward the collected chunks
asyncio.create_task(forward_response(uid=uid, body=body, chunks=collected_chunks))
# background_tasks.add_task(forward_response, uid=uid, body=body, chunks=collected_chunks)
except asyncio.CancelledError:
logger.info("Client disconnected, streaming cancelled")
raise
except Exception as e:
logger.exception(f"Error during streaming: {e}")
yield 'data: {"error": "Internal server Error"}\n\n'

logger.info(f"Making {'streaming' if STREAM else 'non-streaming'} openai query with body: {body}")
response = await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=STREAM)

if STREAM:
return StreamingResponse(
stream_with_error_handling(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
# Choose between regular completion and mixture of miners.
if body.get("mixture", False):
return await mixture_of_miners(body)
else:
asyncio.create_task(forward_response(uid=uid, body=body, chunks=response[1]))
return response[0]
return await chat_completion(body)

except Exception as e:
logger.exception(f"Error setting up streaming: {e}")
logger.exception(f"Error in chat completion: {e}")
return StreamingResponse(content="Internal Server Error", status_code=500)
88 changes: 88 additions & 0 deletions validator_api/mixture_of_miners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import asyncio
import copy
import random

from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger

from shared.uids import get_uids
from validator_api.chat_completion import chat_completion, get_response_from_miner

DEFAULT_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query.
Your task is to synthesize these responses into a single, high-quality and concise response.
It is crucial to follow the provided instuctions or examples in the given prompt if any, and ensure the answer is in correct and expected format.
Critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect.
Your response should not simply replicate the given answers but should offer a refined and accurate reply to the instruction.
Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
Responses from models:"""

TASK_SYSTEM_PROMPT = {
None: DEFAULT_SYSTEM_PROMPT,
# Add more task-specific system prompts here.
}

NUM_MIXTURE_MINERS = 5
TOP_INCENTIVE_POOL = 100


async def get_miner_response(body: dict, uid: str) -> tuple | None:
"""Get response from a single miner with error handling."""
try:
return await get_response_from_miner(body, uid)
except Exception as e:
logger.error(f"Error getting response from miner {uid}: {e}")
return None


async def mixture_of_miners(body: dict[str, any]) -> tuple | StreamingResponse:
"""Handle chat completion with mixture of miners approach.

Based on Mixture-of-Agents Enhances Large Language Model Capabilities, 2024, Wang et al.:
https://arxiv.org/abs/2406.04692

Args:
body: Query parameters:
messages: User prompt.
stream: If True, stream the response.
model: Optional model used for inference, SharedSettings.LLM_MODEL is used by default.
task: Optional task, see prompting/tasks/task_registry.py, InferenceTask is used by default.
"""
body_first_step = copy.deepcopy(body)
body_first_step["stream"] = False

# Get multiple miners
miner_uids = get_uids(sampling_mode="top_incentive", k=NUM_MIXTURE_MINERS)
if len(miner_uids) == 0:
raise HTTPException(status_code=503, detail="No available miners found")

# Concurrently collect responses from all miners.
miner_tasks = [get_miner_response(body_first_step, uid) for uid in miner_uids]
responses = await asyncio.gather(*miner_tasks)

# Filter out None responses (failed requests).
valid_responses = [r for r in responses if r is not None]

if not valid_responses:
raise HTTPException(status_code=503, detail="Failed to get responses from miners")

# Extract completions from the responses.
completions = [response[1][0] for response in valid_responses if response and len(response) > 1]

task_name = body.get("task")
system_prompt = TASK_SYSTEM_PROMPT.get(task_name, DEFAULT_SYSTEM_PROMPT)

# Aggregate responses into one system prompt.
agg_system_prompt = system_prompt + "\n" + "\n".join([f"{i+1}. {comp}" for i, comp in enumerate(completions)])

# Prepare new messages with the aggregated system prompt.
new_messages = [{"role": "system", "content": agg_system_prompt}]
new_messages.extend([msg for msg in body["messages"] if msg["role"] != "system"])

# Update the body with the new messages.
final_body = copy.deepcopy(body)
final_body["messages"] = new_messages

# Get final response using a random top miner.
final_uid = random.choice(get_uids(sampling_mode="top_incentive", k=TOP_INCENTIVE_POOL))
return await chat_completion(final_body, final_uid)
Loading