-
Notifications
You must be signed in to change notification settings - Fork 3
/
evaluate.java
86 lines (72 loc) · 2.84 KB
/
evaluate.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import java.io.BufferedReader;
import java.io.FileReader;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.classifiers.trees.RandomForest;
import weka.classifiers.trees.J48;
import weka.classifiers.rules.PART;
import weka.classifiers.rules.JRip;
import weka.core.SerializationHelper;
import weka.core.converters.ConverterUtils.DataSource;
public class evaluate {
private static Instances getDataFromFile(String path) throws Exception{
DataSource source = new DataSource(path);
Instances data = source.getDataSet();
if (data.classIndex() == -1){
data.setClassIndex(data.numAttributes()-1);
//last attribute as class index
}
return data;
}
public static double roundHalfDown(double d) {
return new BigDecimal(d).setScale(0, RoundingMode.HALF_DOWN)
.doubleValue();
}
public static void main(String args[]) throws Exception {
// Load models
RandomForest rf = (RandomForest) SerializationHelper.read("./models/models_nominal/random_forest.model");
J48 j48 = (J48) SerializationHelper.read("./models/models_nominal/j48.model");
JRip jrip = (JRip) SerializationHelper.read("./models/models_nominal/jrip.model");
// Load headers from all.txt
BufferedReader in = new BufferedReader(new FileReader("./headers.txt"));
String str;
List<String> header_list = new ArrayList<String>();
while((str = in.readLine()) != null) {
header_list.add(str);
}
List<Integer> linenos = new ArrayList<Integer>();
for(int i=3; i<=1068; i++) {
linenos.add(i);
}
// Modify headers of arff file
Path path = Paths.get("final.arff");
List<String> lines = Files.readAllLines(path, StandardCharsets.UTF_8);
for(int i=0; i<linenos.size(); i++) {
lines.set(linenos.get(i)-1, header_list.get(i));
Files.write(path, lines, StandardCharsets.UTF_8);
}
// Load test set
Instances testingdata = getDataFromFile("final.arff");
int s = testingdata.numInstances();
for(int i=0; i<s; i++) {
// Make predictions
double valuej48 = j48.classifyInstance(testingdata.instance(i));
double valuerf = rf.classifyInstance(testingdata.instance(i));
double valuejrip = jrip.classifyInstance(testingdata.instance(i));
double final_val = (valuej48 + valuerf + valuejrip)/3;
// System.out.println(String.valueOf(final_val));
// get the name of the class value
String prediction = testingdata.classAttribute().value((int)roundHalfDown(final_val));
System.out.println("The predicted value of instance " + Integer.toString(i) + ": " + prediction);
}
}
}