From 3c9d4a4b380e30a753501e2d724b657e9f02909a Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Mon, 30 Sep 2024 20:32:21 -0500 Subject: [PATCH 01/19] add mutlitthread and pipe to monitor active task --- backend/generator_func.py | 90 ++++++++-------- backend/main.py | 215 +++++++++++++++++++++++++------------- 2 files changed, 188 insertions(+), 117 deletions(-) diff --git a/backend/generator_func.py b/backend/generator_func.py index 3c272565..5ad20843 100644 --- a/backend/generator_func.py +++ b/backend/generator_func.py @@ -12,46 +12,52 @@ def run_engine( url: str, ) -> Generator[Tuple[Optional[Progress], Optional[Agent]], None, None]: - intro = get_paper_introduction(url) - - if not intro: + try: + intro = get_paper_introduction(url) + if not intro: + yield None, None + return + + config_file_path = '../configs' + profile_file_path = '../examples/profiles' + paper_file_path = '../examples/papers' + + config = Config(config_file_path) + profile_db = ProfileDB() + paper_db = PaperDB() + + if os.path.exists(paper_file_path) and os.path.exists(profile_file_path): + profile_db.load_from_json(profile_file_path, with_embed=True) + paper_db.load_from_json(paper_file_path, with_embed=True) + else: + raise FileNotFoundError('Profile and paper databases not found.') + + log_db = LogDB() + progress_db = ProgressDB() + + engine = Engine( + project_name='research_town_demo', + profile_db=profile_db, + paper_db=paper_db, + progress_db=progress_db, + log_db=log_db, + config=config, + ) + + engine.start(contexts=[intro]) + + while engine.curr_env.name != 'end': + run_result = engine.curr_env.run() + + if run_result: + for progress, agent in run_result: + yield progress, agent + engine.time_step += 1 + + engine.transition() + except Exception as e: + print(f'Error occurred during engine execution: {e}') + + finally: + print('Engine execution completed.') yield None, None - return - - config_file_path = '../configs' - profile_file_path = '../examples/profiles' - paper_file_path = '../examples/papers' - - config = Config(config_file_path) - profile_db = ProfileDB() - paper_db = PaperDB() - - if os.path.exists(paper_file_path) and os.path.exists(profile_file_path): - profile_db.load_from_json(profile_file_path, with_embed=True) - paper_db.load_from_json(paper_file_path, with_embed=True) - else: - raise FileNotFoundError('Profile and paper databases not found.') - - log_db = LogDB() - progress_db = ProgressDB() - - engine = Engine( - project_name='research_town_demo', - profile_db=profile_db, - paper_db=paper_db, - progress_db=progress_db, - log_db=log_db, - config=config, - ) - - engine.start(contexts=[intro]) - - while engine.curr_env.name != 'end': - run_result = engine.curr_env.run() - - if run_result: - for progress, agent in run_result: - yield progress, agent - engine.time_step += 1 - - engine.transition() diff --git a/backend/main.py b/backend/main.py index c4acdaab..0dba0da6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,4 +1,7 @@ +import asyncio import json +import multiprocessing +import uuid from typing import Generator, Optional, Tuple from fastapi import FastAPI, Request, Response @@ -28,8 +31,105 @@ allow_headers=['*'], ) +active_processes = {} -@app.post('/process') # type: ignore + +def stop_process(user_id): + if user_id in active_processes: + process = active_processes[user_id] + process.terminate() # Safely terminate the process + process.join() # Ensure cleanup + del active_processes[user_id] + print(f'Process for user {user_id} stopped.') + + +def background_task(url: str, child_conn): + generator = run_engine(url) + try: + for progress, agent in generator: + child_conn.send((progress, agent)) + except Exception as e: + child_conn.send({'type': 'error', 'content': str(e)}) + finally: + child_conn.send(None) + child_conn.close() + + +def format_response( + generator: Generator[Tuple[Optional[Progress], Optional[Agent]], None, None], +) -> Generator[str, None, None]: + for progress, agent in generator: + item = {} + if progress is None or agent is None: + item = { + 'type': 'error', + 'content': 'Failed to collect complete paper content from the link.', + } + elif isinstance(progress, Insight): + item = {'type': 'insight', 'content': progress.content} + elif isinstance(progress, Idea): + item = {'type': 'idea', 'content': progress.content} + elif isinstance(progress, Proposal): + item = { + 'type': 'proposal', + 'q1': progress.q1 or '', + 'q2': progress.q2 or '', + 'q3': progress.q3 or '', + 'q4': progress.q4 or '', + 'q5': progress.q5 or '', + } + elif isinstance(progress, Review): + item = { + 'type': 'review', + 'summary': progress.summary or '', + 'strength': progress.strength or '', + 'weakness': progress.weakness or '', + 'ethical_concerns': progress.ethical_concerns or '', + 'score': str(progress.score) if progress.score else '-1', + } + elif isinstance(progress, Rebuttal): + item = { + 'type': 'rebuttal', + 'q1': progress.q1 or '', + 'q2': progress.q2 or '', + 'q3': progress.q3 or '', + 'q4': progress.q4 or '', + 'q5': progress.q5 or '', + } + elif isinstance(progress, MetaReview): + item = { + 'type': 'metareview', + 'summary': progress.summary or '', + 'strength': progress.strength or '', + 'weakness': progress.weakness or '', + 'ethical_concerns': progress.ethical_concerns or '', + 'decision': 'accept' if progress.decision else 'reject', + } + else: + item = {'type': 'error', 'content': 'Unrecognized progress type'} + + if agent: + item['agent_name'] = agent.profile.name + if agent.profile.domain is not None: + if len(agent.profile.domain) > 1: + item['agent_domain'] = agent.profile.domain[0].lower() + else: + item['agent_domain'] = 'computer science' + + if agent.role == 'chair': + item['agent_role'] = 'chair' + elif agent.role == 'reviewer': + item['agent_role'] = 'reviewer' + elif agent.role == 'leader': + item['agent_role'] = 'leader' + elif agent.role == 'member': + item['agent_role'] = 'member' + else: + item['agent_role'] = 'none' + yield json.dumps(item) + '\n' + + +@app.post('/process') # This remains unchanged as per your request async def process_url(request: Request) -> Response: # Get URL from the request body data = await request.json() @@ -39,80 +139,45 @@ async def process_url(request: Request) -> Response: if not url: return JSONResponse({'error': 'URL is required'}, status_code=400) - # Helper function to process the generator output - def format_response( - generator: Generator[Tuple[Optional[Progress], Optional[Agent]], None, None], - ) -> Generator[str, None, None]: - for progress, agent in generator: - item = {} - if progress is None or agent is None: - item = { - 'type': 'error', - 'content': 'Failed to collect complete paper content from the link.', - } - elif isinstance(progress, Insight): - item = {'type': 'insight', 'content': progress.content} - elif isinstance(progress, Idea): - item = {'type': 'idea', 'content': progress.content} - elif isinstance(progress, Proposal): - item = { - 'type': 'proposal', - 'q1': progress.q1 or '', - 'q2': progress.q2 or '', - 'q3': progress.q3 or '', - 'q4': progress.q4 or '', - 'q5': progress.q5 or '', - } - elif isinstance(progress, Review): - item = { - 'type': 'review', - 'summary': progress.summary or '', - 'strength': progress.strength or '', - 'weakness': progress.weakness or '', - 'ethical_concerns': progress.ethical_concerns or '', - 'score': str(progress.score) if progress.score else '-1', - } - elif isinstance(progress, Rebuttal): - item = { - 'type': 'rebuttal', - 'q1': progress.q1 or '', - 'q2': progress.q2 or '', - 'q3': progress.q3 or '', - 'q4': progress.q4 or '', - 'q5': progress.q5 or '', - } - elif isinstance(progress, MetaReview): - item = { - 'type': 'metareview', - 'summary': progress.summary or '', - 'strength': progress.strength or '', - 'weakness': progress.weakness or '', - 'ethical_concerns': progress.ethical_concerns or '', - 'decision': 'accept' if progress.decision else 'reject', - } - else: - item = {'type': 'error', 'content': 'Unrecognized progress type'} + # Generate a unique user ID for the task + user_id = str(uuid.uuid4()) - if agent: - item['agent_name'] = agent.profile.name - if agent.profile.domain is not None: - if len(agent.profile.domain) > 1: - item['agent_domain'] = agent.profile.domain[0].lower() - else: - item['agent_domain'] = 'computer science' - - if agent.role == 'chair': - item['agent_role'] = 'chair' - elif agent.role == 'reviewer': - item['agent_role'] = 'reviewer' - elif agent.role == 'leader': - item['agent_role'] = 'leader' - elif agent.role == 'member': - item['agent_role'] = 'member' + # Create a multiprocessing Pipe for communication + parent_conn, child_conn = multiprocessing.Pipe() + + # Start the background task as a separate process + process = multiprocessing.Process(target=background_task, args=(url, child_conn)) + process.start() + + # Store the process for later tracking + active_processes[user_id] = process + print(f'Task for user {user_id} started.') + + async def stream_response(): + try: + # Fetch results from the pipe and format them + while True: + if await request.is_disconnected(): + print(f'Client disconnected for user {user_id}. Cancelling task.') + stop_process(user_id) # Stop the background process + break + + if parent_conn.poll(): + result = parent_conn.recv() + if result is None: + break # End of data + + # Now pass the result into the format_response function + # Yield formatted response to client + print('reuslt') + print(result) + for formatted_output in format_response(iter([result])): + yield formatted_output else: - item['agent_role'] = 'none' - yield json.dumps(item) + '\n' + await asyncio.sleep(0.1) # Avoid busy-waiting + finally: + print('TRUE') + stop_process(user_id) # Ensure process is stopped on function exit - # Run the engine and stream the results back - generator = run_engine(url) - return StreamingResponse(format_response(generator), media_type='application/json') + # Return the StreamingResponse + return StreamingResponse(stream_response(), media_type='application/json') From 91397ae18a1a93da9f005722e879e68e7d1809af Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Fri, 4 Oct 2024 13:53:23 -0500 Subject: [PATCH 02/19] fix def format --- backend/main.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/backend/main.py b/backend/main.py index 0dba0da6..9ba4b443 100644 --- a/backend/main.py +++ b/backend/main.py @@ -31,10 +31,10 @@ allow_headers=['*'], ) -active_processes = {} +active_processes: dict[str, multiprocessing.Process] = {} -def stop_process(user_id): +def stop_process(user_id: str) -> None: if user_id in active_processes: process = active_processes[user_id] process.terminate() # Safely terminate the process @@ -43,7 +43,9 @@ def stop_process(user_id): print(f'Process for user {user_id} stopped.') -def background_task(url: str, child_conn): +def background_task( + url: str, child_conn: multiprocessing.connection.Connection +) -> None: generator = run_engine(url) try: for progress, agent in generator: @@ -129,7 +131,7 @@ def format_response( yield json.dumps(item) + '\n' -@app.post('/process') # This remains unchanged as per your request +@app.post('/process') async def process_url(request: Request) -> Response: # Get URL from the request body data = await request.json() @@ -159,22 +161,18 @@ async def stream_response(): while True: if await request.is_disconnected(): print(f'Client disconnected for user {user_id}. Cancelling task.') - stop_process(user_id) # Stop the background process + stop_process(user_id) break if parent_conn.poll(): result = parent_conn.recv() if result is None: - break # End of data + break - # Now pass the result into the format_response function - # Yield formatted response to client - print('reuslt') - print(result) for formatted_output in format_response(iter([result])): yield formatted_output else: - await asyncio.sleep(0.1) # Avoid busy-waiting + await asyncio.sleep(0.1) finally: print('TRUE') stop_process(user_id) # Ensure process is stopped on function exit From 35b2dee52e55b589b21604f5ab4b0afd6ff84bd2 Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Fri, 4 Oct 2024 14:18:16 -0500 Subject: [PATCH 03/19] Merge branch --- backend/main.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/backend/main.py b/backend/main.py index 9ba4b443..caaaf4f4 100644 --- a/backend/main.py +++ b/backend/main.py @@ -4,7 +4,7 @@ import uuid from typing import Generator, Optional, Tuple -from fastapi import FastAPI, Request, Response +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from generator_func import run_engine @@ -132,7 +132,7 @@ def format_response( @app.post('/process') -async def process_url(request: Request) -> Response: +async def process_url(request: Request) -> StreamingResponse: # Get URL from the request body data = await request.json() url = data.get('url') @@ -155,7 +155,7 @@ async def process_url(request: Request) -> Response: active_processes[user_id] = process print(f'Task for user {user_id} started.') - async def stream_response(): + async def stream_response() -> Generator[str, None, None]: try: # Fetch results from the pipe and format them while True: @@ -174,7 +174,6 @@ async def stream_response(): else: await asyncio.sleep(0.1) finally: - print('TRUE') stop_process(user_id) # Ensure process is stopped on function exit # Return the StreamingResponse From 3b8eaf8eabb314a0676b9902f5de9ddd01470983 Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Fri, 4 Oct 2024 14:23:08 -0500 Subject: [PATCH 04/19] format fix --- backend/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/main.py b/backend/main.py index caaaf4f4..0b08ece4 100644 --- a/backend/main.py +++ b/backend/main.py @@ -2,7 +2,7 @@ import json import multiprocessing import uuid -from typing import Generator, Optional, Tuple +from typing import AsyncGenerator, Generator, Optional, Tuple from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware @@ -155,7 +155,7 @@ async def process_url(request: Request) -> StreamingResponse: active_processes[user_id] = process print(f'Task for user {user_id} started.') - async def stream_response() -> Generator[str, None, None]: + async def stream_response() -> AsyncGenerator[str, None]: try: # Fetch results from the pipe and format them while True: From af478c793f4dd0edfb5d510ef7ed451474d6cabf Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Fri, 4 Oct 2024 14:36:13 -0500 Subject: [PATCH 05/19] merge --- backend/main.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/backend/main.py b/backend/main.py index 0b08ece4..59b3b600 100644 --- a/backend/main.py +++ b/backend/main.py @@ -57,6 +57,12 @@ def background_task( child_conn.close() +def generator_wrapper( + result: Tuple[Optional[Progress], Optional[Agent]], +) -> Generator[Tuple[Optional[Progress], Optional[Agent]], None, None]: + yield result + + def format_response( generator: Generator[Tuple[Optional[Progress], Optional[Agent]], None, None], ) -> Generator[str, None, None]: @@ -169,7 +175,7 @@ async def stream_response() -> AsyncGenerator[str, None]: if result is None: break - for formatted_output in format_response(iter([result])): + for formatted_output in format_response(generator_wrapper(result)): yield formatted_output else: await asyncio.sleep(0.1) From 0efa9ee685d4fc201ca7a927ed45252a58711131 Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Fri, 4 Oct 2024 14:43:18 -0500 Subject: [PATCH 06/19] merge --- backend/main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backend/main.py b/backend/main.py index 59b3b600..8fb82f4f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -137,7 +137,7 @@ def format_response( yield json.dumps(item) + '\n' -@app.post('/process') +@app.post('/process') # type: ignore async def process_url(request: Request) -> StreamingResponse: # Get URL from the request body data = await request.json() @@ -163,7 +163,6 @@ async def process_url(request: Request) -> StreamingResponse: async def stream_response() -> AsyncGenerator[str, None]: try: - # Fetch results from the pipe and format them while True: if await request.is_disconnected(): print(f'Client disconnected for user {user_id}. Cancelling task.') From 76de6928769c3fbbd6ed853fef1451fd7b96c752 Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Sat, 5 Oct 2024 12:49:27 -0500 Subject: [PATCH 07/19] save format prompt (#699) --- backend/generator_func.py | 3 +- backend/main.py | 36 +++++++++++++++++++++-- research_town/utils/agent_prompter.py | 16 +++++++++- research_town/utils/prompt_constructor.py | 14 ++++++++- 4 files changed, 63 insertions(+), 6 deletions(-) diff --git a/backend/generator_func.py b/backend/generator_func.py index 5ad20843..f0ab7883 100644 --- a/backend/generator_func.py +++ b/backend/generator_func.py @@ -59,5 +59,4 @@ def run_engine( print(f'Error occurred during engine execution: {e}') finally: - print('Engine execution completed.') - yield None, None + ('Engine execution completed.') diff --git a/backend/main.py b/backend/main.py index 8fb82f4f..4704ab99 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,6 +1,7 @@ import asyncio import json import multiprocessing +import os import uuid from typing import AsyncGenerator, Generator, Optional, Tuple @@ -40,7 +41,6 @@ def stop_process(user_id: str) -> None: process.terminate() # Safely terminate the process process.join() # Ensure cleanup del active_processes[user_id] - print(f'Process for user {user_id} stopped.') def background_task( @@ -48,12 +48,25 @@ def background_task( ) -> None: generator = run_engine(url) try: + # Generate and send results to the parent process for progress, agent in generator: child_conn.send((progress, agent)) + + print( + 'Generation complete. Background task is now idle and waiting for manual termination.' + ) + while True: + if child_conn.poll(): + msg = child_conn.recv() + if msg == 'terminate': + break + except Exception as e: child_conn.send({'type': 'error', 'content': str(e)}) + finally: child_conn.send(None) + print('Finish Generation') child_conn.close() @@ -63,6 +76,17 @@ def generator_wrapper( yield result +def clean_prompt_data() -> None: + directory = os.path.join('..', 'data', 'prompt_data') + print(f'Cleaning prompt data in {directory}') + if os.path.exists(directory): + for filename in os.listdir(directory): + file_path = os.path.join(directory, filename) + os.remove(file_path) + else: + os.makedirs(directory) + + def format_response( generator: Generator[Tuple[Optional[Progress], Optional[Agent]], None, None], ) -> Generator[str, None, None]: @@ -139,6 +163,8 @@ def format_response( @app.post('/process') # type: ignore async def process_url(request: Request) -> StreamingResponse: + clean_prompt_data() + # Get URL from the request body data = await request.json() url = data.get('url') @@ -164,22 +190,28 @@ async def process_url(request: Request) -> StreamingResponse: async def stream_response() -> AsyncGenerator[str, None]: try: while True: + # Check if the client has disconnected if await request.is_disconnected(): print(f'Client disconnected for user {user_id}. Cancelling task.') stop_process(user_id) break + # Check for new data from the background task if parent_conn.poll(): result = parent_conn.recv() if result is None: + print(f'No more data for user {user_id}. Stopping task.') break + # Stream the formatted output to the client for formatted_output in format_response(generator_wrapper(result)): yield formatted_output else: await asyncio.sleep(0.1) finally: - stop_process(user_id) # Ensure process is stopped on function exit + if await request.is_disconnected(): + stop_process(user_id) + print(f'Process for user {user_id} stopped due to disconnection.') # Return the StreamingResponse return StreamingResponse(stream_response(), media_type='application/json') diff --git a/research_town/utils/agent_prompter.py b/research_town/utils/agent_prompter.py index a1e9f666..53edf1c9 100644 --- a/research_town/utils/agent_prompter.py +++ b/research_town/utils/agent_prompter.py @@ -4,7 +4,7 @@ from beartype.typing import Dict, List, Optional, Tuple, Union from .model_prompting import model_prompting -from .prompt_constructor import openai_format_prompt_construct +from .prompt_constructor import openai_format_prompt_construct, save_prompt_to_json from .string_mapper import ( map_idea_list_to_str, map_idea_to_str, @@ -36,6 +36,7 @@ def review_literature_prompting( 'papers': papers_str, } messages = openai_format_prompt_construct(prompt_template, template_input) + save_prompt_to_json('review_literature', messages[0]['content']) insight = model_prompting( model_name, @@ -84,6 +85,7 @@ def brainstorm_idea_prompting( papers_str = map_paper_list_to_str(papers) template_input = {'bio': bio, 'insights': insights_str, 'papers': papers_str} messages = openai_format_prompt_construct(prompt_template, template_input) + save_prompt_to_json('brainstorm_idea', messages[0]['content']) return model_prompting( model_name, messages, @@ -111,6 +113,7 @@ def discuss_idea_prompting( ideas_str = map_idea_list_to_str(ideas) template_input = {'bio': bio, 'ideas': ideas_str, 'contexts': contexts} messages = openai_format_prompt_construct(prompt_template, template_input) + save_prompt_to_json('discuss_idea', messages[0]['content']) return model_prompting( model_name, messages, @@ -138,6 +141,7 @@ def write_proposal_prompting( papers_str = map_paper_list_to_str(papers) template_input = {'idea': idea_str, 'papers': papers_str} messages = openai_format_prompt_construct(prompt_template, template_input) + save_prompt_to_json('write_proposal', messages[0]['content']) proposal = model_prompting( model_name, messages, @@ -194,14 +198,17 @@ def write_review_prompting( strength_messages = openai_format_prompt_construct( strength_prompt_template, strength_template_input ) + save_prompt_to_json('review_strength', strength_messages[0]['content']) weakness_template_input = {'proposal': proposal_str, 'summary': summary} weakness_messages = openai_format_prompt_construct( weakness_prompt_template, weakness_template_input ) + save_prompt_to_json('review_weakness', weakness_messages[0]['content']) ethical_template_input = {'proposal': proposal_str, 'summary': summary} ethical_messages = openai_format_prompt_construct( ethical_prompt_template, ethical_template_input ) + save_prompt_to_json('review_ethical', ethical_messages[0]['content']) strength = model_prompting( model_name, @@ -241,6 +248,7 @@ def write_review_prompting( score_messages = openai_format_prompt_construct( score_prompt_template, score_template_input ) + save_prompt_to_json('review_score', score_messages[0]['content']) score_str = ( model_prompting( model_name, @@ -286,6 +294,7 @@ def write_metareview_prompting( summary_messages = openai_format_prompt_construct( summary_prompt_template, summary_template_input ) + save_prompt_to_json('metareview_summary', summary_messages[0]['content']) summary = model_prompting( model_name, summary_messages, @@ -314,12 +323,15 @@ def write_metareview_prompting( strength_messages = openai_format_prompt_construct( strength_prompt_template, strength_template_input ) + save_prompt_to_json('metareview_strength', strength_messages[0]['content']) weakness_messages = openai_format_prompt_construct( weakness_prompt_template, weakness_template_input ) + save_prompt_to_json('metareview_weakness', weakness_messages[0]['content']) ethical_messages = openai_format_prompt_construct( ethical_prompt_template, ethical_template_input ) + save_prompt_to_json('metareview_ethical', ethical_messages[0]['content']) strength = model_prompting( model_name, @@ -360,6 +372,7 @@ def write_metareview_prompting( decision_messages = openai_format_prompt_construct( decision_prompt_template, decision_template_input ) + save_prompt_to_json('metareview_decision', decision_messages[0]['content']) decision_str = model_prompting( model_name, decision_messages, @@ -390,6 +403,7 @@ def write_rebuttal_prompting( review_str = map_review_to_str(review) template_input = {'proposal': proposal_str, 'review': review_str} messages = openai_format_prompt_construct(prompt_template, template_input) + save_prompt_to_json('write_rebuttal', messages[0]['content']) rebuttal = model_prompting( model_name, messages, diff --git a/research_town/utils/prompt_constructor.py b/research_town/utils/prompt_constructor.py index 37d97549..6fb78f8b 100644 --- a/research_town/utils/prompt_constructor.py +++ b/research_town/utils/prompt_constructor.py @@ -1,3 +1,5 @@ +import json +import os from typing import Any, Dict, List, Union @@ -5,7 +7,6 @@ def openai_format_prompt_construct( template: Dict[str, Union[str, List[str]]], input_data: Dict[str, Any] ) -> List[Dict[str, str]]: messages = [] - if 'sys_prompt' in template: sys_prompt = template['sys_prompt'] assert isinstance(sys_prompt, str) @@ -26,3 +27,14 @@ def openai_format_prompt_construct( messages.append({'role': 'user', 'content': query}) return messages + + +def save_prompt_to_json( + template_name: str, messages: List[Dict[str, Union[str, List[str]]]] +) -> None: + file_name = f'{template_name}.log' + directory_path = os.path.join('..', 'data', 'prompt_data') + file_path = os.path.join(directory_path, file_name) + + with open(file_path, 'w') as log_file: + json.dump(messages, log_file, indent=4) From 4c707768526c728bb6242a1fdcaf3f7cd05c9273 Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Sat, 5 Oct 2024 12:54:45 -0500 Subject: [PATCH 08/19] save format prompt --- backend/main.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/backend/main.py b/backend/main.py index 4704ab99..e01eda4f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -52,9 +52,6 @@ def background_task( for progress, agent in generator: child_conn.send((progress, agent)) - print( - 'Generation complete. Background task is now idle and waiting for manual termination.' - ) while True: if child_conn.poll(): msg = child_conn.recv() From 79b6618ce703c13b6e89b28ebe2df645c4a598b3 Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Sat, 5 Oct 2024 13:12:43 -0500 Subject: [PATCH 09/19] make sure format prompt dir exists --- backend/main.py | 5 ++--- research_town/utils/prompt_constructor.py | 7 +++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/backend/main.py b/backend/main.py index e01eda4f..1739ad74 100644 --- a/backend/main.py +++ b/backend/main.py @@ -63,7 +63,6 @@ def background_task( finally: child_conn.send(None) - print('Finish Generation') child_conn.close() @@ -75,11 +74,11 @@ def generator_wrapper( def clean_prompt_data() -> None: directory = os.path.join('..', 'data', 'prompt_data') - print(f'Cleaning prompt data in {directory}') if os.path.exists(directory): for filename in os.listdir(directory): file_path = os.path.join(directory, filename) - os.remove(file_path) + with open(file_path, 'w'): + pass else: os.makedirs(directory) diff --git a/research_town/utils/prompt_constructor.py b/research_town/utils/prompt_constructor.py index 6fb78f8b..484e4f36 100644 --- a/research_town/utils/prompt_constructor.py +++ b/research_town/utils/prompt_constructor.py @@ -1,4 +1,3 @@ -import json import os from typing import Any, Dict, List, Union @@ -34,7 +33,11 @@ def save_prompt_to_json( ) -> None: file_name = f'{template_name}.log' directory_path = os.path.join('..', 'data', 'prompt_data') + + if not os.path.exists(directory_path): + os.makedirs(directory_path) + file_path = os.path.join(directory_path, file_name) with open(file_path, 'w') as log_file: - json.dump(messages, log_file, indent=4) + log_file.write(str(messages)) From ca0b4383c9cf8508462910b8c4988d3a53298c18 Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Sat, 5 Oct 2024 13:21:03 -0500 Subject: [PATCH 10/19] minor fix --- research_town/utils/agent_prompter.py | 30 +++++++++++------------ research_town/utils/prompt_constructor.py | 2 +- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/research_town/utils/agent_prompter.py b/research_town/utils/agent_prompter.py index 53edf1c9..7b401fab 100644 --- a/research_town/utils/agent_prompter.py +++ b/research_town/utils/agent_prompter.py @@ -4,7 +4,7 @@ from beartype.typing import Dict, List, Optional, Tuple, Union from .model_prompting import model_prompting -from .prompt_constructor import openai_format_prompt_construct, save_prompt_to_json +from .prompt_constructor import openai_format_prompt_construct, save_prompt_to_log from .string_mapper import ( map_idea_list_to_str, map_idea_to_str, @@ -36,7 +36,7 @@ def review_literature_prompting( 'papers': papers_str, } messages = openai_format_prompt_construct(prompt_template, template_input) - save_prompt_to_json('review_literature', messages[0]['content']) + save_prompt_to_log('review_literature', messages) insight = model_prompting( model_name, @@ -85,7 +85,7 @@ def brainstorm_idea_prompting( papers_str = map_paper_list_to_str(papers) template_input = {'bio': bio, 'insights': insights_str, 'papers': papers_str} messages = openai_format_prompt_construct(prompt_template, template_input) - save_prompt_to_json('brainstorm_idea', messages[0]['content']) + save_prompt_to_log('brainstorm_idea', messages) return model_prompting( model_name, messages, @@ -113,7 +113,7 @@ def discuss_idea_prompting( ideas_str = map_idea_list_to_str(ideas) template_input = {'bio': bio, 'ideas': ideas_str, 'contexts': contexts} messages = openai_format_prompt_construct(prompt_template, template_input) - save_prompt_to_json('discuss_idea', messages[0]['content']) + save_prompt_to_log('discuss_idea', messages) return model_prompting( model_name, messages, @@ -141,7 +141,7 @@ def write_proposal_prompting( papers_str = map_paper_list_to_str(papers) template_input = {'idea': idea_str, 'papers': papers_str} messages = openai_format_prompt_construct(prompt_template, template_input) - save_prompt_to_json('write_proposal', messages[0]['content']) + save_prompt_to_log('write_proposal', messages) proposal = model_prompting( model_name, messages, @@ -198,17 +198,17 @@ def write_review_prompting( strength_messages = openai_format_prompt_construct( strength_prompt_template, strength_template_input ) - save_prompt_to_json('review_strength', strength_messages[0]['content']) + save_prompt_to_log('review_strength', strength_messages) weakness_template_input = {'proposal': proposal_str, 'summary': summary} weakness_messages = openai_format_prompt_construct( weakness_prompt_template, weakness_template_input ) - save_prompt_to_json('review_weakness', weakness_messages[0]['content']) + save_prompt_to_log('review_weakness', weakness_messages) ethical_template_input = {'proposal': proposal_str, 'summary': summary} ethical_messages = openai_format_prompt_construct( ethical_prompt_template, ethical_template_input ) - save_prompt_to_json('review_ethical', ethical_messages[0]['content']) + save_prompt_to_log('review_ethical', ethical_messages) strength = model_prompting( model_name, @@ -248,7 +248,7 @@ def write_review_prompting( score_messages = openai_format_prompt_construct( score_prompt_template, score_template_input ) - save_prompt_to_json('review_score', score_messages[0]['content']) + save_prompt_to_log('review_score', score_messages) score_str = ( model_prompting( model_name, @@ -294,7 +294,7 @@ def write_metareview_prompting( summary_messages = openai_format_prompt_construct( summary_prompt_template, summary_template_input ) - save_prompt_to_json('metareview_summary', summary_messages[0]['content']) + save_prompt_to_log('metareview_summary', summary_messages) summary = model_prompting( model_name, summary_messages, @@ -323,15 +323,15 @@ def write_metareview_prompting( strength_messages = openai_format_prompt_construct( strength_prompt_template, strength_template_input ) - save_prompt_to_json('metareview_strength', strength_messages[0]['content']) + save_prompt_to_log('metareview_strength', strength_messages) weakness_messages = openai_format_prompt_construct( weakness_prompt_template, weakness_template_input ) - save_prompt_to_json('metareview_weakness', weakness_messages[0]['content']) + save_prompt_to_log('metareview_weakness', weakness_messages) ethical_messages = openai_format_prompt_construct( ethical_prompt_template, ethical_template_input ) - save_prompt_to_json('metareview_ethical', ethical_messages[0]['content']) + save_prompt_to_log('metareview_ethical', ethical_messages) strength = model_prompting( model_name, @@ -372,7 +372,7 @@ def write_metareview_prompting( decision_messages = openai_format_prompt_construct( decision_prompt_template, decision_template_input ) - save_prompt_to_json('metareview_decision', decision_messages[0]['content']) + save_prompt_to_log('metareview_decision', decision_messages) decision_str = model_prompting( model_name, decision_messages, @@ -403,7 +403,7 @@ def write_rebuttal_prompting( review_str = map_review_to_str(review) template_input = {'proposal': proposal_str, 'review': review_str} messages = openai_format_prompt_construct(prompt_template, template_input) - save_prompt_to_json('write_rebuttal', messages[0]['content']) + save_prompt_to_log('write_rebuttal', messages) rebuttal = model_prompting( model_name, messages, diff --git a/research_town/utils/prompt_constructor.py b/research_town/utils/prompt_constructor.py index 484e4f36..fbd1c656 100644 --- a/research_town/utils/prompt_constructor.py +++ b/research_town/utils/prompt_constructor.py @@ -28,7 +28,7 @@ def openai_format_prompt_construct( return messages -def save_prompt_to_json( +def save_prompt_to_log( template_name: str, messages: List[Dict[str, Union[str, List[str]]]] ) -> None: file_name = f'{template_name}.log' From 5b39d8e4085b135f59d4fd4842aa252342ac6cf2 Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Sat, 5 Oct 2024 13:29:35 -0500 Subject: [PATCH 11/19] minor fix --- backend/main.py | 3 +-- research_town/utils/prompt_constructor.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/backend/main.py b/backend/main.py index 1739ad74..5811db48 100644 --- a/backend/main.py +++ b/backend/main.py @@ -77,8 +77,7 @@ def clean_prompt_data() -> None: if os.path.exists(directory): for filename in os.listdir(directory): file_path = os.path.join(directory, filename) - with open(file_path, 'w'): - pass + os.remove(file_path) else: os.makedirs(directory) diff --git a/research_town/utils/prompt_constructor.py b/research_town/utils/prompt_constructor.py index fbd1c656..3fb11bcb 100644 --- a/research_town/utils/prompt_constructor.py +++ b/research_town/utils/prompt_constructor.py @@ -29,7 +29,7 @@ def openai_format_prompt_construct( def save_prompt_to_log( - template_name: str, messages: List[Dict[str, Union[str, List[str]]]] + template_name: str, messages: list[dict[str, str | list[str]]] ) -> None: file_name = f'{template_name}.log' directory_path = os.path.join('..', 'data', 'prompt_data') From ce1b6c2cf6bad00020c69126f67ce01df4c71f9f Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Sat, 5 Oct 2024 13:33:53 -0500 Subject: [PATCH 12/19] minor fix --- research_town/utils/prompt_constructor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/research_town/utils/prompt_constructor.py b/research_town/utils/prompt_constructor.py index 3fb11bcb..5159ef93 100644 --- a/research_town/utils/prompt_constructor.py +++ b/research_town/utils/prompt_constructor.py @@ -28,9 +28,7 @@ def openai_format_prompt_construct( return messages -def save_prompt_to_log( - template_name: str, messages: list[dict[str, str | list[str]]] -) -> None: +def save_prompt_to_log(template_name: str, messages: list[dict[str, str]]) -> None: file_name = f'{template_name}.log' directory_path = os.path.join('..', 'data', 'prompt_data') From 3530c959c8b1619c50fa936504848c5b71a72dcd Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Sun, 6 Oct 2024 20:07:13 -0500 Subject: [PATCH 13/19] store prompt to corresponding log --- backend/main.py | 13 --- research_town/agents/agent.py | 119 +++++++++++++++++----- research_town/agents/agent_manager.py | 3 +- research_town/data/data.py | 5 +- research_town/utils/agent_prompter.py | 64 ++++++------ research_town/utils/prompt_constructor.py | 14 --- 6 files changed, 133 insertions(+), 85 deletions(-) diff --git a/backend/main.py b/backend/main.py index 5811db48..900eea32 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,7 +1,6 @@ import asyncio import json import multiprocessing -import os import uuid from typing import AsyncGenerator, Generator, Optional, Tuple @@ -72,16 +71,6 @@ def generator_wrapper( yield result -def clean_prompt_data() -> None: - directory = os.path.join('..', 'data', 'prompt_data') - if os.path.exists(directory): - for filename in os.listdir(directory): - file_path = os.path.join(directory, filename) - os.remove(file_path) - else: - os.makedirs(directory) - - def format_response( generator: Generator[Tuple[Optional[Progress], Optional[Agent]], None, None], ) -> Generator[str, None, None]: @@ -158,8 +147,6 @@ def format_response( @app.post('/process') # type: ignore async def process_url(request: Request) -> StreamingResponse: - clean_prompt_data() - # Get URL from the request body data = await request.json() url = data.get('url') diff --git a/research_town/agents/agent.py b/research_town/agents/agent.py index 171c5511..c3894317 100644 --- a/research_town/agents/agent.py +++ b/research_town/agents/agent.py @@ -2,7 +2,23 @@ from beartype.typing import Dict, List, Literal, Tuple from ..configs import Config -from ..data import Idea, Insight, MetaReview, Paper, Profile, Proposal, Rebuttal, Review +from ..data import ( + Idea, + IdeaBrainstormLog, + Insight, + LiteratureReviewLog, + MetaReview, + MetaReviewWritingLog, + Paper, + Profile, + Proposal, + ProposalWritingLog, + Rebuttal, + RebuttalWritingLog, + Review, + ReviewWritingLog, +) +from ..dbs import LogDB from ..utils.agent_prompter import ( brainstorm_idea_prompting, discuss_idea_prompting, @@ -28,12 +44,14 @@ def __init__( self, profile: Profile, model_name: str, + log_db: LogDB, role: Role = None, ) -> None: self.profile: Profile = profile self.memory: Dict[str, str] = {} self.role: Role = role self.model_name: str = model_name + self.log_db: LogDB = log_db self.serializer = Serializer() @beartype @@ -50,7 +68,7 @@ def review_literature( ) -> Tuple[str, List[str], Insight]: serialized_papers = self.serializer.serialize(papers) serialized_profile = self.serializer.serialize(self.profile) - summary, keywords, valuable_points = review_literature_prompting( + summary, keywords, valuable_points, prompt = review_literature_prompting( profile=serialized_profile, papers=serialized_papers, contexts=contexts, @@ -63,6 +81,12 @@ def review_literature( stream=config.param.stream, ) insight = Insight(content=valuable_points) + log_entry = LiteratureReviewLog( + profile_pk=self.profile.pk, + insight_pk=insight.pk, + prompt=prompt, # Store the formatted prompt + ) + self.log_db.add(log_entry) return summary, keywords, insight @beartype @@ -72,7 +96,7 @@ def brainstorm_idea( ) -> Idea: serialized_insights = self.serializer.serialize(insights) serialized_papers = self.serializer.serialize(papers) - idea_content = brainstorm_idea_prompting( + idea_content, prompt = brainstorm_idea_prompting( bio=self.profile.bio, insights=serialized_insights, papers=serialized_papers, @@ -83,8 +107,15 @@ def brainstorm_idea( temperature=config.param.temperature, top_p=config.param.top_p, stream=config.param.stream, - )[0] - return Idea(content=idea_content) + ) + idea_content = idea_content[0] + idea = Idea(content=idea_content) + + log_entry = IdeaBrainstormLog( + profile_pk=self.profile.pk, idea_pk=idea.pk, prompt=prompt + ) + self.log_db.add(log_entry) + return idea @beartype @member_required @@ -92,7 +123,7 @@ def discuss_idea( self, ideas: List[Idea], contexts: List[str], config: Config ) -> Idea: serialized_ideas = self.serializer.serialize(ideas) - idea_summarized = discuss_idea_prompting( + idea_summarized, prompt = discuss_idea_prompting( bio=self.profile.bio, contexts=contexts, ideas=serialized_ideas, @@ -103,7 +134,8 @@ def discuss_idea( temperature=config.param.temperature, top_p=config.param.top_p, stream=config.param.stream, - )[0] + ) + idea_summarized = idea_summarized[0] return Idea(content=idea_summarized) @beartype @@ -127,7 +159,7 @@ def write_proposal( print('write_proposal_strategy not supported, will use default') prompt_template = config.agent_prompt_template.write_proposal - proposal, q5_result = write_proposal_prompting( + proposal, q5_result, prompt = write_proposal_prompting( idea=serialized_idea, papers=serialized_papers, model_name=self.model_name, @@ -138,7 +170,8 @@ def write_proposal( top_p=config.param.top_p, stream=config.param.stream, ) - return Proposal( + + proposal_obj = Proposal( content=proposal, q1=q5_result.get('q1', ''), q2=q5_result.get('q2', ''), @@ -147,26 +180,37 @@ def write_proposal( q5=q5_result.get('q5', ''), ) + log_entry = ProposalWritingLog( + profile_pk=self.profile.pk, + proposal_pk=proposal_obj.pk, + prompt=prompt, # Store the formatted prompt + ) + self.log_db.add(log_entry) # Add the log entry to LogDB + + return proposal_obj + @beartype @reviewer_required def write_review(self, proposal: Proposal, config: Config) -> Review: serialized_proposal = self.serializer.serialize(proposal) - summary, strength, weakness, ethical_concerns, score = write_review_prompting( - proposal=serialized_proposal, - model_name=self.model_name, - summary_prompt_template=config.agent_prompt_template.write_review_summary, - strength_prompt_template=config.agent_prompt_template.write_review_strength, - weakness_prompt_template=config.agent_prompt_template.write_review_weakness, - ethical_prompt_template=config.agent_prompt_template.write_review_ethical, - score_prompt_template=config.agent_prompt_template.write_review_score, - return_num=config.param.return_num, - max_token_num=config.param.max_token_num, - temperature=config.param.temperature, - top_p=config.param.top_p, - stream=config.param.stream, + summary, strength, weakness, ethical_concerns, score, prompt = ( + write_review_prompting( + proposal=serialized_proposal, + model_name=self.model_name, + summary_prompt_template=config.agent_prompt_template.write_review_summary, + strength_prompt_template=config.agent_prompt_template.write_review_strength, + weakness_prompt_template=config.agent_prompt_template.write_review_weakness, + ethical_prompt_template=config.agent_prompt_template.write_review_ethical, + score_prompt_template=config.agent_prompt_template.write_review_score, + return_num=config.param.return_num, + max_token_num=config.param.max_token_num, + temperature=config.param.temperature, + top_p=config.param.top_p, + stream=config.param.stream, + ) ) - return Review( + review_obj = Review( proposal_pk=proposal.pk, reviewer_pk=self.profile.pk, summary=summary, @@ -176,6 +220,13 @@ def write_review(self, proposal: Proposal, config: Config) -> Review: score=score, ) + log_entry = ReviewWritingLog( + profile_pk=self.profile.pk, review_pk=review_obj.pk, prompt=prompt + ) + self.log_db.add(log_entry) + + return review_obj + @beartype @chair_required def write_metareview( @@ -187,7 +238,7 @@ def write_metareview( serialized_proposal = self.serializer.serialize(proposal) serialized_reviews = self.serializer.serialize(reviews) - summary, strength, weakness, ethical_concerns, decision = ( + summary, strength, weakness, ethical_concerns, decision, prompt = ( write_metareview_prompting( proposal=serialized_proposal, reviews=serialized_reviews, @@ -205,7 +256,7 @@ def write_metareview( ) ) - return MetaReview( + metareview_obj = MetaReview( proposal_pk=proposal.pk, chair_pk=self.profile.pk, reviewer_pks=[review.reviewer_pk for review in reviews], @@ -217,6 +268,13 @@ def write_metareview( decision=decision, ) + log_entry = MetaReviewWritingLog( + profile_pk=self.profile.pk, metareview_pk=metareview_obj.pk, prompt=prompt + ) + self.log_db.add(log_entry) + + return metareview_obj + @beartype @leader_required def write_rebuttal( @@ -228,7 +286,7 @@ def write_rebuttal( serialized_proposal = self.serializer.serialize(proposal) serialized_review = self.serializer.serialize(review) - rebuttal_content, q5_result = write_rebuttal_prompting( + rebuttal_content, q5_result, prompt = write_rebuttal_prompting( proposal=serialized_proposal, review=serialized_review, model_name=self.model_name, @@ -240,7 +298,7 @@ def write_rebuttal( stream=config.param.stream, ) - return Rebuttal( + rebuttal_obj = Rebuttal( proposal_pk=proposal.pk, reviewer_pk=review.reviewer_pk, author_pk=self.profile.pk, @@ -251,3 +309,10 @@ def write_rebuttal( q4=q5_result.get('q4', ''), q5=q5_result.get('q5', ''), ) + + log_entry = RebuttalWritingLog( + profile_pk=self.profile.pk, rebuttal_pk=rebuttal_obj.pk, prompt=prompt + ) + self.log_db.add(log_entry) + + return rebuttal_obj diff --git a/research_town/agents/agent_manager.py b/research_town/agents/agent_manager.py index 6d69d4e5..3572a937 100644 --- a/research_town/agents/agent_manager.py +++ b/research_town/agents/agent_manager.py @@ -2,7 +2,7 @@ from ..configs import Config from ..data import Profile, Proposal -from ..dbs import ProfileDB +from ..dbs import LogDB, ProfileDB from .agent import Agent Role = Literal['reviewer', 'leader', 'member', 'chair'] @@ -17,6 +17,7 @@ def create_agent(self, profile: Profile, role: Role) -> Agent: return Agent( profile=profile, role=role, + log_db=LogDB(), model_name=self.config.param.base_llm, ) diff --git a/research_town/data/data.py b/research_town/data/data.py index a00f2852..5fc3a2e8 100644 --- a/research_town/data/data.py +++ b/research_town/data/data.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, ConfigDict, Field @@ -43,6 +43,9 @@ class Paper(Data): class Log(Data): timestep: int = Field(default=0) profile_pk: str + prompt: Optional[Union[List[Dict[str, str]], List[List[Dict[str, str]]]]] = Field( + default=None + ) class LiteratureReviewLog(Log): diff --git a/research_town/utils/agent_prompter.py b/research_town/utils/agent_prompter.py index 7b401fab..a899eb68 100644 --- a/research_town/utils/agent_prompter.py +++ b/research_town/utils/agent_prompter.py @@ -4,7 +4,7 @@ from beartype.typing import Dict, List, Optional, Tuple, Union from .model_prompting import model_prompting -from .prompt_constructor import openai_format_prompt_construct, save_prompt_to_log +from .prompt_constructor import openai_format_prompt_construct from .string_mapper import ( map_idea_list_to_str, map_idea_to_str, @@ -28,7 +28,7 @@ def review_literature_prompting( temperature: Optional[float] = 0.0, top_p: Optional[float] = None, stream: Optional[bool] = None, -) -> Tuple[str, List[str], str]: +) -> Tuple[str, List[str], str, list[dict[str, str]]]: papers_str = map_paper_list_to_str(papers) template_input = { 'bio': profile['bio'], @@ -36,7 +36,8 @@ def review_literature_prompting( 'papers': papers_str, } messages = openai_format_prompt_construct(prompt_template, template_input) - save_prompt_to_log('review_literature', messages) + + formatted_prompt = messages insight = model_prompting( model_name, @@ -65,7 +66,7 @@ def review_literature_prompting( valuable_points = ( valuable_points_match.group(1).strip() if valuable_points_match else '' ) - return summary, keywords, valuable_points + return summary, keywords, valuable_points, formatted_prompt @beartype @@ -80,12 +81,12 @@ def brainstorm_idea_prompting( temperature: Optional[float] = 0.0, top_p: Optional[float] = None, stream: Optional[bool] = None, -) -> List[str]: +) -> Tuple[List[str], List[Dict[str, str]]]: insights_str = map_insight_list_to_str(insights) papers_str = map_paper_list_to_str(papers) template_input = {'bio': bio, 'insights': insights_str, 'papers': papers_str} messages = openai_format_prompt_construct(prompt_template, template_input) - save_prompt_to_log('brainstorm_idea', messages) + formatted_prompt = messages return model_prompting( model_name, messages, @@ -94,7 +95,7 @@ def brainstorm_idea_prompting( temperature=temperature, top_p=top_p, stream=stream, - ) + ), formatted_prompt @beartype @@ -109,11 +110,12 @@ def discuss_idea_prompting( temperature: Optional[float] = 0.0, top_p: Optional[float] = None, stream: Optional[bool] = None, -) -> List[str]: +) -> Tuple[List[str], List[Dict[str, str]]]: ideas_str = map_idea_list_to_str(ideas) template_input = {'bio': bio, 'ideas': ideas_str, 'contexts': contexts} messages = openai_format_prompt_construct(prompt_template, template_input) - save_prompt_to_log('discuss_idea', messages) + formatted_prompt = messages + return model_prompting( model_name, messages, @@ -122,7 +124,7 @@ def discuss_idea_prompting( temperature=temperature, top_p=top_p, stream=stream, - ) + ), formatted_prompt @beartype @@ -136,12 +138,13 @@ def write_proposal_prompting( temperature: Optional[float] = 0.0, top_p: Optional[float] = None, stream: Optional[bool] = None, -) -> Tuple[str, Dict[str, str]]: +) -> Tuple[str, Dict[str, str], List[Dict[str, str]]]: idea_str = map_idea_to_str(idea) papers_str = map_paper_list_to_str(papers) template_input = {'idea': idea_str, 'papers': papers_str} messages = openai_format_prompt_construct(prompt_template, template_input) - save_prompt_to_log('write_proposal', messages) + + formatted_prompt = messages proposal = model_prompting( model_name, messages, @@ -161,7 +164,7 @@ def write_proposal_prompting( answer = match[1].strip() q5_result[question_number] = answer - return proposal, q5_result + return proposal, q5_result, formatted_prompt @beartype @@ -178,7 +181,8 @@ def write_review_prompting( temperature: Optional[float] = 0.0, top_p: Optional[float] = None, stream: Optional[bool] = None, -) -> Tuple[str, str, str, str, int]: +) -> Tuple[str, str, str, str, int, List[List[Dict[str, str]]]]: + formatted_prompts: List[List[Dict[str, str]]] = [] proposal_str = map_proposal_to_str(proposal) summary_template_input = {'proposal': proposal_str} summary_messages = openai_format_prompt_construct( @@ -198,17 +202,17 @@ def write_review_prompting( strength_messages = openai_format_prompt_construct( strength_prompt_template, strength_template_input ) - save_prompt_to_log('review_strength', strength_messages) + formatted_prompts.append(strength_messages) weakness_template_input = {'proposal': proposal_str, 'summary': summary} weakness_messages = openai_format_prompt_construct( weakness_prompt_template, weakness_template_input ) - save_prompt_to_log('review_weakness', weakness_messages) + formatted_prompts.append(weakness_messages) ethical_template_input = {'proposal': proposal_str, 'summary': summary} ethical_messages = openai_format_prompt_construct( ethical_prompt_template, ethical_template_input ) - save_prompt_to_log('review_ethical', ethical_messages) + formatted_prompts.append(ethical_messages) strength = model_prompting( model_name, @@ -248,7 +252,8 @@ def write_review_prompting( score_messages = openai_format_prompt_construct( score_prompt_template, score_template_input ) - save_prompt_to_log('review_score', score_messages) + + formatted_prompts.append(score_messages) score_str = ( model_prompting( model_name, @@ -266,7 +271,7 @@ def write_review_prompting( ) score = int(score_str[0]) if score_str[0].isdigit() else 0 - return summary, strength, weakness, ethical_concerns, score + return summary, strength, weakness, ethical_concerns, score, formatted_prompts @beartype @@ -284,7 +289,8 @@ def write_metareview_prompting( temperature: Optional[float] = 0.0, top_p: Optional[float] = None, stream: Optional[bool] = None, -) -> Tuple[str, str, str, str, bool]: +) -> Tuple[str, str, str, str, bool, List[List[Dict[str, str]]]]: + formatted_prompts: List[List[Dict[str, str]]] = [] proposal_str = map_proposal_to_str(proposal) reviews_str = map_review_list_to_str(reviews) summary_template_input = { @@ -294,7 +300,7 @@ def write_metareview_prompting( summary_messages = openai_format_prompt_construct( summary_prompt_template, summary_template_input ) - save_prompt_to_log('metareview_summary', summary_messages) + formatted_prompts.append(summary_messages) summary = model_prompting( model_name, summary_messages, @@ -323,15 +329,15 @@ def write_metareview_prompting( strength_messages = openai_format_prompt_construct( strength_prompt_template, strength_template_input ) - save_prompt_to_log('metareview_strength', strength_messages) + formatted_prompts.append(strength_messages) weakness_messages = openai_format_prompt_construct( weakness_prompt_template, weakness_template_input ) - save_prompt_to_log('metareview_weakness', weakness_messages) + formatted_prompts.append(weakness_messages) ethical_messages = openai_format_prompt_construct( ethical_prompt_template, ethical_template_input ) - save_prompt_to_log('metareview_ethical', ethical_messages) + formatted_prompts.append(ethical_messages) strength = model_prompting( model_name, @@ -372,7 +378,7 @@ def write_metareview_prompting( decision_messages = openai_format_prompt_construct( decision_prompt_template, decision_template_input ) - save_prompt_to_log('metareview_decision', decision_messages) + formatted_prompts.append(decision_messages) decision_str = model_prompting( model_name, decision_messages, @@ -384,7 +390,7 @@ def write_metareview_prompting( ) decision = 'accept' in decision_str[0].lower() - return summary, strength, weakness, ethical_concerns, decision + return summary, strength, weakness, ethical_concerns, decision, formatted_prompts @beartype @@ -398,12 +404,12 @@ def write_rebuttal_prompting( temperature: Optional[float] = 0.0, top_p: Optional[float] = None, stream: Optional[bool] = None, -) -> Tuple[str, Dict[str, str]]: +) -> Tuple[str, Dict[str, str], List[Dict[str, str]]]: proposal_str = map_proposal_to_str(proposal) review_str = map_review_to_str(review) template_input = {'proposal': proposal_str, 'review': review_str} messages = openai_format_prompt_construct(prompt_template, template_input) - save_prompt_to_log('write_rebuttal', messages) + formatted_prompt = messages rebuttal = model_prompting( model_name, messages, @@ -423,4 +429,4 @@ def write_rebuttal_prompting( answer = match[1].strip() q5_result[question_number] = answer - return rebuttal, q5_result + return rebuttal, q5_result, formatted_prompt diff --git a/research_town/utils/prompt_constructor.py b/research_town/utils/prompt_constructor.py index 5159ef93..c0072e03 100644 --- a/research_town/utils/prompt_constructor.py +++ b/research_town/utils/prompt_constructor.py @@ -1,4 +1,3 @@ -import os from typing import Any, Dict, List, Union @@ -26,16 +25,3 @@ def openai_format_prompt_construct( messages.append({'role': 'user', 'content': query}) return messages - - -def save_prompt_to_log(template_name: str, messages: list[dict[str, str]]) -> None: - file_name = f'{template_name}.log' - directory_path = os.path.join('..', 'data', 'prompt_data') - - if not os.path.exists(directory_path): - os.makedirs(directory_path) - - file_path = os.path.join(directory_path, file_name) - - with open(file_path, 'w') as log_file: - log_file.write(str(messages)) From 3db1bb06462b994d5a26b47ad6b531a129a02aeb Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Sun, 6 Oct 2024 20:24:04 -0500 Subject: [PATCH 14/19] minor fix --- research_town/agents/agent.py | 8 ++++---- tests/agents/test_agents.py | 19 ++++++++++++++----- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/research_town/agents/agent.py b/research_town/agents/agent.py index c3894317..92ba7f50 100644 --- a/research_town/agents/agent.py +++ b/research_town/agents/agent.py @@ -96,7 +96,7 @@ def brainstorm_idea( ) -> Idea: serialized_insights = self.serializer.serialize(insights) serialized_papers = self.serializer.serialize(papers) - idea_content, prompt = brainstorm_idea_prompting( + idea_content_list, prompt = brainstorm_idea_prompting( bio=self.profile.bio, insights=serialized_insights, papers=serialized_papers, @@ -108,7 +108,7 @@ def brainstorm_idea( top_p=config.param.top_p, stream=config.param.stream, ) - idea_content = idea_content[0] + idea_content = idea_content_list[0] idea = Idea(content=idea_content) log_entry = IdeaBrainstormLog( @@ -123,7 +123,7 @@ def discuss_idea( self, ideas: List[Idea], contexts: List[str], config: Config ) -> Idea: serialized_ideas = self.serializer.serialize(ideas) - idea_summarized, prompt = discuss_idea_prompting( + idea_summarized_list, prompt = discuss_idea_prompting( bio=self.profile.bio, contexts=contexts, ideas=serialized_ideas, @@ -135,7 +135,7 @@ def discuss_idea( top_p=config.param.top_p, stream=config.param.stream, ) - idea_summarized = idea_summarized[0] + idea_summarized = idea_summarized_list[0] return Idea(content=idea_summarized) @beartype diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 33b60794..ce91ce24 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -21,9 +21,11 @@ def test_review_literature( mock_model_prompting: MagicMock, ) -> None: mock_model_prompting.side_effect = mock_prompting + mock_log_db = MagicMock() agent = Agent( profile=profile_A, model_name='gpt-4o-mini', + log_db=mock_log_db, role='leader', ) _, _, research_insight = agent.review_literature( @@ -43,10 +45,11 @@ def test_brainstorm_idea( mock_model_prompting: MagicMock, ) -> None: mock_model_prompting.side_effect = mock_prompting - + mock_log_db = MagicMock() agent = Agent( profile=profile_A, model_name='gpt-4o-mini', + log_db=mock_log_db, role='leader', ) research_idea = agent.brainstorm_idea( @@ -62,10 +65,11 @@ def test_brainstorm_idea( @patch('research_town.utils.agent_prompter.model_prompting') def test_write_proposal(mock_model_prompting: MagicMock) -> None: mock_model_prompting.side_effect = mock_prompting - + mock_log_db = MagicMock() agent = Agent( profile=profile_B, model_name='gpt-4o-mini', + log_db=mock_log_db, role='leader', ) paper = agent.write_proposal( @@ -81,10 +85,11 @@ def test_write_proposal(mock_model_prompting: MagicMock) -> None: @patch('research_town.utils.agent_prompter.model_prompting') def test_write_review(mock_model_prompting: MagicMock) -> None: mock_model_prompting.side_effect = mock_prompting - + mock_log_db = MagicMock() agent = Agent( profile=profile_A, model_name='gpt-4o-mini', + log_db=mock_log_db, role='reviewer', ) review = agent.write_review( @@ -101,15 +106,17 @@ def test_write_review(mock_model_prompting: MagicMock) -> None: @patch('research_town.utils.agent_prompter.model_prompting') def test_write_metareview(mock_model_prompting: MagicMock) -> None: mock_model_prompting.side_effect = mock_prompting - + mock_log_db = MagicMock() agent_reviewer = Agent( profile=profile_A, model_name='gpt-4o-mini', + log_db=mock_log_db, role='reviewer', ) agent_chair = Agent( profile=profile_A, model_name='gpt-4o-mini', + log_db=mock_log_db, role='chair', ) review = agent_reviewer.write_review( @@ -132,15 +139,17 @@ def test_write_metareview(mock_model_prompting: MagicMock) -> None: @patch('research_town.utils.agent_prompter.model_prompting') def test_write_rebuttal(mock_model_prompting: MagicMock) -> None: mock_model_prompting.side_effect = mock_prompting - + mock_log_db = MagicMock() agent_reviewer = Agent( profile=profile_A, model_name='gpt-4o-mini', + log_db=mock_log_db, role='reviewer', ) agent_leader = Agent( profile=profile_A, model_name='gpt-4o-mini', + log_db=mock_log_db, role='leader', ) review = agent_reviewer.write_review( From fe696b5ee8b49edf658877a0f20fbd3f84f29222 Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Sun, 6 Oct 2024 20:29:54 -0500 Subject: [PATCH 15/19] minor fix --- tests/utils/test_serializer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/utils/test_serializer.py b/tests/utils/test_serializer.py index f15b3651..5312d41b 100644 --- a/tests/utils/test_serializer.py +++ b/tests/utils/test_serializer.py @@ -1,12 +1,17 @@ +from unittest.mock import MagicMock + from research_town.agents.agent import Agent from research_town.utils.serializer import Serializer from tests.constants.data_constants import profile_A def test_serializer() -> None: + mock_log_db = MagicMock() + agent = Agent( profile=profile_A, model_name='gpt-4o-mini', + log_db=mock_log_db, role='leader', ) agent_serialized = Serializer.serialize(agent) From c7a07e14396edb28f0787b58868e78e47ca6c0a7 Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Mon, 7 Oct 2024 02:29:52 -0500 Subject: [PATCH 16/19] make log add happen in env --- research_town/agents/agent.py | 69 ++++++++++++---------- research_town/agents/agent_manager.py | 3 +- research_town/data/__init__.py | 4 ++ research_town/data/data.py | 13 +++- research_town/envs/env_proposal_writing.py | 9 ++- research_town/envs/env_review_writing.py | 6 +- tests/agents/test_agents.py | 14 ----- tests/utils/test_serializer.py | 5 -- 8 files changed, 62 insertions(+), 61 deletions(-) diff --git a/research_town/agents/agent.py b/research_town/agents/agent.py index 92ba7f50..4c91df75 100644 --- a/research_town/agents/agent.py +++ b/research_town/agents/agent.py @@ -1,3 +1,5 @@ +import uuid + from beartype import beartype from beartype.typing import Dict, List, Literal, Tuple @@ -9,6 +11,7 @@ LiteratureReviewLog, MetaReview, MetaReviewWritingLog, + OpenAIPrompt, Paper, Profile, Proposal, @@ -18,7 +21,6 @@ Review, ReviewWritingLog, ) -from ..dbs import LogDB from ..utils.agent_prompter import ( brainstorm_idea_prompting, discuss_idea_prompting, @@ -44,14 +46,12 @@ def __init__( self, profile: Profile, model_name: str, - log_db: LogDB, role: Role = None, ) -> None: self.profile: Profile = profile self.memory: Dict[str, str] = {} self.role: Role = role self.model_name: str = model_name - self.log_db: LogDB = log_db self.serializer = Serializer() @beartype @@ -65,7 +65,7 @@ def review_literature( papers: List[Paper], contexts: List[str], config: Config, - ) -> Tuple[str, List[str], Insight]: + ) -> Tuple[str, List[str], Insight, LiteratureReviewLog]: serialized_papers = self.serializer.serialize(papers) serialized_profile = self.serializer.serialize(self.profile) summary, keywords, valuable_points, prompt = review_literature_prompting( @@ -81,19 +81,20 @@ def review_literature( stream=config.param.stream, ) insight = Insight(content=valuable_points) + formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt) log_entry = LiteratureReviewLog( profile_pk=self.profile.pk, insight_pk=insight.pk, - prompt=prompt, # Store the formatted prompt + prompt_pk=formatted_prompt.pk, ) - self.log_db.add(log_entry) - return summary, keywords, insight + + return summary, keywords, insight, log_entry @beartype @member_required def brainstorm_idea( self, insights: List[Insight], papers: List[Paper], config: Config - ) -> Idea: + ) -> Tuple[Idea, IdeaBrainstormLog]: serialized_insights = self.serializer.serialize(insights) serialized_papers = self.serializer.serialize(papers) idea_content_list, prompt = brainstorm_idea_prompting( @@ -110,12 +111,12 @@ def brainstorm_idea( ) idea_content = idea_content_list[0] idea = Idea(content=idea_content) - + formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt) log_entry = IdeaBrainstormLog( - profile_pk=self.profile.pk, idea_pk=idea.pk, prompt=prompt + profile_pk=self.profile.pk, idea_pk=idea.pk, prompt_pk=formatted_prompt.pk ) - self.log_db.add(log_entry) - return idea + + return idea, log_entry @beartype @member_required @@ -142,7 +143,7 @@ def discuss_idea( @member_required def write_proposal( self, idea: Idea, papers: List[Paper], config: Config - ) -> Proposal: + ) -> Tuple[Proposal, ProposalWritingLog]: serialized_idea = self.serializer.serialize(idea) serialized_papers = self.serializer.serialize(papers) @@ -179,19 +180,20 @@ def write_proposal( q4=q5_result.get('q4', ''), q5=q5_result.get('q5', ''), ) - + formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt) log_entry = ProposalWritingLog( profile_pk=self.profile.pk, proposal_pk=proposal_obj.pk, - prompt=prompt, # Store the formatted prompt + prompt_pk=formatted_prompt.pk, ) - self.log_db.add(log_entry) # Add the log entry to LogDB - return proposal_obj + return proposal_obj, log_entry @beartype @reviewer_required - def write_review(self, proposal: Proposal, config: Config) -> Review: + def write_review( + self, proposal: Proposal, config: Config + ) -> Tuple[Review, ReviewWritingLog]: serialized_proposal = self.serializer.serialize(proposal) summary, strength, weakness, ethical_concerns, score, prompt = ( @@ -219,13 +221,14 @@ def write_review(self, proposal: Proposal, config: Config) -> Review: ethical_concerns=ethical_concerns, score=score, ) - + formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt) log_entry = ReviewWritingLog( - profile_pk=self.profile.pk, review_pk=review_obj.pk, prompt=prompt + profile_pk=self.profile.pk, + review_pk=review_obj.pk, + prompt_pk=formatted_prompt.pk, ) - self.log_db.add(log_entry) - return review_obj + return review_obj, log_entry @beartype @chair_required @@ -234,7 +237,7 @@ def write_metareview( proposal: Proposal, reviews: List[Review], config: Config, - ) -> MetaReview: + ) -> Tuple[MetaReview, MetaReviewWritingLog]: serialized_proposal = self.serializer.serialize(proposal) serialized_reviews = self.serializer.serialize(reviews) @@ -267,13 +270,14 @@ def write_metareview( ethical_concerns=ethical_concerns, decision=decision, ) - + formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt) log_entry = MetaReviewWritingLog( - profile_pk=self.profile.pk, metareview_pk=metareview_obj.pk, prompt=prompt + profile_pk=self.profile.pk, + metareview_pk=metareview_obj.pk, + prompt_pk=formatted_prompt.pk, ) - self.log_db.add(log_entry) - return metareview_obj + return metareview_obj, log_entry @beartype @leader_required @@ -282,7 +286,7 @@ def write_rebuttal( proposal: Proposal, review: Review, config: Config, - ) -> Rebuttal: + ) -> Tuple[Rebuttal, RebuttalWritingLog]: serialized_proposal = self.serializer.serialize(proposal) serialized_review = self.serializer.serialize(review) @@ -309,10 +313,11 @@ def write_rebuttal( q4=q5_result.get('q4', ''), q5=q5_result.get('q5', ''), ) - + formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt) log_entry = RebuttalWritingLog( - profile_pk=self.profile.pk, rebuttal_pk=rebuttal_obj.pk, prompt=prompt + profile_pk=self.profile.pk, + rebuttal_pk=rebuttal_obj.pk, + prompt_pk=formatted_prompt.pk, ) - self.log_db.add(log_entry) - return rebuttal_obj + return rebuttal_obj, log_entry diff --git a/research_town/agents/agent_manager.py b/research_town/agents/agent_manager.py index 3572a937..6d69d4e5 100644 --- a/research_town/agents/agent_manager.py +++ b/research_town/agents/agent_manager.py @@ -2,7 +2,7 @@ from ..configs import Config from ..data import Profile, Proposal -from ..dbs import LogDB, ProfileDB +from ..dbs import ProfileDB from .agent import Agent Role = Literal['reviewer', 'leader', 'member', 'chair'] @@ -17,7 +17,6 @@ def create_agent(self, profile: Profile, role: Role) -> Agent: return Agent( profile=profile, role=role, - log_db=LogDB(), model_name=self.config.param.base_llm, ) diff --git a/research_town/data/__init__.py b/research_town/data/__init__.py index 685803b6..a535cf50 100644 --- a/research_town/data/__init__.py +++ b/research_town/data/__init__.py @@ -6,9 +6,11 @@ Log, MetaReview, MetaReviewWritingLog, + OpenAIPrompt, Paper, Profile, Progress, + Prompt, Proposal, ProposalWritingLog, Rebuttal, @@ -27,6 +29,8 @@ 'ReviewWritingLog', 'ExperimentLog', 'Progress', + 'Prompt', + 'OpenAIPrompt', 'Paper', 'Profile', 'Idea', diff --git a/research_town/data/data.py b/research_town/data/data.py index 5fc3a2e8..0842404d 100644 --- a/research_town/data/data.py +++ b/research_town/data/data.py @@ -9,6 +9,15 @@ class Data(BaseModel): project_name: Optional[str] = Field(default=None) +class Prompt(BaseModel): + pk: str = Field(default_factory=lambda: str(uuid.uuid4())) + messages: Optional[Any] = Field(default=None) + + +class OpenAIPrompt(Prompt): + messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]] + + class Profile(Data): name: str bio: str @@ -43,9 +52,7 @@ class Paper(Data): class Log(Data): timestep: int = Field(default=0) profile_pk: str - prompt: Optional[Union[List[Dict[str, str]], List[List[Dict[str, str]]]]] = Field( - default=None - ) + prompt_pk: Optional[str] = Field(default=None) class LiteratureReviewLog(Log): diff --git a/research_town/envs/env_proposal_writing.py b/research_town/envs/env_proposal_writing.py index 37a70263..ee9dc72c 100644 --- a/research_town/envs/env_proposal_writing.py +++ b/research_town/envs/env_proposal_writing.py @@ -53,11 +53,12 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: query=';'.join(self.contexts), num=2, ) - summary, keywords, insight = member.review_literature( + summary, keywords, insight, log_entry = member.review_literature( papers=related_papers, contexts=self.contexts, config=self.config, ) + self.log_db.add(log_entry) yield insight, member insights.append(insight) keywords.extend(keywords) @@ -73,10 +74,11 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: else keywords[0], num=7, ) - idea = member.brainstorm_idea( + idea, log_entry = member.brainstorm_idea( papers=related_papers, insights=insights, config=self.config ) ideas.append(idea) + self.log_db.add(log_entry) yield idea, member # Leader discusses ideas @@ -94,11 +96,12 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: else None, num=2, ) - proposal = self.leader.write_proposal( + proposal, log_entry = self.leader.write_proposal( idea=summarized_idea, papers=related_papers, config=self.config, ) + self.log_db.add(log_entry) yield proposal, self.leader self.proposal = proposal # Store the proposal for use in on_exit diff --git a/research_town/envs/env_review_writing.py b/research_town/envs/env_review_writing.py index cab26f90..cb927af5 100644 --- a/research_town/envs/env_review_writing.py +++ b/research_town/envs/env_review_writing.py @@ -53,21 +53,23 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: # Review Writing self.reviews: List[Review] = [] for reviewer in self.reviewers: - review = reviewer.write_review( + review, log_entry = reviewer.write_review( proposal=self.proposal, config=self.config, ) self.reviews.append(review) + self.log_db.add(log_entry) yield review, reviewer # Rebuttal Submitting self.rebuttals: List[Rebuttal] = [] for review in self.reviews: - rebuttal = self.leader.write_rebuttal( + rebuttal, log_entry = self.leader.write_rebuttal( proposal=self.proposal, review=review, config=self.config, ) + self.log_db.add(log_entry) self.rebuttals.append(rebuttal) yield rebuttal, self.leader diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index ce91ce24..5472818d 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -21,11 +21,9 @@ def test_review_literature( mock_model_prompting: MagicMock, ) -> None: mock_model_prompting.side_effect = mock_prompting - mock_log_db = MagicMock() agent = Agent( profile=profile_A, model_name='gpt-4o-mini', - log_db=mock_log_db, role='leader', ) _, _, research_insight = agent.review_literature( @@ -45,11 +43,9 @@ def test_brainstorm_idea( mock_model_prompting: MagicMock, ) -> None: mock_model_prompting.side_effect = mock_prompting - mock_log_db = MagicMock() agent = Agent( profile=profile_A, model_name='gpt-4o-mini', - log_db=mock_log_db, role='leader', ) research_idea = agent.brainstorm_idea( @@ -65,11 +61,9 @@ def test_brainstorm_idea( @patch('research_town.utils.agent_prompter.model_prompting') def test_write_proposal(mock_model_prompting: MagicMock) -> None: mock_model_prompting.side_effect = mock_prompting - mock_log_db = MagicMock() agent = Agent( profile=profile_B, model_name='gpt-4o-mini', - log_db=mock_log_db, role='leader', ) paper = agent.write_proposal( @@ -85,11 +79,9 @@ def test_write_proposal(mock_model_prompting: MagicMock) -> None: @patch('research_town.utils.agent_prompter.model_prompting') def test_write_review(mock_model_prompting: MagicMock) -> None: mock_model_prompting.side_effect = mock_prompting - mock_log_db = MagicMock() agent = Agent( profile=profile_A, model_name='gpt-4o-mini', - log_db=mock_log_db, role='reviewer', ) review = agent.write_review( @@ -106,17 +98,14 @@ def test_write_review(mock_model_prompting: MagicMock) -> None: @patch('research_town.utils.agent_prompter.model_prompting') def test_write_metareview(mock_model_prompting: MagicMock) -> None: mock_model_prompting.side_effect = mock_prompting - mock_log_db = MagicMock() agent_reviewer = Agent( profile=profile_A, model_name='gpt-4o-mini', - log_db=mock_log_db, role='reviewer', ) agent_chair = Agent( profile=profile_A, model_name='gpt-4o-mini', - log_db=mock_log_db, role='chair', ) review = agent_reviewer.write_review( @@ -139,17 +128,14 @@ def test_write_metareview(mock_model_prompting: MagicMock) -> None: @patch('research_town.utils.agent_prompter.model_prompting') def test_write_rebuttal(mock_model_prompting: MagicMock) -> None: mock_model_prompting.side_effect = mock_prompting - mock_log_db = MagicMock() agent_reviewer = Agent( profile=profile_A, model_name='gpt-4o-mini', - log_db=mock_log_db, role='reviewer', ) agent_leader = Agent( profile=profile_A, model_name='gpt-4o-mini', - log_db=mock_log_db, role='leader', ) review = agent_reviewer.write_review( diff --git a/tests/utils/test_serializer.py b/tests/utils/test_serializer.py index 5312d41b..f15b3651 100644 --- a/tests/utils/test_serializer.py +++ b/tests/utils/test_serializer.py @@ -1,17 +1,12 @@ -from unittest.mock import MagicMock - from research_town.agents.agent import Agent from research_town.utils.serializer import Serializer from tests.constants.data_constants import profile_A def test_serializer() -> None: - mock_log_db = MagicMock() - agent = Agent( profile=profile_A, model_name='gpt-4o-mini', - log_db=mock_log_db, role='leader', ) agent_serialized = Serializer.serialize(agent) From a892c7765a5c6b0e62a19917d6637353e273b509 Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Mon, 7 Oct 2024 02:36:38 -0500 Subject: [PATCH 17/19] minor fix --- tests/agents/test_agents.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 5472818d..f258b189 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -26,7 +26,7 @@ def test_review_literature( model_name='gpt-4o-mini', role='leader', ) - _, _, research_insight = agent.review_literature( + _, _, research_insight, _ = agent.review_literature( papers=[paper_A, paper_B], contexts=[ "Much of the world's most valued data is stored in relational databases and data warehouses, where the data is organized into many tables connected by primary-foreign key relations. However, building machine learning models using this data is both challenging and time consuming. The core problem is that no machine learning method is capable of learning on multiple tables interconnected by primary-foreign key relations. Current methods can only learn from a single table, so the data must first be manually joined and aggregated into a single training table, the process known as feature engineering. Feature engineering is slow, error prone and leads to suboptimal models. Here we introduce an end-to-end deep representation learning approach to directly learn on data laid out across multiple tables. We name our approach Relational Deep Learning (RDL). The core idea is to view relational databases as a temporal, heterogeneous graph, with a node for each row in each table, and edges specified by primary-foreign key links. Message Passing Graph Neural Networks can then automatically learn across the graph to extract representations that leverage all input data, without any manual feature engineering. Relational Deep Learning leads to more accurate models that can be built much faster. To facilitate research in this area, we develop RelBench, a set of benchmark datasets and an implementation of Relational Deep Learning. The data covers a wide spectrum, from discussions on Stack Exchange to book reviews on the Amazon Product Catalog. Overall, we define a new research area that generalizes graph machine learning and broadens its applicability to a wide set of AI use cases." @@ -48,7 +48,7 @@ def test_brainstorm_idea( model_name='gpt-4o-mini', role='leader', ) - research_idea = agent.brainstorm_idea( + research_idea, _ = agent.brainstorm_idea( insights=[research_insight_A, research_insight_B], papers=[paper_A, paper_B], config=example_config, @@ -66,7 +66,7 @@ def test_write_proposal(mock_model_prompting: MagicMock) -> None: model_name='gpt-4o-mini', role='leader', ) - paper = agent.write_proposal( + paper, _ = agent.write_proposal( idea=research_idea_A, papers=[paper_A, paper_B], config=example_config, @@ -84,7 +84,7 @@ def test_write_review(mock_model_prompting: MagicMock) -> None: model_name='gpt-4o-mini', role='reviewer', ) - review = agent.write_review( + review, _ = agent.write_review( proposal=research_proposal_A, config=example_config, ) @@ -108,11 +108,11 @@ def test_write_metareview(mock_model_prompting: MagicMock) -> None: model_name='gpt-4o-mini', role='chair', ) - review = agent_reviewer.write_review( + review, _ = agent_reviewer.write_review( proposal=research_proposal_A, config=example_config, ) - metareview = agent_chair.write_metareview( + metareview, _ = agent_chair.write_metareview( proposal=research_proposal_A, reviews=[review], config=example_config, @@ -138,11 +138,11 @@ def test_write_rebuttal(mock_model_prompting: MagicMock) -> None: model_name='gpt-4o-mini', role='leader', ) - review = agent_reviewer.write_review( + review, _ = agent_reviewer.write_review( proposal=research_proposal_A, config=example_config, ) - rebuttal = agent_leader.write_rebuttal( + rebuttal, _ = agent_leader.write_rebuttal( proposal=research_proposal_A, review=review, config=example_config, From 36d81f9339c71f254032a58bd9eb6feeb0891e1a Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Mon, 7 Oct 2024 02:51:52 -0500 Subject: [PATCH 18/19] minor fix --- research_town/envs/env_review_writing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/research_town/envs/env_review_writing.py b/research_town/envs/env_review_writing.py index cb927af5..2cd82318 100644 --- a/research_town/envs/env_review_writing.py +++ b/research_town/envs/env_review_writing.py @@ -74,11 +74,12 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: yield rebuttal, self.leader # Paper Meta Reviewing - metareview = self.chair.write_metareview( + metareview, log_entry = self.chair.write_metareview( proposal=self.proposal, reviews=self.reviews, config=self.config, ) + self.log_db.add(log_entry) yield metareview, self.chair self.metareview = metareview From b919850f5f5f6439b757a143748e294b729ca3ef Mon Sep 17 00:00:00 2001 From: keyangds <107345948+keyangds@users.noreply.github.com> Date: Mon, 7 Oct 2024 02:59:03 -0500 Subject: [PATCH 19/19] minor fix --- research_town/envs/env_proposal_writing.py | 8 ++++---- research_town/envs/env_review_writing.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/research_town/envs/env_proposal_writing.py b/research_town/envs/env_proposal_writing.py index ee9dc72c..730fc7aa 100644 --- a/research_town/envs/env_proposal_writing.py +++ b/research_town/envs/env_proposal_writing.py @@ -74,11 +74,11 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: else keywords[0], num=7, ) - idea, log_entry = member.brainstorm_idea( + idea, brainstorm_log_entry = member.brainstorm_idea( papers=related_papers, insights=insights, config=self.config ) ideas.append(idea) - self.log_db.add(log_entry) + self.log_db.add(brainstorm_log_entry) yield idea, member # Leader discusses ideas @@ -96,12 +96,12 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: else None, num=2, ) - proposal, log_entry = self.leader.write_proposal( + proposal, proposal_log_entry = self.leader.write_proposal( idea=summarized_idea, papers=related_papers, config=self.config, ) - self.log_db.add(log_entry) + self.log_db.add(proposal_log_entry) yield proposal, self.leader self.proposal = proposal # Store the proposal for use in on_exit diff --git a/research_town/envs/env_review_writing.py b/research_town/envs/env_review_writing.py index 2cd82318..21099fba 100644 --- a/research_town/envs/env_review_writing.py +++ b/research_town/envs/env_review_writing.py @@ -53,33 +53,33 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: # Review Writing self.reviews: List[Review] = [] for reviewer in self.reviewers: - review, log_entry = reviewer.write_review( + review, review_log_entry = reviewer.write_review( proposal=self.proposal, config=self.config, ) self.reviews.append(review) - self.log_db.add(log_entry) + self.log_db.add(review_log_entry) yield review, reviewer # Rebuttal Submitting self.rebuttals: List[Rebuttal] = [] for review in self.reviews: - rebuttal, log_entry = self.leader.write_rebuttal( + rebuttal, rebuttal_log_entry = self.leader.write_rebuttal( proposal=self.proposal, review=review, config=self.config, ) - self.log_db.add(log_entry) + self.log_db.add(rebuttal_log_entry) self.rebuttals.append(rebuttal) yield rebuttal, self.leader # Paper Meta Reviewing - metareview, log_entry = self.chair.write_metareview( + metareview, metareview_log_entry = self.chair.write_metareview( proposal=self.proposal, reviews=self.reviews, config=self.config, ) - self.log_db.add(log_entry) + self.log_db.add(metareview_log_entry) yield metareview, self.chair self.metareview = metareview