Skip to content

Commit

Permalink
Adding web retrieval endpoint (#528)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hollyqui authored Jan 3, 2025
1 parent 667e264 commit 4654d3a
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 54 deletions.
72 changes: 53 additions & 19 deletions prompting/api/scoring/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from fastapi import APIRouter, Depends, Header, HTTPException, Request
from loguru import logger

from prompting.datasets.random_website import DDGDatasetEntry
from prompting.llms.model_zoo import ModelZoo
from prompting.rewards.scoring import task_scorer
from prompting.tasks.inference import InferenceTask
from prompting.tasks.web_retrieval import WebRetrievalTask
from shared.base import DatasetEntry
from shared.dendrite import DendriteResponseEvent
from shared.epistula import SynapseStreamResult
Expand Down Expand Up @@ -37,22 +39,54 @@ async def score_response(request: Request, api_key_data: dict = Depends(validate
uid = int(payload.get("uid"))
chunks = payload.get("chunks")
llm_model = ModelZoo.get_model_by_id(model) if (model := body.get("model")) else None
task_scorer.add_to_queue(
task=InferenceTask(
messages=[msg["content"] for msg in body.get("messages")],
llm_model=llm_model,
llm_model_id=body.get("model"),
seed=int(body.get("seed", 0)),
sampling_params=body.get("sampling_params", {}),
),
response=DendriteResponseEvent(
uids=[uid],
stream_results=[SynapseStreamResult(accumulated_chunks=[chunk for chunk in chunks if chunk is not None])],
timeout=shared_settings.NEURON_TIMEOUT,
),
dataset_entry=DatasetEntry(),
block=shared_settings.METAGRAPH.block,
step=-1,
task_id=str(uuid.uuid4()),
)
logger.info("Organic tas appended to scoring queue")
task = body.get("task")
if task == "InferenceTask":
logger.info(f"Received Organic InferenceTask with body: {body}")
task_scorer.add_to_queue(
task=InferenceTask(
messages=[msg["content"] for msg in body.get("messages")],
llm_model=llm_model,
llm_model_id=body.get("model"),
seed=int(body.get("seed", 0)),
sampling_params=body.get("sampling_params", {}),
),
response=DendriteResponseEvent(
uids=[uid],
stream_results=[
SynapseStreamResult(accumulated_chunks=[chunk for chunk in chunks if chunk is not None])
],
timeout=shared_settings.NEURON_TIMEOUT,
),
dataset_entry=DatasetEntry(),
block=shared_settings.METAGRAPH.block,
step=-1,
task_id=str(uuid.uuid4()),
)
elif task == "WebRetrievalTask":
logger.info(f"Received Organic WebRetrievalTask with body: {body}")
try:
search_term = body.get("messages")[0].get("content")
except Exception as ex:
logger.error(f"Failed to get search term from messages: {ex}, can't score WebRetrievalTask")
return

task_scorer.add_to_queue(
task=WebRetrievalTask(
messages=[msg["content"] for msg in body.get("messages")],
seed=int(body.get("seed", 0)),
sampling_params=body.get("sampling_params", {}),
query=search_term,
),
response=DendriteResponseEvent(
uids=[uid],
stream_results=[
SynapseStreamResult(accumulated_chunks=[chunk for chunk in chunks if chunk is not None])
],
timeout=shared_settings.NEURON_TIMEOUT,
),
dataset_entry=DDGDatasetEntry(search_term=search_term),
block=shared_settings.METAGRAPH.block,
step=-1,
task_id=str(uuid.uuid4()),
)
logger.info("Organic task appended to scoring queue")
4 changes: 2 additions & 2 deletions prompting/datasets/random_website.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

class DDGDatasetEntry(DatasetEntry):
search_term: str
website_url: str
website_content: str
website_url: str = None
website_content: str = None


class DDGDataset(BaseDataset):
Expand Down
2 changes: 1 addition & 1 deletion shared/dendrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def model_dump(self):

class DendriteResponseEvent(BaseModel):
uids: np.ndarray | list[float]
axons: list[str]
timeout: float
stream_results: list[SynapseStreamResult]
axons: list[str] = []
completions: list[str] = []
status_messages: list[str] = []
status_codes: list[int] = []
Expand Down
2 changes: 1 addition & 1 deletion shared/epistula.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def merged_stream(responses: list[AsyncGenerator]):
logger.error(f"Error while streaming: {e}")


async def query_miners(uids, body: dict[str, Any]):
async def query_miners(uids, body: dict[str, Any]) -> list[SynapseStreamResult]:
try:
tasks = []
for uid in uids:
Expand Down
3 changes: 2 additions & 1 deletion shared/uids.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def get_top_incentive_uids(k: int, vpermit_tao_limit: int) -> np.ndarray:
# Extract the top uids.
top_k_uids = [uid for uid, incentive in uid_incentive_pairs_sorted[:k]]

return np.array(top_k_uids).astype(int)
return list(np.array(top_k_uids).astype(int))
# return [int(k) for k in top_k_uids]


def get_uids(
Expand Down
31 changes: 1 addition & 30 deletions validator_api/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,14 @@
import random
from typing import AsyncGenerator, List, Optional

import httpx
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger

from shared.epistula import make_openai_query
from shared.settings import shared_settings
from shared.uids import get_uids


async def forward_response(uid: int, body: dict[str, any], chunks: list[str]):
uid = int(uid)
logger.info(f"Forwarding response from uid {uid} to scoring with body: {body} and chunks: {chunks}")
if not shared_settings.SCORE_ORGANICS:
return

if body.get("task") != "InferenceTask":
logger.debug(f"Skipping forwarding for non-inference task: {body.get('task')}")
return

url = f"http://{shared_settings.VALIDATOR_API}/scoring"
payload = {"body": body, "chunks": chunks, "uid": uid}
try:
timeout = httpx.Timeout(timeout=120.0, connect=60.0, read=30.0, write=30.0, pool=5.0)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url, json=payload, headers={"api-key": shared_settings.SCORING_KEY, "Content-Type": "application/json"}
)
if response.status_code == 200:
logger.info(f"Forwarding response completed with status {response.status_code}")
else:
logger.exception(
f"Forwarding response uid {uid} failed with status {response.status_code} and payload {payload}"
)
except Exception as e:
logger.error(f"Tried to forward response to {url} with payload {payload}")
logger.exception(f"Error while forwarding response: {e}")
from validator_api.utils import forward_response


async def stream_from_first_response(
Expand Down
51 changes: 51 additions & 0 deletions validator_api/gpt_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import asyncio
import json
import random

import numpy as np
from fastapi import APIRouter, Depends, Header, HTTPException, Request
from loguru import logger
from starlette.responses import StreamingResponse

from shared.epistula import SynapseStreamResult, query_miners
from shared.settings import shared_settings
from shared.uids import get_uids
from validator_api.chat_completion import chat_completion
from validator_api.mixture_of_miners import mixture_of_miners
from validator_api.utils import forward_response

router = APIRouter()

Expand Down Expand Up @@ -37,3 +43,48 @@ async def completions(request: Request, api_key: str = Depends(validate_api_key)
except Exception as e:
logger.exception(f"Error in chat completion: {e}")
return StreamingResponse(content="Internal Server Error", status_code=500)


@router.post("/web_retrieval")
async def web_retrieval(search_query: str, n_miners: int = 10, uids: list[int] = None):
uids = list(get_uids(sampling_mode="random", k=n_miners))
logger.debug(f"🔍 Querying uids: {uids}")
if len(uids) == 0:
logger.warning("No available miners. This should already have been caught earlier.")
return

body = {
"seed": random.randint(0, 1_000_000),
"sampling_parameters": shared_settings.SAMPLING_PARAMS,
"task": "WebRetrievalTask",
"messages": [
{"role": "user", "content": search_query},
],
}
stream_results = await query_miners(uids, body)
results = [
"".join(res.accumulated_chunks)
for res in stream_results
if isinstance(res, SynapseStreamResult) and res.accumulated_chunks
]
distinct_results = list(np.unique(results))
logger.info(
f"🔍 Collected responses from {len(stream_results)} miners. {len(results)} responded successfully with a total of {len(distinct_results)} distinct results"
)
loaded_results = []
for result in distinct_results:
try:
loaded_results.append(json.loads(result))
logger.info(f"🔍 Result: {result}")
except Exception:
logger.error(f"🔍 Result: {result}")
if len(loaded_results) == 0:
raise HTTPException(status_code=500, detail="No miner responded successfully")

for uid, res in zip(uids, stream_results):
asyncio.create_task(
forward_response(
uid=uid, body=body, chunks=res.accumulated_chunks if res and res.accumulated_chunks else []
)
)
return loaded_results
35 changes: 35 additions & 0 deletions validator_api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import httpx
from loguru import logger

from shared.settings import shared_settings


# TODO: Modify this so that all the forwarded responses are sent in a single request. This is both more efficient but
# also means that on the validator side all responses are scored at once, speeding up the scoring process.
async def forward_response(uid: int, body: dict[str, any], chunks: list[str]):
uid = int(uid)
logger.info(f"Forwarding response from uid {uid} to scoring with body: {body} and chunks: {chunks}")
if not shared_settings.SCORE_ORGANICS:
return

if body.get("task") != "InferenceTask" and body.get("task") != "WebRetrievalTask":
logger.debug(f"Skipping forwarding for non- inference/web retrieval task: {body.get('task')}")
return

url = f"http://{shared_settings.VALIDATOR_API}/scoring"
payload = {"body": body, "chunks": chunks, "uid": uid}
try:
timeout = httpx.Timeout(timeout=120.0, connect=60.0, read=30.0, write=30.0, pool=5.0)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url, json=payload, headers={"api-key": shared_settings.SCORING_KEY, "Content-Type": "application/json"}
)
if response.status_code == 200:
logger.info(f"Forwarding response completed with status {response.status_code}")
else:
logger.exception(
f"Forwarding response uid {uid} failed with status {response.status_code} and payload {payload}"
)
except Exception as e:
logger.error(f"Tried to forward response to {url} with payload {payload}")
logger.exception(f"Error while forwarding response: {e}")

0 comments on commit 4654d3a

Please sign in to comment.