Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add arguments to $profile() #429

Merged
merged 20 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ Suggests:
bit64,
callr,
data.table,
ggplot2,
knitr,
lubridate,
nanoarrow,
nycflights13,
patrick,
pillar,
rlang,
rmarkdown,
testthat (>= 3.0.0),
tibble,
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
- New method `$write_csv()` for `DataFrame` (#414).
- New method `$sink_csv()` for `LazyFrame` (#432).
- New method `$dt$time()` to extract the time from a `datetime` variable (#428).
- Method `$profile()` gains optimization arguments and plot-related arguments (#429).
- New method `pl$read_parquet()` that is a shortcut for `pl$scan_parquet()$collect()` (#434).
- Rename `$str$str_explode()` to `$str$explode()` (#436).
- New argument `eager` of `LazyFrame$set_optimization_toggle()` (#439).
Expand Down
70 changes: 64 additions & 6 deletions R/lazyframe__lazy.R
Original file line number Diff line number Diff line change
Expand Up @@ -1066,11 +1066,11 @@ LazyFrame_sort = function(
#' table. They must have the same length.
#' @param strategy Strategy for where to find match:
#' * "backward" (default): search for the last row in the right table whose `on`
#' key is less than or equal to the left’s key.
#' key is less than or equal to the left key.
#' * "forward": search for the first row in the right table whose `on` key is
#' greater than or equal to the left’s key.
#' greater than or equal to the left key.
#' * "nearest": search for the last row in the right table whose value is nearest
#' to the left’s key. String keys are not currently supported for a nearest
#' to the left key. String keys are not currently supported for a nearest
#' search.
#' @param tolerance
#' Numeric tolerance. By setting this the join will only be done if the near
Expand Down Expand Up @@ -1360,11 +1360,18 @@ LazyFrame_fetch = function(
#' @description This will run the query and return a list containing the
#' materialized DataFrame and a DataFrame that contains profiling information
#' of each node that is executed.
#'
#' @inheritParams LazyFrame_collect
#' @param show_plot Show a Gantt chart of the profiling result
#' @param truncate_nodes Truncate the label lengths in the Gantt chart to this
#' number of characters. If `0` (default), do not truncate.
#'
#' @details The units of the timings are microseconds.
#'
#' @keywords LazyFrame
#' @return List of two `DataFrame`s: one with the collected result, the other
#' with the timings of each step.
#' with the timings of each step. If `show_graph = TRUE`, then the plot is
#' also stored in the list.
#' @seealso
#' - [`$collect()`][LazyFrame_collect] - regular collect.
#' - [`$fetch()`][LazyFrame_fetch] - fast limited query check
Expand Down Expand Up @@ -1400,8 +1407,59 @@ LazyFrame_fetch = function(
#' group_by("Species", maintain_order = TRUE)$
#' agg(pl$col(pl$Float64)$apply(r_func))$
#' profile()
LazyFrame_profile = function() {
.pr$LazyFrame$profile(self) |> unwrap("in $profile()")
LazyFrame_profile = function(
type_coercion = TRUE,
predicate_pushdown = TRUE,
projection_pushdown = TRUE,
simplify_expression = TRUE,
slice_pushdown = TRUE,
comm_subplan_elim = TRUE,
comm_subexpr_elim = TRUE,
streaming = FALSE,
no_optimization = FALSE,
inherit_optimization = FALSE,
collect_in_background = FALSE,
show_plot = FALSE,
truncate_nodes = 0) {

if (isTRUE(no_optimization)) {
predicate_pushdown = FALSE
projection_pushdown = FALSE
slice_pushdown = FALSE
comm_subplan_elim = FALSE
comm_subexpr_elim = FALSE
}

if (isTRUE(streaming)) {
comm_subplan_elim = FALSE
}

lf = self

if (isFALSE(inherit_optimization)) {
lf = self$set_optimization_toggle(
type_coercion,
predicate_pushdown,
projection_pushdown,
simplify_expression,
slice_pushdown,
comm_subplan_elim,
comm_subexpr_elim,
streaming
) |> unwrap("in $profile():")
}

out = lf |>
.pr$LazyFrame$profile() |>
unwrap("in $profile()")

if (isTRUE(show_plot)) {
sorhawell marked this conversation as resolved.
Show resolved Hide resolved
out[["plot"]] = make_profile_plot(out, truncate_nodes) |>
result() |>
unwrap("in $profile()")
}

out
}

#' @title Explode columns containing a list of values
Expand Down
58 changes: 58 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,61 @@ is_bool = function(x) {
dtypes_are_struct = function(dtypes) {
sapply(dtypes, \(dt) pl$same_outer_dt(dt, pl$Struct()))
}

make_profile_plot = function(data, truncate_nodes) {
if (!requireNamespace("ggplot2", quietly = TRUE)) {
stop('The package "ggplot2" is required.')
}
timings = data$profile$to_data_frame()
timings$node = factor(timings$node, levels = unique(timings$node))
total_timing = max(timings$end)
if (total_timing > 10000000) {
unit = "s"
total_timing = paste0(total_timing/1000000, "s")
timings$start = timings$start / 1000000
timings$end = timings$end / 1000000
} else if (total_timing > 10000) {
unit = "ms"
total_timing = paste0(total_timing/1000, "ms")
timings$start = timings$start / 1000
timings$end = timings$end / 1000
} else {
unit = "\U00B5s"
total_timing = paste0(total_timing, "\U00B5s")
}

# for some reason, there's an error if I use rlang::.data directly in aes()
.data = rlang::.data

plot = ggplot2::ggplot(
timings,
ggplot2::aes(x = .data[["start"]], xend = .data[["end"]],
y = .data[["node"]], yend = .data[["node"]])) +
ggplot2::geom_segment(linewidth = 6) +
ggplot2::xlab(
paste0("Node duration in ", unit, ". Total duration: ", total_timing)
) +
ggplot2::ylab(NULL) +
ggplot2::theme(
axis.text = ggplot2::element_text(size = 12)
)

if (truncate_nodes > 0) {
plot = plot +
ggplot2::scale_y_discrete(
labels = rev(paste0(strtrim(timings$node, truncate_nodes), "...")),
limits = rev
)
} else {
plot = plot +
ggplot2::scale_y_discrete(
limits = rev
)
}

# do not show the plot if we're running testthat
if (!identical(Sys.getenv("TESTTHAT"), "true")) {
print(plot)
}
plot
}
6 changes: 3 additions & 3 deletions man/DataFrame_join_asof.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/LazyFrame_join_asof.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 58 additions & 2 deletions man/LazyFrame_profile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions tests/testthat/test-lazy_profile.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,15 @@ test_that("<LazyFrame>$profile", {
p1$result$as_data_frame()
)
})


test_that("profile: show_plot returns a plot in the list of outputs", {
skip_if_not_installed("ggplot2")
p1 = pl$LazyFrame(iris)$
sort("Sepal.Length")$
group_by("Species", maintain_order = TRUE)$
agg(pl$col(pl$Float64)$first()$add(5)$suffix("_apply"))$
profile(show_plot = TRUE)

expect_length(p1, 3)
})