18  Machine Learning with mlr3

The mlr3 package is an R package that makes it easy for you to implement standard machine learning procedures (e.g., training, prediction, cross-validation) in a unified framework. It is similar to the Python scikit-learn package. In this section, we cover the basic use cases of the mlr3 package accompanied by an real-world example at the end. The authors of the package is currently writing a book on how to use it. Materials covered in this section is basically a condensed (and far less comprehensive) version of the book. Yet, they should still be sufficient for the majority of ML tasks you typically implement in practice.

To start working with mlr3, it is convenient to load the mlr3verse package, which is a collection of packages that provide useful extensions (e.g., quick visualization (mlr3viz), machine learning methods (mlr3leaners), etc). It is like the tidyverse package.

See here for the list of included packages.

library(mlr3verse)
Note

An alternative ML utility package in R is the tidymodels package. In this book, we do not cover how to use the package. Interested readers can read an excellent free on-line book here.

For those who have been using solely R for their programming needs are not likely to be familiar with the way mlr3 works. It uses R6 classes provided the R6 package. The package provides an implementation of encapsulated object-oriented programing, which how Python works. So, if you are familiar with Python, then mlr3 should come quite natural to you. Fortunately, understanding how R6 works is not too hard (especially for us who are just using it). Reading the introduction of the R6 provided here should suffice.

To implement ML tasks, we need two core components at least: task, and learner. Roughly speaking, here are what they are.

We will take a deeper look at these two components first, and then training, prediction, and other ML tasks.

18.1 Tasks

Packages to load for replication

library(data.table)
library(DoubleML)
library(mlr3verse)
library(tidyverse)
library(mltools)
library(parallel)

A task in the mlr3 parlance is basically a dataset (called backend) with information on which variable is the dependent variable (called target) and which variables are explanatory variables (called features).

Here, we will use mtcars data.

#=== load mtcars ===#
data("mtcars", package = "datasets")

#=== see the first 6 rows ===#
head(mtcars)
                   mpg cyl disp  hp drat    wt  qsec vs am gear carb
Mazda RX4         21.0   6  160 110 3.90 2.620 16.46  0  1    4    4
Mazda RX4 Wag     21.0   6  160 110 3.90 2.875 17.02  0  1    4    4
Datsun 710        22.8   4  108  93 3.85 2.320 18.61  1  1    4    1
Hornet 4 Drive    21.4   6  258 110 3.08 3.215 19.44  1  0    3    1
Hornet Sportabout 18.7   8  360 175 3.15 3.440 17.02  0  0    3    2
Valiant           18.1   6  225 105 2.76 3.460 20.22  1  0    3    1

When you create a task, you need to recognize what type of analysis you will be doing and use an appropriate class. Here, we will be running regression, so we use the TaskRegr class (see here for other task types).

Now, let’s create a task using the TaskRegr class with

  • mtcars as backend (data)
  • mpg as target (dependent variable)
  • example as id (this is the id for the task, you can give any name)

You can instantiate a new TaskRegr instance using the new() methods on the TaskRegr class.

(
reg_task <-
  TaskRegr$new(
    id = "example",
    backend = mtcars,
    target = "mpg"
  )
)
<TaskRegr:example> (32 x 11)
* Target: mpg
* Properties: -
* Features (10):
  - dbl (10): am, carb, cyl, disp, drat, gear, hp, qsec, vs, wt

As you can see, mpg is the Target and the rest was automatically set to Features.

You can use the cole_roles() method to return the roles of the variables.

reg_task$col_roles 
$feature
 [1] "am"   "carb" "cyl"  "disp" "drat" "gear" "hp"   "qsec" "vs"   "wt"  

$target
[1] "mpg"

$name
character(0)

$order
character(0)

$stratum
character(0)

$group
character(0)

$weight
character(0)

You can extract information from the task using various methods:

  • $nrow: returns the number of rows
  • $ncol: returns the number of columns
  • $feature_names: returns the name of the feature variables
  • $target_names: returns the name of the target variable(s)
  • $row_ids: return row ids (integers starting from 1 to the number of rows)
  • $data(): returns the backend (data) as a data.table

Let’s see some of these.

#=== row ids ===#
reg_task$row_ids
 [1]  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
[26] 26 27 28 29 30 31 32
#=== data ===#
reg_task$data()
     mpg am carb cyl  disp drat gear  hp  qsec vs    wt
 1: 21.0  1    4   6 160.0 3.90    4 110 16.46  0 2.620
 2: 21.0  1    4   6 160.0 3.90    4 110 17.02  0 2.875
 3: 22.8  1    1   4 108.0 3.85    4  93 18.61  1 2.320
 4: 21.4  0    1   6 258.0 3.08    3 110 19.44  1 3.215
 5: 18.7  0    2   8 360.0 3.15    3 175 17.02  0 3.440
 6: 18.1  0    1   6 225.0 2.76    3 105 20.22  1 3.460
 7: 14.3  0    4   8 360.0 3.21    3 245 15.84  0 3.570
 8: 24.4  0    2   4 146.7 3.69    4  62 20.00  1 3.190
 9: 22.8  0    2   4 140.8 3.92    4  95 22.90  1 3.150
10: 19.2  0    4   6 167.6 3.92    4 123 18.30  1 3.440
11: 17.8  0    4   6 167.6 3.92    4 123 18.90  1 3.440
12: 16.4  0    3   8 275.8 3.07    3 180 17.40  0 4.070
13: 17.3  0    3   8 275.8 3.07    3 180 17.60  0 3.730
14: 15.2  0    3   8 275.8 3.07    3 180 18.00  0 3.780
15: 10.4  0    4   8 472.0 2.93    3 205 17.98  0 5.250
16: 10.4  0    4   8 460.0 3.00    3 215 17.82  0 5.424
17: 14.7  0    4   8 440.0 3.23    3 230 17.42  0 5.345
18: 32.4  1    1   4  78.7 4.08    4  66 19.47  1 2.200
19: 30.4  1    2   4  75.7 4.93    4  52 18.52  1 1.615
20: 33.9  1    1   4  71.1 4.22    4  65 19.90  1 1.835
21: 21.5  0    1   4 120.1 3.70    3  97 20.01  1 2.465
22: 15.5  0    2   8 318.0 2.76    3 150 16.87  0 3.520
23: 15.2  0    2   8 304.0 3.15    3 150 17.30  0 3.435
24: 13.3  0    4   8 350.0 3.73    3 245 15.41  0 3.840
25: 19.2  0    2   8 400.0 3.08    3 175 17.05  0 3.845
26: 27.3  1    1   4  79.0 4.08    4  66 18.90  1 1.935
27: 26.0  1    2   4 120.3 4.43    5  91 16.70  0 2.140
28: 30.4  1    2   4  95.1 3.77    5 113 16.90  1 1.513
29: 15.8  1    4   8 351.0 4.22    5 264 14.50  0 3.170
30: 19.7  1    6   6 145.0 3.62    5 175 15.50  0 2.770
31: 15.0  1    8   8 301.0 3.54    5 335 14.60  0 3.570
32: 21.4  1    2   4 121.0 4.11    4 109 18.60  1 2.780
     mpg am carb cyl  disp drat gear  hp  qsec vs    wt

It is possible to retrieve only a portion of the data using rows and cols options inside data() as follows:

reg_task$data(rows = 1:10, cols = c("mpg", 'wt'))
     mpg    wt
 1: 21.0 2.620
 2: 21.0 2.875
 3: 22.8 2.320
 4: 21.4 3.215
 5: 18.7 3.440
 6: 18.1 3.460
 7: 14.3 3.570
 8: 24.4 3.190
 9: 22.8 3.150
10: 19.2 3.440

To retrieve the complete data from the task, you can apply as.data.table() to the task.

(
data_extracted <- as.data.table(reg_task)
)
     mpg am carb cyl  disp drat gear  hp  qsec vs    wt
 1: 21.0  1    4   6 160.0 3.90    4 110 16.46  0 2.620
 2: 21.0  1    4   6 160.0 3.90    4 110 17.02  0 2.875
 3: 22.8  1    1   4 108.0 3.85    4  93 18.61  1 2.320
 4: 21.4  0    1   6 258.0 3.08    3 110 19.44  1 3.215
 5: 18.7  0    2   8 360.0 3.15    3 175 17.02  0 3.440
 6: 18.1  0    1   6 225.0 2.76    3 105 20.22  1 3.460
 7: 14.3  0    4   8 360.0 3.21    3 245 15.84  0 3.570
 8: 24.4  0    2   4 146.7 3.69    4  62 20.00  1 3.190
 9: 22.8  0    2   4 140.8 3.92    4  95 22.90  1 3.150
