-
Notifications
You must be signed in to change notification settings - Fork 90
/
Part_6.html
1998 lines (1304 loc) · 51 KB
/
Part_6.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<!DOCTYPE html>
<html xmlns="http://www.w3.org/1999/xhtml" lang="" xml:lang="">
<head>
<title>Applied Machine Learning</title>
<meta charset="utf-8" />
<meta name="author" content="Max Kuhn and Davis Vaughan (RStudio)" />
<meta name="date" content="2020-01-26" />
<link href="libs/remark-css-0.0.1/default.css" rel="stylesheet" />
<script src="libs/kePrint-0.0.1/kePrint.js"></script>
<link href="libs/countdown-0.3.3/countdown.css" rel="stylesheet" />
<script src="libs/countdown-0.3.3/countdown.js"></script>
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.8.2/css/all.css" integrity="sha384-oS3vJWv+0UjzBfQzYUhtDYW+Pj2yciDJxpsK1OYPAYjqT085Qq/1cq5FLXAZQ7Ay" crossorigin="anonymous">
<link rel="stylesheet" href="assets/css/aml-theme.css" type="text/css" />
<link rel="stylesheet" href="assets/css/aml-fonts.css" type="text/css" />
</head>
<body>
<textarea id="source">
class: title-slide, center
<span class="fa-stack fa-4x">
<i class="fa fa-circle fa-stack-2x" style="color: #ffffff;"></i>
<strong class="fa-stack-1x" style="color:#E7553C;">6</strong>
</span>
# Applied Machine Learning
## Classification Models
---
# Outline
* Performance Measures
* Amazon Review Data
* Classification Trees
* Boosting
* Extra topics as time allows
---
# Load Packages
```r
library(tidymodels)
```
```
## ── Attaching packages ───────────────────────────────────────────── tidymodels 0.0.4 ──
```
```
## ✓ broom 0.5.3 ✓ recipes 0.1.9
## ✓ dials 0.0.4 ✓ rsample 0.0.5
## ✓ dplyr 0.8.3 ✓ tibble 2.1.3
## ✓ infer 0.5.1 ✓ tune 0.0.1
## ✓ parsnip 0.0.5 ✓ workflows 0.1.0
## ✓ purrr 0.3.3 ✓ yardstick 0.0.5
```
```
## ── Conflicts ──────────────────────────────────────────────── tidymodels_conflicts() ──
## x purrr::accumulate() masks foreach::accumulate()
## x purrr::discard() masks scales::discard()
## x dplyr::filter() masks stats::filter()
## x recipes::fixed() masks stringr::fixed()
## x dplyr::group_rows() masks kableExtra::group_rows()
## x dplyr::ident() masks dbplyr::ident()
## x dplyr::lag() masks stats::lag()
## x purrr::lift() masks caret::lift()
## x dials::margin() masks ggplot2::margin()
## x yardstick::precision() masks caret::precision()
## x dials::prune() masks rpart::prune()
## x yardstick::recall() masks caret::recall()
## x dplyr::select() masks MASS::select()
## x dplyr::sql() masks dbplyr::sql()
## x recipes::step() masks stats::step()
## x purrr::when() masks foreach::when()
## x recipes::yj_trans() masks scales::yj_trans()
```
---
layout: false
class: inverse, middle, center
# Measuring Performance in Classification
---
# Illustrative Example <img src="images/yardstick.png" class="title-hex">
`yardstick` contains another test set example in a data frame called `two_class_example`:
```r
two_class_example %>% head(4)
```
```
## truth Class1 Class2 predicted
## 1 Class2 0.00359 0.996 Class2
## 2 Class1 0.67862 0.321 Class1
## 3 Class2 0.11089 0.889 Class2
## 4 Class1 0.73516 0.265 Class1
```
Both `truth` and `predicted` are factors with the same levels. The other two columns represent _class probabilities_.
This reflects that most classification models can generate "hard" and "soft" predictions for models.
The class predictions are usually created by thresholding some numeric output of the model (e.g. a class probability) or by choosing the largest value.
---
# Class Prediction Metrics <img src="images/yardstick.png" class="title-hex">
.pull-left[
With class predictions, a common summary method is to produce a _confusion matrix_ which is a simple cross-tabulation between the observed and predicted classes:
```r
two_class_example %>%
conf_mat(truth = truth, estimate = predicted)
```
```
## Truth
## Prediction Class1 Class2
## Class1 227 50
## Class2 31 192
```
These can be visualized using [mosaic plots](https://en.wikipedia.org/wiki/Mosaic_plot).
]
.pull-right[
Accuracy is the most obvious metric for characterizing the performance of models.
```r
two_class_example %>%
accuracy(truth = truth, estimate = predicted)
```
```
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.838
```
However, it suffers when there is a _class imbalance_; suppose 95% of the data have a specific class. 95% accuracy can be achieved by predicting samples to be the majority class. There are measures that correct for the natural event rate, such as [Cohen's Kappa](https://en.wikipedia.org/wiki/Cohen%27s_kappa).
]
---
# Two Classes <img src="images/yardstick.png" class="title-hex">
There are a number of specialized metrics that can be used when there are two classes. Usually, one of these classes can be considered the _event of interest_ or the _positive class_.
One common way to think about performance is to consider false negatives and false positives.
* The **sensitivity** is the _true positive rate_ (out of all of the actual positives, how many did you get right?).
* The **specificity** is the rate of correctly predicted negatives, or 1 - _false positive rate_ (out of all the actual negatives, how many did you get right?).
From this, assuming that `Class1` is the event of interest:
.pull-left[
```
## Truth
## Prediction Class1 Class2
## Class1 227 50
## Class2 31 192
```
]
.pull-right[
sensitivity = 227/(227 + 31) = 0.88
specificity = 192/(192 + 50) = 0.79
]
---
# Conditional and Unconditional Measures <img src="images/yardstick.png" class="title-hex">
Sensitivity and specificity can be computed from `sens()` and `spec()`, respectively.
It should be noted that these are _conditional measures_ since we need to know the true outcome.
The event rate is the _prevalence_ (or the Bayesian _prior_). Sensitivity and specificity are analogous to the _likelihood values_.
There are _unconditional_ analogs to the _posterior values_ called the positive predictive values and the negative predictive values.
A variety of other measures are available for two class systems, especially for _information retrieval_.
One thing to consider: what happens if our **threshold to call a sample an event is not optimal**?
---
# Changing the Probability Threshold <img src="images/yardstick.png" class="title-hex">
.pull-left[
For two classes, the 50% cutoff is customary; if the probability of class 1 is >= 50%, they would be labelled as `Class1`.
What happens when you change the cutoff?
* Increasing it makes it harder to be called `Class1` `\(\Rightarrow\)` fewer predicted events, specificity `\(\uparrow\)`, sensitivity `\(\downarrow\)`
* Decreasing the cutoff makes it easier to be called `Class1` `\(\Rightarrow\)` more predicted events, specificity `\(\downarrow\)`, sensitivity `\(\uparrow\)`
]
.pull-right[
With two classes, the **Receiver Operating Characteristic (ROC) curve** can be used to estimate performance using a combination of sensitivity and specificity.
To create the curve, many alternative cutoffs are evaluated.
For each cutoff, we calculate the sensitivity and specificity.
The ROC curve plots the sensitivity (eg. true positive rate) versus 1 - specificity (eg. the false positive rate).
]
---
# The Receiver Operating Characteristic (ROC) Curve <img src="images/yardstick.png" class="title-hex"><img src="images/ggplot2.png" class="title-hex"><img src="images/dplyr.png" class="title-hex">
.pull-left[
```r
roc_obj <-
two_class_example %>%
roc_curve(truth, Class1)
```
```r
two_class_example %>% roc_auc(truth, Class1)
```
```
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 roc_auc binary 0.939
```
```r
autoplot(roc_obj) + thm
```
]
.pull-right[
<img src="images/part-6-roc-plot-1.svg" width="100%" style="display: block; margin: auto;" />
]
---
# Changing the Threshold
<img src="images/part-6-unnamed-chunk-1-1.gif" width="50%" style="display: block; margin: auto;" />
---
# The Receiver Operating Characteristic (ROC) Curve
The ROC curve has some major advantages:
* It can allow models to be optimized for performance before a definitive cutoff is determined.
* It is _robust_ to class imbalances; no matter the event rate, it does a good job at characterizing model performance.
* The ROC curve can be used to pick an optimal cutoff based on the trade-offs between the types of errors that can occur.
When there are two classes, it is advisable to focus on the area under the ROC curve instead of sensitivity and specificity.
Once an acceptable model is determined, a proper cutoff can be determined.
---
layout: false
class: inverse, middle, center
# Example Data
---
# Amazon Review Data
These data are from Amazon, who describe it as
> "This dataset consists of reviews of fine foods from amazon. The data span a period of more than 10 years, including all ~500,000 reviews up to October 2012. Reviews include product and user information, ratings, and a plaintext review."
We will use the text data to predict whether the review have a five-star result or not.
This will involve some natural language processing methods, which we will walk through.
The data are found in the `modeldata` package
```r
library(modeldata)
data(small_fine_foods)
```
---
# Feature Engineering
Most of the work for the features is to extra information from text.
We will use the basics here but more information can be found in the [_Tidy Text Mining with R_](https://www.tidytextmining.com/).
To do this, we will heavily rely on the `textrecipes` package.
---
# Defining roles <img src="images/recipes.png" class="title-hex">
.pull-left-a-lot[
.font80[
```r
text_rec <-
recipe(score ~ product + review, data = training_data) %>%
* update_role(product, new_role = "id")
```
]
]
.pull-right-a-little[
`product` is used for data splitting (as we'll see in a bit).
Arguably, it is not a predictor (although some might use it that way).
We update the role so that it is retained in the recipe but not used as a predictor.
]
---
# Copying a column <img src="images/recipes.png" class="title-hex">
.pull-left-a-lot[
.font80[
```r
text_rec <-
recipe(score ~ product + review, data = training_data) %>%
update_role(product, new_role = "id") %>%
* step_mutate(review_raw = review)
```
]
]
.pull-right-a-little[
Two of the steps that we'll use will destroy the original predictor.
We'll use a basic `step_mutate()` to make a temporary copy.
]
---
# Initial feature set <img src="images/recipes.png" class="title-hex">
.pull-left-a-lot[
.font80[
```r
text_rec <-
recipe(score ~ product + review, data = training_data) %>%
update_role(product, new_role = "id") %>%
step_mutate(review_raw = review) %>%
* step_textfeature(review_raw)
```
]
]
.pull-right-a-little[
A set of numeric predictors are derived from the text.
Most are counts of text elements (e.g. words, punctuation, etc)
]
---
# Tokenize <img src="images/recipes.png" class="title-hex">
.pull-left-a-lot[
.font80[
```r
text_rec <-
recipe(score ~ product + review, data = training_data) %>%
update_role(product, new_role = "id") %>%
step_mutate(review_raw = review) %>%
step_textfeature(review_raw) %>%
* step_tokenize(review)
```
]
]
.pull-right-a-little[
]
---
# Remove Stop Words <img src="images/recipes.png" class="title-hex">
.pull-left-a-lot[
.font80[
```r
text_rec <-
recipe(score ~ product + review, data = training_data) %>%
update_role(product, new_role = "id") %>%
step_mutate(review_raw = review) %>%
step_textfeature(review_raw) %>%
step_tokenize(review) %>%
* step_stopwords(review)
```
]
]
.pull-right-a-little[
Stop words are those that occur commonly in text, such as "the", "a", and so on.
Removing them from text _might_ be a good idea.
This largely depends on what you are doing with the text.
]
---
# Word stemming <img src="images/recipes.png" class="title-hex">
.pull-left-a-lot[
.font80[
```r
text_rec <-
recipe(score ~ product + review, data = training_data) %>%
update_role(product, new_role = "id") %>%
step_mutate(review_raw = review) %>%
step_textfeature(review_raw) %>%
step_tokenize(review) %>%
step_stopwords(review) %>%
* step_stem(review)
```
]
]
.pull-right-a-little[
Stemming is a method that uses a common root of a word instead of the original value.
For example, these 7 words are fairly similar: "teach", "teacher", "teachers", "teaches", "teachable", "teaching", "teachings".
Stemming would reduce these to 3 unique values: "teach", "teacher", "teachabl".
Like stop word removal, this may or may not be a good idea.
]
---
# Feature hashing <img src="images/recipes.png" class="title-hex">
.pull-left-a-lot[
.font80[
```r
text_rec <-
recipe(score ~ product + review, data = training_data) %>%
update_role(product, new_role = "id") %>%
step_mutate(review_raw = review) %>%
step_textfeature(review_raw) %>%
step_tokenize(review) %>%
step_stopwords(review) %>%
step_stem(review) %>%
* step_texthash(review, signed = FALSE, num_terms = 1024)
```
]
]
.pull-right-a-little[
Feature hashing creates numeric terms from words in a sentence (or some other token) similar to dummy variables.
However, there are big differences, including:
* There is no look-up table to consult to make the mapping
* The placement of the non-zero values in meant to emulate randomness.
* The new features are computed on the actual words.
]
---
# Feature hashing
.pull-left[
For string "On Time and product looked like it", a sketch of the calculations to make 8 hashed values:
<table class="table" style="margin-left: auto; margin-right: auto;">
<thead>
<tr>
<th style="text-align:left;"> word </th>
<th style="text-align:right;"> hashed integer value </th>
<th style="text-align:right;"> (integer mod 8) + 1 </th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align:left;"> On </td>
<td style="text-align:right;"> -182693672 </td>
<td style="text-align:right;"> 4 </td>
</tr>
<tr>
<td style="text-align:left;"> Time </td>
<td style="text-align:right;"> 1593484409 </td>
<td style="text-align:right;"> 8 </td>
</tr>
<tr>
<td style="text-align:left;"> and </td>
<td style="text-align:right;"> -1079337235 </td>
<td style="text-align:right;"> 8 </td>
</tr>
<tr>
<td style="text-align:left;"> product </td>
<td style="text-align:right;"> -979280496 </td>
<td style="text-align:right;"> 6 </td>
</tr>
<tr>
<td style="text-align:left;"> looked </td>
<td style="text-align:right;"> -2120797534 </td>
<td style="text-align:right;"> 2 </td>
</tr>
<tr>
<td style="text-align:left;"> like </td>
<td style="text-align:right;"> -592737581 </td>
<td style="text-align:right;"> 5 </td>
</tr>
<tr>
<td style="text-align:left;"> it </td>
<td style="text-align:right;"> 1278008556 </td>
<td style="text-align:right;"> 2 </td>
</tr>
</tbody>
</table>
]
.pull-right[
Note that multiple words end up in the same feature column. This is _aliasing_ (statistical term) or a _collision_ (comp sci term).
* We wouldn't be able to distinguish the effect of those two words.
We can encode this as a simple zero or, as `textrecipes` does, use the count as the value.
* There are also _signed_ hashes that help avoid collisions.
Note that no words were mapped to feature columns three or seven.
]
---
# Optional step: convert binary to factors <img src="images/recipes.png" class="title-hex">
.pull-left-a-lot[
.font80[
```r
count_to_binary <- function(x) {
factor(ifelse(x != 0, "present", "absent"),
levels = c("present", "absent"))
}
text_rec <-
recipe(score ~ product + review, data = training_data) %>%
update_role(product, new_role = "id") %>%
step_mutate(review_raw = review) %>%
step_textfeature(review_raw) %>%
step_tokenize(review) %>%
step_stopwords(review) %>%
step_stem(review) %>%
step_texthash(review, signed = FALSE, num_terms = 1024) %>%
* step_mutate_at(starts_with("review_hash"), fn = count_to_binary)
```
]
]
.pull-right-a-little[
The naive Bayes model will be used on these data.
It computes probability values from each predictor.
- If the predictor is numeric, its statistical density is used.
- If categorical, a contingency table is used.
.font100[
Since the hash values are really about the presence/absence of words, we should convert them to 2-level factor variables to ensure appropriate calculations.
]
]
---
# Optional step: remove zero-variance predictors <img src="images/recipes.png" class="title-hex">
.pull-left-a-lot[
.font80[
```r
count_to_binary <- function(x) {
factor(ifelse(x != 0, "present", "absent"),
levels = c("present", "absent"))
}
text_rec <-
recipe(score ~ product + review, data = training_data) %>%
update_role(product, new_role = "id") %>%
step_mutate(review_raw = review) %>%
step_textfeature(review_raw) %>%
step_tokenize(review) %>%
step_stopwords(review) %>%
step_stem(review) %>%
step_texthash(review, signed = FALSE, num_terms = 1024) %>%
step_mutate_at(starts_with("review_hash"), fn = count_to_binary) %>%
* step_zv(all_predictors())
```
]
]
.pull-right-a-little[
Removing the features that are all zero in the training set increases computational efficiency and may stop model failures.
]
---
# Resampling and Analysis Strategy <img src="images/rsample.png" class="title-hex">
There are enough data here to do a simple 10-fold cross-validation.
Since there is a class imbalance, we will stratify the splits.
```r
set.seed(8935)
text_folds <- vfold_cv(training_data, strata = "score")
```
---
layout: false
class: inverse, middle, center
# Classification Trees
---
# Tree model structure
A classification tree searches through each predictor to find a value of a single variable that best splits the data into two groups.
For the resulting groups, the process is repeated until a hierarchical structure (a tree) is created.
* In effect, trees partition the `\(X\)` space into rectangular sections that assign a single value to samples within the rectangle.
The final structure in the tree is the _terminal node_ and each path through the tree is a _rule_.
.pull-left[
```r
# Example tree with three terminal nodes
if (x > 1) {
if (y < 3) {
class <- "A"
} else {
class <- "B"
}
} else {
class <- "A"
}
```
]
.pull-right[
```r
# Same tree, stated as rules
if (x > 1 & y < 3) class <- "A"
if (x > 1 & y >= 3) class <- "B"
if (x <= 1) class <- "A"
```
]
---
# Species of tree-based models
There are a variety of different methods for creating trees that vary over:
* The search method (e.g., greedy or global).
* The splitting criterion.
* The number of splits.
* Handling of missing values.
* Pruning method.
The most popular is the CART methodology, followed by the C5.0 model.
We will focus on CART for single trees.
The [CRAN Machine Learning Task View](https://cran.r-project.org/web/views/MachineLearning.html) has a good summary of the methods available in R.
---
# Growing phase
CART starts by _growing_ the tree.
* More and more splits are conducted until a pre-specified samples size requirement is exceeded (`min_n`).
* The criterion used is the _purity_ of the terminal nodes that are created by each split.
For example, for simulated data with a 50% event rate, which one of these splits is better?
.pull-left[
<img src="images/part-6-good-split-1.svg" width="70%" style="display: block; margin: auto;" />
]
.pull-right[
<img src="images/part-6-bad-split-1.svg" width="70%" style="display: block; margin: auto;" />
]
---
# Pruning phase
The deepest possible tree has a higher likelihood of overfitting the data.
CART conducts cost-complexity pruning to find the "right sized tree".
It basically penalizes the error rate of the tree by the number of terminal nodes by minimizing
`$$Error_{cv} - (c_p \times nodes)$$`
* The `\(c_p\)` value, usually between 0 and 0.1, controls the depth of the tree.
* CART has an internal 10-fold cross-validation that it uses to estimate the model error.
* If the outcome has a large class imbalance, this method optimizes the tree for the majority class.
For CART, `\(c_p\)` (aka `cost_complexity`) and the minimum splitting size (`min_n`) are the tuning parameters.
---
# Aspects of single trees
* The class percentages in the terminal node are used to make predictions.
* The number of possible class probabilities is typically low.
* Trees are _theoretically_ interpretable if the number of terminal nodes is low.
* The training time tends to be very fast.
* Trees are _unstable_; if the data are slightly changed, the entire tree structure can change. These are [low-bias/high-variance models](https://bookdown.org/max/FES/important-concepts.html#model-bias-and-variance).
* Very little, if any, data pre-processing is needed. _Dummy variables [are not required](https://bookdown.org/max/FES/categorical-trees.html)_.
* Trees automatically conduct _feature selection_.
---
# Fitting and tuning trees
Like MARS, there are two main ways to tune the CART model:
* Rely on the internal CV procedure to pick the tree depth via purity/error rate:
```r
decision_tree(min_n = tune()) %>%
set_engine("rpart") %>%
set_mode("classification")
```
* Manually specify `\(c_p\)` values and use external resampling with a metric of your choice:
```r
decision_tree(cost_complexity = tune(), min_n = tune()) %>%
set_engine("rpart") %>%
set_mode("classification")
```
I prefer the latter approach; I believe that the automated choice tends to pick overly simple models.
---
# {recipe} and {parsnip} objects <img src="images/tune.png" class="title-hex"><img src="images/recipes.png" class="title-hex"><img src="images/parsnip.png" class="title-hex">
.pull-left-a-lot[
.font80[
```r
library(textfeatures)
library(textrecipes)
tree_rec <-
recipe(score ~ product + review, data = training_data) %>%
update_role(product, new_role = "id") %>%
step_mutate(review_raw = review) %>%
step_textfeature(review_raw) %>%
step_tokenize(review) %>%
step_stopwords(review) %>%
step_stem(review) %>%
step_texthash(review, signed = FALSE, num_terms = tune()) %>%
step_zv(all_predictors())
# and
cart_mod <-
decision_tree(cost_complexity = tune(), min_n = tune()) %>%
set_engine("rpart") %>%
set_mode("classification")
ctrl <- control_grid(save_pred = TRUE)
```
]
]
.pull-right-a-little[
Note that:
All of the text processing operations are deterministic and not reliant on any other data in the training set.
We could pre-compute the data prior to the text hashing.
Also, if we were not tuning the number of hashing terms, we could pre-compute the whole feature set with the exception of the zero-variance filter.
]
---
# Model tuning <img src="images/tune.png" class="title-hex">
```r
cart_wfl <-
workflow() %>%
add_recipe(tree_rec) %>%
add_model(cart_mod)
set.seed(2553)
cart_tune <- tune_grid(
cart_wfl,
text_folds,
grid = 10,
metrics = metric_set(roc_auc),
control = ctrl
)
show_best(cart_tune, metric = "roc_auc")
```
```
## # A tibble: 5 x 8
## cost_complexity min_n num_terms .metric .estimator mean n std_err
## <dbl> <int> <int> <chr> <chr> <dbl> <int> <dbl>
## 1 0.000118 33 3253 roc_auc binary 0.711 10 0.00852
## 2 0.00000393 25 1840 roc_auc binary 0.698 10 0.00520
## 3 0.00000000286 22 816 roc_auc binary 0.684 10 0.00601
## 4 0.000000000303 20 1748 roc_auc binary 0.684 10 0.00486
## 5 0.000526 16 1274 roc_auc binary 0.676 10 0.0103
```
---
# Parameter profiles
```r
autoplot(cart_tune)
```
<img src="images/part-6-cart-autoplot-1.svg" width="50%" style="display: block; margin: auto;" />
---
# Plotting ROC curves <img src="images/yardstick.png" class="title-hex"><img src="images/tune.png" class="title-hex"><img src="images/ggplot2.png" class="title-hex">
.font80[
.pull-left[
```r
cart_pred <- collect_predictions(cart_tune)
cart_pred %>% slice(1:5)
```
```
## # A tibble: 5 x 8
## id .pred_great .pred_other .row num_terms cost_complexity min_n score
## <chr> <dbl> <dbl> <int> <int> <dbl> <int> <fct>
## 1 Fold01 0 1 24 412 0.000000271 7 other
## 2 Fold01 0.939 0.0606 25 412 0.000000271 7 great
## 3 Fold01 0.734 0.266 26 412 0.000000271 7 other
## 4 Fold01 0.961 0.0390 46 412 0.000000271 7 other
## 5 Fold01 0.939 0.0606 48 412 0.000000271 7 great
```
```r