-
Notifications
You must be signed in to change notification settings - Fork 2
/
validateit.ado
520 lines (348 loc) · 14.3 KB
/
validateit.ado
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
/*******************************************************************************
* *
* Handles the validation/testing part of the process *
* *
*******************************************************************************/
*! validateit
*! v 0.0.16
*! 05mar2024
// Drop program from memory if already loaded
cap prog drop validateit
// Define program
prog def validateit, rclass
// Version statement
version 15
// Syntax
syntax , MEtric(string asis) PStub(string asis) SPLit(varname) ///
[ Obs(varname) MOnitors(string asis) DISplay KFold(integer 1) noall loo ///
NAme(string asis) ]
// Test to ensure the metric is not included in the monitor
if `: list metric in monitors' {
// Display an informative message
di as error "The metric `metric' is included in the monitors `monitors'."
// Throw an error code
err 134
} // End IF Block to handle metric included in monitors
// Test if missing observed outcome variable name
if mi("`obs'") & mi("`e(depvar)'") {
// Display an error message
di as err "If the dependent variable is not passed to {opt obs} it " ///
"must be accessible in e(depvar)."
// Throw an error code and exit
err 100
} // End IF Block for unknown dependent variable
// If no argument is passed to the option but it is found in e(depvar)
else if mi("`obs'") & !mi("`e(depvar)'") loc obs `e(depvar)'
// Test for invalid KFold option
if `kfold' < 1 {
// Display an error message
di as err "There must always be at least 1 K-Fold. This would be " ///
"the training set in a simple train/test split. You specified " ///
"`kfold' K-Folds."
// Return error code and exit
err 198
} // End IF Block for invalid K-Fold argument
// Test for invalid kfold with loo option
if `kfold' == 1 & !mi("`loo'") {
// Display an error message
di as err "Leave-One-Out cross-validation cannot be used with a " ///
"single K-Fold."
// Return error code and exit
err 198
} // End IF block for invalid kfold & loo combination
// When using K-Fold and not specifying noall
if `kfold' > 1 & mi(`"`all'"') {
// Capture the code from confirming the *all variable's presence
cap: confirm v `pstub'all
// If this fails
if _rc != 0 {
// Print an error message to the console
di as err "The variable `pstub'all was not found and you are " ///
"requesting evaluating metrics that require that variable." _n ///
"You can either pass the noall option or predict the values " ///
"from your models again to generate the `pstub'all variable."
// Throw an error code and exit
err 111
} // End IF Block for missing `pstub'all variable
} // End IF Block for detecting missing `pstub'all w/K-Fold and missing noall
// Parse the metric option
_parse_monitors `metric'
// Verify that there is only a single metric
if `r(n)' > 1 {
// Display an error message
di as err "Users can only specify a single metric."
// Throw an error code
err 134
} // End IF Block for invalid number of metric
// Create macro to store all returned scalar names
loc allnms
// Mark the sample that will be used to compute the validation metrics for
// each K-Fold
tempvar touse
// Create the tempvariable used to identify the set to use for validation
qui: g byte `touse' = 0
// Figure out the number of splits used in the dataset
mata: st_numscalar("vals", rows(uniqrows(st_data(., "`split'"))))
// There will be two ID values > kfold in a TVT split
if `vals' - `kfold' == 2 loc ditxt "Validation Set"
// Otherwise it should be a TT split
else loc ditxt "Test Set"
// Set display related macros
if !mi("`display'") {
// Defines macros to use to construct the display strings used below
loc kfditxt "for K-Fold #\`k'"
loc kfalttxt "for results on entire Training Set"
loc montxt "Monitor Results"
loc metrictxt "Metric Result"
} // End IF Block for user requested display
// Check if the name parameter is missing or not
if mi(`"`name'"') loc name xvval
// Create a collection using the default name
if `c(stata_version)' >= 17 qui: collect create `name', replace
// Locate the labels for the metrics
cap: findfile xvlabels.stjson
// If the file is located
if _rc == 0 & `c(stata_version)' >= 17 {
// Load the capture labels
collect label use `"`r(fn)'"', name(`name')
} // End IF Block to load collection labels for validation metrics
// If there is only a single fold
if `kfold' == 1 & mi("`loo'") {
// Set the touse tempvariable
qui: replace `touse' = cond(`split' == 2, 1, 0)
// Calls subroutine to compute all of the validation metrics/monitors
// and return them
getstats, me(`metric') p(`pstub') o(`obs') t(`touse') st(xv) ///
monitors(`monitors')
// Adds the names so all monitor/metric names can be returned
loc allnms `r(names)'
// Loop over the returned names
foreach i in `r(names)' {
// Return the corresponding scalars
ret sca `i' = r(`i')
} // End Loop over the returned scalars
// Return the matrix with all of the results
matrix res = r(mtrx)
// Set the rownames
mat rownames res = `r(names)'
// Set the column name
mat colnames res = "`ditxt'"
} // End IF Block for no-K-Folds
// If this involves K-Fold CV
else if `kfold' > 1 & mi("`loo'") {
// Initialize this to see if it helps with removing the quotation marks
// when used below
loc colnms
// Loop over the K-Folds
forv k = 1/`kfold' {
// Sets local macro with column names
loc colnms `"`colnms' "Fold `k'""'
// Set the value of the touse tempvariable
qui: replace `touse' = cond(`split' == `k', 1, 0)
// Calls subroutine to compute all of the validation metrics/monitors
// and return them
getstats, me(`metric') p(`pstub') o(`obs') t(`touse') st(xv) ///
monitors(`monitors') sf(`k')
// Adds the names so all monitor/metric names can be returned
loc allnms `r(names)'
// Loop over the returned names
foreach i in `r(names)' {
// Return the corresponding scalars
ret sca `i' = r(`i')
} // End Loop over the returned scalars
// Gets the matrix returned by getstats
if `k' == 1 mat res = r(mtrx)
// Return the matrix with all of the results
else mat res = (res, r(mtrx))
// Resets the value of this macro
loc rnames
// If the user does not specify noall
if `k' == `kfold' & mi(`"`all'"') {
// Adds the last column name
loc colnms `"`colnms' "`ditxt'""'
// Update the variable that IDs the sample to use for the metrics
qui: replace `touse' = cond(`split' == `= `kfold' + 1', 1, 0)
// Call the subroutine with modified arguments (note the use of all)
getstats, me(`metric') p(`pstub'all) o(`obs') t(`touse') st(xv) ///
monitors(`monitors') sf(all)
// Adds the names of these scalars to the allnms macro
loc allnms `allnms' `r(names)'
// Loop over the returned scalar names
foreach i in `r(names)' {
// Return those scalars
ret sca `i' = r(`i')
} // End Loop over the returned scalars
// Update the matrix to include the additional results from the
// validation/test split
matrix res = (res, r(mtrx))
} // End IF Block to compute metrics on the validation/test split
} // End Loop over K-Folds
// Set rownames for the returned matrix based on the monitors/metrics
mat rownames res = `r(names)'
// Set the column names for the returned matrix based on the number of
// K-Folds and what style of split is used
mat colnames res = `colnms'
} // End ELSE Block for K-Fold CV
// Otherwise it will be for leave-one-out CV
else if `kfold' > 1 & !mi("`loo'") {
// Set the value of the touse tempvariable
qui: replace `touse' = cond(`split' <= `kfold', 1, 0)
// Calls subroutine to compute all of the validation metrics/monitors
// and return them
getstats, me(`metric') p(`pstub') o(`obs') t(`touse') st(xv) ///
monitors(`monitors') sf(1)
// Adds the names so all monitor/metric names can be returned
loc allnms `r(names)'
// Loop over the returned names
foreach i in `r(names)' {
// Return the corresponding scalars
ret sca `i' = r(`i')
} // End Loop over the returned scalars
// Return the matrix with all of the results
matrix res = r(mtrx)
// Resets the value of this macro
loc rnames
// If the user does not specify noall
if mi(`"`all'"') {
// Update the variable that IDs the sample to use for the metrics
qui: replace `touse' = cond(`split' == `= `kfold' + 1', 1, 0)
// Call the subroutine with modified arguments (note the use of all)
getstats, me(`metric') p(`pstub'all) o(`obs') t(`touse') st(xv) ///
monitors(`monitors') sf(all)
// Adds the names of these scalars to the allnms macro
loc allnms `allnms' `r(names)'
// Loop over the returned scalar names
foreach i in `r(names)' {
// Return those scalars
ret sca `i' = r(`i')
} // End Loop over the returned scalars
// Update the matrix to include the additional results from the
// validation/test split
matrix res = (res, r(mtrx))
} // End IF Block to compute metrics on the validation/test split
// Set rownames for the returned matrix based on the monitors/metrics
mat rownames res = `r(names)'
// Set column names for the returned matrix based on the samples
mat colnames res = "Leave-One-Out" "`ditxt'"
} // End ELSEIF Block for LOO CV case
// Returns a macro containing the names of all scalars returned
ret loc allnames = "`allnms'"
// Returns a matrix containing all of the results
ret mat xv = res, copy
// If the display option is passed
if !mi("`display'") {
// Get the row names
loc rnames : rown res, quoted
// Get the column names
loc cnames : coln res, quoted
// Test the Stata version
if `c(stata_version)' >= 17 {
// Get the resulting matrix into the collection
collect get xv = res, name(`name')
// Create a title for the display
collect title "Cross-Validation Results", name(`name')
// Create a layout
qui: collect layout (rowname[`rnames'])(colname[`cnames'])(cmdset)
// Display the metrics in a not horrible layout
collect preview
} // End IF Block for current Stata display
// For older Stata
else {
// Display the matrix of results
mat li res
} // End ELSE Block for older Stata display
} // End IF Block to display results if requested by the user
// End of program definition
end
// Subroutine to compute all of the stats and build a matrix that will persist
// over all of the loops to return results as a table instead of printing
// individually
prog def getstats, rclass
// Defines the syntax for the sub-routine
syntax , MEtric(string asis) Pstub(string asis) Obs(string asis) ///
Touse(string asis) STo(string asis) ///
[ MOnitors(string asis) SFx(string asis)]
// Parse the monitors option
_parse_monitors `monitors'
// Store the parsed monitors
loc monargs `"`r(mons)'"'
// Count the words in monitors
loc mons `r(n)'
// Create index for matrix
loc m `= `mons' + 1'
// Initialize the storage matrix in mata
mata: `sto' = J(`m', 1, .)
// Create a macro with the names that get returned
loc rnms
// Only execute if there are monitors
if !mi("`mons'") & `mons' >= 1 {
// Loop over the monitors
forv i = 1/`mons' {
// Get the name of the function for monitoring
loc mon : word `i' of `monargs'
// Get the monitor name from the parsed string
mata: getname(`"`mon'"', "monnm")
// Get any arguments passed to the monitor
mata: getarg(`"`mon'"', "mnopt")
// Call the mata function
mata: `sto'[`i', 1] = `monnm'("`pstub'", "`obs'", "`touse'", `mnopt')
// Creates a Stata scalar with the appropriate value
mata: st_numscalar("`monnm'`sfx'", `sto'[`i', 1])
// Sets the return value for the scalar
return scalar `monnm'`sfx' = `= `monnm'`sfx''
// Add this name to rnms
loc rnms `rnms' `monnm'`sfx'
} // End loop over monitors
} // End IF Block to compute monitors only if requested
// Get the name of the metric (in case there are options passed to it)
mata: getname(`"`metric'"', "metnm")
// Get any arguments passed to the metric
mata: getarg(`"`metric'"', "meopt")
// Call the mata function for the metric
mata: `sto'[`m', 1] = `metnm'("`pstub'", "`obs'", "`touse'", `meopt')
// Push the value into a scalar
mata: st_numscalar("`metnm'sc", `sto'[`m', 1])
// Sets the return value for the scalar
return scalar metric`sfx' = `= `metnm'sc'
// Add this name to rnms
loc rnms `rnms' metric`sfx'
// Return the column from the matrix of results to a stata matrix
mata: st_matrix("vmat", `sto')
// Sets the return matrix value
return matrix mtrx = vmat
// Returns the name of the metrics/monitors
ret loc names = "`rnms'"
// End of subroutine to compute the statistics
end
// Define subroutine to handle parsing of monitors option
prog def _parse_monitors, rclass
// Define syntax
syntax [anything(name = monitors id = "Options passed to monitors")]
// If there are no options passed to monitors return an empty string
if mi(`"`monitors'"') {
// Return an empty string for the monitors
ret loc mons = ""
// Return a value of 0 for the number of monitors
ret loc n = 0
} // End IF Block for no monitors
// Otherwise if monitors is not empty
else {
// Parse the contents initially
gettoken 1 2 : monitors, bind
// Store the first argument in the macro that will be used to return
// all the arguments
loc args `"`args' `"`1'"' "'
// Continue to parse the remainder of the string
while !mi(`"`2'"') {
// Parse the next token from the remaining portion of the macro
gettoken 1 2 : 2, bind
// Add the next token to the parsed and quoted tokens
loc args `"`args' `"`1'"' "'
} // End of WHILE loop to parse monitor arguments
// Get the number of arguments parsed
ret loc n = `"`: word count `args''"'
// Return the parsed monitor options
ret loc mons = `"`args'"'
} // End ELSE Block for optional arguments to monitors
// End of subroutine definition
end