From e59e9638e804ec700d3512415092758ae5fcdfbb Mon Sep 17 00:00:00 2001 From: JayGhiya Date: Tue, 6 Aug 2024 13:01:30 +0530 Subject: [PATCH] chore: intermediate non working performance code --- .../configuration/external_config.py | 1 + .../dspy_class_summary.py | 5 +- .../summary_parser/codebase_summary.py | 63 ++++++++++++++++--- 3 files changed, 60 insertions(+), 9 deletions(-) diff --git a/unoplat-code-confluence/unoplat_code_confluence/configuration/external_config.py b/unoplat-code-confluence/unoplat_code_confluence/configuration/external_config.py index 8b45216..074b59c 100644 --- a/unoplat-code-confluence/unoplat_code_confluence/configuration/external_config.py +++ b/unoplat-code-confluence/unoplat_code_confluence/configuration/external_config.py @@ -39,6 +39,7 @@ class AppConfig(BaseModel): api_tokens: Dict[str, str] llm_provider_config: Dict[str, Any] handlers: List[Dict[str, Any]] = Field(default_factory=list,alias="logging_handlers") + parallisation: int = 1 @field_validator('programming_language') def check_programming_language(cls, value, info:ValidationInfo): diff --git a/unoplat-code-confluence/unoplat_code_confluence/dspy_class_summary.py b/unoplat-code-confluence/unoplat_code_confluence/dspy_class_summary.py index ee0bb12..dfe6edd 100644 --- a/unoplat-code-confluence/unoplat_code_confluence/dspy_class_summary.py +++ b/unoplat-code-confluence/unoplat_code_confluence/dspy_class_summary.py @@ -31,11 +31,12 @@ def __init__(self): def forward(self, class_metadata: DspyUnoplatNodeSubset, function_objective_summary: List[DspyUnoplatFunctionSummary]): logger.debug(f"Generating class summary for {class_metadata.node_name}") class_summary = "" - + + for function_objective in function_objective_summary: signature_class_summary = self.generate_class_summary(class_existing_summary=class_summary, function_summary=function_objective.function_summary.objective, class_metadata=str(class_metadata.model_dump_json()),hint="Generate the class detailed summary for the class by being concise , factual and grounded.:"+class_metadata.node_name) class_summary = signature_class_summary.final_class_summary - + if class_metadata.node_name is not None: hint="Generate the class objective for the class by being concise and dnt miss on any details.:"+class_metadata.node_name else: diff --git a/unoplat-code-confluence/unoplat_code_confluence/summary_parser/codebase_summary.py b/unoplat-code-confluence/unoplat_code_confluence/summary_parser/codebase_summary.py index a830f4d..a6ba544 100644 --- a/unoplat-code-confluence/unoplat_code_confluence/summary_parser/codebase_summary.py +++ b/unoplat-code-confluence/unoplat_code_confluence/summary_parser/codebase_summary.py @@ -1,3 +1,4 @@ +import asyncio from collections import deque from typing import Dict, List from unoplat_code_confluence.configuration.external_config import AppConfig @@ -29,32 +30,48 @@ def __init__(self, codebase: UnoplatCodebase, dspy_pipeline_function: CodeConflu self.dspy_pipeline_package: CodeConfluencePackageModule = dspy_pipeline_package self.dspy_pipeline_codebase: CodeConfluenceCodebaseModule = dspy_pipeline_codebase #TODO: we will be externalise the different llms that can be used at all dspy pipelines and within dspy pipelines once dspy switches to litellm - self.init_dspy_lm(app_config.llm_provider_config) + self.provider_list =self.init_dspy_lm(app_config.llm_provider_config,app_config.parallisation) - def init_dspy_lm(self,llm_config: dict): + def init_dspy_lm(self,llm_config: dict,parallisation: int): #todo define a switch case llm_provider = next(iter(llm_config.keys())) + self.provider_list: dspy.LM = [] match llm_provider: case "openai": openai_provider = dspy.OpenAI(**llm_config["openai"]) dspy.configure(lm=openai_provider,experimental=True) + if parallisation > 1: + self.provider_list = [dspy.OpenAI(**llm_config["openai"]) for _ in range(parallisation)] + case "together": together_provider = dspy.Together(**llm_config["together"]) dspy.configure(lm=together_provider,experimental=True) + if parallisation > 1: + self.provider_list = [dspy.Together(**llm_config["together"]) for _ in range(parallisation)] case "anyscale": anyscale_provider = dspy.Anyscale(**llm_config["anyscale"]) dspy.configure(lm=anyscale_provider,experimental=True) + if parallisation > 1: + self.provider_list = [dspy.Anyscale(**llm_config["anyscale"]) for _ in range(parallisation)] case "awsanthropic": awsanthropic_provider = dspy.AWSAnthropic(**llm_config["awsanthropic"]) dspy.configure(lm=awsanthropic_provider,experimental=True) + if parallisation > 1: + self.provider_list = [dspy.AWSAnthropic(**llm_config["awsanthropic"]) for _ in range(parallisation)] case "ollama": ollama_provider = dspy.OllamaLocal(**llm_config["ollama"]) dspy.configure(lm=ollama_provider,experimental=True) + if parallisation > 1: + self.provider_list = [dspy.OllamaLocal(**llm_config["ollama"]) for _ in range(parallisation)] case "cohere": cohere_provider = dspy.Cohere(**llm_config["cohere"]) dspy.configure(lm=cohere_provider,experimental=True) - + if parallisation > 1: + self.provider_list = [dspy.Cohere(**llm_config["cohere"]) for _ in range(parallisation)] + case _: + raise ValueError(f"Invalid LLM provider: {llm_provider}") + return self.provider_list def parse_codebase(self) -> DspyUnoplatCodebaseSummary: @@ -64,7 +81,7 @@ def parse_codebase(self) -> DspyUnoplatCodebaseSummary: root_packages: Dict[str,UnoplatPackage] = self.codebase.packages - root_package_summaries = self.process_packages(root_packages) + root_package_summaries = self.process_packages(root_packages,self.provider_list) try: @@ -80,7 +97,7 @@ def parse_codebase(self) -> DspyUnoplatCodebaseSummary: #todo: pydantic out to a file of unoplat codebase summary return unoplat_codebase_summary - def count_total_packages(self, packages: Dict[str, UnoplatPackage]) -> int: + def count_total_packages(self, packages: Dict[str, UnoplatPackage],provider_list: List[dspy.LM]) -> int: total = 0 stack = list(packages.values()) while stack: @@ -132,7 +149,7 @@ def process_packages(self, packages: Dict[str,UnoplatPackage]) -> Dict[str,DspyU if package_name in memo: package_summary = memo[package_name] else: - class_summaries = self.process_classes(package.node_subsets,package_name,pman=pman) + class_summaries = self.process_classes(package.node_subsets,package_name,pman=pman,provider_list=self.provider_list) for sub_name in package.sub_packages: if sub_name in memo: logger.debug("Sub package {} already processed, adding to sub_package_summaries",sub_name) @@ -167,8 +184,40 @@ def process_packages(self, packages: Dict[str,UnoplatPackage]) -> Dict[str,DspyU return package_summaries + + async def process_classes_async(self, classes: List[DspyUnoplatNodeSubset],package_name: str,pman: ProgressManager,provider_list: List[dspy.LM]) -> List[DspyUnoplatNodeSummary]: + class_summaries = [] + + async with asyncio.TaskGroup() as tg: + for node,lm in zip(classes,provider_list): + task = tg.create_task(self.process_single_class(node, package_name, pman, lm)) + + for task in tg.tasks: + return class_summaries + + + async def process_single_class(self, node: DspyUnoplatNodeSubset, package_name: str, pman: ProgressManager, provider_list: List[dspy.LM]) -> DspyUnoplatNodeSummary: + try: + function_summaries = await self.process_functions(node.functions, node, pman) + class_summary = await self.dspy_pipeline_class(class_metadata=node, function_objective_summary=function_summaries).answer + return class_summary + except Exception as e: + logger.error(f"Error generating class summary for {node}: {e}") + logger.exception("Traceback:") + return None + + async def process_functions(self, functions: List[DspyUnoplatFunctionSubset], node: DspyUnoplatNodeSubset, pman: ProgressManager) -> List[DspyUnoplatFunctionSummary]: + function_summaries = [] + for function in functions: + try: + function_summary = await self.dspy_pipeline_function(function_metadata=function, class_metadata=node).answer + function_summaries.append(DspyUnoplatFunctionSummary(FunctionName=function.name, FunctionSummary=function_summary)) + except Exception as e: + logger.error(f"Error generating function summary for {function.name}: {e}") + logger.exception("Traceback:") + return function_summaries - def process_classes(self, classes: List[DspyUnoplatNodeSubset],package_name: str,pman: ProgressManager) -> List[DspyUnoplatNodeSummary]: + def process_classes(self, classes: List[DspyUnoplatNodeSubset],package_name: str,pman: ProgressManager,provider_list: List[dspy.LM]) -> List[DspyUnoplatNodeSummary]: class_summaries: List[DspyUnoplatNodeSummary] = [] class_prog = pman.progiter(iterable = classes, desc=f"Processing classes of {package_name}", verbose=2,total=len(classes))