10: 19.2  0    4   6 167.6 3.92    4 123 18.30  1 3.440
11: 17.8  0    4   6 167.6 3.92    4 123 18.90  1 3.440
12: 16.4  0    3   8 275.8 3.07    3 180 17.40  0 4.070
13: 17.3  0    3   8 275.8 3.07    3 180 17.60  0 3.730
14: 15.2  0    3   8 275.8 3.07    3 180 18.00  0 3.780
15: 10.4  0    4   8 472.0 2.93    3 205 17.98  0 5.250
16: 10.4  0    4   8 460.0 3.00    3 215 17.82  0 5.424
17: 14.7  0    4   8 440.0 3.23    3 230 17.42  0 5.345
18: 32.4  1    1   4  78.7 4.08    4  66 19.47  1 2.200
19: 30.4  1    2   4  75.7 4.93    4  52 18.52  1 1.615
20: 33.9  1    1   4  71.1 4.22    4  65 19.90  1 1.835
21: 21.5  0    1   4 120.1 3.70    3  97 20.01  1 2.465
22: 15.5  0    2   8 318.0 2.76    3 150 16.87  0 3.520
23: 15.2  0    2   8 304.0 3.15    3 150 17.30  0 3.435
24: 13.3  0    4   8 350.0 3.73    3 245 15.41  0 3.840
25: 19.2  0    2   8 400.0 3.08    3 175 17.05  0 3.845
26: 27.3  1    1   4  79.0 4.08    4  66 18.90  1 1.935
27: 26.0  1    2   4 120.3 4.43    5  91 16.70  0 2.140
28: 30.4  1    2   4  95.1 3.77    5 113 16.90  1 1.513
29: 15.8  1    4   8 351.0 4.22    5 264 14.50  0 3.170
30: 19.7  1    6   6 145.0 3.62    5 175 15.50  0 2.770
31: 15.0  1    8   8 301.0 3.54    5 335 14.60  0 3.570
32: 21.4  1    2   4 121.0 4.11    4 109 18.60  1 2.780
     mpg am carb cyl  disp drat gear  hp  qsec vs    wt

You can mutate tasks using the select() and filter() methods. It is important to remember here that the instance at which these mutations are implemented is indeed mutated. Let’s see what I mean by this.

#=== first select few variables ===#
reg_task$select(c("am", "carb", "cyl"))

#=== see the backend now ===#
reg_task$data()
     mpg am carb cyl
 1: 21.0  1    4   6
 2: 21.0  1    4   6
 3: 22.8  1    1   4
 4: 21.4  0    1   6
 5: 18.7  0    2   8
 6: 18.1  0    1   6
 7: 14.3  0    4   8
 8: 24.4  0    2   4
 9: 22.8  0    2   4
10: 19.2  0    4   6
11: 17.8  0    4   6
12: 16.4  0    3   8
13: 17.3  0    3   8
14: 15.2  0    3   8
15: 10.4  0    4   8
16: 10.4  0    4   8
17: 14.7  0    4   8
18: 32.4  1    1   4
19: 30.4  1    2   4
20: 33.9  1    1   4
21: 21.5  0    1   4
22: 15.5  0    2   8
23: 15.2  0    2   8
24: 13.3  0    4   8
25: 19.2  0    2   8
26: 27.3  1    1   4
27: 26.0  1    2   4
28: 30.4  1    2   4
29: 15.8  1    4   8
30: 19.7  1    6   6
31: 15.0  1    8   8
32: 21.4  1    2   4
     mpg am carb cyl

As you can see, reg_task now holds only the variables that were selected (plus the target variable mpg). This behavior is similar to how data.table works when you create a new variable using data[, := ] syntax. And, this is different from how dplyr::select works.

#=== create a dataset ===#
data_temp <- reg_task$data()

#=== select mpg, carb ===#
dplyr::select(data_temp, mpg, carb)
     mpg carb
 1: 21.0    4
 2: 21.0    4
 3: 22.8    1
 4: 21.4    1
 5: 18.7    2
 6: 18.1    1
 7: 14.3    4
 8: 24.4    2
 9: 22.8    2
10: 19.2    4
11: 17.8    4
12: 16.4    3
13: 17.3    3
14: 15.2    3
15: 10.4    4
16: 10.4    4
17: 14.7    4
18: 32.4    1
19: 30.4    2
20: 33.9    1
21: 21.5    1
22: 15.5    2
23: 15.2    2
24: 13.3    4
25: 19.2    2
26: 27.3    1
27: 26.0    2
28: 30.4    2
29: 15.8    4
30: 19.7    6
31: 15.0    8
32: 21.4    2
     mpg carb
#=== look at data_temp ===#
data_temp
     mpg am carb cyl
 1: 21.0  1    4   6
 2: 21.0  1    4   6
 3: 22.8  1    1   4
 4: 21.4  0    1   6
 5: 18.7  0    2   8
 6: 18.1  0    1   6
 7: 14.3  0    4   8
 8: 24.4  0    2   4
 9: 22.8  0    2   4
10: 19.2  0    4   6
11: 17.8  0    4   6
12: 16.4  0    3   8
13: 17.3  0    3   8
14: 15.2  0    3   8
15: 10.4  0    4   8
16: 10.4  0    4   8
17: 14.7  0    4   8
18: 32.4  1    1   4
19: 30.4  1    2   4
20: 33.9  1    1   4
21: 21.5  0    1   4
22: 15.5  0    2   8
23: 15.2  0    2   8
24: 13.3  0    4   8
25: 19.2  0    2   8
26: 27.3  1    1   4
27: 26.0  1    2   4
28: 30.4  1    2   4
29: 15.8  1    4   8
30: 19.7  1    6   6
31: 15.0  1    8   8
32: 21.4  1    2   4
     mpg am carb cyl

As you can see, data_temp is not mutated after select(). To save the the result of dplyr::select(), you need to explicitly assign it to another R object.

The target variable cannot be selected in select(). It is automatically selected.

If you would like to keep the original task, you can use the clone() method to create a distinct instance.

#=== create a clone ===#
reg_task_independent <- reg_task$clone()

Let’s filter the data using the filter() method.

#=== filter ===#
reg_task$filter(1:10)

#=== see the backend ===#
reg_task$data()
     mpg am carb cyl
 1: 21.0  1    4   6
 2: 21.0  1    4   6
 3: 22.8  1    1   4
 4: 21.4  0    1   6
 5: 18.7  0    2   8
 6: 18.1  0    1   6
 7: 14.3  0    4   8
 8: 24.4  0    2   4
 9: 22.8  0    2   4
10: 19.2  0    4   6

However, reg_task_independent is not affected.

reg_task_independent$data()
     mpg am carb cyl
 1: 21.0  1    4   6
 2: 21.0  1    4   6
 3: 22.8  1    1   4
 4: 21.4  0    1   6
 5: 18.7  0    2   8
 6: 18.1  0    1   6
 7: 14.3  0    4   8
 8: 24.4  0    2   4
 9: 22.8  0    2   4
10: 19.2  0    4   6
11: 17.8  0    4   6
12: 16.4  0    3   8
13: 17.3  0    3   8
14: 15.2  0    3   8
15: 10.4  0    4   8
16: 10.4  0    4   8
17: 14.7  0    4   8
18: 32.4  1    1   4
19: 30.4  1    2   4
20: 33.9  1    1   4
21: 21.5  0    1   4
22: 15.5  0    2   8
23: 15.2  0    2   8
24: 13.3  0    4   8
25: 19.2  0    2   8
26: 27.3  1    1   4
27: 26.0  1    2   4
28: 30.4  1    2   4
29: 15.8  1    4   8
30: 19.7  1    6   6
31: 15.0  1    8   8
32: 21.4  1    2   4
     mpg am carb cyl

You can use the rbind() and cbind() methods to append data vertically and horizontally, respectively.

Here is an example use of rbind().

reg_task$rbind(
  data.table(mpg = 20, am = 1, carb = 3, cyl = 99)
)

#=== see the change ===#
reg_task$data()
     mpg am carb cyl
 1: 21.0  1    4   6
 2: 21.0  1    4   6
 3: 22.8  1    1   4
 4: 21.4  0    1   6
 5: 18.7  0    2   8
 6: 18.1  0    1   6
 7: 14.3  0    4   8
 8: 24.4  0    2   4
 9: 22.8  0    2   4
10: 19.2  0    4   6
11: 20.0  1    3  99

18.2 Learners

