Skip to content

Commit

Permalink
code refinements and ec2 instance upgrades
Browse files Browse the repository at this point in the history
  • Loading branch information
threnjen committed Dec 11, 2024
1 parent 42a2bf4 commit cb2ad0c
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 31 deletions.
2 changes: 1 addition & 1 deletion aws_terraform_bgg/ec2_instances.tf
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
resource "aws_instance" "weaviate_ec2_instance" {

instance_type = "t2.micro"
instance_type = "t3.medium"
ami = "ami-055e3d4f0bbeb5878"
key_name = "weaviate-ec2"
monitoring = true
Expand Down
9 changes: 0 additions & 9 deletions modules/rag_description_generation/ec2_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,17 @@ def copy_docker_compose_to_instance(self):

response = os.system(command)

print(response)

def start_docker(self):
ssm_client = boto3.client("ssm")

command = "sudo docker compose -f /home/ec2-user/docker-compose.yml up -d"
# command = "ls /home/ec2-user/"

print(f"\nSending the command: {command} to the instance {self.instance_id}")

response = ssm_client.send_command(
InstanceIds=[self.instance_id],
DocumentName="AWS-RunShellScript", # For Linux instances
Parameters={"commands": [command]},
)
command_id = response["Command"]["CommandId"]
print(f"Command ID: {command_id}")

print(f"Waiting for docker containers to start")
time.sleep(5)
Expand All @@ -48,10 +42,7 @@ def start_docker(self):
CommandId=command_id, InstanceId=self.instance_id
)

print(command_invocation_result)

def get_ip_address(self):
print(self.ip_address)
return self.ip_address

def validate_ready_weaviate_instance(self):
Expand Down
41 changes: 24 additions & 17 deletions modules/rag_description_generation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ def stop_ec2_instance(self):
ec2_instance.stop_instance()

def compute_game_overall_stats(self, game_df):
game_mean = round(game_df["AvgRating"].describe()["mean"], 2)
overall_mean = round(game_df["AvgRating"].describe()["mean"], 2)
game_std = round(game_df["AvgRating"].describe()["std"], 2)

self.overall_stats["overall_mean"] = game_mean
self.overall_stats["overall_mean"] = overall_mean
self.overall_stats["overall_std"] = game_std
self.overall_stats["two_under"] = round(game_mean - 2 * game_std, 2)
self.overall_stats["one_under"] = round(game_mean - game_std, 2)
self.overall_stats["half_over"] = round(game_mean + 0.5 * game_std, 2)
self.overall_stats["one_over"] = round(game_mean + game_std, 2)
self.overall_stats["two_under"] = round(overall_mean - 2 * game_std, 2)
self.overall_stats["one_under"] = round(overall_mean - game_std, 2)
self.overall_stats["half_over"] = round(overall_mean + 0.5 * game_std, 2)
self.overall_stats["one_over"] = round(overall_mean + game_std, 2)

print(f"Overall mean: {game_mean}")
print(f"Overall mean: {overall_mean}")

def load_reduced_game_df(self):
print(f"\nLoading game data from {GAME_CONFIGS['clean_dfs_directory']}")
Expand Down Expand Up @@ -92,7 +92,7 @@ def merge_game_df_with_user_df(self, game_df_reduced):
)

print(
f"Reducing user ratings to only include games in the reduced game dataframe"
f"Reducing user ratings to only include games in the reduced game dataframe\n"
)
all_games_df = user_df.merge(
game_df_reduced[
Expand All @@ -113,32 +113,37 @@ def load_prompt(self):
open("modules/rag_description_generation/prompt.json").read()
)["gpt4o_mini_generate_prompt_structured"]

def process_single_game(self, game_id, all_games_df):
if not self.dynamodb_client.check_dynamo_db_key(game_id):
def process_single_game(
self, game_id: str, all_games_df: pd.DataFrame, generate_prompt: str
):
if not self.dynamodb_client.check_dynamo_db_key(game_id=game_id):
df, game_name, game_mean = get_single_game_entries(
df=all_games_df, game_id=game_id, sample_pct=0.05
df=all_games_df, game_id=game_id, sample_pct=0.10
)
reviews = df["combined_review"].to_list()
self.weaviate.add_collection_batch(game_id, reviews)
self.weaviate.add_collection_batch(game_id=game_id, reviews=reviews)
current_prompt = self.weaviate.prompt_replacement(
generate_prompt=self.generate_prompt,
current_prompt=generate_prompt,
overall_stats=self.overall_stats,
game_name=game_name,
game_mean=game_mean,
)
print(current_prompt)
summary = self.weaviate.generate_aggregated_review(game_id, current_prompt)
self.dynamodb_client.divide_and_process_generated_summary(
game_id, summary=summary.generated
)
print(f"\n\n{summary.generated}")
print(f"\n{summary.generated}")
self.weaviate.remove_collection_items(game_id=game_id, reviews=reviews)
return

print(f"Game {game_id} already processed")

def rag_description_generation_chain(self):
self.confirm_running_ec2_host()
game_df_reduced = self.load_reduced_game_df()
all_games_df = self.merge_game_df_with_user_df(game_df_reduced)
generate_prompt = self.load_prompt()
print(generate_prompt)

self.weaviate = WeaviateClient(
ip_address=self.ip_address,
Expand All @@ -149,8 +154,10 @@ def rag_description_generation_chain(self):
self.dynamodb_client = DynamoDB()

for game_id in self.game_ids:
print(f"Processing game {game_id}")
self.process_single_game(game_id, all_games_df)
print(f"\nProcessing game {game_id}")
self.process_single_game(game_id, all_games_df, generate_prompt)

self.weaviate.close_client()


if __name__ == "__main__":
Expand Down
13 changes: 9 additions & 4 deletions modules/rag_description_generation/rag_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def connect_weaviate_client_ec2(self) -> weaviate.client:

def prompt_replacement(
self,
generate_prompt: str,
current_prompt: str,
overall_stats: dict[float],
game_name: str,
game_mean: str,
Expand All @@ -52,7 +52,7 @@ def prompt_replacement(
# turn all stats to strings
overall_stats = {k: str(v) for k, v in overall_stats.items()}

current_prompt = generate_prompt.replace("{GAME_NAME_HERE}", game_name)
current_prompt = current_prompt.replace("{GAME_NAME_HERE}", game_name)
current_prompt = current_prompt.replace("{GAME_AVERAGE_HERE}", game_mean)
current_prompt = current_prompt.replace(
"{TWO_UNDER}", overall_stats["two_under"]
Expand Down Expand Up @@ -90,6 +90,8 @@ def add_collection_batch(
else:
batch.add_object(properties=review_item, uuid=uuid)

print(f"Reviews added for game {game_id}")

def remove_collection_items(
self,
game_id: str,
Expand Down Expand Up @@ -126,8 +128,8 @@ def generate_aggregated_review(
def create_weaviate_collection(self):

if self.weaviate_client.collections.exists(self.collection_name):
self.weaviate_client.collections.delete(self.collection_name)
pass
print("Collection already exists for this block")
return

self.weaviate_client.collections.create(
name=self.collection_name,
Expand All @@ -154,6 +156,9 @@ def create_weaviate_collection(self):
],
)

def close_client(self):
self.weaviate_client.close()


def connect_weaviate_client_docker() -> weaviate.client:
client = weaviate.connect_to_local(
Expand Down

0 comments on commit cb2ad0c

Please sign in to comment.