diff --git a/R/aggregations.R b/R/aggregations.R index d89efd82fb..885713813b 100644 --- a/R/aggregations.R +++ b/R/aggregations.R @@ -235,8 +235,13 @@ test.join = makeAggregation( df = as.data.frame(pred) f = if (length(group)) group[df$iter] else factor(rep(1L, nrow(df))) mean(vnapply(split(df, f), function(df) { + if (pred$predict.type == "response") y = df$response + if (pred$predict.type == "prob") { + y = df[,grepl("^prob[.]", colnames(df))] + colnames(y) = gsub("^prob[.]", "", colnames(y)) + } npred = makePrediction(task.desc = pred$task.desc, row.names = rownames(df), - id = NULL, truth = df$truth, predict.type = pred$predict.type, y = df$response, + id = NULL, truth = df$truth, predict.type = pred$predict.type, y = y, time = NA_real_) performance(npred, measure) })) diff --git a/tests/testthat/test_base_resample_repcv.R b/tests/testthat/test_base_resample_repcv.R index a897103ac3..8d99abca5a 100644 --- a/tests/testthat/test_base_resample_repcv.R +++ b/tests/testthat/test_base_resample_repcv.R @@ -62,4 +62,9 @@ test_that("test.join works somehow", { res = resample(learner = lrn, task = task, resampling = rin, measures = measures) expect_equal(res$measures.test[, 2L], res$measures.test[, 3L]) expect_true(diff(res$aggr) > 0) + + lrn = setPredictType(lrn, predict.type = "prob") + res.prob = resample(learner = lrn, task = task, resampling = cv2, measures = measures[[2]]) + expect_equal(res$measures.test[, 2L], res$measures.test[, 3L]) + expect_true(diff(res$aggr) > 0) })