From f4289ef5a7b1ab074adee22e04a0a1cba74e0e69 Mon Sep 17 00:00:00 2001 From: Haofei Yu <1125027232@qq.com> Date: Wed, 9 Oct 2024 16:51:10 -0500 Subject: [PATCH] update the code --- .../envs/env_proposal_writing_with_rag.py | 36 ++++++++++--------- .../envs/env_proposal_writing_without_rag.py | 33 +++++++++-------- 2 files changed, 36 insertions(+), 33 deletions(-) diff --git a/research_town/envs/env_proposal_writing_with_rag.py b/research_town/envs/env_proposal_writing_with_rag.py index 88283771..0a1468c4 100644 --- a/research_town/envs/env_proposal_writing_with_rag.py +++ b/research_town/envs/env_proposal_writing_with_rag.py @@ -3,7 +3,7 @@ from ..agents import Agent, AgentManager from ..configs import Config -from ..data import Idea, Insight, Progress +from ..data import Idea, Insight, Progress, Proposal from ..dbs import LogDB, PaperDB, ProgressDB from ..utils.sampler import sample_ideas from .env_base import BaseEnv @@ -19,38 +19,40 @@ def __init__( config: Config, agent_manager: AgentManager, ) -> None: - super().__init__( - name=name, - config=config, - ) + super().__init__(name=name, config=config) self.log_db = log_db self.progress_db = progress_db self.paper_db = paper_db self.agent_manager = agent_manager - # self.user_rag = use_rag + self.proposals: List[Proposal] = [] @beartype def on_enter(self, **context: Any) -> None: + # Assign leader and members from context or sample them self.leader = context.get('leader', self.agent_manager.sample_leader()) self.members = context.get('members', self.agent_manager.sample_members()) - # must have contexts otherwise throw error + + if 'contexts' not in context: + raise ValueError("'contexts' is required in the context.") self.contexts = context['contexts'] @beartype def on_exit(self) -> Tuple[str, Dict[str, Any]]: + # Update environment run number and handle limits self.env_run_num += 1 if self.env_run_num > self.config.param.max_env_run_num: - return 'error', {} - else: - return 'start_review', {'proposals': self.proposals, 'leader': self.leader} + return 'error', {} # Return error if max run limit exceeded + return 'start_review', {'proposals': self.proposals, 'leader': self.leader} @beartype def run(self) -> Generator[Tuple[Progress, Agent], None, None]: - # Each member reviews literature insights: List[Insight] = [] - keywords: List[str] = [] + all_keywords: List[str] = [] ideas: List[Idea] = [] + researchers = self.members + [self.leader] + + # Step 1: Researchers review literature and gather insights for researcher in researchers: related_papers = self.paper_db.search_papers( query=';'.join(self.contexts), @@ -64,15 +66,17 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: yield insight, researcher insights.append(insight) - keywords.extend(keywords) + all_keywords.extend(keywords) - keyword = sorted(keywords, key=lambda x: x[1], reverse=True)[0] + # Step 2: Choose the most frequent keyword + top_keyword = sorted(all_keywords, key=lambda x: x[1], reverse=True)[0] + # Step 3: Researchers brainstorm ideas based on their insights and related papers for researcher in researchers: related_papers = self.paper_db.search_papers( query=insight.content, author=researcher.profile.name, - domain=keyword + researcher.profile.domain[0], + domain=top_keyword + researcher.profile.domain[0], num=self.config.param.related_paper_num, ) idea = researcher.brainstorm_idea( @@ -81,7 +85,7 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: yield idea, researcher ideas.append(idea) - self.proposals = [] + # Step 4: Leader summarizes ideas and writes proposals idea_combos = sample_ideas(ideas, self.config.param.proposal_num) for idea_combo in idea_combos: summarized_idea = self.leader.summarize_idea( diff --git a/research_town/envs/env_proposal_writing_without_rag.py b/research_town/envs/env_proposal_writing_without_rag.py index 71fac257..d3aea5fd 100644 --- a/research_town/envs/env_proposal_writing_without_rag.py +++ b/research_town/envs/env_proposal_writing_without_rag.py @@ -3,7 +3,7 @@ from ..agents import Agent, AgentManager from ..configs import Config -from ..data import Idea, Insight, Progress +from ..data import Idea, Insight, Progress, Proposal from ..dbs import LogDB, PaperDB, ProgressDB from ..utils.sampler import sample_ideas from .env_base import BaseEnv @@ -19,37 +19,39 @@ def __init__( config: Config, agent_manager: AgentManager, ) -> None: - super().__init__( - name=name, - config=config, - ) + super().__init__(name=name, config=config) self.log_db = log_db self.progress_db = progress_db self.paper_db = paper_db self.agent_manager = agent_manager - # self.user_rag = use_rag + self.proposals: List[Proposal] = [] @beartype def on_enter(self, **context: Any) -> None: + # Assign leader and members from context or sample them self.leader = context.get('leader', self.agent_manager.sample_leader()) self.members = context.get('members', self.agent_manager.sample_members()) + + if 'contexts' not in context: + raise ValueError("'contexts' is required in the context.") self.contexts = context['contexts'] @beartype def on_exit(self) -> Tuple[str, Dict[str, Any]]: + # Update environment run number and handle limits self.env_run_num += 1 if self.env_run_num > self.config.param.max_env_run_num: - return 'error', {} - else: - return 'start_review', {'proposals': self.proposals, 'leader': self.leader} + return 'error', {} # Return error if max run limit exceeded + return 'start_review', {'proposals': self.proposals, 'leader': self.leader} @beartype def run(self) -> Generator[Tuple[Progress, Agent], None, None]: - # Each member reviews literature insights: List[Insight] = [] - keywords: List[str] = [] ideas: List[Idea] = [] + researchers = self.members + [self.leader] + + # Step 1: Researchers review literature and gather insights for researcher in researchers: summary, keywords, insight = researcher.review_literature( contexts=self.contexts, @@ -58,17 +60,14 @@ def run(self) -> Generator[Tuple[Progress, Agent], None, None]: yield insight, researcher insights.append(insight) - keywords.extend(keywords) - - keywords = sorted(keywords, key=lambda x: x[1], reverse=True) + # Step 3: Researchers brainstorm ideas based on their insights for researcher in researchers: idea = researcher.brainstorm_idea(insights=insights, config=self.config) - ideas.append(idea) - yield idea, researcher + ideas.append(idea) - self.proposals = [] + # Step 4: Leader summarizes ideas and writes proposals idea_combos = sample_ideas(ideas, self.config.param.proposal_num) for idea_combo in idea_combos: summarized_idea = self.leader.summarize_idea(