Skip to content

Commit

Permalink
Make it compile
Browse files Browse the repository at this point in the history
  • Loading branch information
dfuchss committed Aug 20, 2024
1 parent a7bfeda commit 0ff067c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ public String toRow() {
public String toRow(String headerKey, String headerVal) {
return String.format(Locale.ENGLISH, """
%10s & %4s & %4s & %4s & %4s & %4s & %4s & %4s
%10s & %4.2f & %4.2f & %4.2f & %4.2f & %4.2f & %4.2f & %4.2f""", headerKey, "P", "R", "F1", "Acc", "Spec", "Phi", "PhiN", headerVal, precision(),
recall(), f1(), accuracy(), specificity(), phiCoefficient(), phiOverPhiMax());
%10s & %4.2f & %4.2f & %4.2f & %4.2f & %4.2f & %4.2f & %4.2f""", headerKey, "P", "R", "F1", "Acc", "Spec", "Phi", "PhiN", headerVal,
precision(), recall(), f1(), accuracy(), specificity(), phiCoefficient(), phiOverPhiMax());
}

@Override
Expand Down Expand Up @@ -52,7 +52,8 @@ public String getExtendedResultStringWithExpected(ExpectedResults expectedResult
outputBuilder.append(String.format(Locale.ENGLISH, """
\tPrecision:%8.2f (min. expected: %.2f)
\tRecall:%11.2f (min. expected: %.2f)
\tF1:%15.2f (min. expected: %.2f)""", precision(), expectedResults.precision(), recall(), expectedResults.recall(), f1(), expectedResults.f1()));
\tF1:%15.2f (min. expected: %.2f)""", precision(), expectedResults.precision(), recall(), expectedResults.recall(), f1(), expectedResults
.f1()));
outputBuilder.append(String.format(Locale.ENGLISH, """
\tAccuracy:%9.2f (min. expected: %.2f)
Expand All @@ -72,8 +73,8 @@ public String getExplicitResultString() {
\tTN:%15d
\tFN:%15d
\tP:%16d
\tN:%16d""", truePositives().size(), falsePositives().size(), trueNegatives(), falseNegatives().size(), truePositives().size() + falseNegatives().size(),
trueNegatives() + falsePositives().size());
\tN:%16d""", truePositives().size(), falsePositives().size(), trueNegatives(), falseNegatives().size(), truePositives()
.size() + falseNegatives().size(), trueNegatives() + falsePositives().size());
}

public ImmutableList<T> getFound() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,49 +1,64 @@
/* Licensed under MIT 2023-2024. */
package edu.kit.kastel.mcse.ardoco.core.tests.eval.results.calculator;

import java.util.List;

import org.eclipse.collections.api.factory.Sets;
import org.eclipse.collections.api.list.ImmutableList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import edu.kit.kastel.mcse.ardoco.core.tests.eval.results.EvaluationResults;
import edu.kit.kastel.mcse.ardoco.metrics.ClassificationMetricsCalculator;
import edu.kit.kastel.mcse.ardoco.metrics.result.AggregatedClassificationResult;
import edu.kit.kastel.mcse.ardoco.metrics.result.AggregationType;
import edu.kit.kastel.mcse.ardoco.metrics.result.SingleClassificationResult;

/**
* This utility class provides methods to form the average of several {@link EvaluationResults}
*/
public final class ResultCalculatorUtil {
private static final Logger logger = LoggerFactory.getLogger(ResultCalculatorUtil.class);

private ResultCalculatorUtil() {
throw new IllegalAccessError();
}

public static <T> EvaluationResults<T> calculateAverageResults(ImmutableList<EvaluationResults<T>> results) {
var calculator = ClassificationMetricsCalculator.getInstance();
var classifications = results.stream().map(EvaluationResults::classificationResult).toList();
var averages = getAverages(results);
if (averages == null)
return null;

var averages = calculator.calculateAverages(classifications, null);
var macroAverage = averages.stream().filter(it -> it.getType() == AggregationType.MACRO_AVERAGE).findFirst().orElseThrow();

var macroAverageAsSingle = new SingleClassificationResult<T>(Sets.mutable.empty(), Sets.mutable.empty(), Sets.mutable.empty(), null, macroAverage
.getPrecision(), macroAverage.getRecall(), macroAverage.getF1(), macroAverage.getAccuracy(), macroAverage.getSpecificity(), macroAverage
.getPhiCoefficient(), macroAverage.getPhiCoefficientMax(), macroAverage.getPhiOverPhiMax());

return new EvaluationResults<>(macroAverageAsSingle);
return evaluationResults(macroAverage);
}

public static <T> EvaluationResults<T> calculateWeightedAverageResults(ImmutableList<EvaluationResults<T>> results) {
var calculator = ClassificationMetricsCalculator.getInstance();
var classifications = results.stream().map(EvaluationResults::classificationResult).toList();
var averages = getAverages(results);
if (averages == null)
return null;

var averages = calculator.calculateAverages(classifications, null);
var macroAverage = averages.stream().filter(it -> it.getType() == AggregationType.WEIGHTED_AVERAGE).findFirst().orElseThrow();
return evaluationResults(macroAverage);
}

var weightedAverageAsSingle = new SingleClassificationResult<T>(Sets.mutable.empty(), Sets.mutable.empty(), Sets.mutable.empty(), null, macroAverage
.getPrecision(), macroAverage.getRecall(), macroAverage.getF1(), macroAverage.getAccuracy(), macroAverage.getSpecificity(), macroAverage
.getPhiCoefficient(), macroAverage.getPhiCoefficientMax(), macroAverage.getPhiOverPhiMax());
private static <T> EvaluationResults<T> evaluationResults(AggregatedClassificationResult average) {
var weightedAverageAsSingle = new SingleClassificationResult<T>(Sets.mutable.empty(), Sets.mutable.empty(), Sets.mutable.empty(), null, average
.getPrecision(), average.getRecall(), average.getF1(), average.getAccuracy(), average.getSpecificity(), average.getPhiCoefficient(), average
.getPhiCoefficientMax(), average.getPhiOverPhiMax());

return new EvaluationResults<>(weightedAverageAsSingle);
}

private static <T> List<AggregatedClassificationResult> getAverages(ImmutableList<EvaluationResults<T>> results) {
if (results.isEmpty()) {
logger.warn("No results to calculate average from");
return null;
}

var calculator = ClassificationMetricsCalculator.getInstance();
var classifications = results.stream().map(EvaluationResults::classificationResult).toList();

return calculator.calculateAverages(classifications, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,13 @@ private static Pair<MutableList<EvaluationResults<String>>, StringBuilder> inspe

private static void inspectRun(StringBuilder outputBuilder, StringBuilder detailedOutputBuilder, MutableList<EvaluationResults<String>> allResults,
ArDoCoResult arDoCoResult, EvaluationResults<String> result) {
var truePositives = result.truePositives().toList();
var truePositives = result.truePositives();
appendResults(truePositives, detailedOutputBuilder, "True Positives", arDoCoResult, outputBuilder);

var falsePositives = result.falsePositives().toList();
var falsePositives = result.falsePositives();
appendResults(falsePositives, detailedOutputBuilder, "False Positives", arDoCoResult, outputBuilder);

var falseNegatives = result.falseNegatives().toList();
var falseNegatives = result.falseNegatives();
appendResults(falseNegatives, detailedOutputBuilder, "False Negatives", arDoCoResult, outputBuilder);
allResults.add(result);
}
Expand Down

0 comments on commit 0ff067c

Please sign in to comment.