Skip to content

Commit

Permalink
feat(task): add cifar 10 and cifar 100
Browse files Browse the repository at this point in the history
  • Loading branch information
cxzhang4 authored Jan 9, 2025
1 parent 01fe903 commit 3e477dd
Show file tree
Hide file tree
Showing 12 changed files with 398 additions and 4 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ Collate:
'PipeOpTorchReshape.R'
'PipeOpTorchSoftmax.R'
'Select.R'
'TaskClassif_cifar.R'
'TaskClassif_lazy_iris.R'
'TaskClassif_melanoma.R'
'TaskClassif_mnist.R'
Expand Down
175 changes: 175 additions & 0 deletions R/TaskClassif_cifar.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#' @title CIFAR Classification Tasks
#'
#' @name mlr_tasks_cifar
#'
#' @format [R6::R6Class] inheriting from [mlr3::TaskClassif].
#' @include aaa.R
#'
#' @description
#' The CIFAR-10 and CIFAR-100 datasets. A subset of the 80 million tiny images dataset
#' with noisy labels was supplied to student labelers, who were asked to filter out
#' incorrectly labeled images.
#'
#' CIFAR-10 contains 10 classes. CIFAR-100 contains 100 classes, which may be partitioned into 20 superclasses of 5 classes each.
#' The CIFAR-10 and CIFAR-100 classes are mutually exclusive.
#' See Chapter 3.1 of [the technical report](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf) for more details.
#'
#' The data is obtained from [`torchvision::cifar10_dataset()`] (or `torchvision::cifar100_dataset()`).
#'
#' @section Construction:
#' ```
#' tsk("cifar10")
#' tsk("cifar100")
#' ```
#'
#' @template task_download
#'
#' @section Properties:
#' `r rd_info_task_torch("cifar10", missings = FALSE)`
#'
#' @references
#' `r format_bib("cifar2009")`
#' @examples
#' task_cifar10 = tsk("cifar10")
#' task_cifar100 = tsk("cifar100")
#' print(task_cifar10)
#' print(task_cifar100)
NULL

cifar_ds_generator = torch::dataset(
initialize = function(images) {
self$images = images
},
.getitem = function(idx) {
force(idx)

x = torch_tensor(self$images[idx, , , ])

return(list(x = x))
},
.length = function() {
dim(self$images)[1L]
}
)

constructor_cifar = function(path, type = 10) {
if (type == 10) {
d_train = torchvision::cifar10_dataset(root = path, train = TRUE, download = TRUE)
d_test = torchvision::cifar10_dataset(root = path, train = FALSE, download = FALSE)
class_names = readLines(file.path(path, "cifar-10-batches-bin", "batches.meta.txt"))
class_names = class_names[class_names != ""]
} else if (type == 100) {
d_train = torchvision::cifar100_dataset(root = path, train = TRUE, download = TRUE)
d_test = torchvision::cifar100_dataset(root = path, train = FALSE, download = FALSE)
class_names = readLines(file.path(path, "cifar-100-binary", "fine_label_names.txt"))
}

classes = c(d_train$y, d_test$y)
images = array(NA, dim = c(60000, 3, 32, 32))
# original data has channel dimension at the end
perm_idx = c(1, 4, 2, 3)
images[1:50000, , , ] = aperm(d_train$x, perm_idx, resize = TRUE)
images[50001:60000, , , ] = aperm(d_test$x, perm_idx, resize = TRUE)

return(list(class = factor(classes, labels = class_names), images = images))
}

constructor_cifar10 = function(path) {
require_namespaces("torchvision")

return(constructor_cifar(path, type = 10))
}

load_task_cifar10 = function(id = "cifar10") {
cached_constructor = function(backend) {
data <- cached(constructor_cifar10, "datasets", "cifar10")$data

cifar10_ds = cifar_ds_generator(data$images)

dd = as_data_descriptor(cifar10_ds, list(x = c(NA, 3, 32, 32)))
lt = lazy_tensor(dd)

dt = data.table(
class = data$class,
image = lt,
split = factor(rep(c("train", "test"), c(50000, 10000))),
..row_id = seq_len(60000)
)

DataBackendDataTable$new(data = dt, primary_key = "..row_id")
}

backend = DataBackendLazy$new(
constructor = cached_constructor,
rownames = seq_len(60000),
col_info = load_col_info("cifar10"),
primary_key = "..row_id"
)

task = TaskClassif$new(
backend = backend,
id = "cifar10",
target = "class",
label = "CIFAR-10 Classification"
)

task$col_roles$feature = "image"

backend$hash = "mlr3torch::mlr_tasks_cifar10"
task$man = "mlr3torch::mlr_tasks_cifar"

return(task)
}

register_task("cifar10", load_task_cifar10)

constructor_cifar100 = function(path) {
require_namespaces("torchvision")

return(constructor_cifar(path, type = 100))
}

load_task_cifar100 = function(id = "cifar100") {
cached_constructor = function(backend) {
data = cached(constructor_cifar100, "datasets", "cifar100")$data

cifar100_ds = cifar_ds_generator(data$images)

dd = as_data_descriptor(cifar100_ds, list(x = c(NA, 3, 32, 32)))
lt = lazy_tensor(dd)

dt = data.table(
class = data$class,
image = lt,
split = factor(rep(c("train", "test"), c(50000, 10000))),
..row_id = seq_len(60000)
)

DataBackendDataTable$new(data = dt, primary_key = "..row_id")
}

backend = DataBackendLazy$new(
constructor = cached_constructor,
rownames = seq_len(60000),
col_info = load_col_info("cifar100"),
primary_key = "..row_id"
)

task = TaskClassif$new(
backend = backend,
id = "cifar100",
target = "class",
label = "CIFAR-100 Classification"
)

task$col_roles$feature = "image"

backend$hash = "mlr3torch::mlr_tasks_cifar100"
task$man = "mlr3torch::mlr_tasks_cifar"

return(task)
}

