Skip to content

Commit

Permalink
SN1-361: Query multiple miners to return fastest response + Scoring b…
Browse files Browse the repository at this point in the history
…ug fix (#521)
  • Loading branch information
Hollyqui authored Jan 3, 2025
1 parent dbeaa29 commit ebb4f85
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 40 deletions.
7 changes: 4 additions & 3 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,12 @@ async def spawn_loops(task_queue, scoring_queue, reward_events):
asyncio.run(spawn_loops(task_queue, scoring_queue, reward_events))


def start_api():
def start_api(scoring_queue, reward_events):
async def start():
from prompting.api.api import start_scoring_api # noqa: F401

await start_scoring_api()
await start_scoring_api(scoring_queue, reward_events)

while True:
await asyncio.sleep(10)
logger.debug("Running API...")
Expand Down Expand Up @@ -125,7 +126,7 @@ async def main():

if shared_settings.DEPLOY_SCORING_API:
# Use multiprocessing to bypass API blocking issue
api_process = mp.Process(target=start_api, name="API_Process")
api_process = mp.Process(target=start_api, args=(scoring_queue, reward_events), name="API_Process")
api_process.start()
processes.append(api_process)

Expand Down
5 changes: 4 additions & 1 deletion prompting/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from prompting.api.miner_availabilities.api import router as miner_availabilities_router
from prompting.api.scoring.api import router as scoring_router
from prompting.rewards.scoring import task_scorer
from shared.settings import shared_settings

app = FastAPI()
Expand All @@ -17,7 +18,9 @@ def health():
return {"status": "healthy"}


async def start_scoring_api():
async def start_scoring_api(scoring_queue, reward_events):
task_scorer.scoring_queue = scoring_queue
task_scorer.reward_events = reward_events
logger.info(f"Starting Scoring API on https://0.0.0.0:{shared_settings.SCORING_API_PORT}")
uvicorn.run(
"prompting.api.api:app", host="0.0.0.0", port=shared_settings.SCORING_API_PORT, loop="asyncio", reload=False
Expand Down
2 changes: 1 addition & 1 deletion prompting/weight_setting/weight_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def set_weights(
"weights": processed_weights.flatten(),
"raw_weights": str(list(weights.flatten())),
"averaged_weights": str(list(averaged_weights.flatten())),
"block": ttl_get_block(),
"block": ttl_get_block(subtensor=subtensor),
}
)
step_filename = "weights.csv"
Expand Down
1 change: 0 additions & 1 deletion scripts/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from loguru import logger

from shared.epistula import query_miners
from shared.settings import shared_settings

"""
This has assumed you have:
Expand Down
2 changes: 1 addition & 1 deletion shared/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def ttl_get_block(subtensor: bt.Subtensor | None = None) -> int:
efficiently reduces the workload on the blockchain interface.
Example:
current_block = ttl_get_block(self)
current_block = ttl_get_block(subtensor=subtensor)
Note: self here is the miner or validator instance
"""
Expand Down
151 changes: 118 additions & 33 deletions validator_api/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
import random
from typing import AsyncGenerator
from typing import AsyncGenerator, List, Optional

import httpx
from fastapi import HTTPException
Expand All @@ -14,14 +14,15 @@


async def forward_response(uid: int, body: dict[str, any], chunks: list[str]):
uid = int(uid) # sometimes uid is type np.uint64
logger.info(f"Forwarding response to scoring with body: {body}")
if not shared_settings.SCORE_ORGANICS: # Allow disabling of scoring by default
uid = int(uid)
logger.info(f"Forwarding response from uid {uid} to scoring with body: {body} and chunks: {chunks}")
if not shared_settings.SCORE_ORGANICS:
return

if body.get("task") != "InferenceTask":
logger.debug(f"Skipping forwarding for non-inference task: {body.get('task')}")
return

url = f"http://{shared_settings.VALIDATOR_API}/scoring"
payload = {"body": body, "chunks": chunks, "uid": uid}
try:
Expand All @@ -32,73 +33,157 @@ async def forward_response(uid: int, body: dict[str, any], chunks: list[str]):
)
if response.status_code == 200:
logger.info(f"Forwarding response completed with status {response.status_code}")

else:
logger.exception(
f"Forwarding response uid {uid} failed with status {response.status_code} and payload {payload}"
)

except Exception as e:
logger.error(f"Tried to forward response to {url} with payload {payload}")
logger.exception(f"Error while forwarding response: {e}")


