Skip to content

Commit

Permalink
fix retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
richwardle committed Jan 7, 2025
1 parent e259a00 commit d082fc3
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 17 deletions.
2 changes: 1 addition & 1 deletion neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

torch.multiprocessing.set_start_method("spawn", force=True)

NEURON_SAMPLE_SIZE = 100
NEURON_SAMPLE_SIZE = 100 #TODO: Should add this to constants.py


def create_loop_process(task_queue, scoring_queue, reward_events):
Expand Down
14 changes: 7 additions & 7 deletions prompting/datasets/random_website.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ class DDGDataset(BaseDataset):
english_words: list[str] = None

def search_random_term(self, retries: int = 3) -> tuple[Optional[str], Optional[list[dict[str, str]]]]:
try:
ddg = PatchedDDGS(proxy=shared_settings.PROXY_URL, verify=False)
for _ in range(retries):
random_words = " ".join(random.sample(ENGLISH_WORDS, 5))
ddg = PatchedDDGS(proxy=shared_settings.PROXY_URL, verify=False)
for _ in range(retries):
random_words = " ".join(random.sample(ENGLISH_WORDS, 4))
try:
results = list(ddg.text(random_words))
if results:
return random_words, results
except Exception as ex:
logger.error(f"Failed to get search results from DuckDuckGo: {ex}")
except Exception as ex:
logger.error(f"Failed to get search results from DuckDuckGo: {ex}")
return None, None

@staticmethod
Expand All @@ -44,7 +44,7 @@ def extract_website_content(url: str) -> Optional[str]:
logger.error(f"Failed to extract content from website {url}: {ex}")

def next(self) -> Optional[DDGDatasetEntry]:
search_term, results = self.search_random_term(retries=3)
search_term, results = self.search_random_term(retries=5)
if not results:
return None
website_url = results[0]["href"]
Expand Down
8 changes: 4 additions & 4 deletions prompting/rewards/web_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,24 @@ def reward(self, reference: str, response_event: DendriteResponseEvent, **kwargs
continue

dataset_entry = DDGDatasetEntry.model_validate_json(json.loads(reference))
search_term = dataset_entry.search_term
query = dataset_entry.query
reference_content = dataset_entry.website_content

# Similarity between search term and miner's scraped content.
search_response_sim = self._cosine_similarity(content1=search_term, content2=response_content)
search_response_sim = self._cosine_similarity(content1=query, content2=response_content)

# Similarity between search term and relevant section of content.
search_relevant_sim = 0
if response_relevant is not None:
search_relevant_sim = self._cosine_similarity(content1=search_term, content2=response_relevant)
search_relevant_sim = self._cosine_similarity(content1=query, content2=response_relevant)

# If the URL provided in the completion is valid.
valid_url_score = 0
if response_url_scraped is not None:
valid_url_score = self._cosine_similarity(content1=response_content, content2=response_url_scraped)

# Similarity between search term and reference content.
search_reference_sim = self._cosine_similarity(content1=search_term, content2=reference_content)
search_reference_sim = self._cosine_similarity(content1=query, content2=reference_content)
score = (search_response_sim + valid_url_score + search_relevant_sim) / 3
if abs(search_response_sim - search_reference_sim) > _SEARCH_TERM_THRESH:
logger.info(
Expand Down
4 changes: 2 additions & 2 deletions prompting/tasks/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ class WebQuestionAnsweringTask(BaseTextTask):
reference: str | None = None

def make_query(self, dataset_entry: Context):
query_prompt = QUERY_PROMPT_TEMPLATE.format(context=dataset_entry.content)
query_prompt = QUERY_PROMPT_TEMPLATE.format(context=dataset_entry.website_content)
self.query = self.generate_query(messages=[query_prompt])
return self.query

def make_reference(self, dataset_entry: Context):
reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(context=dataset_entry.content, question=self.query)
reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(context=dataset_entry.website_content, question=self.query)
self.reference = self.generate_reference(messages=[{"role": "user", "content": reference_prompt}])
return self.reference
4 changes: 2 additions & 2 deletions prompting/tasks/task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def __hash__(self):
class TaskRegistry(BaseModel):
task_configs: ClassVar[list[TaskConfig]] = [
TaskConfig(
task=WikiQuestionAnsweringTask, probability=0.2, datasets=[WikiDataset], reward_model=QARewardConfig
task=WikiQuestionAnsweringTask, probability=0.05, datasets=[WikiDataset], reward_model=QARewardConfig
),
TaskConfig(task=WebQuestionAnsweringTask, probability=0.1, datasets=[DDGDataset], reward_model=QARewardConfig),
TaskConfig(task=WebQuestionAnsweringTask, probability=0.25, datasets=[DDGDataset], reward_model=QARewardConfig),
TaskConfig(
task=InferenceTask,
probability=0.1,
Expand Down
4 changes: 3 additions & 1 deletion prompting/tasks/web_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,7 @@ def make_query(self, dataset_entry: DDGDatasetEntry) -> str:
return self.query

def make_reference(self, dataset_entry: DDGDatasetEntry) -> str:
self.reference = json.dumps(dataset_entry.model_dump_json())
ref_dict = dataset_entry.model_dump_json()
ref_dict['query'] = self.query
self.reference = json.dumps(ref_dict)
return self.reference
1 change: 1 addition & 0 deletions shared/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class SharedSettings(BaseSettings):
SUBTENSOR_NETWORK: Optional[str] = Field(None, env="SUBTENSOR_NETWORK")
MAX_ALLOWED_VRAM_GB: int = Field(62, env="MAX_ALLOWED_VRAM_GB")
LLM_MAX_MODEL_LEN: int = Field(4096, env="LLM_MAX_MODEL_LEN")
PROXY_URL: Optional[str] = Field(None, env="PROXY_URL")
LLM_MODEL: str = Field("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4", env="LLM_MODEL")
SAMPLING_PARAMS: dict[str, Any] = {
"temperature": 0.7,
Expand Down

0 comments on commit d082fc3

Please sign in to comment.