-
Notifications
You must be signed in to change notification settings - Fork 62
/
26-tree.Rmd
328 lines (243 loc) · 10 KB
/
26-tree.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
# Trees
**Chapter Status:** This chapter was originally written using the `tree` packages. Currently being re-written to exclusively use the `rpart` package which seems more widely suggested and provides better plotting features.
```{r tree_opts, include = FALSE}
knitr::opts_chunk$set(cache = TRUE, autodep = TRUE, fig.align = "center")
```
```{r, message = FALSE, warning = FALSE}
library(tree)
```
In this document, we will use the package `tree` for both classification and regression trees. Note that there are many packages to do this in `R`. [`rpart`](https://cran.r-project.org/web/packages/rpart/vignettes/longintro.pdf) may be the most common, however, we will use `tree` for simplicity.
## Classification Trees
```{r}
library(ISLR)
```
To understand classification trees, we will use the `Carseat` dataset from the `ISLR` package. We will first modify the response variable `Sales` from its original use as a numerical variable, to a categorical variable with `High` for high sales, and `Low` for low sales.
```{r}
data(Carseats)
#?Carseats
str(Carseats)
Carseats$Sales = as.factor(ifelse(Carseats$Sales <= 8, "Low", "High"))
str(Carseats)
```
We first fit an unpruned classification tree using all of the predictors. Details of this process can be found using `?tree` and `?tree.control`
```{r}
seat_tree = tree(Sales ~ ., data = Carseats)
# seat_tree = tree(Sales ~ ., data = Carseats,
# control = tree.control(nobs = nrow(Carseats), minsize = 10))
summary(seat_tree)
```
We see this tree has 27 terminal nodes and a misclassification rate of 0.09.
```{r, fig.height = 12, fig.width = 24}
plot(seat_tree)
text(seat_tree, pretty = 0)
title(main = "Unpruned Classification Tree")
```
Above we plot the tree. Below we output the details of the splits.
```{r}
seat_tree
```
We now test-train split the data so we can evaluate how well our tree is working. We use 200 observations for each.
```{r}
dim(Carseats)
set.seed(2)
seat_idx = sample(1:nrow(Carseats), 200)
seat_trn = Carseats[seat_idx,]
seat_tst = Carseats[-seat_idx,]
```
```{r}
seat_tree = tree(Sales ~ ., data = seat_trn)
```
```{r}
summary(seat_tree)
```
Note that, the tree is not using all of the available variables.
```{r}
summary(seat_tree)$used
names(Carseats)[which(!(names(Carseats) %in% summary(seat_tree)$used))]
```
Also notice that, this new tree is slightly different than the tree fit to all of the data.
```{r, fig.height=12, fig.width=24}
plot(seat_tree)
text(seat_tree, pretty = 0)
title(main = "Unpruned Classification Tree")
```
When using the `predict()` function on a tree, the default `type` is `vector` which gives predicted probabilities for both classes. We will use `type = class` to directly obtain classes. We first fit the tree using the training data (above), then obtain predictions on both the train and test set, then view the confusion matrix for both.
```{r}
seat_trn_pred = predict(seat_tree, seat_trn, type = "class")
seat_tst_pred = predict(seat_tree, seat_tst, type = "class")
#predict(seat_tree, seat_trn, type = "vector")
#predict(seat_tree, seat_tst, type = "vector")
```
```{r}
# train confusion
table(predicted = seat_trn_pred, actual = seat_trn$Sales)
# test confusion
table(predicted = seat_tst_pred, actual = seat_tst$Sales)
```
```{r}
accuracy = function(actual, predicted) {
mean(actual == predicted)
}
```
```{r}
# train acc
accuracy(predicted = seat_trn_pred, actual = seat_trn$Sales)
# test acc
accuracy(predicted = seat_tst_pred, actual = seat_tst$Sales)
```
Here it is easy to see that the tree has been over-fit. The train set performs much better than the test set.
We will now use cross-validation to find a tree by considering trees of different sizes which have been pruned from our original tree.
```{r}
set.seed(3)
seat_tree_cv = cv.tree(seat_tree, FUN = prune.misclass)
```
```{r}
# index of tree with minimum error
min_idx = which.min(seat_tree_cv$dev)
min_idx
```
```{r}
# number of terminal nodes in that tree
seat_tree_cv$size[min_idx]
```
```{r}
# misclassification rate of each tree
seat_tree_cv$dev / length(seat_idx)
```
```{r}
par(mfrow = c(1, 2))
# default plot
plot(seat_tree_cv)
# better plot
plot(seat_tree_cv$size, seat_tree_cv$dev / nrow(seat_trn), type = "b",
xlab = "Tree Size", ylab = "CV Misclassification Rate")
```
It appears that a tree of size 9 has the fewest misclassifications of the considered trees, via cross-validation.
We use `prune.misclass()` to obtain that tree from our original tree, and plot this smaller tree.
```{r}
seat_tree_prune = prune.misclass(seat_tree, best = 9)
summary(seat_tree_prune)
```
```{r, fig.height=8, fig.width=12}
plot(seat_tree_prune)
text(seat_tree_prune, pretty = 0)
title(main = "Pruned Classification Tree")
```
We again obtain predictions using this smaller tree, and evaluate on the test and train sets.
```{r}
# train
seat_prune_trn_pred = predict(seat_tree_prune, seat_trn, type = "class")
table(predicted = seat_prune_trn_pred, actual = seat_trn$Sales)
accuracy(predicted = seat_prune_trn_pred, actual = seat_trn$Sales)
```
```{r}
# test
seat_prune_tst_pred = predict(seat_tree_prune, seat_tst, type = "class")
table(predicted = seat_prune_tst_pred, actual = seat_tst$Sales)
accuracy(predicted = seat_prune_tst_pred, actual = seat_tst$Sales)
```
The train set has performed almost as well as before, and there was a **small** improvement in the test set, but it is still obvious that we have over-fit. Trees tend to do this. We will look at several ways to fix this, including: bagging, boosting and random forests.
## Regression Trees
To demonstrate regression trees, we will use the `Boston` data. Recall `medv` is the response. We first split the data in half.
```{r}
library(MASS)
set.seed(18)
boston_idx = sample(1:nrow(Boston), nrow(Boston) / 2)
boston_trn = Boston[boston_idx,]
boston_tst = Boston[-boston_idx,]
```
Then fit an unpruned regression tree to the training data.
```{r}
boston_tree = tree(medv ~ ., data = boston_trn)
summary(boston_tree)
```
```{r, fig.height=8, fig.width=12}
plot(boston_tree)
text(boston_tree, pretty = 0)
title(main = "Unpruned Regression Tree")
```
As with classification trees, we can use cross-validation to select a good pruning of the tree.
```{r}
set.seed(18)
boston_tree_cv = cv.tree(boston_tree)
plot(boston_tree_cv$size, sqrt(boston_tree_cv$dev / nrow(boston_trn)), type = "b",
xlab = "Tree Size", ylab = "CV-RMSE")
```
While the tree of size 9 does have the lowest RMSE, we'll prune to a size of 7 as it seems to perform just as well. (Otherwise we would not be pruning.) The pruned tree is, as expected, smaller and easier to interpret.
```{r}
boston_tree_prune = prune.tree(boston_tree, best = 7)
summary(boston_tree_prune)
```
```{r, fig.height=8, fig.width=12}
plot(boston_tree_prune)
text(boston_tree_prune, pretty = 0)
title(main = "Pruned Regression Tree")
```
Let's compare this regression tree to an additive linear model and use RMSE as our metric.
```{r}
rmse = function(actual, predicted) {
sqrt(mean((actual - predicted) ^ 2))
}
```
We obtain predictions on the train and test sets from the pruned tree. We also plot actual vs predicted. This plot may look odd. We'll compare it to a plot for linear regression below.
```{r}
# training RMSE two ways
sqrt(summary(boston_tree_prune)$dev / nrow(boston_trn))
boston_prune_trn_pred = predict(boston_tree_prune, newdata = boston_trn)
rmse(boston_prune_trn_pred, boston_trn$medv)
```
```{r}
# test RMSE
boston_prune_tst_pred = predict(boston_tree_prune, newdata = boston_tst)
rmse(boston_prune_tst_pred, boston_tst$medv)
```
```{r}
plot(boston_prune_tst_pred, boston_tst$medv, xlab = "Predicted", ylab = "Actual")
abline(0, 1)
```
Here, using an additive linear regression the actual vs predicted looks much more like what we are used to.
```{r}
bostom_lm = lm(medv ~ ., data = boston_trn)
boston_lm_pred = predict(bostom_lm, newdata = boston_tst)
plot(boston_lm_pred, boston_tst$medv, xlab = "Predicted", ylab = "Actual")
abline(0, 1)
rmse(boston_lm_pred, boston_tst$medv)
```
We also see a lower test RMSE. The most obvious linear regression beats the tree! Again, we'll improve on this tree soon. Also note the summary of the additive linear regression below. Which is easier to interpret, that output, or the small tree above?
```{r}
coef(bostom_lm)
```
## `rpart` Package
The `rpart` package is an alternative method for fitting trees in `R`. It is much more feature rich, including fitting multiple cost complexities and performing cross-validation by default. It also has the ability to produce much nicer trees. Based on its default settings, it will often result in smaller trees than using the `tree` package. See the references below for more information. `rpart` can also be tuned via `caret`.
```{r}
library(rpart)
set.seed(430)
# Fit a decision tree using rpart
# Note: when you fit a tree using rpart, the fitting routine automatically
# performs 10-fold CV and stores the errors for later use
# (such as for pruning the tree)
# fit a tree using rpart
seat_rpart = rpart(Sales ~ ., data = seat_trn, method = "class")
# plot the cv error curve for the tree
# rpart tries different cost-complexities by default
# also stores cv results
plotcp(seat_rpart)
# find best value of cp
min_cp = seat_rpart$cptable[which.min(seat_rpart$cptable[,"xerror"]),"CP"]
min_cp
# prune tree using best cp
seat_rpart_prune = prune(seat_rpart, cp = min_cp)
# nicer plots
library(rpart.plot)
prp(seat_rpart_prune)
prp(seat_rpart_prune, type = 4)
rpart.plot(seat_rpart_prune)
```
## External Links
- [An Introduction to Recursive Partitioning Using the `rpart` Routines](https://cran.r-project.org/web/packages/rpart/vignettes/longintro.pdf) - Details of the `rpart` package.
- [`rpart.plot` Package](http://www.milbo.org/doc/prp.pdf) - Detailed manual on plotting with `rpart` using the `rpart.plot` package.
## `rmarkdown`
The `rmarkdown` file for this chapter can be found [**here**](26-tree.Rmd). The file was created using `R` version `r paste0(version$major, "." ,version$minor)`. The following packages (and their dependencies) were loaded when knitting this file:
```{r, echo = FALSE}
names(sessionInfo()$otherPkgs)
```