Skip to content

Commit

Permalink
SN1-362 fix issues with scoring endpoint (#507)
Browse files Browse the repository at this point in the history
Co-authored-by: bkb2135 <[email protected]>
Co-authored-by: richwardle <[email protected]>
  • Loading branch information
3 people authored Dec 19, 2024
1 parent cc16e12 commit 9c1659d
Show file tree
Hide file tree
Showing 15 changed files with 252 additions and 114 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,4 @@ wandb
.vscode
api_keys.json
prompting/api/api_keys.json
weights.csv
139 changes: 102 additions & 37 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import multiprocessing as mp
import time

import torch
from loguru import logger

from prompting.api.api import start_scoring_api
Expand All @@ -20,48 +21,112 @@
from prompting.weight_setting.weight_setter import weight_setter
from shared.profiling import profiler

torch.multiprocessing.set_start_method("spawn", force=True)

NEURON_SAMPLE_SIZE = 100


def create_loop_process(task_queue, scoring_queue, reward_events):
async def spawn_loops(task_queue, scoring_queue, reward_events):
logger.info("Starting Profiler...")
asyncio.create_task(profiler.print_stats(), name="Profiler"),
logger.info("Starting ModelScheduler...")
asyncio.create_task(model_scheduler.start(scoring_queue), name="ModelScheduler"),
logger.info("Starting TaskScorer...")
asyncio.create_task(task_scorer.start(scoring_queue, reward_events), name="TaskScorer"),
logger.info("Starting WeightSetter...")
asyncio.create_task(weight_setter.start(reward_events))

# Main monitoring loop
start = time.time()

logger.info("Starting Main Monitoring Loop...")
while True:
await asyncio.sleep(5)
current_time = time.time()
time_diff = current_time - start
start = current_time

# Check if all tasks are still running
logger.debug(f"Running {time_diff:.2f} seconds")
logger.debug(f"Number of tasks in Task Queue: {len(task_queue)}")
logger.debug(f"Number of tasks in Scoring Queue: {len(scoring_queue)}")
logger.debug(f"Number of tasks in Reward Events: {len(reward_events)}")

asyncio.run(spawn_loops(task_queue, scoring_queue, reward_events))


def start_api():
async def start():
await start_scoring_api()
while True:
await asyncio.sleep(10)
logger.debug("Running API...")

asyncio.run(start())


def create_task_loop(task_queue, scoring_queue):
async def start(task_queue, scoring_queue):
logger.info("Starting AvailabilityCheckingLoop...")
asyncio.create_task(availability_checking_loop.start())

logger.info("Starting TaskSender...")
asyncio.create_task(task_sender.start(task_queue, scoring_queue))

logger.info("Starting TaskLoop...")
asyncio.create_task(task_loop.start(task_queue, scoring_queue))
while True:
await asyncio.sleep(10)
logger.debug("Running task loop...")

asyncio.run(start(task_queue, scoring_queue))


async def main():
# will start checking the availability of miners at regular intervals, needed for API and Validator
asyncio.create_task(availability_checking_loop.start())

if shared_settings.DEPLOY_SCORING_API:
# Use multiprocessing to bypass API blocking issue.
api_process = mp.Process(target=lambda: asyncio.run(start_scoring_api()))
api_process.start()

GPUInfo.log_gpu_info()
# start profiling
asyncio.create_task(profiler.print_stats())

# start rotating LLM models
asyncio.create_task(model_scheduler.start())

# start creating tasks
asyncio.create_task(task_loop.start())

# will start checking the availability of miners at regular intervals
asyncio.create_task(availability_checking_loop.start())

# start sending tasks to miners
asyncio.create_task(task_sender.start())

# sets weights at regular intervals (synchronised between all validators)
asyncio.create_task(weight_setter.start())

# start scoring tasks in separate loop
asyncio.create_task(task_scorer.start())
# # TODO: Think about whether we want to store the task queue locally in case of a crash
# # TODO: Possibly run task scorer & model scheduler with a lock so I don't unload a model whilst it's generating
# # TODO: Make weight setting happen as specific intervals as we load/unload models
start = time.time()
await asyncio.sleep(60)
while True:
await asyncio.sleep(5)
time_diff = -start + (start := time.time())
logger.debug(f"Running {time_diff:.2f} seconds")
with torch.multiprocessing.Manager() as manager:
reward_events = manager.list()
scoring_queue = manager.list()
task_queue = manager.list()

# Create process pool for managed processes
processes = []

try:
# # Start checking the availability of miners at regular intervals

if shared_settings.DEPLOY_SCORING_API:
# Use multiprocessing to bypass API blocking issue
api_process = mp.Process(target=start_api, name="API_Process")
api_process.start()
processes.append(api_process)

loop_process = mp.Process(
target=create_loop_process, args=(task_queue, scoring_queue, reward_events), name="LoopProcess"
)
task_loop_process = mp.Process(
target=create_task_loop, args=(task_queue, scoring_queue), name="TaskLoopProcess"
)
loop_process.start()
task_loop_process.start()
processes.append(loop_process)
processes.append(task_loop_process)
GPUInfo.log_gpu_info()

while True:
await asyncio.sleep(10)
logger.debug("Running...")

except Exception as e:
logger.error(f"Main loop error: {e}")
raise
finally:
# Clean up processes
for process in processes:
if process.is_alive():
process.terminate()
process.join()


# The main function parses the configuration and runs the validator.
Expand Down
2 changes: 1 addition & 1 deletion prompting/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def health():


async def start_scoring_api():
logger.info("Starting API...")
logger.info(f"Starting Scoring API on https://0.0.0.0:{shared_settings.SCORING_API_PORT}")
uvicorn.run(
"prompting.api.api:app", host="0.0.0.0", port=shared_settings.SCORING_API_PORT, loop="asyncio", reload=False
)
21 changes: 19 additions & 2 deletions prompting/api/scoring/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import uuid
from typing import Any

from fastapi import APIRouter, Request
from fastapi import APIRouter, Depends, Header, HTTPException, Request
from loguru import logger

from prompting.llms.model_zoo import ModelZoo
from prompting.rewards.scoring import task_scorer
Expand All @@ -14,10 +15,25 @@
router = APIRouter()


def validate_scoring_key(api_key: str = Header(...)):
if api_key != shared_settings.SCORING_KEY:
raise HTTPException(status_code=403, detail="Invalid API key")


@router.post("/scoring")
async def score_response(request: Request): # , api_key_data: dict = Depends(validate_api_key)):
async def score_response(request: Request, api_key_data: dict = Depends(validate_scoring_key)):
model = None
payload: dict[str, Any] = await request.json()
body = payload.get("body")

try:
if body.get("model") is not None:
model = ModelZoo.get_model_by_id(body.get("model"))
except Exception:
logger.warning(
f"Organic request with model {body.get('model')} made but the model cannot be found in model zoo. Skipping scoring."
)
return
uid = int(payload.get("uid"))
chunks = payload.get("chunks")
llm_model = ModelZoo.get_model_by_id(model) if (model := body.get("model")) else None
Expand All @@ -39,3 +55,4 @@ async def score_response(request: Request): # , api_key_data: dict = Depends(va
step=-1,
task_id=str(uuid.uuid4()),
)
logger.info("Organic tas appended to scoring queue")
10 changes: 7 additions & 3 deletions prompting/llms/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from prompting.llms.hf_llm import ReproducibleHF
from prompting.llms.model_zoo import ModelConfig, ModelZoo
from prompting.llms.utils import GPUInfo
from prompting.mutable_globals import scoring_queue
from shared.loop_runner import AsyncLoopRunner
from shared.settings import shared_settings

