1. Introduction
There are generally two approaches to address causal inference in observational studies. The first one is to draw population-level causal inference which goes back at least to the 1970s [
1]. The second is to draw conditional causal inference which has received attention more recently [
2,
3]. An example of a population-level causal parameter the average treatment effect (ATE),
The quantity
is referred to as the conditional average treatment effect (CATE) [
4,
5,
6,
7,
8,
9,
10]. CATE is NOT an individual-level causal parameter. The latter is impossible to estimate accurately unless both potential outcomes are observed for each individual (under parallel worlds!), or
W contains all the varying factors that make the causal relationship deterministic, which are unlikely to hold in practice. That said, under certain assumptions, the counterfactual loss, the loss due to the absence of counterfactual outcome, can be upper bounded [
11]. The present article focuses on the estimation of ATE which does not require those assumptions.
Through a number of attempts, researchers have utilized ML models for the causal parameter estimation [
12,
13,
14,
15,
16,
17]. While the ultimate goal of a ML algorithm is to predict the outcome of interest as accurately as possible, it does not optimally serve the main purpose of the causal parameter estimation. In fact, ML algorithms minimize some prediction loss containing the treatment or the observed outcome (and not counterfactual outcome) and without targeting any relevant predictor(s) such as confounding variables [
18].
Including confounders for the estimation of ATE in observational studies avoids potential selection bias [
19], however, in practice, we do not have a priori knowledge about the confounders and the ML algorithm minimizes the loss function without discriminating between the input covariates. In fact, the ML algorithm can successfully learn the linear and non-linear relationships between the confounders and the treatment and outcome, but at the same time, might learn from potential Instrumental Variables (IVs) present in the data as well (the variables that predict the treatment, but not the outcome). If there are strong confounders or IVs among the covariates, the predictions of treatments (i.e., the propensity scores) can become extreme (near zero or one) which in turn can make the estimates unstable. While possibly reducing the bias, the variance gets elevated at the same time. Less complex models, on the other hand, may suffer from large bias (under-fitting) but can obtain more stable causal parameter estimation. This conflict between the necessary complexity in the model(s) and the bias-variance tradeoff motivates to develop ML algorithms for step 1 that provide a compromise between learning from confounders and IVs to entail a balance between the bias and variance of the causal parameter in step 2. In addition to a low bias-variance tradeoff, the asymptotic normality of the causal effect estimator is wanted for inferential statistics.
Chernozhukov et al. [
16] investigated the asymptotic normality of orthogonal estimators of ATE (including Augmented Inverse Probability Weighting (AIPW)) when two separate ML algorithms model the treatment and outcome, referred to as the Double Machine Learning (DML). With the same objective, Farrell et al. [
17] utilized two separate neural networks (we refer to as the double NN or dNN), without the usage of any regularization other than using the Stochastic Gradient Descent (SGD) for model optimization. SGD does impose some regularization but is insufficient to control the complexity of NN algorithms where strong predictors exist in the data [
20]. Rostami and Saarela [
20] experimentally showed that when AIPW is utilized, dNN performs poorly. The normalization of AIPW helps control both the bias and variance of the estimator. Further, they illustrated that imposing the
regularization on all of the parameters (without targeting a specific set of input features) helps reduce the bias, variance, and Mean Square Error (MSE) of the ATE estimators up to some extent. Simulations indicated that when dNN is used, with or without regularization, the normalized AIPW (nAIPW) outperforms AIPW. For a comprehensive literature review on the doubly robust estimators (including AIPW) see Moosavi et al. [
21].
The strategy of targeting a specific type of features can be designed in NN architectures along with the necessary optimization and regularization techniques. Flexible NN structures, optimizations and regularization techniques are easily programmed in deep learning platforms such as pytorch.
Shi et al. [
22] proposes a neural network architecture, referred to as the DragonNet, that jointly models the treatment and outcome, in which a multi-tasking optimization technique is employed. In the DragonNet architecture, the interaction of the treatment and non-linear transformations of the input variables are considered. Chernozhukov et al. [
23] uses the Riesz Representer [
16] as the minimizer of a stochastic loss, which provides an alternative for the propensity score estimation, and aims to prevent the empirical consistency assumption violation issue [
20]. Chernozhukov et al. [
23] also use the joint modeling of the Riesz Representer and the outcome through multi-tasking, and they call their method auto Double Machine Learning (Auto-DML). Chernozhukov et al. [
24] optimized an
regularized loss function to estimate weights rather than estimating propensity scores and plugging them into the AIPW estimator. Chernozhukov et al. [
25] proposed optimizing a minimax loss function for the same purpose. In this body of work, it is still unclear how to hyperparameter tune the chosen NN architecture for causal inference, especially for the ATE estimation.
Other techniques of feature selection before propensity score estimation have been proposed in the literature [
26]. However, hard thresholding might ignore important information hidden in the features.
The objective of this research is to experimentally investigate how NN-type methods can be utilized for ATE estimation, and how the hyperparameters can be tuned to achieve the best bias-variance tradeoff for the ATE estimators. This is done in the presence of strong IVs and confounders. The papers cited above do not consider this general scenario.
In this research our goal is not any of the following: 1. We do not aim to compare NNs with other ML algorithms to see which ones outperform the others. By the no-free-lunch theorem, [
27], there is no specific algorithm that can learn all relationships sufficiently well. Thus, it is expected that some ML algorithms are better in some scenarios and other algorithms in other scenarios. 2. We do not aim to study different types of causal parameters. 3. We do not aim to study different estimators of the Average Treatment Effect. 4. We do not aim to study feature selection or other types of methods that can prevent IVs to feed into the model of the treatment in the first step inference.
Throughout this research, we utilize nAIPW as it outperforms AIPW estimator in the presence of strong confounders and IVs [
20]. To target the relevant inputs, we propose two methods. First, employing a type of
regularization on top of the common
regularization on all the network parameters. Second, we propose a joint model of the treatment and outcome in a Neural Network (jNN) architecture where we place both the treatment and outcome on the output layer of a multi-layer perceptron [
28]. This NN architecture is appealing as it models the treatment and outcome simultaneously which can potentially target the relevant covariates that are predictive of both treatment and outcome (or confounder) and can mitigate or ignore the IVs’ effects on the predictions. We will investigate if both or either of these ideas improves the bias-variance tradeoff of the causal effect estimator as compared to a dNN model.
In this research, the NN architecture that jointly models the treatment and outcome here referred to as jNN. The parameters or weights are estimated by minimizing a regularized multi-task loss which is the summation of the Cross-Entropy (for modeling the binary treatment) and MSE loss (for modeling the continuous outcome) [
29]. Multi-task learning can help target the predictors of both treatment and outcome that are placed in the output layer, and also it helps to resist against over-fitting in case of many irrelevant inputs [
30]. Other benefits of multi-task learning are listed in
Section 2.2. Also, two types of
regularization terms are used in order to dampen the instrumental variables and strong confounders.
To show the effectiveness of jNN and dNN, a thorough simulation study is performed and these methods are compared in terms of the number of confounders and IVs that are captured in each scenario, the prediction measures, and the bias and variance of causal estimators. To investigate whether our network targets confounders rather than IVs and also dampens the impact of strong confounders on the propensity scores, we calculate the bias-variance tradeoff of causal estimators (i.e., minimal MSE) utilizing the NN predictions; Low bias means the model has mildly learned from confounders and other types of covariates for the outcome, and low variance means the model has ignored IVs and has dampened strong confounders in the treatment model. Further, a comparison between the methods is made on the Canadian Community Health Survey (CCHS) dataset where the intervention/treatment is food security versus food insecurity and the outcome is individuals’ body mass index (BMI).
The organization of this paper is as follows. In
Section 1.2 we define the problem setting and the causal parameter to be estimated. In
Section 2 we introduce the NN-type methods, their loss functions, and hyperparameters.
Section 3 provides a quick review of the ATE estimators. In
Section 4 our simulation scenarios are stated along with their results in
Section 4.2. The results of the application of our methods on a real dataset are presented in
Section 5. We conclude the paper in
Section 6 with some discussion on the results and future work.
1.1. Notation
Let data
be generated by a data generating process
F, where
is a finite dimensional vector
, with
Y as the outcome,
A as the treatment and
, where we assumes
, and
, for some functions
.
is the set of confounders,
is the set of instrumental variables,
is the set of y-predictors (independent of the treatment), and
is a set of given noise or irrelevant inputs (
Figure 1).
P is the true observed data distribution,
is the distribution of
such that its marginal distribution with respect to
W is its empirical distribution, and the expectation of the conditional distribution
, for
, can be estimated. We denote the prediction function of observed outcome given covariates in the treated group
, and that in the untreated group
, and the propensity score as
. Throughout, the expectations
are with respect to
P. The symbol
on the population-level quantities indicates the corresponding finite sample estimator, and
P is replaced by
.
1.2. Problem Setup and Assumptions
The fundamental problem of causal inference states that individual-level causality cannot be exactly determined since each person can experience only one value of
A. Thus, it is customary to only estimate a population-level causal parameter, in this research Average Treatment Effect (ATE) (
1).
For identifiablity of the parameter, the following assumptions must hold true. The first assumption is the Conditional Independence, Ignorability or Unconfoundedness stating that, given the confoudners, the potential outcomes are independent of the treatment assignments (
). The second assumption is Positivity which entails that the assignment of treatment groups is not deterministic (
) ([
18], page 344). The third assumption is Consistency which states that the observed outcomes equal their corresponding potential outcomes (
). There are other modeling assumptions made such as time order (i.e., the covariates
W are measured before the treatment), IID subjects, and a linear causal effect.
4. Simulations
A simulation study (with 100 iterations) was performed to compare the prediction methods jNN, and dNN by inserting their predictions in the nAIPW (causal) estimators (
6). There are a total of 8 scenarios according to the size of the data (i.e., the number of subjects and number of covariates), and the confounding and instrumental variables strengths. We fixed the sample sizes to be
and
, with the number of covariates
and
, respectively. The four sets of covariates had the same sizes
and independent from each other were drawn from the Multivariate Normal (MVN) Distribution as
, with
and
. Let
. The models to generate the treatment assignment and outcome were specified as
The functions
select 30% of the columns and apply interactions and non-linear functions listed below (
8). The strength of instrumental variable and confounding effects were chosen as
where
or
, and
where
or
.
The non-linearities for each pair of covariates are randomly selected among the following functions:
where
, and
, or
, and
.
In order to find the best set of hyperparameter values for the NN architectures, we ran an initial series of simulations to find the best set of hyperparameters for all scenarios, presented here. The networks’ activation function is Rectified Linear Unit (ReLU), with 3 hidden layers as large as the input size (p), with
regularization and batch size equal to
and 200 epochs. The Adaptive Moment Estimation (Adam) optimizer Kingma and Ba [
39] with learning rate 0.01 and momentum 0.95 were used to estimate the network’s parameters, including the causal parameter (ATE).
As in practice the RMSE and covariate types are unknown, prediction measures of the outcome and treatment should be used to choose the best model in a K-fold cross-validation.
and
each provide insight about the outcome and treatment models, respectively, but in our framework, both models should be satisfactory. To measure the goodness of the prediction models (jNN and dNN) for causal inference purposes, we define and utilize a statistic which is a compromise (geometric average) between
and
, here referred to as
,
where
, the Somers’ D index. This measure was not utilized in the optimization process (i.e., training the neural networks), and is rather introduced here to observe if the compromise between
and
agrees with the models that capture more confounders than IVs. We will refer to
simply as
.
4.1. Selected Covariate Types
In order to identify which types of covariates (confounders, IVs, y-predictors, and irrelevant covariates) the prediction methods have learned from, we calculate the association between the inputs and the predicted values (
and
), and after sorting the inputs (from large to small values) based on the association values, we count the number of different types of covariates within top 15 inputs. The association between two variables here is estimated using the distance correlation statistic [
40] whose zero values entail independence and non-zero values entail statistical dependence between the two variables.
4.2. Results
Figure 3,
Figure 4,
Figure 5,
Figure 6,
Figure 7 and
Figure 8 present the overall comparison of different hyperparameter settings of jNN and dNN architectures in terms of five different measures, respectively: (1) The average number of captured confounders/IVs/y-predictors, (2) Average Root Mean Square Error (RMSE) of causal estimators, (3) Average
,
and their mixture measure
(
9), (4) Bias, (5) MC standard deviation of nAIPW. The bootstrap confidence intervals for the bias, standard deviation and RMSE are calculated to capture significant differences between the simulation scenarios. The
x-axis includes 16 hyperparameter settings, and as a general rule here, models in the left are most complex (less regularization and wider neural nets) and in the right are least complex. Noted that
regularization is only targeted at the treatment model.
The
Figure 3 and
Figure 4 show how the complexity of both dNN and jNN (
x-axis) impact the number of captured covariate types (i.e., confounders/IVs/y-predictors) (top graph), RMSE (middle graph) and prediction measures (bottom graph). In almost all the hyperparameter settings, especially when
is non zero, the number of captured confounders is larger and the number of captured IVs is smaller in jNN as compared to dNN. This shows the joint modeling has a benefit of focusing on the confounders, rather than IVs, especially in the large data scenario.
The RMSE of jNN is larger than that of dNN for models with zero targetted regularization (the scenarios in the left). With decreasing the complexity of the treatment model, the RMSE of both jNN and dNN decline. The jNN outperforms dNN in almost all of the hyperparameter settings in case of , but does not show a clear pattern in case of . Further, the impact of the width of architectures (H) changes based on regularization: wider architectures (, p: number of covariates) with large outperform other combinations of these two hyperparameters. This observation is more clear for smaller sized data, and for dNN model. In the small size scenarios, when the width is small (), the outcome model is affected and has a smaller . This means there are not enough neurons (on the first or last layer) to provide more accurate outcome predictions. In the best scenarios, the RMSE confidence intervals of jNN model are below those of dNN, illustrating a small preference of jNN over dNN in terms of RMSE. Comparing the three hyperparameters, is most effective, and zero values of this hyperparameter results in very large RMSEs for both dNN and jNN.
From
Figure 3 and
Figure 4, it is observed that both jNN and dNN models have roughly the same values for the
(outcome model performance) across hyperparameter settings and for both data sizes (
, and
). That is, the targeted regularization in jNN does not impact the performance of the outcome model. The
, on the other hand, declines with higher values of
, and is almost always smaller or equal in dNN as compared to jNN. Further, larger values of
in the small size data correspond to smaller RMSE, but no such pattern can be seen in the large data scenario.
Overall, the trends favor the idea that more complex treatment models capture larger number of IVs, have larger (smaller .), and have larger RMSE. That is, more complex models are less favorable.
Figure 7 and
Figure 8 illustrate the bias and standard deviation of the causal estimators. As expected and mentioned in the
Section 1, the models that do not dampen IVs suffer from large bias and standard deviation. The bias and standard deviation have opposite behavior in different settings, such that settings that produce larger standard deviation, results in small bias, and vice versa, except for the one setting that produces both largest bias and standard deviation. The fluctuations of the bias-variance across hyperparameter settings are larger in
case than in
case. For small sample
, the best scenario for jNN is
where both bias and standard deviation of jNN are small in the same direction. For the large sample
, however, the best scenario for jNN is
with a similar behavior. The best scenarios for dNN are slightly different. For small sample
and for the large sample
are most favorable.
5. Application: Food Insecurity and BMI
The Canadian Community Health Survey (CCHS) is a cross-sectional survey that collects data related to health status, health care utilization and health determinants for the Canadian population in multiple cycles. The 2021 CCHS covers the population 12 years of age and over living in the ten provinces and the three territorial capitals. Excluded from the survey’s coverage are: Persons living on reserves and other Aboriginal settlements in the provinces and some other sub-populations that altogether represent less than 3% of the Canadian population aged 12 and over. Examples of modules asked in most cycles are: General health, chronic conditions, smoking and alcohol use. For the 2021 cycle, thematic content on food security, home care, sedentary behavior and depression, among many others, was included. In addition to the health component of the survey are questions about respondent characteristics such as labor market activities, income and socio-demographics.
In this article, we use the CCHS dataset to investigate the causal relationship of food insecurity and body mass index (BMI). Other gathered information in the CCHS is used which might contain potential confounders, y-predictors and instrumental variables. The data are from a survey and need special methods such as the resampling or bootstrap methods to estimate the standard errors. However, here, we use the data to illustrate the utilization of jNN and dNN with different hyperparameters choices in the presence of possible empirical positivity violations. In order to reduce the amount of variability in the data, we have focused on the sub-population 18–65 years of age.
Figure 5 and
Figure 6 present the ATE estimates and their 95% asymptotic confidence intervals with nIPW, AIPW and nAIPW methods.
Figure 5 contains hyperparameter settings where there is no targeted regularization and it shows how important this regularization technique is, especially for the AIPW estimator that has no normalization. We have removed these scenarios in
Figure 6 for a more clear comparison between the remaining scenarios. The estimates and 95% CIs seem similar across the hyperparameter settings, but there is a clear difference between those of AIPW and nAIPW. This means that for this dataset, normalization might not be needed as the propensity scores do not behave extremely and AIPW does not blow up.
6. Discussion
In this paper, we have studied how hyperparameters of the Neural Network predictions in the first step can affect the Average Treatment Effect (ATE) estimator. We have considered a general Data Generating Process (DGP) that four types of covariates that exist in the dataset, confounders, IVs, y-predictors, and irrelevant covariates. Two general NN architectures have been studied, jNN and dNN where in the former both the outcome and treatment are modeled jointly (with an appropriate loss function) and in the latter, they are modeled separately. We have observed that
regularization especially the ones that targets the treatment model (
) is an effective hyperparameter for achieving a better bias-variance trade-off for the normalized Augmented Inverse Probability Weighting (nAIPW) estimator. And, the number of neurons in the first and last layer of the network becomes irrelevant as long as the value of
is sufficient. Further, we have observed that in the hyperparameter settings where the IV effects are controlled, the estimation is less biased and more stable. Thus the targeted regularization is successful in dampening the IVs and preventing perfect prediction in the treatment model.
Figure 3,
Figure 4,
Figure 5,
Figure 6,
Figure 7 and
Figure 8 illustrate that jNN is overall more stable and has a smaller RMSE in the small sample dataset scenario as compared to dNN. We utilized nAIPW in our simulations as they outperform or at least do not underperform AIPW and other estimators such as IPW, nIPW, AIPW, and SR. The nAIPW estimator has a normalization factor in the denumerator which can dilute the impact of extreme predictions of the propensity score model and protect the estimator against the positivity assumption violation Van der Laan and Rose [
18].
We utilized a geometric-type average of the and to choose among the first step models. As the objective of optimization in the first step is increasing prediction performance which is not necessarily the same as the causal inference objectives, the usage of either , or their geometric average is sub-optimal. In a future study alternative approaches will be explored and compared with the said prediction measure.
A real strength of NNs would be to uncover hidden information (and thus confounder effects) in unstructured data such as text or image data. However, in this article, we have not studied the presence of unstructured data and it is left for future research.
There are limitations due to the assumptions and simulation scenarios and, thus, some questions are left to future studies to be explored. For example, the outcome here was assumed to be continuous, and the treatment to be binary. We also did not cover heavy tail outcomes or rare treatment scenarios. Also, the ratio of dimension to the size of the data was considered to be fairly small (), and we have not studied the case where . Furthermore, we did not study the asymptotic behavior of nAIPW when jNN or dNN predictions are used.
A limitation of jNN as compared to dNN is that if one needs to shrink the final hidden layer to control the complexity of the treatment model, by structure, we are limiting the complexity of the outcome model which might not be necessary. This might be resolved by another architectural design, which is left to future studies on the subject.
The usage of another regularization technique that controls the extremeness of propensity score values is a plausible approach. For example, a data-dependent term can be added to the loss function Such a term discourages the network to obtain values extremely close to zero or one, as opposed to the negative log-likelihood term that encourages such tendencies. This approach might also focus less on the inputs that cause extreme values such as strong confounders or IVs. Examination of this approach is left to future studies.
In the design of the optimization, we did not consider a formal early stopping as a regularization technique. However, in the preliminary exploration, our simulations performed better with fewer iterations (in fact epochs). In modern NNs, researchers usually run the NN algorithms in many iterations, but that is partly due to the dropout regularization technique. We did not use drop-out (and L2) regularization in the final simulations, as the preliminary results did not confirm dropout as promising as regularization.
Further, we utilized NNs to learn the underlying relationships between the covariates and the outcome and treatment by targeting the relevant features through regularization and joint modeling of the treatment and outcome. NNs with other structures that might target confounders have not been explored, nor have other Machine Learning algorithms such as tree-based models. The Gradient Boosting Machines (GBM) algorithm Friedman [
41] can be alternatively used to learn these non-linear relationships while targeting the right set of features. This is postponed to a future article.