async def stream_response(
response, collected_chunks: list[str], body: dict[str, any], uid: int
async def stream_from_first_response(
responses: List[asyncio.Task], collected_chunks_list: List[List[str]], body: dict[str, any], uids: List[int]
) -> AsyncGenerator[str, None]:
chunks_received = False
first_valid_response = None
try:
async for chunk in response:
# Wait for the first valid response
while responses and first_valid_response is None:
done, pending = await asyncio.wait(responses, return_when=asyncio.FIRST_COMPLETED)

for task in done:
try:
response = await task
if response and not isinstance(response, Exception):
first_valid_response = response
break
except Exception as e:
logger.error(f"Error in miner response: {e}")
responses.remove(task)

if first_valid_response is None:
logger.error("No valid response received from any miner")
yield 'data: {"error": "502 - No valid response received"}\n\n'
return

# Stream the first valid response
chunks_received = False
async for chunk in first_valid_response:
chunks_received = True
collected_chunks.append(chunk.choices[0].delta.content)
collected_chunks_list[0].append(chunk.choices[0].delta.content)
yield f"data: {json.dumps(chunk.model_dump())}\n\n"

if not chunks_received:
logger.error("Stream is empty: No chunks were received")
yield 'data: {"error": "502 - Response is empty"}\n\n'

yield "data: [DONE]\n\n"

# Forward the collected chunks after streaming is complete
asyncio.create_task(forward_response(uid=uid, body=body, chunks=collected_chunks))
# Continue collecting remaining responses in background for scoring
remaining = asyncio.gather(*pending, return_exceptions=True)
asyncio.create_task(collect_remaining_responses(remaining, collected_chunks_list, body, uids))

except asyncio.CancelledError:
logger.info("Client disconnected, streaming cancelled")
for task in responses:
task.cancel()
raise
except Exception as e:
logger.exception(f"Error during streaming: {e}")
yield 'data: {"error": "Internal server Error"}\n\n'


async def chat_completion(body: dict[str, any], uid: int | None = None) -> tuple | StreamingResponse:
"""Handle regular chat completion without mixture of miners."""
if uid is None:
uid = random.choice(get_uids(sampling_mode="top_incentive", k=100))
async def collect_remaining_responses(
remaining: asyncio.Task, collected_chunks_list: List[List[str]], body: dict[str, any], uids: List[int]
):
"""Collect remaining responses for scoring without blocking the main response."""
try:
responses = await remaining
logger.debug(f"responses to forward: {responses}")
for i, response in enumerate(responses):
if isinstance(response, Exception):
logger.error(f"Error collecting response from uid {uids[i+1]}: {response}")
continue

async for chunk in response:
collected_chunks_list[i + 1].append(chunk.choices[0].delta.content)
for uid, chunks in zip(uids, collected_chunks_list):
# Forward for scoring
asyncio.create_task(forward_response(uid, body, chunks))

if uid is None:
logger.error("No available miner found")
raise HTTPException(status_code=503, detail="No available miner found")
except Exception as e:
logger.exception(f"Error collecting remaining responses: {e}")

logger.debug(f"Querying uid {uid}")
STREAM = body.get("stream", False)

collected_chunks: list[str] = []
async def get_response_from_miner(body: dict[str, any], uid: int) -> tuple:
"""Get response from a single miner."""
return await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=False)

logger.info(f"Making {'streaming' if STREAM else 'non-streaming'} openai query with body: {body}")
response = await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=STREAM)

async def chat_completion(
body: dict[str, any], uids: Optional[int] = None, num_miners: int = 3
) -> tuple | StreamingResponse:
"""Handle chat completion with multiple miners in parallel."""
# Get multiple UIDs if none specified
if uids is None:
uids = list(get_uids(sampling_mode="top_incentive", k=100))
if uids is None or len(uids) == 0: # if not uids throws error, figure out how to fix
logger.error("No available miners found")
raise HTTPException(status_code=503, detail="No available miners found")
selected_uids = random.sample(uids, min(num_miners, len(uids)))
else:
selected_uids = uids[:num_miners] # If UID is specified, only use that one

logger.debug(f"Querying uids {selected_uids}")
STREAM = body.get("stream", False)

# Initialize chunks collection for each miner
collected_chunks_list = [[] for _ in selected_uids]

if STREAM:
# Create tasks for all miners
response_tasks = [
asyncio.create_task(
make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=True)
)
for uid in selected_uids
]

return StreamingResponse(
stream_response(response, collected_chunks, body, uid),
stream_from_first_response(response_tasks, collected_chunks_list, body, selected_uids),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
else:
asyncio.create_task(forward_response(uid=uid, body=body, chunks=response[1]))
return response[0]


async def get_response_from_miner(body: dict[str, any], uid: int) -> tuple:
"""Get response from a single miner."""
return await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=False)
# For non-streaming requests, wait for first valid response
response_tasks = [asyncio.create_task(get_response_from_miner(body, uid)) for uid in selected_uids]

first_valid_response = None
collected_responses = []

while response_tasks and first_valid_response is None:
done, pending = await asyncio.wait(response_tasks, return_when=asyncio.FIRST_COMPLETED)

for task in done:
try:
response = await task
if response and isinstance(response, tuple):
if first_valid_response is None:
first_valid_response = response
collected_responses.append(response)
except Exception as e:
logger.error(f"Error in miner response: {e}")
response_tasks.remove(task)

if first_valid_response is None:
raise HTTPException(status_code=502, detail="No valid response received")

# Forward all collected responses for scoring in the background
for i, response in enumerate(collected_responses):
if response and isinstance(response, tuple):
asyncio.create_task(forward_response(uid=selected_uids[i], body=body, chunks=response[1]))

return first_valid_response[0] # Return only the response object, not the chunks

0 comments on commit ebb4f85

Please sign in to comment.