Skip to content

Commit

Permalink
Merge pull request #21 from grantmcdermott/deparse
Browse files Browse the repository at this point in the history
Handle complex formula expressions
  • Loading branch information
grantmcdermott authored Jan 11, 2025
2 parents bf23c14 + a4fd802 commit 72b179b
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 14 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parttree
Title: Visualise simple decision tree partitions
Version: 0.0.1.9005
Version: 0.0.1.9006
Authors@R: c(
person(given = "Grant",
family = "McDermott",
Expand Down
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@ export(geom_parttree)
export(parttree)
importFrom(data.table,":=")
importFrom(data.table,.SD)
importFrom(data.table,as.data.table)
importFrom(data.table,data.table)
importFrom(data.table,fifelse)
importFrom(data.table,rbindlist)
importFrom(data.table,tstrsplit)
importFrom(graphics,par)
importFrom(rpart,path.rpart)
importFrom(stats,reformulate)
importFrom(stats,terms)
importFrom(tinyplot,tinyplot)
importFrom(utils,packageVersion)
5 changes: 3 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# parttree 0.0.1.9005
# parttree 0.0.1.9006

To be released as 0.1.0

#### Breaking changes

* Move ggplot2 to Suggests, following the addition of native (base R)
* Move ggplot2 to Enhances, following the addition of native (base R)
`plot.parttree` method. The `geom_parttree()` function now checks whether
ggplot2 is available on the user's system before executing any code. (#18)
* The `flipaxes` argument has been renamed to `flip`, e.g.
Expand All @@ -23,6 +23,7 @@ by @juliasilge).

* Support for negative values. (#6 by @pjgeens)
* Better handling of single-level factors and `flip(axes)`. (#5)
* Handling of complex formula expressions. (#17)

#### Internals

Expand Down
25 changes: 14 additions & 11 deletions R/parttree.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
#' @returns A data frame comprising seven columns: the leaf node, its path, a
#' set of rectangle limits (i.e., xmin, xmax, ymin, ymax), and a final column
#' corresponding to the predicted value for that leaf.
#' @importFrom data.table := .SD fifelse
#' @importFrom data.table := .SD as.data.table data.table fifelse rbindlist tstrsplit
#' @importFrom rpart path.rpart
#' @importFrom stats reformulate terms
#' @export
#' @examples
#' library("parttree")
Expand Down Expand Up @@ -112,24 +114,25 @@ parttree.rpart =

## Get details about y variable for later
### y variable string (i.e. name)
y_var = paste(tree$terms)[2]
# y_var = attr(tree$terms, "variables")[[2]]
all_terms = terms(tree)
y_var = all.vars(all_terms)[attr(all_terms, "response")]
### y values
yvals = tree$frame[tree$frame$var == "<leaf>", ]$yval
y_factored = attr(tree$terms, "dataClasses")[paste(y_var)] == "factor"
# y_factored = attr(tree$terms, "dataClasses")[paste(y_var)] == "factor"
y_factored = attr(tree$terms, "dataClasses")[y_var] == "factor"
## factor equivalents (if factor)
if (y_factored) {
ylevs = attr(tree, "ylevels")
yvals = ylevs[yvals]
}

part_list = rpart::path.rpart(tree, node=nodes, print.it = FALSE)
part_list = lapply(part_list, data.table::as.data.table)
part_dt = data.table::rbindlist(part_list, idcol="node")[V1!="root"]
part_dt[, c("variable", "split") := data.table::tstrsplit(V1, split = "+<|+<=|>|+>=")][]
part_list = path.rpart(tree, nodes = nodes, print.it = FALSE)
part_list = lapply(part_list, as.data.table)
part_dt = rbindlist(part_list, idcol="node")[V1!="root"]
part_dt[, c("variable", "split") := tstrsplit(V1, split = "+<|+<=|>|+>=")][]
part_dt[, side := gsub("\\s$", "", gsub("\\w|\\.", "", V1))][]

yvals_dt = data.table::data.table(yvals, node = nodes)
yvals_dt = data.table(yvals, node = nodes)

part_dt = part_dt[yvals_dt, on = "node", all = TRUE]
part_dt[, V1 := NULL][, node := as.integer(node)][, split := as.double(split)][]
Expand Down Expand Up @@ -281,7 +284,7 @@ parttree.workflow =
y_name = names(tree$pre$mold$outcomes)[[1]]
raw_data = cbind(tree$pre$mold$predictors, tree$pre$mold$outcomes)
tree = workflows::extract_fit_engine(tree)
tree$terms[[2]] = y_name
tree$terms[[2]] = reformulate(y_name)[[2]]
attr(tree$terms, "variables")[[2]] = y_name
names(attr(tree$terms, "dataClasses"))[[1]] = y_name

Expand Down Expand Up @@ -456,7 +459,7 @@ parttree.constparty =
colnames(rval)[4L:7L] = c("xmin", "xmax", "ymin", "ymax")

## turn into data.table?
if(keep_as_dt) rval = data.table::as.data.table(rval)
if(keep_as_dt) rval = as.data.table(rval)

class(rval) = c("parttree", class(rval))
xvar = ifelse(isFALSE(flip), mx[1], mx[2])
Expand Down

0 comments on commit 72b179b

Please sign in to comment.