Skip to content

Commit

Permalink
Add demo video
Browse files Browse the repository at this point in the history
  • Loading branch information
s-jse authored Aug 23, 2024
1 parent 4e8f06b commit 4cd4eae
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 10 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
<br>
</p>



https://github.com/user-attachments/assets/982e8733-f7a7-468d-940c-5c96f411f527



<!-- <hr /> -->


Expand Down Expand Up @@ -325,4 +331,4 @@ If you have used code or data from this repository, please cite the following pa
url = "https://aclanthology.org/2024.findings-acl.96",
pages = "1663--1678",
}
```
```
23 changes: 15 additions & 8 deletions benchmark/user_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def user_simulation_chain(user_engine: str, user_temperature: float, language: s
)


async def simulate_dialog(dialogue_inputs, args) -> list[DialogueTurn]:
async def simulate_dialogue(dialogue_inputs, args) -> list[DialogueTurn]:
"""
Simulate one dialog
Simulate one dialogue
"""
user_character = random.choice(user_characteristics)
chatbot, dialogue_state = create_chain(args)
Expand Down Expand Up @@ -128,7 +128,7 @@ def repeat_dialogue_inputs(dialogue_inputs, target_num_dialogues):
return dialogue_inputs


def main(args):
async def main(args):
topics = []
if args.mode == "topic":
dialogue_inputs = []
Expand Down Expand Up @@ -162,9 +162,10 @@ def main(args):
raise ValueError("Unknown mode: %s" % args.mode)

all_dialogues = []
for di in tqdm(dialogue_inputs, desc="Dialogues"):
dialogue_state = asyncio.run(simulate_dialog(di, args=args))
all_dialogues.append(dialogue_state)
for i in tqdm(range(0, len(dialogue_inputs), args.batch_size), desc="Dialogue Batches"):
batch = dialogue_inputs[i:i+args.batch_size]
batch_results = await asyncio.gather(*[simulate_dialogue(di, args=args) for di in batch])
all_dialogues.extend(batch_results)

make_parent_directories(args.output_file)
with open(args.output_file, "w") as output_file:
Expand Down Expand Up @@ -237,7 +238,13 @@ def main(args):
"--num_turns",
type=int,
required=True,
help="The number of turns in each dialog",
help="The number of turns in each dialogue",
)
parser.add_argument(
"--batch_size",
type=int,
default=10,
help="The number of dialogues to simulate in parallel",
)
parser.add_argument(
"--language",
Expand All @@ -259,4 +266,4 @@ def main(args):
level=logging.INFO, format=" %(name)s : %(levelname)-8s : %(message)s"
)

main(args)
asyncio.run(main(args))
2 changes: 1 addition & 1 deletion pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def extract_year(title, content):
content = title + " | " + content
years = []
year_pattern = r"\d{4}"
year_duration_pattern = r"\b\d{4}[--–]\d{2}\b"
year_duration_pattern = r"\b\d{4}[-–]\d{2}\b"
year_to_pattern = r"\b\d{4} to \d{4}\b"
# extract "1990 to 1998" before spacy because spacy would split it to 1990 and 1998
re_year_tos = re.findall(year_to_pattern, content)
Expand Down
Binary file added public/demo video.mp4
Binary file not shown.

0 comments on commit 4cd4eae

Please sign in to comment.