Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue/699 #745

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions backend/generator_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,4 @@ def run_engine(
print(f'Error occurred during engine execution: {e}')

finally:
print('Engine execution completed.')
yield None, None
('Engine execution completed.')
18 changes: 16 additions & 2 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,26 @@ def stop_process(user_id: str) -> None:
process.terminate() # Safely terminate the process
process.join() # Ensure cleanup
del active_processes[user_id]
print(f'Process for user {user_id} stopped.')


def background_task(
url: str, child_conn: multiprocessing.connection.Connection
) -> None:
generator = run_engine(url)
try:
# Generate and send results to the parent process
for progress, agent in generator:
child_conn.send((progress, agent))

while True:
if child_conn.poll():
msg = child_conn.recv()
if msg == 'terminate':
break

except Exception as e:
child_conn.send({'type': 'error', 'content': str(e)})

finally:
child_conn.send(None)
child_conn.close()
Expand Down Expand Up @@ -164,22 +172,28 @@ async def process_url(request: Request) -> StreamingResponse:
async def stream_response() -> AsyncGenerator[str, None]:
try:
while True:
# Check if the client has disconnected
if await request.is_disconnected():
print(f'Client disconnected for user {user_id}. Cancelling task.')
stop_process(user_id)
break

# Check for new data from the background task
if parent_conn.poll():
result = parent_conn.recv()
if result is None:
print(f'No more data for user {user_id}. Stopping task.')
break

# Stream the formatted output to the client
for formatted_output in format_response(generator_wrapper(result)):
yield formatted_output
else:
await asyncio.sleep(0.1)
finally:
stop_process(user_id) # Ensure process is stopped on function exit
if await request.is_disconnected():
stop_process(user_id)
print(f'Process for user {user_id} stopped due to disconnection.')

# Return the StreamingResponse
return StreamingResponse(stream_response(), media_type='application/json')
138 changes: 104 additions & 34 deletions research_town/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
import uuid

from beartype import beartype
from beartype.typing import Dict, List, Literal, Tuple

from ..configs import Config
from ..data import Idea, Insight, MetaReview, Paper, Profile, Proposal, Rebuttal, Review
from ..data import (
Idea,
IdeaBrainstormLog,
Insight,
LiteratureReviewLog,
MetaReview,
MetaReviewWritingLog,
OpenAIPrompt,
Paper,
Profile,
Proposal,
ProposalWritingLog,
Rebuttal,
RebuttalWritingLog,
Review,
ReviewWritingLog,
)
from ..utils.agent_prompter import (
brainstorm_idea_prompting,
discuss_idea_prompting,
Expand Down Expand Up @@ -47,10 +65,10 @@ def review_literature(
papers: List[Paper],
contexts: List[str],
config: Config,
) -> Tuple[str, List[str], Insight]:
) -> Tuple[str, List[str], Insight, LiteratureReviewLog]:
serialized_papers = self.serializer.serialize(papers)
serialized_profile = self.serializer.serialize(self.profile)
summary, keywords, valuable_points = review_literature_prompting(
summary, keywords, valuable_points, prompt = review_literature_prompting(
profile=serialized_profile,
papers=serialized_papers,
contexts=contexts,
Expand All @@ -63,16 +81,23 @@ def review_literature(
stream=config.param.stream,
)
insight = Insight(content=valuable_points)
return summary, keywords, insight
formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt)
log_entry = LiteratureReviewLog(
profile_pk=self.profile.pk,
insight_pk=insight.pk,
prompt_pk=formatted_prompt.pk,
)

return summary, keywords, insight, log_entry

@beartype
@member_required
def brainstorm_idea(
self, insights: List[Insight], papers: List[Paper], config: Config
) -> Idea:
) -> Tuple[Idea, IdeaBrainstormLog]:
serialized_insights = self.serializer.serialize(insights)
serialized_papers = self.serializer.serialize(papers)
idea_content = brainstorm_idea_prompting(
idea_content_list, prompt = brainstorm_idea_prompting(
bio=self.profile.bio,
insights=serialized_insights,
papers=serialized_papers,
Expand All @@ -83,16 +108,23 @@ def brainstorm_idea(
temperature=config.param.temperature,
top_p=config.param.top_p,
stream=config.param.stream,
)[0]
return Idea(content=idea_content)
)
idea_content = idea_content_list[0]
idea = Idea(content=idea_content)
formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt)
log_entry = IdeaBrainstormLog(
profile_pk=self.profile.pk, idea_pk=idea.pk, prompt_pk=formatted_prompt.pk
)

return idea, log_entry

@beartype
@member_required
def discuss_idea(
self, ideas: List[Idea], contexts: List[str], config: Config
) -> Idea:
serialized_ideas = self.serializer.serialize(ideas)
idea_summarized = discuss_idea_prompting(
idea_summarized_list, prompt = discuss_idea_prompting(
bio=self.profile.bio,
contexts=contexts,
ideas=serialized_ideas,
Expand All @@ -103,14 +135,15 @@ def discuss_idea(
temperature=config.param.temperature,
top_p=config.param.top_p,
stream=config.param.stream,
)[0]
)
idea_summarized = idea_summarized_list[0]
return Idea(content=idea_summarized)

@beartype
@member_required
def write_proposal(
self, idea: Idea, papers: List[Paper], config: Config
) -> Proposal:
) -> Tuple[Proposal, ProposalWritingLog]:
serialized_idea = self.serializer.serialize(idea)
serialized_papers = self.serializer.serialize(papers)

Expand All @@ -127,7 +160,7 @@ def write_proposal(
print('write_proposal_strategy not supported, will use default')
prompt_template = config.agent_prompt_template.write_proposal

proposal, q5_result = write_proposal_prompting(
proposal, q5_result, prompt = write_proposal_prompting(
idea=serialized_idea,
papers=serialized_papers,
model_name=self.model_name,
Expand All @@ -138,35 +171,48 @@ def write_proposal(
top_p=config.param.top_p,
stream=config.param.stream,
)
return Proposal(

proposal_obj = Proposal(
content=proposal,
q1=q5_result.get('q1', ''),
q2=q5_result.get('q2', ''),
q3=q5_result.get('q3', ''),
q4=q5_result.get('q4', ''),
q5=q5_result.get('q5', ''),
)
formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt)
log_entry = ProposalWritingLog(
profile_pk=self.profile.pk,
proposal_pk=proposal_obj.pk,
prompt_pk=formatted_prompt.pk,
)

return proposal_obj, log_entry

@beartype
@reviewer_required
def write_review(self, proposal: Proposal, config: Config) -> Review:
def write_review(
self, proposal: Proposal, config: Config
) -> Tuple[Review, ReviewWritingLog]:
serialized_proposal = self.serializer.serialize(proposal)

summary, strength, weakness, ethical_concerns, score = write_review_prompting(
proposal=serialized_proposal,
model_name=self.model_name,
summary_prompt_template=config.agent_prompt_template.write_review_summary,
strength_prompt_template=config.agent_prompt_template.write_review_strength,
weakness_prompt_template=config.agent_prompt_template.write_review_weakness,
ethical_prompt_template=config.agent_prompt_template.write_review_ethical,
score_prompt_template=config.agent_prompt_template.write_review_score,
return_num=config.param.return_num,
max_token_num=config.param.max_token_num,
temperature=config.param.temperature,
top_p=config.param.top_p,
stream=config.param.stream,
summary, strength, weakness, ethical_concerns, score, prompt = (
write_review_prompting(
proposal=serialized_proposal,
model_name=self.model_name,
summary_prompt_template=config.agent_prompt_template.write_review_summary,
strength_prompt_template=config.agent_prompt_template.write_review_strength,
weakness_prompt_template=config.agent_prompt_template.write_review_weakness,
ethical_prompt_template=config.agent_prompt_template.write_review_ethical,
score_prompt_template=config.agent_prompt_template.write_review_score,
return_num=config.param.return_num,
max_token_num=config.param.max_token_num,
temperature=config.param.temperature,
top_p=config.param.top_p,
stream=config.param.stream,
)
)
return Review(
review_obj = Review(
proposal_pk=proposal.pk,
reviewer_pk=self.profile.pk,
summary=summary,
Expand All @@ -175,6 +221,14 @@ def write_review(self, proposal: Proposal, config: Config) -> Review:
ethical_concerns=ethical_concerns,
score=score,
)
formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt)
log_entry = ReviewWritingLog(
profile_pk=self.profile.pk,
review_pk=review_obj.pk,
prompt_pk=formatted_prompt.pk,
)

return review_obj, log_entry

@beartype
@chair_required
Expand All @@ -183,11 +237,11 @@ def write_metareview(
proposal: Proposal,
reviews: List[Review],
config: Config,
) -> MetaReview:
) -> Tuple[MetaReview, MetaReviewWritingLog]:
serialized_proposal = self.serializer.serialize(proposal)
serialized_reviews = self.serializer.serialize(reviews)

summary, strength, weakness, ethical_concerns, decision = (
summary, strength, weakness, ethical_concerns, decision, prompt = (
write_metareview_prompting(
proposal=serialized_proposal,
reviews=serialized_reviews,
Expand All @@ -205,7 +259,7 @@ def write_metareview(
)
)

return MetaReview(
metareview_obj = MetaReview(
proposal_pk=proposal.pk,
chair_pk=self.profile.pk,
reviewer_pks=[review.reviewer_pk for review in reviews],
Expand All @@ -216,6 +270,14 @@ def write_metareview(
ethical_concerns=ethical_concerns,
decision=decision,
)
formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt)
log_entry = MetaReviewWritingLog(
profile_pk=self.profile.pk,
metareview_pk=metareview_obj.pk,
prompt_pk=formatted_prompt.pk,
)

return metareview_obj, log_entry

@beartype
@leader_required
Expand All @@ -224,11 +286,11 @@ def write_rebuttal(
proposal: Proposal,
review: Review,
config: Config,
) -> Rebuttal:
) -> Tuple[Rebuttal, RebuttalWritingLog]:
serialized_proposal = self.serializer.serialize(proposal)
serialized_review = self.serializer.serialize(review)

rebuttal_content, q5_result = write_rebuttal_prompting(
rebuttal_content, q5_result, prompt = write_rebuttal_prompting(
proposal=serialized_proposal,
review=serialized_review,
model_name=self.model_name,
Expand All @@ -240,7 +302,7 @@ def write_rebuttal(
stream=config.param.stream,
)

return Rebuttal(
rebuttal_obj = Rebuttal(
proposal_pk=proposal.pk,
reviewer_pk=review.reviewer_pk,
author_pk=self.profile.pk,
Expand All @@ -251,3 +313,11 @@ def write_rebuttal(
q4=q5_result.get('q4', ''),
q5=q5_result.get('q5', ''),
)
formatted_prompt = OpenAIPrompt(pk=str(uuid.uuid4()), messages=prompt)
log_entry = RebuttalWritingLog(
profile_pk=self.profile.pk,
rebuttal_pk=rebuttal_obj.pk,
prompt_pk=formatted_prompt.pk,
)

return rebuttal_obj, log_entry
4 changes: 4 additions & 0 deletions research_town/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
Log,
MetaReview,
MetaReviewWritingLog,
OpenAIPrompt,
Paper,
Profile,
Progress,
Prompt,
Proposal,
ProposalWritingLog,
Rebuttal,
Expand All @@ -27,6 +29,8 @@
'ReviewWritingLog',
'ExperimentLog',
'Progress',
'Prompt',
'OpenAIPrompt',
'Paper',
'Profile',
'Idea',
Expand Down
12 changes: 11 additions & 1 deletion research_town/data/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, ConfigDict, Field

Expand All @@ -9,6 +9,15 @@ class Data(BaseModel):
project_name: Optional[str] = Field(default=None)


class Prompt(BaseModel):
pk: str = Field(default_factory=lambda: str(uuid.uuid4()))
messages: Optional[Any] = Field(default=None)


class OpenAIPrompt(Prompt):
messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]]


class Profile(Data):
name: str
bio: str
Expand Down Expand Up @@ -43,6 +52,7 @@ class Paper(Data):
class Log(Data):
timestep: int = Field(default=0)
profile_pk: str
prompt_pk: Optional[str] = Field(default=None)


class LiteratureReviewLog(Log):
Expand Down
Loading
Loading