From 9be18e407468f48e8cc03c59e417f63020a9c377 Mon Sep 17 00:00:00 2001 From: JayGhiya Date: Fri, 6 Sep 2024 18:52:08 +0530 Subject: [PATCH] feat: perf improvement through parallelisation of classes --- .../unoplat_code_confluence/__main__.py | 23 ++- .../dspy_function_summary.py | 2 +- .../unoplat_code_confluence/loguru.json | 2 +- .../summary_parser/codebase_summary.py | 189 ++++++++++-------- 4 files changed, 115 insertions(+), 101 deletions(-) diff --git a/unoplat-code-confluence/unoplat_code_confluence/__main__.py b/unoplat-code-confluence/unoplat_code_confluence/__main__.py index 3abf30e..a3bb840 100644 --- a/unoplat-code-confluence/unoplat_code_confluence/__main__.py +++ b/unoplat-code-confluence/unoplat_code_confluence/__main__.py @@ -1,4 +1,5 @@ import argparse +import asyncio import os from loguru import logger import datetime @@ -21,7 +22,7 @@ from packaging import version -def start_pipeline(): +async def start_pipeline(): parser = argparse.ArgumentParser(description="Codebase Parser CLI") parser.add_argument("--config", help="Path to configuration file for unoplat utility", default=os.getcwd() + '/default_config.json', type=str) args = parser.parse_args() @@ -37,17 +38,17 @@ def start_pipeline(): # logger.configure(handlers=logging_config["handlers"]) - get_codebase_metadata(json_configuration_data,iload_json,iparse_json,isummariser) + await get_codebase_metadata(json_configuration_data,iload_json,iparse_json,isummariser) -def get_codebase_metadata(json_configuration_data,iload_json,iparse_json,isummariser): +async def get_codebase_metadata(json_configuration_data,iload_json,iparse_json,isummariser): # Collect necessary inputs from the user to set up the codebase indexing app_config = AppConfig(**json_configuration_data) logger.configure(handlers=app_config.handlers) # Button to submit the indexing - start_parsing( + await start_parsing( app_config, iload_json, iparse_json, @@ -55,13 +56,13 @@ def get_codebase_metadata(json_configuration_data,iload_json,iparse_json,isummar ) -def ensure_jar_downloaded(github_token, arcguard_cli_repo, local_download_directory): +async def ensure_jar_downloaded(github_token, arcguard_cli_repo, local_download_directory): jar_path = Downloader.download_latest_jar(arcguard_cli_repo, local_download_directory, github_token) return jar_path -def get_extension(programming_language: str): +async def get_extension(programming_language: str): #TODO: convert this to enum based check if programming_language == "java": return "java" @@ -70,13 +71,13 @@ def get_extension(programming_language: str): else: raise ValueError(f"Unsupported programming language: {programming_language}") -def start_parsing(app_config: AppConfig, iload_json: JsonLoader, iparse_json: JsonParser, isummariser: MarkdownSummariser): +async def start_parsing(app_config: AppConfig, iload_json: JsonLoader, iparse_json: JsonParser, isummariser: MarkdownSummariser): # Log the start of the parsing process logger.info("Starting parsing process...") # Ensure the JAR is downloaded or use the existing one - jar_path = ensure_jar_downloaded(app_config.api_tokens["github_token"],app_config.repo.download_url,app_config.repo.download_directory) + jar_path = await ensure_jar_downloaded(app_config.api_tokens["github_token"],app_config.repo.download_url,app_config.repo.download_directory) logger.info(f"Local Workspace URL: {app_config.local_workspace_path}") logger.info(f"Programming Language: {app_config.programming_language}") @@ -84,7 +85,7 @@ def start_parsing(app_config: AppConfig, iload_json: JsonLoader, iparse_json: Js logger.info(f"Codebase Name: {app_config.codebase_name}") # based on programming_language convert to extension - extension = get_extension(app_config.programming_language) + extension = await get_extension(app_config.programming_language) # Initialize the ArchGuard handler with the collected parameters. archguard_handler = ArchGuardHandler( @@ -120,7 +121,7 @@ def start_parsing(app_config: AppConfig, iload_json: JsonLoader, iparse_json: Js codebase_summary = CodebaseSummaryParser(unoplat_codebase,dspy_function_pipeline_summary, dspy_class_pipeline_summary,dspy_package_pipeline_summary,dspy_codebase_pipeline_summary,app_config) - unoplat_codebase_summary: DspyUnoplatCodebaseSummary = codebase_summary.parse_codebase() + unoplat_codebase_summary: DspyUnoplatCodebaseSummary = await codebase_summary.parse_codebase() # now write to a markdown dspy unoplat codebase summary @@ -136,5 +137,5 @@ def start_parsing(app_config: AppConfig, iload_json: JsonLoader, iparse_json: Js warnings.filterwarnings("ignore", category=DeprecationWarning, module='pydantic.*') - start_pipeline() + asyncio.run(start_pipeline()) \ No newline at end of file diff --git a/unoplat-code-confluence/unoplat_code_confluence/dspy_function_summary.py b/unoplat-code-confluence/unoplat_code_confluence/dspy_function_summary.py index d1e0c9a..03b7c62 100644 --- a/unoplat-code-confluence/unoplat_code_confluence/dspy_function_summary.py +++ b/unoplat-code-confluence/unoplat_code_confluence/dspy_function_summary.py @@ -37,7 +37,7 @@ def __init__(self): self.generate_function_objective = dspy.TypedPredictor(CodeConfluenceFunctionObjectiveSignature) def forward(self, function_metadata: DspyUnoplatFunctionSubset, class_metadata: DspyUnoplatNodeSubset): - logger.debug(f"Generating function summary for {function_metadata.name}") + logger.debug(f"Generating function summary for {function_metadata.name} present in class {class_metadata.node_name}") class_subset = str(class_metadata.model_dump_json()) function_subset = str(function_metadata.model_dump_json()) diff --git a/unoplat-code-confluence/unoplat_code_confluence/loguru.json b/unoplat-code-confluence/unoplat_code_confluence/loguru.json index e5588d3..3c329f9 100644 --- a/unoplat-code-confluence/unoplat_code_confluence/loguru.json +++ b/unoplat-code-confluence/unoplat_code_confluence/loguru.json @@ -2,7 +2,7 @@ "handlers": [ { "sink": "./app.log", - "format": "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} - {message}", + "format": "{time:YYYY-MM-DD at HH:mm:ss} | {level} | {name}:{function}:{line} | {thread.name} - {message}", "rotation": "10 MB", "retention": "10 days", "level": "INFO" diff --git a/unoplat-code-confluence/unoplat_code_confluence/summary_parser/codebase_summary.py b/unoplat-code-confluence/unoplat_code_confluence/summary_parser/codebase_summary.py index a6ba544..f829534 100644 --- a/unoplat-code-confluence/unoplat_code_confluence/summary_parser/codebase_summary.py +++ b/unoplat-code-confluence/unoplat_code_confluence/summary_parser/codebase_summary.py @@ -1,5 +1,7 @@ import asyncio from collections import deque +from itertools import cycle +import sys from typing import Dict, List from unoplat_code_confluence.configuration.external_config import AppConfig from unoplat_code_confluence.data_models.chapi_unoplat_codebase import UnoplatCodebase @@ -40,64 +42,77 @@ def init_dspy_lm(self,llm_config: dict,parallisation: int): match llm_provider: case "openai": openai_provider = dspy.OpenAI(**llm_config["openai"]) - dspy.configure(lm=openai_provider,experimental=True) - if parallisation > 1: - self.provider_list = [dspy.OpenAI(**llm_config["openai"]) for _ in range(parallisation)] + dspy.configure(lm=openai_provider, experimental=True) + self.provider_list = [openai_provider] + if parallisation and parallisation > 1: + self.provider_list.extend([dspy.OpenAI(**llm_config["openai"]) for _ in range(parallisation - 1)]) case "together": together_provider = dspy.Together(**llm_config["together"]) - dspy.configure(lm=together_provider,experimental=True) - if parallisation > 1: - self.provider_list = [dspy.Together(**llm_config["together"]) for _ in range(parallisation)] + dspy.configure(lm=together_provider, experimental=True) + self.provider_list = [together_provider] + if parallisation and parallisation > 1: + self.provider_list.extend([dspy.Together(**llm_config["together"]) for _ in range(parallisation - 1)]) + case "anyscale": anyscale_provider = dspy.Anyscale(**llm_config["anyscale"]) - dspy.configure(lm=anyscale_provider,experimental=True) - if parallisation > 1: - self.provider_list = [dspy.Anyscale(**llm_config["anyscale"]) for _ in range(parallisation)] + dspy.configure(lm=anyscale_provider, experimental=True) + self.provider_list = [anyscale_provider] + if parallisation and parallisation > 1: + self.provider_list.extend([dspy.Anyscale(**llm_config["anyscale"]) for _ in range(parallisation - 1)]) + case "awsanthropic": awsanthropic_provider = dspy.AWSAnthropic(**llm_config["awsanthropic"]) - dspy.configure(lm=awsanthropic_provider,experimental=True) - if parallisation > 1: - self.provider_list = [dspy.AWSAnthropic(**llm_config["awsanthropic"]) for _ in range(parallisation)] + dspy.configure(lm=awsanthropic_provider, experimental=True) + self.provider_list = [awsanthropic_provider] + if parallisation and parallisation > 1: + self.provider_list.extend([dspy.AWSAnthropic(**llm_config["awsanthropic"]) for _ in range(parallisation - 1)]) + case "ollama": ollama_provider = dspy.OllamaLocal(**llm_config["ollama"]) - dspy.configure(lm=ollama_provider,experimental=True) - if parallisation > 1: - self.provider_list = [dspy.OllamaLocal(**llm_config["ollama"]) for _ in range(parallisation)] + dspy.configure(lm=ollama_provider, experimental=True) + self.provider_list = [ollama_provider] + if parallisation and parallisation > 1: + self.provider_list.extend([dspy.OllamaLocal(**llm_config["ollama"]) for _ in range(parallisation - 1)]) + case "cohere": cohere_provider = dspy.Cohere(**llm_config["cohere"]) - dspy.configure(lm=cohere_provider,experimental=True) - if parallisation > 1: - self.provider_list = [dspy.Cohere(**llm_config["cohere"]) for _ in range(parallisation)] + dspy.configure(lm=cohere_provider, experimental=True) + self.provider_list = [cohere_provider] + if parallisation and parallisation > 1: + self.provider_list.extend([dspy.Cohere(**llm_config["cohere"]) for _ in range(parallisation - 1)]) case _: raise ValueError(f"Invalid LLM provider: {llm_provider}") return self.provider_list - def parse_codebase(self) -> DspyUnoplatCodebaseSummary: + async def parse_codebase(self) -> DspyUnoplatCodebaseSummary: - - unoplat_codebase_summary = DspyUnoplatCodebaseSummary() root_packages: Dict[str,UnoplatPackage] = self.codebase.packages - root_package_summaries = self.process_packages(root_packages,self.provider_list) - + root_package_summaries = await self.process_packages(root_packages) try: dspy_codebase_summary = self.dspy_pipeline_codebase(package_objective_dict=root_package_summaries) except Exception as e: logger.error(f"Error generating codebase summary: {e}") logger.exception("Traceback:") - + sys.exit(1) + unoplat_codebase_summary.codebase_summary = dspy_codebase_summary.summary unoplat_codebase_summary.codebase_objective = dspy_codebase_summary.answer unoplat_codebase_summary.codebase_package = root_package_summaries + json_unoplat_codebase_summary = unoplat_codebase_summary.model_dump_json() + # write to file + with open("unoplat_codebase_summary_dspy_2.json", "w") as f: + f.write(json_unoplat_codebase_summary) + # write to md file #todo: pydantic out to a file of unoplat codebase summary return unoplat_codebase_summary - def count_total_packages(self, packages: Dict[str, UnoplatPackage],provider_list: List[dspy.LM]) -> int: + async def count_total_packages(self, packages: Dict[str, UnoplatPackage]) -> int: total = 0 stack = list(packages.values()) while stack: @@ -106,13 +121,13 @@ def count_total_packages(self, packages: Dict[str, UnoplatPackage],provider_list stack.extend(package.sub_packages.values()) return total - def process_packages(self, packages: Dict[str,UnoplatPackage]) -> Dict[str,DspyUnoplatPackageSummary]: + async def process_packages(self, packages: Dict[str,UnoplatPackage]) -> Dict[str,DspyUnoplatPackageSummary]: package_summaries: Dict[str, DspyUnoplatPackageSummary] = {} stack = deque([(name, package, True) for name, package in packages.items()]) processed = set() memo = {} - total_packages = self.count_total_packages(packages) + total_packages = await self.count_total_packages(packages) pman = ProgressManager(backend='rich') @@ -149,7 +164,7 @@ def process_packages(self, packages: Dict[str,UnoplatPackage]) -> Dict[str,DspyU if package_name in memo: package_summary = memo[package_name] else: - class_summaries = self.process_classes(package.node_subsets,package_name,pman=pman,provider_list=self.provider_list) + class_summaries = await self.process_classes_async(package.node_subsets,package_name,pman=pman,provider_list=self.provider_list) for sub_name in package.sub_packages: if sub_name in memo: logger.debug("Sub package {} already processed, adding to sub_package_summaries",sub_name) @@ -157,7 +172,7 @@ def process_packages(self, packages: Dict[str,UnoplatPackage]) -> Dict[str,DspyU try: logger.debug("Generating package summary for {}",package_name) - package_summary = self.dspy_pipeline_package( + package_summary = self.dspy_pipeline_package( package_name=package_name, class_objective_list=class_summaries, sub_package_summaries=sub_package_summaries @@ -185,80 +200,78 @@ def process_packages(self, packages: Dict[str,UnoplatPackage]) -> Dict[str,DspyU return package_summaries + async def process_batch(self, batch: List[DspyUnoplatNodeSubset], package_name: str, pman: ProgressManager, lm_cycle: cycle) -> List[DspyUnoplatNodeSummary]: + tasks = [] + async with asyncio.TaskGroup() as tg: + for node in batch: + + try: + lm = next(lm_cycle) + task = tg.create_task(self.process_single_class_wrapper(node, package_name, pman, lm)) + tasks.append(task) + except Exception as e: + logger.error(f"Error creating task for {node.node_name}: {e}") + logger.exception("Traceback:") + return await self.collect_batch_results(tasks) + + async def collect_batch_results(self, tasks: List[asyncio.Task]) -> List[DspyUnoplatNodeSummary]: + batch_results = [] + for task in tasks: + try: + result = await task + if result is not None: + batch_results.append(result) + except Exception as e: + logger.error(f"Error collecting batch result: {e}") + logger.exception("Traceback:") + return batch_results + + + async def process_classes_async(self, classes: List[DspyUnoplatNodeSubset],package_name: str,pman: ProgressManager,provider_list: List[dspy.LM]) -> List[DspyUnoplatNodeSummary]: - class_summaries = [] - - async with asyncio.TaskGroup() as tg: - for node,lm in zip(classes,provider_list): - task = tg.create_task(self.process_single_class(node, package_name, pman, lm)) - - for task in tg.tasks: - return class_summaries + class_summaries = [] + concurrency = len(provider_list) + lm_cycle = cycle(provider_list) + + for i in range(0, len(classes), concurrency): + batch = classes[i:i+concurrency] + batch_summaries = await self.process_batch(batch, package_name, pman, lm_cycle) + class_summaries.extend(batch_summaries) + + return class_summaries + async def process_single_class_wrapper(self, node: DspyUnoplatNodeSubset, package_name: str, pman: ProgressManager, lm: dspy.LM) -> DspyUnoplatNodeSummary: + return await asyncio.to_thread(self.process_single_class, node, package_name, pman, lm) + - async def process_single_class(self, node: DspyUnoplatNodeSubset, package_name: str, pman: ProgressManager, provider_list: List[dspy.LM]) -> DspyUnoplatNodeSummary: + def process_single_class(self, node: DspyUnoplatNodeSubset, package_name: str, pman: ProgressManager, lm: dspy.LM) -> DspyUnoplatNodeSummary: try: - function_summaries = await self.process_functions(node.functions, node, pman) - class_summary = await self.dspy_pipeline_class(class_metadata=node, function_objective_summary=function_summaries).answer - return class_summary + with dspy.context(lm=lm): + function_summaries = self.process_functions(node.functions, node, pman) + class_summary = self.dspy_pipeline_class(class_metadata=node, function_objective_summary=function_summaries).answer + return class_summary except Exception as e: - logger.error(f"Error generating class summary for {node}: {e}") + logger.error(f"Error processing class {node.node_name}: {e}") logger.exception("Traceback:") return None - async def process_functions(self, functions: List[DspyUnoplatFunctionSubset], node: DspyUnoplatNodeSubset, pman: ProgressManager) -> List[DspyUnoplatFunctionSummary]: + + def process_functions(self, functions: List[DspyUnoplatFunctionSubset], node: DspyUnoplatNodeSubset, pman: ProgressManager) -> List[DspyUnoplatFunctionSummary]: function_summaries = [] for function in functions: - try: - function_summary = await self.dspy_pipeline_function(function_metadata=function, class_metadata=node).answer - function_summaries.append(DspyUnoplatFunctionSummary(FunctionName=function.name, FunctionSummary=function_summary)) - except Exception as e: - logger.error(f"Error generating function summary for {function.name}: {e}") - logger.exception("Traceback:") - return function_summaries - - def process_classes(self, classes: List[DspyUnoplatNodeSubset],package_name: str,pman: ProgressManager,provider_list: List[dspy.LM]) -> List[DspyUnoplatNodeSummary]: - class_summaries: List[DspyUnoplatNodeSummary] = [] - - class_prog = pman.progiter(iterable = classes, desc=f"Processing classes of {package_name}", verbose=2,total=len(classes)) - - for node in class_prog: - function_summaries = self.process_functions(node.functions,node,pman=pman) - - try: - class_summary = self.dspy_pipeline_class(class_metadata=node, function_objective_summary=function_summaries).answer - class_summaries.append(class_summary) - except Exception as e: - logger.error(f"Error generating class summary for {node}: {e}") - logger.exception("Traceback:") - - - return class_summaries - - - def process_functions(self,functions: List[DspyUnoplatFunctionSubset],node: DspyUnoplatNodeSubset,pman: ProgressManager) -> List[DspyUnoplatFunctionSummary]: - function_summaries: List[DspyUnoplatFunctionSummary] = [] - - - - function_prog = pman.progiter(iterable =functions, desc=f"Processing functions of {node.node_name}", verbose=2,total=len(functions)) - - - for function in function_prog: - if function.name is not None: + if function.name: try: - function_summary = self.dspy_pipeline_function(function_metadata=function,class_metadata=node).answer - dspyUnoplatFunctionSummary: DspyUnoplatFunctionSummary = DspyUnoplatFunctionSummary(FunctionName=function.name,FunctionSummary=function_summary) - function_summaries.append(dspyUnoplatFunctionSummary) - function_prog.update(1) + function_summary = self.dspy_pipeline_function(function_metadata=function, class_metadata=node).answer + function_summaries.append(DspyUnoplatFunctionSummary(FunctionName=function.name, FunctionSummary=function_summary)) except Exception as e: logger.error(f"Error generating function summary for {function.name}: {e}") logger.exception("Traceback:") - - + return function_summaries + + - return function_summaries + - + \ No newline at end of file