Skip to content

Commit

Permalink
Calc averages
Browse files Browse the repository at this point in the history
  • Loading branch information
dfuchss committed Sep 25, 2024
1 parent f61fd8c commit cbfca5f
Showing 1 changed file with 17 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

import edu.kit.kastel.mcse.ardoco.core.api.output.ArDoCoResult;
import edu.kit.kastel.mcse.ardoco.core.tests.eval.CodeProject;
import edu.kit.kastel.mcse.ardoco.metrics.ClassificationMetricsCalculator;
import edu.kit.kastel.mcse.ardoco.metrics.result.AggregationType;
import edu.kit.kastel.mcse.ardoco.metrics.result.SingleClassificationResult;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LLMArchitecturePrompt;
import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LargeLanguageModel;

Expand Down Expand Up @@ -69,12 +72,14 @@ void evaluateSadCodeTlrIT(CodeProject project, LargeLanguageModel llm) {
@AfterAll
static void printResults() {
logger.info("!!!!!!!!! Results !!!!!!!!!!");
System.out.println(Arrays.stream(CodeProject.values()).map(Enum::name).collect(Collectors.joining(" & ")) + " \\\\");
System.out.println(Arrays.stream(CodeProject.values()).map(Enum::name).collect(Collectors.joining(" & ")) + " Macro Avg & Weighted Average" + " \\\\");
for (LargeLanguageModel llm : LargeLanguageModel.values()) {
if (llm.isGeneric() && RESULTS.keySet().stream().noneMatch(k -> k.getTwo() == llm)) {
continue;
}
StringBuilder llmResult = new StringBuilder(llm.getHumanReadableName() + " ");

List<SingleClassificationResult<String>> classificationResults = new ArrayList<>();
for (CodeProject project : CodeProject.values()) {
if (!RESULTS.containsKey(Tuples.pair(project, llm))) {
llmResult.append("&--&--&--");
Expand All @@ -83,15 +88,23 @@ static void printResults() {
ArDoCoResult result = RESULTS.get(Tuples.pair(project, llm));

// Just some instance .. parameters do not matter ..
var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true, llm, null, LLMArchitecturePrompt.CODE_ONLY_V1, null);
var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true, llm, LLMArchitecturePrompt.DOCUMENTATION_ONLY_V1, null, null);
var goldStandard = project.getSadCodeGoldStandard();
goldStandard = TraceabilityLinkRecoveryEvaluation.enrollGoldStandardForCode(goldStandard, result);
var evaluationResults = evaluation.calculateEvaluationResults(result, goldStandard);
classificationResults.add(evaluationResults.classificationResult());
llmResult.append(String.format(Locale.ENGLISH, "&%.2f&%.2f&%.2f", evaluationResults.precision(), evaluationResults.recall(), evaluationResults
.f1()));
}
llmResult.append("&&&&&&\\\\"); // end of line
System.out.println(llmResult);
ClassificationMetricsCalculator classificationMetricsCalculator = ClassificationMetricsCalculator.getInstance();
var averages = classificationMetricsCalculator.calculateAverages(classificationResults, null);

var macro = averages.stream().filter(it -> it.getType() == AggregationType.MACRO_AVERAGE).findFirst().orElseThrow();
var weighted = averages.stream().filter(it -> it.getType() == AggregationType.WEIGHTED_AVERAGE).findFirst().orElseThrow();

llmResult.append(String.format(Locale.ENGLISH, "&%.2f&%.2f&%.2f&%.2f&%.2f&%.2f\\\\", macro.getPrecision(), macro.getRecall(), macro.getF1(),
weighted.getPrecision(), weighted.getRecall(), weighted.getF1())); // end of line
System.out.println(llmResult.toString().replace("0.", ".").replace("1.00", "1.0"));
}
}

Expand Down

0 comments on commit cbfca5f

Please sign in to comment.