register_task("cifar100", load_task_cifar100)


1 change: 1 addition & 0 deletions R/TaskClassif_melanoma.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#'
#' @references
#' `r format_bib("melanoma2021")`
#' @examples
#' task = tsk("melanoma")
#' task
NULL
Expand Down
2 changes: 1 addition & 1 deletion R/TaskClassif_tiny_imagenet.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' The data is obtained from [`torchvision::tiny_imagenet_dataset()`].
#'
#' The underlying [`DataBackend`][mlr3::DataBackend] contains columns `"class"`, `"image"`, `"..row_id"`, `"split"`, where the last column
#' indicates whether the row belongs to the train, validation or test set that defined provided in torchvision.
#' indicates whether the row belongs to the train, validation or test set that are provided in torchvision.
#'
#' There are no labels for the test rows, so by default, these observations are inactive, which means that the task
#' uses only 110000 of the 120000 observations that are defined in the underlying data backend.
Expand Down
6 changes: 6 additions & 0 deletions R/bibentries.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ bibentries = c(# nolint start
pages = "34",
year = "2021",
doi = "10.1038/s41597-021-00815-z"
),
cifar2009 = bibentry("article",
title = "Learning Multiple Layers of Features from Tiny Images",
author = "Krizhevsky, Alex",
journal= "Master's thesis, Department of Computer Science, University of Toronto",,
year = "2009",
)
) # nolint end

110 changes: 110 additions & 0 deletions data-raw/cifar.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
devtools::load_all()

library(mlr3misc)
library(data.table)
library(torchvision)

# cached
cifar_ds_generator = torch::dataset(
initialize = function(images) {
self$images = images
},
.getitem = function(idx) {
force(idx)

x = torch_tensor(self$images[idx, , , ])

return(list(x = x))
},
.length = function() {
dim(self$images)[1L]
}
)

constructor_cifar = function(path, type = 10) {
if (type == 10) {
d_train = torchvision::cifar10_dataset(root = path, train = TRUE, download = TRUE)
d_test = torchvision::cifar10_dataset(root = path, train = FALSE, download = FALSE)
class_names = readLines(file.path(path, "cifar-10-batches-bin", "batches.meta.txt"))
class_names = class_names[class_names != ""]
} else if (type == 100) {
d_train = torchvision::cifar100_dataset(root = path, train = TRUE, download = TRUE)
d_test = torchvision::cifar100_dataset(root = path, train = FALSE, download = FALSE)
class_names = readLines(file.path(path, "cifar-100-binary", "fine_label_names.txt"))
}

classes = c(d_train$y, d_test$y)
images = array(NA, dim = c(60000, 3, 32, 32))
# original data has channel dimension at the end
perm_idx = c(1, 4, 2, 3)
images[1:50000, , , ] = aperm(d_train$x, perm_idx, resize = TRUE)
images[50001:60000, , , ] = aperm(d_test$x, perm_idx, resize = TRUE)

return(list(class = factor(classes, labels = class_names), images = images))
}

constructor_cifar10 = function(path) {
require_namespaces("torchvision")

return(constructor_cifar(path, type = 10))
}

withr::local_options(mlr3torch.cache = TRUE)
path = file.path(get_cache_dir(), "datasets", "cifar10", "raw")

# begin CIFAR-10
data <- constructor_cifar10(path)

cifar10_ds = cifar_ds_generator(data$images)

dd = as_data_descriptor(cifar10_ds, list(x = c(NA, 3, 32, 32)))
lt = lazy_tensor(dd)

tsk_dt = data.table(
class = data$class,
image = lt,
split = factor(rep(c("train", "test"), c(50000, 10000))),
..row_id = seq_len(60000)
)

# tsk_dt = cbind(data, data.table(image = lt))

tsk_cifar10 = as_task_classif(tsk_dt, target = "class", id = "cifar10")
tsk_cifar10$col_roles$feature = "image"

ci = col_info(tsk_cifar10$backend)

saveRDS(ci, here::here("inst/col_info/cifar10.rds"))
# end CIFAR-10

path = file.path(get_cache_dir(), "datasets", "cifar100", "raw")

# begin CIFAR-100
constructor_cifar100 = function(path) {
require_namespaces("torchvision")

return(constructor_cifar(path, type = 100))
}

data = constructor_cifar100(path)

cifar100_ds = cifar_ds_generator(data$images)

dd = as_data_descriptor(cifar100_ds, list(x = c(NA, 3, 32, 32)))
lt = lazy_tensor(dd)

dt = data.table(
class = data$class,
image = lt,
split = factor(rep(c("train", "test"), c(50000, 10000))),
..row_id = seq_len(60000)
)

task = as_task_classif(dt, target = "class")

task$col_roles$feature = "image"

ci = col_info(task$backend)

saveRDS(ci, here::here("inst/col_info/cifar100.rds"))

Binary file added inst/col_info/cifar10.rds
Binary file not shown.
Binary file added inst/col_info/cifar100.rds
Binary file not shown.
58 changes: 58 additions & 0 deletions man/mlr_tasks_cifar.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions man/mlr_tasks_melanoma.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 3e477dd

Please sign in to comment.