diff --git a/pipeline-tlr/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/execution/ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery.java b/pipeline-tlr/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/execution/ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery.java index d112565..dc713e7 100644 --- a/pipeline-tlr/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/execution/ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery.java +++ b/pipeline-tlr/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/execution/ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery.java @@ -13,6 +13,7 @@ import edu.kit.kastel.mcse.ardoco.tlr.connectiongenerator.ConnectionGenerator; import edu.kit.kastel.mcse.ardoco.tlr.models.agents.ArCoTLModelProviderAgent; import edu.kit.kastel.mcse.ardoco.tlr.models.agents.LLMArchitectureProviderAgent; +import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LargeLanguageModel; import edu.kit.kastel.mcse.ardoco.tlr.recommendationgenerator.RecommendationGenerator; import edu.kit.kastel.mcse.ardoco.tlr.text.providers.TextPreprocessingAgent; import edu.kit.kastel.mcse.ardoco.tlr.textextraction.TextExtraction; @@ -23,13 +24,13 @@ public ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery(String projectName) { super(projectName); } - public void setUp(File inputText, File inputCode, SortedMap additionalConfigs, File outputDir) { - definePipeline(inputText, inputCode, additionalConfigs); + public void setUp(File inputText, File inputCode, SortedMap additionalConfigs, File outputDir, LargeLanguageModel largeLanguageModel) { + definePipeline(inputText, inputCode, additionalConfigs, largeLanguageModel); setOutputDirectory(outputDir); isSetUp = true; } - private void definePipeline(File inputText, File inputCode, SortedMap additionalConfigs) { + private void definePipeline(File inputText, File inputCode, SortedMap additionalConfigs, LargeLanguageModel largeLanguageModel) { ArDoCo arDoCo = this.getArDoCo(); var dataRepository = arDoCo.getDataRepository(); @@ -47,7 +48,7 @@ private void definePipeline(File inputText, File inputCode, SortedMaplangchain4j-core ${langchain4j.version} + + dev.langchain4j + langchain4j-ollama + ${langchain4j.version} + dev.langchain4j langchain4j-open-ai diff --git a/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/agents/LLMArchitectureProviderAgent.java b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/agents/LLMArchitectureProviderAgent.java index 108b31c..8ba54fb 100644 --- a/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/agents/LLMArchitectureProviderAgent.java +++ b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/agents/LLMArchitectureProviderAgent.java @@ -6,10 +6,12 @@ import edu.kit.kastel.mcse.ardoco.core.data.DataRepository; import edu.kit.kastel.mcse.ardoco.core.pipeline.agent.PipelineAgent; import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LLMArchitectureProviderInformant; +import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LargeLanguageModel; public class LLMArchitectureProviderAgent extends PipelineAgent { - public LLMArchitectureProviderAgent(DataRepository dataRepository) { - super(List.of(new LLMArchitectureProviderInformant(dataRepository)), LLMArchitectureProviderAgent.class.getSimpleName(), dataRepository); + public LLMArchitectureProviderAgent(DataRepository dataRepository, LargeLanguageModel largeLanguageModel) { + super(List.of(new LLMArchitectureProviderInformant(dataRepository, largeLanguageModel)), LLMArchitectureProviderAgent.class.getSimpleName(), + dataRepository); } } diff --git a/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java index bac4462..134f91c 100644 --- a/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java +++ b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java @@ -11,8 +11,6 @@ import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.openai.OpenAiChatModel; -import dev.langchain4j.model.openai.OpenAiChatModelName; import edu.kit.kastel.mcse.ardoco.core.api.models.ArchitectureModelType; import edu.kit.kastel.mcse.ardoco.core.api.models.ModelStates; import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.ArchitectureModel; @@ -35,19 +33,14 @@ public class LLMArchitectureProviderInformant extends Informant { """, "Now provide a list that only covers the component names. Omit common prefixes and suffixes in the names."); - public LLMArchitectureProviderInformant(DataRepository dataRepository) { + public LLMArchitectureProviderInformant(DataRepository dataRepository, LargeLanguageModel largeLanguageModel) { super(LLMArchitectureProviderInformant.class.getSimpleName(), dataRepository); String apiKey = System.getenv("OPENAI_API_KEY"); String orgId = System.getenv("OPENAI_ORG_ID"); if (apiKey == null || orgId == null) { throw new IllegalArgumentException("OPENAI_API_KEY and OPENAI_ORG_ID must be set as environment variables"); } - this.chatLanguageModel = new OpenAiChatModel.OpenAiChatModelBuilder().modelName(OpenAiChatModelName.GPT_4_O) - .apiKey(apiKey) - .organizationId(orgId) - .seed(422413373) - .temperature(0.0) - .build(); + this.chatLanguageModel = largeLanguageModel.create(); } @Override @@ -95,12 +88,12 @@ else if (line.matches("^\\d+\\.\\s*.*$")) { componentNames.add(line.split("\\.\\s*")[1]); } // Version 3: - **Name** - else if (line.matches("^-\\s*\\*\\*.*\\*\\*$")) { + else if (line.matches("^([-*])\\s*\\*\\*.*\\*\\*$")) { componentNames.add(line.split("\\*\\*")[1]); } // Version 4: - Name - else if (line.matches("^-\\s*.*$")) { - componentNames.add(line.split("-\\s*")[1]); + else if (line.matches("^([-*])\\s*.*$")) { + componentNames.add(line.split("([-*])\\s*")[1]); } else { logger.warn("Could not parse line: {}", line); } diff --git a/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LargeLanguageModel.java b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LargeLanguageModel.java new file mode 100644 index 0000000..3c24643 --- /dev/null +++ b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LargeLanguageModel.java @@ -0,0 +1,78 @@ +/* Licensed under MIT 2024. */ +package edu.kit.kastel.mcse.ardoco.tlr.models.informants; + +import java.time.Duration; +import java.util.Map; +import java.util.function.Supplier; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.ollama.OllamaChatModel; +import dev.langchain4j.model.openai.OpenAiChatModel; +import okhttp3.Credentials; + +public enum LargeLanguageModel { + // OPENAI + GPT_4_O_MINI(() -> createOpenAiModel("gpt-4o-mini-2024-07-18")), // + GPT_4_O(() -> createOpenAiModel("gpt-4o-2024-08-06")), // + GPT_4_TURBO(() -> createOpenAiModel("gpt-4-turbo-2024-04-09")), // + GPT_4(() -> createOpenAiModel("gpt-4-0613")), // + GPT_3_5_TURBO(() -> createOpenAiModel("gpt-3.5-turbo-0125")), // + OPENAI_GENERIC(() -> createOpenAiModel(System.getenv("OPENAI_MODEL_NAME"))), // + // OLLAMA + CODELLAMA_13B(() -> createOllamaModel("codellama:13b")), // + CODELLAMA_70B(() -> createOllamaModel("codellama:70b")), // + // + GEMMA_2_27B(() -> createOllamaModel("gemma2:27b")), // + // + LLAMA_3_1_8B(() -> createOllamaModel("llama3.1:8b-instruct-fp16")), // + LLAMA_3_1_70B(() -> createOllamaModel("llama3.1:70b")), // + // + MISTRAL_7B(() -> createOllamaModel("mistral:7b")), // + MISTRAL_NEMO_27B(() -> createOllamaModel("mistral-nemo:12b")), // + MIXTRAL_8_X_22B(() -> createOllamaModel("mixtral:8x22b")), // + // + PHI_3_14B(() -> createOllamaModel("phi3:14b")), // + // + OLLAMA_GENERIC(() -> createOllamaModel(System.getenv("OLLAMA_MODEL_NAME"))); + + private final Supplier creator; + + LargeLanguageModel(Supplier creator) { + this.creator = creator; + } + + public ChatLanguageModel create() { + return creator.get(); + } + + private static final int SEED = 422413373; + + private static ChatLanguageModel createOpenAiModel(String model) { + String apiKey = System.getenv("OPENAI_API_KEY"); + String orgId = System.getenv("OPENAI_ORG_ID"); + if (apiKey == null || orgId == null) { + throw new IllegalArgumentException("OPENAI_API_KEY and OPENAI_ORG_ID must be set as environment variables"); + } + return new OpenAiChatModel.OpenAiChatModelBuilder().modelName(model).apiKey(apiKey).organizationId(orgId).seed(SEED).temperature(0.0).build(); + } + + private static ChatLanguageModel createOllamaModel(String model) { + String ollamaHost = System.getenv("OLLAMA_HOST"); + String ollamaUser = System.getenv("OLLAMA_USER"); + String ollamaPassword = System.getenv("OLLAMA_PASSWORD"); + if (ollamaHost == null) { + throw new IllegalArgumentException("OLLAMA_HOST must be set as environment variable"); + } + + OllamaChatModel.OllamaChatModelBuilder builder = new OllamaChatModel.OllamaChatModelBuilder().modelName(model) + .baseUrl(ollamaHost) + .seed(SEED) + .timeout(Duration.ofMinutes(30)) + .temperature(0.0); + if (ollamaUser != null && ollamaPassword != null) { + builder.customHeaders(Map.of("Authorization", Credentials.basic(ollamaUser, ollamaPassword))); + } + + return builder.build(); + } +} diff --git a/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation.java b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation.java index c3df727..de3ebde 100644 --- a/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation.java +++ b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation.java @@ -22,13 +22,16 @@ import edu.kit.kastel.mcse.ardoco.core.tests.eval.CodeProject; import edu.kit.kastel.mcse.ardoco.core.tests.eval.results.ExpectedResults; import edu.kit.kastel.mcse.ardoco.tlr.execution.ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery; +import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LargeLanguageModel; class SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation extends TraceabilityLinkRecoveryEvaluation { private final boolean acmFile; + private final LargeLanguageModel largeLanguageModel; - public SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(boolean acmFile) { + public SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(boolean acmFile, LargeLanguageModel largeLanguageModel) { super(); this.acmFile = acmFile; + this.largeLanguageModel = largeLanguageModel; } @Override @@ -46,7 +49,7 @@ protected ArDoCoRunner getAndSetupRunner(CodeProject codeProject) { File outputDir = new File(OUTPUT); var runner = new ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery(name); - runner.setUp(textInput, inputCode, additionalConfigsMap, outputDir); + runner.setUp(textInput, inputCode, additionalConfigsMap, outputDir, largeLanguageModel); return runner; } diff --git a/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java index a16227b..66d2a3c 100644 --- a/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java +++ b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java @@ -18,6 +18,7 @@ import edu.kit.kastel.mcse.ardoco.core.tests.eval.CodeProject; import edu.kit.kastel.mcse.ardoco.core.tests.eval.GoldStandardProject; import edu.kit.kastel.mcse.ardoco.core.tests.eval.results.EvaluationResults; +import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LargeLanguageModel; import edu.kit.kastel.mcse.ardoco.tlr.tests.integration.tlrhelper.ModelElementSentenceLink; class TraceLinkEvaluationSadSamViaLlmCodeIT { @@ -40,7 +41,7 @@ static void afterAll() { @ParameterizedTest(name = "{0}") @EnumSource(CodeProject.class) void evaluateSadCodeTlrIT(CodeProject project) { - var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true); + var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true, LargeLanguageModel.PHI_3_14B); ArDoCoResult results = evaluation.runTraceLinkEvaluation(project); Assertions.assertNotNull(results); }