In Python, the majority of ML packages are written to be compatible with scikit-learn framework. However, in R, there is no single framework that is equivalent to scikit-learn to which all the developers of ML packages conform to. Fortunately, the author of the mlr3 package picked popular ML packages (e.g., ranger, xgboost) and made it easier for us to use those packages under the unified framework.

tidymodels have their own collection of packages, which is very similar to what mlr3 has.

Here is the list of learners that is available after loading the mlr3verse package.

mlr_learners
<DictionaryLearner> with 129 stored values
Keys: classif.AdaBoostM1, classif.bart, classif.C50, classif.catboost,
  classif.cforest, classif.ctree, classif.cv_glmnet, classif.debug,
  classif.earth, classif.featureless, classif.fnn, classif.gam,
  classif.gamboost, classif.gausspr, classif.gbm, classif.glmboost,
  classif.glmnet, classif.IBk, classif.J48, classif.JRip, classif.kknn,
  classif.ksvm, classif.lda, classif.liblinear, classif.lightgbm,
  classif.LMT, classif.log_reg, classif.lssvm, classif.mob,
  classif.multinom, classif.naive_bayes, classif.nnet, classif.OneR,
  classif.PART, classif.qda, classif.randomForest, classif.ranger,
  classif.rfsrc, classif.rpart, classif.svm, classif.xgboost,
  clust.agnes, clust.ap, clust.cmeans, clust.cobweb, clust.dbscan,
  clust.diana, clust.em, clust.fanny, clust.featureless, clust.ff,
  clust.hclust, clust.kkmeans, clust.kmeans, clust.MBatchKMeans,
  clust.meanshift, clust.pam, clust.SimpleKMeans, clust.xmeans,
  dens.kde_ks, dens.locfit, dens.logspline, dens.mixed, dens.nonpar,
  dens.pen, dens.plug, dens.spline, regr.bart, regr.catboost,
  regr.cforest, regr.ctree, regr.cubist, regr.cv_glmnet, regr.debug,
  regr.earth, regr.featureless, regr.fnn, regr.gam, regr.gamboost,
  regr.gausspr, regr.gbm, regr.glm, regr.glmboost, regr.glmnet,
  regr.IBk, regr.kknn, regr.km, regr.ksvm, regr.liblinear,
  regr.lightgbm, regr.lm, regr.lmer, regr.M5Rules, regr.mars, regr.mob,
  regr.randomForest, regr.ranger, regr.rfsrc, regr.rpart, regr.rvm,
  regr.svm, regr.xgboost, surv.akritas, surv.blackboost, surv.cforest,
  surv.coxboost, surv.coxtime, surv.ctree, surv.cv_coxboost,
  surv.cv_glmnet, surv.deephit, surv.deepsurv, surv.dnnsurv,
  surv.flexible, surv.gamboost, surv.gbm, surv.glmboost, surv.glmnet,
  surv.loghaz, surv.mboost, surv.nelson, surv.obliqueRSF,
  surv.parametric, surv.pchazard, surv.penalized, surv.ranger,
  surv.rfsrc, surv.svm, surv.xgboost

As you can see, packages that we have seen earlier (e.g., glmnet, ranger, xgboost, gam) are available. the mlr3extralearners package provides you with additional but less-supported learners. You can check the complete list of learners here.

You set up a learner by giving the name of the learner you would like to implement from the list to lrn() like below.

learner <- lrn("regr.ranger")

Note that you need to pick the one with the appropriate prediction type prefix (the prefixes are self-explanatory). Here, since we are interested in regression, we picked "regr.ranger".

Once you set up a learner, you can see the set of the learner’s hyper-parameters.

learner$param_set
<ParamSet>
                              id    class lower upper nlevels        default
 1:                        alpha ParamDbl  -Inf   Inf     Inf            0.5
 2:       always.split.variables ParamUty    NA    NA     Inf <NoDefault[3]>
 3:                      holdout ParamLgl    NA    NA       2          FALSE
 4:                   importance ParamFct    NA    NA       4 <NoDefault[3]>
 5:                   keep.inbag ParamLgl    NA    NA       2          FALSE
 6:                    max.depth ParamInt     0   Inf     Inf               
 7:                min.node.size ParamInt     1   Inf     Inf              5
 8:                     min.prop ParamDbl  -Inf   Inf     Inf            0.1
 9:                      minprop ParamDbl  -Inf   Inf     Inf            0.1
10:                         mtry ParamInt     1   Inf     Inf <NoDefault[3]>
11:                   mtry.ratio ParamDbl     0     1     Inf <NoDefault[3]>
12:            num.random.splits ParamInt     1   Inf     Inf              1
13:                  num.threads ParamInt     1   Inf     Inf              1
14:                    num.trees ParamInt     1   Inf     Inf            500
15:                    oob.error ParamLgl    NA    NA       2           TRUE
16:                     quantreg ParamLgl    NA    NA       2          FALSE
17:        regularization.factor ParamUty    NA    NA     Inf              1
18:      regularization.usedepth ParamLgl    NA    NA       2          FALSE
19:                      replace ParamLgl    NA    NA       2           TRUE
20:    respect.unordered.factors ParamFct    NA    NA       3         ignore
21:              sample.fraction ParamDbl     0     1     Inf <NoDefault[3]>
22:                  save.memory ParamLgl    NA    NA       2          FALSE
23: scale.permutation.importance ParamLgl    NA    NA       2          FALSE
24:                    se.method ParamFct    NA    NA       2        infjack
25:                         seed ParamInt  -Inf   Inf     Inf               
26:         split.select.weights ParamUty    NA    NA     Inf               
27:                    splitrule ParamFct    NA    NA       3       variance
28:                      verbose ParamLgl    NA    NA       2           TRUE
29:                 write.forest ParamLgl    NA    NA       2           TRUE
                              id    class lower upper nlevels        default
       parents value
 1:  splitrule      
 2:                 
 3:                 
 4:                 
 5:                 
 6:                 
 7:                 
 8:                 
 9:  splitrule      
10:                 
11:                 
12:  splitrule      
13:                1
14:                 
15:                 
16:                 
17:                 
18:                 
19:                 
20:                 
21:                 
22:                 
23: importance      
24:                 
25:                 
26:                 
27:                 
28:                 
29:                 
       parents value

Right now, only num.threads is set explicitly.

learner$param_set$values
$num.threads
[1] 1

You can update or assign the value of a parameter like this:

#=== set max.depth to 5 ===#
learner$param_set$values$max.depth <- 5

#=== see the values ===#
learner$param_set$values
$num.threads
[1] 1

$max.depth
[1] 5

When you would like to set values for multiple hyper-parameters at the same time, you can provide a named list to $param_set$values. Here is an example:

#=== create a named vector ===#
parameter_values <- list("min.node.size" = 10, "mtry" = 5, "num.trees" = 500)

#=== assign them ===#
learner$param_set$values <- parameter_values

#=== see the values ===#
learner$param_set$values
$min.node.size
[1] 10

$mtry
[1] 5

$num.trees
[1] 500

But, notice that the values we set previously for max.depth and num.threads are gone.

18.3 Train, predict, assessment

18.3.1 Train

Training a model can be done using the $train() method on a leaner by supplying a task to it. Let’s first set up a task and learner.

#=== define a task ===#
reg_task <-
  TaskRegr$new(
    id = "example",
    backend = mtcars,
    target = "mpg"
  )

#=== set up a learner ===#
learner <- lrn("regr.ranger")
learner$param_set$values <- 
  list(
    "min.node.size" = 10, 
    "mtry" = 5, 
    "num.trees" = 500
  )

Notice that the model attribute of the learner is empty at this point.

learner$model
NULL

Now, let’s train.

learner$train(reg_task)

We now how information about the trained model in the model attribute.

learner$model
Ranger result

Call:
 ranger::ranger(dependent.variable.name = task$target_names, data = task$data(),      case.weights = task$weights$weight, min.node.size = 10L,      mtry = 5L, num.trees = 500L) 

Type:                             Regression 
Number of trees:                  500 
Sample size:                      32 
Number of independent variables:  10 
Mtry:                             5 
Target node size:                 10 
Variable importance mode:         none 
Splitrule:                        variance 
OOB prediction error (MSE):       6.369001 
R squared (OOB):                  0.8246618 

Notice that this is exactly what you would get if you use ranger() to train your model.

The train() function has the row_ids() option, where you can specify which rows of the data backend in the task are used for training.

Let’s split extract the row_ids attribute and then split it for train and test purposes.

#=== extract row ids ===#
row_ids <- reg_task$row_ids

