Plot a policy_tree tree object.
# S3 method for policy_tree plot(x, leaf.labels = NULL, ...)
x | The tree to plot. |
---|---|
leaf.labels | An optional character vector of leaf labels for each treatment. |
... | Additional arguments (currently ignored). |
# Plot a policy_tree object if (FALSE) { n <- 250 p <- 10 X <- matrix(rnorm(n * p), n, p) W <- as.factor(sample(c("A", "B", "C"), n, replace = TRUE)) Y <- X[, 1] + X[, 2] * (W == "B") + X[, 3] * (W == "C") + runif(n) multi.forest <- grf::multi_arm_causal_forest(X = X, Y = Y, W = W) Gamma.matrix <- double_robust_scores(multi.forest) tree <- policy_tree(X, Gamma.matrix, depth = 2) plot(tree) # Provide optional names for the treatment names in each leaf node # `action.names` is by default the column names of the reward matrix plot(tree, leaf.labels = tree$action.names) # Providing a custom character vector plot(tree, leaf.labels = c("treatment A", "treatment B", "placebo C")) # Saving a plot in a vectorized SVG format can be done with the `DiagrammeRsvg` package. install.packages("DiagrammeRsvg") tree.plot = plot(tree) cat(DiagrammeRsvg::export_svg(tree.plot), file = 'plot.svg') }