diff --git a/DESCRIPTION b/DESCRIPTION index 328406d..8da6e0d 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parttree Title: Visualise simple decision tree partitions -Version: 0.0.1.9005 +Version: 0.0.1.9006 Authors@R: c( person(given = "Grant", family = "McDermott", diff --git a/NAMESPACE b/NAMESPACE index 18e0a04..1246915 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -11,8 +11,14 @@ export(geom_parttree) export(parttree) importFrom(data.table,":=") importFrom(data.table,.SD) +importFrom(data.table,as.data.table) +importFrom(data.table,data.table) importFrom(data.table,fifelse) +importFrom(data.table,rbindlist) +importFrom(data.table,tstrsplit) importFrom(graphics,par) +importFrom(rpart,path.rpart) importFrom(stats,reformulate) +importFrom(stats,terms) importFrom(tinyplot,tinyplot) importFrom(utils,packageVersion) diff --git a/NEWS.md b/NEWS.md index 47d0414..11f147d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,10 +1,10 @@ -# parttree 0.0.1.9005 +# parttree 0.0.1.9006 To be released as 0.1.0 #### Breaking changes -* Move ggplot2 to Suggests, following the addition of native (base R) +* Move ggplot2 to Enhances, following the addition of native (base R) `plot.parttree` method. The `geom_parttree()` function now checks whether ggplot2 is available on the user's system before executing any code. (#18) * The `flipaxes` argument has been renamed to `flip`, e.g. @@ -23,6 +23,7 @@ by @juliasilge). * Support for negative values. (#6 by @pjgeens) * Better handling of single-level factors and `flip(axes)`. (#5) +* Handling of complex formula expressions. (#17) #### Internals diff --git a/R/parttree.R b/R/parttree.R index ded35f4..3a66551 100644 --- a/R/parttree.R +++ b/R/parttree.R @@ -20,7 +20,9 @@ #' @returns A data frame comprising seven columns: the leaf node, its path, a #' set of rectangle limits (i.e., xmin, xmax, ymin, ymax), and a final column #' corresponding to the predicted value for that leaf. -#' @importFrom data.table := .SD fifelse +#' @importFrom data.table := .SD as.data.table data.table fifelse rbindlist tstrsplit +#' @importFrom rpart path.rpart +#' @importFrom stats reformulate terms #' @export #' @examples #' library("parttree") @@ -112,24 +114,25 @@ parttree.rpart = ## Get details about y variable for later ### y variable string (i.e. name) - y_var = paste(tree$terms)[2] - # y_var = attr(tree$terms, "variables")[[2]] + all_terms = terms(tree) + y_var = all.vars(all_terms)[attr(all_terms, "response")] ### y values yvals = tree$frame[tree$frame$var == "", ]$yval - y_factored = attr(tree$terms, "dataClasses")[paste(y_var)] == "factor" + # y_factored = attr(tree$terms, "dataClasses")[paste(y_var)] == "factor" + y_factored = attr(tree$terms, "dataClasses")[y_var] == "factor" ## factor equivalents (if factor) if (y_factored) { ylevs = attr(tree, "ylevels") yvals = ylevs[yvals] } - part_list = rpart::path.rpart(tree, node=nodes, print.it = FALSE) - part_list = lapply(part_list, data.table::as.data.table) - part_dt = data.table::rbindlist(part_list, idcol="node")[V1!="root"] - part_dt[, c("variable", "split") := data.table::tstrsplit(V1, split = "+<|+<=|>|+>=")][] + part_list = path.rpart(tree, nodes = nodes, print.it = FALSE) + part_list = lapply(part_list, as.data.table) + part_dt = rbindlist(part_list, idcol="node")[V1!="root"] + part_dt[, c("variable", "split") := tstrsplit(V1, split = "+<|+<=|>|+>=")][] part_dt[, side := gsub("\\s$", "", gsub("\\w|\\.", "", V1))][] - yvals_dt = data.table::data.table(yvals, node = nodes) + yvals_dt = data.table(yvals, node = nodes) part_dt = part_dt[yvals_dt, on = "node", all = TRUE] part_dt[, V1 := NULL][, node := as.integer(node)][, split := as.double(split)][] @@ -281,7 +284,7 @@ parttree.workflow = y_name = names(tree$pre$mold$outcomes)[[1]] raw_data = cbind(tree$pre$mold$predictors, tree$pre$mold$outcomes) tree = workflows::extract_fit_engine(tree) - tree$terms[[2]] = y_name + tree$terms[[2]] = reformulate(y_name)[[2]] attr(tree$terms, "variables")[[2]] = y_name names(attr(tree$terms, "dataClasses"))[[1]] = y_name @@ -456,7 +459,7 @@ parttree.constparty = colnames(rval)[4L:7L] = c("xmin", "xmax", "ymin", "ymax") ## turn into data.table? - if(keep_as_dt) rval = data.table::as.data.table(rval) + if(keep_as_dt) rval = as.data.table(rval) class(rval) = c("parttree", class(rval)) xvar = ifelse(isFALSE(flip), mx[1], mx[2])