#=== train ===#
train_ids <- row_ids[1:(length(row_ids) / 2)]

#=== test ===#
test_ids <- row_ids[!(row_ids %in% train_ids)]
Note

He, we are doing the split manually. But, you should use resampling methods, which we will look at later.

Now train using the train data.

#=== train ===#
learner$train(reg_task, row_ids = train_ids)

#=== seed the trained model ===#
learner$model
Ranger result

Call:
 ranger::ranger(dependent.variable.name = task$target_names, data = task$data(),      case.weights = task$weights$weight, min.node.size = 10L,      mtry = 5L, num.trees = 500L) 

Type:                             Regression 
Number of trees:                  500 
Sample size:                      16 
Number of independent variables:  10 
Mtry:                             5 
Target node size:                 10 
Variable importance mode:         none 
Splitrule:                        variance 
OOB prediction error (MSE):       5.24416 
R squared (OOB):                  0.6951542 

18.3.2 Prediction

We can use the predict() method to make predictions by supplying a task to it. Just like the train() method, we can use the row_ids option inside predict_newdata() to apply prediction only on a portion of the data. Let’s use test_ids we created above.

prediction <- learner$predict(reg_task, row_ids = test_ids)

prediction is of a class called Prediction. You can make the prediction available as a data.table by using as.data.table().

as.data.table(prediction)
    row_ids truth response
 1:      17  14.7 13.71786
 2:      18  32.4 21.80367
 3:      19  30.4 21.79294
 4:      20  33.9 21.80987
 5:      21  21.5 21.69345
 6:      22  15.5 16.47427
 7:      23  15.2 17.81916
 8:      24  13.3 14.88102
 9:      25  19.2 16.24036
10:      26  27.3 21.79751
11:      27  26.0 21.71119
12:      28  30.4 21.63942
13:      29  15.8 17.17378
14:      30  19.7 20.59941
15:      31  15.0 15.42186
16:      32  21.4 21.67216

You can predict on a new dataset by supplying a new dataset as a data.frame/data.table to the predict_newdata() method. Here, we just use parts of mtcars (just pretend this is a newdataset).

prediction <- learner$predict_newdata(mtcars)

18.3.3 Performance assessment

There are many measure of performance we can use under mlr3. Here is the list:

mlr_measures
<DictionaryMeasure> with 63 stored values
Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier,
  classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
  classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
  classif.logloss, classif.mbrier, classif.mcc, classif.npv,
  classif.ppv, classif.prauc, classif.precision, classif.recall,
  classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
  classif.tp, classif.tpr, clust.ch, clust.db, clust.dunn,
  clust.silhouette, clust.wss, debug, oob_error, regr.bias, regr.ktau,
  regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse, regr.mse,
  regr.msle, regr.pbias, regr.rae, regr.rmse, regr.rmsle, regr.rrse,
  regr.rse, regr.rsq, regr.sae, regr.smape, regr.srho, regr.sse,
  selected_features, sim.jaccard, sim.phi, time_both, time_predict,
  time_train

We can access a measure using the msr() function.

#=== get a measure ===#
measure <- msr("regr.mse")

#=== check the class ===#
class(measure)
[1] "MeasureRegrSimple" "MeasureRegr"       "Measure"          
[4] "R6"               

We can do performance evaluation using the $score() method on a Prediction object by supplying a Measure object.

prediction$score(measure)
regr.mse 
16.13093 

18.4 Resampling, cross-validation, and cross-fitting

18.4.1 Resampling

mlr3 offers the following resampling methods:

as.data.table(mlr_resamplings)
           key                         label        params iters
1:   bootstrap                     Bootstrap ratio,repeats    30
2:      custom                 Custom Splits                  NA
3:   custom_cv Custom Split Cross-Validation                  NA
4:          cv              Cross-Validation         folds    10
5:     holdout                       Holdout         ratio     1
6:    insample           Insample Resampling                   1
7:         loo                 Leave-One-Out                  NA
8: repeated_cv     Repeated Cross-Validation folds,repeats   100
9: subsampling                   Subsampling ratio,repeats    30

You can access a resampling method using the rsmp() function. You can specify parameters at the same time.

(
resampling <- rsmp("repeated_cv", repeats = 2, folds = 3)
)
<ResamplingRepeatedCV>: Repeated Cross-Validation
* Iterations: 6
* Instantiated: FALSE
* Parameters: repeats=2, folds=3

You can check the number of iterations (number of train-test datasets combinations) by accessing the iters attribute.

resampling$iters
[1] 6

You can override parameters just like you did for a leaner.

#=== update ===#
resampling$param_set$values = list(repeats = 3, folds = 4)

#=== see the updates ===#
resampling
<ResamplingRepeatedCV>: Repeated Cross-Validation
* Iterations: 12
* Instantiated: FALSE
* Parameters: repeats=3, folds=4

We can use the instantiate() method to implement the specified resampling method:

resampling$instantiate(reg_task)

You can access the train and test datasets using the train_set() and test_set() method. respectively. Since repeats = 3 and folds = 4, we have 12 sets of train and test datasets. You indicate which set you want inside train_set() and test_set().

First pair:

resampling$train_set(1)
 [1]  5  6 13 14 18 25 28 31  2  9 15 16 20 23 27 30  3  4 17 19 21 22 24 26
resampling$test_set(1)
[1]  1  7  8 10 11 12 29 32

Last pair:

resampling$train_set(12)
 [1]  3  5 12 14 19 20 25 29  9 15 18 21 24 28 30 31  2  4  6 10 11 22 26 32
resampling$test_set(12)
[1]  1  7  8 13 16 17 23 27

18.4.2 Cross-validation and cross-fitting

Now that data splits are determined (along with a task and leaner), we can conduct a cross-validation using the resample() function like below.

cv_results <- resample(reg_task, learner, resampling)
INFO  [10:52:25.594] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/12) 
INFO  [10:52:25.616] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/12) 
INFO  [10:52:25.629] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 10/12) 
INFO  [10:52:25.638] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 6/12) 
INFO  [10:52:25.648] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 12/12) 
INFO  [10:52:25.658] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 5/12) 
INFO  [10:52:25.667] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 8/12) 
INFO  [10:52:25.676] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/12) 
INFO  [10:52:25.686] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 7/12) 
INFO  [10:52:25.695] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 11/12) 
INFO  [10:52:25.704] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 4/12) 
INFO  [10:52:25.714] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 9/12) 

This code applies the method specified in leaner to each of the 12 train datasets, and evaluate the trained model on each of the 12 test datasets.

You can look at the prediction results using the predictions() method.

cv_results$predictions()
[[1]]
<PredictionRegr> for 8 observations:
    row_ids truth response
          1  21.0 21.34871
          7  14.3 14.71893
          8  24.4 23.00590
---                       
         12  16.4 16.02169
         29  15.8 18.13388
         32  21.4 23.81583

[[2]]
<PredictionRegr> for 8 observations:
    row_ids truth response
          5  18.7 14.83204
          6  18.1 19.85426
         13  17.3 15.03344
---                       
         25  19.2 14.59602
         28  30.4 24.80455
         31  15.0 14.89332

[[3]]
<PredictionRegr> for 8 observations:
    row_ids truth response
          2  21.0 20.66545
          9  22.8 22.73046
         15  10.4 16.03584
---                       
         23  15.2 17.66704
         27  26.0 24.92584
         30  19.7 19.87509

[[4]]
<PredictionRegr> for 8 observations:
    row_ids truth response
          3  22.8 27.84492
          4  21.4 19.01035
         17  14.7 13.06798
---                       
         22  15.5 16.55219
         24  13.3 15.67540
         26  27.3 28.81684

[[5]]
<PredictionRegr> for 8 observations:
    row_ids truth response
          2  21.0 20.62481
          3  22.8 25.88562
         11  17.8 19.87309
---                       
         23  15.2 17.86239
         28  30.4 25.34370
         32  21.4 23.44607

[[6]]
<PredictionRegr> for 8 observations:
    row_ids truth response
          4  21.4 18.79353
          5  18.7 16.19981
         10  19.2 18.58758
---                       
         24  13.3 15.79908
         26  27.3 29.61911
         30  19.7 20.85556

[[7]]
<PredictionRegr> for 8 observations:
    row_ids truth response
          1  21.0 20.54647
          7  14.3 14.97206
          9  22.8 23.21974
---                       
         18  32.4 27.26674
         21  21.5 23.42506
         27  26.0 24.08510

[[8]]
<PredictionRegr> for 8 observations:
    row_ids truth response
          6  18.1 19.57434
          8  24.4 22.70844
         16  10.4 14.07726
