Skip to content

Commit

Permalink
Merge pull request #1430 from rstudio/fix-application_preprocess_inputs
Browse files Browse the repository at this point in the history
convert input to a writeable numpy array in preprocess_inputs
  • Loading branch information
t-kalinowski authored Apr 15, 2024
2 parents 778e39f + 6944a25 commit 3592976
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 1 deletion.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ export(image_from_array)
export(image_load)
export(image_smart_resize)
export(image_to_array)
export(imagenet_decode_predictions)
export(imagenet_preprocess_input)
export(initializer_constant)
export(initializer_glorot_normal)
export(initializer_glorot_uniform)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ New functions:

- `layer_lstm()` and `layer_gru()` gain arg `use_cudnn`, default `'auto'`.

- Fixed an issue where `application_preprocess_inputs()` would error if supplied
an R array as input.

- Doc improvements.

# keras3 0.1.0
Expand Down
83 changes: 82 additions & 1 deletion R/applications.R
Original file line number Diff line number Diff line change
Expand Up @@ -3953,7 +3953,88 @@ list_model_names <- function() {
}

set_preprocessing_attributes <- function(object, module) {
attr(object, "preprocess_input") <- module$preprocess_input
.preprocess_input <- r_to_py(module)$preprocess_input

attr(object, "preprocess_input") <-
as.function.default(c(formals(.preprocess_input), bquote({
args <- capture_args(list(
x = function(x) {
if (!is_py_object(x))
x <- np_array(x)
if (inherits(x, "numpy.ndarray") &&
!py_bool(x$flags$writeable))
x <- x$copy()
x
}
))
do.call(.(.preprocess_input), args)
})), envir = parent.env(environment()))

attr(object, "decode_predictions") <- module$decode_predictions
object
}


#' Decodes the prediction of an ImageNet model.
#'
#' @param preds Tensor encoding a batch of predictions.
#' @param top integer, how many top-guesses to return.
#'
#' @return List of data frames with variables `class_name`, `class_description`,
#' and `score` (one data frame per sample in batch input).
#'
#' @export
#' @keywords internal
imagenet_decode_predictions <- function(preds, top = 5) {

# decode predictions
decoded <- keras$applications$imagenet_utils$decode_predictions(
preds = preds,
top = as.integer(top)
)

# convert to a list of data frames
lapply(decoded, function(x) {
m <- t(sapply(1:length(x), function(n) x[[n]]))
data.frame(class_name = as.character(m[,1]),
class_description = as.character(m[,2]),
score = as.numeric(m[,3]),
stringsAsFactors = FALSE)
})
}


#' Preprocesses a tensor or array encoding a batch of images.
#'
#' @param x Input Numpy or symbolic tensor, 3D or 4D.
#' @param data_format Data format of the image tensor/array.
#' @param mode One of "caffe", "tf", or "torch"
#' - caffe: will convert the images from RGB to BGR,
#' then will zero-center each color channel with
#' respect to the ImageNet dataset,
#' without scaling.
#' - tf: will scale pixels between -1 and 1, sample-wise.
#' - torch: will scale pixels between 0 and 1 and then
#' will normalize each channel with respect to the
#' ImageNet dataset.
#'
#' @return Preprocessed tensor or array.
#'
#' @export
#' @keywords internal
imagenet_preprocess_input <- function(x, data_format = NULL, mode = "caffe") {
args <- capture_args(list(
x = function(x) {
if (!is_py_object(x))
x <- np_array(x)
if (inherits(x, "numpy.ndarray") &&
!py_bool(x$flags$writeable))
x <- x$copy()
x
}
))

preprocess_input <- r_to_py(keras$applications$imagenet_utils)$preprocess_input
do.call(preprocess_input, args)
}

21 changes: 21 additions & 0 deletions man/imagenet_decode_predictions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions man/imagenet_preprocess_input.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 3592976

Please sign in to comment.