diff --git a/R/Measure.R b/R/Measure.R index ff45b64d8..5c58b85e8 100644 --- a/R/Measure.R +++ b/R/Measure.R @@ -194,7 +194,7 @@ Measure = R6Class("Measure", #' #' @return `numeric(1)`. score = function(prediction, task = NULL, learner = NULL, train_set = NULL) { - assert_measure(self, task = task, learner = learner, prediction = prediction) + assert_scorable(self, task = task, learner = learner, prediction = prediction) assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% self$properties) # check should be added to assert_measure() @@ -395,7 +395,7 @@ score_measures = function(obj, measures, reassemble = TRUE, view = NULL, iters = tmp = unique(tab, by = c("task_hash", "learner_hash"))[, c("task", "learner"), with = FALSE] for (measure in measures) { - pmap(tmp, assert_measure, measure = measure) + pmap(tmp, assert_scorable, measure = measure) score = pmap_dbl(tab[, c("task", "learner", "resampling", "iteration", "prediction"), with = FALSE], function(task, learner, resampling, iteration, prediction) { diff --git a/R/assertions.R b/R/assertions.R index f4c76fa41..fa6618dbd 100644 --- a/R/assertions.R +++ b/R/assertions.R @@ -223,14 +223,6 @@ assert_measure = function(measure, task = NULL, learner = NULL, prediction = NUL if (!is.null(learner)) { - if ("requires_model" %in% measure$properties && is.null(learner$model)) { - stopf("Measure '%s' requires the trained model", measure$id) - } - - if ("requires_model" %in% measure$properties && is_marshaled_model(learner$model)) { - stopf("Measure '%s' requires the trained model, but model is in marshaled form", measure$id) - } - if (!is_scalar_na(measure$task_type) && measure$task_type != learner$task_type) { stopf("Measure '%s' is not compatible with type '%s' of learner '%s'", measure$id, learner$task_type, learner$id) @@ -263,6 +255,21 @@ assert_measure = function(measure, task = NULL, learner = NULL, prediction = NUL invisible(measure) } +#' @export +#' @param measure ([Measure]). +#' @param prediction ([Prediction]). +#' @rdname mlr_assertions +assert_scorable = function(measure, task, learner, prediction = NULL, .var.name = vname(measure)) { + if ("requires_model" %in% measure$properties && is.null(learner$model)) { + stopf("Measure '%s' requires the trained model", measure$id) + } + + if ("requires_model" %in% measure$properties && is_marshaled_model(learner$model)) { + stopf("Measure '%s' requires the trained model, but model is in marshaled form", measure$id) + } + + assert_measure(measure, task = task, learner = learner, prediction = prediction, .var.name = .var.name) +} #' @export #' @param measures (list of [Measure]).