From 84d5f2f26526cd21c3c14fce7bf3a7d8d5dbc319 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dominik=20Fuch=C3=9F?= Date: Mon, 23 Sep 2024 13:14:49 +0200 Subject: [PATCH] Disable some models --- .../LLMArchitectureProviderInformant.java | 38 ++++++++++++++++--- .../models/informants/LargeLanguageModel.java | 10 ++--- ...TraceLinkEvaluationSadSamViaLlmCodeIT.java | 5 +-- 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java index 343328b..67a3f12 100644 --- a/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java +++ b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LLMArchitectureProviderInformant.java @@ -2,10 +2,12 @@ package edu.kit.kastel.mcse.ardoco.tlr.models.informants; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.TreeSet; import java.util.stream.Collectors; +import java.util.stream.Stream; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; @@ -19,7 +21,10 @@ import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.architecture.ArchitectureComponent; import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.architecture.ArchitectureItem; import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.code.CodePackage; +import edu.kit.kastel.mcse.ardoco.core.common.util.CommonTextToolsConfig; import edu.kit.kastel.mcse.ardoco.core.common.util.DataRepositoryHelper; +import edu.kit.kastel.mcse.ardoco.core.common.util.wordsim.WordSimUtils; +import edu.kit.kastel.mcse.ardoco.core.common.util.wordsim.measures.levenshtein.LevenshteinMeasure; import edu.kit.kastel.mcse.ardoco.core.data.DataRepository; import edu.kit.kastel.mcse.ardoco.core.pipeline.agent.Informant; @@ -47,22 +52,30 @@ public LLMArchitectureProviderInformant(DataRepository dataRepository, LargeLang throw new IllegalArgumentException("At least one prompt must be provided"); } if (documentationPrompt != null && codePrompt != null && aggregationPrompt == null) { - throw new IllegalArgumentException("Aggregation prompt must be provided when both documentation and code prompts are provided"); + logger.info("Using Similarity Metrics to aggregate the component names"); } } @Override protected void process() { - List componentNames = new ArrayList<>(); + List componentNamesDocumentation = new ArrayList<>(); + List componentNamesCode = new ArrayList<>(); if (documentationPrompt != null) - documentationToArchitecture(componentNames); + documentationToArchitecture(componentNamesDocumentation); if (codePrompt != null) - codeToArchitecture(componentNames); + codeToArchitecture(componentNamesCode); + + List componentNames = new ArrayList<>(); if (aggregationPrompt != null) { - var aggregation = chatLanguageModel.generate(aggregationPrompt.getTemplates().getFirst().formatted(String.join("\n", componentNames))); - componentNames = new ArrayList<>(); + var allComponentNames = Stream.concat(componentNamesDocumentation.stream(), componentNamesCode.stream()).toList(); + var aggregation = chatLanguageModel.generate(aggregationPrompt.getTemplates().getFirst().formatted(String.join("\n", allComponentNames))); parseComponentNames(aggregation, componentNames); + } else if (documentationPrompt != null && codePrompt != null) { + componentNames = mergeViaSimilarity(componentNamesDocumentation, componentNamesCode); + } else { + // If only one prompt is provided, use the component names from that prompt + componentNames = Stream.concat(componentNamesDocumentation.stream(), componentNamesCode.stream()).toList(); } // Remove any not letter characters @@ -77,6 +90,19 @@ protected void process() { buildModel(componentNames); } + private static List mergeViaSimilarity(List componentNamesDocumentation, List componentNamesCode) { + WordSimUtils simUtils = new WordSimUtils(); + simUtils.setMeasures(Collections.singletonList(new LevenshteinMeasure(CommonTextToolsConfig.LEVENSHTEIN_MIN_LENGTH, + CommonTextToolsConfig.LEVENSHTEIN_MAX_DISTANCE, 0.8))); + List componentNames = new ArrayList<>(); + for (String componentName : Stream.concat(componentNamesDocumentation.stream(), componentNamesCode.stream()).toList()) { + if (componentNames.stream().noneMatch(it -> simUtils.areWordsSimilar(it, componentName))) { + componentNames.add(componentName); + } + } + return componentNames; + } + private void documentationToArchitecture(List componentNames) { var inputText = DataRepositoryHelper.getInputText(dataRepository); parseComponentsFromAiRequests(componentNames, documentationPrompt.getTemplates(), inputText); diff --git a/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LargeLanguageModel.java b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LargeLanguageModel.java index a98437e..da62b3f 100644 --- a/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LargeLanguageModel.java +++ b/stages-tlr/model-provider/src/main/java/edu/kit/kastel/mcse/ardoco/tlr/models/informants/LargeLanguageModel.java @@ -22,18 +22,18 @@ public enum LargeLanguageModel { CODELLAMA_13B("Codellama 13b", () -> createOllamaModel("codellama:13b")), // CODELLAMA_70B("Codellama 70b", () -> createOllamaModel("codellama:70b")), // // - GEMMA_2_27B("Gemma2 27b", () -> createOllamaModel("gemma2:27b")), // + // GEMMA_2_27B("Gemma2 27b", () -> createOllamaModel("gemma2:27b")), // // // QWEN_2_72B("Qwen2 72b", () -> createOllamaModel("qwen2:72b")), // // LLAMA_3_1_8B("Llama3.1 8b", () -> createOllamaModel("llama3.1:8b-instruct-fp16")), // LLAMA_3_1_70B("Llama3.1 70b", () -> createOllamaModel("llama3.1:70b")), // // - MISTRAL_7B("Mistral 7b", () -> createOllamaModel("mistral:7b")), // - MISTRAL_NEMO_27B("Mistral Nemo 12b", () -> createOllamaModel("mistral-nemo:12b")), // - MIXTRAL_8_X_22B("Mixtral 8x22b", () -> createOllamaModel("mixtral:8x22b")), // + // MISTRAL_7B("Mistral 7b", () -> createOllamaModel("mistral:7b")), // + // MISTRAL_NEMO_27B("Mistral Nemo 12b", () -> createOllamaModel("mistral-nemo:12b")), // + // MIXTRAL_8_X_22B("Mixtral 8x22b", () -> createOllamaModel("mixtral:8x22b")), // // - PHI_3_14B("Phi3 14b", () -> createOllamaModel("phi3:14b")), // + // PHI_3_14B("Phi3 14b", () -> createOllamaModel("phi3:14b")), // // OLLAMA_GENERIC(System.getenv("OLLAMA_MODEL_NAME"), () -> createOllamaModel(System.getenv("OLLAMA_MODEL_NAME"))); diff --git a/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java index 7a4f160..9829c2c 100644 --- a/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java +++ b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java @@ -47,12 +47,11 @@ static void afterAll() { @ParameterizedTest(name = "{0} ({1})") @MethodSource("llmsXprojects") void evaluateSadCodeTlrIT(CodeProject project, LargeLanguageModel llm) { + Assumptions.assumeTrue(System.getenv("CI") == null); + if (llm.isGeneric()) { Assumptions.abort("Generic LLM is disabled"); } - if (llm.isOpenAi()) { - Assumptions.abort("Model is disabled"); - } logger.info("###############################################"); logger.info("Evaluating project {} with LLM '{}'", project, llm);