---                       
         25  19.2 16.03038
         29  15.8 16.94906
         31  15.0 15.58641

[[9]]
<PredictionRegr> for 8 observations:
    row_ids truth response
          3  22.8 25.67939
          5  18.7 15.04072
         12  16.4 15.76059
---                       
         20  33.9 27.56829
         25  19.2 14.73703
         29  15.8 16.71143

[[10]]
<PredictionRegr> for 8 observations:
    row_ids truth response
          9  22.8 23.56904
         15  10.4 13.90063
         18  32.4 27.14532
---                       
         28  30.4 24.69515
         30  19.7 20.17987
         31  15.0 15.55283

[[11]]
<PredictionRegr> for 8 observations:
    row_ids truth response
          2  21.0 21.85924
          4  21.4 20.37669
          6  18.1 19.36490
---                       
         22  15.5 16.19947
         26  27.3 29.14830
         32  21.4 24.31793

[[12]]
<PredictionRegr> for 8 observations:
    row_ids truth response
          1  21.0 20.60021
          7  14.3 15.26024
          8  24.4 22.88854
---                       
         17  14.7 14.66568
         23  15.2 17.57585
         27  26.0 24.52656

You can get the all predictions combined using the prediction() method.

#=== all combined ===#
all_predictions <- cv_results$prediction()

#=== check the class ===#
class(all_predictions)
[1] "PredictionRegr" "Prediction"     "R6"            

Since it is a Prediciton object, we can apply the score() method like this.

all_predictions$score(msr("regr.mse"))
regr.mse 
 7.51057 

Of course, you are also cross-fitting when you are doing cross-validation. You can just take the prediction results and use them if you are implementing DML for example.

(
cross_fitted_yhat <-
  as.data.table(all_predictions) %>%
  .[, .(y_hat_cf = mean(response)), by = row_ids]
)
    row_ids y_hat_cf
 1:       1 20.83180
 2:       7 14.98374
 3:       8 22.86763
 4:      10 20.01379
 5:      11 20.45234
 6:      12 15.81812
 7:      29 17.26479
 8:      32 23.85994
 9:       5 15.35752
10:       6 19.59783
11:      13 15.45907
12:      14 15.70010
13:      18 27.18675
14:      25 15.12114
15:      28 24.94780
16:      31 15.34419
17:       2 21.04983
18:       9 23.17308
19:      15 14.85905
20:      16 14.75365
21:      20 28.18517
22:      23 17.70176
23:      27 24.51250
24:      30 20.30351
25:       3 26.46998
26:       4 19.39352
27:      17 13.58068
28:      19 28.22748
29:      21 24.24876
30:      22 16.28554
31:      24 15.64745
32:      26 29.19475
    row_ids y_hat_cf

18.5 Hyper-parameter tuning

In conducting hyper-parameter tuning under the mlr3 framework, you define TuningInstance* class, select the tuning method, and then trigger it.

18.5.1 TuningInstance

There are two tuning instance classes.

  • TuningInstanceSingleCrit
  • TuningInstanceMultiCrit

The difference should be clear by looking at the name of the classes. We focus on the TuningInstanceSingleCrit class here.

Tuning instance consists of six elements:

  • task
  • learner
  • resampling
  • measure
  • search_space
  • terminator

We have covered all except search_space and terminator. Let’s look at these two.

Let’s quickly create the first four elements.

#=== task ===#
reg_task <-
  TaskRegr$new(
    id = "example",
    backend = mtcars,
    target = "mpg"
  )

#=== learner ===#
learner <- lrn("regr.ranger")
learner$param_set$values$max.depth <- 10

#=== resampling ===#
resampling <- rsmp("cv", folds = 3) # k-fold cv

#=== measure ===#
measure <- msr("regr.mse")

search_space defines which hyper-parameters to tune and their ranges. You can use ps() to create a search space. Before doing so, let’s look at the parameters of the learner.

learner$param_set
<ParamSet>
                              id    class lower upper nlevels        default
 1:                        alpha ParamDbl  -Inf   Inf     Inf            0.5
 2:       always.split.variables ParamUty    NA    NA     Inf <NoDefault[3]>
 3:                      holdout ParamLgl    NA    NA       2          FALSE
 4:                   importance ParamFct    NA    NA       4 <NoDefault[3]>
 5:                   keep.inbag ParamLgl    NA    NA       2          FALSE
 6:                    max.depth ParamInt     0   Inf     Inf               
 7:                min.node.size ParamInt     1   Inf     Inf              5
 8:                     min.prop ParamDbl  -Inf   Inf     Inf            0.1
 9:                      minprop ParamDbl  -Inf   Inf     Inf            0.1
10:                         mtry ParamInt     1   Inf     Inf <NoDefault[3]>
11:                   mtry.ratio ParamDbl     0     1     Inf <NoDefault[3]>
12:            num.random.splits ParamInt     1   Inf     Inf              1
13:                  num.threads ParamInt     1   Inf     Inf              1
14:                    num.trees ParamInt     1   Inf     Inf            500
15:                    oob.error ParamLgl    NA    NA       2           TRUE
16:                     quantreg ParamLgl    NA    NA       2          FALSE
17:        regularization.factor ParamUty    NA    NA     Inf              1
18:      regularization.usedepth ParamLgl    NA    NA       2          FALSE
19:                      replace ParamLgl    NA    NA       2           TRUE
20:    respect.unordered.factors ParamFct    NA    NA       3         ignore
21:              sample.fraction ParamDbl     0     1     Inf <NoDefault[3]>
22:                  save.memory ParamLgl    NA    NA       2          FALSE
23: scale.permutation.importance ParamLgl    NA    NA       2          FALSE
24:                    se.method ParamFct    NA    NA       2        infjack
25:                         seed ParamInt  -Inf   Inf     Inf               
26:         split.select.weights ParamUty    NA    NA     Inf               
27:                    splitrule ParamFct    NA    NA       3       variance
28:                      verbose ParamLgl    NA    NA       2           TRUE
29:                 write.forest ParamLgl    NA    NA       2           TRUE
                              id    class lower upper nlevels        default
       parents value
 1:  splitrule      
 2:                 
 3:                 
 4:                 
 5:                 
 6:               10
 7:                 
 8:                 
 9:  splitrule      
10:                 
11:                 
12:  splitrule      
13:                1
14:                 
15:                 
16:                 
17:                 
18:                 
19:                 
20:                 
21:                 
22:                 
23: importance      
24:                 
25:                 
26:                 
27:                 
28:                 
29:                 
       parents value

Let’s tune three parameters here: mtry.ratio, min.node.size, and num.trees. For each parameter to tune, we need to use an appropriate function to define the range. You can see what functions to use from class variable.

learner$param_set %>% 
  as.data.table() %>% 
  .[id %in% c("mtry.ratio", "min.node.size"), .(id, class, lower, upper)]
              id    class lower upper
1: min.node.size ParamInt     1   Inf
2:    mtry.ratio ParamDbl     0     1

In this case, we use p_int() for “min.node.size” as its class is ParamInt and use p_dbl() for “mtry.ratio” as its class is ParamDbl. You cannot specify the range that go beyond the lower and upper for each parameter.

search_space <- 
  ps(
    mtry.ratio = p_dbl(lower = 0.5, upper = 0.9),
    min.node.size = p_int(lower = 1, upper = 20)
  )

Let’s define a terminator now. mlr3 offers five different options. We will just look at the most common one here, which is TerminatorEvals. TerminatorEvals terminates tuning after a given number of iterations specified by the user (see here for other terminator options).

You can use the trm() function to define a Terminator object.

terminator <- trm("evals", n_evals = 100) 

Inside trm(), “evals” indicates that we would like to use the TerminatorEvals option.

Now that we have all the components specified, we can instantiate (generate) a TuningInstanceSingleCrit class.

(
tuning_instance <- 
  TuningInstanceSingleCrit$new(
    task = reg_task,
    learner = learner,
    resampling = resampling,
    measure = measure,
    search_space = search_space,
    terminator = terminator
  )
)
<TuningInstanceSingleCrit>
* State:  Not optimized
* Objective: <ObjectiveTuning:regr.ranger_on_example>
* Search Space:
              id    class lower upper nlevels
1:    mtry.ratio ParamDbl   0.5   0.9     Inf
2: min.node.size ParamInt   1.0  20.0      20
* Terminator: <TerminatorEvals>

18.5.2 Tuner

