diff --git a/README.md b/README.md index f1c3a59..1436c02 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,12 @@

+ + +https://github.com/user-attachments/assets/982e8733-f7a7-468d-940c-5c96f411f527 + + + @@ -325,4 +331,4 @@ If you have used code or data from this repository, please cite the following pa url = "https://aclanthology.org/2024.findings-acl.96", pages = "1663--1678", } -``` \ No newline at end of file +``` diff --git a/benchmark/user_simulator.py b/benchmark/user_simulator.py index 6e6a2dd..9ecce17 100644 --- a/benchmark/user_simulator.py +++ b/benchmark/user_simulator.py @@ -59,9 +59,9 @@ def user_simulation_chain(user_engine: str, user_temperature: float, language: s ) -async def simulate_dialog(dialogue_inputs, args) -> list[DialogueTurn]: +async def simulate_dialogue(dialogue_inputs, args) -> list[DialogueTurn]: """ - Simulate one dialog + Simulate one dialogue """ user_character = random.choice(user_characteristics) chatbot, dialogue_state = create_chain(args) @@ -128,7 +128,7 @@ def repeat_dialogue_inputs(dialogue_inputs, target_num_dialogues): return dialogue_inputs -def main(args): +async def main(args): topics = [] if args.mode == "topic": dialogue_inputs = [] @@ -162,9 +162,10 @@ def main(args): raise ValueError("Unknown mode: %s" % args.mode) all_dialogues = [] - for di in tqdm(dialogue_inputs, desc="Dialogues"): - dialogue_state = asyncio.run(simulate_dialog(di, args=args)) - all_dialogues.append(dialogue_state) + for i in tqdm(range(0, len(dialogue_inputs), args.batch_size), desc="Dialogue Batches"): + batch = dialogue_inputs[i:i+args.batch_size] + batch_results = await asyncio.gather(*[simulate_dialogue(di, args=args) for di in batch]) + all_dialogues.extend(batch_results) make_parent_directories(args.output_file) with open(args.output_file, "w") as output_file: @@ -237,7 +238,13 @@ def main(args): "--num_turns", type=int, required=True, - help="The number of turns in each dialog", + help="The number of turns in each dialogue", + ) + parser.add_argument( + "--batch_size", + type=int, + default=10, + help="The number of dialogues to simulate in parallel", ) parser.add_argument( "--language", @@ -259,4 +266,4 @@ def main(args): level=logging.INFO, format=" %(name)s : %(levelname)-8s : %(message)s" ) - main(args) + asyncio.run(main(args)) diff --git a/pipelines/utils.py b/pipelines/utils.py index 345b7dc..51b6251 100644 --- a/pipelines/utils.py +++ b/pipelines/utils.py @@ -415,7 +415,7 @@ def extract_year(title, content): content = title + " | " + content years = [] year_pattern = r"\d{4}" - year_duration_pattern = r"\b\d{4}[--–]\d{2}\b" + year_duration_pattern = r"\b\d{4}[-–]\d{2}\b" year_to_pattern = r"\b\d{4} to \d{4}\b" # extract "1990 to 1998" before spacy because spacy would split it to 1990 and 1998 re_year_tos = re.findall(year_to_pattern, content) diff --git a/public/demo video.mp4 b/public/demo video.mp4 new file mode 100644 index 0000000..b0ee985 Binary files /dev/null and b/public/demo video.mp4 differ