In addition to providing out-of-the-box forests for quantile regression and instrumental variables, grf provides a framework for creating forests tailored to new statistical tasks. Certain components around splitting and prediction can be swapped out, within the general infrastructure for growing and predicting on trees.


This repository follows the standard open source protocol and setup with git where there is an abundance of existing resources to get up to speed (see for example the contributing guidelines for well known packages in other languages, like Scikit-learn, Scipy, and pandas)

Condensed greatly, the workflow is to fork this repository, check out a branch, commit your changes (forming an ideally legible commit history), then submitting a pull request explaining your contribution, ideally referring to the issue you created, or the issue you chose to work on.

Working with the code

The core forest implementation is written in C++, with an R interface powered by Rcpp. We recommend using a full-powered C++ IDE such as CLion, Xcode, or Visual Studio when working with the core code.

R package

To build the R package from source, cd into r-package and run build_package.R. Required development dependencies are listed there. This mimics the tests run when submitting a pull request. Additional online package documentation is built using continuous integration with pkgdown. Usage examples in the form of R Markdown files under grf\vignettes are built and rendered and the R method reference (along with the package articles) is displayed according to the layout defined in pkgdown.yml. To build the site locally run pkgdown::build_site() from the R package directory.

An alternative development workflow is to use the accompanying grf.Rproj and build and test the package with RStudio’s build menu, which can be convenient for quickly iterating C++/R code changes. Compiling the package locally with PKG_CPPFLAGS="-UNDEBUG" (or set in ~/.R/Makevars) may give helpful debug assertions, as the Eigen library will then perform bounds checks on matrix algebra.

CRAN releases have a 3-digit version number, while the development version has 4 digits. This convention allows one to easily distinguish if a package was installed from CRAN or from GitHub (R does not allow version numbers with a “devel” tacked on).

Note for Windows users:

Symlinks in the src directory point to the core C++ and R bindings. On Windows one has to clone this repository with symlinks enabled: git clone -c core.symlinks=true (this command needs to be run as an administrator: right click Command Prompt -> Run as administrator). Caveat: the above RStudio workflow is not tested on Windows.

Core C++

Code structure

GRF Architecture Diagram

The forest implementation is composed of two top-level components, ForestTrainer and ForestPredictor.

ForestTrainer drives the tree-growing process, and has two pluggable components.

  • RelabelingStrategy is applied before every split, and produces a set of relabelled outcomes given the observations for a group of samples. In the case of quantile forests, for example, this strategy computes the quantiles for the group of samples, then relabels them with a factor corresponding to the quantile they belong to.
  • SplittingRule is called to find the best split for a particular node, given a set of outcomes. There are currently implementations for standard regression and multinomial splitting.

The trained forest produces a Forest object. This can then be passed to the ForestPredictor to predict on test samples. The predictor has a pluggable ‘prediction strategy’, which computes a prediction given a test sample. Prediction strategies can be one of two types:

  • DefaultPredictionStrategy computes a prediction given a weighted list of training sample IDs that share a leaf with the test sample. Taking quantile forests as an example, this strategy would compute the quantiles of the weighted leaf samples.
  • OptimizedPredictionStrategy does not predict using a list of neighboring samples and weights, but instead precomputes summary values for each leaf during training, and uses these during prediction. This type of strategy will also be passed to ForestTrainer, so it can define how the summary values are computed.

Prediction strategies can also compute variance estimates for the predictions, given a forest trained with grouped trees. Because of performance constraints, only ‘optimized’ prediction strategies can provide variance estimates.

A particular type of forest is created by pulling together a set of pluggable components. As an example, a quantile forest is composed of a QuantileRelabelingStrategy, ProbabilitySplittingRule, and QuantilePredictionStrategy. The factory classes ForestTrainers and ForestPredictors define the common types of forests like regression, quantile, and causal forests.

Creating a custom forest