Let’s now define the method of tuning. mlr3 offers four options:

  • Grid Search (TunerGridSearch)
  • Random Search (TunerRandomSearch)
  • Generalized Simulated Annealing (TunerGenSA)
  • Non-Linear Optimization (TunerNLoptr)

Here we use TunerGridSearch. We can set up a tuner using the tnr() function.

tuner <- tnr("grid_search", resolution = 4) 

resolution = 4 means that each parameter takes four values where the values are equidistant between the upper and lower bounds specified in search_space. So, this tuning will look at \(4^2 = 16\) parameter configurations. Notice that we set the number of evaluations to 100 above. So, all \(16\) cases will be evaluated. However, if you set n_evals lower than \(16\), then the tuning will not look at all the cases.

18.5.3 Tuning

You can trigger tuning by supplying a tuning instance to the optimizer() method on tuner.

tuner$optimize(tuning_instance)

Since the execution of this tuning prints so many lines of results, it is not presented here.

One the tuning is done, you can get the optimized parameters by accessing the result_learner_param_values attribute of the tuning instance.

tuning_instance$result_learner_param_vals
$num.threads
[1] 1

$max.depth
[1] 10

$mtry.ratio
[1] 0.5

$min.node.size
[1] 1

You can look at the evaluation results of other parameter configurations by accessing the archive attribute of the tuning instance.

as.data.table(tuning_instance$archive) %>% head()
   mtry.ratio min.node.size  regr.mse x_domain_mtry.ratio
1:  0.5000000            20 10.095093           0.5000000
2:  0.9000000            14  8.107527           0.9000000
3:  0.7666667            14  7.884286           0.7666667
4:  0.9000000             1  5.375487           0.9000000
5:  0.9000000             7  5.606771           0.9000000
6:  0.6333333             1  5.150592           0.6333333
   x_domain_min.node.size runtime_learners           timestamp batch_nr
1:                     20            0.022 2022-08-19 10:52:25        1
2:                     14            0.019 2022-08-19 10:52:25        2
3:                     14            0.022 2022-08-19 10:52:26        3
4:                      1            0.035 2022-08-19 10:52:26        4
5:                      7            0.024 2022-08-19 10:52:26        5
6:                      1            0.032 2022-08-19 10:52:26        6
   warnings errors      resample_result
1:        0      0 <ResampleResult[22]>
2:        0      0 <ResampleResult[22]>
3:        0      0 <ResampleResult[22]>
4:        0      0 <ResampleResult[22]>
5:        0      0 <ResampleResult[22]>
6:        0      0 <ResampleResult[22]>

18.5.4 AutoTuner

AutoTuner sounds like it is a tuner, but it is really learner where tuning is automatically implemented when training is triggered with a task. An AutoTuner class can be instantiated using the new() method on AutoTuner class like below.

(
auto_tunning_learner <- 
  AutoTuner$new(
    learner = learner,
    resampling = resampling,
    measure = measure,
    search_space = search_space,
    terminator = terminator,
    tuner = tuner
  )
)
<AutoTuner:regr.ranger.tuned>
* Model: -
* Search Space:
<ParamSet>
              id    class lower upper nlevels        default value
1:    mtry.ratio ParamDbl   0.5   0.9     Inf <NoDefault[3]>      
2: min.node.size ParamInt   1.0  20.0      20 <NoDefault[3]>      
* Packages: mlr3, mlr3tuning, mlr3learners, ranger
* Predict Type: response
* Feature Types: logical, integer, numeric, character, factor, ordered
* Properties: hotstart_backward, importance, oob_error, weights

Note that unlike TuningInstance, AutoTuner does not take a task as its element and takes a Tuner instead. Once an AutoTuner is instantiated, you can use it like a learner and invoke the train() method with a task to train a model. The difference from a regular leaner is that it automatically tune the parameters internally and use the optimized parameter values to train.

auto_tunning_learner$train(reg_task)
INFO  [10:52:26.915] [bbotk] Starting to optimize 2 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=100, k=0]' 
INFO  [10:52:26.916] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:26.925] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:26.928] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:26.938] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:26.947] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:26.956] [mlr3] Finished benchmark 
INFO  [10:52:26.966] [bbotk] Result of batch 1: 
INFO  [10:52:26.967] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:26.967] [bbotk]         0.5            14 8.853036        0      0            0.015 
INFO  [10:52:26.967] [bbotk]                                 uhash 
INFO  [10:52:26.967] [bbotk]  92b99ee4-e7ba-4cee-b61a-5fe8bf3c6a6d 
INFO  [10:52:26.967] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:26.976] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:26.979] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:26.988] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:26.997] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.006] [mlr3] Finished benchmark 
INFO  [10:52:27.022] [bbotk] Result of batch 2: 
INFO  [10:52:27.023] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.023] [bbotk]         0.9            20 11.74412        0      0            0.018 
INFO  [10:52:27.023] [bbotk]                                 uhash 
INFO  [10:52:27.023] [bbotk]  2686d082-d6cb-4430-8ef2-0c9d4f18048e 
INFO  [10:52:27.023] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:27.032] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:27.035] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:27.050] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:27.064] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.078] [mlr3] Finished benchmark 
INFO  [10:52:27.089] [bbotk] Result of batch 3: 
INFO  [10:52:27.090] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.090] [bbotk]         0.9             1 5.814905        0      0            0.034 
INFO  [10:52:27.090] [bbotk]                                 uhash 
INFO  [10:52:27.090] [bbotk]  640b968f-e097-42f8-8094-49563ba2fe32 
INFO  [10:52:27.090] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:27.099] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:27.102] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:27.110] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:27.119] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.134] [mlr3] Finished benchmark 
INFO  [10:52:27.145] [bbotk] Result of batch 4: 
INFO  [10:52:27.146] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.146] [bbotk]   0.6333333            14 9.373763        0      0            0.019 
INFO  [10:52:27.146] [bbotk]                                 uhash 
INFO  [10:52:27.146] [bbotk]  abce4942-9d23-4ab6-ae96-b0158a64b60f 
INFO  [10:52:27.146] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:27.155] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:27.157] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:27.166] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:27.174] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.183] [mlr3] Finished benchmark 
INFO  [10:52:27.193] [bbotk] Result of batch 5: 
INFO  [10:52:27.194] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.194] [bbotk]         0.5            20 11.11123        0      0            0.016 
INFO  [10:52:27.194] [bbotk]                                 uhash 
INFO  [10:52:27.194] [bbotk]  56524c5b-292b-4795-8b0b-7f5032c906c8 
INFO  [10:52:27.195] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:27.203] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:27.206] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:27.214] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.228] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:27.238] [mlr3] Finished benchmark 
INFO  [10:52:27.249] [bbotk] Result of batch 6: 
INFO  [10:52:27.249] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.249] [bbotk]   0.7666667            14 9.556742        0      0            0.019 
INFO  [10:52:27.249] [bbotk]                                 uhash 
INFO  [10:52:27.249] [bbotk]  0985f273-95d0-4844-920e-837591ad4787 
INFO  [10:52:27.250] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:27.258] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:27.261] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:27.274] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:27.288] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.302] [mlr3] Finished benchmark 
INFO  [10:52:27.312] [bbotk] Result of batch 7: 
INFO  [10:52:27.313] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.313] [bbotk]   0.7666667             1 5.640281        0      0             0.03 
INFO  [10:52:27.313] [bbotk]                                 uhash 
INFO  [10:52:27.313] [bbotk]  63c479e2-0ebf-43bf-b1f2-cccee6a03fe4 
INFO  [10:52:27.314] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:27.322] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:27.325] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:27.346] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:27.360] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.374] [mlr3] Finished benchmark 
INFO  [10:52:27.384] [bbotk] Result of batch 8: 
INFO  [10:52:27.385] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.385] [bbotk]   0.6333333             1 5.344656        0      0            0.031 
INFO  [10:52:27.385] [bbotk]                                 uhash 
INFO  [10:52:27.385] [bbotk]  e377f9f3-9222-449e-8bc1-aec5d6f9bea4 
INFO  [10:52:27.386] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:27.394] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:27.397] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:27.405] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.414] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:27.423] [mlr3] Finished benchmark 
INFO  [10:52:27.434] [bbotk] Result of batch 9: 
INFO  [10:52:27.434] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.434] [bbotk]   0.7666667            20 11.70993        0      0            0.017 
INFO  [10:52:27.434] [bbotk]                                 uhash 
INFO  [10:52:27.434] [bbotk]  5ad30c70-2251-4700-bebe-135f5ad756b4 
INFO  [10:52:27.435] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:27.449] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:27.453] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.466] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:27.477] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:27.489] [mlr3] Finished benchmark 
INFO  [10:52:27.500] [bbotk] Result of batch 10: 
INFO  [10:52:27.501] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.501] [bbotk]   0.6333333             7 6.243863        0      0            0.022 
INFO  [10:52:27.501] [bbotk]                                 uhash 
INFO  [10:52:27.501] [bbotk]  efc59800-b9af-4500-bf2a-375273866b87 
INFO  [10:52:27.501] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:27.511] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:27.514] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:27.524] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.533] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:27.543] [mlr3] Finished benchmark 
INFO  [10:52:27.555] [bbotk] Result of batch 11: 
INFO  [10:52:27.556] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.556] [bbotk]         0.9            14 9.585749        0      0            0.019 
INFO  [10:52:27.556] [bbotk]                                 uhash 
INFO  [10:52:27.556] [bbotk]  9c144243-5261-463e-a177-77b3bc34e7aa 
INFO  [10:52:27.556] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:27.573] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:27.576] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.586] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:27.595] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:27.605] [mlr3] Finished benchmark 
INFO  [10:52:27.617] [bbotk] Result of batch 12: 
INFO  [10:52:27.618] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.618] [bbotk]   0.6333333            20 11.30442        0      0            0.018 
INFO  [10:52:27.618] [bbotk]                                 uhash 
INFO  [10:52:27.618] [bbotk]  1aed11b1-a85f-44ce-a675-8fb756b0de19 
INFO  [10:52:27.618] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:27.628] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:27.630] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:27.644] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.657] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:27.671] [mlr3] Finished benchmark 
INFO  [10:52:27.682] [bbotk] Result of batch 13: 
INFO  [10:52:27.683] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.683] [bbotk]         0.5             1 5.168695        0      0            0.029 
INFO  [10:52:27.683] [bbotk]                                 uhash 
INFO  [10:52:27.683] [bbotk]  ef6b5f2a-887f-4ce3-bccc-ed1caf5b8f39 
INFO  [10:52:27.684] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:27.700] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:27.703] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.714] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:27.724] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:27.735] [mlr3] Finished benchmark 
INFO  [10:52:27.745] [bbotk] Result of batch 14: 
INFO  [10:52:27.746] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.746] [bbotk]   0.7666667             7 6.306153        0      0            0.022 
INFO  [10:52:27.746] [bbotk]                                 uhash 
INFO  [10:52:27.746] [bbotk]  52e7534f-a6fd-4b06-a5db-ec9f16149880 
INFO  [10:52:27.747] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:27.756] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:27.758] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:27.769] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:27.780] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.791] [mlr3] Finished benchmark 
INFO  [10:52:27.807] [bbotk] Result of batch 15: 
INFO  [10:52:27.808] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.808] [bbotk]         0.9             7 6.025976        0      0            0.023 
INFO  [10:52:27.808] [bbotk]                                 uhash 
INFO  [10:52:27.808] [bbotk]  6657aa2d-e805-438d-81ff-de9b8fff72c6 
INFO  [10:52:27.808] [bbotk] Evaluating 1 configuration(s) 
INFO  [10:52:27.818] [mlr3] Running benchmark with 3 resampling iterations 
INFO  [10:52:27.821] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3) 
INFO  [10:52:27.830] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3) 
INFO  [10:52:27.840] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3) 
INFO  [10:52:27.851] [mlr3] Finished benchmark 
INFO  [10:52:27.861] [bbotk] Result of batch 16: 
INFO  [10:52:27.862] [bbotk]  mtry.ratio min.node.size regr.mse warnings errors runtime_learners 
INFO  [10:52:27.862] [bbotk]         0.5             7 6.040561        0      0             0.02 
INFO  [10:52:27.862] [bbotk]                                 uhash 
INFO  [10:52:27.862] [bbotk]  cddb6675-643b-4549-a1fc-af3cc26bf1cd 
INFO  [10:52:27.864] [bbotk] Finished optimizing after 16 evaluation(s) 
INFO  [10:52:27.864] [bbotk] Result: 
INFO  [10:52:27.865] [bbotk]  mtry.ratio min.node.size learner_param_vals  x_domain regr.mse 
INFO  [10:52:27.865] [bbotk]         0.5             1          <list[4]> <list[2]> 5.168695 

