Skip to content

Commit

Permalink
Disable some models
Browse files Browse the repository at this point in the history
  • Loading branch information
dfuchss committed Sep 23, 2024
1 parent 11c0d64 commit 84d5f2f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
package edu.kit.kastel.mcse.ardoco.tlr.models.informants;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
Expand All @@ -19,7 +21,10 @@
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.architecture.ArchitectureComponent;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.architecture.ArchitectureItem;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.code.CodePackage;
import edu.kit.kastel.mcse.ardoco.core.common.util.CommonTextToolsConfig;
import edu.kit.kastel.mcse.ardoco.core.common.util.DataRepositoryHelper;
import edu.kit.kastel.mcse.ardoco.core.common.util.wordsim.WordSimUtils;
import edu.kit.kastel.mcse.ardoco.core.common.util.wordsim.measures.levenshtein.LevenshteinMeasure;
import edu.kit.kastel.mcse.ardoco.core.data.DataRepository;
import edu.kit.kastel.mcse.ardoco.core.pipeline.agent.Informant;

Expand Down Expand Up @@ -47,22 +52,30 @@ public LLMArchitectureProviderInformant(DataRepository dataRepository, LargeLang
throw new IllegalArgumentException("At least one prompt must be provided");
}
if (documentationPrompt != null && codePrompt != null && aggregationPrompt == null) {
throw new IllegalArgumentException("Aggregation prompt must be provided when both documentation and code prompts are provided");
logger.info("Using Similarity Metrics to aggregate the component names");
}
}

@Override
protected void process() {
List<String> componentNames = new ArrayList<>();
List<String> componentNamesDocumentation = new ArrayList<>();
List<String> componentNamesCode = new ArrayList<>();
if (documentationPrompt != null)
documentationToArchitecture(componentNames);
documentationToArchitecture(componentNamesDocumentation);
if (codePrompt != null)
codeToArchitecture(componentNames);
codeToArchitecture(componentNamesCode);

List<String> componentNames = new ArrayList<>();

if (aggregationPrompt != null) {
var aggregation = chatLanguageModel.generate(aggregationPrompt.getTemplates().getFirst().formatted(String.join("\n", componentNames)));
componentNames = new ArrayList<>();
var allComponentNames = Stream.concat(componentNamesDocumentation.stream(), componentNamesCode.stream()).toList();
var aggregation = chatLanguageModel.generate(aggregationPrompt.getTemplates().getFirst().formatted(String.join("\n", allComponentNames)));
parseComponentNames(aggregation, componentNames);
} else if (documentationPrompt != null && codePrompt != null) {
componentNames = mergeViaSimilarity(componentNamesDocumentation, componentNamesCode);
} else {
// If only one prompt is provided, use the component names from that prompt
componentNames = Stream.concat(componentNamesDocumentation.stream(), componentNamesCode.stream()).toList();
}

// Remove any not letter characters
Expand All @@ -77,6 +90,19 @@ protected void process() {
buildModel(componentNames);
}

private static List<String> mergeViaSimilarity(List<String> componentNamesDocumentation, List<String> componentNamesCode) {
WordSimUtils simUtils = new WordSimUtils();
simUtils.setMeasures(Collections.singletonList(new LevenshteinMeasure(CommonTextToolsConfig.LEVENSHTEIN_MIN_LENGTH,
CommonTextToolsConfig.LEVENSHTEIN_MAX_DISTANCE, 0.8)));
List<String> componentNames = new ArrayList<>();
for (String componentName : Stream.concat(componentNamesDocumentation.stream(), componentNamesCode.stream()).toList()) {
if (componentNames.stream().noneMatch(it -> simUtils.areWordsSimilar(it, componentName))) {
componentNames.add(componentName);
}
}
return componentNames;
}

private void documentationToArchitecture(List<String> componentNames) {
var inputText = DataRepositoryHelper.getInputText(dataRepository);
parseComponentsFromAiRequests(componentNames, documentationPrompt.getTemplates(), inputText);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ public enum LargeLanguageModel {
CODELLAMA_13B("Codellama 13b", () -> createOllamaModel("codellama:13b")), //
CODELLAMA_70B("Codellama 70b", () -> createOllamaModel("codellama:70b")), //
//
GEMMA_2_27B("Gemma2 27b", () -> createOllamaModel("gemma2:27b")), //
// GEMMA_2_27B("Gemma2 27b", () -> createOllamaModel("gemma2:27b")), //
//
// QWEN_2_72B("Qwen2 72b", () -> createOllamaModel("qwen2:72b")), //
//
LLAMA_3_1_8B("Llama3.1 8b", () -> createOllamaModel("llama3.1:8b-instruct-fp16")), //
LLAMA_3_1_70B("Llama3.1 70b", () -> createOllamaModel("llama3.1:70b")), //
//
MISTRAL_7B("Mistral 7b", () -> createOllamaModel("mistral:7b")), //
MISTRAL_NEMO_27B("Mistral Nemo 12b", () -> createOllamaModel("mistral-nemo:12b")), //
MIXTRAL_8_X_22B("Mixtral 8x22b", () -> createOllamaModel("mixtral:8x22b")), //
// MISTRAL_7B("Mistral 7b", () -> createOllamaModel("mistral:7b")), //
// MISTRAL_NEMO_27B("Mistral Nemo 12b", () -> createOllamaModel("mistral-nemo:12b")), //
// MIXTRAL_8_X_22B("Mixtral 8x22b", () -> createOllamaModel("mixtral:8x22b")), //
//
PHI_3_14B("Phi3 14b", () -> createOllamaModel("phi3:14b")), //
// PHI_3_14B("Phi3 14b", () -> createOllamaModel("phi3:14b")), //
//
OLLAMA_GENERIC(System.getenv("OLLAMA_MODEL_NAME"), () -> createOllamaModel(System.getenv("OLLAMA_MODEL_NAME")));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,11 @@ static void afterAll() {
@ParameterizedTest(name = "{0} ({1})")
@MethodSource("llmsXprojects")
void evaluateSadCodeTlrIT(CodeProject project, LargeLanguageModel llm) {
Assumptions.assumeTrue(System.getenv("CI") == null);

if (llm.isGeneric()) {
Assumptions.abort("Generic LLM is disabled");
}
if (llm.isOpenAi()) {
Assumptions.abort("Model is disabled");
}

logger.info("###############################################");
logger.info("Evaluating project {} with LLM '{}'", project, llm);
Expand Down

0 comments on commit 84d5f2f

Please sign in to comment.