Skip to content

Commit

Permalink
Add code features
Browse files Browse the repository at this point in the history
  • Loading branch information
dfuchss committed Sep 30, 2024
1 parent 01c611c commit d66383a
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ public ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery(String projectName) {
}

public void setUp(File inputText, File inputCode, SortedMap<String, String> additionalConfigs, File outputDir, LargeLanguageModel largeLanguageModel,
LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt aggregationPrompt) {
definePipeline(inputText, inputCode, additionalConfigs, largeLanguageModel, documentationExtractionPrompt, codeExtractionPrompt, aggregationPrompt);
LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt.Features codeFeatures,
LLMArchitecturePrompt aggregationPrompt) {
definePipeline(inputText, inputCode, additionalConfigs, largeLanguageModel, documentationExtractionPrompt, codeExtractionPrompt, codeFeatures,
aggregationPrompt);
setOutputDirectory(outputDir);
isSetUp = true;
}

private void definePipeline(File inputText, File inputCode, SortedMap<String, String> additionalConfigs, LargeLanguageModel largeLanguageModel,
LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt aggregationPrompt) {
LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt.Features codeFeatures,
LLMArchitecturePrompt aggregationPrompt) {
ArDoCo arDoCo = this.getArDoCo();
var dataRepository = arDoCo.getDataRepository();

Expand All @@ -52,7 +55,7 @@ private void definePipeline(File inputText, File inputCode, SortedMap<String, St
arDoCo.addPipelineStep(arCoTLModelProviderAgent);

LLMArchitectureProviderAgent llmArchitectureProviderAgent = new LLMArchitectureProviderAgent(dataRepository, largeLanguageModel,
documentationExtractionPrompt, codeExtractionPrompt, aggregationPrompt);
documentationExtractionPrompt, codeExtractionPrompt, codeFeatures, aggregationPrompt);
arDoCo.addPipelineStep(llmArchitectureProviderAgent);

arDoCo.addPipelineStep(TextExtraction.get(additionalConfigs, dataRepository));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
public class LLMArchitectureProviderAgent extends PipelineAgent {

public LLMArchitectureProviderAgent(DataRepository dataRepository, LargeLanguageModel largeLanguageModel,
LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt aggregationPrompt) {
LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt.Features codeFeatures,
LLMArchitecturePrompt aggregationPrompt) {
super(List.of(new LLMArchitectureProviderInformant(dataRepository, largeLanguageModel, documentationExtractionPrompt, codeExtractionPrompt,
aggregationPrompt)), LLMArchitectureProviderAgent.class.getSimpleName(), dataRepository);
codeFeatures, aggregationPrompt)), LLMArchitectureProviderAgent.class.getSimpleName(), dataRepository);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ public enum LLMArchitecturePrompt {
"""), //
CODE_ONLY_V1(
"""
You get the Packages of a software project. Your task is to summarize the Packages w.r.t. the high-level architecture of the system. Try to identify possible components.
You get the {FEATURES} of a software project. Your task is to summarize the {FEATURES} w.r.t. the high-level architecture of the system. Try to identify possible components.
Packages:
{FEATURES}:
%s
""",
Expand Down Expand Up @@ -49,6 +49,21 @@ public enum LLMArchitecturePrompt {
}

public List<String> getTemplates() {
if (this == CODE_ONLY_V1)
throw new IllegalArgumentException("This method is not supported for this enum value");
return templates;
}

public List<String> getTemplates(Features features) {
return templates.stream().map(it -> it.replace("{FEATURES}", features.toString())).toList();
}

public enum Features {
PACKAGES, PACKAGES_AND_THEIR_CLASSES;

@Override
public String toString() {
return super.toString().charAt(0) + super.toString().toLowerCase().substring(1).replace("_", " ");
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
/* Licensed under MIT 2024. */
package edu.kit.kastel.mcse.ardoco.tlr.models.informants;

import static edu.kit.kastel.mcse.ardoco.tlr.models.informants.LLMArchitecturePrompt.Features.PACKAGES;
import static edu.kit.kastel.mcse.ardoco.tlr.models.informants.LLMArchitecturePrompt.Features.PACKAGES_AND_THEIR_CLASSES;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand All @@ -15,6 +18,7 @@
import dev.langchain4j.model.chat.ChatLanguageModel;
import edu.kit.kastel.mcse.ardoco.core.api.models.ArchitectureModelType;
import edu.kit.kastel.mcse.ardoco.core.api.models.CodeModelType;
import edu.kit.kastel.mcse.ardoco.core.api.models.Entity;
import edu.kit.kastel.mcse.ardoco.core.api.models.ModelStates;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.ArchitectureModel;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.CodeModel;
Expand All @@ -35,9 +39,10 @@ public class LLMArchitectureProviderInformant extends Informant {
private final LLMArchitecturePrompt documentationPrompt;
private final LLMArchitecturePrompt codePrompt;
private final LLMArchitecturePrompt aggregationPrompt;
private final LLMArchitecturePrompt.Features codeFeature;

public LLMArchitectureProviderInformant(DataRepository dataRepository, LargeLanguageModel largeLanguageModel, LLMArchitecturePrompt documentation,
LLMArchitecturePrompt code, LLMArchitecturePrompt aggregation) {
LLMArchitecturePrompt code, LLMArchitecturePrompt.Features codeFeature, LLMArchitecturePrompt aggregation) {
super(LLMArchitectureProviderInformant.class.getSimpleName(), dataRepository);
String apiKey = System.getenv("OPENAI_API_KEY");
String orgId = System.getenv("OPENAI_ORG_ID");
Expand All @@ -47,13 +52,17 @@ public LLMArchitectureProviderInformant(DataRepository dataRepository, LargeLang
this.chatLanguageModel = largeLanguageModel.create();
this.documentationPrompt = documentation;
this.codePrompt = code;
this.codeFeature = codeFeature;
this.aggregationPrompt = aggregation;
if (documentationPrompt == null && codePrompt == null) {
throw new IllegalArgumentException("At least one prompt must be provided");
}
if (documentationPrompt != null && codePrompt != null && aggregationPrompt == null) {
logger.info("Using Similarity Metrics to aggregate the component names");
}
if (codePrompt != null && codeFeature == null) {
throw new IllegalArgumentException("Code prompt requires a code feature");
}
}

@Override
Expand Down Expand Up @@ -121,8 +130,26 @@ private void codeToArchitecture(List<String> componentNames) {
return;
}

var packages = codeModel.getAllPackages().stream().filter(it -> !it.getContent().isEmpty()).toList();
parseComponentsFromAiRequests(componentNames, codePrompt.getTemplates(), String.join("\n", packages.stream().map(this::getPackageName).toList()));
switch (this.codeFeature) {
case PACKAGES -> {
var packages = codeModel.getAllPackages().stream().filter(it -> !it.getContent().isEmpty()).toList();
parseComponentsFromAiRequests(componentNames, codePrompt.getTemplates(PACKAGES), String.join("\n", packages.stream()
.map(this::getPackageName)
.toList()));
}
case PACKAGES_AND_THEIR_CLASSES -> {
var packages = codeModel.getAllPackages().stream().filter(it -> !it.getContent().isEmpty()).toList();

var packagesWithClasses = packages.stream().map(p -> {
var packageName = getPackageName(p);
var classes = p.getContent().stream().flatMap(it -> it.getAllCompilationUnits().stream()).map(Entity::getName).distinct().sorted().toList();
return packageName + " (" + String.join(", ", classes) + ")";
}).toList();

parseComponentsFromAiRequests(componentNames, codePrompt.getTemplates(PACKAGES_AND_THEIR_CLASSES), String.join("\n", packagesWithClasses));
}
}

}

private void parseComponentsFromAiRequests(List<String> componentNames, List<String> templates, String dataForFirstPrompt) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,17 @@ class SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation extends TraceabilityLin
private final LLMArchitecturePrompt documentationExtractionPrompt;
private final LLMArchitecturePrompt codeExtractionPrompt;
private final LLMArchitecturePrompt aggregationPrompt;
private final LLMArchitecturePrompt.Features codeFeatures;

public SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(boolean acmFile, LargeLanguageModel largeLanguageModel,
LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt aggregationPrompt) {
LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt.Features codeFeatures,
LLMArchitecturePrompt aggregationPrompt) {
super();
this.acmFile = acmFile;
this.largeLanguageModel = largeLanguageModel;
this.documentationExtractionPrompt = documentationExtractionPrompt;
this.codeExtractionPrompt = codeExtractionPrompt;
this.codeFeatures = codeFeatures;
this.aggregationPrompt = aggregationPrompt;
}

Expand Down Expand Up @@ -82,7 +85,7 @@ protected ArDoCoRunner getAndSetupRunner(CodeProject codeProject) {

var runner = new ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery(name);
runner.setUp(textInput, inputCode, additionalConfigsMap, outputDir, largeLanguageModel, documentationExtractionPrompt, codeExtractionPrompt,
aggregationPrompt);
codeFeatures, aggregationPrompt);
return runner;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,14 @@ void evaluateSadCodeTlrIT(CodeProject project, LargeLanguageModel llm) {
LLMArchitecturePrompt codePrompt = null;
LLMArchitecturePrompt aggPrompt = null;

LLMArchitecturePrompt.Features codeFeatures = LLMArchitecturePrompt.Features.PACKAGES;

logger.info("###############################################");
logger.info("Evaluating project {} with LLM '{}'", project, llm);
logger.info("Prompts: {}, {}, {}", docPrompt, codePrompt, aggPrompt);
logger.info("Features: {}", codeFeatures);

var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true, llm, docPrompt, codePrompt, aggPrompt);
var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true, llm, docPrompt, codePrompt, codeFeatures, aggPrompt);
var result = evaluation.runTraceLinkEvaluation(project);
if (result != null) {
RESULTS.put(Tuples.pair(project, llm), result);
Expand Down Expand Up @@ -91,7 +94,8 @@ static void printResults() {
ArDoCoResult result = RESULTS.get(Tuples.pair(project, llm));

// Just some instance .. parameters do not matter ..
var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true, llm, LLMArchitecturePrompt.DOCUMENTATION_ONLY_V1, null, null);
var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true, llm, LLMArchitecturePrompt.DOCUMENTATION_ONLY_V1, null, null,
null);
var goldStandard = project.getSadCodeGoldStandard();
goldStandard = TraceabilityLinkRecoveryEvaluation.enrollGoldStandardForCode(goldStandard, result);
var evaluationResults = evaluation.calculateEvaluationResults(result, goldStandard);
Expand Down

0 comments on commit d66383a

Please sign in to comment.