Finds the optimal (maximizing the sum of rewards) depth k tree by exhaustive search. If the optimal action is the same in both the left and right leaf of a node, the node is pruned.
policy_tree( X, Gamma, depth = 2, split.step = 1, min.node.size = 1, verbose = TRUE )
X | The covariates used. Dimension \(N*p\) where \(p\) is the number of features. |
Gamma | The rewards for each action. Dimension \(N*d\) where \(d\) is the number of actions. |
depth | The depth of the fitted tree. Default is 2. |
split.step | An optional approximation parameter, the number of possible splits to consider when performing tree search. split.step = 1 (default) considers every possible split, split.step = 10 considers splitting at every 10'th sample and may yield a substantial speedup for dense features. Manually rounding or re-encoding continuous covariates with very high cardinality in a problem specific manner allows for finer-grained control of the accuracy/runtime tradeoff and may in some cases be the preferred approach. |
min.node.size | An integer indicating the smallest terminal node size permitted. Default is 1. |
verbose | Give verbose output. Default is TRUE. |
A policy_tree object.
Exact tree search is intended as a way to find shallow (i.e. depth 2 or 3) globally optimal tree-based polices on datasets of "moderate" size. The amortized runtime of exact tree search is \(O(p^k n^k (log n + d) + pnlog n)\) where p is the number of features, n the number of distinct observations, d the number of treatments, and k >= 1 the tree depth. Due to the exponents in this expression, exact tree search will not scale to datasets of arbitrary size.
As an example, the runtime of a depth two tree scales quadratically with the number of observations, implying
that doubling the number of samples will quadruple the runtime.
n refers to the number of distinct observations, substantial speedups can be gained
when the features are discrete (with all binary features, the runtime will be ~ linear in n),
and it is therefore beneficial to round down/re-encode very dense data to a lower cardinality
(the optional parameter split.step
emulates this, though rounding/re-encoding allow for finer-grained control).
Athey, Susan, and Stefan Wager. "Policy Learning With Observational Data." Econometrica 89.1 (2021): 133-161.
Sverdrup, Erik, Ayush Kanodia, Zhengyuan Zhou, Susan Athey, and Stefan Wager. "policytree: Policy learning via doubly robust empirical welfare maximization over trees." Journal of Open Source Software 5, no. 50 (2020): 2232.
Zhou, Zhengyuan, Susan Athey, and Stefan Wager. "Offline multi-action policy learning: Generalization and optimization." Operations Research 71.1 (2023).
for building deeper trees.
# \donttest{ # Construct doubly robust scores using a causal forest. n <- 10000 p <- 10 # Discretizing continuous covariates decreases runtime for policy learning. X <- round(matrix(rnorm(n * p), n, p), 2) colnames(X) <- make.names(1:p) W <- rbinom(n, 1, 1 / (1 + exp(X[, 3]))) tau <- 1 / (1 + exp((X[, 1] + X[, 2]) / 2)) - 0.5 Y <- X[, 3] + W * tau + rnorm(n) c.forest <- grf::causal_forest(X, Y, W) # Retrieve doubly robust scores. dr.scores <- double_robust_scores(c.forest) # Learn a depth-2 tree on a training set. train <- sample(1:n, n / 2) tree <- policy_tree(X[train, ], dr.scores[train, ], 2) tree#> policy_tree object #> Tree depth: 2 #> Actions: 1: control 2: treated #> Variable splits: #> (1) split_variable: X1 split_value: -0.82 #> (2) split_variable: X2 split_value: 1.18 #> (4) * action: 2 #> (5) * action: 1 #> (3) split_variable: X2 split_value: -1.34 #> (6) * action: 2 #> (7) * action: 1# Evaluate the tree on a test set. test <- -train # One way to assess the policy is to see whether the leaf node (group) the test set samples # are predicted to belong to have mean outcomes in accordance with the prescribed policy. # Get the leaf node assigned to each test sample. <- predict(tree, X[test, ], type = "") # Doubly robust estimates of E[Y(control)] and E[Y(treated)] by leaf node. values <- aggregate(dr.scores[test, ], by = list(leaf.node =, FUN = function(dr) c(mean = mean(dr), se = sd(dr) / sqrt(length(dr)))) print(values, digits = 1)#> leaf.node control.mean treated.mean #> 1 4 -0.09 0.07 0.11 0.07 #> 2 5 0.07 0.14 -0.11 0.12 #> 3 6 0.21 0.10 0.36 0.11 #> 4 7 0.05 0.03 -0.05 0.03# Take cost of treatment into account by, for example, offsetting the objective # with an estimate of the average treatment effect. ate <- grf::average_treatment_effect(c.forest) cost.offset <- ate[["estimate"]] dr.scores[, "treated"] <- dr.scores[, "treated"] - cost.offset tree.cost <- policy_tree(X, dr.scores, 2) # Predict treatment assignment for each sample. predicted <- predict(tree, X) # If there are too many covariates to make tree search computationally feasible, then one # approach is to consider for example only the top features according to GRF's variable importance. var.imp <- grf::variable_importance(c.forest) top.5 <- order(var.imp, decreasing = TRUE)[1:5] tree.top5 <- policy_tree(X[, top.5], dr.scores, 2, split.step = 50) # }