diff --git a/pipeline-tlr/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/execution/ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery.java b/pipeline-tlr/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/execution/ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery.java new file mode 100644 index 0000000..d112565 --- /dev/null +++ b/pipeline-tlr/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/execution/ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery.java @@ -0,0 +1,61 @@ +/* Licensed under MIT 2023-2024. */ +package edu.kit.kastel.mcse.ardoco.tlr.execution; + +import java.io.File; +import java.util.SortedMap; + +import edu.kit.kastel.mcse.ardoco.core.common.util.CommonUtilities; +import edu.kit.kastel.mcse.ardoco.core.common.util.DataRepositoryHelper; +import edu.kit.kastel.mcse.ardoco.core.execution.ArDoCo; +import edu.kit.kastel.mcse.ardoco.core.execution.runner.ArDoCoRunner; +import edu.kit.kastel.mcse.ardoco.tlr.codetraceability.SadSamCodeTraceabilityLinkRecovery; +import edu.kit.kastel.mcse.ardoco.tlr.codetraceability.SamCodeTraceabilityLinkRecovery; +import edu.kit.kastel.mcse.ardoco.tlr.connectiongenerator.ConnectionGenerator; +import edu.kit.kastel.mcse.ardoco.tlr.models.agents.ArCoTLModelProviderAgent; +import edu.kit.kastel.mcse.ardoco.tlr.models.agents.LLMArchitectureProviderAgent; +import edu.kit.kastel.mcse.ardoco.tlr.recommendationgenerator.RecommendationGenerator; +import edu.kit.kastel.mcse.ardoco.tlr.text.providers.TextPreprocessingAgent; +import edu.kit.kastel.mcse.ardoco.tlr.textextraction.TextExtraction; + +public class ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery extends ArDoCoRunner { + + public ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery(String projectName) { + super(projectName); + } + + public void setUp(File inputText, File inputCode, SortedMap additionalConfigs, File outputDir) { + definePipeline(inputText, inputCode, additionalConfigs); + setOutputDirectory(outputDir); + isSetUp = true; + } + + private void definePipeline(File inputText, File inputCode, SortedMap additionalConfigs) { + ArDoCo arDoCo = this.getArDoCo(); + var dataRepository = arDoCo.getDataRepository(); + + var text = CommonUtilities.readInputText(inputText); + if (text.isBlank()) { + throw new IllegalArgumentException("Cannot deal with empty input text. Maybe there was an error reading the file."); + } + DataRepositoryHelper.putInputText(dataRepository, text); + + arDoCo.addPipelineStep(TextPreprocessingAgent.get(additionalConfigs, dataRepository)); + + var codeConfiguration = ArCoTLModelProviderAgent.getCodeConfiguration(inputCode); + + ArCoTLModelProviderAgent arCoTLModelProviderAgent = ArCoTLModelProviderAgent.getArCoTLModelProviderAgent(dataRepository, additionalConfigs, null, + codeConfiguration); + arDoCo.addPipelineStep(arCoTLModelProviderAgent); + + LLMArchitectureProviderAgent llmArchitectureProviderAgent = new LLMArchitectureProviderAgent(dataRepository); + arDoCo.addPipelineStep(llmArchitectureProviderAgent); + + arDoCo.addPipelineStep(TextExtraction.get(additionalConfigs, dataRepository)); + arDoCo.addPipelineStep(RecommendationGenerator.get(additionalConfigs, dataRepository)); + arDoCo.addPipelineStep(ConnectionGenerator.get(additionalConfigs, dataRepository)); + + arDoCo.addPipelineStep(SamCodeTraceabilityLinkRecovery.get(additionalConfigs, dataRepository)); + + arDoCo.addPipelineStep(SadSamCodeTraceabilityLinkRecovery.get(additionalConfigs, dataRepository)); + } +} diff --git a/stages-tlr/model-provider/pom.xml b/stages-tlr/model-provider/pom.xml index 6f94306..e441ecd 100644 --- a/stages-tlr/model-provider/pom.xml +++ b/stages-tlr/model-provider/pom.xml @@ -9,6 +9,10 @@ model-provider + + 0.33.0 + + com.fasterxml.jackson.core @@ -33,6 +37,21 @@ commons-io 2.15.1 + + dev.langchain4j + langchain4j + ${langchain4j.version} + + + dev.langchain4j + langchain4j-core + ${langchain4j.version} + + + dev.langchain4j + langchain4j-open-ai + ${langchain4j.version} + io.github.ardoco.core common diff --git a/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/agents/LLMArchitectureProviderAgent.java b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/agents/LLMArchitectureProviderAgent.java new file mode 100644 index 0000000..108b31c --- /dev/null +++ b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/agents/LLMArchitectureProviderAgent.java @@ -0,0 +1,15 @@ +/* Licensed under MIT 2024. */ +package edu.kit.kastel.mcse.ardoco.tlr.models.agents; + +import java.util.List; + +import edu.kit.kastel.mcse.ardoco.core.data.DataRepository; +import edu.kit.kastel.mcse.ardoco.core.pipeline.agent.PipelineAgent; +import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LLMArchitectureProviderInformant; + +public class LLMArchitectureProviderAgent extends PipelineAgent { + + public LLMArchitectureProviderAgent(DataRepository dataRepository) { + super(List.of(new LLMArchitectureProviderInformant(dataRepository)), LLMArchitectureProviderAgent.class.getSimpleName(), dataRepository); + } +} diff --git a/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java new file mode 100644 index 0000000..c663c01 --- /dev/null +++ b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java @@ -0,0 +1,196 @@ +/* Licensed under MIT 2024. */ +package edu.kit.kastel.mcse.ardoco.tlr.models.informants; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.TreeSet; +import java.util.stream.Collectors; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.model.openai.OpenAiChatModelName; +import edu.kit.kastel.mcse.ardoco.core.api.models.ArchitectureModelType; +import edu.kit.kastel.mcse.ardoco.core.api.models.CodeModelType; +import edu.kit.kastel.mcse.ardoco.core.api.models.ModelStates; +import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.ArchitectureModel; +import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.CodeModel; +import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.architecture.ArchitectureComponent; +import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.architecture.ArchitectureItem; +import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.code.CodePackage; +import edu.kit.kastel.mcse.ardoco.core.common.util.DataRepositoryHelper; +import edu.kit.kastel.mcse.ardoco.core.data.DataRepository; +import edu.kit.kastel.mcse.ardoco.core.pipeline.agent.Informant; + +public class LLMArchitectureProviderInformant extends Informant { + private static final String MODEL_STATES_DATA = "ModelStatesData"; + + private final ChatLanguageModel chatLanguageModel; + + private static final List TEMPLATES_DOC_TO_ARCHITECTURE = List.of( + """ + Your task is to identify the high-level components based on a software architecture. In a first step, you shall elaborate on the following documentation: + + %s + """, + "Now provide a list that only covers the component names. Omit common prefixes and suffixes in the names."); + + private static final List TEMPLATES_CODE_TO_ARCHITECTURE = List.of( + """ + You get the package names of a software project. Your task is to summarize the packages w.r.t. the architecture of the system. Try to identify possible components. + + Packages: + + %s + """, + """ + You get a summarization and a suggestion of the components of a software project. + Identify the possible component names and list them. Only list the component names. If you don't know what the component is about, omit it. + + Summarization: + + %s + """); + + private static final String AGGREGATION_PROMPT = """ + You get a list of possible component names. Your task is to aggregate the list and remove duplicates. + Also filter out component names, that are very generic. Do not repeat what you filtered out. + + Possible component names: + + %s + """; + + public LLMArchitectureProviderInformant(DataRepository dataRepository) { + super(LLMArchitectureProviderInformant.class.getSimpleName(), dataRepository); + String apiKey = System.getenv("OPENAI_API_KEY"); + String orgId = System.getenv("OPENAI_ORG_ID"); + if (apiKey == null || orgId == null) { + throw new IllegalArgumentException("OPENAI_API_KEY and OPENAI_ORG_ID must be set as environment variables"); + } + this.chatLanguageModel = new OpenAiChatModel.OpenAiChatModelBuilder().modelName(OpenAiChatModelName.GPT_4_O) + .apiKey(apiKey) + .organizationId(orgId) + .seed(422413373) + .temperature(0.0) + .build(); + } + + @Override + protected void process() { + List componentNames = new ArrayList<>(); + documentationToArchitecture(componentNames); + // codeToArchitecture(componentNames); + // Remove any not letter characters + componentNames = componentNames.stream().map(it -> it.replaceAll("[^a-zA-Z -_]", "").trim()).filter(it -> !it.isBlank()).distinct().sorted().toList(); + + /* + var aggregation = chatLanguageModel.generate(AGGREGATION_PROMPT.formatted(String.join("\n", componentNames))); + componentNames = new ArrayList<>(); + parseComponentNames(aggregation, componentNames); + */ + logger.info("Component names:\n{}", String.join("\n", componentNames)); + + buildModel(componentNames); + + } + + private void codeToArchitecture(List componentNames) { + var models = DataRepositoryHelper.getModelStatesData(dataRepository); + CodeModel codeModel = (CodeModel) models.getModel(CodeModelType.CODE_MODEL.getModelId()); + if (codeModel == null) { + logger.warn("Code model not found"); + return; + } + + var packages = codeModel.getAllPackages().stream().filter(it -> it.getContent().size() > 1).toList(); + + List responses = new ArrayList<>(); + responses.add(String.join("\n", packages.stream().map(it -> getPackageName(it)).toList())); + for (String template : TEMPLATES_CODE_TO_ARCHITECTURE) { + var filledTemplate = template.formatted(responses.getLast()); + var response = chatLanguageModel.generate(filledTemplate); + logger.info("Response: {}", response); + responses.add(response); + } + parseComponentNames(responses.getLast(), componentNames); + + } + + private String getPackageName(CodePackage codePackage) { + List packageName = new ArrayList<>(); + packageName.add(codePackage.getName()); + var parent = codePackage.getParent(); + while (parent != null) { + packageName.add(parent.getName()); + parent = parent.getParent(); + } + packageName = packageName.reversed(); + return String.join(".", packageName); + } + + private void documentationToArchitecture(List componentNames) { + var inputText = DataRepositoryHelper.getInputText(dataRepository); + String startMessage = TEMPLATES_DOC_TO_ARCHITECTURE.getFirst().formatted(inputText); + List messages = new ArrayList<>(); + messages.add(UserMessage.from(startMessage)); + + var initialResponse = chatLanguageModel.generate(messages).content(); + messages.add(initialResponse); + logger.info("Initial Response: {}", initialResponse.text()); + + for (String nextMessage : TEMPLATES_DOC_TO_ARCHITECTURE.stream().skip(1).toList()) { + messages.add(UserMessage.from(nextMessage)); + var response = chatLanguageModel.generate(messages).content(); + logger.info("Response: {}", response.text()); + messages.add(response); + } + + parseComponentNames(((AiMessage) messages.getLast()).text(), componentNames); + } + + private void parseComponentNames(String response, List componentNames) { + for (String line : response.split("\n")) { + if (line.isBlank()) { + continue; + } + line = line.trim(); + // Version 1: 1. **Name** or 2. **Name** + if (line.matches("^\\d+\\.\\s*\\*\\*.*\\*\\*$")) { + componentNames.add(line.split("\\*\\*")[1]); + } + // Version 2: 1. Name or 2. Name + else if (line.matches("^\\d+\\.\\s*.*$")) { + componentNames.add(line.split("\\.\\s*")[1]); + } + // Version 3: - **Name** + else if (line.matches("^-\\s*\\*\\*.*\\*\\*$")) { + componentNames.add(line.split("\\*\\*")[1]); + } + // Version 4: - Name + else if (line.matches("^-\\s*.*$")) { + componentNames.add(line.split("-\\s*")[1]); + } else { + logger.warn("Could not parse line: {}", line); + } + } + } + + private void buildModel(List componentNames) { + List componentList = componentNames.stream() + .map(it -> new ArchitectureComponent(it.replace("Component", "").trim(), it, new TreeSet<>(), new TreeSet<>(), new TreeSet<>(), "Component")) + .collect(Collectors.toList()); + ArchitectureModel am = new ArchitectureModel(componentList); + Optional modelStatesOptional = dataRepository.getData(MODEL_STATES_DATA, ModelStates.class); + var modelStates = modelStatesOptional.orElseGet(ModelStates::new); + + modelStates.addModel(ArchitectureModelType.PCM.getModelId(), am); + + if (modelStatesOptional.isEmpty()) { + dataRepository.addData(MODEL_STATES_DATA, modelStates); + } + } +} diff --git a/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation.java b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation.java new file mode 100644 index 0000000..c3df727 --- /dev/null +++ b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation.java @@ -0,0 +1,88 @@ +/* Licensed under MIT 2023-2024. */ +package edu.kit.kastel.mcse.ardoco.tlr.tests.integration; + +import static edu.kit.kastel.mcse.ardoco.tlr.tests.integration.TraceLinkEvaluationIT.OUTPUT; + +import java.io.File; +import java.util.SortedMap; +import java.util.TreeMap; + +import org.eclipse.collections.api.factory.Lists; +import org.eclipse.collections.api.list.ImmutableList; + +import edu.kit.kastel.mcse.ardoco.core.api.models.CodeModelType; +import edu.kit.kastel.mcse.ardoco.core.api.models.ModelStates; +import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.Model; +import edu.kit.kastel.mcse.ardoco.core.api.output.ArDoCoResult; +import edu.kit.kastel.mcse.ardoco.core.api.text.Text; +import edu.kit.kastel.mcse.ardoco.core.common.util.DataRepositoryHelper; +import edu.kit.kastel.mcse.ardoco.core.common.util.TraceLinkUtilities; +import edu.kit.kastel.mcse.ardoco.core.data.DataRepository; +import edu.kit.kastel.mcse.ardoco.core.execution.runner.ArDoCoRunner; +import edu.kit.kastel.mcse.ardoco.core.tests.eval.CodeProject; +import edu.kit.kastel.mcse.ardoco.core.tests.eval.results.ExpectedResults; +import edu.kit.kastel.mcse.ardoco.tlr.execution.ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery; + +class SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation extends TraceabilityLinkRecoveryEvaluation { + private final boolean acmFile; + + public SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(boolean acmFile) { + super(); + this.acmFile = acmFile; + } + + @Override + protected boolean resultHasRequiredData(ArDoCoResult arDoCoResult) { + var traceLinks = arDoCoResult.getSadCodeTraceLinks(); + return !traceLinks.isEmpty(); + } + + @Override + protected ArDoCoRunner getAndSetupRunner(CodeProject codeProject) { + String name = codeProject.name().toLowerCase(); + File textInput = codeProject.getTextFile(); + File inputCode = getInputCode(codeProject, acmFile); + SortedMap additionalConfigsMap = new TreeMap<>(); + File outputDir = new File(OUTPUT); + + var runner = new ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery(name); + runner.setUp(textInput, inputCode, additionalConfigsMap, outputDir); + return runner; + } + + @Override + protected ImmutableList createTraceLinkStringList(ArDoCoResult arDoCoResult) { + var traceLinks = arDoCoResult.getSadCodeTraceLinks(); + + return TraceLinkUtilities.getSadCodeTraceLinksAsStringList(Lists.immutable.ofAll(traceLinks)); + } + + @Override + protected ImmutableList getGoldStandard(CodeProject codeProject) { + return codeProject.getSadCodeGoldStandard(); + } + + @Override + protected ImmutableList enrollGoldStandard(ImmutableList goldStandard, ArDoCoResult result) { + return enrollGoldStandardForCode(goldStandard, result); + } + + @Override + protected ExpectedResults getExpectedResults(CodeProject codeProject) { + return codeProject.getExpectedResultsForSadSamCode(); + } + + @Override + protected int getConfusionMatrixSum(ArDoCoResult arDoCoResult) { + DataRepository dataRepository = arDoCoResult.dataRepository(); + + Text text = DataRepositoryHelper.getAnnotatedText(dataRepository); + int sentences = text.getSentences().size(); + + ModelStates modelStatesData = DataRepositoryHelper.getModelStatesData(dataRepository); + Model codeModel = modelStatesData.getModel(CodeModelType.CODE_MODEL.getModelId()); + var codeModelEndpoints = codeModel.getEndpoints().size(); + + return sentences * codeModelEndpoints; + } +} diff --git a/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java new file mode 100644 index 0000000..a16227b --- /dev/null +++ b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java @@ -0,0 +1,47 @@ +/* Licensed under MIT 2023-2024. */ +package edu.kit.kastel.mcse.ardoco.tlr.tests.integration; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.eclipse.collections.api.tuple.Pair; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +import edu.kit.kastel.mcse.ardoco.core.api.output.ArDoCoResult; +import edu.kit.kastel.mcse.ardoco.core.tests.eval.CodeProject; +import edu.kit.kastel.mcse.ardoco.core.tests.eval.GoldStandardProject; +import edu.kit.kastel.mcse.ardoco.core.tests.eval.results.EvaluationResults; +import edu.kit.kastel.mcse.ardoco.tlr.tests.integration.tlrhelper.ModelElementSentenceLink; + +class TraceLinkEvaluationSadSamViaLlmCodeIT { + protected static final String LOGGING_ARDOCO_CORE = "org.slf4j.simpleLogger.log.edu.kit.kastel.mcse.ardoco.core"; + + protected static final List>> RESULTS = new ArrayList<>(); + protected static final Map DATA_MAP = new LinkedHashMap<>(); + + @BeforeAll + static void beforeAll() { + System.setProperty(LOGGING_ARDOCO_CORE, "info"); + } + + @AfterAll + static void afterAll() { + System.setProperty(LOGGING_ARDOCO_CORE, "error"); + } + + @DisplayName("Evaluate SAD-SAM-via-LLM-Code TLR") + @ParameterizedTest(name = "{0}") + @EnumSource(CodeProject.class) + void evaluateSadCodeTlrIT(CodeProject project) { + var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true); + ArDoCoResult results = evaluation.runTraceLinkEvaluation(project); + Assertions.assertNotNull(results); + } +}