Skip to content

Commit

Permalink
Upgraded the function signature for the metric functions to allow pas…
Browse files Browse the repository at this point in the history
…sing optional arguments to the functions. If options are not supported by the underlying function, nothing is done with the optional argument. Updated help files to reflect that change and added titles to each of the help files. Added some more robustness checks for syntax for the xv and xvloo commands. Updated validate it to handle parsing optional arguments passed to metrics/monitors (still need to test). Added tests for libxv to check that passing optional arguments would not affect existing functions that do not support them. Updated helpfiles for the main commands to include titles. Updated the README. Updated the distribution date for the package file. Recompiled the Mata library in Stata 15.
  • Loading branch information
wbuchanan committed Feb 29, 2024
1 parent 7a523b8 commit 09e5086
Show file tree
Hide file tree
Showing 17 changed files with 282 additions and 115 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ collections for Stata >= 17 and modify help files related to collections.
- [x] Modify the replay option to adjust for the collection thing above
- [x] Write a function that can do the same thing as `assertnested` and will
work in Stata 15
- [x] Update function signature to provide support for optional arguments
- [x] Compile Mata library in Stata 15
- [ ] Standardize language in help files
- [ ] Finish writing test cases for Mata functions
- [ ] Finish writing test cases for ADO commands
Expand All @@ -78,11 +80,12 @@ The program will allow users to define their own metrics/monitors that are not
contained in libcrossvalidate. In order to do this, users must implement a
specific method/function signature:

`real scalar metric(string scalar pred, string scalar obs, string scalar touse)`
`real scalar metric(string scalar pred, string scalar obs, string scalar touse, | transmorphic matrix opts)`

The function must return a real valued scalar and take three arguments. The
three arguments are used to access the data that would be used to compute the
metrics/monitors.
metrics/monitors and to provide a method to pass optional arguments to the
underlying functions if supported.

### Data access
Within the function body, we recommend using the following pattern to access
Expand Down
147 changes: 98 additions & 49 deletions crossvalidate.mata

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion crossvalidate.pkg
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ d KW: evaluation
d
d Requires: Stata version 15
d
d Distribution-Date: 20240228
d Distribution-Date: 20240229
d
d Author: Billy Buchanan, Ph.D.
d Sr. Research Scientist, SAG Corporation
Expand Down
27 changes: 13 additions & 14 deletions crossvalidate.sthlp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{smcl}
{* *! version 0.0.1 17feb2024}{...}
{* *! version 0.0.2 29feb2024}{...}
{vieweralsosee "[R] predict" "mansection R predict"}{...}
{vieweralsosee "[R] estat classification" "mansection R estat_classification"}{...}
{vieweralsosee "[P] creturn" "mansection P creturn"}{...}
Expand All @@ -8,9 +8,9 @@
{viewerjumpto "Commands" "crossvalidate##cmds"}{...}
{viewerjumpto "Additional Information" "crossvalidate##additional"}{...}
{viewerjumpto "Contact" "crossvalidate##contact"}{...}
{title:Cross-Validation in Stata}

{marker overview}
{title:Overview}
{marker overview}{title:Overview}

{pstd}
The crossvalidate package includes several commands and a Mata library that
Expand All @@ -27,25 +27,24 @@ This help file provides an overview of the commands included in the crossvalidat
package. We leave detailed information to the documentation for each of the
individual commands.

{marker cmds}
{title:Commands}
{marker cmds}{title:Commands}

{synoptset 15 tabbed}{...}
{synoptline}
{synopthdr:Command Name}
{synoptline}
{syntab:Prefix Commands}
{synopt :{opt xv}}Cross-Validation{p_end}
{synopt :{opt xvloo}}Leave-One-Out Cross-Validation{p_end}
{synopt :{opt {help xv}}}Cross-Validation{p_end}
{synopt :{opt {help xvloo}}}Leave-One-Out Cross-Validation{p_end}
{syntab:Lower Level Commands}
{synopt :{opt splitit}}Splits the dataset into train/test or train/validation/test splits{p_end}
{synopt :{opt fitit}}Calls the estimation command on the appropriate split{p_end}
{synopt :{opt predictit}}Predicts the outcome on the appropriate split{p_end}
{synopt :{opt validateit}}Computes {p_end}
{synopt :{opt {help splitit}}}Splits the dataset into train/test or train/validation/test splits{p_end}
{synopt :{opt {help fitit}}}Calls the estimation command on the appropriate split{p_end}
{synopt :{opt {help predictit}}}Predicts the outcome on the appropriate split{p_end}
{synopt :{opt {help validateit}}}Computes {p_end}
{syntab:Utility Commands}
{synopt :{opt classify}}Used to manage {p_end}
{synopt :{opt cmdmod}}Used for metaprogramming tasks in commands above{p_end}
{synopt :{opt state}}Retrieves current settings and binds to the dataset{p_end}
{synopt :{opt {help classify}}}Used to manage {p_end}
{synopt :{opt {help cmdmod}}}Used for metaprogramming tasks in commands above{p_end}
{synopt :{opt {help state}}}Retrieves current settings and binds to the dataset{p_end}
{synoptline}