Expand Down Expand Up @@ -158,14 +157,19 @@ def generate(
class AsyncModelScheduler(AsyncLoopRunner):
llm_model_manager: ModelManager
interval: int = 14400
scoring_queue: list | None = None

async def start(self, scoring_queue: list):
self.scoring_queue = scoring_queue
return await super().start()

async def initialise_loop(self):
model_manager.load_always_active_models()

async def run_step(self):
"""This method is called periodically according to the interval."""
# try to load the model belonging to the oldest task in the queue
selected_model = scoring_queue[0].task.llm_model if scoring_queue else None
selected_model = self.scoring_queue[0].task.llm_model if self.scoring_queue else None
if not selected_model:
selected_model = ModelZoo.get_random(max_ram=self.llm_model_manager.total_ram)
logger.info(f"Loading model {selected_model.llm_model_id} for {self.interval} seconds.")
Expand All @@ -174,7 +178,7 @@ async def run_step(self):
logger.info(f"Model {selected_model.llm_model_id} is already loaded.")
return

logger.debug(f"Active models: {model_manager.active_models.keys()}")
logger.debug(f"Active models: {self.llm_model_manager.active_models.keys()}")
# Load the selected model
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self.llm_model_manager.load_model, selected_model)
Expand Down
22 changes: 13 additions & 9 deletions prompting/rewards/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from loguru import logger
from pydantic import ConfigDict

