Skip to content

apoorvalal/dmlUtils

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

dmlUtils: Utilities for DoubleML

Utility functions to quickly fit DML models with cross-fitting and make tables with estimates and RMSEs.

rm(list = ls())
set.seed(42)
library(pacman)
p_load(knitr,DoubleML, mlr3, mlr3learners)

# this library
library(dmlUtils)

Data Prep

# %% data prep
data(lalonde.exp)
formula_flex = "re78 ~ treat +
                (poly(age, 4, raw=TRUE) + poly(education, 4, raw=TRUE) +
                poly(re74, 4, raw=TRUE) + poly(re75, 4, raw=TRUE) +
                black + hispanic + married + nodegree + u74 + u75)^2"
model_flex = as.data.table(model.frame(formula_flex, lalonde.exp))
# %%
x_cols = colnames(model_flex)[-c(1,2)]
data_ml = DoubleMLData$new(model_flex, y_col = "re78", d_cols = "treat",
          x_cols = x_cols)
          
# alternative setup with interactions
data(lalonde.exp)
# %%
y = 're78'
w = 'treat'
fx = "~ -1 +(poly(age, 3, raw = T) + poly(education, 3, raw = T) +
                  poly(re74, 3, raw = T) + poly(re75, 3, raw = T) +
                  black + hispanic + married + nodegree + u74 + u75
                )^2"
X = model.matrix(as.formula(fx), data = lalonde.exp)
# drop constant columns
X = X[, -which(apply(X, 2, var) == 0) ]
# stack
d = data.table(y = lalonde.exp[[y]], w = lalonde.exp[[w]], X) %>% clean_names()
# init object
data_ml = DoubleMLData$new(d, y_col = 'y', d_cols = 'w',
          x_cols = setdiff(colnames(d), c('y', 'w')))

This is a DoubleMLData object initialised above.

# Classes 'DoubleMLData', 'R6' <DoubleMLData>
#   Public:
#     all_variables: active binding
#     clone: function (deep = FALSE)
#     d_cols: active binding
#     data: active binding
#     data_model: active binding
#     initialize: function (data = NULL, x_cols = NULL, y_col = NULL, d_cols = NULL,
#     n_instr: active binding
#     n_obs: active binding
#     n_treat: active binding
#     other_treat_cols: active binding
#     print: function ()
#     set_data_model: function (treatment_var)
#     treat_col: active binding
#     use_other_treat_as_covariate: active binding
#     x_cols: active binding
#     y_col: active binding
#     z_cols: active binding
#   Private:
#     check_disjoint_sets: function ()
#     d_cols_: treat
#     data_: data.table, data.frame
#     data_model_: data.table, data.frame
#     other_treat_cols_: NULL
#     treat_col_: treat
#     use_other_treat_as_covariate_: TRUE
#     x_cols_: poly.age..4..raw...TRUE..1 poly.age..4..raw...TRUE..2 po ...
#     y_col_: re78
#     z_cols_: NULL
#

Initialise mlr3 learners

# %% learners
lgr::get_logger("mlr3")$set_threshold("warn")
lasso       = lrn("regr.cv_glmnet",    nfolds = 5, s = "lambda.min"); set_threads(lasso)
lasso_class = lrn("classif.cv_glmnet", nfolds = 5, s = "lambda.min"); set_threads(lasso_class)
rf          = lrn("regr.ranger");      set_threads(rf)
rf_class    = lrn("classif.ranger");   set_threads(rf_class)
trees       = lrn("regr.rpart");       set_threads(trees)
trees_class = lrn("classif.rpart");    set_threads(trees_class)
boost       = lrn("regr.glmboost");    set_threads(boost)
boost_class = lrn("classif.glmboost"); set_threads(boost_class)
# <LearnerRegrCVGlmnet:regr.cv_glmnet>
# * Model: -
# * Parameters: family=gaussian, nfolds=5, s=lambda.min
# * Packages: mlr3, mlr3learners, glmnet
# * Predict Type: response
# * Feature types: logical, integer, numeric
# * Properties: selected_features, weights

# <LearnerClassifCVGlmnet:classif.cv_glmnet>
# * Model: -
# * Parameters: nfolds=5, s=lambda.min
# * Packages: mlr3, mlr3learners, glmnet
# * Predict Type: response
# * Feature types: logical, integer, numeric
# * Properties: multiclass, selected_features, twoclass, weights
<truncated> ...

Fit models

# partially linear
lassoPLR = plrFit(data_ml, lasso, lasso_class)
rforsPLR = plrFit(data_ml, rf,    rf_class)
treesPLR = plrFit(data_ml, trees, trees_class)
boostPLR = plrFit(data_ml, boost, boost_class)

# fully nonparametric
lassoIRM = irmFit(data_ml, lasso, lasso_class)
rforsIRM = irmFit(data_ml, rf,    rf_class)
treesIRM = irmFit(data_ml, trees, trees_class)
boostIRM = irmFit(data_ml, boost, boost_class)

table util to pass model fits

dmlTab(lassoPLR, rforsPLR, treesPLR, boostPLR, lassoIRM, rforsIRM, treesIRM, boostIRM)
LASSO RF CART BOOST LASSO RF CART BOOST
Estimate 1889.1452 1869.2201 1805.9770 1709.8805 1587.8331 1315.4074 2140.5224 1790.9077
SE 684.0436 634.5607 649.6598 667.7938 668.2180 806.9806 2220.9821 713.1665
RMSE: Y 6772.3834 7180.0214 6751.6879 6529.5300 6573.6237 6902.0071 6922.6472 6777.6399
RMSE: D 0.4910 0.5202 0.5173 0.4980 0.4947 0.5144 0.5012 0.4943
CE: D 0.4135 0.4472 0.4247 0.3978 0.4292 0.4562 0.4022 0.4270

About

utils to fit DoubleML models

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages