-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindex.html
647 lines (516 loc) · 35.8 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
<!DOCTYPE html>
<html lang="" xml:lang="">
<head>
<title>Cross-validation: What does it estimate and how well does it do it?</title>
<meta charset="utf-8" />
<meta name="author" content="Gül İnan " />
<script src="libs/header-attrs-2.11/header-attrs.js"></script>
<link href="libs/remark-css-0.0.1/default.css" rel="stylesheet" />
<script src="libs/fabric-4.3.1/fabric.min.js"></script>
<link href="libs/xaringanExtra-scribble-0.0.1/scribble.css" rel="stylesheet" />
<script src="libs/xaringanExtra-scribble-0.0.1/scribble.js"></script>
<script>document.addEventListener('DOMContentLoaded', function() { window.xeScribble = new Scribble({"pen_color":["#1A292C"],"pen_size":3,"eraser_size":30,"palette":[]}) })</script>
<link href="libs/panelset-0.2.6/panelset.css" rel="stylesheet" />
<script src="libs/panelset-0.2.6/panelset.js"></script>
<link rel="stylesheet" href="css/metropolis.css" type="text/css" />
<link rel="stylesheet" href="css/cal.css" type="text/css" />
</head>
<body>
<textarea id="source">
class: center, middle, inverse, title-slide
# <b>Cross-validation: What does it estimate and how well does it do it?</b>
## by Stephen Bates, Trevor Hastie, and Robert Tibshirani <br>
### Gül İnan <br>
### <br>Istanbul Technical University
### <br><i>Machine Learning in Montpellier, Theory & Practice, <br> January 6th, 2022 </i>
---
# Motivation
<br>
- In machine learning applications, when deploying a <b><span style="color:#067373;">predictive model</span></b>, the main interest is on understanding <b><span style="color:#067373;">prediction accuracy</span></b> of the model on future test points.
- The standard measure of accuracy for a predictive model is the <b><span style="color:#067373;">prediction error</span></b>, i.e., the <b><span style="color:#067373;">expected loss on future test points</span></b>.
- For inference, both <b><span style="color:#067373;">good point estimates</span></b> and <b><span style="color:#067373;">accurate confidence intervals</span></b> are required for <b><span style="color:#067373;">prediction error</span></b>.
---
# Motivation continu'ed
<br>
- From a practical point of view, <b><span style="color:#067373;">cross-validation (CV)</span></b> is one of the widely-used resampling-based approaches to estimate a <b><span style="color:#067373;">point estimate</span></b> and to build a <b><span style="color:#067373;">confidence interval</span></b> for prediction error.
- However, <b><span style="color:#067373;">Bates, Hastie, and Tibshirani (2021)</span></b> discuss that (statistical) properties of CV estimator of prediction error are **not well-understood** despite CV has a very simple functionality.
---
# Aim
<br>
- <b><span style="color:#067373;">Bates, Hastie, and Tibshirani (2021)</span></b> firstly show that the <b><span style="color:#067373;">CV estimator of prediction error</span></b>:
- tracks **the accuracy of the model fit weakly** and, instead
- estimates **the average prediction error of models fit across many (hypothetical) data sets from the same population**.
---
# Aim continu'ed
<br>
- <b><span style="color:#067373;">Bates, Hastie, and Tibshirani (2021)</span></b> secondly show that </span></b> the <b><span style="color:#067373;">naive confidence intervals</span></b> based on CV estimate of prediction error give **poor coverage** since the variance of error estimates used to compute the width of the interval does not account for the **correlation between error estimates** in different folds, which arises from the fact that each data point is used both in training and testing.
- <b><span style="color:#067373;">Bates, Hastie, and Tibshirani (2021)</span></b> propose the <b><span style="color:#067373;">nested cross-validation (NCV)</span></b> approach which provides confidence intervals with a **coverage close to the nominal level**.
- <b><span style="color:#067373;">Bates, Hastie, and Tibshirani (2021)</span></b> validate their work through deep theory and extensive numerical experiments (both simulation studies and real data examples).
---
# Setting and notation
<br>
- Consider a <b><span style="color:#067373;">supervised learning </span></b> setting.
- We have a data `\((X, Y)\)`, where `\(X=(X_1, \ldots, X_n) \in \mathcal{X}^{n \times p }\)` is the <b><span style="color:#067373;">feature matrix</span></b>
and `\(Y=(Y_1, \ldots, Y_n) \in \mathcal{Y}^{n}\)` is the <b><span style="color:#067373;">response vector</span></b>.
- We assume that each <b><span style="color:#067373;">data point </span></b> `\((X_i,Y_i)\)`, `\(i=1,\ldots,n\)`,
is i.i.d. from a <b><span style="color:#067373;"> distribution</span></b> `\(P\)`.
---
# Setting and notation continu'ed
<br>
- Consider a <b><span style="color:#067373;">class of models </span></b> parameterized by vector `\(\theta\)`.
- We assume that `\(\widehat f(x, \theta)\)` is the <b><span style="color:#067373;">function that predicts</span></b> `\(y\)` from `\(x \in \mathbb{R}^{p}\)` using the model with parameter `\(\theta\)`, where `\(\theta\)` takes values in some space `\(\Theta\)`.
- We let `\(\mathcal{A}\)` be a <b><span style="color:#067373;">model-fitting algorithm</span></b> that returns the <b><span style="color:#067373;">fitted value</span></b> of the parameter vector, `\(\widehat \theta =\mathcal{A}(X,Y) \in \Theta\)` based on the observed data `\((x, y)\)`.
---
# Prediction error
<br>
- In <b><span style="color:#067373;">measuring the accuracy of a model</span></b>, we are interested in <b><span style="color:#067373;"> prediction error (out-of-sample error)</span></b> which is defined as the <b><span style="color:#067373;">expected loss on future data points</span></b> `\((X_{n+1},Y_{n+1})\)`:
`\begin{equation}
Err_{XY}: = \mathbb{E}\bigg[ \ell \big(\hat{f}(X_{n+1},\hat \theta ),Y_{n+1}\big) \vert (X,Y) \bigg] , \nonumber
\end{equation}`
- where `\((X_{n+1},Y_{n+1})\)` is an <b><span style="color:#067373;">independent test point</span></b> from the same distribution `\(P\)`.
- The expression `\(\hat{f}(X_{n+1},\hat \theta)\)` is the <b><span style="color:#067373;">predicted value</span></b> of `\(Y_{n+1}\)` at the future point `\(X_{n+1}\)` and
- `\(\hat \theta\)` is the fitted value of the parameter estimated through algorithm `\(\mathcal{A}\)` based on the training data `\((X,Y)\)`.
- The expression `\(\ell \big(\hat{f}(X_{n+1},\hat \theta ),Y_{n+1}\big)\)` is the <b><span style="color:#067373;">loss</span></b> between predicted value of `\(Y_{n+1}\)` and `\(Y_{n+1}\)` itself.
- Here, the <b><span style="color:#067373;">loss function</span></b> `\(\ell(.)\)` could be squared error loss, classification error, or deviance (cross-entropy).
- Furthermore `\(Err_{XY}\)` can be considered as a <b><span style="color:#067373;">random quantity</span></b> depending on the training data `\((X,Y)\)`.
---
# Expected prediction error
<br>
- On the other hand, we may also be interested in learning algorithms' <b><span style="color:#067373;">average performance on predicting future test points</span></b> when designing and comparing <b><span style="color:#067373;">algorithms</span></b> with each other.
- This quantity of interest can be formally defined as
the <b><span style="color:#067373;">expected value of prediction error</span></b> across possible training data sets of size n drawn from the same data distribution `\(P\)`:
`\begin{equation}
Err :=\mathbb{E}\big[Err_{XY}\big].\nonumber
\end{equation}`
- Shortly: it is the expectation of prediction error across possible training sets (from the same distribution).
---
- Note that estimates of the quantities `\(Err_{XY}\)` and `\(Err\)` **cannot be computed** when the **data distribution** `\(P\)` is **unknown**.
- Then, <b><span style="color:#067373;">resampling based methods</span></b> such as cross-validation, bootstrap, and jacknife or <b><span style="color:#067373;">analytical methods</span></b> such as AIC, BIC, Mallow's `\(C_p\)`, and covariance penalties can be used to **estimate the quantities** `\(Err_{XY}\)` and `\(Err\)`.
---
# K-fold cross-validation
<br>
- In <b><span style="color:#067373;">K-fold cross-validation</span></b>, we randomly partition the data `\((X,Y) = \mathcal{I}\)` into `\(K\)` equally sized <b><span style="color:#067373;">disjoint
folds (subsets)</span></b> `\(\mathcal{I}_k\)` `\((k=1,\ldots,K)\)`.
- Here the fold size is `\(m=n/K\)` and the whole data is `\(\mathcal{I} = \cup_{k=1}^{K} \mathcal{I}_k\)`.
- When the data point `\((x_i,y_i) \in \mathcal{I}_k\)`, we will also write `\(i \in \mathcal{I}_k\)` `\((k=1,\ldots,K)\)`.
<div class="figure" style="text-align: center">
<img src="imgs/K_CV_single.png" alt="&lt;i&gt;Partition of the data into K-folds.&lt;/i&gt;" width="13%" />
<p class="caption"><i>Partition of the data into K-folds.</i></p>
</div>
<div style="text-align:right;"><font size="4px"><a href="https://www.researchgate.net/profile/Mingchao-Li/publication/331209203/figure/fig2/AS:728070977748994@1550597056956/K-fold-cross-validation-method.png">Image Source</a></font></div>
---
- Consider the first fold `\(\mathcal{I}_{1}\)` and hold it out as future data set (or <b><span style="color:#067373;">test set</span></b>).
- The **remaining data points** `\((x_i,y_i) \in \mathcal{I} \backslash \mathcal{I}_{1}\)`, which are not in the first fold, are called as <b><span style="color:#067373;">training set</span></b>.
- Let `\(\hat \theta^{(-1)} = \mathcal{A}\big((X_i,Y_i)_{i \in \mathcal{I} \backslash \mathcal{I}_{1}}\big)\)` be the **parameter estimate based on training data set**, then
we can calculate the **prediction error** for future data set `\(\mathcal{I}_{1}\)`, of size `\(m\)`, as follows:
`\begin{equation}
\frac{1}{m} \sum_{{i} \in \mathcal{I}_1}\ell \big(\hat{f}(x_{i},\hat \theta^{(-1) } ),y_i\big).
\end{equation}`
---
- In <b><span style="color:#067373;">K-fold cross-validation</span></b>, we **iteratively repeat** this process for each fold `\((k=1,\ldots,K)\)`.
<div class="figure" style="text-align: center">
<img src="imgs/K_CV.png" alt="&lt;i&gt;K-fold cross-validation.&lt;/i&gt;" width="60%" />
<p class="caption"><i>K-fold cross-validation.</i></p>
</div>
<div style="text-align:right;"><font size="4px"><a href="https://www.researchgate.net/profile/Mingchao-Li/publication/331209203/figure/fig2/AS:728070977748994@1550597056956/K-fold-cross-validation-method.png">Image Source</a></font></div>
---
# CV estimate of prediction error
<br>
- The **average of prediction errors over K folds** is given as follows:
`\begin{equation}
\widehat {Err} ^{(CV)} := \frac{1} {K} \sum_{k=1}^{K} \frac{1} {m}\sum_{{i} \in \mathcal{I}_k} \ell \big(\hat{f}(x_{i},\hat \theta^{(-1)}),y_i\big).
\end{equation}`
- This is usually called as the <b><span style="color:#067373;">CV estimate of prediction error</span></b>.
- **Relationship between `\(\widehat {Err} ^{(CV)}\)`, `\(Err_{XY}\)`, and `\(Err\)`**: Intuitively, the inner sum is an estimate for `\(Err_{XY}\)` for a fixed fold, and the double sum estimates `\(Err\)` with `\(\mathcal{I}_k\)` `\(( k=1,\ldots,K)\)` being different samples from the same distribution <b><span style="color:#067373;">(De Benito Delgado, 2021)</span></b>.
---
# What prediction error are we estimating?
<br>
- `\(Err_{XY}\)` is the <b><span style="color:#067373;">prediction error of the model which is fit on the training data set</span></b>.
- `\(Err\)` is the <b><span style="color:#067373;">average of the fitting algorithm runs on the same-sized data sets drawn
from the same distribution `\(P\)`.</span></b>
- The **former quantity** is of the most interest to a practitioner **deploying a specific model**, whereas the **latter** may be of interest to a researcher **comparing different fitting algorithms**.
---
- Some earlier studies such as Zhang (1995), Hastie et al. (2009), and Yousef (2020)
have observed that **cross-validation estimate provides little information** about `\(Err_{XY}\)`, which is also called as _weak correlation_ problem in the literature.
- For the special case of the linear model, <b><span style="color:#067373;">Bates, Hastie, and Tibshirani (2021)</span></b> claim that CV estimate should be **considered as an estimate** of `\(Err\)` rather than `\(Err_{XY}\)`.
---
# `\(Err_{X}\)`: A different target of inference
<br>
- Assume a <b><span style="color:#067373;">homoskedastic Gaussian linear model</span></b> as follows:
`\begin{equation}
y_i = x_{i}^{T} \theta + \epsilon_i \quad \text{where} \quad \epsilon_i \overset{i.i.d.}{\sim} N(0,\sigma^2) \quad \text{i=1,...,n.}
\end{equation}`
- Define a <b><span style="color:#067373;">new (mid) key quantity</span></b> as follows:
`\begin{equation}
Err_{X} := \mathbb{E}[Err_{XY} | X],
\end{equation}`
- which **falls between** `\(Err\)` and `\(Err_{XY}\)` as visualized below:
<br>
<div class="figure" style="text-align: center">
<img src="imgs/paper_fig2.png" alt="&lt;i&gt;Possible targets of inference for cross-validation.&lt;/i&gt;" width="60%" />
<p class="caption"><i>Possible targets of inference for cross-validation.</i></p>
</div>
---
$$ \newcommand\independent{\protect\mathpalette{\protect\independenT}{\perp}}
\def\independenT#1#2{\mathrel{\rlap{$#1#2$}\mkern2mu{#1#2}}}$$
- **Lemma 1:** When ordinary least squares (OLS) is used as the fitting
algorithm along with a squared-error loss function, the CV estimate of prediction error,
`\(\widehat {Err}^{(CV)}\)`, is linearly invariant.
--
- << Under this setting, residuals turns out to be the same for both **original** `\((x_{1}, y_{1}),...,(x_{n}, y_{n})\)` and **shifted data** `\((x_{1}, y_{1}^{'}),...,(x_{n}, y_{n}^{'})\)`, where `\((y_{i}^{'} = y_{i} + x_{i}^{T} \kappa)\)`. Since the CV estimate of prediction error is the mean of the squared residuals, the CV estimate of prediction error also turns out to be **the same** for both the **original data** and the **shifted data**. >>
---
- **Theorem 1:** Assume homoskedastic Gaussian linear model holds and that we use
squared-error loss function. Let `\(\widehat {Err}\)` be a **linearly invariant estimate of prediction
error** (such as `\(\widehat {Err}^{(CV)}\)` using OLS as the fitting algorithm). Then,
`\begin{equation}
Err_{XY} \perp \widehat {Err} \quad | \quad X.
\end{equation}`
--
- << Recall from classical linear regression theory that when using OLS, the **estimated coefficient** vector `\(\widehat {\theta}\)` is independent of the **residuals** `\((Y-X\widehat {\theta})\)`:
`\begin{equation}
\widehat {\theta} \perp (Y-X\widehat {\theta}) \quad | \quad X.
\end{equation}`
- Since `\(Err_{XY}\)` is a function of `\(\widehat {\theta}\)` only, which is the OLS estimate of `\(\theta\)`, and any linearly invariant estimate of prediction error `\(Err\)` is a function only of residuals, `\(Y-X\widehat {\theta}\)`, by the invariance property, then `\(Err_{XY} \perp \widehat{Err} \quad | \quad X\)`. >>
---
- **Corollary 1:** Under the conditions of Theorem 1, we get the following decomposition:
`\begin{equation}
\mathbb{E}\Big[\big(\widehat Err-Err_{XY}\big)^2\Big] = \mathbb{E}\Big[\big(\widehat Err-Err_{X}\big)^2\Big] + \mathbb{E}\Big[Var\big(Err_{XY}|X\big)\Big].
\end{equation}`
--
- << Any linearly invariant estimator (such as cross-validation) has **larger mean squared error** (MSE) as an estimate of `\(Err_{XY}\)` than as an estimate of `\(Err_{X}\)`. >>
- This implies that `\(\widehat Err^{CV}\)` is a better estimate of the intermediate quantity `\(Err_{X}\)` than of `\(Err_{XY}\)`.
<br>
<div class="figure" style="text-align: center">
<img src="imgs/paper_fig2.png" alt="&lt;i&gt;Possible targets of inference for cross-validation.&lt;/i&gt;" width="60%" />
<p class="caption"><i>Possible targets of inference for cross-validation.</i></p>
</div>
---
# Example
<br>
- Consider an experiment in a simple linear model with `\(n=100\)` observations
and `\(p=20\)` features `\(\overset{i.i.d.}{\sim} N(0,1)\)`, which is replicated `\(2000\)` times.
- **MSE** of `\(\widehat {Err} ^{(CV)}\)` relative to **three estimands**: `\(Err\)`, `\(Err_{X}\)`, and `\(Err_{XY}\)`.
<img src="imgs/paper_fig3a.png" width="30%" style="display: block; margin: auto;" />
- **Side note:** Each pair of points connected by a line represents the 2000 replicates with the same feature matrix `\(X\)`.
---
- We see that `\(\widehat {Err}^{(CV)}\)` has **lower MSE** for `\(Err_{X}\)` than `\(Err_{XY}\)`.
- These results suggest that, `\(Err_{X}\)` is a more **natural target of inference** (estimand) rather than `\(Err_{XY}\)` for `\(\widehat {Err} ^{(CV)}\)`.
<br>
<img src="imgs/paper_fig3a.png" width="30%" style="display: block; margin: auto;" />
---
# Relationship between `\(Err\)` and `\(Err_{X}\)`
<br>
- <b><span style="color:#067373;">Bates, Hastie, and Tibshirani (2021)</span></b> further investigate the **relationship** between `\(Err\)` and `\(Err_{X}\)` through asymptotic analysis.
<br>
<div class="figure" style="text-align: center">
<img src="imgs/paper_fig2.png" alt="&lt;i&gt;Possible targets of inference for cross-validation.&lt;/i&gt;" width="60%" />
<p class="caption"><i>Possible targets of inference for cross-validation.</i></p>
</div>
- <b><span style="color:#067373;">Bates, Hastie, and Tibshirani (2021)</span></b> show that the variance of `\(Err_{X}\)` (which also has mean `\(Err\)`) is **small** compared with the variance of `\(Err_{XY}\)` (which also has mean `\(Err\)`), showing that `\(Err_{X}\)` is **close** to `\(Err\)`.
---
- Note that in the example, MSE of `\(\widehat {Err} ^{(CV)}\)` is **similar** when estimating either `\(Err\)` or `\(Err_{X}\)`, but **significantly different** when estimating `\(Err_{XY}\)`.
<img src="imgs/paper_fig3a.png" width="30%" style="display: block; margin: auto;" />
---
- <b><span style="color:#067373;">Bates, Hastie, and Tibshirani (2021)</span></b> show that `\( \widehat {Err} ^{(CV)}\)`
is **closer** to `\(Err\)` and `\(Err_{X}\)` than to `\(Err_{XY}\)` in the proportional asymptotic limit (for `\(n > p\)`, as `\(n, p \rightarrow \infty\)` with `\(n/p \rightarrow \lambda >1\)`).
- Combined with the earlier results, this implies that `\( \widehat {Err} ^{(CV)}\)` is a **better estimator** for `\(Err\)` than for `\(Err_{XY}\)`.
- <b><span style="color:#067373;">Bates, Hastie, and Tibshirani (2021)</span></b> also show that `\( \widehat{Err}^{(CV)}\)` is **asymptotically uncorrelated** with `\(Err_{XY}\)`.
---
# Dependence structure of CV errors
<br>
- Let
`\(e_i =\ell \big(\hat{f}(x_{i},\hat \theta^{(-1) } ),y_i\big)\)` be the **error** for each `\(i \in \mathcal{I}_k\)` ( `\(k=1,\ldots,K\)`), resulting _m_ different `\(e_i\)`'s for each `\(\mathcal{I}_k\)`.
- Then, we can **re-define** <b><span style="color:#067373;">CV point estimate of prediction error</span></b> as the **average of errors** as follows:
`\begin{equation}
\widehat {Err } ^{(CV)}:= \frac{1} {K} \sum_{k=1}^{K} \frac{1} {m}\sum_{{i} \in \mathcal{I}_k} \ell \big(\hat{f}(x_{i},\hat \theta^{(-1)}),y_i\big)
= \frac{1} {n} \sum_{i=1}^{n}e_i
= \bar e,
\end{equation}`
- where `\(n = K \times m\)`.
---
- Assuming that `\(e_i\)`'s are i.i.d., then <b><span style="color:#067373;">estimate of the standard error of CV point estimate of prediction error</span></b> would be:
`\begin{equation}
\widehat{se}^{(CV)}:=\frac{1} {\sqrt n} \times \sqrt {\frac{1} {n-1} \sum_{i=1}^{n}(e_i-\bar e)^2}, \nonumber
\end{equation}`
- where the second term in the multiplication refers to the <b><span style="color:#067373;">empirical standard deviation</span></b> of the `\(e_i\)`.
---
# A `\(100(1-\alpha)\%\)` confidence interval for prediction error
<br>
- A `\(100(1-\alpha)\%\)` <b><span style="color:#067373;">confidence interval for prediction error</span></b> can be constructed as follows:
`\begin{equation}
\Big(\widehat {Err}^{(CV)} - z_{1-(\frac{\alpha}{2})} \times \widehat{se}^{(CV)} \quad, \quad \widehat {Err}^{(CV)} + z_{1-(\frac{\alpha}{2})} \times \widehat{se}^{(CV)} \Big),
\end{equation}`
- where `\(0 < \alpha < 1\)`, `\(z_{1- (\frac{\alpha}{2}) }\)` is the `\(1- (\frac{\alpha}{2})\)` quantile of the standard normal distribution.
- The intervals are called as <b><span style="color:#067373;">naive cross-validation intervals</span></b>.
- However, since every data point is used in both in training and testing, we **cannot accept** that are `\(e_i\)`'s are **independent** of each other.
- Any **confidence interval** built on this assumption would have **poor coverage**.
---
# Example re-visited
<br>
- The <b><span style="color:#067373;">naive cross-validation intervals</span></b> for **three estimands**: `\(Err\)`, `\(Err_{X}\)`, and `\(Err_{XY}\)` are built and miscoverage rates are estimated.
<img src="imgs/paper_fig3b.png" width="30%" style="display: block; margin: auto;" />
- The naive CV intervals **fail** due to large miscoverage rates.
- **Side note:** The nominal miscoverage rate is `\(10\%\)`.
---
<style type="text/css">
.pull-left {
float: left;
width: 50%;
}
.pull-right {
float: right;
width: 50%;
}
</style>
.pull-left[
- The fundamental paper of <b><span style="color:#067373;">Bengio and Grandvalet (2004)</span></b> gives the key structure of the covariance matrix of `\(e_i\)`'s such that:
`\begin{equation}
Var(\widehat {Err}^{(CV)}) = \frac{1}{n}a_1 + \frac{n/K-1}{n}a_2 + \frac{n-n/K}{n}a_3,
\end{equation}`
- where `\(a_1=Var(e_i)\)` is the **variance of the diagonal elements**,
- `\(a_2 = Cov(e_i, e_j)\)` is the **covariance of the off-diagonal elements within the same fold** (in-block covariance of errors due to a common training set), and
- `\(a_3 = Cov(e_i, e_j)\)` is the **covariance between-blocks**, covariance due the dependence between training sets `\(\mathcal{I}_k\)` `\((k=1,\ldots,K)\)`.
]
.pull-right[
<div class="figure" style="text-align: center">
<img src="imgs/covariance_structure.png" alt="&lt;i&gt;Structure of the covariance matrix of errrors.&lt;/i&gt;" width="60%" />
<p class="caption"><i>Structure of the covariance matrix of errrors.</i></p>
</div>
<div style="text-align:right;"><font size="4px"><a href="https://www.jmlr.org/papers/volume5/grandvalet04a/grandvalet04a.pdf">Image Source</a></font></div>
]
---
- The constants `\(a_2\)` and `\(a_3\)` will typically be positive, in which case:
`\begin{equation}
Var(\widehat {Err}^{(CV)}) > \frac{1}{n}a_1,
\end{equation}`
- The naive cross-validation intervals **implicitly assume** `\(a_2=0\)` and `\(a_3=0\)`.
- Hence estimating the variance of `\(\widehat{Err}^{(CV)}\)` as `\(\widehat{se}^2\)` results in an estimate that it is too **small**, and, in turn, **poor coverage**.
---
# Target of inference: Confidence intervals
<br>
- <b><span style="color:#067373;">Bates, Hastie, and Tibshirani (2021)</span></b>
develops an **estimator** that **empirically estimate the variance** of `\(\widehat{Err}^{(CV)}\)` across many subsamples.
- **Definition 2:** For a sample of size n split into K folds, the cross-validation MSE is:
`\begin{equation}
MSE_{K,n} := \mathbb{E}\big[ (\widehat {Err}^{(CV)} - Err_{XY})^2 \big].
\end{equation}`
- MSE contains both a **bias term** and **variance term**, but the **bias** typically
**small** for cross-validation (Efron, 1983; Efron and Gong, 1983; Efron and Tibshirani, 1997).
- **MSE** can be viewed as a slightly conservative version of the **variance** of the cross-validation estimator.
- The estimate of the MSE can be used to construct confidence intervals for `\(Err_{XY}\)` since it is what
typically matters for practical applications.
---
- **Lemma 2:** For a single split, randomly partition the data into a training set with `\(K-1\)` folds and denote it by
`\(\mathcal{I}_{(train)} = \cup_{k=1}^{K-1} \mathcal{I}_k=(\tilde X, \tilde Y)\)`, and call the remaining fold
as `\(\mathcal{I}_{(out)}\)`. Using only `\((\tilde X, \tilde Y)\)`, define the prediction error `\(Err_{\tilde X, \tilde Y}\)`
and an estimator `\(\widehat Err_{\tilde X, \tilde Y}\)` such as cross-validation, as usual. For the hold-out data set, calculate errors `\(\{e^{(out)}\}_{i \in {I}_{(out)}}\)` and their average
`\(\bar e^{(out)}\)`.
- Then, estimate the MSE from the data as follows:
`\begin{equation}
\mathbb{E}\big[ (\widehat {Err}_{\tilde X, \tilde Y} - Err_{\tilde X, \tilde Y})^2 \big] = \mathbb{E}\big[ (\widehat {Err}_{\tilde X, \tilde Y} -\bar e^{(out)})^2 \big] - \mathbb{E}\big[ (\bar e^{(out)} - Err_{\tilde X, \tilde Y})^2 \big].
\end{equation}`
<img src="imgs/paper_fig8.png" width="60%" style="display: block; margin: auto;" />
---
# Nested cross-validation (CV) estimate of MSE
#### <b><span style="color:#067373;">Nested CV algorithm</span></b>
- **Repeatedly** split data into `\(K\)` folds with `\(K-1\)` building `\(\mathcal{I}_{(train)}\)` and the remaining
one being `\(\mathcal{I}_{(out)}\)`.
- For each split j:
- Compute `\(\epsilon_j := \widehat {Err}_{\tilde X \tilde Y}\)` with (K-2)-fold
cross-validation over K-1 folds in `\(\mathcal{I}_{(train)}\)`.
- Train model on `\(\mathcal{I}_{(train)}\)` and compute errors `\(e_{i}\)` for all data points `\((x_i,y_i) \in \mathcal{I}_{(out)}\)`.
- Compute `\(\bar e_{out} := \text{mean of} \quad \{i\}_{e \in I_{out}}\)`.
- Set `\(a_j := (\widehat {Err}_{\tilde X \tilde Y} - \bar e_{out})^2\)` (estimate of (the first term at RHS)).
- Set `\(b_j := \text{empirical variance of} \quad \{e_i\}_{i \in I_{out}}\)` (estimate of (the second term at RHS)).
- Output `\(\widehat{MSE} = mean(a_j)-mean(b_j)\)`.
- Output `\(\widehat {Err}^{(NCV)} := mean(\epsilon_j)\)`.
<img src="imgs/paper_fig8.png" width="40%" style="display: block; margin: auto;" />
---
- Nested CV algorithm provides us a **point estimate** for prediction error, denoted by `\(\widehat {Err}^{(NCV)}\)`, and **estimate for MSE**, denoted by `\(\widehat{MSE}\)`.
- **Theorem 2:** For a nested CV with a sample of size `\(n\)`:
`\begin{equation}
\mathbb{E}\big[ \widehat{MSE} \big] := MSE_{K-1,((K-1)n/K)},
\end{equation}`
- where `\(n/K\)` is the fold size.
- Since the estimation is done over `\(K-1\)` folds, `\(\widehat{MSE}\)` estimates the actual quantity of interest `\(MSE_{K,n}\)` with some **bias**.
- `\(\widehat{MSE}\)` is **rescaled** to obtained unbiased estimate for `\(MSE_{K,n}\)`.
- Similarly, `\(\widehat {Err}^{(NCV)}\)` is also adjusted (**de-biased**).
---
# A `\(100(1-\alpha)\%\)` confidence interval
<br>
- Finally, a `\(100(1-\alpha)\%\)` confidence interval is obtained as follows:
`\begin{equation}
\Big(\widehat {Err}^{(NCV)} - \widehat {bias}- z_{1-(\frac{\alpha}{2})} \widehat{se}^{(NCV)}, \quad \widehat {Err}^{(NCV)} + \widehat {bias}- z_{1-(\frac{\alpha}{2})} \widehat{se}^{(NCV)}\Big),
\end{equation}`
- where
- \begin{equation}
\widehat {bias} := \Big(1 + \big(\frac{K-2}{K}\big)\Big) (\widehat {Err}^{(NCV)} - \widehat {Err}^{(CV)}) \quad \text{and} \quad
\widehat{se}^{(NCV)} :=\sqrt{\frac{K-1}{K}}\sqrt{\widehat {MSE}}.
\end{equation}
---
# Simulation experiments: Data generation scenario
<br>
- The **coverage** of nested CV intervals approach is investigated for **classification** and **regression** problems over synthetic data sets (and real data sets).
- Consider a **sparse logistic data generating model**:
`\begin{align}
Pr(Y_i = 1 | X_i = x_i) = \frac{1}{1 + \exp(-x_{i}^{T}\theta)}, \quad \text{i=1,...,n},
\end{align}`
- where `\(n\)` is the number of observations, `\(p\)` is the number of features,
- `\(X_i\)` is the feature matrix consisting of i.i.d standard Gaussian variables,
- the coefficient
`\(\theta = c \times (1,1,1,1,0,0,...)^{T} \in \mathcal{R}^{p}\)` and
- `\(c\)` is chosen such that <b><span style="color:#067373;">Bayes error</span></b> is either `\(33.2\%\)` or `\(22.5\%\)` which is the **optimal lower bound** for `\(Err\)`.
---
# Simulation experiments: Performance metrics
<br>
- In each case:
- The **miscoverage** of naive CV (CV) and nested CV (NCV) intervals are reported where
the **nominal miscoverage** rate is `\(10\%\)`.
- The **width** of NCV intervals are expressed relative to the width of CV intervals.
- 10-fold CV and 10-fold-NCV with `\(200\)` splits are used.
- `\(R\)` scripts of reproducing experiments are available at: https://github.com/stephenbates19/nestedcv_experiments.
---
# Simulation results: Low-dimensional setting results
<br>
- `\(n=100\)`, `\(p=20\)`, and (un-regularized) logistic regression is used as fitting algorithm.
- Nested CV gives **coverage closer** to the nominal level.
<img src="imgs/result1.png" width="100%" style="display: block; margin: auto;" />
- A "Hi" miscoverage is one where the **confidence interval is too large** and the point estimate falls below the interval; conversely for a "Lo" miscoverage.
---
# Simulation results: High-dimensional setting results
<br>
- `\(n={90,200}\)`, `\(p=100\)`, and `\(\ell_1\)` penalized logistic regression with a fixed
penalty level is used as fitting algorithm.
- Nested CV gives **coverage more closer** to the nominal level.
<img src="imgs/result2.png" width="100%" style="display: block; margin: auto;" />
---
# Discussion & Future work
<br>
- Nested CV is more **computationally intensive** than standard CV, but, parallel programming can be used for speeding up the algorithm.
- It is well-known that CV is also commonly used for selecting a good value of a learning algorithm's hyperparameters (fine-tuning).
- <b><span style="color:#067373;">Bates, Hastie, and Tibshirani (2021)</span></b> expect that NCV would be of use for hyperparameter selection since it yields more accurate confidence intervals for prediction error.
---
class: center, middle
<br> Merci!.. </br>
</textarea>
<style data-target="print-only">@media screen {.remark-slide-container{display:block;}.remark-slide-scaler{box-shadow:none;}}</style>
<script src="https://remarkjs.com/downloads/remark-latest.min.js"></script>
<script>var slideshow = remark.create({
"highlightStyle": "atelier-forest-light",
"highlightLines": true,
"highlightSpans": true,
"countIncrementalSlides": true,
"ratio": "16:9"
});
if (window.HTMLWidgets) slideshow.on('afterShowSlide', function (slide) {
window.dispatchEvent(new Event('resize'));
});
(function(d) {
var s = d.createElement("style"), r = d.querySelector(".remark-slide-scaler");
if (!r) return;
s.type = "text/css"; s.innerHTML = "@page {size: " + r.style.width + " " + r.style.height +"; }";
d.head.appendChild(s);
})(document);
(function(d) {
var el = d.getElementsByClassName("remark-slides-area");
if (!el) return;
var slide, slides = slideshow.getSlides(), els = el[0].children;
for (var i = 1; i < slides.length; i++) {
slide = slides[i];
if (slide.properties.continued === "true" || slide.properties.count === "false") {
els[i - 1].className += ' has-continuation';
}
}
var s = d.createElement("style");
s.type = "text/css"; s.innerHTML = "@media print { .has-continuation { display: none; } }";
d.head.appendChild(s);
})(document);
// delete the temporary CSS (for displaying all slides initially) when the user
// starts to view slides
(function() {
var deleted = false;
slideshow.on('beforeShowSlide', function(slide) {
if (deleted) return;
var sheets = document.styleSheets, node;
for (var i = 0; i < sheets.length; i++) {
node = sheets[i].ownerNode;
if (node.dataset["target"] !== "print-only") continue;
node.parentNode.removeChild(node);
}
deleted = true;
});
})();
(function() {
"use strict"
// Replace <script> tags in slides area to make them executable
var scripts = document.querySelectorAll(
'.remark-slides-area .remark-slide-container script'
);
if (!scripts.length) return;
for (var i = 0; i < scripts.length; i++) {
var s = document.createElement('script');
var code = document.createTextNode(scripts[i].textContent);
s.appendChild(code);
var scriptAttrs = scripts[i].attributes;
for (var j = 0; j < scriptAttrs.length; j++) {
s.setAttribute(scriptAttrs[j].name, scriptAttrs[j].value);
}
scripts[i].parentElement.replaceChild(s, scripts[i]);
}
})();
(function() {
var links = document.getElementsByTagName('a');
for (var i = 0; i < links.length; i++) {
if (/^(https?:)?\/\//.test(links[i].getAttribute('href'))) {
links[i].target = '_blank';
}
}
})();
// adds .remark-code-has-line-highlighted class to <pre> parent elements
// of code chunks containing highlighted lines with class .remark-code-line-highlighted
(function(d) {
const hlines = d.querySelectorAll('.remark-code-line-highlighted');
const preParents = [];
const findPreParent = function(line, p = 0) {
if (p > 1) return null; // traverse up no further than grandparent
const el = line.parentElement;
return el.tagName === "PRE" ? el : findPreParent(el, ++p);
};
for (let line of hlines) {
let pre = findPreParent(line);
if (pre && !preParents.includes(pre)) preParents.push(pre);
}
preParents.forEach(p => p.classList.add("remark-code-has-line-highlighted"));
})(document);</script>
<script>
slideshow._releaseMath = function(el) {
var i, text, code, codes = el.getElementsByTagName('code');
for (i = 0; i < codes.length;) {
code = codes[i];
if (code.parentNode.tagName !== 'PRE' && code.childElementCount === 0) {
text = code.textContent;
if (/^\\\((.|\s)+\\\)$/.test(text) || /^\\\[(.|\s)+\\\]$/.test(text) ||
/^\$\$(.|\s)+\$\$$/.test(text) ||
/^\\begin\{([^}]+)\}(.|\s)+\\end\{[^}]+\}$/.test(text)) {
code.outerHTML = code.innerHTML; // remove <code></code>
continue;
}
}
i++;
}
};
slideshow._releaseMath(document);
</script>
<!-- dynamically load mathjax for compatibility with self-contained -->
<script>
(function () {
var script = document.createElement('script');
script.type = 'text/javascript';
script.src = 'https://mathjax.rstudio.com/latest/MathJax.js?config=TeX-MML-AM_CHTML';
if (location.protocol !== 'file:' && /^https?:/.test(script.src))
script.src = script.src.replace(/^https?:/, '');
document.getElementsByTagName('head')[0].appendChild(script);
})();
</script>
</body>
</html>