Skip to content

Commit

Permalink
Fix class probability metrics (#62)
Browse files Browse the repository at this point in the history
* Fix class probability metrics

* Use getter function
  • Loading branch information
mikemahoney218 authored Apr 30, 2024
1 parent 1bfb0b5 commit 14591d9
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 13 deletions.
13 changes: 9 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# waywiser (development version)

* `ww_multi_scale()`, when called with raster arguments (either to `data` or to `truth`
and `estimate`) and a classification metric set, will now convert `truth` and
`estimate`to factors before passing them to the metric set. Thanks to @nowosad
for the report in #60 (#61).
* `ww_multi_scale()` now handles classification and class probability metrics better
when called with raster arguments (either to `data` or to `truth` and `estimate`):
* When called with classification metrics, `ww_multi_scale()` will now convert
`truth` and `estimate` to factors before passing them to the metric set.
Thanks to @nowosad for the report in #60 (#61).
* When called with class probability metrics, `ww_multi_scale()` will convert
`truth` to a factor and will pass `estimate` as an unnamed argument. (#62)
* When called with a mix of class and probability metrics, `ww_multi_scale()`
will error. (#62)

# waywiser 0.5.1

Expand Down
40 changes: 33 additions & 7 deletions R/multi_scale.R
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,36 @@ raster_method_notes <- function(grid_list) {
}

raster_method_summary <- function(grid_list, .notes, metrics, na_rm) {
class_metrics <- FALSE
prob_metrics <- FALSE
if (inherits(metrics, "class_prob_metric_set")) {
is_class_metric <- tibble::as_tibble(metrics)$class == "class_metric"

if (any(is_class_metric) && !all(is_class_metric)) {
rlang::abort(
c(
"`ww_multi_scale` can't handle mixed classification and class probability metric sets.",
i = "Call `ww_multi_scale()` twice: once for classification metrics, and once for class probability metrics."
),
class = "waywiser_mixed_metrics"
)
}

class_metrics <- all(is_class_metric)
prob_metrics <- !class_metrics
}

if (class_metrics || prob_metrics) {
lvls <- unique(
unlist(
lapply(
grid_list$grids,
function(grid) {
c(
levels(factor(grid$.truth)),
levels(factor(grid$.estimate))
)
out <- levels(factor(grid$.truth))
if (class_metrics) {
out <- c(out, levels(factor(grid$.estimate)))
}
out
}
)
)
Expand All @@ -345,15 +365,21 @@ raster_method_summary <- function(grid_list, .notes, metrics, na_rm) {
grid_list$grids,
function(grid) {
grid$.truth <- factor(grid$.truth, levels = lvls)
grid$.estimate <- factor(grid$.estimate, levels = lvls)
if (class_metrics) {
grid$.estimate <- factor(grid$.estimate, levels = lvls)
}
grid
}
)
)
}

out <- mapply(
function(grid, grid_arg, .notes) {
out <- metrics(grid, truth = .truth, estimate = .estimate, na_rm = na_rm)
if (prob_metrics) {
out <- metrics(as.data.frame(grid), truth = .truth, .estimate, na_rm = na_rm)
} else {
out <- metrics(grid, truth = .truth, estimate = .estimate, na_rm = na_rm)
}
out[attr(out, "sf_column")] <- NULL
out$.grid_args <- list(grid_list$grid_args[grid_arg, ])
out$.grid <- list(grid)
Expand Down
88 changes: 86 additions & 2 deletions tests/testthat/test-multi_scale.R
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ test_that("ww_multi_scale with raster args can handle classification metrics (#6
truth = l1,
estimate = l2,
metrics = list(yardstick::precision),
grid = list(sf::st_make_grid(l1))
grids = list(sf::st_make_grid(l1))
)$.estimate,
1
)
Expand All @@ -809,9 +809,93 @@ test_that("ww_multi_scale with raster data can handle classification metrics (#6
truth = "l1",
estimate = "l2",
metrics = list(yardstick::precision),
grid = list(sf::st_make_grid(l1))
grids = list(sf::st_make_grid(l1))
)$.estimate,
1
)

})

test_that("ww_multi_scale with raster args can handle class prob metrics", {
skip_if_not_installed("terra")
skip_if_not_installed("withr")
l1 <- withr::with_seed(
1107,
matrix(sample(1:2, 100, TRUE), nrow = 10)
)
l1 <- terra::rast(l1)
l2 <- withr::with_seed(
1107,
matrix(runif(100, 0, 1), nrow = 10)
)
l2 <- terra::rast(l2)

expect_equal(
ww_multi_scale(
truth = l1,
estimate = l2,
metrics = list(yardstick::pr_auc),
grids = list(sf::st_make_grid(l1))
)$.estimate,
0.47208959
)
})

test_that("ww_multi_scale with raster data can handle class prob metrics", {
skip_if_not_installed("terra")
skip_if_not_installed("withr")
l1 <- withr::with_seed(
1107,
matrix(sample(1:2, 100, TRUE), nrow = 10)
)
l1 <- terra::rast(l1)
l2 <- withr::with_seed(
1107,
matrix(runif(100, 0, 1), nrow = 10)
)
l2 <- terra::rast(l2)

r <- c(l1, l2)
names(r) <- c("l1", "l2")

expect_equal(
ww_multi_scale(
r,
truth = "l1",
estimate = "l2",
metrics = list(yardstick::pr_auc),
grids = list(sf::st_make_grid(l1))
)$.estimate,
0.47208959
)
})

test_that("ww_multi_scale with rasters fails if metrics are mixed", {
skip_if_not_installed("terra")
l1 <- terra::rast(matrix(sample(1:10, 100, TRUE), nrow = 10))
l2 <- l1

r <- c(l1, l2)
names(r) <- c("l1", "l2")

expect_error(
ww_multi_scale(
truth = l1,
estimate = l2,
metrics = list(yardstick::precision, yardstick::pr_auc),
grids = list(sf::st_make_grid(l1))
)$.estimate,
class = "waywiser_mixed_metrics"
)

expect_error(
ww_multi_scale(
r,
truth = "l1",
estimate = "l2",
metrics = list(yardstick::precision, yardstick::pr_auc),
grids = list(sf::st_make_grid(l1))
)$.estimate,
class = "waywiser_mixed_metrics"
)
})

0 comments on commit 14591d9

Please sign in to comment.