-
Notifications
You must be signed in to change notification settings - Fork 0
/
tree.lua
330 lines (255 loc) · 9.33 KB
/
tree.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
-- extend package.path with path of this .lua file:local filepath = debug.getinfo(1).source:match("@(.*)$")
local filepath = debug.getinfo(1).source:match("@(.*)$")
local dir = string.gsub(filepath, '/[^/]+$', '') .. "/"
package.path = dir .. "/?.lua;" .. package.path
local math = require("math")
local ffi = require("ffi")
local bitop = require("bit")
require("luarocks.loader")
local array = require("ljarray.array")
local helpers = require("ljarray.helpers")
local operator = helpers.operator
local criteria = require("criteria")
local Partition = require("partition")
--module(..., package.seeall) -- export all local functions
local Tree = {}
Tree.__index = Tree
Tree.create = function(options)
assert(type(options.n_classes) == "number", "Tree: options.n_classes must be set and a number" .. type(options.n_classes))
local tree = {}
setmetatable(tree, Tree)
tree.n_classes = options.n_classes
tree.max_depth = options.max_depth or -1
tree.criterion_class = options.criterion or criteria.Gini
tree.m_try = options.m_try -- default m_try is determined at learning time
tree.f_subsample = options.f_subsample or 0.66
tree.children = array.create({2048,2})
tree.split_feature = array.create({2048})
tree.split_threshold = array.create({2048})
tree.node_count = 0
return tree
end
Tree.learn = function(self,X,y)
assert(X.ndim == 2)
assert(y.ndim == 1)
assert(X.shape[0] == y.shape[0])
self.n_features = X.shape[1]
self.m_try = self.m_try or math.floor(math.sqrt(self.n_features)) -- default m_try is sqrt of feature number
self.criterion = self.criterion_class.create(self.n_classes)
self._feature_indices = array.arange(X.shape[1])
self:_build(X,y)
-- print("finished learning")
end
Tree._predict_class_counts = function(self, X)
assert(X.ndim == 2)
assert(X.shape[1] == self.n_features)
local leaf_nodes = self:_predict_leaf_nodes(X)
local class_counts = self.leaf_class_count:take_slices(leaf_nodes, 0)
return class_counts
end
Tree.predict = function(self,X)
assert(X.ndim == 2)
assert(X.shape[1] == self.n_features)
local leaf_nodes = self:_predict_leaf_nodes(X)
local class_counts = self.leaf_class_count:take_slices(leaf_nodes, 0)
local result = array.create({X.shape[0]},array.int32)
result:assign(0)
for c = 0, self.n_classes do
for i = 0, X.shape[0]-1 do
if class_counts:get(i,c) > class_counts:get(i,result:get(i)) then
result:set(i,c)
end
end
end
return result
end
Tree._predict_leaf_nodes = function(self, X)
local leaf_node_indices = array.create({X.shape[0]}, array.int32)
for i=0,X.shape[0]-1 do
local node = 0
local f = self.split_feature:get(node)
local did = false
while f ~= -1 or did == false do
did = true
if X:get(i, f) <= self.split_threshold:get(node) then
node = self.children:get(node, 0) --left branch
else
node = self.children:get(node, 1) --right branch
end
f = self.split_feature:get(node)
end
leaf_node_indices:set(i,node)
end
return leaf_node_indices
end
Tree._setup_leaf_nodes = function(self, X, y)
assert(X.shape[0] == y.shape[0])
local leaf_class_count = array.create({self.node_count+1, self.n_classes+1}, array.int32)
leaf_class_count:assign(0)
local leaf_node_indices = self:_predict_leaf_nodes(X)
for i=0,X.shape[0]-1 do
local node = leaf_node_indices:get(i)
local class = y:get(i)
local count = leaf_class_count:get(node,class) + 1
leaf_class_count:set(node,class,count)
end
self.leaf_class_count = leaf_class_count
end
Tree._recursive_partition = function(tree,parent_node, is_left_child, partition_number)
--print("recurse, size: ", tree.partition.size[partition_number])
-- only split further if more then 1 sample
if tree.partition:size(partition_number) > 1 then
-- find best split
local best_split_feat, best_split_val, best_split_x, best_split_pos = tree:_find_best_split(partition_number)
if best_split_feat ~= -1 then
--found split, add split node
local node = tree:_add_split_node(parent_node, is_left_child, best_split_feat, best_split_x)
local left_part, right_part = tree.partition:split(partition_number, best_split_feat, best_split_pos)
-- recurse for left and right partition
tree:_recursive_partition(node, true, left_part)
tree:_recursive_partition(node, false, right_part)
else
-- did not find possible split, add leaf node
if best_split_val > 0 then
print("best_split_val", best_split_val, tree.partition:size(partition_number))
end
tree:_add_leaf_node(parent_node, is_left_child)
end
else
-- 1-element node: clean node
tree:_add_leaf_node(parent_node, is_left_child)
end
end
Tree._build = function(self, X, y)
assert(X.ndim == 2)
assert(y.ndim == 1)
assert(X.shape[0] == y.shape[0])
self.node_count = 0
local crit = self.criterion_class.create(self.n_classes) -- construct criterion
-- creat sample mask
local subsample_mask = array.create({y.shape[0]}, array.int8)
for i = 0, (y.shape[0]-1)*self.f_subsample do
subsample_mask.data[i] = 1
end
for i = (y.shape[0]-1)*self.f_subsample+1, y.shape[0] - 1 do
subsample_mask.data[i] = 0
end
subsample_mask:permute()
self.partition = Partition.create(X,y,subsample_mask)
self.X = X
self.y = y
self.temp_y = array.create({self.y.shape[0]}, array.int32)
self.crit = crit
print("partitioning..")
self:_recursive_partition(-1, nil, 0)
print("done")
-- trim array size
self:_resize(self.node_count)
-- setup leaf nodes with class counts
print("setting up leaf nodes..")
self:_setup_leaf_nodes(X,y)
print("done")
-- print("finished building")
print("FINISHED")
end
Tree._find_best_split = function(self, partition_number)
-- find best split for self.m_try features
assert(self._feature_indices.strides[0] == 1)
self._feature_indices:permute()
local best_split_val = 1e6
local best_split_pos = 0
local best_split_feat = -1
local best_split_x = 0
local partition_size = self.partition:size(partition_number)
local left_partition = nil
local right_partition = nil
local i = 0
-- try at least m_try features for splitting
-- if no valid split point was found, try more features
while (i < self.X.shape[1] and best_split_feat == -1) or i < self.m_try do
local f = self._feature_indices.data[i] -- get feature index to try
local split_pos, split_val, split_x = self:_find_best_split_for_feat(partition_number, f, self.crit)
if split_pos ~= -1 and split_val < best_split_val then
best_split_pos = split_pos
best_split_val = split_val
best_split_x = split_x
best_split_feat = f
end
if split_val < best_split_val then
best_split_val = split_val
end
i = i + 1
end
-- print("best_split_pos", best_split_pos)
-- print("best_split_feat", best_split_feat)
-- print("best_split_x", best_split_x)
-- print("best_split_val", best_split_val)
return best_split_feat, best_split_val, best_split_x, best_split_pos
end
Tree._find_best_split_for_feat = function(self, partition, f, crit)
-- finds the split position that minimizes self.criterion_class
-- assert(x.ndim == 1)
-- assert(x.ndim == y.ndim)
-- assert(x.shape[0] == y.shape[0])
-- assert(x.strides[0] == 1)
-- assert(x_argsorted.strides[0] == 1)
-- assert(y.shape[0]>1)
local argsort = self.partition:sort(partition, f)
local start, stop = self.partition:range(partition)
local x = self.partition.values
local y = self.temp_y
for i = start,stop-1 do
y.data[i] = self.partition.y.data[argsort.data[i]]
end
crit:init(y,start,stop)
local best_split_pos = -1 -- position of split, -1 means invalid
local best_split_val = crit:eval() -- value of critertion that is minimized
local best_split_x -- feature value at best split
for i = start, stop-2 do
-- splits can only happen between different feature values
crit:move(1)
--print(partition, start, stop, i, x.data[i])
assert(x.data[i] <= x.data[i+1])
if x.data[i+1] ~= x.data[i] then
local crit_val = crit:eval()
if crit_val < best_split_val then
best_split_pos = i - start
best_split_val = crit_val
end
end
end
best_split_x = (x.data[start + best_split_pos] + x.data[start + best_split_pos+1])/2
return best_split_pos, best_split_val, best_split_x
end
Tree._add_split_node = function(self, parent, is_left_child, feature, threshold)
assert(parent >= -1)
assert(parent <= self.node_count)
local id = self.node_count
self.node_count = self.node_count + 1
if self.node_count >= self.children.shape[0]-2 then
self:_resize(self.node_count * 2)
end
assert(self.children.shape[0] > id)
assert(self.split_feature.shape[0] > id)
assert(self.split_threshold.shape[0] > id)
if parent >= 0 then
if is_left_child then
self.children:set(parent,0,id)
else
self.children:set(parent,1,id)
end
end
self.split_feature:set(id,feature)
self.split_threshold:set(id,threshold)
return id
end
Tree._add_leaf_node = function(self, parent, is_left_child)
-- a leaf node is identified by split_feature = -1
return self:_add_split_node(parent, is_left_child, -1,0)
end
Tree._resize = function(self, size)
self.children:resize({size,2})
self.split_feature:resize({size})
self.split_threshold:resize({size})
end
return Tree