{dlgtab:Prefix Commands}
Expand Down
3 changes: 2 additions & 1 deletion fitit.sthlp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{smcl}
{* *! version 0.0.6 28feb2024}{...}
{* *! version 0.0.7 29feb2024}{...}
{vieweralsosee "[R] estat classification" "mansection R estat_classification"}{...}
{vieweralsosee "" "--"}{...}
{viewerjumpto "Syntax" "fitit##syntax"}{...}
Expand All @@ -9,6 +9,7 @@
{viewerjumpto "Returned Values" "fitit##retvals"}{...}
{viewerjumpto "Additional Information" "fitit##additional"}{...}
{viewerjumpto "Contact" "fitit##contact"}{...}
{title:Model Fitting for Cross-Validation in Stata}

{marker syntax}{...}
{title:Syntax}
Expand Down
Binary file modified libxv.mlib
Binary file not shown.
3 changes: 2 additions & 1 deletion predictit.sthlp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{smcl}
{* *! version 0.0.3 23feb2024}{...}
{* *! version 0.0.4 29feb2024}{...}
{vieweralsosee "[R] estat classification" "mansection R estat_classification"}{...}
{vieweralsosee "" "--"}{...}
{viewerjumpto "Syntax" "predictit##syntax"}{...}
Expand All @@ -8,6 +8,7 @@
{viewerjumpto "Examples" "predictit##examples"}{...}
{viewerjumpto "Additional Information" "predictit##additional"}{...}
{viewerjumpto "Contact" "predictit##contact"}{...}
{title:Generating and Managing Model Predictions for Cross-Validation in Stata}

{marker syntax}{...}
{title:Syntax}
Expand Down
3 changes: 2 additions & 1 deletion splitit.sthlp
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
{smcl}
{* *! version 0.0.5 28feb2024}{...}
{* *! version 0.0.6 29feb2024}{...}
{viewerjumpto "Syntax" "splitit##syntax"}{...}
{viewerjumpto "Description" "splitit##description"}{...}
{viewerjumpto "Options" "splitit##options"}{...}
{viewerjumpto "Examples" "splitit##examples"}{...}
{viewerjumpto "Returned Values" "splitit##retvals"}{...}
{viewerjumpto "Additional Information" "splitit##additional"}{...}
{viewerjumpto "Contact" "splitit##contact"}{...}
{title:Dataset Splitting and Folding for Cross-Validation in Stata}

{marker syntax}{...}
{title:Syntax}
Expand Down
4 changes: 4 additions & 0 deletions test/libxvtests.do
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ asserteq(round(fxbinr2, rf), round(binr2("pred", "obs", "touse"), rf))
// Test equality of Matthews Correlation Coefficient
asserteq(round(fxmcc, rf), round(mcc("pred", "obs", "touse"), rf))
// Test equality of F1 metrics with new function signature
asserteq(round(fxf1, rf), round(f1("pred", "obs", "touse", (1, 2, 3)), rf))
asserteq(round(fxf1, rf), round(f1("pred", "obs", "touse", ("yes", "no")), rf))
// End the Mata session
end

Expand Down
50 changes: 47 additions & 3 deletions test/xvlootests.do
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ run crossvalidate.mata
* individual commands that are called subsequently in the prefix. *
*******************************************************************************/

// Test that correct error code is thrown when there are no results to replay
rcof "xvloo 0.8, replay: mlogit industry south" == 119

// Test that correct error code is thrown when
rcof "xvloo 0.98, pstub(pred) metric(mse) classes(12): mlogit industry south" == 1002

Expand All @@ -30,6 +33,21 @@ rcof "xvloo 0.7, metric(mse): reg wage i.industry" == 198
// Test that error code is thrown if the user tries to specify K-Folds
rcof "xvloo 0.7, pstub(pred) metric(mse) kfold(4): reg wage i.industry " == 184

// Create a variable that should trigger an error for an existing `pstub'all
// variable
qui: g byte predall = rbinomial(3, 0.5)

// Test that error code is thrown if the user is trying to create the `pstub'all
// variable if it already exists
rcof "xvloo 0.7, pstub(pred) metric(mse): reg wage i.industry " == 110

// Now test that the same will happen if the pstub variable already exists
rename predall pred

// Test that error code is thrown if the user is trying to create the `pstub'
// variable if it already exists
rcof "xvloo 0.7, pstub(pred) metric(mse): reg wage i.industry " == 110

// Clear everything
clear all

Expand All @@ -56,6 +74,32 @@ xvloo 0.8, metric(mse) pstub(pred) monitors(mae mape) display retain: reg mpg pr
// There should be 59 stored estimation results, pred, predall, and _xvsplit
// added as variables

