Skip to content

Commit

Permalink
Define Language Models
Browse files Browse the repository at this point in the history
  • Loading branch information
dfuchss committed Sep 6, 2024
1 parent 9f6f62b commit c326c3e
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import edu.kit.kastel.mcse.ardoco.tlr.connectiongenerator.ConnectionGenerator;
import edu.kit.kastel.mcse.ardoco.tlr.models.agents.ArCoTLModelProviderAgent;
import edu.kit.kastel.mcse.ardoco.tlr.models.agents.LLMArchitectureProviderAgent;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LargeLanguageModel;
import edu.kit.kastel.mcse.ardoco.tlr.recommendationgenerator.RecommendationGenerator;
import edu.kit.kastel.mcse.ardoco.tlr.text.providers.TextPreprocessingAgent;
import edu.kit.kastel.mcse.ardoco.tlr.textextraction.TextExtraction;
Expand All @@ -23,13 +24,13 @@ public ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery(String projectName) {
super(projectName);
}

public void setUp(File inputText, File inputCode, SortedMap<String, String> additionalConfigs, File outputDir) {
definePipeline(inputText, inputCode, additionalConfigs);
public void setUp(File inputText, File inputCode, SortedMap<String, String> additionalConfigs, File outputDir, LargeLanguageModel largeLanguageModel) {
definePipeline(inputText, inputCode, additionalConfigs, largeLanguageModel);
setOutputDirectory(outputDir);
isSetUp = true;
}

private void definePipeline(File inputText, File inputCode, SortedMap<String, String> additionalConfigs) {
private void definePipeline(File inputText, File inputCode, SortedMap<String, String> additionalConfigs, LargeLanguageModel largeLanguageModel) {
ArDoCo arDoCo = this.getArDoCo();
var dataRepository = arDoCo.getDataRepository();

Expand All @@ -47,7 +48,7 @@ private void definePipeline(File inputText, File inputCode, SortedMap<String, St
codeConfiguration);
arDoCo.addPipelineStep(arCoTLModelProviderAgent);

LLMArchitectureProviderAgent llmArchitectureProviderAgent = new LLMArchitectureProviderAgent(dataRepository);
LLMArchitectureProviderAgent llmArchitectureProviderAgent = new LLMArchitectureProviderAgent(dataRepository, largeLanguageModel);
arDoCo.addPipelineStep(llmArchitectureProviderAgent);

arDoCo.addPipelineStep(TextExtraction.get(additionalConfigs, dataRepository));
Expand Down
5 changes: 5 additions & 0 deletions stages-tlr/model-provider/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
<artifactId>langchain4j-core</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-ollama</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import edu.kit.kastel.mcse.ardoco.core.data.DataRepository;
import edu.kit.kastel.mcse.ardoco.core.pipeline.agent.PipelineAgent;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LLMArchitectureProviderInformant;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LargeLanguageModel;

public class LLMArchitectureProviderAgent extends PipelineAgent {

public LLMArchitectureProviderAgent(DataRepository dataRepository) {
super(List.of(new LLMArchitectureProviderInformant(dataRepository)), LLMArchitectureProviderAgent.class.getSimpleName(), dataRepository);
public LLMArchitectureProviderAgent(DataRepository dataRepository, LargeLanguageModel largeLanguageModel) {
super(List.of(new LLMArchitectureProviderInformant(dataRepository, largeLanguageModel)), LLMArchitectureProviderAgent.class.getSimpleName(),
dataRepository);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiChatModelName;
import edu.kit.kastel.mcse.ardoco.core.api.models.ArchitectureModelType;
import edu.kit.kastel.mcse.ardoco.core.api.models.ModelStates;
import edu.kit.kastel.mcse.ardoco.core.api.models.arcotl.ArchitectureModel;
Expand All @@ -35,19 +33,14 @@ public class LLMArchitectureProviderInformant extends Informant {
""",
"Now provide a list that only covers the component names. Omit common prefixes and suffixes in the names.");

public LLMArchitectureProviderInformant(DataRepository dataRepository) {
public LLMArchitectureProviderInformant(DataRepository dataRepository, LargeLanguageModel largeLanguageModel) {
super(LLMArchitectureProviderInformant.class.getSimpleName(), dataRepository);
String apiKey = System.getenv("OPENAI_API_KEY");
String orgId = System.getenv("OPENAI_ORG_ID");
if (apiKey == null || orgId == null) {
throw new IllegalArgumentException("OPENAI_API_KEY and OPENAI_ORG_ID must be set as environment variables");
}
this.chatLanguageModel = new OpenAiChatModel.OpenAiChatModelBuilder().modelName(OpenAiChatModelName.GPT_4_O)
.apiKey(apiKey)
.organizationId(orgId)
.seed(422413373)
.temperature(0.0)
.build();
this.chatLanguageModel = largeLanguageModel.create();
}

@Override
Expand Down Expand Up @@ -95,12 +88,12 @@ else if (line.matches("^\\d+\\.\\s*.*$")) {
componentNames.add(line.split("\\.\\s*")[1]);
}
// Version 3: - **Name**
else if (line.matches("^-\\s*\\*\\*.*\\*\\*$")) {
else if (line.matches("^([-*])\\s*\\*\\*.*\\*\\*$")) {
componentNames.add(line.split("\\*\\*")[1]);
}
// Version 4: - Name
else if (line.matches("^-\\s*.*$")) {
componentNames.add(line.split("-\\s*")[1]);
else if (line.matches("^([-*])\\s*.*$")) {
componentNames.add(line.split("([-*])\\s*")[1]);
} else {
logger.warn("Could not parse line: {}", line);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* Licensed under MIT 2024. */
package edu.kit.kastel.mcse.ardoco.tlr.models.informants;

import java.time.Duration;
import java.util.Map;
import java.util.function.Supplier;

import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.ollama.OllamaChatModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import okhttp3.Credentials;

public enum LargeLanguageModel {
// OPENAI
GPT_4_O_MINI(() -> createOpenAiModel("gpt-4o-mini-2024-07-18")), //
GPT_4_O(() -> createOpenAiModel("gpt-4o-2024-08-06")), //
GPT_4_TURBO(() -> createOpenAiModel("gpt-4-turbo-2024-04-09")), //
GPT_4(() -> createOpenAiModel("gpt-4-0613")), //
GPT_3_5_TURBO(() -> createOpenAiModel("gpt-3.5-turbo-0125")), //
OPENAI_GENERIC(() -> createOpenAiModel(System.getenv("OPENAI_MODEL_NAME"))), //
// OLLAMA
CODELLAMA_13B(() -> createOllamaModel("codellama:13b")), //
CODELLAMA_70B(() -> createOllamaModel("codellama:70b")), //
//
GEMMA_2_27B(() -> createOllamaModel("gemma2:27b")), //
//
LLAMA_3_1_8B(() -> createOllamaModel("llama3.1:8b-instruct-fp16")), //
LLAMA_3_1_70B(() -> createOllamaModel("llama3.1:70b")), //
//
MISTRAL_7B(() -> createOllamaModel("mistral:7b")), //
MISTRAL_NEMO_27B(() -> createOllamaModel("mistral-nemo:12b")), //
MIXTRAL_8_X_22B(() -> createOllamaModel("mixtral:8x22b")), //
//
PHI_3_14B(() -> createOllamaModel("phi3:14b")), //
//
OLLAMA_GENERIC(() -> createOllamaModel(System.getenv("OLLAMA_MODEL_NAME")));

private final Supplier<ChatLanguageModel> creator;

LargeLanguageModel(Supplier<ChatLanguageModel> creator) {
this.creator = creator;
}

public ChatLanguageModel create() {
return creator.get();
}

private static final int SEED = 422413373;

private static ChatLanguageModel createOpenAiModel(String model) {
String apiKey = System.getenv("OPENAI_API_KEY");
String orgId = System.getenv("OPENAI_ORG_ID");
if (apiKey == null || orgId == null) {
throw new IllegalArgumentException("OPENAI_API_KEY and OPENAI_ORG_ID must be set as environment variables");
}
return new OpenAiChatModel.OpenAiChatModelBuilder().modelName(model).apiKey(apiKey).organizationId(orgId).seed(SEED).temperature(0.0).build();
}

private static ChatLanguageModel createOllamaModel(String model) {
String ollamaHost = System.getenv("OLLAMA_HOST");
String ollamaUser = System.getenv("OLLAMA_USER");
String ollamaPassword = System.getenv("OLLAMA_PASSWORD");
if (ollamaHost == null) {
throw new IllegalArgumentException("OLLAMA_HOST must be set as environment variable");
}

OllamaChatModel.OllamaChatModelBuilder builder = new OllamaChatModel.OllamaChatModelBuilder().modelName(model)
.baseUrl(ollamaHost)
.seed(SEED)
.timeout(Duration.ofMinutes(30))
.temperature(0.0);
if (ollamaUser != null && ollamaPassword != null) {
builder.customHeaders(Map.of("Authorization", Credentials.basic(ollamaUser, ollamaPassword)));
}

return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@
import edu.kit.kastel.mcse.ardoco.core.tests.eval.CodeProject;
import edu.kit.kastel.mcse.ardoco.core.tests.eval.results.ExpectedResults;
import edu.kit.kastel.mcse.ardoco.tlr.execution.ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LargeLanguageModel;

class SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation extends TraceabilityLinkRecoveryEvaluation<CodeProject> {
private final boolean acmFile;
private final LargeLanguageModel largeLanguageModel;

public SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(boolean acmFile) {
public SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(boolean acmFile, LargeLanguageModel largeLanguageModel) {
super();
this.acmFile = acmFile;
this.largeLanguageModel = largeLanguageModel;
}

@Override
Expand All @@ -46,7 +49,7 @@ protected ArDoCoRunner getAndSetupRunner(CodeProject codeProject) {
File outputDir = new File(OUTPUT);

var runner = new ArDoCoForSadSamViaLlmCodeTraceabilityLinkRecovery(name);
runner.setUp(textInput, inputCode, additionalConfigsMap, outputDir);
runner.setUp(textInput, inputCode, additionalConfigsMap, outputDir, largeLanguageModel);
return runner;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import edu.kit.kastel.mcse.ardoco.core.tests.eval.CodeProject;
import edu.kit.kastel.mcse.ardoco.core.tests.eval.GoldStandardProject;
import edu.kit.kastel.mcse.ardoco.core.tests.eval.results.EvaluationResults;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LargeLanguageModel;
import edu.kit.kastel.mcse.ardoco.tlr.tests.integration.tlrhelper.ModelElementSentenceLink;

class TraceLinkEvaluationSadSamViaLlmCodeIT {
Expand All @@ -40,7 +41,7 @@ static void afterAll() {
@ParameterizedTest(name = "{0}")
@EnumSource(CodeProject.class)
void evaluateSadCodeTlrIT(CodeProject project) {
var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true);
var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true, LargeLanguageModel.PHI_3_14B);
ArDoCoResult results = evaluation.runTraceLinkEvaluation(project);
Assertions.assertNotNull(results);
}
Expand Down

0 comments on commit c326c3e

Please sign in to comment.