Skip to content

Commit

Permalink
Issue #625: solve aggregation issue for test.join with probs
Browse files Browse the repository at this point in the history
  • Loading branch information
giuseppec committed Oct 7, 2016
1 parent e55f1a5 commit 880b136
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
7 changes: 6 additions & 1 deletion R/aggregations.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,13 @@ test.join = makeAggregation(
df = as.data.frame(pred)
f = if (length(group)) group[df$iter] else factor(rep(1L, nrow(df)))
mean(vnapply(split(df, f), function(df) {
if (pred$predict.type == "response") y = df$response
if (pred$predict.type == "prob") {
y = df[,grepl("^prob[.]", colnames(df))]
colnames(y) = gsub("^prob[.]", "", colnames(y))
}
npred = makePrediction(task.desc = pred$task.desc, row.names = rownames(df),
id = NULL, truth = df$truth, predict.type = pred$predict.type, y = df$response,
id = NULL, truth = df$truth, predict.type = pred$predict.type, y = y,
time = NA_real_)
performance(npred, measure)
}))
Expand Down
5 changes: 5 additions & 0 deletions tests/testthat/test_base_resample_repcv.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,9 @@ test_that("test.join works somehow", {
res = resample(learner = lrn, task = task, resampling = rin, measures = measures)
expect_equal(res$measures.test[, 2L], res$measures.test[, 3L])
expect_true(diff(res$aggr) > 0)

lrn = setPredictType(lrn, predict.type = "prob")
res.prob = resample(learner = lrn, task = task, resampling = cv2, measures = measures[[2]])
expect_equal(res$measures.test[, 2L], res$measures.test[, 3L])
expect_true(diff(res$aggr) > 0)
})

0 comments on commit 880b136

Please sign in to comment.