diff --git a/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java index 591112b..2972f48 100644 --- a/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java +++ b/tests/integration-tests/tests-tlr/src/test/java/edu/kit/kastel/mcse/ardoco/tlr/tests/integration/TraceLinkEvaluationSadSamViaLlmCodeIT.java @@ -24,6 +24,9 @@ import edu.kit.kastel.mcse.ardoco.core.api.output.ArDoCoResult; import edu.kit.kastel.mcse.ardoco.core.tests.eval.CodeProject; +import edu.kit.kastel.mcse.ardoco.metrics.ClassificationMetricsCalculator; +import edu.kit.kastel.mcse.ardoco.metrics.result.AggregationType; +import edu.kit.kastel.mcse.ardoco.metrics.result.SingleClassificationResult; import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LLMArchitecturePrompt; import edu.kit.kastel.mcse.ardoco.tlr.models.informants.LargeLanguageModel; @@ -69,12 +72,14 @@ void evaluateSadCodeTlrIT(CodeProject project, LargeLanguageModel llm) { @AfterAll static void printResults() { logger.info("!!!!!!!!! Results !!!!!!!!!!"); - System.out.println(Arrays.stream(CodeProject.values()).map(Enum::name).collect(Collectors.joining(" & ")) + " \\\\"); + System.out.println(Arrays.stream(CodeProject.values()).map(Enum::name).collect(Collectors.joining(" & ")) + " Macro Avg & Weighted Average" + " \\\\"); for (LargeLanguageModel llm : LargeLanguageModel.values()) { if (llm.isGeneric() && RESULTS.keySet().stream().noneMatch(k -> k.getTwo() == llm)) { continue; } StringBuilder llmResult = new StringBuilder(llm.getHumanReadableName() + " "); + + List> classificationResults = new ArrayList<>(); for (CodeProject project : CodeProject.values()) { if (!RESULTS.containsKey(Tuples.pair(project, llm))) { llmResult.append("&--&--&--"); @@ -83,15 +88,23 @@ static void printResults() { ArDoCoResult result = RESULTS.get(Tuples.pair(project, llm)); // Just some instance .. parameters do not matter .. - var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true, llm, null, LLMArchitecturePrompt.CODE_ONLY_V1, null); + var evaluation = new SadSamViaLlmCodeTraceabilityLinkRecoveryEvaluation(true, llm, LLMArchitecturePrompt.DOCUMENTATION_ONLY_V1, null, null); var goldStandard = project.getSadCodeGoldStandard(); goldStandard = TraceabilityLinkRecoveryEvaluation.enrollGoldStandardForCode(goldStandard, result); var evaluationResults = evaluation.calculateEvaluationResults(result, goldStandard); + classificationResults.add(evaluationResults.classificationResult()); llmResult.append(String.format(Locale.ENGLISH, "&%.2f&%.2f&%.2f", evaluationResults.precision(), evaluationResults.recall(), evaluationResults .f1())); } - llmResult.append("&&&&&&\\\\"); // end of line - System.out.println(llmResult); + ClassificationMetricsCalculator classificationMetricsCalculator = ClassificationMetricsCalculator.getInstance(); + var averages = classificationMetricsCalculator.calculateAverages(classificationResults, null); + + var macro = averages.stream().filter(it -> it.getType() == AggregationType.MACRO_AVERAGE).findFirst().orElseThrow(); + var weighted = averages.stream().filter(it -> it.getType() == AggregationType.WEIGHTED_AVERAGE).findFirst().orElseThrow(); + + llmResult.append(String.format(Locale.ENGLISH, "&%.2f&%.2f&%.2f&%.2f&%.2f&%.2f\\\\", macro.getPrecision(), macro.getRecall(), macro.getF1(), + weighted.getPrecision(), weighted.getRecall(), weighted.getF1())); // end of line + System.out.println(llmResult.toString().replace("0.", ".").replace("1.00", "1.0")); } }