From d66383acc6d6dcd214452de21b53297ee5e98e4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dominik=20Fuch=C3=9F?= Date: Mon, 30 Sep 2024 16:57:00 +0200 Subject: [PATCH] Add code features --- ...SamViaLlmCodeTraceabilityLinkRecovery.java | 11 ++++--- .../agents/LLMArchitectureProviderAgent.java | 5 +-- .../informants/LLMArchitecturePrompt.java | 19 +++++++++-- .../LLMArchitectureProviderInformant.java | 33 +++++++++++++++++-- ...odeTraceabilityLinkRecoveryEvaluation.java | 7 ++-- ...TraceLinkEvaluationSadSamViaLlmCodeIT.java | 8 +++-- 6 files changed, 68 insertions(+), 15 deletions(-) diff --git a/pipeline-tlr/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/execution/ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery.java b/pipeline-tlr/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/execution/ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery.java index e6b49d1..f473250 100644 --- a/pipeline-tlr/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/execution/ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery.java +++ b/pipeline-tlr/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/execution/ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery.java @@ -26,14 +26,17 @@ public ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery(String projectName) { } public void setUp(File inputText, File inputCode, SortedMap additionalConfigs, File outputDir, LargeLanguageModel largeLanguageModel, - LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt aggregationPrompt) { - definePipeline(inputText, inputCode, additionalConfigs, largeLanguageModel, documentationExtractionPrompt, codeExtractionPrompt, aggregationPrompt); + LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt.Features codeFeatures, + LLMArchitecturePrompt aggregationPrompt) { + definePipeline(inputText, inputCode, additionalConfigs, largeLanguageModel, documentationExtractionPrompt, codeExtractionPrompt, codeFeatures, + aggregationPrompt); setOutputDirectory(outputDir); isSetUp = true; } private void definePipeline(File inputText, File inputCode, SortedMap additionalConfigs, LargeLanguageModel largeLanguageModel, - LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt aggregationPrompt) { + LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt.Features codeFeatures, + LLMArchitecturePrompt aggregationPrompt) { ArDoCo arDoCo = this.getArDoCo(); var dataRepository = arDoCo.getDataRepository(); @@ -52,7 +55,7 @@ private void definePipeline(File inputText, File inputCode, SortedMap getTemplates() { + if (this == CODE_ONLY_V1) + throw new IllegalArgumentException("This method is not supported for this enum value"); return templates; } + + public List getTemplates(Features features) { + return templates.stream().map(it -> it.replace("{FEATURES}", features.toString())).toList(); + } + + public enum Features { + PACKAGES, PACKAGES_AND_THEIR_CLASSES; + + @Override + public String toString() { + return super.toString().charAt(0) + super.toString().toLowerCase().substring(1).replace("_", " "); + } + } } diff --git a/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java index 6930059..ac02bb6 100644 --- a/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java +++ b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java @@ -1,6 +1,9 @@ /* Licensed under MIT 2024. */ package edu.kit.kastel.mcse.ardoco.tlr.models.informants; +import static edu.kit.kastel.mcse.ardoco.tlr.models.informants.LLMArchitecturePrompt.Features.PACKAGES; +import static edu.kit.kastel.mcse.ardoco.tlr.models.informants.LLMArchitecturePrompt.Features.PACKAGES_AND_THEIR_CLASSES; + import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -15,6 +18,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel; import edu.kit.kastel.mcse.ardoco.core.api.models.ArchitectureModelType; import edu.kit.kastel.mcse.ardoco.core.api.models.CodeModelType; +import edu.kit.kastel.mcse.ardoco.core.api.models.Entity; import edu.kit.kastel.mcse.ardoco.core.api.models.ModelStates; import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.ArchitectureModel; import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.CodeModel; @@ -35,9 +39,10 @@ public class LLMArchitectureProviderInformant extends Informant { private final LLMArchitecturePrompt documentationPrompt; private final LLMArchitecturePrompt codePrompt; private final LLMArchitecturePrompt aggregationPrompt; + private final LLMArchitecturePrompt.Features codeFeature; public LLMArchitectureProviderInformant(DataRepository dataRepository, LargeLanguageModel largeLanguageModel, LLMArchitecturePrompt documentation, - LLMArchitecturePrompt code, LLMArchitecturePrompt aggregation) { + LLMArchitecturePrompt code, LLMArchitecturePrompt.Features codeFeature, LLMArchitecturePrompt aggregation) { super(LLMArchitectureProviderInformant.class.getSimpleName(), dataRepository); String apiKey = System.getenv("OPENAI_API_KEY"); String orgId = System.getenv("OPENAI_ORG_ID"); @@ -47,6 +52,7 @@ public LLMArchitectureProviderInformant(DataRepository dataRepository, LargeLang this.chatLanguageModel = largeLanguageModel.create(); this.documentationPrompt = documentation; this.codePrompt = code; + this.codeFeature = codeFeature; this.aggregationPrompt = aggregation; if (documentationPrompt == null && codePrompt == null) { throw new IllegalArgumentException("At least one prompt must be provided"); @@ -54,6 +60,9 @@ public LLMArchitectureProviderInformant(DataRepository dataRepository, LargeLang if (documentationPrompt != null && codePrompt != null && aggregationPrompt == null) { logger.info("Using Similarity Metrics to aggregate the component names"); } + if (codePrompt != null && codeFeature == null) { + throw new IllegalArgumentException("Code prompt requires a code feature"); + } } @Override @@ -121,8 +130,26 @@ private void codeToArchitecture(List componentNames) { return; } - var packages = codeModel.getAllPackages().stream().filter(it -> !it.getContent().isEmpty()).toList(); - parseComponentsFromAiRequests(componentNames, codePrompt.getTemplates(), String.join("\n", packages.stream().map(this::getPackageName).toList())); + switch (this.codeFeature) { + case PACKAGES -> { + var packages = codeModel.getAllPackages().stream().filter(it -> !it.getContent().isEmpty()).toList(); + parseComponentsFromAiRequests(componentNames, codePrompt.getTemplates(PACKAGES), String.join("\n", packages.stream() + .map(this::getPackageName) + .toList())); + } + case PACKAGES_AND_THEIR_CLASSES -> { + var packages = codeModel.getAllPackages().stream().filter(it -> !it.getContent().isEmpty()).toList(); + + var packagesWithClasses = packages.stream().map(p -> { + var packageName = getPackageName(p); + var classes = p.getContent().stream().flatMap(it -> it.getAllCompilationUnits().stream()).map(Entity::getName).distinct().sorted().toList(); + return packageName + " (" + String.join(", ", classes) + ")"; + }).toList(); + + parseComponentsFromAiRequests(componentNames, codePrompt.getTemplates(PACKAGES_AND_THEIR_CLASSES), String.join("\n", packagesWithClasses)); + } + } + } private void parseComponentsFromAiRequests(List componentNames, List templates, String dataForFirstPrompt) { diff --git a/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation.java b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation.java index 4b2aaed..a4d65be 100644 --- a/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation.java +++ b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation.java @@ -36,14 +36,17 @@ class SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation extends TraceabilityLin private final LLMArchitecturePrompt documentationExtractionPrompt; private final LLMArchitecturePrompt codeExtractionPrompt; private final LLMArchitecturePrompt aggregationPrompt; + private final LLMArchitecturePrompt.Features codeFeatures; public SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(boolean acmFile, LargeLanguageModel largeLanguageModel, - LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt aggregationPrompt) { + LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt.Features codeFeatures, + LLMArchitecturePrompt aggregationPrompt) { super(); this.acmFile = acmFile; this.largeLanguageModel = largeLanguageModel; this.documentationExtractionPrompt = documentationExtractionPrompt; this.codeExtractionPrompt = codeExtractionPrompt; + this.codeFeatures = codeFeatures; this.aggregationPrompt = aggregationPrompt; } @@ -82,7 +85,7 @@ protected ArDoCoRunner getAndSetupRunner(CodeProject codeProject) { var runner = new ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery(name); runner.setUp(textInput, inputCode, additionalConfigsMap, outputDir, largeLanguageModel, documentationExtractionPrompt, codeExtractionPrompt, - aggregationPrompt); + codeFeatures, aggregationPrompt); return runner; } diff --git a/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java index 4861072..8d2749a 100644 --- a/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java +++ b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java @@ -57,11 +57,14 @@ void evaluateSadCodeTlrIT(CodeProject project, LargeLanguageModel llm) { LLMArchitecturePrompt codePrompt = null; LLMArchitecturePrompt aggPrompt = null; + LLMArchitecturePrompt.Features codeFeatures = LLMArchitecturePrompt.Features.PACKAGES; + logger.info("###############################################"); logger.info("Evaluating project {} with LLM '{}'", project, llm); logger.info("Prompts: {}, {}, {}", docPrompt, codePrompt, aggPrompt); + logger.info("Features: {}", codeFeatures); - var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true, llm, docPrompt, codePrompt, aggPrompt); + var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true, llm, docPrompt, codePrompt, codeFeatures, aggPrompt); var result = evaluation.runTraceLinkEvaluation(project); if (result != null) { RESULTS.put(Tuples.pair(project, llm), result); @@ -91,7 +94,8 @@ static void printResults() { ArDoCoResult result = RESULTS.get(Tuples.pair(project, llm)); // Just some instance .. parameters do not matter .. - var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true, llm, LLMArchitecturePrompt.DOCUMENTATION_ONLY_V1, null, null); + var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true, llm, LLMArchitecturePrompt.DOCUMENTATION_ONLY_V1, null, null, + null); var goldStandard = project.getSadCodeGoldStandard(); goldStandard = TraceabilityLinkRecoveryEvaluation.enrollGoldStandardForCode(goldStandard, result); var evaluationResults = evaluation.calculateEvaluationResults(result, goldStandard);