Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sn1 360 move qa task to web dataset #524

Merged
merged 12 commits into from
Jan 9, 2025
17 changes: 7 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,21 @@ Subnet one utilizes the concept of "Tasks" to control the behavior of miners. Va
### 1. **QA (Question Answering)**
The miner receives a question about a specific section from a Wikipedia page. The miner must then find the original context in the specified section and use it to return an accurate answer. References are generated using the validators privileged knowledge of the context, and miner complestions are scored based on similarity metrics.

### 2. **Summarization**
Similar to QA, but the miner uses the entire Wikipedia page instead of a specific section. The miner reads the whole page, summarizes it, and provides a concise answer.

### 3. **DateQA**
The miner receives a question about an event from Wikipedia. The miner must search through Wikipedia for the relevant event and return the correct answer based on the findings. References are again generated with validator's knowledge of the context, and similarity metrics are used to score miner completions.

### 4. **Inference**
### 2. **Inference**
A question is given with some pre-seeded information and a random seed. The miner must perform an inference based on this information to provide the correct answer. Completions are scored based on similarity metrics.

### 5. **MultiChoice**
### 3. **MultiChoice**
The miner is presented with a question from Wikipedia along with four possible answers (A, B, C, or D). The miner must search Wikipedia and return the correct answer by selecting one of the given options. Miner completions are scored by Regex matching.

### 6. **Programming**
### 5. **Programming**
The miner receives a code snippet that is incomplete. The task is to complete the code snippet to perform its intended function. The validator generates a reference using it's internal LLM, and the miner is scored based on its similarity to this reference.

### 7. **Web Retrieval**
### 6. **Web Retrieval**
The miner is given a question based on a random web page and must return a scraped website that contains the answer. This requires searching the web to locate the most accurate and reliable source to provide the answer. The miner is scored based on the embedding similarity between the answer it returns and the original website that the validator generated the reference from.

### 7. **Multistep Reasoning**
The miner is given a complex problem that requires multiple steps to solve. Each step builds upon the previous one, and the miner must provide intermediate results before arriving at the final answer. The validator generates a reference solution using its internal LLM, and the miner is scored based on the accuracy and coherence of the intermediate and final results.

# API Documentation

For detailed information on the available API endpoints, request/response formats, and usage examples, please refer to the [API Documentation](./api/API.md).
Expand Down
1 change: 0 additions & 1 deletion docs/SN1_validation.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ More tooling will be included in future releases.
# Tasks
The validation process supports an ever-growing number of tasks. Tasks drive agent behaviour based on specific goals, such as;
- Question answering
- Summarization
- Code debugging
- Mathematics
and more.
Expand Down
2 changes: 1 addition & 1 deletion neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

torch.multiprocessing.set_start_method("spawn", force=True)

NEURON_SAMPLE_SIZE = 100
NEURON_SAMPLE_SIZE = 100 # TODO: Should add this to constants.py


def create_loop_process(task_queue, scoring_queue, reward_events):
Expand Down
14 changes: 7 additions & 7 deletions prompting/datasets/random_website.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ class DDGDataset(BaseDataset):
english_words: list[str] = None

def search_random_term(self, retries: int = 3) -> tuple[Optional[str], Optional[list[dict[str, str]]]]:
try:
ddg = PatchedDDGS(proxy=shared_settings.PROXY_URL, verify=False)
for _ in range(retries):
random_words = " ".join(random.sample(ENGLISH_WORDS, 3))
ddg = PatchedDDGS(proxy=shared_settings.PROXY_URL, verify=False)
for _ in range(retries):
random_words = " ".join(random.sample(ENGLISH_WORDS, 3))
try:
results = list(ddg.text(random_words))
if results:
return random_words, results
except Exception as ex:
logger.error(f"Failed to get search results from DuckDuckGo: {ex}")
except Exception as ex:
logger.error(f"Failed to get search results from DuckDuckGo: {ex}")
return None, None

@staticmethod
Expand All @@ -44,7 +44,7 @@ def extract_website_content(url: str) -> Optional[str]:
logger.error(f"Failed to extract content from website {url}: {ex}")

