Skip to content

Commit

Permalink
Provide more parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
dfuchss committed Sep 10, 2024
1 parent c326c3e commit efc79c2
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import edu.kit.kastel.mcse.ardoco.tlr.connectiongenerator.ConnectionGenerator;
import edu.kit.kastel.mcse.ardoco.tlr.models.agents.ArCoTLModelProviderAgent;
import edu.kit.kastel.mcse.ardoco.tlr.models.agents.LLMArchitectureProviderAgent;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LLMArchitecturePrompt;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LargeLanguageModel;
import edu.kit.kastel.mcse.ardoco.tlr.recommendationgenerator.RecommendationGenerator;
import edu.kit.kastel.mcse.ardoco.tlr.text.providers.TextPreprocessingAgent;
Expand All @@ -24,13 +25,15 @@ public ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery(String projectName) {
super(projectName);
}

public void setUp(File inputText, File inputCode, SortedMap<String, String> additionalConfigs, File outputDir, LargeLanguageModel largeLanguageModel) {
definePipeline(inputText, inputCode, additionalConfigs, largeLanguageModel);
public void setUp(File inputText, File inputCode, SortedMap<String, String> additionalConfigs, File outputDir, LargeLanguageModel largeLanguageModel,
LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt aggregationPrompt) {
definePipeline(inputText, inputCode, additionalConfigs, largeLanguageModel, documentationExtractionPrompt, codeExtractionPrompt, aggregationPrompt);
setOutputDirectory(outputDir);
isSetUp = true;
}

private void definePipeline(File inputText, File inputCode, SortedMap<String, String> additionalConfigs, LargeLanguageModel largeLanguageModel) {
private void definePipeline(File inputText, File inputCode, SortedMap<String, String> additionalConfigs, LargeLanguageModel largeLanguageModel,
LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt aggregationPrompt) {
ArDoCo arDoCo = this.getArDoCo();
var dataRepository = arDoCo.getDataRepository();

Expand All @@ -48,7 +51,8 @@ private void definePipeline(File inputText, File inputCode, SortedMap<String, St
codeConfiguration);
arDoCo.addPipelineStep(arCoTLModelProviderAgent);

LLMArchitectureProviderAgent llmArchitectureProviderAgent = new LLMArchitectureProviderAgent(dataRepository, largeLanguageModel);
LLMArchitectureProviderAgent llmArchitectureProviderAgent = new LLMArchitectureProviderAgent(dataRepository, largeLanguageModel,
documentationExtractionPrompt, codeExtractionPrompt, aggregationPrompt);
arDoCo.addPipelineStep(llmArchitectureProviderAgent);

arDoCo.addPipelineStep(TextExtraction.get(additionalConfigs, dataRepository));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

import edu.kit.kastel.mcse.ardoco.core.data.DataRepository;
import edu.kit.kastel.mcse.ardoco.core.pipeline.agent.PipelineAgent;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LLMArchitecturePrompt;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LLMArchitectureProviderInformant;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LargeLanguageModel;

public class LLMArchitectureProviderAgent extends PipelineAgent {

public LLMArchitectureProviderAgent(DataRepository dataRepository, LargeLanguageModel largeLanguageModel) {
super(List.of(new LLMArchitectureProviderInformant(dataRepository, largeLanguageModel)), LLMArchitectureProviderAgent.class.getSimpleName(),
dataRepository);
public LLMArchitectureProviderAgent(DataRepository dataRepository, LargeLanguageModel largeLanguageModel,
LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt aggregationPrompt) {
super(List.of(new LLMArchitectureProviderInformant(dataRepository, largeLanguageModel, documentationExtractionPrompt, codeExtractionPrompt,
aggregationPrompt)), LLMArchitectureProviderAgent.class.getSimpleName(), dataRepository);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/* Licensed under MIT 2024. */
package edu.kit.kastel.mcse.ardoco.tlr.models.informants;

import java.util.List;

public enum LLMArchitecturePrompt {
DOCUMENTATION_ONLY_V1(
"""
Your task is to identify the high-level components based on a software architecture. In a first step, you shall elaborate on the following documentation:
%s
""",
"Now provide a list that only covers the component names. Omit common prefixes and suffixes in the names."),//
CODE_ONLY_V1(
"""
You get the package names of a software project. Your task is to summarize the packages w.r.t. the architecture of the system. Try to identify possible components.
Packages:
%s
""",
"""
You get a summarization and a suggestion of the components of a software project.
Identify the possible component names and list them. Only list the component names. If you don't know what the component is about, omit it.
Summarization:
%s
"""), //
AGGREGATION_V1("""
You get a list of possible component names. Your task is to aggregate the list and remove duplicates.
Also filter out component names, that are very generic. Do not repeat what you filtered out.
Possible component names:
%s
""");

private final List<String> templates;

LLMArchitecturePrompt(String... templates) {
this.templates = List.of(templates);
}

public List<String> getTemplates() {
return templates;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import edu.kit.kastel.mcse.ardoco.core.api.models.ArchitectureModelType;
import edu.kit.kastel.mcse.ardoco.core.api.models.CodeModelType;
import edu.kit.kastel.mcse.ardoco.core.api.models.ModelStates;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.ArchitectureModel;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.CodeModel;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.architecture.ArchitectureComponent;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.architecture.ArchitectureItem;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.code.CodePackage;
import edu.kit.kastel.mcse.ardoco.core.common.util.DataRepositoryHelper;
import edu.kit.kastel.mcse.ardoco.core.data.DataRepository;
import edu.kit.kastel.mcse.ardoco.core.pipeline.agent.Informant;
Expand All @@ -24,46 +27,82 @@ public class LLMArchitectureProviderInformant extends Informant {
private static final String MODEL_STATES_DATA = "ModelStatesData";

private final ChatLanguageModel chatLanguageModel;
private final LLMArchitecturePrompt documentationPrompt;
private final LLMArchitecturePrompt codePrompt;
private final LLMArchitecturePrompt aggregationPrompt;

private static final List<String> TEMPLATES_DOC_TO_ARCHITECTURE = List.of(
"""
Your task is to identify the high-level components based on a software architecture. In a first step, you shall elaborate on the following documentation:
%s
""",
"Now provide a list that only covers the component names. Omit common prefixes and suffixes in the names.");

public LLMArchitectureProviderInformant(DataRepository dataRepository, LargeLanguageModel largeLanguageModel) {
public LLMArchitectureProviderInformant(DataRepository dataRepository, LargeLanguageModel largeLanguageModel, LLMArchitecturePrompt documentation,
LLMArchitecturePrompt code, LLMArchitecturePrompt aggregation) {
super(LLMArchitectureProviderInformant.class.getSimpleName(), dataRepository);
String apiKey = System.getenv("OPENAI_API_KEY");
String orgId = System.getenv("OPENAI_ORG_ID");
if (apiKey == null || orgId == null) {
throw new IllegalArgumentException("OPENAI_API_KEY and OPENAI_ORG_ID must be set as environment variables");
}
this.chatLanguageModel = largeLanguageModel.create();
this.documentationPrompt = documentation;
this.codePrompt = code;
this.aggregationPrompt = aggregation;
if (documentationPrompt == null && codePrompt == null) {
throw new IllegalArgumentException("At least one prompt must be provided");
}
if (documentationPrompt != null && codePrompt != null && aggregationPrompt == null) {
throw new IllegalArgumentException("Aggregation prompt must be provided when both documentation and code prompts are provided");
}
}

@Override
protected void process() {
List<String> componentNames = new ArrayList<>();
documentationToArchitecture(componentNames);
if (documentationPrompt != null)
documentationToArchitecture(componentNames);
if (codePrompt != null)
codeToArchitecture(componentNames);

if (aggregationPrompt != null) {
var aggregation = chatLanguageModel.generate(aggregationPrompt.getTemplates().getFirst().formatted(String.join("\n", componentNames)));
componentNames = new ArrayList<>();
parseComponentNames(aggregation, componentNames);
}

// Remove any not letter characters
componentNames = componentNames.stream().map(it -> it.replaceAll("[^a-zA-Z -_]", "").trim()).filter(it -> !it.isBlank()).distinct().sorted().toList();
componentNames = componentNames.stream()
.map(it -> it.replaceAll("[^a-zA-Z \\-_]", "").replaceAll("\\s+", " ").trim())
.filter(it -> !it.isBlank())
.distinct()
.sorted()
.toList();
logger.info("Component names:\n{}", String.join("\n", componentNames));
buildModel(componentNames);
}

private void documentationToArchitecture(List<String> componentNames) {
var inputText = DataRepositoryHelper.getInputText(dataRepository);
String startMessage = TEMPLATES_DOC_TO_ARCHITECTURE.getFirst().formatted(inputText);
parseComponentsFromAiRequests(componentNames, documentationPrompt.getTemplates(), inputText);
}

private void codeToArchitecture(List<String> componentNames) {
var models = DataRepositoryHelper.getModelStatesData(dataRepository);
CodeModel codeModel = (CodeModel) models.getModel(CodeModelType.CODE_MODEL.getModelId());
if (codeModel == null) {
logger.warn("Code model not found");
return;
}

var packages = codeModel.getAllPackages().stream().filter(it -> it.getContent().size() > 1).toList();
parseComponentsFromAiRequests(componentNames, codePrompt.getTemplates(), String.join("\n", packages.stream().map(this::getPackageName).toList()));
}

private void parseComponentsFromAiRequests(List<String> componentNames, List<String> templates, String dataForFirstPrompt) {
String startMessage = templates.getFirst().formatted(dataForFirstPrompt);
List<ChatMessage> messages = new ArrayList<>();
messages.add(UserMessage.from(startMessage));

var initialResponse = chatLanguageModel.generate(messages).content();
messages.add(initialResponse);
logger.info("Initial Response: {}", initialResponse.text());

for (String nextMessage : TEMPLATES_DOC_TO_ARCHITECTURE.stream().skip(1).toList()) {
for (String nextMessage : templates.stream().skip(1).toList()) {
messages.add(UserMessage.from(nextMessage));
var response = chatLanguageModel.generate(messages).content();
logger.info("Response: {}", response.text());
Expand All @@ -73,6 +112,18 @@ private void documentationToArchitecture(List<String> componentNames) {
parseComponentNames(((AiMessage) messages.getLast()).text(), componentNames);
}

private String getPackageName(CodePackage codePackage) {
List<String> packageName = new ArrayList<>();
packageName.add(codePackage.getName());
var parent = codePackage.getParent();
while (parent != null) {
packageName.add(parent.getName());
parent = parent.getParent();
}
packageName = packageName.reversed();
return String.join(".", packageName);
}

private void parseComponentNames(String response, List<String> componentNames) {
for (String line : response.split("\n")) {
if (line.isBlank()) {
Expand All @@ -94,6 +145,10 @@ else if (line.matches("^([-*])\\s*\\*\\*.*\\*\\*$")) {
// Version 4: - Name
else if (line.matches("^([-*])\\s*.*$")) {
componentNames.add(line.split("([-*])\\s*")[1]);
}
// Version 5: 1. Name (NotImportant) or 2. Name (SomeString)
else if (line.matches("^\\d+\\.\\s*.*\\s*\\(.*\\)$")) {
componentNames.add(line.split("\\.\\s*")[1].split("\\s*\\(.*\\)")[0]);
} else {
logger.warn("Could not parse line: {}", line);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@ public enum LargeLanguageModel {
CODELLAMA_13B(() -> createOllamaModel("codellama:13b")), //
CODELLAMA_70B(() -> createOllamaModel("codellama:70b")), //
//
// DEEPSEEK_CODER_V2_16B(() -> createOllamaModel("deepseek-coder-v2:16b")), //
//
GEMMA_2_27B(() -> createOllamaModel("gemma2:27b")), //
//
QWEN_2_72B(() -> createOllamaModel("qwen2:72b")), //
//
LLAMA_3_1_8B(() -> createOllamaModel("llama3.1:8b-instruct-fp16")), //
LLAMA_3_1_70B(() -> createOllamaModel("llama3.1:70b")), //
//
Expand All @@ -45,6 +49,14 @@ public ChatLanguageModel create() {
return creator.get();
}

public boolean isGeneric() {
return this.name().endsWith("_GENERIC");
}

public boolean isOpenAi() {
return this.name().startsWith("GPT_");
}

private static final int SEED = 422413373;

private static ChatLanguageModel createOpenAiModel(String model) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,32 @@
import edu.kit.kastel.mcse.ardoco.core.data.DataRepository;
import edu.kit.kastel.mcse.ardoco.core.execution.runner.ArDoCoRunner;
import edu.kit.kastel.mcse.ardoco.core.tests.eval.CodeProject;
import edu.kit.kastel.mcse.ardoco.core.tests.eval.results.EvaluationResults;
import edu.kit.kastel.mcse.ardoco.core.tests.eval.results.ExpectedResults;
import edu.kit.kastel.mcse.ardoco.tlr.execution.ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LLMArchitecturePrompt;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LargeLanguageModel;

class SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation extends TraceabilityLinkRecoveryEvaluation<CodeProject> {
private final boolean acmFile;
private final LargeLanguageModel largeLanguageModel;
private final LLMArchitecturePrompt documentationExtractionPrompt;
private final LLMArchitecturePrompt codeExtractionPrompt;
private final LLMArchitecturePrompt aggregationPrompt;

public SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(boolean acmFile, LargeLanguageModel largeLanguageModel) {
public SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(boolean acmFile, LargeLanguageModel largeLanguageModel,
LLMArchitecturePrompt documentationExtractionPrompt, LLMArchitecturePrompt codeExtractionPrompt, LLMArchitecturePrompt aggregationPrompt) {
super();
this.acmFile = acmFile;
this.largeLanguageModel = largeLanguageModel;
this.documentationExtractionPrompt = documentationExtractionPrompt;
this.codeExtractionPrompt = codeExtractionPrompt;
this.aggregationPrompt = aggregationPrompt;
}

@Override
protected void compareResults(EvaluationResults<String> results, ExpectedResults expectedResults) {
// Disable Asserts. We want to see all results.
}

@Override
Expand All @@ -49,7 +63,8 @@ protected ArDoCoRunner getAndSetupRunner(CodeProject codeProject) {
File outputDir = new File(OUTPUT);

var runner = new ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery(name);
runner.setUp(textInput, inputCode, additionalConfigsMap, outputDir, largeLanguageModel);
runner.setUp(textInput, inputCode, additionalConfigsMap, outputDir, largeLanguageModel, documentationExtractionPrompt, codeExtractionPrompt,
aggregationPrompt);
return runner;
}

Expand Down
Loading

0 comments on commit efc79c2

Please sign in to comment.