diff --git a/neurons/validator.py b/neurons/validator.py index 9d7e6305..a8735229 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -22,7 +22,7 @@ torch.multiprocessing.set_start_method("spawn", force=True) -NEURON_SAMPLE_SIZE = 100 +NEURON_SAMPLE_SIZE = 100 #TODO: Should add this to constants.py def create_loop_process(task_queue, scoring_queue, reward_events): diff --git a/prompting/datasets/random_website.py b/prompting/datasets/random_website.py index 3058812a..4061e889 100644 --- a/prompting/datasets/random_website.py +++ b/prompting/datasets/random_website.py @@ -23,15 +23,15 @@ class DDGDataset(BaseDataset): english_words: list[str] = None def search_random_term(self, retries: int = 3) -> tuple[Optional[str], Optional[list[dict[str, str]]]]: - try: - ddg = PatchedDDGS(proxy=shared_settings.PROXY_URL, verify=False) - for _ in range(retries): - random_words = " ".join(random.sample(ENGLISH_WORDS, 5)) + ddg = PatchedDDGS(proxy=shared_settings.PROXY_URL, verify=False) + for _ in range(retries): + random_words = " ".join(random.sample(ENGLISH_WORDS, 4)) + try: results = list(ddg.text(random_words)) if results: return random_words, results - except Exception as ex: - logger.error(f"Failed to get search results from DuckDuckGo: {ex}") + except Exception as ex: + logger.error(f"Failed to get search results from DuckDuckGo: {ex}") return None, None @staticmethod @@ -44,7 +44,7 @@ def extract_website_content(url: str) -> Optional[str]: logger.error(f"Failed to extract content from website {url}: {ex}") def next(self) -> Optional[DDGDatasetEntry]: - search_term, results = self.search_random_term(retries=3) + search_term, results = self.search_random_term(retries=5) if not results: return None website_url = results[0]["href"] diff --git a/prompting/rewards/web_retrieval.py b/prompting/rewards/web_retrieval.py index 592ef9d2..7c258082 100644 --- a/prompting/rewards/web_retrieval.py +++ b/prompting/rewards/web_retrieval.py @@ -60,16 +60,16 @@ def reward(self, reference: str, response_event: DendriteResponseEvent, **kwargs continue dataset_entry = DDGDatasetEntry.model_validate_json(json.loads(reference)) - search_term = dataset_entry.search_term + query = dataset_entry.query reference_content = dataset_entry.website_content # Similarity between search term and miner's scraped content. - search_response_sim = self._cosine_similarity(content1=search_term, content2=response_content) + search_response_sim = self._cosine_similarity(content1=query, content2=response_content) # Similarity between search term and relevant section of content. search_relevant_sim = 0 if response_relevant is not None: - search_relevant_sim = self._cosine_similarity(content1=search_term, content2=response_relevant) + search_relevant_sim = self._cosine_similarity(content1=query, content2=response_relevant) # If the URL provided in the completion is valid. valid_url_score = 0 @@ -77,7 +77,7 @@ def reward(self, reference: str, response_event: DendriteResponseEvent, **kwargs valid_url_score = self._cosine_similarity(content1=response_content, content2=response_url_scraped) # Similarity between search term and reference content. - search_reference_sim = self._cosine_similarity(content1=search_term, content2=reference_content) + search_reference_sim = self._cosine_similarity(content1=query, content2=reference_content) score = (search_response_sim + valid_url_score + search_relevant_sim) / 3 if abs(search_response_sim - search_reference_sim) > _SEARCH_TERM_THRESH: logger.info( diff --git a/prompting/tasks/qa.py b/prompting/tasks/qa.py index fdc3d167..314121d8 100644 --- a/prompting/tasks/qa.py +++ b/prompting/tasks/qa.py @@ -81,11 +81,11 @@ class WebQuestionAnsweringTask(BaseTextTask): reference: str | None = None def make_query(self, dataset_entry: Context): - query_prompt = QUERY_PROMPT_TEMPLATE.format(context=dataset_entry.content) + query_prompt = QUERY_PROMPT_TEMPLATE.format(context=dataset_entry.website_content) self.query = self.generate_query(messages=[query_prompt]) return self.query def make_reference(self, dataset_entry: Context): - reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(context=dataset_entry.content, question=self.query) + reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(context=dataset_entry.website_content, question=self.query) self.reference = self.generate_reference(messages=[{"role": "user", "content": reference_prompt}]) return self.reference diff --git a/prompting/tasks/task_registry.py b/prompting/tasks/task_registry.py index d26dbd06..8dee0ee2 100644 --- a/prompting/tasks/task_registry.py +++ b/prompting/tasks/task_registry.py @@ -35,9 +35,9 @@ def __hash__(self): class TaskRegistry(BaseModel): task_configs: ClassVar[list[TaskConfig]] = [ TaskConfig( - task=WikiQuestionAnsweringTask, probability=0.2, datasets=[WikiDataset], reward_model=QARewardConfig + task=WikiQuestionAnsweringTask, probability=0.05, datasets=[WikiDataset], reward_model=QARewardConfig ), - TaskConfig(task=WebQuestionAnsweringTask, probability=0.1, datasets=[DDGDataset], reward_model=QARewardConfig), + TaskConfig(task=WebQuestionAnsweringTask, probability=0.25, datasets=[DDGDataset], reward_model=QARewardConfig), TaskConfig( task=InferenceTask, probability=0.1, diff --git a/prompting/tasks/web_retrieval.py b/prompting/tasks/web_retrieval.py index f595a0ba..7d5f4b19 100644 --- a/prompting/tasks/web_retrieval.py +++ b/prompting/tasks/web_retrieval.py @@ -43,5 +43,7 @@ def make_query(self, dataset_entry: DDGDatasetEntry) -> str: return self.query def make_reference(self, dataset_entry: DDGDatasetEntry) -> str: - self.reference = json.dumps(dataset_entry.model_dump_json()) + ref_dict = dataset_entry.model_dump_json() + ref_dict['query'] = self.query + self.reference = json.dumps(ref_dict) return self.reference diff --git a/shared/settings.py b/shared/settings.py index 1be9d5ea..e4df024f 100644 --- a/shared/settings.py +++ b/shared/settings.py @@ -113,6 +113,7 @@ class SharedSettings(BaseSettings): SUBTENSOR_NETWORK: Optional[str] = Field(None, env="SUBTENSOR_NETWORK") MAX_ALLOWED_VRAM_GB: int = Field(62, env="MAX_ALLOWED_VRAM_GB") LLM_MAX_MODEL_LEN: int = Field(4096, env="LLM_MAX_MODEL_LEN") + PROXY_URL: Optional[str] = Field(None, env="PROXY_URL") LLM_MODEL: str = Field("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4", env="LLM_MODEL") SAMPLING_PARAMS: dict[str, Any] = { "temperature": 0.7,