Skip to content

Commit

Permalink
use grid_space_filling() instead of grid_latin_hypercube() (tidymodel…
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo authored Aug 2, 2024
1 parent cb4a0ee commit 72ca7e2
Show file tree
Hide file tree
Showing 12 changed files with 67 additions and 53 deletions.
8 changes: 4 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Depends:
R (>= 4.0)
Imports:
cli (>= 3.3.0),
dials (>= 1.0.0),
dials (>= 1.3.0),
doFuture (>= 1.0.0),
dplyr (>= 1.1.0),
foreach,
Expand All @@ -33,13 +33,13 @@ Imports:
purrr (>= 1.0.0),
recipes (>= 1.0.4),
rlang (>= 1.1.0),
rsample (>= 1.2.0),
rsample (>= 1.2.1.9000),
tibble (>= 3.1.0),
tidyr (>= 1.2.0),
tidyselect (>= 1.1.2),
vctrs (>= 0.6.1),
withr,
workflows (>= 1.1.4),
workflows (>= 1.1.4.9000),
yardstick (>= 1.3.0)
Suggests:
C50,
Expand All @@ -66,4 +66,4 @@ Encoding: UTF-8
Language: en-US
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

* The package will now log a backtrace for errors and warnings that occur during tuning. When a tuning process encounters issues, see the new `trace` column in the `collect_notes(.Last.tune.result)` output to find precisely where the error occurred (#873).

* When automatic grids are used, `dials::grid_space_filling()` is now used (instead of `dials::grid_latin_hypercube()`). Overall, the new function produces optimized designs (not depending on random numbers). When using Bayesian models, we will use a Latin Hypercube since we produce 5,000 candidates, which is too slow to do with pre-optimized designs.

# tune 1.2.1

* Addressed issue in `int_pctl()` where the function would error when parallelized using `makePSOCKcluster()` (#885).
Expand Down
2 changes: 1 addition & 1 deletion R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ check_grid <- function(grid, workflow, pset = NULL, call = caller_env()) {
}
check_workflow(workflow, pset = pset, check_dials = TRUE, call = call)

grid <- dials::grid_latin_hypercube(pset, size = grid)
grid <- dials::grid_space_filling(pset, size = grid)
grid <- dplyr::distinct(grid)
}

Expand Down
4 changes: 2 additions & 2 deletions R/tune_bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ create_initial_set <- function(param, n = NULL, checks) {
if (any(checks == "bayes")) {
check_bayes_initial_size(nrow(param), n)
}
dials::grid_latin_hypercube(param, size = n)
dials::grid_space_filling(param, size = n)
}

check_iter <- function(iter, call) {
Expand Down Expand Up @@ -632,7 +632,7 @@ fit_gp <- function(dat, pset, metric, eval_time = NULL, control, ...) {

pred_gp <- function(object, pset, size = 5000, current = NULL, control) {
pred_grid <-
dials::grid_latin_hypercube(pset, size = size) %>%
dials::grid_space_filling(pset, size = size, type = "latin_hypercube") %>%
dplyr::distinct()

if (!is.null(current)) {
Expand Down
5 changes: 2 additions & 3 deletions R/tune_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@
#'
#' @section Parameter Grids:
#'
#' If no tuning grid is provided, a semi-random grid (via
#' [dials::grid_latin_hypercube()]) is created with 10 candidate parameter
#' combinations.
#' If no tuning grid is provided, a grid (via [dials::grid_space_filling()]) is
#' created with 10 candidate parameter combinations.
#'
#' When provided, the grid should have column names for each parameter and
#' these should be named by the parameter name or `id`. For example, if a
Expand Down
3 changes: 2 additions & 1 deletion inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ Codecov
Davison
Disambiguates
EI
foreach
Hinkley
Isomap
Lifecycle
Expand All @@ -15,10 +14,12 @@ Olshen
PSOCK
RNGkind
Wadsworth
backtrace
cdot
doi
el
finetune
foreach
frac
geo
ggplot
Expand Down
5 changes: 2 additions & 3 deletions man/tune_grid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 22 additions & 18 deletions tests/testthat/_snaps/bayes.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,29 +170,29 @@
-- Iteration 1 -----------------------------------------------------------------
i Current best: rmse=2.453 (@iter 0)
i Current best: rmse=2.461 (@iter 0)
i Gaussian process model
! The Gaussian process model is being fit using 1 features but only has 2
data points to do so. This may cause errors or a poor model fit.
v Gaussian process model
i Generating 3 candidates
i Predicted candidates
i num_comp=4
i num_comp=5
i Estimating performance
v Estimating performance
(x) Newest results: rmse=2.461 (+/-0.37)
<3 Newest results: rmse=2.453 (+/-0.381)
-- Iteration 2 -----------------------------------------------------------------
i Current best: rmse=2.453 (@iter 0)
i Current best: rmse=2.453 (@iter 1)
i Gaussian process model
v Gaussian process model
i Generating 2 candidates
i Predicted candidates
i num_comp=3
i num_comp=1
i Estimating performance
v Estimating performance
<3 Newest results: rmse=2.418 (+/-0.357)
(x) Newest results: rmse=2.646 (+/-0.286)
Output
# Tuning results
# 10-fold cross-validation
Expand Down Expand Up @@ -225,14 +225,14 @@
-- Iteration 1 -----------------------------------------------------------------
i Current best: rmse=2.453 (@iter 0)
i Current best: rmse=2.461 (@iter 0)
i Gaussian process model
! The Gaussian process model is being fit using 1 features but only has 2
data points to do so. This may cause errors or a poor model fit.
v Gaussian process model
i Generating 3 candidates
i Predicted candidates
i num_comp=4
i num_comp=5
i Estimating performance
i Fold01: preprocessor 1/1
v Fold01: preprocessor 1/1
Expand Down Expand Up @@ -295,16 +295,16 @@
i Fold10: preprocessor 1/1, model 1/1 (extracts)
i Fold10: preprocessor 1/1, model 1/1 (predictions)
v Estimating performance
(x) Newest results: rmse=2.461 (+/-0.37)
<3 Newest results: rmse=2.453 (+/-0.381)
-- Iteration 2 -----------------------------------------------------------------
i Current best: rmse=2.453 (@iter 0)
i Current best: rmse=2.453 (@iter 1)
i Gaussian process model
v Gaussian process model
i Generating 2 candidates
i Predicted candidates
i num_comp=3
i num_comp=1
i Estimating performance
i Fold01: preprocessor 1/1
v Fold01: preprocessor 1/1
Expand Down Expand Up @@ -367,7 +367,7 @@
i Fold10: preprocessor 1/1, model 1/1 (extracts)
i Fold10: preprocessor 1/1, model 1/1 (predictions)
v Estimating performance
<3 Newest results: rmse=2.418 (+/-0.357)
(x) Newest results: rmse=2.646 (+/-0.286)
Output
# Tuning results
# 10-fold cross-validation
Expand Down Expand Up @@ -523,12 +523,6 @@
data points to do so. This may cause errors or a poor model fit.
! For the rsq estimates, 1 missing value was found and removed before fitting
the Gaussian process model.
! For the rsq estimates, 1 missing value was found and removed before fitting
the Gaussian process model.
! For the rsq estimates, 1 missing value was found and removed before fitting
the Gaussian process model.
! For the rsq estimates, 1 missing value was found and removed before fitting
the Gaussian process model.
! validation: internal: A correlation computation is required, but `estimate` is constant and ha...
! For the rsq estimates, 2 missing values were found and removed before
fitting the Gaussian process model.
Expand All @@ -545,6 +539,16 @@
! For the rsq estimates, 6 missing values were found and removed before
fitting the Gaussian process model.
! validation: internal: A correlation computation is required, but `estimate` is constant and ha...
! For the rsq estimates, 7 missing values were found and removed before
fitting the Gaussian process model.
! validation: internal: A correlation computation is required, but `estimate` is constant and ha...
! For the rsq estimates, 8 missing values were found and removed before
fitting the Gaussian process model.
! validation: internal: A correlation computation is required, but `estimate` is constant and ha...
! For the rsq estimates, 9 missing values were found and removed before
fitting the Gaussian process model.
! validation: internal: A correlation computation is required, but `estimate` is constant and ha...
! No improvement for 10 iterations; returning current results.

---

Expand Down
12 changes: 6 additions & 6 deletions tests/testthat/_snaps/fit_best.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
fit_best(knn_pca_res, verbose = TRUE)
Output
Using rmse as the metric, the optimal parameters were:
neighbors: 10
num_comp: 3
neighbors: 1
num_comp: 4
Message
i Fitting using 161 data points...
Expand All @@ -23,13 +23,13 @@
-- Model -----------------------------------------------------------------------
Call:
kknn::train.kknn(formula = ..y ~ ., data = data, ks = min_rows(10L, data, 5))
kknn::train.kknn(formula = ..y ~ ., data = data, ks = min_rows(1L, data, 5))
Type of response variable: continuous
minimal mean absolute error: 1.690086
Minimal mean squared error: 4.571625
minimal mean absolute error: 1.015528
Minimal mean squared error: 2.448261
Best kernel: optimal
Best k: 10
Best k: 1

---

Expand Down
20 changes: 10 additions & 10 deletions tests/testthat/_snaps/grid.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@
# A tibble: 10 x 4
splits id .metrics .notes
<list> <chr> <list> <list>
1 <split [28/4]> Fold01 <tibble [4 x 5]> <tibble [0 x 4]>
2 <split [28/4]> Fold02 <tibble [4 x 5]> <tibble [0 x 4]>
3 <split [29/3]> Fold03 <tibble [4 x 5]> <tibble [0 x 4]>
4 <split [29/3]> Fold04 <tibble [4 x 5]> <tibble [0 x 4]>
5 <split [29/3]> Fold05 <tibble [4 x 5]> <tibble [0 x 4]>
6 <split [29/3]> Fold06 <tibble [4 x 5]> <tibble [0 x 4]>
7 <split [29/3]> Fold07 <tibble [4 x 5]> <tibble [0 x 4]>
8 <split [29/3]> Fold08 <tibble [4 x 5]> <tibble [0 x 4]>
9 <split [29/3]> Fold09 <tibble [4 x 5]> <tibble [0 x 4]>
10 <split [29/3]> Fold10 <tibble [4 x 5]> <tibble [0 x 4]>
1 <split [28/4]> Fold01 <tibble [6 x 5]> <tibble [0 x 4]>
2 <split [28/4]> Fold02 <tibble [6 x 5]> <tibble [0 x 4]>
3 <split [29/3]> Fold03 <tibble [6 x 5]> <tibble [0 x 4]>
4 <split [29/3]> Fold04 <tibble [6 x 5]> <tibble [0 x 4]>
5 <split [29/3]> Fold05 <tibble [6 x 5]> <tibble [0 x 4]>
6 <split [29/3]> Fold06 <tibble [6 x 5]> <tibble [0 x 4]>
7 <split [29/3]> Fold07 <tibble [6 x 5]> <tibble [0 x 4]>
8 <split [29/3]> Fold08 <tibble [6 x 5]> <tibble [0 x 4]>
9 <split [29/3]> Fold09 <tibble [6 x 5]> <tibble [0 x 4]>
10 <split [29/3]> Fold10 <tibble [6 x 5]> <tibble [0 x 4]>

17 changes: 13 additions & 4 deletions tests/testthat/test-autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,21 @@ test_that("plot_perf_vs_iter with fairness metrics (#773)", {

test_that("regular grid plot", {
skip_if_not_installed("ggplot2", minimum_version = "3.5.0")
set.seed(1)
res <-

svm_spec <-
parsnip::svm_rbf(cost = tune()) %>%
parsnip::set_engine("kernlab") %>%
parsnip::set_mode("regression") %>%
tune_grid(mpg ~ ., resamples = rsample::vfold_cv(mtcars, v = 5), grid = 1)
parsnip::set_mode("regression")

svm_grid <-
svm_spec %>%
extract_parameter_set_dials() %>%
dials::grid_regular(levels = 1)

set.seed(1)
res <-
svm_spec %>%
tune_grid(mpg ~ ., resamples = rsample::vfold_cv(mtcars, v = 5), grid = svm_grid)

expect_snapshot(
error = TRUE,
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ test_that("tune model and recipe", {
grid_3 <-
extract_parameter_set_dials(wflow_3) %>%
update(num_comp = dials::num_comp(c(2, 5))) %>%
dials::grid_latin_hypercube(size = 4)
dials::grid_space_filling(size = 4)

expect_error(
res_3_1 <- tune_grid(
Expand Down

0 comments on commit 72ca7e2

Please sign in to comment.