Skip to content

Commit

Permalink
fix: extend assert_measure with checks for trained models in assert_s…
Browse files Browse the repository at this point in the history
…corable (#1218)
  • Loading branch information
be-marc authored Nov 27, 2024
1 parent 282b53a commit 565e3ff
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
4 changes: 2 additions & 2 deletions R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ Measure = R6Class("Measure",
#'
#' @return `numeric(1)`.
score = function(prediction, task = NULL, learner = NULL, train_set = NULL) {
assert_measure(self, task = task, learner = learner, prediction = prediction)
assert_scorable(self, task = task, learner = learner, prediction = prediction)
assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% self$properties)

# check should be added to assert_measure()
Expand Down Expand Up @@ -395,7 +395,7 @@ score_measures = function(obj, measures, reassemble = TRUE, view = NULL, iters =
tmp = unique(tab, by = c("task_hash", "learner_hash"))[, c("task", "learner"), with = FALSE]

for (measure in measures) {
pmap(tmp, assert_measure, measure = measure)
pmap(tmp, assert_scorable, measure = measure)

score = pmap_dbl(tab[, c("task", "learner", "resampling", "iteration", "prediction"), with = FALSE],
function(task, learner, resampling, iteration, prediction) {
Expand Down
23 changes: 15 additions & 8 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,6 @@ assert_measure = function(measure, task = NULL, learner = NULL, prediction = NUL

if (!is.null(learner)) {

if ("requires_model" %in% measure$properties && is.null(learner$model)) {
stopf("Measure '%s' requires the trained model", measure$id)
}

if ("requires_model" %in% measure$properties && is_marshaled_model(learner$model)) {
stopf("Measure '%s' requires the trained model, but model is in marshaled form", measure$id)
}

if (!is_scalar_na(measure$task_type) && measure$task_type != learner$task_type) {
stopf("Measure '%s' is not compatible with type '%s' of learner '%s'",
measure$id, learner$task_type, learner$id)
Expand Down Expand Up @@ -263,6 +255,21 @@ assert_measure = function(measure, task = NULL, learner = NULL, prediction = NUL
invisible(measure)
}

#' @export
#' @param measure ([Measure]).
#' @param prediction ([Prediction]).
#' @rdname mlr_assertions
assert_scorable = function(measure, task, learner, prediction = NULL, .var.name = vname(measure)) {
if ("requires_model" %in% measure$properties && is.null(learner$model)) {
stopf("Measure '%s' requires the trained model", measure$id)
}

if ("requires_model" %in% measure$properties && is_marshaled_model(learner$model)) {
stopf("Measure '%s' requires the trained model, but model is in marshaled form", measure$id)
}

assert_measure(measure, task = task, learner = learner, prediction = prediction, .var.name = .var.name)
}

#' @export
#' @param measures (list of [Measure]).
Expand Down

0 comments on commit 565e3ff

Please sign in to comment.