library(grf)

The following example demonstrates how the number of trees affects the variance estimates.

n <- 2000
p <- 10
X <- matrix(rnorm(n * p), n, p)
X.test <- matrix(0, 101, p)
X.test[, 1] <- seq(-2, 2, length.out = 101)

W <- rbinom(n, 1, 0.4 + 0.2 * (X[, 1] > 0))
Y <- pmax(X[, 1], 0) * W + X[, 2] + pmin(X[, 3], 0) + rnorm(n)

num.trees.grid <- c(10, 20, 30, 40, 100, 500, 1000, 2000, 3000, 4000)
median.variances <- c()
for (num.trees in num.trees.grid) {
  tau.forest <- causal_forest(X, Y, W, num.trees = num.trees)
  hn <- median(predict(tau.forest, estimate.variance = TRUE)$variance.estimates, na.rm = TRUE)
  median.variances <- c(median.variances, hn)
  print(hn)
}
#> [1] 0.1046173
#> [1] 0.09344524
#> [1] 0.07327336
#> [1] 0.07266573
#> [1] 0.0496021
#> [1] 0.02753671
#> [1] 0.02338877
#> [1] 0.01892714
#> [1] 0.0201111
#> [1] 0.01866248
plot(
  x = num.trees.grid,
  y = median.variances,
  main = "Median prediction variances",
  xlab = "num.trees"
)
lines(num.trees.grid, median.variances)