Predict values based on fitted policy_tree object.

# S3 method for policy_tree
predict(object, newdata, type = c("", ""), ...)



policy_tree object


Points at which predictions should be made. Note that this matrix should have the same number of columns as the training matrix, and that the columns must appear in the same order.


The type of prediction required, "" is the action id and "" is the integer id of the leaf node the sample falls into. Default is "".


Additional arguments (currently ignored).


A vector of predictions. For type = "" each element is an integer from 1 to d where d is the number of columns in the reward matrix. For type = "" each element is an integer corresponding to the node the sample falls into (level-ordered).


# \donttest{ # Construct doubly robust scores using a causal forest. n <- 10000 p <- 10 # Discretizing continuous covariates decreases runtime for policy learning. X <- round(matrix(rnorm(n * p), n, p), 2) colnames(X) <- make.names(1:p) W <- rbinom(n, 1, 1 / (1 + exp(X[, 3]))) tau <- 1 / (1 + exp((X[, 1] + X[, 2]) / 2)) - 0.5 Y <- X[, 3] + W * tau + rnorm(n) c.forest <- grf::causal_forest(X, Y, W) # Retrieve doubly robust scores. dr.scores <- double_robust_scores(c.forest) # Learn a depth-2 tree on a training set. train <- sample(1:n, n / 2) tree <- policy_tree(X[train, ], dr.scores[train, ], 2) tree
#> policy_tree object #> Tree depth: 2 #> Actions: 1: control 2: treated #> Variable splits: #> (1) split_variable: X2 split_value: 0.17 #> (2) split_variable: X1 split_value: 0.16 #> (4) * action: 2 #> (5) * action: 1 #> (3) split_variable: X9 split_value: -0.02 #> (6) * action: 1 #> (7) * action: 2
# Evaluate the tree on a test set. test <- -train # One way to assess the policy is to see whether the leaf node (group) the test set samples # are predicted to belong to have mean outcomes in accordance with the prescribed policy. # Get the leaf node assigned to each test sample. <- predict(tree, X[test, ], type = "") # Doubly robust estimates of E[Y(control)] and E[Y(treated)] by leaf node. values <- aggregate(dr.scores[test, ], by = list(leaf.node =, FUN = function(dr) c(mean = mean(dr), se = sd(dr) / sqrt(length(dr)))) print(values, digits = 1)
#> leaf.node control.mean treated.mean #> 1 4 0.039 0.043 0.19 0.04 #> 2 5 -0.006 0.052 -0.09 0.05 #> 3 6 -0.078 0.056 -0.07 0.05 #> 4 7 0.012 0.057 -0.09 0.06
# Take cost of treatment into account by, for example, offsetting the objective # with an estimate of the average treatment effect. ate <- grf::average_treatment_effect(c.forest)
#> Warning: Estimated treatment propensities go as high as 0.961 which means that treatment effects for some treated units may not be well identified. In this case, using `target.sample=control` may be helpful.
cost.offset <- ate[["estimate"]] dr.scores[, "treated"] <- dr.scores[, "treated"] - cost.offset tree.cost <- policy_tree(X, dr.scores, 2) # Predict treatment assignment for each sample. predicted <- predict(tree, X) # If there are too many covariates to make tree search computationally feasible, then one # approach is to consider for example only the top features according to GRF's variable importance. var.imp <- grf::variable_importance(c.forest) top.5 <- order(var.imp, decreasing = TRUE)[1:5] tree.top5 <- policy_tree(X[, top.5], dr.scores, 2, split.step = 50) # }