// There should be values in e(splitter), e(training), e(validation),
// e(stype), e(flavor), e(estresname), e(estresall)

// Make sure the returned values are populated
assert !mi(`e(rng)')
assert !mi(`e(rngcurrent)')
assert !mi(`e(rngstate)')
assert !mi(`e(rngseed)')
assert !mi(`e(rngstream)')
assert !mi(`e(filename)')
assert !mi(`e(filedate)')
assert !mi(`e(version)')
assert !mi(`e(currentdate)')
assert !mi(`e(currenttime)')
assert !mi(`e(stflavor)')
assert !mi(`e(processors)')
assert !mi(`e(hostname)')
assert !mi(`e(machinetype)')
assert !mi(`e(splitter)')
assert !mi(`e(training)')
assert !mi(`e(validation)')
assert !mi(`e(testing)')
assert !mi(`e(stype)')
assert !mi(`e(flavor)')
assert mi(`e(forecastset)')
assert !mi(`e(estresnames)')
assert !mi(`e(estresall)')
assert !mi(`e(fitnm)')
assert !mi(`e(valnm)')
assert !mi(`e(xv)')

assert `"`e(stype)'"' == "Leave One Out"
49 changes: 44 additions & 5 deletions test/xvtests.do
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,30 @@ run crossvalidate.mata
* individual commands that are called subsequently in the prefix. *
*******************************************************************************/

// Test that correct error code is thrown when there are no results to replay
rcof "xvloo 0.8, replay: mlogit industry south" == 119

// Test that error code is thrown for missing metric
rcof "xv 0.7, pstub(pred): reg wage i.industry" == 198

// Test that error code is thrown for missing predicted value stub
rcof "xv 0.7, metric(mse): reg wage i.industry" == 198

// Create a variable that should trigger an error for an existing `pstub'all
// variable
qui: g byte predall = rbinomial(3, 0.5)

// Test that error code is thrown if the user is trying to create the `pstub'all
// variable if it already exists
rcof "xv 0.7, pstub(pred) metric(mse): reg wage i.industry " == 110

// Now test that the same will happen if the pstub variable already exists
rename predall pred

// Test that error code is thrown if the user is trying to create the `pstub'
// variable if it already exists
rcof "xv 0.7, pstub(pred) metric(mse): reg wage i.industry " == 110

// Clear everything
clear all

Expand All @@ -40,9 +57,31 @@ run crossvalidate.mata
// variables
xv 0.8, metric(r2) pstub(pred) monitors() display retain: reg mpg price i.rep78, vce(rob)

// There should be N stored estimation results, pred, predall, and _xvsplit
// added as variables

// There should be values in e(splitter), e(training), e(validation),
// e(stype), e(flavor), e(estresname), e(estresall)

// Make sure the returned values are populated
assert !mi(`e(rng)')
assert !mi(`e(rngcurrent)')
assert !mi(`e(rngstate)')
assert !mi(`e(rngseed)')
assert !mi(`e(rngstream)')
assert !mi(`e(filename)')
assert !mi(`e(filedate)')
assert !mi(`e(version)')
assert !mi(`e(currentdate)')
assert !mi(`e(currenttime)')
assert !mi(`e(stflavor)')
assert !mi(`e(processors)')
assert !mi(`e(hostname)')
assert !mi(`e(machinetype)')
assert !mi(`e(splitter)')
assert !mi(`e(training)')
assert !mi(`e(validation)')
assert !mi(`e(testing)')
assert !mi(`e(stype)')
assert !mi(`e(flavor)')
assert mi(`e(forecastset)')
assert !mi(`e(estresnames)')
assert !mi(`e(estresall)')
assert !mi(`e(fitnm)')
assert !mi(`e(valnm)')
assert !mi(`e(xv)')
10 changes: 8 additions & 2 deletions validateit.ado
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,11 @@ prog def getstats, rclass
// Get the name of the function for monitoring
loc monnm : word `i' of `monitors'

// Get any arguments passed to the monitor
mata: getargs("`monnm'", "mnopt")

// Call the mata function
mata: `sto'[`i', 1] = `monnm'("`pstub'", "`obs'", "`touse'")
mata: `sto'[`i', 1] = `monnm'("`pstub'", "`obs'", "`touse'", `mnopt')

// Creates a Stata scalar with the appropriate value
mata: st_numscalar("`monnm'`sfx'", `sto'[`i', 1])
Expand All @@ -418,8 +421,11 @@ prog def getstats, rclass

} // End loop over monitors

// Get any arguments passed to the metric
mata: getargs("`metric'", "meopt")

// Call the mata function for the metric
mata: `sto'[`m', 1] = `metric'("`pstub'", "`obs'", "`touse'")
mata: `sto'[`m', 1] = `metric'("`pstub'", "`obs'", "`touse'", `meopt')

// Push the value into a scalar
mata: st_numscalar("`metric'sc", `sto'[`m', 1])
Expand Down
Loading

0 comments on commit 09e5086

Please sign in to comment.