Skip to content

Commit

Permalink
Initial version using LLMs
Browse files Browse the repository at this point in the history
  • Loading branch information
dfuchss committed Sep 5, 2024
1 parent dcbb6bf commit 9ecf51d
Show file tree
Hide file tree
Showing 6 changed files with 426 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/* Licensed under MIT 2023-2024. */
package edu.kit.kastel.mcse.ardoco.tlr.execution;

import java.io.File;
import java.util.SortedMap;

import edu.kit.kastel.mcse.ardoco.core.common.util.CommonUtilities;
import edu.kit.kastel.mcse.ardoco.core.common.util.DataRepositoryHelper;
import edu.kit.kastel.mcse.ardoco.core.execution.ArDoCo;
import edu.kit.kastel.mcse.ardoco.core.execution.runner.ArDoCoRunner;
import edu.kit.kastel.mcse.ardoco.tlr.codetraceability.SadSamCodeTraceabilityLinkRecovery;
import edu.kit.kastel.mcse.ardoco.tlr.codetraceability.SamCodeTraceabilityLinkRecovery;
import edu.kit.kastel.mcse.ardoco.tlr.connectiongenerator.ConnectionGenerator;
import edu.kit.kastel.mcse.ardoco.tlr.models.agents.ArCoTLModelProviderAgent;
import edu.kit.kastel.mcse.ardoco.tlr.models.agents.LLMArchitectureProviderAgent;
import edu.kit.kastel.mcse.ardoco.tlr.recommendationgenerator.RecommendationGenerator;
import edu.kit.kastel.mcse.ardoco.tlr.text.providers.TextPreprocessingAgent;
import edu.kit.kastel.mcse.ardoco.tlr.textextraction.TextExtraction;

public class ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery extends ArDoCoRunner {

public ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery(String projectName) {
super(projectName);
}

public void setUp(File inputText, File inputCode, SortedMap<String, String> additionalConfigs, File outputDir) {
definePipeline(inputText, inputCode, additionalConfigs);
setOutputDirectory(outputDir);
isSetUp = true;
}

private void definePipeline(File inputText, File inputCode, SortedMap<String, String> additionalConfigs) {
ArDoCo arDoCo = this.getArDoCo();
var dataRepository = arDoCo.getDataRepository();

var text = CommonUtilities.readInputText(inputText);
if (text.isBlank()) {
throw new IllegalArgumentException("Cannot deal with empty input text. Maybe there was an error reading the file.");
}
DataRepositoryHelper.putInputText(dataRepository, text);

arDoCo.addPipelineStep(TextPreprocessingAgent.get(additionalConfigs, dataRepository));

var codeConfiguration = ArCoTLModelProviderAgent.getCodeConfiguration(inputCode);

ArCoTLModelProviderAgent arCoTLModelProviderAgent = ArCoTLModelProviderAgent.getArCoTLModelProviderAgent(dataRepository, additionalConfigs, null,
codeConfiguration);
arDoCo.addPipelineStep(arCoTLModelProviderAgent);

LLMArchitectureProviderAgent llmArchitectureProviderAgent = new LLMArchitectureProviderAgent(dataRepository);
arDoCo.addPipelineStep(llmArchitectureProviderAgent);

arDoCo.addPipelineStep(TextExtraction.get(additionalConfigs, dataRepository));
arDoCo.addPipelineStep(RecommendationGenerator.get(additionalConfigs, dataRepository));
arDoCo.addPipelineStep(ConnectionGenerator.get(additionalConfigs, dataRepository));

arDoCo.addPipelineStep(SamCodeTraceabilityLinkRecovery.get(additionalConfigs, dataRepository));

arDoCo.addPipelineStep(SadSamCodeTraceabilityLinkRecovery.get(additionalConfigs, dataRepository));
}
}
19 changes: 19 additions & 0 deletions stages-tlr/model-provider/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
</parent>
<artifactId>model-provider</artifactId>