You can access the optimized learner by accessing the model$learner attribute.

auto_tunning_learner$model$learner
<LearnerRegrRanger:regr.ranger>
* Model: ranger
* Parameters: num.threads=1, max.depth=10, mtry.ratio=0.5,
  min.node.size=1
* Packages: mlr3, mlr3learners, ranger
* Predict Type: response
* Feature types: logical, integer, numeric, character, factor, ordered
* Properties: hotstart_backward, importance, oob_error, weights

You can look at the history of tuning process like below:

auto_tunning_learner$model$tuning_instance
<TuningInstanceSingleCrit>
* State:  Optimized
* Objective: <ObjectiveTuning:regr.ranger_on_example>
* Search Space:
              id    class lower upper nlevels
1:    mtry.ratio ParamDbl   0.5   0.9     Inf
2: min.node.size ParamInt   1.0  20.0      20
* Terminator: <TerminatorEvals>
* Result:
   mtry.ratio min.node.size regr.mse
1:        0.5             1 5.168695
* Archive:
    mtry.ratio min.node.size  regr.mse
 1:  0.5000000            14  8.853036
 2:  0.9000000            20 11.744119
 3:  0.9000000             1  5.814905
 4:  0.6333333            14  9.373763
 5:  0.5000000            20 11.111230
 6:  0.7666667            14  9.556742
 7:  0.7666667             1  5.640281
 8:  0.6333333             1  5.344656
 9:  0.7666667            20 11.709931
10:  0.6333333             7  6.243863
11:  0.9000000            14  9.585749
12:  0.6333333            20 11.304420
13:  0.5000000             1  5.168695
14:  0.7666667             7  6.306153
15:  0.9000000             7  6.025976
16:  0.5000000             7  6.040561

You can of course use the predict() method as well.

auto_tunning_learner$predict(reg_task)
<PredictionRegr> for 32 observations:
    row_ids truth response
          1  21.0 20.90324
          2  21.0 20.90364
          3  22.8 23.93983
---                       
         30  19.7 19.82130
         31  15.0 14.92800
         32  21.4 21.84063

18.6 mlr3 in action

Here, an example usage of the mlr3 framework as a part of a research process is presented. Suppose our goal is to estimate the heterogeneous treatment effect using causal forest using the R grf package. The grf::causal_forest() function implements causal forest as an R-learner and it uses random forest for its first-stage estimations by default. However, the function allows the users to provide their own estimated values of \(y\) (dependent variable) and \(T\) (treatment). We will code the process of conducting the first stage using mlr3 and then use grf::causal_forest() to estimate heterogeneous treatment effects.

#=== load the Treatment dataset ===#
data("Treatment", package = "Ecdat")

#=== convert to a data.table ===#
(
data <- 
  data.table(Treatment)
)
      treat age educ     ethn married    re74    re75     re78   u74   u75
   1:  TRUE  37   11    black    TRUE     0.0     0.0  9930.05  TRUE  TRUE
   2:  TRUE  30   12    black   FALSE     0.0     0.0 24909.50  TRUE  TRUE
   3:  TRUE  27   11    black   FALSE     0.0     0.0  7506.15  TRUE  TRUE
   4:  TRUE  33    8    black   FALSE     0.0     0.0   289.79  TRUE  TRUE
   5:  TRUE  22    9    black   FALSE     0.0     0.0  4056.49  TRUE  TRUE
  ---                                                                     
2671: FALSE  47    8    other    TRUE 44667.4 33837.1 38568.70 FALSE FALSE
2672: FALSE  32    8    other    TRUE 47022.4 67137.1 59109.10 FALSE FALSE
2673: FALSE  47   10    other    TRUE 48198.0 47968.1 55710.30 FALSE FALSE
2674: FALSE  54    0 hispanic    TRUE 49228.5 44221.0 20540.40 FALSE FALSE
2675: FALSE  40    8    other    TRUE 50940.9 55500.0 53198.20 FALSE FALSE

The dependent variable is re78 (\(Y\)), which is real annual earnings in 1978 (after treatment). The treatment variable of interest is treat (\(T\)), which is TRUE if a person had gone through a training, FALSE otherwise. The features that are included as potential drivers of the heterogeneity in the impact of treat on re78 is age, educ, ethn, and married (\(X\)). Note that the focus of this section is just showcasing the use of mlr3 and no attention is paid to potential endogeneity problems.

18.6.1 First stage

