Skip to content

Commit

Permalink
use more performant data frame construction
Browse files Browse the repository at this point in the history
  • Loading branch information
mjskay committed Mar 6, 2024
1 parent 7d6d2d2 commit 7b67244
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 41 deletions.
12 changes: 6 additions & 6 deletions R/abstract_stat_slabinterval.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ AbstractStatSlabinterval = ggproto("AbstractStatSlabinterval", AbstractStat,
# @param trans the scale transformation object applied to the coordinate space
# @param ... other stat parameters created by children of stat_slabinterval
compute_limits = function(self, data, trans, ...) {
data.frame(.lower = NA, .upper = NA)
data_frame0(.lower = NA, .upper = NA)
},

# Compute the function that defines the slab. That takes a data frame of
Expand All @@ -81,7 +81,7 @@ AbstractStatSlabinterval = ggproto("AbstractStatSlabinterval", AbstractStat,
# @param trans the scale transformation object applied to the coordinate space
# @param ... other stat parameters created by children of stat_slabinterval
compute_slab = function(self, data, scales, trans, input, ...) {
data.frame()
data_frame0()
},

# Compute interval(s). Takes a data frame of aesthetics and a `.width`
Expand All @@ -99,7 +99,7 @@ AbstractStatSlabinterval = ggproto("AbstractStatSlabinterval", AbstractStat,
.width, na.rm,
...
) {
if (is.null(point_interval)) return(data.frame())
if (is.null(point_interval)) return(data_frame0())

define_orientation_variables(orientation)

Expand Down Expand Up @@ -128,9 +128,9 @@ AbstractStatSlabinterval = ggproto("AbstractStatSlabinterval", AbstractStat,
is_missing = is.na(data$dist)
if (any(is_missing)) {
data = data[!is_missing, ]
remove_missing(data.frame(dist = ifelse(is_missing, NA_real_, 0)), na.rm, "dist", name = "stat_slabinterval")
remove_missing(data_frame0(dist = ifelse(is_missing, NA_real_, 0)), na.rm, "dist", name = "stat_slabinterval")
}
if (nrow(data) == 0) return(data.frame())
if (nrow(data) == 0) return(data_frame0())


# figure out coordinate transformation
Expand Down Expand Up @@ -181,7 +181,7 @@ AbstractStatSlabinterval = ggproto("AbstractStatSlabinterval", AbstractStat,
...
)
} else {
data.frame(.input = numeric())
data_frame0(.input = numeric())
}
i_data = self$compute_interval(d,
trans = trans,
Expand Down
2 changes: 1 addition & 1 deletion R/binning_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ bin_dots = function(x, y, binwidth,
side = match.arg(side)
orientation = match.arg(orientation)

d = data.frame(x = x, y = y)
d = data_frame0(x = x, y = y)

# after this point `x` and `y` refer to column names in `d` according
# to the orientation
Expand Down
4 changes: 2 additions & 2 deletions R/curve_interval.R
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ curve_interval.matrix = function(
check_along_is_null(.along)

curve_interval(
data.frame(.value = posterior::rvar(.data)), .value,
data_frame0(.value = posterior::rvar(.data)), .value,
.width = .width, na.rm = na.rm,
.interval = .interval
)
Expand All @@ -167,7 +167,7 @@ curve_interval.rvar = function(
check_along_is_null(.along)

curve_interval(
data.frame(.value = .data), .value,
data_frame0(.value = .data), .value,
.width = .width, na.rm = na.rm,
.interval = .interval
)
Expand Down
4 changes: 2 additions & 2 deletions R/geom_slabinterval.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ rescale_slab_thickness = function(
s_data = ggplot2::remove_missing(s_data, na.rm, c(height, "justification", "scale"), name = name, finite = TRUE)
# side is a character vector, thus need finite = FALSE for it; x/y can be Inf here
s_data = ggplot2::remove_missing(s_data, na.rm, c(x, y, "side"), name = name)
if (nrow(s_data) == 0) return(list(data = s_data, subguide_params = data.frame()))
if (nrow(s_data) == 0) return(list(data = s_data, subguide_params = data_frame0()))

min_height = min(s_data[[height]])

Expand All @@ -100,7 +100,7 @@ rescale_slab_thickness = function(

thickness_scale = d$scale[[1]] * min_height

subguide_params = data.frame(
subguide_params = data_frame0(
group = d$group[[1]],
side = d$size[[1]],
justification = d$justification[[1]],
Expand Down
2 changes: 1 addition & 1 deletion R/guide_rampbar.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ guide_rampbar = function(..., to = "gray65", available_aes = c("fill_ramp", "col
if (length(bar) == 0) {
bar = unique(limits)
}
bar = data_frame(
bar = data_frame0(
colour = scale$map(bar),
value = bar,
.size = length(bar)
Expand Down
4 changes: 2 additions & 2 deletions R/point_interval.R
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ point_interval.default = function(.data, ..., .width = .95, .point = median, .in
map_dfr_(seq_len(nrow(row)), function(j) {
# get row of `data` with grouping factors
# faster version of row_j = row[j, , drop = FALSE]
row_j = tibble::new_tibble(lapply(row, vctrs::vec_slice, j), nrow = 1)
row_j = new_data_frame(lapply(row, vctrs::vec_slice, j), n = 1L)
row.names(row_j) = NULL
draws_j = draws[[j]]

Expand Down Expand Up @@ -322,7 +322,7 @@ point_interval.numeric = function(.data, ..., .width = .95, .point = median, .in

result = map_dfr_(.width, function(p) {
interval = .interval(data, .width = p, na.rm = na.rm)
data.frame(
data_frame0(
y = .point(data, na.rm = na.rm),
ymin = interval[, 1],
ymax = interval[, 2],
Expand Down
4 changes: 2 additions & 2 deletions R/stat_dotsinterval.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ compute_slab_dots = function(

dist = data$dist
if (distr_is_missing(dist)) {
return(data.frame(.input = NA_real_, f = NA_real_, n = NA_integer_))
return(data_frame0(.input = NA_real_, f = NA_real_, n = NA_integer_))
}

quantiles = quantiles %||% NA
Expand Down Expand Up @@ -64,7 +64,7 @@ compute_slab_dots = function(
se = 0
}

out = data.frame(
out = data_frame0(
.input = input,
f = 1,
n = length(input)
Expand Down
27 changes: 13 additions & 14 deletions R/stat_slabinterval.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ compute_limits_slabinterval = function(
) {
dist = check_one_dist(data$dist)
if (distr_is_missing(dist)) {
return(data.frame(.lower = NA, .upper = NA))
return(data_frame0(.lower = NA, .upper = NA))
}

if (distr_is_factor_like(dist)) {
# limits on factor-like dists are determined by the scale, which will
# have been set earlier (in layer_slabinterval()), so we don't have to
# do it here
return(data.frame(.lower = NA, .upper = NA))
return(data_frame0(.lower = NA, .upper = NA))
}

if (distr_is_constant(dist)) {
.median = distr_quantile(dist)(0.5)
return(data.frame(.lower = .median, .upper = .median))
return(data_frame0(.lower = .median, .upper = .median))
}

if (distr_is_sample(dist)) {
Expand Down Expand Up @@ -57,7 +57,7 @@ compute_limits_slabinterval = function(
lower_limit = min(quantile_fun(p_limits[[1]]))
upper_limit = max(quantile_fun(p_limits[[2]]))

data.frame(
data_frame0(
.lower = lower_limit,
.upper = upper_limit
)
Expand All @@ -74,7 +74,7 @@ compute_limits_sample = function(x, trans, trim, adjust, ..., density = "bounded
# determine limits of data based on the density estimator
x = trans$transform(x)
x_range = range(density(x, n = 2, range_only = TRUE, trim = trim, adjust = adjust, weights = weights)$x)
data.frame(
data_frame0(
.lower = trans$inverse(x_range[[1]]),
.upper = trans$inverse(x_range[[2]])
)
Expand All @@ -94,7 +94,7 @@ compute_slab_slabinterval = function(
dist = data$dist
# TODO: add support for multivariate distributions
if (distr_is_missing(dist) || distr_is_multivariate(dist)) {
return(data.frame(.input = NA_real_, f = NA_real_, n = NA_integer_))
return(data_frame0(.input = NA_real_, f = NA_real_, n = NA_integer_))
}

# calculate pdf and cdf
Expand Down Expand Up @@ -167,7 +167,7 @@ compute_slab_slabinterval = function(
cdf = cdf_fun(input)
}

data.frame(
data_frame0(
.input = input,
f = get_slab_function(slab_type, list(pdf = pdf, cdf = cdf)),
pdf = pdf,
Expand All @@ -191,7 +191,6 @@ compute_slab_sample = function(
...,
weights = NULL
) {

if (is.integer(x) || inherits(x, "mapped_discrete")) {
# discrete variables are always displayed as histograms
slab_type = "histogram"
Expand All @@ -207,7 +206,7 @@ compute_slab_sample = function(
breaks = breaks, align = align, outline_bars = outline_bars,
weights = weights
)
slab_df = data.frame(
slab_df = data_frame0(
.input = trans$inverse(d$x),
pdf = d$y,
cdf = d$cdf %||% weighted_ecdf(x, weights = weights)(d$x)
Expand All @@ -221,7 +220,7 @@ compute_slab_sample = function(
if (expand[[1]]) {
input_below_slab = input[input < min(slab_df$.input) - .Machine$double.eps]
if (length(input_below_slab) > 0) {
slab_df = rbind(data.frame(
slab_df = rbind(data_frame0(
.input = input_below_slab,
pdf = 0,
cdf = 0
Expand All @@ -231,7 +230,7 @@ compute_slab_sample = function(
if (expand[[2]]) {
input_above_slab = input[input > max(slab_df$.input) + .Machine$double.eps]
if (length(input_above_slab) > 0) {
slab_df = rbind(slab_df, data.frame(
slab_df = rbind(slab_df, data_frame0(
.input = input_above_slab,
pdf = 0,
cdf = 1
Expand All @@ -255,10 +254,10 @@ compute_interval_slabinterval = function(
.width, na.rm,
...
) {
if (is.null(point_interval)) return(data.frame())
if (is.null(point_interval)) return(data_frame0())
dist = data$dist
if (distr_is_missing(dist)) {
return(data.frame(.value = NA_real_, .lower = NA_real_, .upper = NA_real_, .width = .width))
return(data_frame0(.value = NA_real_, .lower = NA_real_, .upper = NA_real_, .width = .width))
}

distr_point_interval(dist, point_interval, trans = trans, .width = .width, na.rm = na.rm)
Expand Down Expand Up @@ -624,7 +623,7 @@ StatSlabinterval = ggproto("StatSlabinterval", AbstractStatSlabinterval,
# dist aesthetic is not provided but x aesthetic is, and x is not a dist
# this means we need to wrap it as a weighted dist_sample
data = summarise_by(data, c("PANEL", y, "group"), function(d) {
data.frame(dist = .dist_weighted_sample(list(trans$inverse(d[[x]])), list(d[["weight"]])))
data_frame0(dist = .dist_weighted_sample(list(trans$inverse(d[[x]])), list(d[["weight"]])))
})
data[[x]] = median(data$dist)
}
Expand Down
6 changes: 3 additions & 3 deletions R/stat_spike.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ compute_slab_spike = function(
# needs to be a vector (e.g. in cases of interval functions
# like qi() which return matrices)
input = unlist(input_nested, use.names = FALSE, recursive = FALSE)
names(input) = rep(names(at), times = lengths(input_nested))
input_names = rep(names(at), times = lengths(input_nested))

# evaluate functions
pdf = pdf_fun(input)
cdf = cdf_fun(input)

data.frame(
data_frame0(
.input = input,
at = names(input),
at = input_names,
f = if (length(input) > 0) get_slab_function(slab_type, list(pdf = pdf, cdf = cdf)),
pdf = pdf,
cdf = cdf,
Expand Down
4 changes: 2 additions & 2 deletions R/subguide.R
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,12 @@ draw_subguide_axis = function(
position = axis_position
)
params = guide$params
params$key = data_frame(
params$key = data_frame0(
!!aes := break_positions,
.value = break_positions,
.label = break_labels
)
params$decor = data_frame(
params$decor = data_frame0(
!!aes := c(0, 1),
!!opp := if (axis_position %in% c("top", "right")) 0 else 1
)
Expand Down
10 changes: 8 additions & 2 deletions R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ check_na = function(x, na.rm) {

# data frames -------------------------------------------------------------

#' fast data frame creation
#' @noRd
data_frame0 = function(...) {
vctrs::data_frame(..., .name_repair = "minimal")
}

#' rename columns using a lookup table
#' @param data data frame
#' @param new_names lookup table of new column names, where names are
Expand Down Expand Up @@ -182,7 +188,7 @@ map_dfr_ = function(data, fun, ...) {
row_map_dfr_ = function(data, fun, ...) {
map_dfr_(seq_len(nrow(data)), function(row_i) {
# faster version of row_df = data[row_i, , drop = FALSE]
row_df = tibble::new_tibble(lapply(data, vctrs::vec_slice, row_i), nrow = 1)
row_df = new_data_frame(lapply(data, vctrs::vec_slice, row_i), n = 1L)
fun(row_df, ...)
})
}
Expand Down Expand Up @@ -221,7 +227,7 @@ dlply_ = function(data, groups, fun, ...) {

lapply(group_is, function(group_i) {
# faster version of row_df = data[group_i, , drop = FALSE]
row_df = tibble::new_tibble(lapply(data, vctrs::vec_slice, group_i), nrow = length(group_i))
row_df = new_data_frame(lapply(data, vctrs::vec_slice, group_i), n = length(group_i))
fun(row_df, ...)
})
} else {
Expand Down
19 changes: 15 additions & 4 deletions tests/testthat/test.util.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# Author: mjskay
###############################################################################

library(dplyr)
suppressPackageStartupMessages(suppressWarnings({
library(dplyr)
}))


test_that("all_names works", {
Expand Down Expand Up @@ -51,12 +53,21 @@ test_that("fct_rev_ works properly", {
# dlply_ ------------------------------------------------------------------

test_that("dlply_ works properly", {
df = tibble(
df = data.frame(
x = 1:8,
g = c(rep("a", 2), rep("(Missing)", 2), rep("(Missing)+", 2), rep(NA, 2))
g = c(rep("a", 2), rep("(Missing)", 2), rep("(Missing)+", 2), rep(NA, 2)),
stringsAsFactors = FALSE
)

expect_equal(dlply_(df, "g", identity), list(df[3:4,], df[5:6,], df[1:2,], df[7:8,]))
expect_equal(
dlply_(df, "g", identity),
list(
new_data_frame(df[3:4,]),
new_data_frame(df[5:6,]),
new_data_frame(df[1:2,]),
new_data_frame(df[7:8,])
)
)

expect_equal(dlply_(df, NULL, identity), list(df))

Expand Down

0 comments on commit 7b67244

Please sign in to comment.