from prompting import mutable_globals
from prompting.llms.model_manager import model_manager, model_scheduler
from prompting.tasks.base_task import BaseTextTask
from prompting.tasks.task_registry import TaskRegistry
Expand Down Expand Up @@ -33,9 +32,16 @@ class TaskScorer(AsyncLoopRunner):
is_running: bool = False
thread: threading.Thread = None
interval: int = 10
scoring_queue: list | None = None
reward_events: list | None = None

model_config = ConfigDict(arbitrary_types_allowed=True)

async def start(self, scoring_queue, reward_events):
self.scoring_queue = scoring_queue
self.reward_events = reward_events
return await super().start()

def add_to_queue(
self,
task: BaseTextTask,
Expand All @@ -45,7 +51,7 @@ def add_to_queue(
step: int,
task_id: str,
) -> None:
mutable_globals.scoring_queue.append(
self.scoring_queue.append(
ScoringConfig(
task=task,
response=response,
Expand All @@ -55,26 +61,24 @@ def add_to_queue(
task_id=task_id,
)
)
logger.debug(
f"SCORING: Added to queue: {task.__class__.__name__}. Queue size: {len(mutable_globals.scoring_queue)}"
)
logger.debug(f"SCORING: Added to queue: {task.__class__.__name__}. Queue size: {len(self.scoring_queue)}")

async def run_step(self) -> RewardLoggingEvent:
await asyncio.sleep(0.1)
# Only score responses for which the model is loaded
scorable = [
scoring_config
for scoring_config in mutable_globals.scoring_queue
for scoring_config in self.scoring_queue
if (scoring_config.task.llm_model in model_manager.active_models.keys())
or (scoring_config.task.llm_model is None)
]
if len(scorable) == 0:
logger.debug("Nothing to score. Skipping scoring step.")
# Run a model_scheduler step to load a new model as there are no more tasks to be scored
if len(mutable_globals.scoring_queue) > 0:
if len(self.scoring_queue) > 0:
await model_scheduler.run_step()
return
mutable_globals.scoring_queue.remove(scorable[0])
self.scoring_queue.remove(scorable[0])
scoring_config: ScoringConfig = scorable.pop(0)

# here we generate the actual reference
Expand All @@ -94,7 +98,7 @@ async def run_step(self) -> RewardLoggingEvent:
model_id=scoring_config.task.llm_model,
task=scoring_config.task,
)
mutable_globals.reward_events.append(reward_events)
self.reward_events.append(reward_events)
logger.debug(
f"REFERENCE: {scoring_config.task.reference}\n\n||||RESPONSES: {scoring_config.response.completions}"
)
Expand Down
17 changes: 12 additions & 5 deletions prompting/tasks/task_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pydantic import ConfigDict

from prompting.miner_availability.miner_availability import miner_availabilities
from prompting.mutable_globals import scoring_queue, task_queue
from prompting.tasks.task_registry import TaskRegistry
from shared.logging import ErrorLoggingEvent, ValidatorLoggingEvent
from shared.loop_runner import AsyncLoopRunner
Expand All @@ -18,14 +17,20 @@ class TaskLoop(AsyncLoopRunner):
is_running: bool = False
thread: threading.Thread = None
interval: int = 10

task_queue: list | None = []
scoring_queue: list | None = []
model_config = ConfigDict(arbitrary_types_allowed=True)

async def start(self, task_queue, scoring_queue):
self.task_queue = task_queue
self.scoring_queue = scoring_queue
await super().start()

async def run_step(self) -> ValidatorLoggingEvent | ErrorLoggingEvent | None:
if len(task_queue) > shared_settings.TASK_QUEUE_LENGTH_THRESHOLD:
if len(self.task_queue) > shared_settings.TASK_QUEUE_LENGTH_THRESHOLD:
logger.debug("Task queue is full. Skipping task generation.")
return None
if len(scoring_queue) > shared_settings.SCORING_QUEUE_LENGTH_THRESHOLD:
if len(self.scoring_queue) > shared_settings.SCORING_QUEUE_LENGTH_THRESHOLD:
logger.debug("Scoring queue is full. Skipping task generation.")
return None
await asyncio.sleep(0.1)
Expand Down Expand Up @@ -55,7 +60,9 @@ async def run_step(self) -> ValidatorLoggingEvent | ErrorLoggingEvent | None:
if not task.query:
logger.debug(f"Generating query for task: {task.__class__.__name__}.")
task.make_query(dataset_entry=dataset_entry)
task_queue.append(task)

logger.debug(f"Appending task: {task.__class__.__name__} to task queue.")
self.task_queue.append(task)
except Exception as ex:
logger.exception(ex)
return None
Expand Down
Loading

0 comments on commit 9c1659d

Please sign in to comment.