We will use ranger() and xgboost() as learners. While ranger() accepts factor variables, xgboost() does not. So, we will one-hot-encode the data using mltools::one_hot() so that we can just create a single Task and use it for both. We also turn treat to a factor so it is amenable with classification jobs by the two learner functions.

(
data_trt <- 
  mltools::one_hot(data) %>% 
  .[, treat := factor(treat)]
)

Note that we need separate procedures for estimating \(E[Y|X]\) and \(E[T|X]\). The former is a regression and the latter is a classification task.

Let’s first work on estimating \(E[Y|X]\). First, we set up a task.

y_est_task <- 
  TaskRegr$new(
    id = "estimate_y_on_x",
    backend = data_trt[, .(
      re78, age, educ, married, 
      ethn_other, ethn_black, ethn_hispanic 
    )],
    target = "re78"
  )

We consider two modeling approaches: random forest by ranger and gradient boosted forest by xgboost.

y_learner_ranger <- lrn("regr.ranger")
y_learner_xgboost <- lrn("regr.xgboost")

We will implement a K-fold cross-validation to select the better model with optimized hyper-parameter values. We do this by applying triggering Tuner classed defined separately for the two learners.

Let’s define the resampling method, measure, terminator, and tuner that will be shared by the two approaches.

resampling_y <- rsmp("cv", folds = 4)
measure_y <- msr("regr.mse")
terminator <- trm("evals", n_evals = 100)
tuner <- tnr("grid_search", resolution = 4)

Now, when we compare multiple models, we should use the same CV splits to have a fair comparison their model performance. To ensure this, we need to use an instantiated Resampling object.

If you provide an un-instantiated Resampling object to an AutoTuner, it will instantiate the Resampling object internally and two separate AutoTuners can result in two distinct splits.

#=== instantiate ===#
resampling_y$instantiate(y_est_task)

#=== confirm it is indeed instantiated ===#
resampling_y
<ResamplingCV>: Cross-Validation
* Iterations: 4
* Instantiated: TRUE
* Parameters: folds=4

Let’s define search space for each of the learners.

search_space_ranger <-
  ps(
    mtry = p_int(lower = 1, upper = length(y_est_task$feature_names)),
    min.node.size = p_int(lower = 1, upper = 20)
  )

search_space_xgboost <-
  ps(
    nrounds = p_int(lower = 100, upper = 400),
    eta = p_dbl(lower = 0.01, upper = 1)
  )

We have all the ingredients to set up TuningInstances for the learners.

tuning_instance_ranger <-
  TuningInstanceSingleCrit$new(
    task = y_est_task,
    learner = y_learner_ranger,
    resampling = resampling_y,
    measure = measure_y,
    search_space = search_space_ranger,
    terminator = terminator
  )

tuning_instance_xgboost <-
  TuningInstanceSingleCrit$new(
    task = y_est_task,
    learner = y_learner_xgboost,
    resampling = resampling_y,
    measure = measure_y,
    search_space = search_space_xgboost,
    terminator = terminator
  )

Let’s tune them now (this can take a while).

We use using the same tuner here, but you can use different tuning processes for the learners. For example, you can have resolution = 5 for tuning regr.xgboost.

#=== tune ranger ===#
tuner$optimize(tuning_instance_ranger)

#=== tune xgboost ===#
tuner$optimize(tuning_instance_xgboost)

Here are the MSEs from the two individually tuned learners.

tuning_instance_ranger$result_y
 regr.mse 
193034514 
tuning_instance_xgboost$result_y
 regr.mse 
197112372 

So, in this example, we go with regr.ranger with its optimized hyper-parameter values. Let’s update our learner with the optimized hyper-parameter values.

(
y_learner_ranger$param_set$values <- tuning_instance_ranger$result_learner_param_vals
)
$num.threads
[1] 1

$mtry
[1] 2

$min.node.size
[1] 14

Now that we have decided on the model to use for predicting \(E[y|X]\), let’s implement cross-fitting.

cv_results_y <-
  resample(
    y_est_task, 
    y_learner_ranger, 
    rsmp("repeated_cv", repeats = 4, folds = 3) 
  )
INFO  [10:52:55.515] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 4/12) 
INFO  [10:52:55.643] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 11/12) 
INFO  [10:52:55.769] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 9/12) 
INFO  [10:52:55.893] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 12/12) 
INFO  [10:52:56.018] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 10/12) 
INFO  [10:52:56.144] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 3/12) 
INFO  [10:52:56.274] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 5/12) 
INFO  [10:52:56.404] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 8/12) 
INFO  [10:52:56.528] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 1/12) 
INFO  [10:52:56.653] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 7/12) 
INFO  [10:52:56.777] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 2/12) 
INFO  [10:52:56.900] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 6/12) 
#=== all combined ===#
all_predictions_y <- 
  cv_results_y$prediction() %>% 
  as.data.table() %>% 
  .[, .(y_hat = mean(response)), by = row_ids] %>%
  .[order(row_ids), ]

We basically follow the same process for estimating \(E[T|X]\). Let’s set up tuning processes for the learners.

t_est_task <-
  TaskClassif$new(
    id = "estimate_t_on_x",
    backend = 
      data_trt[, .(
        treat, age, educ, married, 
        ethn_other, ethn_black, ethn_hispanic 
      )],
    target = "treat"
  )

t_learner_ranger <- lrn("classif.ranger")
t_learner_xgboost <- lrn("classif.xgboost")

resampling_t <- rsmp("cv", folds = 4)
resampling_y$instantiate(t_est_task)
measure_t <- msr("classif.ce")
terminator <- trm("evals", n_evals = 100)
tuner <- tnr("grid_search", resolution = 4)

tuning_instance_ranger <-
  TuningInstanceSingleCrit$new(
    task = t_est_task,
    learner = t_learner_ranger,
    resampling = resampling_y,
    measure = measure_t,
    search_space = search_space_ranger,
    terminator = terminator
  )

tuning_instance_xgboost <-
  TuningInstanceSingleCrit$new(
    task = t_est_task,
    learner = t_learner_xgboost,
    resampling = resampling_y,
    measure = measure_t,
    search_space = search_space_xgboost,
    terminator = terminator
  )

Here are the classification error from the two individually tuned learners.

#=== tune ranger ===#
tuner$optimize(tuning_instance_ranger)
tuning_instance_ranger$result_y
#=== tune xgboost ===#
tuner$optimize(tuning_instance_xgboost)
tuning_instance_xgboost$result_y

So, we are picking the classif.ranger option here as well as it has a lower classification error.

t_learner_ranger$param_set$values <- tuning_instance_ranger$result_learner_param_vals

Now that we have decided on the model to use for predicting \(E[T|X]\), let’s implement cross-fitting. Before cross-fitting, we need to tell t_learner_ranger to predict probability instead of classification (either 0 or 1).

t_learner_ranger$predict_type <- "prob"

cv_results_t <-
  resample(
    t_est_task, 
    t_learner_ranger, 
    rsmp("repeated_cv", repeats = 4, folds = 3) 
  )
INFO  [10:53:16.482] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 2/12) 
INFO  [10:53:16.614] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 11/12) 
INFO  [10:53:16.742] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 10/12) 
INFO  [10:53:16.869] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 8/12) 
INFO  [10:53:16.998] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 12/12) 
INFO  [10:53:17.124] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 6/12) 
INFO  [10:53:17.251] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 7/12) 
INFO  [10:53:17.387] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 3/12) 
INFO  [10:53:17.516] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 5/12) 
INFO  [10:53:17.645] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 4/12) 
INFO  [10:53:17.772] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 9/12) 
INFO  [10:53:17.896] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 1/12) 
#=== all combined ===#
all_predictions_t <- 
  cv_results_t$prediction() %>% 
  as.data.table() %>% 
  .[, .(t_hat = mean(prob.TRUE)), by = row_ids] %>% 
  .[order(row_ids), ]

We now use all_predictions_y and all_predictions_t for Y.hat and W.hat in grf::causal_forest().

grf::causal_forest(
  X = data_trt[, .(age, educ, married, ethn_other, ethn_black, ethn_hispanic)] %>% as.matrix(),
  Y = data_trt[, re78],
  W = data_trt[, fifelse(treat == TRUE, 1, 0)],
  Y.hat = all_predictions_y[, y_hat],
  W.hat = all_predictions_t[, t_hat]
)
GRF forest object of type causal_forest 
Number of trees: 2000 
Number of training samples: 2675 
Variable importance: 
    1     2     3     4     5     6 
0.506 0.302 0.116 0.023 0.044 0.009