From 646126a7d785cee9283f6558b09f411315139ba2 Mon Sep 17 00:00:00 2001 From: joseph-sentry <136376984+joseph-sentry@users.noreply.github.com> Date: Thu, 8 Feb 2024 15:14:57 -0500 Subject: [PATCH] fix: remove n+1 repo flags query (#260) --- services/report/__init__.py | 15 +++++++++++---- services/test_results.py | 16 ++++++++++++---- tasks/upload.py | 2 ++ 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/services/report/__init__.py b/services/report/__init__.py index 51c65c1aa..6796e9571 100644 --- a/services/report/__init__.py +++ b/services/report/__init__.py @@ -178,6 +178,10 @@ def create_report_upload( class ReportService(BaseReportService): metrics_prefix = "services.report" + def __init__(self, current_yaml: UserYaml): + super().__init__(current_yaml) + self.flag_dict = None + def has_initialized_report(self, commit: Commit) -> bool: """Says whether a commit has already initialized its report or not @@ -302,25 +306,28 @@ def _attach_flags_to_upload(self, upload: Upload, flag_names: Sequence[str]): db_session = upload.get_db_session() repoid = upload.report.commit.repoid - existing_flag_dict = self.get_existing_flag_dict(db_session, repoid) + if self.flag_dict is None: + self.fetch_repo_flags(db_session, repoid) + for individual_flag in flag_names: - flag_obj = existing_flag_dict.get(individual_flag, None) + flag_obj = self.flag_dict.get(individual_flag, None) if flag_obj is None: flag_obj = RepositoryFlag( repository_id=repoid, flag_name=individual_flag ) db_session.add(flag_obj) db_session.flush() + self.flag_dict[individual_flag] = flag_obj all_flags.append(flag_obj) upload.flags = all_flags db_session.flush() return all_flags - def get_existing_flag_dict(self, db_session, repoid): + def fetch_repo_flags(self, db_session, repoid): existing_flags_on_repo = ( db_session.query(RepositoryFlag).filter_by(repository_id=repoid).all() ) - return {flag.flag_name: flag for flag in existing_flags_on_repo} + self.flag_dict = {flag.flag_name: flag for flag in existing_flags_on_repo} def build_files( self, report_details: ReportDetails diff --git a/services/test_results.py b/services/test_results.py index 882abf335..41e5bf1b4 100644 --- a/services/test_results.py +++ b/services/test_results.py @@ -3,6 +3,7 @@ from typing import Mapping, Sequence from shared.torngit.exceptions import TorngitClientError +from shared.yaml import UserYaml from sqlalchemy import desc from test_results_parser import Outcome @@ -19,6 +20,10 @@ class TestResultsReportService(BaseReportService): + def __init__(self, current_yaml: UserYaml): + super().__init__(current_yaml) + self.flag_dict = None + async def initialize_and_save_report( self, commit: Commit, report_code: str = None ) -> CommitReport: @@ -69,25 +74,28 @@ def _attach_flags_to_upload(self, upload: Upload, flag_names: Sequence[str]): db_session = upload.get_db_session() repoid = upload.report.commit.repoid - existing_flag_dict = self.get_existing_flag_dict(db_session, repoid) + if self.flag_dict is None: + self.fetch_repo_flags(db_session, repoid) + for individual_flag in flag_names: - flag_obj = existing_flag_dict.get(individual_flag, None) + flag_obj = self.flag_dict.get(individual_flag, None) if flag_obj is None: flag_obj = RepositoryFlag( repository_id=repoid, flag_name=individual_flag ) db_session.add(flag_obj) db_session.flush() + self.flag_dict[individual_flag] = flag_obj all_flags.append(flag_obj) upload.flags = all_flags db_session.flush() return all_flags - def get_existing_flag_dict(self, db_session, repoid): + def fetch_repo_flags(self, db_session, repoid): existing_flags_on_repo = ( db_session.query(RepositoryFlag).filter_by(repository_id=repoid).all() ) - return {flag.flag_name: flag for flag in existing_flags_on_repo} + self.flag_dict = {flag.flag_name: flag for flag in existing_flags_on_repo} def generate_flags_hash(flag_names): diff --git a/tasks/upload.py b/tasks/upload.py index 0afd237ea..7c3d07727 100644 --- a/tasks/upload.py +++ b/tasks/upload.py @@ -482,6 +482,7 @@ async def run_async_within_lock( upload_context.prepare_kwargs_for_retry(kwargs) self.retry(countdown=60, kwargs=kwargs) argument_list = [] + for arguments in upload_context.arguments_list(): normalized_arguments = upload_context.normalize_arguments(commit, arguments) if "upload_id" in normalized_arguments: @@ -492,6 +493,7 @@ async def run_async_within_lock( upload = report_service.create_report_upload( normalized_arguments, commit_report ) + normalized_arguments["upload_pk"] = upload.id_ argument_list.append(normalized_arguments) if argument_list: