-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathperceplearn3.py
124 lines (106 loc) · 4.32 KB
/
perceplearn3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import sys;
import math;
import json;
import string;
import random;
from collections import defaultdict;
# list ds - {id,class1,class2,map of wordCount}
sentenceFeatureClassList = [];
vocabSet=set();
def readFile(fileName):
inputFileObj = open(fileName, encoding="utf8");
return inputFileObj;
def saveToFile(fileName,class1WeightMap,biasClass1,class2WeightMap,biasClass2):
f = open(fileName, "w");
f.write(json.dumps(class1WeightMap));
f.write("\n");
f.write(json.dumps(biasClass1));
f.write("\n");
f.write(json.dumps(class2WeightMap));
f.write("\n");
f.write(json.dumps(biasClass2));
f.close()
def removePunctuation(line):
translator = str.maketrans('', '', string.punctuation);
return line.translate(translator);
def constructTokens(inputFileObj):
i = 0;
for line in inputFileObj:
# create wordCountMap
wordCountMap = defaultdict(int);
sentenceFeatureClassListRow=[];
line = removePunctuation(line);
doc = line.rstrip().split(" ");
sentenceFeatureClassListRow.append(doc[0]); # set id
sentenceFeatureClassListRow.append(-1) if (doc[1] == "Fake") else sentenceFeatureClassListRow.append(1); # set class1
sentenceFeatureClassListRow.append(-1) if (doc[2] == "Neg") else sentenceFeatureClassListRow.append(1); # set class2
words = (doc[3:]);
for word in words:
word = word.lower();
if (word in stopList):
continue;
vocabSet.add(word);
wordCountMap[word] += 1;
sentenceFeatureClassListRow.append(wordCountMap);
sentenceFeatureClassList.append(sentenceFeatureClassListRow);
i += 1;
def trainAvgPerceptron(classNo):
# init
classWeightMap = defaultdict(int);
cachedWeightMap = defaultdict(int);
# for key in vocabSet:
# classWeightMap[key]=0
bias = 0;
cachedBias = 0;
c = 1;
for iteration in range(21):
random.shuffle(sentenceFeatureClassList);
for doc in sentenceFeatureClassList:
activation = 0;
for featureX, input in doc[3].items():
activation += input * classWeightMap.get(featureX, 0);
activation += bias;
y = doc[classNo];
if (y * activation <= 0):
for featureA, inputA in doc[3].items():
classWeightMap[featureA] = classWeightMap.get(featureA, 0) + y * inputA;
cachedWeightMap[featureA] = classWeightMap.get(featureA, 0) + y * inputA;
bias += y;
cachedBias += y * c;
c += 1;
for key in classWeightMap:
classWeightMap[key] -= cachedWeightMap[key] / c;
return classWeightMap, bias - cachedBias / c;
def trainVanillaPerceptron(classNo):
# init
classWeightMap = defaultdict(int);
# for key in vocabSet:
# classWeightMap[key]=0
bias = 0;
for iteration in range(21):
random.shuffle(sentenceFeatureClassList);
for doc in sentenceFeatureClassList:
activation = 0;
for featureX, input in doc[3].items():
activation += input * classWeightMap.get(featureX, 0);
activation += bias;
y = doc[classNo];
if (y * activation <= 0):
for featureA, inputA in doc[3].items():
classWeightMap[featureA] = classWeightMap.get(featureA, 0) + y * inputA;
bias += y;
return classWeightMap, bias;
inputFileObj = readFile(sys.argv[1]);
stopList = [line.rstrip() for line in open("input/stop-words.txt", encoding="utf8")];
constructTokens(inputFileObj);
class1WeightMap,biasClass1=trainVanillaPerceptron(1);
class2WeightMap,biasClass2=trainVanillaPerceptron(2);
class1AvgWeightMap,biasAvgClass1=trainAvgPerceptron(1);
class2AvgWeightMap,biasAvgClass2=trainAvgPerceptron(2);
saveToFile("averagedmodel.txt",class1AvgWeightMap,biasAvgClass1,class2AvgWeightMap,biasAvgClass2);
saveToFile("vanillamodel.txt",class1WeightMap,biasClass1,class2WeightMap,biasClass2);
# print("classCountMap:",classCountMap);
# print("priorClassProps:",priorClassProps);
# print("vocab:",vocabMap);
# print("classWordMap:",classWordMap1);
# print("classWordMap:",classWordMap2);