When designing a forest for a new statistical task, we suggest following the format suggested in the paper: provide a custom RelabelingStrategy and PredictionStrategy, but use the standard RegressionSplittingRule. If modifications to the splitting rule must be made, they should be done carefully, as it is an extremely performance-sensitive piece of code.

To get started with setting up new classes and Rcpp bindings, we suggest having a look at one of the simpler forests, like regression_forest, and use this as a template.

Current forests and main components

The following table shows the current collection of forests implemented and the C++ components.

R forest name RelabelingStrategy SplittingStrategy PredictionStrategy
causal_forest InstrumentalRelabelingStrategy InstrumentalSplittingRule InstrumentalPredictionStrategy
causal_forest with ll_causal_predict InstrumentalRelabelingStrategy InstrumentalSplittingRule LLCausalPredictionStrategy
causal_survival_forest CausalSurvivalRelabelingStrategy CausalSurvivalSplittingRule CausalSurvivalPredictionStrategy
instrumental_forest InstrumentalRelabelingStrategy InstrumentalSplittingRule InstrumentalPredictionStrategy
ll_regression_forest LLRegressionRelabelingStrategy RegressionSplittingRule LocalLinearPredictionStrategy
multi_arm_causal_forest MultiCausalRelabelingStrategy MultiCausalSplittingRule MultiCausalPredictionStrategy
multi_regression_forest MultiNoopRelabelingStrategy MultiRegressionSplittingRule MultiRegressionPredictionStrategy
probability_forest NoopRelabelingStrategy ProbabilitySplittingRule ProbabilityPredictionStrategy
quantile_forest QuantileRelabelingStrategy ProbabilitySplittingRule QuantilePredictionStrategy
regression_forest NoopRelabelingStrategy RegressionSplittingRule RegressionPredictionStrategy
survival_forest NoopRelabelingStrategy SurvivalSplittingRule SurvivalPredictionStrategy

Tree Splitting Algorithm

The follow section outlines pseudocode for some of the components listed above.

Each splitting rule follow the same pattern: given a set of samples, and a possible split variable, iterate over all the split points for that variable and compute an error criterion on both sides of the split, then select the split where the decrease in the summed error criterion is smallest. For RegressionSplittingRule the criterion is the mean squared error, and is usually referred to as standard CART splitting. InstrumentalSplittingRule use the same error criterion, but it incorporates more constraints on potential splits, such as requiring a certain number of treated and control observations (more details are in the Algorithm Reference).

Two added features to grf’s splitting rule beyond basic CART is the addition of sample weights and support for missing values with Missingness Incorporated in Attributes (MIA) splitting. The algebra behind the decrease calculation is detailed below.

Algorithm (RegressionSplittingRule): find the best split for variable var at n samples samples = i...j

Input: X: covariate matrix, weights: weight n-vector, response: the responses (pseudo-outcomes) n-vector, min_child_size: split constraint on number of samples

Output: The best split value and the direction to send missing values

let x = X[samples, var]

for each unique value u in x:
  let count_left[u] be the number of samples with x <= u
  let sum_left[u], weight_sum_left[u] be the sum of weight * response and weight with x <= u

  let count_right[u], sum_right[u], weight_sum_right[u] be the same for samples with x > u
  let count_missing, sum_missing, weight_sum_missing be the same for samples with x = missing

  decrease[u, missing = on_left] = (sum_left[u] + sum_missing)^2 / (weight_sum_left[u] + weight_sum_missing)
                                  + sum_right[u]^2 / weight_sums_right[u]

  decrease[u, missing = on_right] = sum_left[u]^2 / weight_sum_left[u]
                                  + (sum_right[u] + sum_missing)^2 / (weight_sums_right[u] + weight_sum_missing)

return (u, missing) that maximizes decrease such that                                 
count_left[u] and count_right[u] >= min_child_size

Algorithm (Decrease calculation in weighted CART): find the best split Sj for variable Xj. The impurity criterion is the sum of weighted mean squared errors, which for a specific split Sj is minimized by the weighted sample average Ybar (see for example Elements of Statistical Learning, Ch 9).