def next(self) -> Optional[DDGDatasetEntry]:
search_term, results = self.search_random_term(retries=3)
search_term, results = self.search_random_term(retries=5)
if not results:
return None
website_url = results[0]["href"]
Expand Down
8 changes: 4 additions & 4 deletions prompting/rewards/web_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,24 @@ def reward(self, reference: str, response_event: DendriteResponseEvent, **kwargs
continue

dataset_entry = DDGDatasetEntry.model_validate_json(json.loads(reference))
search_term = dataset_entry.search_term
query = dataset_entry.query
reference_content = dataset_entry.website_content

# Similarity between search term and miner's scraped content.
search_response_sim = self._cosine_similarity(content1=search_term, content2=response_content)
search_response_sim = self._cosine_similarity(content1=query, content2=response_content)

# Similarity between search term and relevant section of content.
search_relevant_sim = 0
if response_relevant is not None:
search_relevant_sim = self._cosine_similarity(content1=search_term, content2=response_relevant)
search_relevant_sim = self._cosine_similarity(content1=query, content2=response_relevant)

# If the URL provided in the completion is valid.
valid_url_score = 0
if response_url_scraped is not None:
valid_url_score = self._cosine_similarity(content1=response_content, content2=response_url_scraped)

# Similarity between search term and reference content.
search_reference_sim = self._cosine_similarity(content1=search_term, content2=reference_content)
search_reference_sim = self._cosine_similarity(content1=query, content2=reference_content)
score = (search_response_sim + valid_url_score + search_relevant_sim) / 3
if abs(search_response_sim - search_reference_sim) > _SEARCH_TERM_THRESH:
logger.info(
Expand Down
46 changes: 0 additions & 46 deletions prompting/tasks/date_qa.py

This file was deleted.

4 changes: 2 additions & 2 deletions prompting/tasks/multi_step_reasoning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from prompting.llms.apis.llm_wrapper import LLMWrapper
from prompting.rewards.relevance import RelevanceRewardModel
from prompting.rewards.reward import BaseRewardConfig, BaseRewardModel
from prompting.tasks.qa import QuestionAnsweringTask
from prompting.tasks.qa import WikiQuestionAnsweringTask
from shared.base import Context
from shared.timer import Timer

Expand Down Expand Up @@ -174,7 +174,7 @@ class MultiStepReasoningRewardConfig(BaseRewardConfig):
]


class MultiStepReasoningTask(QuestionAnsweringTask):
class MultiStepReasoningTask(WikiQuestionAnsweringTask):
"""QuestionAnsweringTasks must be initialised with an LLM pipeline to generate query and reference plus
context from a dataset to base the query on"""

Expand Down
26 changes: 24 additions & 2 deletions prompting/tasks/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ class QARewardConfig(BaseRewardConfig):
penalty_definition: ClassVar[list[BaseRewardModel]] = [RougeRewardModel(weight=0.5)]


class QuestionAnsweringTask(BaseTextTask):
class WikiQuestionAnsweringTask(BaseTextTask):
"""QuestionAnsweringTasks must be initialised with an LLM pipeline to generate query and reference plus
context from a dataset to base the query on"""

name: ClassVar[str] = "qa"
name: ClassVar[str] = "wiki_qa"
query_system_prompt: ClassVar[str] = QUERY_SYSTEM_PROMPT
reference_system_prompt: ClassVar[str] = REFERENCE_SYSTEM_PROMPT
augmentation_system_prompt: ClassVar[str] = ""
Expand All @@ -67,3 +67,25 @@ def make_reference(self, dataset_entry: Context):
reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(context=dataset_entry.content, question=self.query)
self.reference = self.generate_reference(messages=[{"role": "user", "content": reference_prompt}])
return self.reference


class WebQuestionAnsweringTask(BaseTextTask):
"""QuestionAnsweringTasks must be initialised with an LLM pipeline to generate query and reference plus
context from a dataset to base the query on"""

name: ClassVar[str] = "web_qa"
query_system_prompt: ClassVar[str] = QUERY_SYSTEM_PROMPT
reference_system_prompt: ClassVar[str] = REFERENCE_SYSTEM_PROMPT
augmentation_system_prompt: ClassVar[str] = ""
query: str | None = None
reference: str | None = None

def make_query(self, dataset_entry: Context):
query_prompt = QUERY_PROMPT_TEMPLATE.format(context=dataset_entry.website_content)
self.query = self.generate_query(messages=[query_prompt])
return self.query

def make_reference(self, dataset_entry: Context):
reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(context=dataset_entry.website_content, question=self.query)
self.reference = self.generate_reference(messages=[{"role": "user", "content": reference_prompt}])
return self.reference
57 changes: 0 additions & 57 deletions prompting/tasks/summarization.py

This file was deleted.

26 changes: 8 additions & 18 deletions prompting/tasks/task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,17 @@
from prompting.datasets.huggingface_github import HuggingFaceGithubDataset
from prompting.datasets.random_website import DDGDataset
from prompting.datasets.sn13 import SN13Dataset
from prompting.datasets.wiki import WikiDataset, WikiDateDataset
from prompting.datasets.wiki import WikiDataset
from prompting.rewards.reward import BaseRewardConfig
from prompting.tasks.base_task import BaseTextTask
from prompting.tasks.date_qa import DateQARewardConfig, DateQuestionAnsweringTask
from prompting.tasks.inference import InferenceRewardConfig, InferenceTask
from prompting.tasks.multi_choice import MultiChoiceRewardConfig, MultiChoiceTask
from prompting.tasks.multi_step_reasoning import MultiStepReasoningRewardConfig, MultiStepReasoningTask
from prompting.tasks.programming_task import ProgrammingRewardConfig, ProgrammingTask
from prompting.tasks.qa import QARewardConfig, QuestionAnsweringTask
from prompting.tasks.summarization import SummarizationRewardConfig, SummarizationTask
from prompting.tasks.qa import QARewardConfig, WebQuestionAnsweringTask, WikiQuestionAnsweringTask
from prompting.tasks.web_retrieval import WebRetrievalRewardConfig, WebRetrievalTask
from shared.base import BaseDataset

# from prompting.tasks.


class TaskConfig(BaseModel):
task: type[BaseTextTask]
Expand All @@ -38,37 +34,31 @@ def __hash__(self):

class TaskRegistry(BaseModel):
task_configs: ClassVar[list[TaskConfig]] = [
TaskConfig(task=QuestionAnsweringTask, probability=0.15, datasets=[WikiDataset], reward_model=QARewardConfig),
TaskConfig(
task=SummarizationTask, probability=0.1, datasets=[WikiDataset], reward_model=SummarizationRewardConfig
),
TaskConfig(
task=DateQuestionAnsweringTask,
probability=0.1,
datasets=[WikiDateDataset],
reward_model=DateQARewardConfig,
task=WikiQuestionAnsweringTask, probability=0.05, datasets=[WikiDataset], reward_model=QARewardConfig
),
TaskConfig(task=WebQuestionAnsweringTask, probability=0.25, datasets=[DDGDataset], reward_model=QARewardConfig),
TaskConfig(
task=InferenceTask,
probability=0.16,
probability=0.1,
datasets=[SN13Dataset],
reward_model=InferenceRewardConfig,
),
TaskConfig(
task=MultiChoiceTask,
probability=0.26,
probability=0.2,
datasets=[WikiDataset],
reward_model=MultiChoiceRewardConfig,
),
TaskConfig(
task=ProgrammingTask,
probability=0.1,
probability=0.2,
datasets=[HuggingFaceGithubDataset],
reward_model=ProgrammingRewardConfig,
),
TaskConfig(
task=WebRetrievalTask,
probability=0.03,
probability=0.1,
datasets=[DDGDataset],
reward_model=WebRetrievalRewardConfig,
),
Expand Down
4 changes: 3 additions & 1 deletion prompting/tasks/web_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,7 @@ def make_query(self, dataset_entry: DDGDatasetEntry) -> str:
return self.query

def make_reference(self, dataset_entry: DDGDatasetEntry) -> str:
self.reference = json.dumps(dataset_entry.model_dump_json())
ref_dict = dataset_entry.model_dump_json()
ref_dict["query"] = self.query
self.reference = json.dumps(ref_dict)
return self.reference
1 change: 1 addition & 0 deletions shared/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class SharedSettings(BaseSettings):
SUBTENSOR_NETWORK: Optional[str] = Field(None, env="SUBTENSOR_NETWORK")
MAX_ALLOWED_VRAM_GB: int = Field(62, env="MAX_ALLOWED_VRAM_GB")
LLM_MAX_MODEL_LEN: int = Field(4096, env="LLM_MAX_MODEL_LEN")
PROXY_URL: Optional[str] = Field(None, env="PROXY_URL")
LLM_MODEL: str = Field("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4", env="LLM_MODEL")
SAMPLING_PARAMS: dict[str, Any] = {
"temperature": 0.7,
Expand Down
2 changes: 1 addition & 1 deletion validator_api/API_docs.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ bash run_api.sh

**Parameters:**

- **task** (query, optional): The type of task (e.g., `QuestionAnsweringTask`, `SummarizationTask`, etc.).
- **task** (query, optional): The type of task (e.g., `QuestionAnsweringTask`, `Programming`, etc.).
- **model** (query, optional): The specific model (string).
- **k** (query, optional): The maximum number of results to return (integer, default: 10).

Expand Down
Loading