<properties>
<langchain4j.version>0.33.0</langchain4j.version>
</properties>

<dependencies>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
Expand All @@ -33,6 +37,21 @@
<artifactId>commons-io</artifactId>
<version>2.15.1</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>io.github.ardoco.core</groupId>
<artifactId>common</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/* Licensed under MIT 2024. */
package edu.kit.kastel.mcse.ardoco.tlr.models.agents;

import java.util.List;

import edu.kit.kastel.mcse.ardoco.core.data.DataRepository;
import edu.kit.kastel.mcse.ardoco.core.pipeline.agent.PipelineAgent;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LLMArchitectureProviderInformant;

public class LLMArchitectureProviderAgent extends PipelineAgent {

public LLMArchitectureProviderAgent(DataRepository dataRepository) {
super(List.of(new LLMArchitectureProviderInformant(dataRepository)), LLMArchitectureProviderAgent.class.getSimpleName(), dataRepository);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/* Licensed under MIT 2024. */
package edu.kit.kastel.mcse.ardoco.tlr.models.informants;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.TreeSet;
import java.util.stream.Collectors;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiChatModelName;
import edu.kit.kastel.mcse.ardoco.core.api.models.ArchitectureModelType;
import edu.kit.kastel.mcse.ardoco.core.api.models.CodeModelType;
import edu.kit.kastel.mcse.ardoco.core.api.models.ModelStates;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.ArchitectureModel;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.CodeModel;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.architecture.ArchitectureComponent;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.architecture.ArchitectureItem;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.code.CodePackage;
import edu.kit.kastel.mcse.ardoco.core.common.util.DataRepositoryHelper;
import edu.kit.kastel.mcse.ardoco.core.data.DataRepository;
import edu.kit.kastel.mcse.ardoco.core.pipeline.agent.Informant;

public class LLMArchitectureProviderInformant extends Informant {
private static final String MODEL_STATES_DATA = "ModelStatesData";

private final ChatLanguageModel chatLanguageModel;

private static final List<String> TEMPLATES_DOC_TO_ARCHITECTURE = List.of(
"""
Your task is to identify the high-level components based on a software architecture. In a first step, you shall elaborate on the following documentation:
%s
""",
"Now provide a list that only covers the component names. Omit common prefixes and suffixes in the names.");

private static final List<String> TEMPLATES_CODE_TO_ARCHITECTURE = List.of(
"""
You get the package names of a software project. Your task is to summarize the packages w.r.t. the architecture of the system. Try to identify possible components.
Packages:
%s
""",
"""
You get a summarization and a suggestion of the components of a software project.
Identify the possible component names and list them. Only list the component names. If you don't know what the component is about, omit it.
Summarization:
%s
""");

private static final String AGGREGATION_PROMPT = """
You get a list of possible component names. Your task is to aggregate the list and remove duplicates.
Also filter out component names, that are very generic. Do not repeat what you filtered out.
Possible component names:
%s
""";

public LLMArchitectureProviderInformant(DataRepository dataRepository) {
super(LLMArchitectureProviderInformant.class.getSimpleName(), dataRepository);
String apiKey = System.getenv("OPENAI_API_KEY");
String orgId = System.getenv("OPENAI_ORG_ID");
if (apiKey == null || orgId == null) {
throw new IllegalArgumentException("OPENAI_API_KEY and OPENAI_ORG_ID must be set as environment variables");
}
this.chatLanguageModel = new OpenAiChatModel.OpenAiChatModelBuilder().modelName(OpenAiChatModelName.GPT_4_O)
.apiKey(apiKey)
.organizationId(orgId)
.seed(422413373)
.temperature(0.0)
.build();
}

@Override
protected void process() {
List<String> componentNames = new ArrayList<>();
documentationToArchitecture(componentNames);
// codeToArchitecture(componentNames);
// Remove any not letter characters
componentNames = componentNames.stream().map(it -> it.replaceAll("[^a-zA-Z -_]", "").trim()).filter(it -> !it.isBlank()).distinct().sorted().toList();

/*
var aggregation = chatLanguageModel.generate(AGGREGATION_PROMPT.formatted(String.join("\n", componentNames)));
componentNames = new ArrayList<>();
parseComponentNames(aggregation, componentNames);
*/
logger.info("Component names:\n{}", String.join("\n", componentNames));

buildModel(componentNames);

}

private void codeToArchitecture(List<String> componentNames) {
var models = DataRepositoryHelper.getModelStatesData(dataRepository);
CodeModel codeModel = (CodeModel) models.getModel(CodeModelType.CODE_MODEL.getModelId());
if (codeModel == null) {
logger.warn("Code model not found");
return;
}

var packages = codeModel.getAllPackages().stream().filter(it -> it.getContent().size() > 1).toList();

List<String> responses = new ArrayList<>();
responses.add(String.join("\n", packages.stream().map(it -> getPackageName(it)).toList()));
for (String template : TEMPLATES_CODE_TO_ARCHITECTURE) {
var filledTemplate = template.formatted(responses.getLast());
var response = chatLanguageModel.generate(filledTemplate);
logger.info("Response: {}", response);
responses.add(response);
}
parseComponentNames(responses.getLast(), componentNames);

}

private String getPackageName(CodePackage codePackage) {
List<String> packageName = new ArrayList<>();
packageName.add(codePackage.getName());
var parent = codePackage.getParent();
while (parent != null) {
packageName.add(parent.getName());
parent = parent.getParent();
}
packageName = packageName.reversed();
return String.join(".", packageName);
}

private void documentationToArchitecture(List<String> componentNames) {
var inputText = DataRepositoryHelper.getInputText(dataRepository);
String startMessage = TEMPLATES_DOC_TO_ARCHITECTURE.getFirst().formatted(inputText);
List<ChatMessage> messages = new ArrayList<>();
messages.add(UserMessage.from(startMessage));

var initialResponse = chatLanguageModel.generate(messages).content();
messages.add(initialResponse);
logger.info("Initial Response: {}", initialResponse.text());

for (String nextMessage : TEMPLATES_DOC_TO_ARCHITECTURE.stream().skip(1).toList()) {
messages.add(UserMessage.from(nextMessage));
var response = chatLanguageModel.generate(messages).content();
logger.info("Response: {}", response.text());
messages.add(response);
}

parseComponentNames(((AiMessage) messages.getLast()).text(), componentNames);
}

private void parseComponentNames(String response, List<String> componentNames) {
for (String line : response.split("\n")) {
if (line.isBlank()) {
continue;
}
line = line.trim();
// Version 1: 1. **Name** or 2. **Name**
if (line.matches("^\\d+\\.\\s*\\*\\*.*\\*\\*$")) {
componentNames.add(line.split("\\*\\*")[1]);
}
// Version 2: 1. Name or 2. Name
else if (line.matches("^\\d+\\.\\s*.*$")) {
componentNames.add(line.split("\\.\\s*")[1]);
}
// Version 3: - **Name**
else if (line.matches("^-\\s*\\*\\*.*\\*\\*$")) {
componentNames.add(line.split("\\*\\*")[1]);
}
// Version 4: - Name
else if (line.matches("^-\\s*.*$")) {
componentNames.add(line.split("-\\s*")[1]);
} else {
logger.warn("Could not parse line: {}", line);
}
}
}

private void buildModel(List<String> componentNames) {
List<ArchitectureItem> componentList = componentNames.stream()
.map(it -> new ArchitectureComponent(it.replace("Component", "").trim(), it, new TreeSet<>(), new TreeSet<>(), new TreeSet<>(), "Component"))
.collect(Collectors.toList());
ArchitectureModel am = new ArchitectureModel(componentList);
Optional<ModelStates> modelStatesOptional = dataRepository.getData(MODEL_STATES_DATA, ModelStates.class);
var modelStates = modelStatesOptional.orElseGet(ModelStates::new);

modelStates.addModel(ArchitectureModelType.PCM.getModelId(), am);

if (modelStatesOptional.isEmpty()) {
dataRepository.addData(MODEL_STATES_DATA, modelStates);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/* Licensed under MIT 2023-2024. */
package edu.kit.kastel.mcse.ardoco.tlr.tests.integration;

import static edu.kit.kastel.mcse.ardoco.tlr.tests.integration.TraceLinkEvaluationIT.OUTPUT;

import java.io.File;
import java.util.SortedMap;
import java.util.TreeMap;

import org.eclipse.collections.api.factory.Lists;
import org.eclipse.collections.api.list.ImmutableList;

import edu.kit.kastel.mcse.ardoco.core.api.models.CodeModelType;
import edu.kit.kastel.mcse.ardoco.core.api.models.ModelStates;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.Model;
import edu.kit.kastel.mcse.ardoco.core.api.output.ArDoCoResult;
import edu.kit.kastel.mcse.ardoco.core.api.text.Text;
import edu.kit.kastel.mcse.ardoco.core.common.util.DataRepositoryHelper;
import edu.kit.kastel.mcse.ardoco.core.common.util.TraceLinkUtilities;
import edu.kit.kastel.mcse.ardoco.core.data.DataRepository;
import edu.kit.kastel.mcse.ardoco.core.execution.runner.ArDoCoRunner;
import edu.kit.kastel.mcse.ardoco.core.tests.eval.CodeProject;
import edu.kit.kastel.mcse.ardoco.core.tests.eval.results.ExpectedResults;
import edu.kit.kastel.mcse.ardoco.tlr.execution.ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery;

class SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation extends TraceabilityLinkRecoveryEvaluation<CodeProject> {
private final boolean acmFile;

public SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(boolean acmFile) {
super();
this.acmFile = acmFile;
}

@Override
protected boolean resultHasRequiredData(ArDoCoResult arDoCoResult) {
var traceLinks = arDoCoResult.getSadCodeTraceLinks();
return !traceLinks.isEmpty();
}

@Override
protected ArDoCoRunner getAndSetupRunner(CodeProject codeProject) {
String name = codeProject.name().toLowerCase();
File textInput = codeProject.getTextFile();
File inputCode = getInputCode(codeProject, acmFile);
SortedMap<String, String> additionalConfigsMap = new TreeMap<>();
File outputDir = new File(OUTPUT);

var runner = new ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery(name);
runner.setUp(textInput, inputCode, additionalConfigsMap, outputDir);
return runner;
}

@Override
protected ImmutableList<String> createTraceLinkStringList(ArDoCoResult arDoCoResult) {
var traceLinks = arDoCoResult.getSadCodeTraceLinks();

return TraceLinkUtilities.getSadCodeTraceLinksAsStringList(Lists.immutable.ofAll(traceLinks));
}

@Override
protected ImmutableList<String> getGoldStandard(CodeProject codeProject) {
return codeProject.getSadCodeGoldStandard();
}

@Override
protected ImmutableList<String> enrollGoldStandard(ImmutableList<String> goldStandard, ArDoCoResult result) {
return enrollGoldStandardForCode(goldStandard, result);
}

@Override
protected ExpectedResults getExpectedResults(CodeProject codeProject) {
return codeProject.getExpectedResultsForSadSamCode();
}

@Override
protected int getConfusionMatrixSum(ArDoCoResult arDoCoResult) {
DataRepository dataRepository = arDoCoResult.dataRepository();

Text text = DataRepositoryHelper.getAnnotatedText(dataRepository);
int sentences = text.getSentences().size();

ModelStates modelStatesData = DataRepositoryHelper.getModelStatesData(dataRepository);
Model codeModel = modelStatesData.getModel(CodeModelType.CODE_MODEL.getModelId());
var codeModelEndpoints = codeModel.getEndpoints().size();

return sentences * codeModelEndpoints;
}
}
Loading

0 comments on commit 9ecf51d

Please sign in to comment.