library(mlr3verse)
18 Machine Learning with mlr3
The mlr3
package is an R package that makes it easy for you to implement standard machine learning procedures (e.g., training, prediction, cross-validation) in a unified framework. It is similar to the Python scikit-learn package. In this section, we cover the basic use cases of the mlr3
package accompanied by an real-world example at the end. The authors of the package is currently writing a book on how to use it. Materials covered in this section is basically a condensed (and far less comprehensive) version of the book. Yet, they should still be sufficient for the majority of ML tasks you typically implement in practice.
To start working with mlr3
, it is convenient to load the mlr3verse
package, which is a collection of packages that provide useful extensions (e.g., quick visualization (mlr3viz
), machine learning methods (mlr3leaners
), etc). It is like the tidyverse
package.
See here for the list of included packages.
For those who have been using solely R for their programming needs are not likely to be familiar with the way mlr3
works. It uses R6 classes provided the R6
package. The package provides an implementation of encapsulated object-oriented programing, which how Python works. So, if you are familiar with Python, then mlr3
should come quite natural to you. Fortunately, understanding how R6
works is not too hard (especially for us who are just using it). Reading the introduction of the R6
provided here should suffice.
To implement ML tasks, we need two core components at least: task, and learner. Roughly speaking, here are what they are.
- Task: data
- Learner: model
We will take a deeper look at these two components first, and then training, prediction, and other ML tasks.
18.1 Tasks
Packages to load for replication
library(data.table)
library(DoubleML)
library(mlr3verse)
library(tidyverse)
library(mltools)
library(parallel)
A task in the mlr3
parlance is basically a dataset (called backend
) with information on which variable is the dependent variable (called target
) and which variables are explanatory variables (called features
).
Here, we will use mtcars
data.
#=== load mtcars ===#
data("mtcars", package = "datasets")
#=== see the first 6 rows ===#
head(mtcars)
mpg cyl disp hp drat wt qsec vs am gear carb
Mazda RX4 21.0 6 160 110 3.90 2.620 16.46 0 1 4 4
Mazda RX4 Wag 21.0 6 160 110 3.90 2.875 17.02 0 1 4 4
Datsun 710 22.8 4 108 93 3.85 2.320 18.61 1 1 4 1
Hornet 4 Drive 21.4 6 258 110 3.08 3.215 19.44 1 0 3 1
Hornet Sportabout 18.7 8 360 175 3.15 3.440 17.02 0 0 3 2
Valiant 18.1 6 225 105 2.76 3.460 20.22 1 0 3 1
When you create a task, you need to recognize what type of analysis you will be doing and use an appropriate class. Here, we will be running regression, so we use the TaskRegr
class (see here for other task types).
Now, let’s create a task using the TaskRegr
class with
mtcars
asbackend
(data)mpg
astarget
(dependent variable)example
asid
(this is the id for the task, you can give any name)
You can instantiate a new TaskRegr
instance using the new()
methods on the TaskRegr
class.
(<-
reg_task $new(
TaskRegrid = "example",
backend = mtcars,
target = "mpg"
) )
<TaskRegr:example> (32 x 11)
* Target: mpg
* Properties: -
* Features (10):
- dbl (10): am, carb, cyl, disp, drat, gear, hp, qsec, vs, wt
As you can see, mpg
is the Target
and the rest was automatically set to Features
.
You can use the cole_roles()
method to return the roles of the variables.
$col_roles reg_task
$feature
[1] "am" "carb" "cyl" "disp" "drat" "gear" "hp" "qsec" "vs" "wt"
$target
[1] "mpg"
$name
character(0)
$order
character(0)
$stratum
character(0)
$group
character(0)
$weight
character(0)
You can extract information from the task using various methods:
- $nrow: returns the number of rows
- $ncol: returns the number of columns
- $feature_names: returns the name of the feature variables
- $target_names: returns the name of the target variable(s)
- $row_ids: return row ids (integers starting from 1 to the number of rows)
- $data(): returns the backend (data) as a
data.table
Let’s see some of these.
#=== row ids ===#
$row_ids reg_task
[1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
[26] 26 27 28 29 30 31 32
#=== data ===#
$data() reg_task
mpg am carb cyl disp drat gear hp qsec vs wt
1: 21.0 1 4 6 160.0 3.90 4 110 16.46 0 2.620
2: 21.0 1 4 6 160.0 3.90 4 110 17.02 0 2.875
3: 22.8 1 1 4 108.0 3.85 4 93 18.61 1 2.320
4: 21.4 0 1 6 258.0 3.08 3 110 19.44 1 3.215
5: 18.7 0 2 8 360.0 3.15 3 175 17.02 0 3.440
6: 18.1 0 1 6 225.0 2.76 3 105 20.22 1 3.460
7: 14.3 0 4 8 360.0 3.21 3 245 15.84 0 3.570
8: 24.4 0 2 4 146.7 3.69 4 62 20.00 1 3.190
9: 22.8 0 2 4 140.8 3.92 4 95 22.90 1 3.150
10: 19.2 0 4 6 167.6 3.92 4 123 18.30 1 3.440
11: 17.8 0 4 6 167.6 3.92 4 123 18.90 1 3.440
12: 16.4 0 3 8 275.8 3.07 3 180 17.40 0 4.070
13: 17.3 0 3 8 275.8 3.07 3 180 17.60 0 3.730
14: 15.2 0 3 8 275.8 3.07 3 180 18.00 0 3.780
15: 10.4 0 4 8 472.0 2.93 3 205 17.98 0 5.250
16: 10.4 0 4 8 460.0 3.00 3 215 17.82 0 5.424
17: 14.7 0 4 8 440.0 3.23 3 230 17.42 0 5.345
18: 32.4 1 1 4 78.7 4.08 4 66 19.47 1 2.200
19: 30.4 1 2 4 75.7 4.93 4 52 18.52 1 1.615
20: 33.9 1 1 4 71.1 4.22 4 65 19.90 1 1.835
21: 21.5 0 1 4 120.1 3.70 3 97 20.01 1 2.465
22: 15.5 0 2 8 318.0 2.76 3 150 16.87 0 3.520
23: 15.2 0 2 8 304.0 3.15 3 150 17.30 0 3.435
24: 13.3 0 4 8 350.0 3.73 3 245 15.41 0 3.840
25: 19.2 0 2 8 400.0 3.08 3 175 17.05 0 3.845
26: 27.3 1 1 4 79.0 4.08 4 66 18.90 1 1.935
27: 26.0 1 2 4 120.3 4.43 5 91 16.70 0 2.140
28: 30.4 1 2 4 95.1 3.77 5 113 16.90 1 1.513
29: 15.8 1 4 8 351.0 4.22 5 264 14.50 0 3.170
30: 19.7 1 6 6 145.0 3.62 5 175 15.50 0 2.770
31: 15.0 1 8 8 301.0 3.54 5 335 14.60 0 3.570
32: 21.4 1 2 4 121.0 4.11 4 109 18.60 1 2.780
mpg am carb cyl disp drat gear hp qsec vs wt
It is possible to retrieve only a portion of the data using rows
and cols
options inside data()
as follows:
$data(rows = 1:10, cols = c("mpg", 'wt')) reg_task
mpg wt
1: 21.0 2.620
2: 21.0 2.875
3: 22.8 2.320
4: 21.4 3.215
5: 18.7 3.440
6: 18.1 3.460
7: 14.3 3.570
8: 24.4 3.190
9: 22.8 3.150
10: 19.2 3.440
To retrieve the complete data from the task, you can apply as.data.table()
to the task.
(<- as.data.table(reg_task)
data_extracted )
mpg am carb cyl disp drat gear hp qsec vs wt
1: 21.0 1 4 6 160.0 3.90 4 110 16.46 0 2.620
2: 21.0 1 4 6 160.0 3.90 4 110 17.02 0 2.875
3: 22.8 1 1 4 108.0 3.85 4 93 18.61 1 2.320
4: 21.4 0 1 6 258.0 3.08 3 110 19.44 1 3.215
5: 18.7 0 2 8 360.0 3.15 3 175 17.02 0 3.440
6: 18.1 0 1 6 225.0 2.76 3 105 20.22 1 3.460
7: 14.3 0 4 8 360.0 3.21 3 245 15.84 0 3.570
8: 24.4 0 2 4 146.7 3.69 4 62 20.00 1 3.190
9: 22.8 0 2 4 140.8 3.92 4 95 22.90 1 3.150
10: 19.2 0 4 6 167.6 3.92 4 123 18.30 1 3.440
11: 17.8 0 4 6 167.6 3.92 4 123 18.90 1 3.440
12: 16.4 0 3 8 275.8 3.07 3 180 17.40 0 4.070
13: 17.3 0 3 8 275.8 3.07 3 180 17.60 0 3.730
14: 15.2 0 3 8 275.8 3.07 3 180 18.00 0 3.780
15: 10.4 0 4 8 472.0 2.93 3 205 17.98 0 5.250
16: 10.4 0 4 8 460.0 3.00 3 215 17.82 0 5.424
17: 14.7 0 4 8 440.0 3.23 3 230 17.42 0 5.345
18: 32.4 1 1 4 78.7 4.08 4 66 19.47 1 2.200
19: 30.4 1 2 4 75.7 4.93 4 52 18.52 1 1.615
20: 33.9 1 1 4 71.1 4.22 4 65 19.90 1 1.835
21: 21.5 0 1 4 120.1 3.70 3 97 20.01 1 2.465
22: 15.5 0 2 8 318.0 2.76 3 150 16.87 0 3.520
23: 15.2 0 2 8 304.0 3.15 3 150 17.30 0 3.435
24: 13.3 0 4 8 350.0 3.73 3 245 15.41 0 3.840
25: 19.2 0 2 8 400.0 3.08 3 175 17.05 0 3.845
26: 27.3 1 1 4 79.0 4.08 4 66 18.90 1 1.935
27: 26.0 1 2 4 120.3 4.43 5 91 16.70 0 2.140
28: 30.4 1 2 4 95.1 3.77 5 113 16.90 1 1.513
29: 15.8 1 4 8 351.0 4.22 5 264 14.50 0 3.170
30: 19.7 1 6 6 145.0 3.62 5 175 15.50 0 2.770
31: 15.0 1 8 8 301.0 3.54 5 335 14.60 0 3.570
32: 21.4 1 2 4 121.0 4.11 4 109 18.60 1 2.780
mpg am carb cyl disp drat gear hp qsec vs wt
You can mutate tasks using the select()
and filter()
methods. It is important to remember here that the instance at which these mutations are implemented is indeed mutated. Let’s see what I mean by this.
#=== first select few variables ===#
$select(c("am", "carb", "cyl"))
reg_task
#=== see the backend now ===#
$data() reg_task
mpg am carb cyl
1: 21.0 1 4 6
2: 21.0 1 4 6
3: 22.8 1 1 4
4: 21.4 0 1 6
5: 18.7 0 2 8
6: 18.1 0 1 6
7: 14.3 0 4 8
8: 24.4 0 2 4
9: 22.8 0 2 4
10: 19.2 0 4 6
11: 17.8 0 4 6
12: 16.4 0 3 8
13: 17.3 0 3 8
14: 15.2 0 3 8
15: 10.4 0 4 8
16: 10.4 0 4 8
17: 14.7 0 4 8
18: 32.4 1 1 4
19: 30.4 1 2 4
20: 33.9 1 1 4
21: 21.5 0 1 4
22: 15.5 0 2 8
23: 15.2 0 2 8
24: 13.3 0 4 8
25: 19.2 0 2 8
26: 27.3 1 1 4
27: 26.0 1 2 4
28: 30.4 1 2 4
29: 15.8 1 4 8
30: 19.7 1 6 6
31: 15.0 1 8 8
32: 21.4 1 2 4
mpg am carb cyl
As you can see, reg_task
now holds only the variables that were selected (plus the target variable mpg
). This behavior is similar to how data.table
works when you create a new variable using data[, := ]
syntax. And, this is different from how dplyr::select
works.
#=== create a dataset ===#
<- reg_task$data()
data_temp
#=== select mpg, carb ===#
::select(data_temp, mpg, carb) dplyr
mpg carb
1: 21.0 4
2: 21.0 4
3: 22.8 1
4: 21.4 1
5: 18.7 2
6: 18.1 1
7: 14.3 4
8: 24.4 2
9: 22.8 2
10: 19.2 4
11: 17.8 4
12: 16.4 3
13: 17.3 3
14: 15.2 3
15: 10.4 4
16: 10.4 4
17: 14.7 4
18: 32.4 1
19: 30.4 2
20: 33.9 1
21: 21.5 1
22: 15.5 2
23: 15.2 2
24: 13.3 4
25: 19.2 2
26: 27.3 1
27: 26.0 2
28: 30.4 2
29: 15.8 4
30: 19.7 6
31: 15.0 8
32: 21.4 2
mpg carb
#=== look at data_temp ===#
data_temp
mpg am carb cyl
1: 21.0 1 4 6
2: 21.0 1 4 6
3: 22.8 1 1 4
4: 21.4 0 1 6
5: 18.7 0 2 8
6: 18.1 0 1 6
7: 14.3 0 4 8
8: 24.4 0 2 4
9: 22.8 0 2 4
10: 19.2 0 4 6
11: 17.8 0 4 6
12: 16.4 0 3 8
13: 17.3 0 3 8
14: 15.2 0 3 8
15: 10.4 0 4 8
16: 10.4 0 4 8
17: 14.7 0 4 8
18: 32.4 1 1 4
19: 30.4 1 2 4
20: 33.9 1 1 4
21: 21.5 0 1 4
22: 15.5 0 2 8
23: 15.2 0 2 8
24: 13.3 0 4 8
25: 19.2 0 2 8
26: 27.3 1 1 4
27: 26.0 1 2 4
28: 30.4 1 2 4
29: 15.8 1 4 8
30: 19.7 1 6 6
31: 15.0 1 8 8
32: 21.4 1 2 4
mpg am carb cyl
As you can see, data_temp
is not mutated after select()
. To save the the result of dplyr::select()
, you need to explicitly assign it to another R object.
The target variable cannot be selected in select()
. It is automatically selected.
If you would like to keep the original task, you can use the clone()
method to create a distinct instance.
#=== create a clone ===#
<- reg_task$clone() reg_task_independent
Let’s filter the data using the filter()
method.
#=== filter ===#
$filter(1:10)
reg_task
#=== see the backend ===#
$data() reg_task
mpg am carb cyl
1: 21.0 1 4 6
2: 21.0 1 4 6
3: 22.8 1 1 4
4: 21.4 0 1 6
5: 18.7 0 2 8
6: 18.1 0 1 6
7: 14.3 0 4 8
8: 24.4 0 2 4
9: 22.8 0 2 4
10: 19.2 0 4 6
However, reg_task_independent
is not affected.
$data() reg_task_independent
mpg am carb cyl
1: 21.0 1 4 6
2: 21.0 1 4 6
3: 22.8 1 1 4
4: 21.4 0 1 6
5: 18.7 0 2 8
6: 18.1 0 1 6
7: 14.3 0 4 8
8: 24.4 0 2 4
9: 22.8 0 2 4
10: 19.2 0 4 6
11: 17.8 0 4 6
12: 16.4 0 3 8
13: 17.3 0 3 8
14: 15.2 0 3 8
15: 10.4 0 4 8
16: 10.4 0 4 8
17: 14.7 0 4 8
18: 32.4 1 1 4
19: 30.4 1 2 4
20: 33.9 1 1 4
21: 21.5 0 1 4
22: 15.5 0 2 8
23: 15.2 0 2 8
24: 13.3 0 4 8
25: 19.2 0 2 8
26: 27.3 1 1 4
27: 26.0 1 2 4
28: 30.4 1 2 4
29: 15.8 1 4 8
30: 19.7 1 6 6
31: 15.0 1 8 8
32: 21.4 1 2 4
mpg am carb cyl
You can use the rbind()
and cbind()
methods to append data vertically and horizontally, respectively.
Here is an example use of rbind()
.
$rbind(
reg_taskdata.table(mpg = 20, am = 1, carb = 3, cyl = 99)
)
#=== see the change ===#
$data() reg_task
mpg am carb cyl
1: 21.0 1 4 6
2: 21.0 1 4 6
3: 22.8 1 1 4
4: 21.4 0 1 6
5: 18.7 0 2 8
6: 18.1 0 1 6
7: 14.3 0 4 8
8: 24.4 0 2 4
9: 22.8 0 2 4
10: 19.2 0 4 6
11: 20.0 1 3 99
18.2 Learners
In Python, the majority of ML packages are written to be compatible with scikit-learn
framework. However, in R, there is no single framework that is equivalent to scikit-learn
to which all the developers of ML packages conform to. Fortunately, the author of the mlr3
package picked popular ML packages (e.g., ranger
, xgboost
) and made it easier for us to use those packages under the unified framework.
tidymodels
have their own collection of packages, which is very similar to what mlr3
has.
Here is the list of learners that is available after loading the mlr3verse
package.
mlr_learners
<DictionaryLearner> with 129 stored values
Keys: classif.AdaBoostM1, classif.bart, classif.C50, classif.catboost,
classif.cforest, classif.ctree, classif.cv_glmnet, classif.debug,
classif.earth, classif.featureless, classif.fnn, classif.gam,
classif.gamboost, classif.gausspr, classif.gbm, classif.glmboost,
classif.glmnet, classif.IBk, classif.J48, classif.JRip, classif.kknn,
classif.ksvm, classif.lda, classif.liblinear, classif.lightgbm,
classif.LMT, classif.log_reg, classif.lssvm, classif.mob,
classif.multinom, classif.naive_bayes, classif.nnet, classif.OneR,
classif.PART, classif.qda, classif.randomForest, classif.ranger,
classif.rfsrc, classif.rpart, classif.svm, classif.xgboost,
clust.agnes, clust.ap, clust.cmeans, clust.cobweb, clust.dbscan,
clust.diana, clust.em, clust.fanny, clust.featureless, clust.ff,
clust.hclust, clust.kkmeans, clust.kmeans, clust.MBatchKMeans,
clust.meanshift, clust.pam, clust.SimpleKMeans, clust.xmeans,
dens.kde_ks, dens.locfit, dens.logspline, dens.mixed, dens.nonpar,
dens.pen, dens.plug, dens.spline, regr.bart, regr.catboost,
regr.cforest, regr.ctree, regr.cubist, regr.cv_glmnet, regr.debug,
regr.earth, regr.featureless, regr.fnn, regr.gam, regr.gamboost,
regr.gausspr, regr.gbm, regr.glm, regr.glmboost, regr.glmnet,
regr.IBk, regr.kknn, regr.km, regr.ksvm, regr.liblinear,
regr.lightgbm, regr.lm, regr.lmer, regr.M5Rules, regr.mars, regr.mob,
regr.randomForest, regr.ranger, regr.rfsrc, regr.rpart, regr.rvm,
regr.svm, regr.xgboost, surv.akritas, surv.blackboost, surv.cforest,
surv.coxboost, surv.coxtime, surv.ctree, surv.cv_coxboost,
surv.cv_glmnet, surv.deephit, surv.deepsurv, surv.dnnsurv,
surv.flexible, surv.gamboost, surv.gbm, surv.glmboost, surv.glmnet,
surv.loghaz, surv.mboost, surv.nelson, surv.obliqueRSF,
surv.parametric, surv.pchazard, surv.penalized, surv.ranger,
surv.rfsrc, surv.svm, surv.xgboost
As you can see, packages that we have seen earlier (e.g., glmnet
, ranger
, xgboost
, gam
) are available. the mlr3extralearners
package provides you with additional but less-supported learners. You can check the complete list of learners here.
You set up a learner by giving the name of the learner you would like to implement from the list to lrn()
like below.
<- lrn("regr.ranger") learner
Note that you need to pick the one with the appropriate prediction type prefix (the prefixes are self-explanatory). Here, since we are interested in regression, we picked "regr.ranger"
.
Once you set up a learner, you can see the set of the learner’s hyper-parameters.
$param_set learner
<ParamSet>
id class lower upper nlevels default
1: alpha ParamDbl -Inf Inf Inf 0.5
2: always.split.variables ParamUty NA NA Inf <NoDefault[3]>
3: holdout ParamLgl NA NA 2 FALSE
4: importance ParamFct NA NA 4 <NoDefault[3]>
5: keep.inbag ParamLgl NA NA 2 FALSE
6: max.depth ParamInt 0 Inf Inf
7: min.node.size ParamInt 1 Inf Inf 5
8: min.prop ParamDbl -Inf Inf Inf 0.1
9: minprop ParamDbl -Inf Inf Inf 0.1
10: mtry ParamInt 1 Inf Inf <NoDefault[3]>
11: mtry.ratio ParamDbl 0 1 Inf <NoDefault[3]>
12: num.random.splits ParamInt 1 Inf Inf 1
13: num.threads ParamInt 1 Inf Inf 1
14: num.trees ParamInt 1 Inf Inf 500
15: oob.error ParamLgl NA NA 2 TRUE
16: quantreg ParamLgl NA NA 2 FALSE
17: regularization.factor ParamUty NA NA Inf 1
18: regularization.usedepth ParamLgl NA NA 2 FALSE
19: replace ParamLgl NA NA 2 TRUE
20: respect.unordered.factors ParamFct NA NA 3 ignore
21: sample.fraction ParamDbl 0 1 Inf <NoDefault[3]>
22: save.memory ParamLgl NA NA 2 FALSE
23: scale.permutation.importance ParamLgl NA NA 2 FALSE
24: se.method ParamFct NA NA 2 infjack
25: seed ParamInt -Inf Inf Inf
26: split.select.weights ParamUty NA NA Inf
27: splitrule ParamFct NA NA 3 variance
28: verbose ParamLgl NA NA 2 TRUE
29: write.forest ParamLgl NA NA 2 TRUE
id class lower upper nlevels default
parents value
1: splitrule
2:
3:
4:
5:
6:
7:
8:
9: splitrule
10:
11:
12: splitrule
13: 1
14:
15:
16:
17:
18:
19:
20:
21:
22:
23: importance
24:
25:
26:
27:
28:
29:
parents value
Right now, only num.threads
is set explicitly.
$param_set$values learner
$num.threads
[1] 1
You can update or assign the value of a parameter like this:
#=== set max.depth to 5 ===#
$param_set$values$max.depth <- 5
learner
#=== see the values ===#
$param_set$values learner
$num.threads
[1] 1
$max.depth
[1] 5
When you would like to set values for multiple hyper-parameters at the same time, you can provide a named list to $param_set$values
. Here is an example:
#=== create a named vector ===#
<- list("min.node.size" = 10, "mtry" = 5, "num.trees" = 500)
parameter_values
#=== assign them ===#
$param_set$values <- parameter_values
learner
#=== see the values ===#
$param_set$values learner
$min.node.size
[1] 10
$mtry
[1] 5
$num.trees
[1] 500
But, notice that the values we set previously for max.depth
and num.threads
are gone.
18.3 Train, predict, assessment
18.3.1 Train
Training a model can be done using the $train()
method on a leaner by supplying a task to it. Let’s first set up a task and learner.
#=== define a task ===#
<-
reg_task $new(
TaskRegrid = "example",
backend = mtcars,
target = "mpg"
)
#=== set up a learner ===#
<- lrn("regr.ranger")
learner $param_set$values <-
learnerlist(
"min.node.size" = 10,
"mtry" = 5,
"num.trees" = 500
)
Notice that the model
attribute of the learner is empty at this point.
$model learner
NULL
Now, let’s train.
$train(reg_task) learner
We now how information about the trained model in the model
attribute.
$model learner
Ranger result
Call:
ranger::ranger(dependent.variable.name = task$target_names, data = task$data(), case.weights = task$weights$weight, min.node.size = 10L, mtry = 5L, num.trees = 500L)
Type: Regression
Number of trees: 500
Sample size: 32
Number of independent variables: 10
Mtry: 5
Target node size: 10
Variable importance mode: none
Splitrule: variance
OOB prediction error (MSE): 6.369001
R squared (OOB): 0.8246618
Notice that this is exactly what you would get if you use ranger()
to train your model.
The train()
function has the row_ids()
option, where you can specify which rows of the data backend in the task are used for training.
Let’s split extract the row_ids
attribute and then split it for train and test purposes.
#=== extract row ids ===#
<- reg_task$row_ids
row_ids
#=== train ===#
<- row_ids[1:(length(row_ids) / 2)]
train_ids
#=== test ===#
<- row_ids[!(row_ids %in% train_ids)] test_ids
Now train using the train data.
#=== train ===#
$train(reg_task, row_ids = train_ids)
learner
#=== seed the trained model ===#
$model learner
Ranger result
Call:
ranger::ranger(dependent.variable.name = task$target_names, data = task$data(), case.weights = task$weights$weight, min.node.size = 10L, mtry = 5L, num.trees = 500L)
Type: Regression
Number of trees: 500
Sample size: 16
Number of independent variables: 10
Mtry: 5
Target node size: 10
Variable importance mode: none
Splitrule: variance
OOB prediction error (MSE): 5.24416
R squared (OOB): 0.6951542
18.3.2 Prediction
We can use the predict()
method to make predictions by supplying a task to it. Just like the train()
method, we can use the row_ids
option inside predict_newdata()
to apply prediction only on a portion of the data. Let’s use test_ids
we created above.
<- learner$predict(reg_task, row_ids = test_ids) prediction
prediction
is of a class called Prediction
. You can make the prediction available as a data.table
by using as.data.table()
.
as.data.table(prediction)
row_ids truth response
1: 17 14.7 13.71786
2: 18 32.4 21.80367
3: 19 30.4 21.79294
4: 20 33.9 21.80987
5: 21 21.5 21.69345
6: 22 15.5 16.47427
7: 23 15.2 17.81916
8: 24 13.3 14.88102
9: 25 19.2 16.24036
10: 26 27.3 21.79751
11: 27 26.0 21.71119
12: 28 30.4 21.63942
13: 29 15.8 17.17378
14: 30 19.7 20.59941
15: 31 15.0 15.42186
16: 32 21.4 21.67216
You can predict on a new dataset by supplying a new dataset as a data.frame
/data.table
to the predict_newdata()
method. Here, we just use parts of mtcars
(just pretend this is a newdataset).
<- learner$predict_newdata(mtcars) prediction
18.3.3 Performance assessment
There are many measure of performance we can use under mlr3
. Here is the list:
mlr_measures
<DictionaryMeasure> with 63 stored values
Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier,
classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
classif.logloss, classif.mbrier, classif.mcc, classif.npv,
classif.ppv, classif.prauc, classif.precision, classif.recall,
classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
classif.tp, classif.tpr, clust.ch, clust.db, clust.dunn,
clust.silhouette, clust.wss, debug, oob_error, regr.bias, regr.ktau,
regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse, regr.mse,
regr.msle, regr.pbias, regr.rae, regr.rmse, regr.rmsle, regr.rrse,
regr.rse, regr.rsq, regr.sae, regr.smape, regr.srho, regr.sse,
selected_features, sim.jaccard, sim.phi, time_both, time_predict,
time_train
We can access a measure using the msr()
function.
#=== get a measure ===#
<- msr("regr.mse")
measure
#=== check the class ===#
class(measure)
[1] "MeasureRegrSimple" "MeasureRegr" "Measure"
[4] "R6"
We can do performance evaluation using the $score()
method on a Prediction
object by supplying a Measure
object.
$score(measure) prediction
regr.mse
16.13093
18.4 Resampling, cross-validation, and cross-fitting
18.4.1 Resampling
mlr3
offers the following resampling methods:
as.data.table(mlr_resamplings)
key label params iters
1: bootstrap Bootstrap ratio,repeats 30
2: custom Custom Splits NA
3: custom_cv Custom Split Cross-Validation NA
4: cv Cross-Validation folds 10
5: holdout Holdout ratio 1
6: insample Insample Resampling 1
7: loo Leave-One-Out NA
8: repeated_cv Repeated Cross-Validation folds,repeats 100
9: subsampling Subsampling ratio,repeats 30
You can access a resampling method using the rsmp()
function. You can specify parameters at the same time.
(<- rsmp("repeated_cv", repeats = 2, folds = 3)
resampling )
<ResamplingRepeatedCV>: Repeated Cross-Validation
* Iterations: 6
* Instantiated: FALSE
* Parameters: repeats=2, folds=3
You can check the number of iterations (number of train-test datasets combinations) by accessing the iters
attribute.
$iters resampling
[1] 6
You can override parameters just like you did for a leaner.
#=== update ===#
$param_set$values = list(repeats = 3, folds = 4)
resampling
#=== see the updates ===#
resampling
<ResamplingRepeatedCV>: Repeated Cross-Validation
* Iterations: 12
* Instantiated: FALSE
* Parameters: repeats=3, folds=4
We can use the instantiate()
method to implement the specified resampling method:
$instantiate(reg_task) resampling
You can access the train and test datasets using the train_set()
and test_set()
method. respectively. Since repeats = 3
and folds = 4
, we have 12 sets of train and test datasets. You indicate which set you want inside train_set()
and test_set()
.
First pair:
$train_set(1) resampling
[1] 5 6 13 14 18 25 28 31 2 9 15 16 20 23 27 30 3 4 17 19 21 22 24 26
$test_set(1) resampling
[1] 1 7 8 10 11 12 29 32
Last pair:
$train_set(12) resampling
[1] 3 5 12 14 19 20 25 29 9 15 18 21 24 28 30 31 2 4 6 10 11 22 26 32
$test_set(12) resampling
[1] 1 7 8 13 16 17 23 27
18.4.2 Cross-validation and cross-fitting
Now that data splits are determined (along with a task and leaner), we can conduct a cross-validation using the resample()
function like below.
<- resample(reg_task, learner, resampling) cv_results
INFO [10:52:25.594] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/12)
INFO [10:52:25.616] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/12)
INFO [10:52:25.629] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 10/12)
INFO [10:52:25.638] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 6/12)
INFO [10:52:25.648] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 12/12)
INFO [10:52:25.658] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 5/12)
INFO [10:52:25.667] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 8/12)
INFO [10:52:25.676] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/12)
INFO [10:52:25.686] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 7/12)
INFO [10:52:25.695] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 11/12)
INFO [10:52:25.704] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 4/12)
INFO [10:52:25.714] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 9/12)
This code applies the method specified in leaner
to each of the 12 train datasets, and evaluate the trained model on each of the 12 test datasets.
You can look at the prediction results using the predictions()
method.
$predictions() cv_results
[[1]]
<PredictionRegr> for 8 observations:
row_ids truth response
1 21.0 21.34871
7 14.3 14.71893
8 24.4 23.00590
---
12 16.4 16.02169
29 15.8 18.13388
32 21.4 23.81583
[[2]]
<PredictionRegr> for 8 observations:
row_ids truth response
5 18.7 14.83204
6 18.1 19.85426
13 17.3 15.03344
---
25 19.2 14.59602
28 30.4 24.80455
31 15.0 14.89332
[[3]]
<PredictionRegr> for 8 observations:
row_ids truth response
2 21.0 20.66545
9 22.8 22.73046
15 10.4 16.03584
---
23 15.2 17.66704
27 26.0 24.92584
30 19.7 19.87509
[[4]]
<PredictionRegr> for 8 observations:
row_ids truth response
3 22.8 27.84492
4 21.4 19.01035
17 14.7 13.06798
---
22 15.5 16.55219
24 13.3 15.67540
26 27.3 28.81684
[[5]]
<PredictionRegr> for 8 observations:
row_ids truth response
2 21.0 20.62481
3 22.8 25.88562
11 17.8 19.87309
---
23 15.2 17.86239
28 30.4 25.34370
32 21.4 23.44607
[[6]]
<PredictionRegr> for 8 observations:
row_ids truth response
4 21.4 18.79353
5 18.7 16.19981
10 19.2 18.58758
---
24 13.3 15.79908
26 27.3 29.61911
30 19.7 20.85556
[[7]]
<PredictionRegr> for 8 observations:
row_ids truth response
1 21.0 20.54647
7 14.3 14.97206
9 22.8 23.21974
---
18 32.4 27.26674
21 21.5 23.42506
27 26.0 24.08510
[[8]]
<PredictionRegr> for 8 observations:
row_ids truth response
6 18.1 19.57434
8 24.4 22.70844
16 10.4 14.07726
---
25 19.2 16.03038
29 15.8 16.94906
31 15.0 15.58641
[[9]]
<PredictionRegr> for 8 observations:
row_ids truth response
3 22.8 25.67939
5 18.7 15.04072
12 16.4 15.76059
---
20 33.9 27.56829
25 19.2 14.73703
29 15.8 16.71143
[[10]]
<PredictionRegr> for 8 observations:
row_ids truth response
9 22.8 23.56904
15 10.4 13.90063
18 32.4 27.14532
---
28 30.4 24.69515
30 19.7 20.17987
31 15.0 15.55283
[[11]]
<PredictionRegr> for 8 observations:
row_ids truth response
2 21.0 21.85924
4 21.4 20.37669
6 18.1 19.36490
---
22 15.5 16.19947
26 27.3 29.14830
32 21.4 24.31793
[[12]]
<PredictionRegr> for 8 observations:
row_ids truth response
1 21.0 20.60021
7 14.3 15.26024
8 24.4 22.88854
---
17 14.7 14.66568
23 15.2 17.57585
27 26.0 24.52656
You can get the all predictions combined using the prediction()
method.
#=== all combined ===#
<- cv_results$prediction()
all_predictions
#=== check the class ===#
class(all_predictions)
[1] "PredictionRegr" "Prediction" "R6"
Since it is a Prediciton
object, we can apply the score()
method like this.
$score(msr("regr.mse")) all_predictions
regr.mse
7.51057
Of course, you are also cross-fitting when you are doing cross-validation. You can just take the prediction results and use them if you are implementing DML for example.
(<-
cross_fitted_yhat as.data.table(all_predictions) %>%
y_hat_cf = mean(response)), by = row_ids]
.[, .( )
row_ids y_hat_cf
1: 1 20.83180
2: 7 14.98374
3: 8 22.86763
4: 10 20.01379
5: 11 20.45234
6: 12 15.81812
7: 29 17.26479
8: 32 23.85994
9: 5 15.35752
10: 6 19.59783
11: 13 15.45907
12: 14 15.70010
13: 18 27.18675
14: 25 15.12114
15: 28 24.94780
16: 31 15.34419
17: 2 21.04983
18: 9 23.17308
19: 15 14.85905
20: 16 14.75365
21: 20 28.18517
22: 23 17.70176
23: 27 24.51250
24: 30 20.30351
25: 3 26.46998
26: 4 19.39352
27: 17 13.58068
28: 19 28.22748
29: 21 24.24876
30: 22 16.28554
31: 24 15.64745
32: 26 29.19475
row_ids y_hat_cf
18.5 Hyper-parameter tuning
In conducting hyper-parameter tuning under the mlr3
framework, you define TuningInstance*
class, select the tuning method, and then trigger it.
18.5.1 TuningInstance
There are two tuning instance classes.
- TuningInstanceSingleCrit
- TuningInstanceMultiCrit
The difference should be clear by looking at the name of the classes. We focus on the TuningInstanceSingleCrit
class here.
Tuning instance consists of six elements:
task
learner
resampling
measure
search_space
terminator
We have covered all except search_space
and terminator
. Let’s look at these two.
Let’s quickly create the first four elements.
#=== task ===#
<-
reg_task $new(
TaskRegrid = "example",
backend = mtcars,
target = "mpg"
)
#=== learner ===#
<- lrn("regr.ranger")
learner $param_set$values$max.depth <- 10
learner
#=== resampling ===#
<- rsmp("cv", folds = 3) # k-fold cv
resampling
#=== measure ===#
<- msr("regr.mse") measure
search_space
defines which hyper-parameters to tune and their ranges. You can use ps()
to create a search space. Before doing so, let’s look at the parameters of the learner.
$param_set learner
<ParamSet>
id class lower upper nlevels default
1: alpha ParamDbl -Inf Inf Inf 0.5
2: always.split.variables ParamUty NA NA Inf <NoDefault[3]>
3: holdout ParamLgl NA NA 2 FALSE
4: importance ParamFct NA NA 4 <NoDefault[3]>
5: keep.inbag ParamLgl NA NA 2 FALSE
6: max.depth ParamInt 0 Inf Inf
7: min.node.size ParamInt 1 Inf Inf 5
8: min.prop ParamDbl -Inf Inf Inf 0.1
9: minprop ParamDbl -Inf Inf Inf 0.1
10: mtry ParamInt 1 Inf Inf <NoDefault[3]>
11: mtry.ratio ParamDbl 0 1 Inf <NoDefault[3]>
12: num.random.splits ParamInt 1 Inf Inf 1
13: num.threads ParamInt 1 Inf Inf 1
14: num.trees ParamInt 1 Inf Inf 500
15: oob.error ParamLgl NA NA 2 TRUE
16: quantreg ParamLgl NA NA 2 FALSE
17: regularization.factor ParamUty NA NA Inf 1
18: regularization.usedepth ParamLgl NA NA 2 FALSE
19: replace ParamLgl NA NA 2 TRUE
20: respect.unordered.factors ParamFct NA NA 3 ignore
21: sample.fraction ParamDbl 0 1 Inf <NoDefault[3]>
22: save.memory ParamLgl NA NA 2 FALSE
23: scale.permutation.importance ParamLgl NA NA 2 FALSE
24: se.method ParamFct NA NA 2 infjack
25: seed ParamInt -Inf Inf Inf
26: split.select.weights ParamUty NA NA Inf
27: splitrule ParamFct NA NA 3 variance
28: verbose ParamLgl NA NA 2 TRUE
29: write.forest ParamLgl NA NA 2 TRUE
id class lower upper nlevels default
parents value
1: splitrule
2:
3:
4:
5:
6: 10
7:
8:
9: splitrule
10:
11:
12: splitrule
13: 1
14:
15:
16:
17:
18:
19:
20:
21:
22:
23: importance
24:
25:
26:
27:
28:
29:
parents value
Let’s tune three parameters here: mtry.ratio
, min.node.size
, and num.trees
. For each parameter to tune, we need to use an appropriate function to define the range. You can see what functions to use from class
variable.
$param_set %>%
learneras.data.table() %>%
%in% c("mtry.ratio", "min.node.size"), .(id, class, lower, upper)] .[id
id class lower upper
1: min.node.size ParamInt 1 Inf
2: mtry.ratio ParamDbl 0 1
In this case, we use p_int()
for “min.node.size” as its class is ParamInt
and use p_dbl()
for “mtry.ratio” as its class is ParamDbl
. You cannot specify the range that go beyond the lower
and upper
for each parameter.
<-
search_space ps(
mtry.ratio = p_dbl(lower = 0.5, upper = 0.9),
min.node.size = p_int(lower = 1, upper = 20)
)
Let’s define a terminator
now. mlr3
offers five different options. We will just look at the most common one here, which is TerminatorEvals
. TerminatorEvals
terminates tuning after a given number of iterations specified by the user (see here for other terminator options).
You can use the trm()
function to define a Terminator object.
<- trm("evals", n_evals = 100) terminator
Inside trm()
, “evals” indicates that we would like to use the TerminatorEvals
option.
Now that we have all the components specified, we can instantiate (generate) a TuningInstanceSingleCrit class.
(<-
tuning_instance $new(
TuningInstanceSingleCrittask = reg_task,
learner = learner,
resampling = resampling,
measure = measure,
search_space = search_space,
terminator = terminator
) )
<TuningInstanceSingleCrit>
* State: Not optimized
* Objective: <ObjectiveTuning:regr.ranger_on_example>
* Search Space:
id class lower upper nlevels
1: mtry.ratio ParamDbl 0.5 0.9 Inf
2: min.node.size ParamInt 1.0 20.0 20
* Terminator: <TerminatorEvals>
18.5.2 Tuner
Let’s now define the method of tuning. mlr3
offers four options:
- Grid Search (
TunerGridSearch
) - Random Search (
TunerRandomSearch
) - Generalized Simulated Annealing (
TunerGenSA
) - Non-Linear Optimization (
TunerNLoptr
)
Here we use TunerGridSearch
. We can set up a tuner using the tnr()
function.
<- tnr("grid_search", resolution = 4) tuner
resolution = 4
means that each parameter takes four values where the values are equidistant between the upper and lower bounds specified in search_space
. So, this tuning will look at \(4^2 = 16\) parameter configurations. Notice that we set the number of evaluations to 100 above. So, all \(16\) cases will be evaluated. However, if you set n_evals
lower than \(16\), then the tuning will not look at all the cases.
18.5.3 Tuning
You can trigger tuning by supplying a tuning instance to the optimizer()
method on tuner
.
$optimize(tuning_instance) tuner
Since the execution of this tuning prints so many lines of results, it is not presented here.
One the tuning is done, you can get the optimized parameters by accessing the result_learner_param_values
attribute of the tuning instance.
$result_learner_param_vals tuning_instance
$num.threads
[1] 1
$max.depth
[1] 10
$mtry.ratio
[1] 0.5
$min.node.size
[1] 1
You can look at the evaluation results of other parameter configurations by accessing the archive
attribute of the tuning instance.
as.data.table(tuning_instance$archive) %>% head()
mtry.ratio min.node.size regr.mse x_domain_mtry.ratio
1: 0.5000000 20 10.095093 0.5000000
2: 0.9000000 14 8.107527 0.9000000
3: 0.7666667 14 7.884286 0.7666667
4: 0.9000000 1 5.375487 0.9000000
5: 0.9000000 7 5.606771 0.9000000
6: 0.6333333 1 5.150592 0.6333333
x_domain_min.node.size runtime_learners timestamp batch_nr
1: 20 0.022 2022-08-19 10:52:25 1
2: 14 0.019 2022-08-19 10:52:25 2
3: 14 0.022 2022-08-19 10:52:26 3
4: 1 0.035 2022-08-19 10:52:26 4
5: 7 0.024 2022-08-19 10:52:26 5
6: 1 0.032 2022-08-19 10:52:26 6
warnings errors resample_result
1: 0 0 <ResampleResult[22]>
2: 0 0 <ResampleResult[22]>
3: 0 0 <ResampleResult[22]>
4: 0 0 <ResampleResult[22]>
5: 0 0 <ResampleResult[22]>
6: 0 0 <ResampleResult[22]>
18.5.4 AutoTuner
AutoTuner
sounds like it is a tuner, but it is really learner where tuning is automatically implemented when training is triggered with a task. An AutoTuner
class can be instantiated using the new()
method on AutoTuner
class like below.
(<-
auto_tunning_learner $new(
AutoTunerlearner = learner,
resampling = resampling,
measure = measure,
search_space = search_space,
terminator = terminator,
tuner = tuner
) )
<AutoTuner:regr.ranger.tuned>
* Model: -
* Search Space:
<ParamSet>
id class lower upper nlevels default value
1: mtry.ratio ParamDbl 0.5 0.9 Inf <NoDefault[3]>
2: min.node.size ParamInt 1.0 20.0 20 <NoDefault[3]>
* Packages: mlr3, mlr3tuning, mlr3learners, ranger
* Predict Type: response
* Feature Types: logical, integer, numeric, character, factor, ordered
* Properties: hotstart_backward, importance, oob_error, weights
Note that unlike TuningInstance, AutoTuner
does not take a task as its element and takes a Tuner
instead. Once an AutoTuner
is instantiated, you can use it like a learner and invoke the train()
method with a task to train a model. The difference from a regular leaner is that it automatically tune the parameters internally and use the optimized parameter values to train.
$train(reg_task) auto_tunning_learner
INFO [10:52:26.915] [bbotk] Starting to optimize 2 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=100, k=0]'
INFO [10:52:26.916] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:26.925] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:26.928] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:26.938] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:26.947] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:26.956] [mlr3] Finished benchmark
INFO [10:52:26.966] [bbotk] Result of batch 1:
INFO [10:52:26.967] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:26.967] [bbotk] 0.5 14 8.853036 0 0 0.015
INFO [10:52:26.967] [bbotk] uhash
INFO [10:52:26.967] [bbotk] 92b99ee4-e7ba-4cee-b61a-5fe8bf3c6a6d
INFO [10:52:26.967] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:26.976] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:26.979] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:26.988] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:26.997] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.006] [mlr3] Finished benchmark
INFO [10:52:27.022] [bbotk] Result of batch 2:
INFO [10:52:27.023] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.023] [bbotk] 0.9 20 11.74412 0 0 0.018
INFO [10:52:27.023] [bbotk] uhash
INFO [10:52:27.023] [bbotk] 2686d082-d6cb-4430-8ef2-0c9d4f18048e
INFO [10:52:27.023] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:27.032] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:27.035] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:27.050] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:27.064] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.078] [mlr3] Finished benchmark
INFO [10:52:27.089] [bbotk] Result of batch 3:
INFO [10:52:27.090] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.090] [bbotk] 0.9 1 5.814905 0 0 0.034
INFO [10:52:27.090] [bbotk] uhash
INFO [10:52:27.090] [bbotk] 640b968f-e097-42f8-8094-49563ba2fe32
INFO [10:52:27.090] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:27.099] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:27.102] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:27.110] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:27.119] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.134] [mlr3] Finished benchmark
INFO [10:52:27.145] [bbotk] Result of batch 4:
INFO [10:52:27.146] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.146] [bbotk] 0.6333333 14 9.373763 0 0 0.019
INFO [10:52:27.146] [bbotk] uhash
INFO [10:52:27.146] [bbotk] abce4942-9d23-4ab6-ae96-b0158a64b60f
INFO [10:52:27.146] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:27.155] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:27.157] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:27.166] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:27.174] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.183] [mlr3] Finished benchmark
INFO [10:52:27.193] [bbotk] Result of batch 5:
INFO [10:52:27.194] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.194] [bbotk] 0.5 20 11.11123 0 0 0.016
INFO [10:52:27.194] [bbotk] uhash
INFO [10:52:27.194] [bbotk] 56524c5b-292b-4795-8b0b-7f5032c906c8
INFO [10:52:27.195] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:27.203] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:27.206] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:27.214] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.228] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:27.238] [mlr3] Finished benchmark
INFO [10:52:27.249] [bbotk] Result of batch 6:
INFO [10:52:27.249] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.249] [bbotk] 0.7666667 14 9.556742 0 0 0.019
INFO [10:52:27.249] [bbotk] uhash
INFO [10:52:27.249] [bbotk] 0985f273-95d0-4844-920e-837591ad4787
INFO [10:52:27.250] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:27.258] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:27.261] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:27.274] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:27.288] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.302] [mlr3] Finished benchmark
INFO [10:52:27.312] [bbotk] Result of batch 7:
INFO [10:52:27.313] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.313] [bbotk] 0.7666667 1 5.640281 0 0 0.03
INFO [10:52:27.313] [bbotk] uhash
INFO [10:52:27.313] [bbotk] 63c479e2-0ebf-43bf-b1f2-cccee6a03fe4
INFO [10:52:27.314] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:27.322] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:27.325] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:27.346] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:27.360] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.374] [mlr3] Finished benchmark
INFO [10:52:27.384] [bbotk] Result of batch 8:
INFO [10:52:27.385] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.385] [bbotk] 0.6333333 1 5.344656 0 0 0.031
INFO [10:52:27.385] [bbotk] uhash
INFO [10:52:27.385] [bbotk] e377f9f3-9222-449e-8bc1-aec5d6f9bea4
INFO [10:52:27.386] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:27.394] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:27.397] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:27.405] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.414] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:27.423] [mlr3] Finished benchmark
INFO [10:52:27.434] [bbotk] Result of batch 9:
INFO [10:52:27.434] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.434] [bbotk] 0.7666667 20 11.70993 0 0 0.017
INFO [10:52:27.434] [bbotk] uhash
INFO [10:52:27.434] [bbotk] 5ad30c70-2251-4700-bebe-135f5ad756b4
INFO [10:52:27.435] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:27.449] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:27.453] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.466] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:27.477] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:27.489] [mlr3] Finished benchmark
INFO [10:52:27.500] [bbotk] Result of batch 10:
INFO [10:52:27.501] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.501] [bbotk] 0.6333333 7 6.243863 0 0 0.022
INFO [10:52:27.501] [bbotk] uhash
INFO [10:52:27.501] [bbotk] efc59800-b9af-4500-bf2a-375273866b87
INFO [10:52:27.501] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:27.511] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:27.514] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:27.524] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.533] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:27.543] [mlr3] Finished benchmark
INFO [10:52:27.555] [bbotk] Result of batch 11:
INFO [10:52:27.556] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.556] [bbotk] 0.9 14 9.585749 0 0 0.019
INFO [10:52:27.556] [bbotk] uhash
INFO [10:52:27.556] [bbotk] 9c144243-5261-463e-a177-77b3bc34e7aa
INFO [10:52:27.556] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:27.573] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:27.576] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.586] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:27.595] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:27.605] [mlr3] Finished benchmark
INFO [10:52:27.617] [bbotk] Result of batch 12:
INFO [10:52:27.618] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.618] [bbotk] 0.6333333 20 11.30442 0 0 0.018
INFO [10:52:27.618] [bbotk] uhash
INFO [10:52:27.618] [bbotk] 1aed11b1-a85f-44ce-a675-8fb756b0de19
INFO [10:52:27.618] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:27.628] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:27.630] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:27.644] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.657] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:27.671] [mlr3] Finished benchmark
INFO [10:52:27.682] [bbotk] Result of batch 13:
INFO [10:52:27.683] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.683] [bbotk] 0.5 1 5.168695 0 0 0.029
INFO [10:52:27.683] [bbotk] uhash
INFO [10:52:27.683] [bbotk] ef6b5f2a-887f-4ce3-bccc-ed1caf5b8f39
INFO [10:52:27.684] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:27.700] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:27.703] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.714] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:27.724] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:27.735] [mlr3] Finished benchmark
INFO [10:52:27.745] [bbotk] Result of batch 14:
INFO [10:52:27.746] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.746] [bbotk] 0.7666667 7 6.306153 0 0 0.022
INFO [10:52:27.746] [bbotk] uhash
INFO [10:52:27.746] [bbotk] 52e7534f-a6fd-4b06-a5db-ec9f16149880
INFO [10:52:27.747] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:27.756] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:27.758] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:27.769] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:27.780] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.791] [mlr3] Finished benchmark
INFO [10:52:27.807] [bbotk] Result of batch 15:
INFO [10:52:27.808] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.808] [bbotk] 0.9 7 6.025976 0 0 0.023
INFO [10:52:27.808] [bbotk] uhash
INFO [10:52:27.808] [bbotk] 6657aa2d-e805-438d-81ff-de9b8fff72c6
INFO [10:52:27.808] [bbotk] Evaluating 1 configuration(s)
INFO [10:52:27.818] [mlr3] Running benchmark with 3 resampling iterations
INFO [10:52:27.821] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 1/3)
INFO [10:52:27.830] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 2/3)
INFO [10:52:27.840] [mlr3] Applying learner 'regr.ranger' on task 'example' (iter 3/3)
INFO [10:52:27.851] [mlr3] Finished benchmark
INFO [10:52:27.861] [bbotk] Result of batch 16:
INFO [10:52:27.862] [bbotk] mtry.ratio min.node.size regr.mse warnings errors runtime_learners
INFO [10:52:27.862] [bbotk] 0.5 7 6.040561 0 0 0.02
INFO [10:52:27.862] [bbotk] uhash
INFO [10:52:27.862] [bbotk] cddb6675-643b-4549-a1fc-af3cc26bf1cd
INFO [10:52:27.864] [bbotk] Finished optimizing after 16 evaluation(s)
INFO [10:52:27.864] [bbotk] Result:
INFO [10:52:27.865] [bbotk] mtry.ratio min.node.size learner_param_vals x_domain regr.mse
INFO [10:52:27.865] [bbotk] 0.5 1 <list[4]> <list[2]> 5.168695
You can access the optimized learner by accessing the model$learner
attribute.
$model$learner auto_tunning_learner
<LearnerRegrRanger:regr.ranger>
* Model: ranger
* Parameters: num.threads=1, max.depth=10, mtry.ratio=0.5,
min.node.size=1
* Packages: mlr3, mlr3learners, ranger
* Predict Type: response
* Feature types: logical, integer, numeric, character, factor, ordered
* Properties: hotstart_backward, importance, oob_error, weights
You can look at the history of tuning process like below:
$model$tuning_instance auto_tunning_learner
<TuningInstanceSingleCrit>
* State: Optimized
* Objective: <ObjectiveTuning:regr.ranger_on_example>
* Search Space:
id class lower upper nlevels
1: mtry.ratio ParamDbl 0.5 0.9 Inf
2: min.node.size ParamInt 1.0 20.0 20
* Terminator: <TerminatorEvals>
* Result:
mtry.ratio min.node.size regr.mse
1: 0.5 1 5.168695
* Archive:
mtry.ratio min.node.size regr.mse
1: 0.5000000 14 8.853036
2: 0.9000000 20 11.744119
3: 0.9000000 1 5.814905
4: 0.6333333 14 9.373763
5: 0.5000000 20 11.111230
6: 0.7666667 14 9.556742
7: 0.7666667 1 5.640281
8: 0.6333333 1 5.344656
9: 0.7666667 20 11.709931
10: 0.6333333 7 6.243863
11: 0.9000000 14 9.585749
12: 0.6333333 20 11.304420
13: 0.5000000 1 5.168695
14: 0.7666667 7 6.306153
15: 0.9000000 7 6.025976
16: 0.5000000 7 6.040561
You can of course use the predict()
method as well.
$predict(reg_task) auto_tunning_learner
<PredictionRegr> for 32 observations:
row_ids truth response
1 21.0 20.90324
2 21.0 20.90364
3 22.8 23.93983
---
30 19.7 19.82130
31 15.0 14.92800
32 21.4 21.84063
18.6 mlr3
in action
Here, an example usage of the mlr3
framework as a part of a research process is presented. Suppose our goal is to estimate the heterogeneous treatment effect using causal forest using the R grf
package. The grf::causal_forest()
function implements causal forest as an R-learner and it uses random forest for its first-stage estimations by default. However, the function allows the users to provide their own estimated values of \(y\) (dependent variable) and \(T\) (treatment). We will code the process of conducting the first stage using mlr3
and then use grf::causal_forest()
to estimate heterogeneous treatment effects.
#=== load the Treatment dataset ===#
data("Treatment", package = "Ecdat")
#=== convert to a data.table ===#
(<-
data data.table(Treatment)
)
treat age educ ethn married re74 re75 re78 u74 u75
1: TRUE 37 11 black TRUE 0.0 0.0 9930.05 TRUE TRUE
2: TRUE 30 12 black FALSE 0.0 0.0 24909.50 TRUE TRUE
3: TRUE 27 11 black FALSE 0.0 0.0 7506.15 TRUE TRUE
4: TRUE 33 8 black FALSE 0.0 0.0 289.79 TRUE TRUE
5: TRUE 22 9 black FALSE 0.0 0.0 4056.49 TRUE TRUE
---
2671: FALSE 47 8 other TRUE 44667.4 33837.1 38568.70 FALSE FALSE
2672: FALSE 32 8 other TRUE 47022.4 67137.1 59109.10 FALSE FALSE
2673: FALSE 47 10 other TRUE 48198.0 47968.1 55710.30 FALSE FALSE
2674: FALSE 54 0 hispanic TRUE 49228.5 44221.0 20540.40 FALSE FALSE
2675: FALSE 40 8 other TRUE 50940.9 55500.0 53198.20 FALSE FALSE
The dependent variable is re78
(\(Y\)), which is real annual earnings in 1978 (after treatment). The treatment variable of interest is treat
(\(T\)), which is TRUE
if a person had gone through a training, FALSE
otherwise. The features that are included as potential drivers of the heterogeneity in the impact of treat
on re78
is age
, educ
, ethn
, and married
(\(X\)). Note that the focus of this section is just showcasing the use of mlr3
and no attention is paid to potential endogeneity problems.
18.6.1 First stage
We will use ranger()
and xgboost()
as learners. While ranger()
accepts factor variables, xgboost()
does not. So, we will one-hot-encode the data using mltools::one_hot()
so that we can just create a single Task
and use it for both. We also turn treat
to a factor so it is amenable with classification jobs by the two learner functions.
(<-
data_trt ::one_hot(data) %>%
mltoolstreat := factor(treat)]
.[, )
Note that we need separate procedures for estimating \(E[Y|X]\) and \(E[T|X]\). The former is a regression and the latter is a classification task.
Let’s first work on estimating \(E[Y|X]\). First, we set up a task.
<-
y_est_task $new(
TaskRegrid = "estimate_y_on_x",
backend = data_trt[, .(
re78, age, educ, married,
ethn_other, ethn_black, ethn_hispanic
)],target = "re78"
)
We consider two modeling approaches: random forest by ranger
and gradient boosted forest by xgboost
.
<- lrn("regr.ranger")
y_learner_ranger <- lrn("regr.xgboost") y_learner_xgboost
We will implement a K-fold cross-validation to select the better model with optimized hyper-parameter values. We do this by applying triggering Tuner
classed defined separately for the two learners.
Let’s define the resampling method, measure, terminator, and tuner that will be shared by the two approaches.
<- rsmp("cv", folds = 4)
resampling_y <- msr("regr.mse")
measure_y <- trm("evals", n_evals = 100)
terminator <- tnr("grid_search", resolution = 4) tuner
Now, when we compare multiple models, we should use the same CV splits to have a fair comparison their model performance. To ensure this, we need to use an instantiated Resampling
object.
If you provide an un-instantiated Resampling
object to an AutoTuner
, it will instantiate the Resampling
object internally and two separate AutoTuner
s can result in two distinct splits.
#=== instantiate ===#
$instantiate(y_est_task)
resampling_y
#=== confirm it is indeed instantiated ===#
resampling_y
<ResamplingCV>: Cross-Validation
* Iterations: 4
* Instantiated: TRUE
* Parameters: folds=4
Let’s define search space for each of the learners.
<-
search_space_ranger ps(
mtry = p_int(lower = 1, upper = length(y_est_task$feature_names)),
min.node.size = p_int(lower = 1, upper = 20)
)
<-
search_space_xgboost ps(
nrounds = p_int(lower = 100, upper = 400),
eta = p_dbl(lower = 0.01, upper = 1)
)
We have all the ingredients to set up TuningInstance
s for the learners.
<-
tuning_instance_ranger $new(
TuningInstanceSingleCrittask = y_est_task,
learner = y_learner_ranger,
resampling = resampling_y,
measure = measure_y,
search_space = search_space_ranger,
terminator = terminator
)
<-
tuning_instance_xgboost $new(
TuningInstanceSingleCrittask = y_est_task,
learner = y_learner_xgboost,
resampling = resampling_y,
measure = measure_y,
search_space = search_space_xgboost,
terminator = terminator
)
Let’s tune them now (this can take a while).
We use using the same tuner here, but you can use different tuning processes for the learners. For example, you can have resolution = 5
for tuning regr.xgboost
.
#=== tune ranger ===#
$optimize(tuning_instance_ranger)
tuner
#=== tune xgboost ===#
$optimize(tuning_instance_xgboost) tuner
Here are the MSEs from the two individually tuned learners.
$result_y tuning_instance_ranger
regr.mse
193034514
$result_y tuning_instance_xgboost
regr.mse
197112372
So, in this example, we go with regr.ranger
with its optimized hyper-parameter values. Let’s update our learner with the optimized hyper-parameter values.
($param_set$values <- tuning_instance_ranger$result_learner_param_vals
y_learner_ranger )
$num.threads
[1] 1
$mtry
[1] 2
$min.node.size
[1] 14
Now that we have decided on the model to use for predicting \(E[y|X]\), let’s implement cross-fitting.
<-
cv_results_y resample(
y_est_task,
y_learner_ranger, rsmp("repeated_cv", repeats = 4, folds = 3)
)
INFO [10:52:55.515] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 4/12)
INFO [10:52:55.643] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 11/12)
INFO [10:52:55.769] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 9/12)
INFO [10:52:55.893] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 12/12)
INFO [10:52:56.018] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 10/12)
INFO [10:52:56.144] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 3/12)
INFO [10:52:56.274] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 5/12)
INFO [10:52:56.404] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 8/12)
INFO [10:52:56.528] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 1/12)
INFO [10:52:56.653] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 7/12)
INFO [10:52:56.777] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 2/12)
INFO [10:52:56.900] [mlr3] Applying learner 'regr.ranger' on task 'estimate_y_on_x' (iter 6/12)
#=== all combined ===#
<-
all_predictions_y $prediction() %>%
cv_results_yas.data.table() %>%
y_hat = mean(response)), by = row_ids] %>%
.[, .(order(row_ids), ] .[
We basically follow the same process for estimating \(E[T|X]\). Let’s set up tuning processes for the learners.
<-
t_est_task $new(
TaskClassifid = "estimate_t_on_x",
backend =
data_trt[, .(
treat, age, educ, married,
ethn_other, ethn_black, ethn_hispanic
)],target = "treat"
)
<- lrn("classif.ranger")
t_learner_ranger <- lrn("classif.xgboost")
t_learner_xgboost
<- rsmp("cv", folds = 4)
resampling_t $instantiate(t_est_task)
resampling_y<- msr("classif.ce")
measure_t <- trm("evals", n_evals = 100)
terminator <- tnr("grid_search", resolution = 4)
tuner
<-
tuning_instance_ranger $new(
TuningInstanceSingleCrittask = t_est_task,
learner = t_learner_ranger,
resampling = resampling_y,
measure = measure_t,
search_space = search_space_ranger,
terminator = terminator
)
<-
tuning_instance_xgboost $new(
TuningInstanceSingleCrittask = t_est_task,
learner = t_learner_xgboost,
resampling = resampling_y,
measure = measure_t,
search_space = search_space_xgboost,
terminator = terminator
)
Here are the classification error from the two individually tuned learners.
#=== tune ranger ===#
$optimize(tuning_instance_ranger)
tuner$result_y tuning_instance_ranger
#=== tune xgboost ===#
$optimize(tuning_instance_xgboost)
tuner$result_y tuning_instance_xgboost
So, we are picking the classif.ranger
option here as well as it has a lower classification error.
$param_set$values <- tuning_instance_ranger$result_learner_param_vals t_learner_ranger
Now that we have decided on the model to use for predicting \(E[T|X]\), let’s implement cross-fitting. Before cross-fitting, we need to tell t_learner_ranger
to predict probability instead of classification (either 0 or 1).
$predict_type <- "prob"
t_learner_ranger
<-
cv_results_t resample(
t_est_task,
t_learner_ranger, rsmp("repeated_cv", repeats = 4, folds = 3)
)
INFO [10:53:16.482] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 2/12)
INFO [10:53:16.614] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 11/12)
INFO [10:53:16.742] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 10/12)
INFO [10:53:16.869] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 8/12)
INFO [10:53:16.998] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 12/12)
INFO [10:53:17.124] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 6/12)
INFO [10:53:17.251] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 7/12)
INFO [10:53:17.387] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 3/12)
INFO [10:53:17.516] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 5/12)
INFO [10:53:17.645] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 4/12)
INFO [10:53:17.772] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 9/12)
INFO [10:53:17.896] [mlr3] Applying learner 'classif.ranger' on task 'estimate_t_on_x' (iter 1/12)
#=== all combined ===#
<-
all_predictions_t $prediction() %>%
cv_results_tas.data.table() %>%
t_hat = mean(prob.TRUE)), by = row_ids] %>%
.[, .(order(row_ids), ] .[
We now use all_predictions_y
and all_predictions_t
for Y.hat
and W.hat
in grf::causal_forest()
.
::causal_forest(
grfX = data_trt[, .(age, educ, married, ethn_other, ethn_black, ethn_hispanic)] %>% as.matrix(),
Y = data_trt[, re78],
W = data_trt[, fifelse(treat == TRUE, 1, 0)],
Y.hat = all_predictions_y[, y_hat],
W.hat = all_predictions_t[, t_hat]
)
GRF forest object of type causal_forest
Number of trees: 2000
Number of training samples: 2675
Variable importance:
1 2 3 4 5 6
0.506 0.302 0.116 0.023 0.044 0.009