Skip to content

Commit

Permalink
Using multiple messages for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Hollyqui committed Jan 7, 2025
1 parent 4654d3a commit 24a25b7
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 93 deletions.
38 changes: 1 addition & 37 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 8 additions & 53 deletions prompting/datasets/sn13.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import random
from typing import ClassVar

import datasets
import nltk
from nltk.corpus import wordnet
from pydantic import model_validator

from shared.base import BaseDataset, ChatEntry

nltk.download("wordnet")


class SN13Dataset(BaseDataset):
_url: ClassVar[str] = "arrmlet/x_dataset_218"
Expand Down Expand Up @@ -42,50 +36,11 @@ def sample(self) -> ChatEntry:
raise self.exception
# Randomly select a sample from the dataset.
sample_idx = random.randint(0, len(self.dataset) - 1)
message = self.dataset[sample_idx]["text"]
role = ["user"]

# Augment the messages by modifying words and introducing errors.
messages = [self._augment_message(role, message)]

return ChatEntry(roles=role, messages=messages, organic=False, source=self._url)

def _augment_message(self, role: str, message: str) -> str:
if role == "assistant":
return message

words = message.split()
num_words_to_modify = random.randint(1, max(1, int(len(words) * self._chance_word_synonym)))
words_to_modify = random.sample(range(len(words)), num_words_to_modify)

for idx in words_to_modify:
synonym = self._get_synonym(words[idx])
if synonym:
words[idx] = synonym

message = " ".join(words)
message = self._introduce_typos(message)
return message

def _get_synonym(self, word: str) -> str:
synonyms = wordnet.synsets(word)
if synonyms:
# Choose a synonym that is not the word itself.
synonym_words = [lemma.name() for lemma in synonyms[0].lemmas() if lemma.name() != word]
if synonym_words:
return random.choice(synonym_words)
return word

def _introduce_typos(self, message: str) -> str:
message = list(message)
num_errors = random.randint(0, max(1, int(len(message) * self._chance_char_typo)))
for _ in range(num_errors):
error_type = random.choice(["remove", "add_space"])
error_position = random.randint(0, len(message) - 1)

if error_type == "remove":
message.pop(error_position)
elif error_type == "add_space":
message.insert(error_position, " ")

return "".join(message)
messages = []
roles = []
for _ in range(4):
if message := self.dataset[sample_idx]["text"]:
roles.append(random.choice(["user", "assistant"]))
messages.append(message)

return ChatEntry(roles=roles, messages=messages, organic=False, source=self._url)
4 changes: 2 additions & 2 deletions prompting/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def make_query(self, dataset_entry: ChatEntry) -> str:

def make_reference(self, dataset_entry: ChatEntry) -> str:
self.reference = model_manager.generate(
messages=self.messages,
roles=["user"],
messages=dataset_entry.messages,
roles=dataset_entry.roles,
model=self.llm_model,
seed=self.seed,
sampling_params=self.sampling_params,
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ huggingface-hub = { version = ">=0.25.2", optional = true }
pandas = { version = ">=2.2.1", optional = true }
trafilatura = { version = ">=1.12.1", optional = true }
datasets = { version = ">=3.1.0", optional = true }
nltk = { version = ">=3.8.1", optional = true }
primp = { version = "==0.8.1", optional = true }

[tool.poetry.extras]
Expand Down

0 comments on commit 24a25b7

Please sign in to comment.