Input: Xj : Covariate j, w: sample weight vector, y: response vector

Output: The best split criterion

minimize wrt. s: MSE(s, left) + MSE(s, right)


MSE(s, left) = \sum_{i \in L(s)} wi [yi - Ybar_L(s)]^2
MSE[s, right] = \sum_{i \in R(s)} wi [yi - Ybar_R(s)]^2

L(s) are all points Xj such that Xj <= s, and R(s) are all points such that Xj > s.
Ybar_L(s) and Ybar_R(s) are the weighted sample averages for the respective partitions.

Expand the child impurity and write out the averages (minimize wrt. s):
= \sum_{i \in L(s)} wi yi^2 - W_L(s) [Ybar_L(s)]^2
 + \sum_{i \in R(s)} wi yi^2 - W_R(s) [Ybar_R(s)]^2
= \sum_i wi yi^2
  - 1/W_L(s) [\sum_{i \in L(s)} wi yi]^2
  - 1/W_R(s) [\sum_{i \in R(s)} wi yi]^2

Where W is the sum of sample weights.

The parent impurity is (minimize wrt. s):
= \sum_i wi [yi - Ybar]^2
= \sum_i wi yi^2 - W Ybar^2
= \sum_i wi yi^2 - 1/W [\sum_i wi yi]^2

Child impurity <= parent impurity then reduces to

1/W_L(s) [\sum_{i \in L(s)} wi yi]^2 + 1/W_R(s) [\sum_{i \in R(s)} wi yi]^2 >= 1/W [\sum_i wi yi]^2

Or sum_left^2 / weight_sum_left + sum_right^2 / weight_sum_right

For multivariate CART sum_left^2 and sum_right^2 becomes the squared L2 norm.

Algorithm (SurvivalSplittingRule): find the best split for variable var at n samples samples = i...j.

Input: X: covariate matrix, response: the failure times, min_child_size: split constraint on number of failures, m: the number of failure times in this node.

Output: The best split value and the direction to send missing values

grf performs survival splits by finding the partition that maximizes the log-rank test for the comparison of the survival distribution in the left and right node. Using the notation from Ishwaran et. al (2008), equation 1 the statistic is:

logrank(x, split.value) =  sum over all times k up to m
  (dk,l - Yk,l * dk/Yk)  /
  (Yk,l / Yk * (1 - Yk,l / Yk) * (Yk - dk) / (Yk - 1) dk)

dk,l: the count of failures in the left node at time k
Yk,l: the count of observations at risk in the left node at time k (At risk: all i such that Ti >= k)
dk: the count of failures in the parent node at time k
Yk: the count of observations at risk in the parent node at time k

Note that the absolute value of the response variable (the failure time) is never used, only counts matter, constructed from the relative ordering between them. Therefore, the response variable is always relabeled to range from consecutive integers starting at zero and ending at the number of failures. This can be done quickly with binary search:

let failure_times be a sorted vector of the original failure values (T1, T2, .. TM)
initialize response_relabeled = zeros(n)
for each i=1:n
  relabeled_failure_time = binary_search(response[i], failure_times)
  response_relabeled[i] = relabeled_failure_time

A label of 0 is given to all samples with response less than the first failure time, a value of 1 is given to all samples with failure value greater or equal to the first failure value and less than the second failure value, etc.

The algorithm proceeds in the same manner as outlined above for RegressionSplitting: iterate over all possible splits and calculate the logrank statistic, one with all missing values on the left, and one with all missing on the right. Then select the split that yielded the maximum logrank test, subject to a constraint that there are a sufficient amount of failures on both sides of the split.

Algorithm (ProbabilitySplittingRule): This splitting rule uses the Gini impurity measure from CART for categorical data. Sample weights are incorporated by counting the sample weight of an observation when forming class counts. The missing values adjustment is the same as for the other splitting rules.