diff --git a/backend/generator_func.py b/backend/generator_func.py index 5ad20843..f0ab7883 100644 --- a/backend/generator_func.py +++ b/backend/generator_func.py @@ -59,5 +59,4 @@ def run_engine( print(f'Error occurred during engine execution: {e}') finally: - print('Engine execution completed.') - yield None, None + ('Engine execution completed.') diff --git a/backend/main.py b/backend/main.py index 8fb82f4f..900eea32 100644 --- a/backend/main.py +++ b/backend/main.py @@ -40,7 +40,6 @@ def stop_process(user_id: str) -> None: process.terminate() # Safely terminate the process process.join() # Ensure cleanup del active_processes[user_id] - print(f'Process for user {user_id} stopped.') def background_task( @@ -48,10 +47,19 @@ def background_task( ) -> None: generator = run_engine(url) try: + # Generate and send results to the parent process for progress, agent in generator: child_conn.send((progress, agent)) + + while True: + if child_conn.poll(): + msg = child_conn.recv() + if msg == 'terminate': + break + except Exception as e: child_conn.send({'type': 'error', 'content': str(e)}) + finally: child_conn.send(None) child_conn.close() @@ -164,22 +172,28 @@ async def process_url(request: Request) -> StreamingResponse: async def stream_response() -> AsyncGenerator[str, None]: try: while True: + # Check if the client has disconnected if await request.is_disconnected(): print(f'Client disconnected for user {user_id}. Cancelling task.') stop_process(user_id) break + # Check for new data from the background task if parent_conn.poll(): result = parent_conn.recv() if result is None: + print(f'No more data for user {user_id}. Stopping task.') break + # Stream the formatted output to the client for formatted_output in format_response(generator_wrapper(result)): yield formatted_output else: await asyncio.sleep(0.1) finally: - stop_process(user_id) # Ensure process is stopped on function exit + if await request.is_disconnected(): + stop_process(user_id) + print(f'Process for user {user_id} stopped due to disconnection.') # Return the StreamingResponse return StreamingResponse(stream_response(), media_type='application/json') diff --git a/research_town/agents/agent.py b/research_town/agents/agent.py index 171c5511..4c91df75 100644 --- a/research_town/agents/agent.py +++ b/research_town/agents/agent.py @@ -1,8 +1,26 @@ +import uuid + from beartype import beartype from beartype.typing import Dict, List, Literal, Tuple from ..configs import Config -from ..data import Idea, Insight, MetaReview, Paper, Profile, Proposal, Rebuttal, Review +from ..data import ( + Idea, + IdeaBrainstormLog, + Insight, + LiteratureReviewLog, + MetaReview, + MetaReviewWritingLog, + OpenAIPrompt, + Paper, + Profile, + Proposal, + ProposalWritingLog, + Rebuttal, + RebuttalWritingLog, + Review, + ReviewWritingLog, +) from ..utils.agent_prompter import ( brainstorm_idea_prompting, discuss_idea_prompting, @@ -47,10 +65,10 @@ def review_literature( papers: List[Paper], contexts: List[str], config: Config, - ) -> Tuple[str, List[str], Insight]: + ) -> Tuple[str, List[str], Insight, LiteratureReviewLog]: serialized_papers = self.serializer.serialize(papers) serialized_profile = self.serializer.serialize(self.profile) - summary, keywords, valuable_points = review_literature_prompting( + summary, keywords, valuable_points, prompt = review_literature_prompting( profile=serialized_profile, papers=serialized_papers, contexts=contexts, @@ -63,16 +81,23 @@ def review_literature( stream=config.param.stream, ) insight = Insight(content=valuable_points) - return summary, keywords, insight + formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt) + log_entry = LiteratureReviewLog( + profile_pk=self.profile.pk, + insight_pk=insight.pk, + prompt_pk=formatted_prompt.pk, + ) + + return summary, keywords, insight, log_entry @beartype @member_required def brainstorm_idea( self, insights: List[Insight], papers: List[Paper], config: Config - ) -> Idea: + ) -> Tuple[Idea, IdeaBrainstormLog]: serialized_insights = self.serializer.serialize(insights) serialized_papers = self.serializer.serialize(papers) - idea_content = brainstorm_idea_prompting( + idea_content_list, prompt = brainstorm_idea_prompting( bio=self.profile.bio, insights=serialized_insights, papers=serialized_papers, @@ -83,8 +108,15 @@ def brainstorm_idea( temperature=config.param.temperature, top_p=config.param.top_p, stream=config.param.stream, - )[0] - return Idea(content=idea_content) + ) + idea_content = idea_content_list[0] + idea = Idea(content=idea_content) + formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt) + log_entry = IdeaBrainstormLog( + profile_pk=self.profile.pk, idea_pk=idea.pk, prompt_pk=formatted_prompt.pk + ) + + return idea, log_entry @beartype @member_required @@ -92,7 +124,7 @@ def discuss_idea( self, ideas: List[Idea], contexts: List[str], config: Config ) -> Idea: serialized_ideas = self.serializer.serialize(ideas) - idea_summarized = discuss_idea_prompting( + idea_summarized_list, prompt = discuss_idea_prompting( bio=self.profile.bio, contexts=contexts, ideas=serialized_ideas, @@ -103,14 +135,15 @@ def discuss_idea( temperature=config.param.temperature, top_p=config.param.top_p, stream=config.param.stream, - )[0] + ) + idea_summarized = idea_summarized_list[0] return Idea(content=idea_summarized) @beartype @member_required def write_proposal( self, idea: Idea, papers: List[Paper], config: Config - ) -> Proposal: + ) -> Tuple[Proposal, ProposalWritingLog]: serialized_idea = self.serializer.serialize(idea) serialized_papers = self.serializer.serialize(papers) @@ -127,7 +160,7 @@ def write_proposal( print('write_proposal_strategy not supported, will use default') prompt_template = config.agent_prompt_template.write_proposal - proposal, q5_result = write_proposal_prompting( + proposal, q5_result, prompt = write_proposal_prompting( idea=serialized_idea, papers=serialized_papers, model_name=self.model_name, @@ -138,7 +171,8 @@ def write_proposal( top_p=config.param.top_p, stream=config.param.stream, ) - return Proposal( + + proposal_obj = Proposal( content=proposal, q1=q5_result.get('q1', ''), q2=q5_result.get('q2', ''), @@ -146,27 +180,39 @@ def write_proposal( q4=q5_result.get('q4', ''), q5=q5_result.get('q5', ''), ) + formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt) + log_entry = ProposalWritingLog( + profile_pk=self.profile.pk, + proposal_pk=proposal_obj.pk, + prompt_pk=formatted_prompt.pk, + ) + + return proposal_obj, log_entry @beartype @reviewer_required - def write_review(self, proposal: Proposal, config: Config) -> Review: + def write_review( + self, proposal: Proposal, config: Config + ) -> Tuple[Review, ReviewWritingLog]: serialized_proposal = self.serializer.serialize(proposal) - summary, strength, weakness, ethical_concerns, score = write_review_prompting( - proposal=serialized_proposal, - model_name=self.model_name, - summary_prompt_template=config.agent_prompt_template.write_review_summary, - strength_prompt_template=config.agent_prompt_template.write_review_strength, - weakness_prompt_template=config.agent_prompt_template.write_review_weakness, - ethical_prompt_template=config.agent_prompt_template.write_review_ethical, - score_prompt_template=config.agent_prompt_template.write_review_score, - return_num=config.param.return_num, - max_token_num=config.param.max_token_num, - temperature=config.param.temperature, - top_p=config.param.top_p, - stream=config.param.stream, + summary, strength, weakness, ethical_concerns, score, prompt = ( + write_review_prompting( + proposal=serialized_proposal, + model_name=self.model_name, + summary_prompt_template=config.agent_prompt_template.write_review_summary, + strength_prompt_template=config.agent_prompt_template.write_review_strength, + weakness_prompt_template=config.agent_prompt_template.write_review_weakness, + ethical_prompt_template=config.agent_prompt_template.write_review_ethical, + score_prompt_template=config.agent_prompt_template.write_review_score, + return_num=config.param.return_num, + max_token_num=config.param.max_token_num, + temperature=config.param.temperature, + top_p=config.param.top_p, + stream=config.param.stream, + ) ) - return Review( + review_obj = Review( proposal_pk=proposal.pk, reviewer_pk=self.profile.pk, summary=summary, @@ -175,6 +221,14 @@ def write_review(self, proposal: Proposal, config: Config) -> Review: ethical_concerns=ethical_concerns, score=score, ) + formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt) + log_entry = ReviewWritingLog( + profile_pk=self.profile.pk, + review_pk=review_obj.pk, + prompt_pk=formatted_prompt.pk, + ) + + return review_obj, log_entry @beartype @chair_required @@ -183,11 +237,11 @@ def write_metareview( proposal: Proposal, reviews: List[Review], config: Config, - ) -> MetaReview: + ) -> Tuple[MetaReview, MetaReviewWritingLog]: serialized_proposal = self.serializer.serialize(proposal) serialized_reviews = self.serializer.serialize(reviews) - summary, strength, weakness, ethical_concerns, decision = ( + summary, strength, weakness, ethical_concerns, decision, prompt = ( write_metareview_prompting( proposal=serialized_proposal, reviews=serialized_reviews, @@ -205,7 +259,7 @@ def write_metareview( ) ) - return MetaReview( + metareview_obj = MetaReview( proposal_pk=proposal.pk, chair_pk=self.profile.pk, reviewer_pks=[review.reviewer_pk for review in reviews], @@ -216,6 +270,14 @@ def write_metareview( ethical_concerns=ethical_concerns, decision=decision, ) + formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt) + log_entry = MetaReviewWritingLog( + profile_pk=self.profile.pk, + metareview_pk=metareview_obj.pk, + prompt_pk=formatted_prompt.pk, + ) + + return metareview_obj, log_entry @beartype @leader_required @@ -224,11 +286,11 @@ def write_rebuttal( proposal: Proposal, review: Review, config: Config, - ) -> Rebuttal: + ) -> Tuple[Rebuttal, RebuttalWritingLog]: serialized_proposal = self.serializer.serialize(proposal) serialized_review = self.serializer.serialize(review) - rebuttal_content, q5_result = write_rebuttal_prompting( + rebuttal_content, q5_result, prompt = write_rebuttal_prompting( proposal=serialized_proposal, review=serialized_review, model_name=self.model_name, @@ -240,7 +302,7 @@ def write_rebuttal( stream=config.param.stream, ) - return Rebuttal( + rebuttal_obj = Rebuttal( proposal_pk=proposal.pk, reviewer_pk=review.reviewer_pk, author_pk=self.profile.pk, @@ -251,3 +313,11 @@ def write_rebuttal( q4=q5_result.get('q4', ''), q5=q5_result.get('q5', ''), ) + formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt) + log_entry = RebuttalWritingLog( + profile_pk=self.profile.pk, + rebuttal_pk=rebuttal_obj.pk, + prompt_pk=formatted_prompt.pk, + ) + + return rebuttal_obj, log_entry diff --git a/research_town/data/__init__.py b/research_town/data/__init__.py index 685803b6..a535cf50 100644 --- a/research_town/data/__init__.py +++ b/research_town/data/__init__.py @@ -6,9 +6,11 @@ Log, MetaReview, MetaReviewWritingLog, + OpenAIPrompt, Paper, Profile, Progress, + Prompt, Proposal, ProposalWritingLog, Rebuttal, @@ -27,6 +29,8 @@ 'ReviewWritingLog', 'ExperimentLog', 'Progress', + 'Prompt', + 'OpenAIPrompt', 'Paper', 'Profile', 'Idea', diff --git a/research_town/data/data.py b/research_town/data/data.py index a00f2852..0842404d 100644 --- a/research_town/data/data.py +++ b/research_town/data/data.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, ConfigDict, Field @@ -9,6 +9,15 @@ class Data(BaseModel): project_name: Optional[str] = Field(default=None) +class Prompt(BaseModel): + pk: str = Field(default_factory=lambda: str(uuid.uuid4())) + messages: Optional[Any] = Field(default=None) + + +class OpenAIPrompt(Prompt): + messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]] + + class Profile(Data): name: str bio: str @@ -43,6 +52,7 @@ class Paper(Data): class Log(Data): timestep: int = Field(default=0) profile_pk: str + prompt_pk: Optional[str] = Field(default=None) class LiteratureReviewLog(Log): diff --git a/research_town/envs/env_proposal_writing.py b/research_town/envs/env_proposal_writing.py index 37a70263..730fc7aa 100644 --- a/research_town/envs/env_proposal_writing.py +++ b/research_town/envs/env_proposal_writing.py @@ -53,11 +53,12 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: query=';'.join(self.contexts), num=2, ) - summary, keywords, insight = member.review_literature( + summary, keywords, insight, log_entry = member.review_literature( papers=related_papers, contexts=self.contexts, config=self.config, ) + self.log_db.add(log_entry) yield insight, member insights.append(insight) keywords.extend(keywords) @@ -73,10 +74,11 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: else keywords[0], num=7, ) - idea = member.brainstorm_idea( + idea, brainstorm_log_entry = member.brainstorm_idea( papers=related_papers, insights=insights, config=self.config ) ideas.append(idea) + self.log_db.add(brainstorm_log_entry) yield idea, member # Leader discusses ideas @@ -94,11 +96,12 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: else None, num=2, ) - proposal = self.leader.write_proposal( + proposal, proposal_log_entry = self.leader.write_proposal( idea=summarized_idea, papers=related_papers, config=self.config, ) + self.log_db.add(proposal_log_entry) yield proposal, self.leader self.proposal = proposal # Store the proposal for use in on_exit diff --git a/research_town/envs/env_review_writing.py b/research_town/envs/env_review_writing.py index cab26f90..21099fba 100644 --- a/research_town/envs/env_review_writing.py +++ b/research_town/envs/env_review_writing.py @@ -53,30 +53,33 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: # Review Writing self.reviews: List[Review] = [] for reviewer in self.reviewers: - review = reviewer.write_review( + review, review_log_entry = reviewer.write_review( proposal=self.proposal, config=self.config, ) self.reviews.append(review) + self.log_db.add(review_log_entry) yield review, reviewer # Rebuttal Submitting self.rebuttals: List[Rebuttal] = [] for review in self.reviews: - rebuttal = self.leader.write_rebuttal( + rebuttal, rebuttal_log_entry = self.leader.write_rebuttal( proposal=self.proposal, review=review, config=self.config, ) + self.log_db.add(rebuttal_log_entry) self.rebuttals.append(rebuttal) yield rebuttal, self.leader # Paper Meta Reviewing - metareview = self.chair.write_metareview( + metareview, metareview_log_entry = self.chair.write_metareview( proposal=self.proposal, reviews=self.reviews, config=self.config, ) + self.log_db.add(metareview_log_entry) yield metareview, self.chair self.metareview = metareview diff --git a/research_town/utils/agent_prompter.py b/research_town/utils/agent_prompter.py index a1e9f666..a899eb68 100644 --- a/research_town/utils/agent_prompter.py +++ b/research_town/utils/agent_prompter.py @@ -28,7 +28,7 @@ def review_literature_prompting( temperature: Optional[float] = 0.0, top_p: Optional[float] = None, stream: Optional[bool] = None, -) -> Tuple[str, List[str], str]: +) -> Tuple[str, List[str], str, list[dict[str, str]]]: papers_str = map_paper_list_to_str(papers) template_input = { 'bio': profile['bio'], @@ -37,6 +37,8 @@ def review_literature_prompting( } messages = openai_format_prompt_construct(prompt_template, template_input) + formatted_prompt = messages + insight = model_prompting( model_name, messages, @@ -64,7 +66,7 @@ def review_literature_prompting( valuable_points = ( valuable_points_match.group(1).strip() if valuable_points_match else '' ) - return summary, keywords, valuable_points + return summary, keywords, valuable_points, formatted_prompt @beartype @@ -79,11 +81,12 @@ def brainstorm_idea_prompting( temperature: Optional[float] = 0.0, top_p: Optional[float] = None, stream: Optional[bool] = None, -) -> List[str]: +) -> Tuple[List[str], List[Dict[str, str]]]: insights_str = map_insight_list_to_str(insights) papers_str = map_paper_list_to_str(papers) template_input = {'bio': bio, 'insights': insights_str, 'papers': papers_str} messages = openai_format_prompt_construct(prompt_template, template_input) + formatted_prompt = messages return model_prompting( model_name, messages, @@ -92,7 +95,7 @@ def brainstorm_idea_prompting( temperature=temperature, top_p=top_p, stream=stream, - ) + ), formatted_prompt @beartype @@ -107,10 +110,12 @@ def discuss_idea_prompting( temperature: Optional[float] = 0.0, top_p: Optional[float] = None, stream: Optional[bool] = None, -) -> List[str]: +) -> Tuple[List[str], List[Dict[str, str]]]: ideas_str = map_idea_list_to_str(ideas) template_input = {'bio': bio, 'ideas': ideas_str, 'contexts': contexts} messages = openai_format_prompt_construct(prompt_template, template_input) + formatted_prompt = messages + return model_prompting( model_name, messages, @@ -119,7 +124,7 @@ def discuss_idea_prompting( temperature=temperature, top_p=top_p, stream=stream, - ) + ), formatted_prompt @beartype @@ -133,11 +138,13 @@ def write_proposal_prompting( temperature: Optional[float] = 0.0, top_p: Optional[float] = None, stream: Optional[bool] = None, -) -> Tuple[str, Dict[str, str]]: +) -> Tuple[str, Dict[str, str], List[Dict[str, str]]]: idea_str = map_idea_to_str(idea) papers_str = map_paper_list_to_str(papers) template_input = {'idea': idea_str, 'papers': papers_str} messages = openai_format_prompt_construct(prompt_template, template_input) + + formatted_prompt = messages proposal = model_prompting( model_name, messages, @@ -157,7 +164,7 @@ def write_proposal_prompting( answer = match[1].strip() q5_result[question_number] = answer - return proposal, q5_result + return proposal, q5_result, formatted_prompt @beartype @@ -174,7 +181,8 @@ def write_review_prompting( temperature: Optional[float] = 0.0, top_p: Optional[float] = None, stream: Optional[bool] = None, -) -> Tuple[str, str, str, str, int]: +) -> Tuple[str, str, str, str, int, List[List[Dict[str, str]]]]: + formatted_prompts: List[List[Dict[str, str]]] = [] proposal_str = map_proposal_to_str(proposal) summary_template_input = {'proposal': proposal_str} summary_messages = openai_format_prompt_construct( @@ -194,14 +202,17 @@ def write_review_prompting( strength_messages = openai_format_prompt_construct( strength_prompt_template, strength_template_input ) + formatted_prompts.append(strength_messages) weakness_template_input = {'proposal': proposal_str, 'summary': summary} weakness_messages = openai_format_prompt_construct( weakness_prompt_template, weakness_template_input ) + formatted_prompts.append(weakness_messages) ethical_template_input = {'proposal': proposal_str, 'summary': summary} ethical_messages = openai_format_prompt_construct( ethical_prompt_template, ethical_template_input ) + formatted_prompts.append(ethical_messages) strength = model_prompting( model_name, @@ -241,6 +252,8 @@ def write_review_prompting( score_messages = openai_format_prompt_construct( score_prompt_template, score_template_input ) + + formatted_prompts.append(score_messages) score_str = ( model_prompting( model_name, @@ -258,7 +271,7 @@ def write_review_prompting( ) score = int(score_str[0]) if score_str[0].isdigit() else 0 - return summary, strength, weakness, ethical_concerns, score + return summary, strength, weakness, ethical_concerns, score, formatted_prompts @beartype @@ -276,7 +289,8 @@ def write_metareview_prompting( temperature: Optional[float] = 0.0, top_p: Optional[float] = None, stream: Optional[bool] = None, -) -> Tuple[str, str, str, str, bool]: +) -> Tuple[str, str, str, str, bool, List[List[Dict[str, str]]]]: + formatted_prompts: List[List[Dict[str, str]]] = [] proposal_str = map_proposal_to_str(proposal) reviews_str = map_review_list_to_str(reviews) summary_template_input = { @@ -286,6 +300,7 @@ def write_metareview_prompting( summary_messages = openai_format_prompt_construct( summary_prompt_template, summary_template_input ) + formatted_prompts.append(summary_messages) summary = model_prompting( model_name, summary_messages, @@ -314,12 +329,15 @@ def write_metareview_prompting( strength_messages = openai_format_prompt_construct( strength_prompt_template, strength_template_input ) + formatted_prompts.append(strength_messages) weakness_messages = openai_format_prompt_construct( weakness_prompt_template, weakness_template_input ) + formatted_prompts.append(weakness_messages) ethical_messages = openai_format_prompt_construct( ethical_prompt_template, ethical_template_input ) + formatted_prompts.append(ethical_messages) strength = model_prompting( model_name, @@ -360,6 +378,7 @@ def write_metareview_prompting( decision_messages = openai_format_prompt_construct( decision_prompt_template, decision_template_input ) + formatted_prompts.append(decision_messages) decision_str = model_prompting( model_name, decision_messages, @@ -371,7 +390,7 @@ def write_metareview_prompting( ) decision = 'accept' in decision_str[0].lower() - return summary, strength, weakness, ethical_concerns, decision + return summary, strength, weakness, ethical_concerns, decision, formatted_prompts @beartype @@ -385,11 +404,12 @@ def write_rebuttal_prompting( temperature: Optional[float] = 0.0, top_p: Optional[float] = None, stream: Optional[bool] = None, -) -> Tuple[str, Dict[str, str]]: +) -> Tuple[str, Dict[str, str], List[Dict[str, str]]]: proposal_str = map_proposal_to_str(proposal) review_str = map_review_to_str(review) template_input = {'proposal': proposal_str, 'review': review_str} messages = openai_format_prompt_construct(prompt_template, template_input) + formatted_prompt = messages rebuttal = model_prompting( model_name, messages, @@ -409,4 +429,4 @@ def write_rebuttal_prompting( answer = match[1].strip() q5_result[question_number] = answer - return rebuttal, q5_result + return rebuttal, q5_result, formatted_prompt diff --git a/research_town/utils/prompt_constructor.py b/research_town/utils/prompt_constructor.py index 37d97549..c0072e03 100644 --- a/research_town/utils/prompt_constructor.py +++ b/research_town/utils/prompt_constructor.py @@ -5,7 +5,6 @@ def openai_format_prompt_construct( template: Dict[str, Union[str, List[str]]], input_data: Dict[str, Any] ) -> List[Dict[str, str]]: messages = [] - if 'sys_prompt' in template: sys_prompt = template['sys_prompt'] assert isinstance(sys_prompt, str) diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 33b60794..f258b189 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -26,7 +26,7 @@ def test_review_literature( model_name='gpt-4o-mini', role='leader', ) - _, _, research_insight = agent.review_literature( + _, _, research_insight, _ = agent.review_literature( papers=[paper_A, paper_B], contexts=[ "Much of the world's most valued data is stored in relational databases and data warehouses, where the data is organized into many tables connected by primary-foreign key relations. However, building machine learning models using this data is both challenging and time consuming. The core problem is that no machine learning method is capable of learning on multiple tables interconnected by primary-foreign key relations. Current methods can only learn from a single table, so the data must first be manually joined and aggregated into a single training table, the process known as feature engineering. Feature engineering is slow, error prone and leads to suboptimal models. Here we introduce an end-to-end deep representation learning approach to directly learn on data laid out across multiple tables. We name our approach Relational Deep Learning (RDL). The core idea is to view relational databases as a temporal, heterogeneous graph, with a node for each row in each table, and edges specified by primary-foreign key links. Message Passing Graph Neural Networks can then automatically learn across the graph to extract representations that leverage all input data, without any manual feature engineering. Relational Deep Learning leads to more accurate models that can be built much faster. To facilitate research in this area, we develop RelBench, a set of benchmark datasets and an implementation of Relational Deep Learning. The data covers a wide spectrum, from discussions on Stack Exchange to book reviews on the Amazon Product Catalog. Overall, we define a new research area that generalizes graph machine learning and broadens its applicability to a wide set of AI use cases." @@ -43,13 +43,12 @@ def test_brainstorm_idea( mock_model_prompting: MagicMock, ) -> None: mock_model_prompting.side_effect = mock_prompting - agent = Agent( profile=profile_A, model_name='gpt-4o-mini', role='leader', ) - research_idea = agent.brainstorm_idea( + research_idea, _ = agent.brainstorm_idea( insights=[research_insight_A, research_insight_B], papers=[paper_A, paper_B], config=example_config, @@ -62,13 +61,12 @@ def test_brainstorm_idea( @patch('research_town.utils.agent_prompter.model_prompting') def test_write_proposal(mock_model_prompting: MagicMock) -> None: mock_model_prompting.side_effect = mock_prompting - agent = Agent( profile=profile_B, model_name='gpt-4o-mini', role='leader', ) - paper = agent.write_proposal( + paper, _ = agent.write_proposal( idea=research_idea_A, papers=[paper_A, paper_B], config=example_config, @@ -81,13 +79,12 @@ def test_write_proposal(mock_model_prompting: MagicMock) -> None: @patch('research_town.utils.agent_prompter.model_prompting') def test_write_review(mock_model_prompting: MagicMock) -> None: mock_model_prompting.side_effect = mock_prompting - agent = Agent( profile=profile_A, model_name='gpt-4o-mini', role='reviewer', ) - review = agent.write_review( + review, _ = agent.write_review( proposal=research_proposal_A, config=example_config, ) @@ -101,7 +98,6 @@ def test_write_review(mock_model_prompting: MagicMock) -> None: @patch('research_town.utils.agent_prompter.model_prompting') def test_write_metareview(mock_model_prompting: MagicMock) -> None: mock_model_prompting.side_effect = mock_prompting - agent_reviewer = Agent( profile=profile_A, model_name='gpt-4o-mini', @@ -112,11 +108,11 @@ def test_write_metareview(mock_model_prompting: MagicMock) -> None: model_name='gpt-4o-mini', role='chair', ) - review = agent_reviewer.write_review( + review, _ = agent_reviewer.write_review( proposal=research_proposal_A, config=example_config, ) - metareview = agent_chair.write_metareview( + metareview, _ = agent_chair.write_metareview( proposal=research_proposal_A, reviews=[review], config=example_config, @@ -132,7 +128,6 @@ def test_write_metareview(mock_model_prompting: MagicMock) -> None: @patch('research_town.utils.agent_prompter.model_prompting') def test_write_rebuttal(mock_model_prompting: MagicMock) -> None: mock_model_prompting.side_effect = mock_prompting - agent_reviewer = Agent( profile=profile_A, model_name='gpt-4o-mini', @@ -143,11 +138,11 @@ def test_write_rebuttal(mock_model_prompting: MagicMock) -> None: model_name='gpt-4o-mini', role='leader', ) - review = agent_reviewer.write_review( + review, _ = agent_reviewer.write_review( proposal=research_proposal_A, config=example_config, ) - rebuttal = agent_leader.write_rebuttal( + rebuttal, _ = agent_leader.write_rebuttal( proposal=research_proposal_A, review=review, config=example_config,