Causal Machine Learning: What Can We Accomplish with a Single Theorem?
Causal Inference, and specifically causal machine learning, is an indispensable tool that can help us make decisions by understanding cause and effect. Optimizing prices, reducing customer churn, running targeted ad campaigns, and deciding which patients would benefit most from medical treatment are all example use cases for causal machine learning.
There are many techniques for causal machine learning problems, but the technique that seems to stand out most is known as Double Machine Learning (DML) or Debiased/Orthogonal Machine Learning. Beyond the empirical success of DML, this technique stands out because of its rich theoretical backing rooted in a simple theorem from econometrics.
In this article, we'll unpack the theorem that grounds DML through hands-on examples. We'll discuss the intuition for DML and empirically verify its generality on increasingly complex examples. This article is not a tutorial on DML, instead it serves as motivation for how DML models see past mere correlation to understand and predict cause and effect.
A (Very) Brief Primer On Causal Inference
Causal inference is all about measuring the effect of a treatment (T) on an outcome (Y). Examples include measuring the effect of exercise on weight loss, marketing on customer conversion, price on sales, or a medical intervention on a health outcome.
When T is randomly assigned to observations, such as in randomized control trials (RCTs), we can directly estimate the causal relationship between T and Y, at least in aggregate, by analyzing how Y varies with T. In other words, if T is randomly assigned, we don't need any other information about our observations to estimate the aggregate effect of T on Y. In practice, we estimate this effect with techniques like linear regression where the coefficient on T, say θ1, tells us the average change in Y for a one-unit increase in T:

If we run this regression, and T is randomly assigned, then we call θ1 the average treatment effect (ATE). In general, for randomly assigned treatments, the ATE is defined as the expected change in Y when we change T:

The problem is that most data we encounter in the real world is observational, and treatments aren't randomly assigned to observations. In these circumstances, we can't directly estimate the effect of T on Y because of confounding variables. A confounding variable causes both the treatment and the outcome, making it difficult to discern whether changes in the outcome are coming from changes in the treatment or the confounder.
For example, suppose we're interested in understanding the effect of different diets on life expectancy. By analyzing how life expectancy varies across different diets, we notice that people who follow vegan diets tend to have higher life expectancies than people who don't. The issue is people who follow vegan diets tend to be more health-conscious overall. This means they might be more likely to exercise, go to the doctor, or take fewer risks.
The overall health consciousness of vegans is a cofounder in the relationship between their diet and life expectancy. In other words, unless we account for these health-consciousness variables, it's hard to know whether a vegan's increased life expectancy is due to their diet or other lifestyle factors.
Moreover, even if treatments are randomly assigned, implying there's no confounding, it's possible for the effect of the treatment on the outcome to vary depending on an individual's characteristics. We call these characteristics background variables.
For instance, suppose we conduct a marketing campaign in an attempt to increase sales for a new product. It's possible, and likely, that different people will respond differently to our marketing. Factors like age, socioeconomic status, sex, and previous product interests can all impact whether an individual responds to our marketing campaign. If we conduct an experiment where we randomly market to half the population and don't market to the other half, none of these factors are confounders. Nonetheless, they impact how our treatment (marketing) affects the outcome (new product sales).
We refer to confounding and background variables together as nuisance parameters, and we'll represent them with a matrix X. We call these nuisance parameters because we're only interested in the effect of T on Y, but we can't accurately estimate this effect without accounting for X.
In causal Machine Learning, we're often more interested in estimating a quantity known as the conditional average treatment effect (CATE) rather than the ATE. The CATE is generally defined as the expected change in Y when we change T for a specific set of nuisance parameters x:

If we assume Y follows a linear relationship with T and X, we can estimate the CATE by including interaction terms in the regression equation:

With this formulation, we can differentiate Y with respect to T to derive a formula for the CATE:

Beyond linear regression, there are many more advanced machine learning techniques that aim to estimate the CATE. These are often referred to as meta learners, with double machine learning (the R-learner) being one of them.
Lastly, if X is a set of variables that, when controlled for, allows for the estimation of the causal effect of T by removing confounding bias, then we call X a sufficient adjustment set. We must record as many nuisance parameters as possible when doing causal inference. This is essential for us to remove confounding bias and to get accurate CATE estimates. In particular, if we leave confounders out from X, we risk being mislead in our estimates.
There's much more to cover in causal inference and machine learning, but what we've covered so far is enough for this article. In the next section, we'll take a deeper dive into the nature of linear regression by exploring a simple yet profound theorem from econometrics. This will give us an intuition as to why linear regression is the cornerstone of causal inference, and it will help us build the foundations for more sophisticated techniques.
Frisch-Waugh-Lovell (FWL) Theorem
Frisch-Waugh-Lovell (FWL) Theorem is a theorem from econometrics that gives us profound insight into the nature of linear regression and motivates why linear regression works so well for causal inference. While we'll view FWL theorem from the lens of causal inference in this article, keep in mind this theorem implies a general property of linear regression.
Suppose we want to understand the effect of a treatment (T) on an outcome (Y) while accounting for nuisance parameters (X). We assume the relationship Y = f(T, X) can be estimated with a linear equation:

As is custom, we run ordinary least squares regression to obtain the coefficient matrices θ1 and θ2. From a causal inference perspective, we're interested in θ1 because it tells us the effect of T on Y while accounting for X. But what does mean to take X into account? FWL theorem gives us an answer.
FWL theorem states that we can obtain the same _θ_1 coefficient by doing the following. First, regress X on Y and X on T separately:

Here, Tlr and Ylr are the predicted values of T and Y, respectively, from a linear regression with X. In the second step, FWL theorem states that we ** can recover _θ**_1 by performing the following regression:

This tells us that we obtain the same θ1 that we get from the full regression of X and T on Y by regressing the residuals of Y on the residuals of T.
At first glance, this theorem might not seem to have much value. Why would we perform three separate regressions when we get the same results performing just one?
The power of FWL theorem comes from the intuition it gives us about linear regression and how we can extrapolate this idea to work for non-linear data.
To understand this intuition, let's interpret each residual term:
- Y – Ylr: This gives us the residualized value of Y after removing any variance in Y that can be explained by a linear relationship with X. If X perfectly predicts Y, the Y – Ylr will be zero, implying T is not necessary to predict Y. From a causal perspective, the better our nuisance parameters predict the outcome, the less evidence there is that the treatment affects the outcome.
- T – Tlr: This is the residualized value of T after removing any variance that can be explained by a linear relationship with X. This residual allows us to remove confounding bias from our treatment caused by the nuisance parameters. If T – Tlr is 0, this implies that the treatment assignment is deterministic, and we have no way to perform causal inference. If our nuisance parameters are a sufficient adjustment set, and we can assume there's a non-deterministic linear relationship with the treatment, then we can think of T – Tlr as a debiased treatment that's as good as random.
When we regress these residuals onto one another, we isolated the effect of T on Y because we've removed any variance that can be explained by X. This is why the coefficient θ1 can be interpreted as the expected change in Y when we increase T by one unit, while holding X fixed.
To solidify the ideal FWL theorem conveys, let's look at two examples in Python.
Example 1: A Confounded Treatment Effect
For this example, we'll use a synthetic dataset generated according to the following equations:

We want to evaluate the effect of t on y, but the relationship between t and y is confounded by nuisance parameters x1 and x2. We see this because the first equation tells us that t can be calculated by a linear combination of x1 and x2 plus Gaussian noise. We also see that t is a relatively weak predictor of y compared to x1 and x2, as ** indicated by its small coefficient**.
Nonetheless, we're only interested in the effect of t on y, regardless of its strength. Here are the dependencies we'll need for the remainder of this tutorial:
DATA_PATH = "data/"
import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt
from lightgbm.sklearn import LGBMRegressor
from sklearn.model_selection import cross_val_predict
Let's read in the dataset:
data_linear_1 = pd.read_csv(DATA_PATH + "linear_example1.csv")
print(data_linear_1.shape)
"(10000, 4)"
print(data_linear_1.head())
"""
t x1 x2 y
0 31.139704 1.101262 0.682916 874.357591
1 -46.437301 0.338431 0.064071 -181.341797
2 -94.685877 -0.539972 1.060375 2190.136391
3 -30.234819 -1.260242 1.994309 4525.823182
4 43.309347 -1.894621 -0.086365 870.204667
"""
The linear_example1.csv
dataset has 10,000 observations of t, x1, x2, and y. Here's what a scatter plot of t vs y looks like:

From this scatter plot, there doesn't appear to be a discernible relationship between t and y. We can confirm this by running a simple linear regression of t on y:
t_ols_model = smf.ols("y ~ t", data=data_linear_1)
print(t_ols_model.fit().summary().tables[1])
"""
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
Intercept -11.8412 20.760 -0.570 0.568 -52.535 28.853
t 0.2665 0.418 0.637 0.524 -0.553 1.086
==============================================================================
"""
We use statsmodels
to regress t on y. The summary table shows us that the estimated coefficient on t is 0.2665 – this is much lower than the true coefficient of 2. Moreover, the p-value for the coefficient on t is 0.524, meaning there's insufficient evidence to conclude that this coefficient is significantly different from 0. If we incorrectly assumed that there was no confounding between t and y, we might conclude that t does not affect y (i.e. the ATE is 0).
We can run the full regression of t, x1, and x2 on y to get a better estimate of the t's coefficient:
full_linear_model = smf.ols("y ~ t + x1 + x2", data=data_linear_1)
print(full_linear_model.fit().summary().tables[1])
"""
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
Intercept 0.1830 0.504 0.363 0.716 -0.805 1.171
t 1.9998 0.010 196.444 0.000 1.980 2.020
x1 -499.9225 0.505 -990.130 0.000 -500.912 -498.933
x2 2000.2940 0.500 4000.091 0.000 1999.314 2001.274
==============================================================================
"""
When we include x1 and x2 in the regression, we get a much different estimate of the coefficient on t. The coefficient on t is now 1.9998, much closer to the true coefficient of 2. We also see that the p-value on this coefficient is effectively 0, indicating statistical significance. We can interpret this as follows:
While holding x1 and x2 fixed, for every one unit increase in t, we expect y to increase by about 1.9998 on average. If x1 and x2 are a sufficient adjustment set, we can conclude that the ATE of t on y is about 1.9998.
This makes a lot of sense because x1 and x2 are confounders of the relationship between t and y. Therefore, to accurately estimate the effect of t on y, we have to hold x1 and x2 constant and measure how much y changes when we change t. This is exactly what linear regression does.
Now, FWL theorem tells us we can obtain the same coefficient on t by decomposing the regression. To do this, we fit one model that predicts t from x1 and x2, and a second model that predicts y from x1 and x2:
t_resid_model = smf.ols("t ~ x1 + x2", data=data_linear_1).fit()
y_resid_model = smf.ols("y ~ x1 + x2", data=data_linear_1).fit()
We then use these models to create residualized versions of t and y:
data_linear_1["t_resid"] = data_linear_1["t"].values - t_resid_model.predict(data_linear_1)
data_linear_1["y_resid"] = data_linear_1["y"].values - y_resid_model.predict(data_linear_1)
Here, t_resid
is the value of t after removing variance that can be predicted by a linear relationship with x1 and x2, and y_resid
is the value of y after removing variance that can be predicted by a linear relationship with x1 and x2. Hence, if x1 and x2 are a sufficient adjustment set, any variance left in y_resid
can only be attributed to noise or changes in t_resid
.
Here's what happens when we regress t_resid
on y_resid
:
resid_linear_model = smf.ols("y_resid ~ t_resid", data=data_linear_1)
print(resid_linear_model.fit().summary().tables[1])
"""
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
Intercept 2.526e-13 0.504 5.02e-13 1.000 -0.987 0.987
t_resid 1.9998 0.010 196.464 0.000 1.980 2.020
==============================================================================
"""
Just as FWL theorem tells us, we get the same coefficient on the residualized version of t (1.9998) ** as we do when we fit the full regression model. We also get the same p-value and confidence interval. By removing the effects of x1 and x2 on ** t and y, we recover the effect of t on y using simple linear regression.
We can visualize the results of residualization by plotting t_resid
vs y_resid
:

This looks much different than the original scatter plot of t vs y. The relationship between the t and y residuals is linear, as compared to the noisy relationship between the raw values of t and y. By removing the confounding effects of x1 and x2, we've uncovered the true relationship between t and y.
Example 2: Spurious Correlations
Let's do the same exercise with data generated according to the following functions:

Notice in these equations, t is a linear function of x1 and Gaussian random noise. What stands out, however, is that t does not affect y since it's missing from the equation for y.
Let's read in the dataset for this example:
data_linear_2 = pd.read_csv(DATA_PATH + "linear_example2.csv")
print(data_linear_2.shape)
"(10000, 4)"
print(data_linear_2.head())
"""
t x1 x2 y
0 44.758318 1.101262 0.682916 -44.527435
1 12.625887 0.338431 0.064071 -63.033848
2 -23.450555 -0.539972 1.060375 36.979009
3 -50.945122 -1.260242 1.994309 70.451553
4 -74.999631 -1.894621 -0.086365 102.008320
"""
As before, let's plot t vs y:

From this scatterplot, it appears that there's a strong negative linear relationship between t and y. We confirm this by running a simple regression of t on y:
t_ols_model = smf.ols("y ~ t", data=data_linear_2)
print(t_ols_model.fit().summary().tables[1])
"""
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
Intercept 0.0916 0.544 0.169 0.866 -0.974 1.158
t -1.2452 0.014 -91.498 0.000 -1.272 -1.219
==============================================================================
"""
When we regress t on y, we get a statistically significant negative coefficient (-1.2452) on t. If we were to incorrectly assume that there are no confounders interfering with the relationship between t and y, we might conclude that t has a causal effect on y. This is a classic example of correlation not implying causation.
Let's skip the full regression of t, x1, and x2 on y and create residuals for t and y:
t_resid_model = smf.ols("t ~ x1 + x2", data=data_linear_2).fit()
y_resid_model = smf.ols("y ~ x1 + x2", data=data_linear_2).fit()
data_linear_2["t_resid"] = data_linear_2["t"].values - t_resid_model.predict(data_linear_2)
data_linear_2["y_resid"] = data_linear_2["y"].values - y_resid_model.predict(data_linear_2)
Here's what the t and y residuals look like:

After removing the variance from t and y due to a linear relationship with x1 and x2, we see no relationship between t and y. The spurious correlation we saw by directly comparing t to y appeared because t is highly correlated with x1, and x1 causes y. The association between t and y vanishes once we account for x1 and x2, revealing the true relationship between t and y.
We can confirm this by regressing the t and y residuals again:
resid_linear_model = smf.ols("y_resid ~ t_resid", data=data_linear_2)
print(resid_linear_model.fit().summary().tables[1])
"""
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
Intercept 8.59e-15 0.504 1.71e-14 1.000 -0.987 0.987
t_resid -0.0122 0.509 -0.024 0.981 -1.010 0.985
==============================================================================
"""
This regression confirms that there is no discernible effect of t on y after we remove variance caused by x1 and x2. The p-value and confidence interval on the t_resid
coefficient indicate that it's very likely to be 0, i.e., t
has no effect on y
.
With this, we've seen FWL theorem in action and have an intuitive feel for the nature of linear regression. Crucially, we saw how removing variance from t and y predicted by a linear relationship with the nuisance parameters, i.e. residualization, reveals the true relationship between t and y.
The next step is to figure out how we can use FWL theorem on nonlinear data. Can we still use FWL theorem to uncover the relationship between t and y? We'll cover this next.
Applying FWL Theorem to Nonlinear Data
FWL theorem is only guaranteed to be exact for linear regression, but real-world data can be complicated and nonlinear. Fortunately, we can extrapolate the intuition FWL theorem gives us and apply it to arbitrary data to uncover causal relationships.
Partial Linear Models
While we ultimately don't want to make assumptions about the functional form of the outcome with respect to the treatment and nuisance parameters, a natural next step from FWL theorem is to see if we can recover a linear treatment effect in the presence of nonlinear nuisance parameters.
That is, we assume our outcome can be written as follows:

Y is a linear function of T plus an arbitrary, potentially nonlinear, function of X. We can also assume that T is an arbitrary function of X, allowing for confounding. We call this a partial linear model because it has linear and nonlinear components. While the assumptions of a partial linear model are still restrictive, there's much more flexibility than a linear model which is a step closer to our goal.
So how do we fit a partial linear model? We know that modern machine learning models, such as tree-based models, can fit nonlinear functions. Let's denote an arbitrary machine learning model that predicts Y as a function of X by Yml = E[Y | X], and a model that predicts T as a function of X by Tml = E[T | X]. To remove the variance in T and Y caused by an arbitrary function of X, we can redefine the residuals from FWL theorem as follows:

Recall from the original FWL theorem that E[T | X] and E[Y | X] are linear models. However, we can use any machine learning model to estimate E[T | X] and E[Y | X], allowing us to account for nonlinearities in the residuals.
With these residuals, assuming Tml and Yml are accurate estimates of E[T | X] and E[Y | X], we can recover θ1 with the following regression:

This says, using arbitrary machine learning models to calculate residuals for T and Y, we can estimate θ1 by regressing the residuals of T on the residuals of Y.
To try this out, we'll use data generated by the following process:

Here, we only know that t is a function of x1 and x2, and y is a linear function of t plus an arbitrary function function of x1 and x2.
As before, let's read in this dataset:
data_partial_linear = pd.read_csv(DATA_PATH + "partial_linear_example.csv")
print(data_partial_linear.shape)
"(10000, 4)"
print(data_partial_linear.head())
"""
t x1 x2 y
0 2.052069 1.101262 0.682916 145.112924
1 0.363920 0.338431 0.064071 81.352175
2 -1.508082 -0.539972 1.060375 -37.170164
3 -1.182025 -1.260242 1.994309 -119.741583
4 -0.481000 -1.894621 -0.086365 11.796844
"""
And visualize t vs y with a scatter plot:

We again see a seemingly noisy relationship between t and y. Let's confirm this with a simple linear regression of t on y:
t_ols_model = smf.ols("y ~ t", data=data_partial_linear)
print(t_ols_model.fit().summary().tables[1])
"""
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
Intercept 27.1079 0.598 45.330 0.000 25.936 28.280
t 17.8144 0.428 41.616 0.000 16.975 18.653
==============================================================================
"""
The simple linear model severely overestimates the true t coefficient of 2. Let's see what happens when we include the nuisance parameters in the regression:
full_linear_model = smf.ols("y ~ t + x1 + x2 + x1*x2", data=data_partial_linear)
print(full_linear_model.fit().summary().tables[1])
"""
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
Intercept 34.2972 0.442 77.676 0.000 33.432 35.163
t 4.9527 0.349 14.188 0.000 4.268 5.637
x1 34.6709 0.443 78.290 0.000 33.803 35.539
x2 -0.0551 0.388 -0.142 0.887 -0.816 0.706
x1:x2 21.1510 0.386 54.814 0.000 20.395 21.907
==============================================================================
"""
The full linear model gets closer to the true t coefficient but still overestimates. Notice here how we've included an interaction term between x1 and x2 since we don't fully know the functional form of y which could be affected by interactions.
Let's continue as if we don't know that y is a partial linear function of x1, x2, and t, and assume y is a linear function. We've already run a full linear regression and estimated the coefficient on t to be around 4.95. To check our assumption about y being a linear function of x1, x2, and t, let's compute and visualize the t and y residuals using linear models:
t_resid_model = smf.ols("t ~ x1 + x2 + x1*x2", data=data_partial_linear).fit()
y_resid_model = smf.ols("y ~ x1 + x2 + x1*x2", data=data_partial_linear).fit()
data_partial_linear["t_resid"] = data_partial_linear["t"].values - t_resid_model.predict(data_partial_linear)
data_partial_linear["y_resid"] = data_partial_linear["y"].values - y_resid_model.predict(data_partial_linear)
Just as we did in the first two examples, we use the nuisance parameters (x1 and x2) ** as features in linear models that estimate t and y. We then use these linear models to create t and y** residuals. Here's what the residuals look like:

Yikes! From inspecting this scatter plot, it's hard to maintain our assumption that y is a linear function of t, even though it is. So what's going on here? The issue is that we can't accurately estimate t and y using linear functions of the nuisance parameters, and this is essential for us to uncover the relationship between t and y.
To fix this, let's replace the linear models we used to estimate t and y with LightGBM regression models:
nuisance = ["x1", "x2"]
model_t = LGBMRegressor(n_estimators=800, max_depth=6)
model_y = LGBMRegressor(n_estimators=800, max_depth=6)
data_partial_linear["t_hat"] = cross_val_predict(model_t, data_partial_linear[nuisance], data_partial_linear["t"], cv=5)
data_partial_linear["y_hat"] = cross_val_predict(model_y, data_partial_linear[nuisance], data_partial_linear["y"], cv=5)
data_partial_linear["t_resid"] = data_partial_linear["t"] - data_partial_linear["t_hat"]
data_partial_linear["y_resid"] = data_partial_linear["y"] - data_partial_linear["y_hat"]
We use LGBMRegressor
to estimate t and y as a function of the nuisance parameters. We then add the predictions for t and y to data_partial_linear
using cross_val_predict
with 5 folds, ensuring that each predicted value of t and y is made out of sample. Lastly, we calculate and store the residuals in data_partial_linear
.
Here's what the residuals look like:

What we see is here really fascinating. While there is still noise and some outliers, the linear relationship between t and y is much more evident after removing variance caused by nonlinear relationships with x1 and x2 using LightGBM models.
We can confirm this by regressing the new t and y residuals:
resid_linear_model = smf.ols("y_resid ~ t_resid", data=data_partial_linear).fit()
print(resid_linear_model.summary().tables[1])
"""
==============================================================================
coef std err t P>|t| [0.025 0.975]
------------------------------------------------------------------------------
Intercept -0.0246 0.039 -0.631 0.528 -0.101 0.052
t_resid 1.9824 0.037 53.344 0.000 1.910 2.055
==============================================================================
"""
This is a pretty remarkable result. The coefficient on t_resid is about 1.98, much closer to the true value of 2 than our estimate from multiple linear regression. We used the intuition from FWL theorem to remove variance from t and y caused by a nonlinear relationship with the nuisance parameters, effectively revealing the linear relationship between t and y.
One advantage of the partial linear model is that we can make statistical statements about the linear relationship between T and Y using techniques like ordinary least squares or Bayesian regression. After removing the effects of the nuisance parameters on T and Y using a machine learning model of our choice, we've reduced the problem down to a simple linear regression, and we reap the benefits of p-values and confidence/credible intervals on the T coefficient.
Of course, the assumptions of a partial linear model may be too restrictive for many use cases. If we can't assume a linear relationship between T and Y, or if we suspect a nonlinear interaction between T and X, a partial linear model will not be sufficient to uncover the causal relationship between T and Y.
Our ultimate goal is to learn causal relationships between T and Y when Y is a completely unknown function of T and X (i.e. Y = f(T, X)). **** We'll attempt to do this in the next section.
A Leap of Faith – Nonlinear Outcome Functions
In this final section, we'll relax all assumptions about the outcome function (i.e. Y = f(T, X)), and attempt to learn the causal relationship between T and Y. This is no easy feat. The causal relationship between T and Y can be quite complex and may be heterogeneous with respect to X. Moreover, it's difficult to evaluate how well we've approximated this function. Nonetheless, we'll show that, at least empirically, this can be done.
Let's start by recalling the equation for the residualized partial linear model:

In the partial linear model, we remove variance in Y and T using machine learning models to approximate E[Y | X] and E[T | X], but we're still assuming there's a homogenous linear relationship between T and Y. Removing this assumption, in principle, is actually straightforward. We'll take the residuals we computed from the partial linear model and use another machine learning model to try to predict the Y residuals from the T residuals and X:

Here, we're removing assumptions by stating that the Y residuals are a function of the T residuals and X. We include X in the function to account for interactions between the T residuals and X in relation to the Y residuals. We already have the models needed to approximate the T and Y residuals, and we can use a third machine learning model to approximate f.
With our final model, f, we can plug in different values of the T residuals and X to see how ** the estimated Y residuals change. This model will effectively allow us to make _counterfactual prediction_s on new data. That is, if we hold our nuisance parameters (X) constant, how do we predict ** Y will change as T changes?
As Matheus Facure points out in his book on causal inference, this is a potentially dangerous thing to do because we might not know how the treatment functions in a real-world dataset. On the other hand, if we feel confident in our estimation of f, then we could have a really powerful causal model that can help us understand the impact of a range of treatments on an outcome.
Let's see how this works with one last example. The only assumption we'll make about our dataset is that it was generated according to the following equations :

We're only assuming that t is a function of x1 and x2, and y is a function of t, x1, x2, and x3. This implies that x1 and x2 are confounders, but x3 isn't. However, x3 could have interactions with t that affect its relationship with y.
Let's read in this dataset:
data_non_linear = pd.read_csv(DATA_PATH + "non_linear_example.csv")
print(data_non_linear.shape)
"(15000, 5)"
print(data_non_linear.head())
"""
t x1 x2 x3 y
0 0.815358 1.101262 -0.378001 7 -0.448740
1 -4.789706 0.338431 -0.585686 7 -0.756166
2 -0.201015 -0.539972 0.964218 8 -0.861377
3 -4.267054 -1.260242 0.105409 2 1.447061
4 -2.871802 -1.894621 -1.388249 5 -0.199239
"""
print(data_non_linear["x3"].nunique())
"9"
Notice that x3 is an integer that takes on 9 unique values – this will be important in a moment. Next, let's visualize the relationship between t and y:

Again, the relationship between t and y isn't discernible from the raw data. To try to remove any variance in t and y caused by the nuisance parameters, we'll fit LightGBM models as we did with the partial linear model:
nuisance = ["x1", "x2", "x3"]
model_t = LGBMRegressor(n_estimators=300, max_depth=6)
model_y = LGBMRegressor(n_estimators=300, max_depth=6)
data_non_linear["t_hat"] = cross_val_predict(
model_t, data_non_linear[nuisance], data_non_linear["t"], cv=5
)
data_non_linear["y_hat"] = cross_val_predict(
model_y, data_non_linear[nuisance], data_non_linear["y"], cv=5
)
data_non_linear["t_resid"] = data_non_linear["t"] - data_non_linear["t_hat"]
data_non_linear["y_resid"] = data_non_linear["y"] - data_non_linear["y_hat"
Let's see what the residuals look like:

Astonishingly, a clear nonlinear relationship appears between the treatment and outcome residuals. In particular, this scatter plot suggests there's a quadratic relationship between t and y. We can see how well a quadratic polynomial fits these residuals by regressing the square of the t residuals on the y residuals:
data_non_linear["t_resid_squared"] = data_non_linear["t_resid"]**2
resid_quadratic_model = smf.ols("y_resid ~ t_resid_squared", data=data_non_linear).fit()
print(resid_quadratic_model.summary())
"""
OLS Regression Results
==============================================================================
Dep. Variable: y_resid R-squared: 0.547
Model: OLS Adj. R-squared: 0.547
Method: Least Squares F-statistic: 1.808e+04
Date: Tue, 19 Mar 2024 Prob (F-statistic): 0.00
Time: 09:54:06 Log-Likelihood: 13503.
No. Observations: 15000 AIC: -2.700e+04
Df Residuals: 14998 BIC: -2.699e+04
Df Model: 1
Covariance Type: nonrobust
===================================================================================
coef std err t P>|t| [0.025 0.975]
-----------------------------------------------------------------------------------
Intercept 0.0758 0.001 77.241 0.000 0.074 0.078
t_resid_squared -0.0028 2.12e-05 -134.461 0.000 -0.003 -0.003
==============================================================================
Omnibus: 3204.283 Durbin-Watson: 2.004
Prob(Omnibus): 0.000 Jarque-Bera (JB): 72693.392
Skew: -0.461 Prob(JB): 0.00
Kurtosis: 13.745 Cond. No. 56.7
==============================================================================
Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
"""
We see a moderately strong R-squared value between the square of the treatment residuals and the outcome residuals. In other words, there's decent evidence that the treatment affects the outcome quadratically once we've removed the effects of the nuisance parameters.
We can plot the predicted quadratic polynomial against the residuals to visualize this:

Indeed, the quadratic model seems to be a good fit, but this is only the first step. Recall that the treatment can interact with the nuisance parameters, meaning the relationship between t and y depends on the values of x1, x2, and x3. What we have so far is the y residuals as a function of the t residuals, but we want the y residuals as a function of the t residuals and the nuisance parameters.
To get this, we can use the t and y residuals we already computed and fit a new model with the nuisance parameters included:
data_non_linear["y_resid_2"] = cross_val_predict(
LGBMRegressor(n_estimators=300, max_depth=6),
data_non_linear[nuisance + ["t_resid"]],
data_non_linear["y_resid"],
cv=5,
)
Now let's see how well the quadratic model fits our new y residuals:
resid_quadratic_model_2 = smf.ols(
"y_resid_2 ~ t_resid_squared", data=data_non_linear
).fit()
print(resid_quadratic_model_2.summary())
""" OLS Regression Results
==============================================================================
Dep. Variable: y_resid_2 R-squared: 0.776
Model: OLS Adj. R-squared: 0.776
Method: Least Squares F-statistic: 5.190e+04
Date: Tue, 26 Mar 2024 Prob (F-statistic): 0.00
Time: 08:30:40 Log-Likelihood: 21629.
No. Observations: 15000 AIC: -4.325e+04
Df Residuals: 14998 BIC: -4.324e+04
Df Model: 1
Covariance Type: nonrobust
===================================================================================
coef std err t P>|t| [0.025 0.975]
-----------------------------------------------------------------------------------
Intercept 0.0743 0.001 130.244 0.000 0.073 0.075
t_resid_squared -0.0028 1.23e-05 -227.816 0.000 -0.003 -0.003
==============================================================================
Omnibus: 2487.996 Durbin-Watson: 1.992
Prob(Omnibus): 0.000 Jarque-Bera (JB): 35598.209
Skew: -0.349 Prob(JB): 0.00
Kurtosis: 10.515 Cond. No. 56.7
==============================================================================
"""
Notice that the R-squared value when fitting a quadratic polynomial between t_resid_squared
and y_resid_2
is much higher than the R-squared between t_resid_squared
and y_resid
. Let's visualize this fit:

By allowing t to interact with the nuisance parameters, the nonlinear relationship between t and y becomes even more evident, but we still can't see how the relationship between t and y changes depending on the nuisance parameters.
To see this, instead of using cross_val_predict()
, we'll fit a single model that uses the t residuals and nuisance parameters to predict the y residuals:
t_model = LGBMRegressor(n_estimators=300, max_depth=6)
y_model = LGBMRegressor(n_estimators=300, max_depth=6)
final_model = LGBMRegressor(n_estimators=300, max_depth=6)
t_model.fit(data_non_linear[nuisance], data_non_linear["t"])
y_model.fit(data_non_linear[nuisance], data_non_linear["y"])
final_model.fit(data_non_linear[nuisance + ["t_resid"]], data_non_linear["y_resid"])
t_model
predicts t as a function of the nuisance parameters, and y_model
predicts y as a function of the nuisance parameters. final_model
predicts the y residuals as a function of the nuisance parameters and t residuals. We can use these three models to make counterfactual predictions on new data.
Let's read in some test data to make predictions on:
test_data = pd.read_csv(DATA_PATH + "non_linear_test.csv")
print(test_data.shape)
"(1800, 4)"
print(test_data)
"""
x1 x2 x3 t
0 0.003099 -0.014613 1 -10.000000
1 0.003099 -0.014613 1 -9.889447
2 0.003099 -0.014613 1 -9.778894
3 0.003099 -0.014613 1 -9.668342
4 0.003099 -0.014613 1 -9.557789
... ... ... .. ...
1795 0.003099 -0.014613 9 11.557789
1796 0.003099 -0.014613 9 11.668342
1797 0.003099 -0.014613 9 11.778894
1798 0.003099 -0.014613 9 11.889447
1799 0.003099 -0.014613 9 12.000000
"""
print(test_data["x1"].unique())
"[0.00309858]"
print(test_data["x2"].unique())
"[-0.01461324]"
print(test_data["x3"].unique())
"[1 2 3 4 5 6 7 8 9]"
Notice in this test dataset that x1 and x2 only take on a single value which is their mean from the training data, and x3 takes on all 9 values from the training data. This test set will allow us to see how the predicted response varies as a function of t while holding our nuisance parameters fixed.
First, we'll predict y and t from the nuisance parameters:
test_data["y_hat"] = y_model.predict(test_data[nuisance])
test_data["t_hat"] = t_model.predict(test_data[nuisance])
Then, we'll calculate the t residuals and use them to predict the y residuals using final_model
:
test_data["t_resid"] = test_data["t"] - test_data["t_hat"]
test_data["y_resid"] = final_model.predict(test_data[nuisance + ["t_resid"]])
Finally, we'll add the y predictions to the predicted y residuals to get final y predictions:
test_data["y_counterfactual"] = test_data["y_hat"] + test_data["y_resid"]
test_data
"""
x1 x2 x3 t y_hat t_hat t_resid y_resid y_counterfactual
0 0.003099 -0.014613 1 -10.000000 1.55532 0.201618 -10.201618 -0.240874 1.314446
1 0.003099 -0.014613 1 -9.889447 1.55532 0.201618 -10.091065 -0.240874 1.314446
2 0.003099 -0.014613 1 -9.778894 1.55532 0.201618 -9.980512 -0.240874 1.314446
3 0.003099 -0.014613 1 -9.668342 1.55532 0.201618 -9.869960 -0.233701 1.321619
4 0.003099 -0.014613 1 -9.557789 1.55532 0.201618 -9.759407 -0.233701 1.321619
... ... ... .. ... ... ... ... ... ...
1795 0.003099 -0.014613 9 11.557789 -1.34144 1.441577 10.116212 -0.224248 -1.565688
1796 0.003099 -0.014613 9 11.668342 -1.34144 1.441577 10.226765 -0.230682 -1.572122
1797 0.003099 -0.014613 9 11.778894 -1.34144 1.441577 10.337317 -0.230682 -1.572122
1798 0.003099 -0.014613 9 11.889447 -1.34144 1.441577 10.447870 -0.230682 -1.572122
1799 0.003099 -0.014613 9 12.000000 -1.34144 1.441577 10.558423 -0.230682 -1.572122
"""
Now, we know x1 and x2 are fixed at their mean value, and x3 takes on 9 discrete values. At these values of x1 and x2, we can visualize how the predicted relationship between t and y varies at different levels of x3:

What we see here is that, at different values of x3, holding x1 and x2 fixed, the final model predicts different relationships between t and y. In other words, the model has picked up on nonlinear heterogeneous treatment effects. The quadratic relationship that we observed in the residuals varies depending on the nuisance parameters.
Here's what the predicted relationship between t and y looks like for the remaining values of x3:

So what do we make of all this? The main takeaway here is that final_model
has learned a representation of a nonlinear causal relationship between t and y that varies depending on the nuisance parameters. This is possible because the nuisance parameters, x1, x2, and x3, are a sufficient adjustment set and we were able to remove confounding bias from t and y with machine learning models.
While this method can work for counterfactual predictions, it's difficult to evaluate in practice. Instead, what's commonly used is the R-learner from the double machine learning framework. Intuitively, we can think of the R-learner as an approximation of the derivative of f(T, X) with respect to T evaluated at a point (x, t). This is exactly the CATE, and we learn this by optimizing the following loss function:

Here, tml and yml are predictions from the machine learning models we used to estimate T and Y as a function of X. τ(X, T) is the model we're optimizing that predicts the CATE at a point (xi, ti). This is a slightly more complicated model to fit, but it's easier to evaluate, and it uses the same machine learning-derived residuals that we've used since the partial linear model.
The big picture here is that, whether we're trying to generate counterfactual predictions, or we want to estimate CATEs, we can take the ideas from FWL theorem to generate residuals that account for nonlinear bias interfering with the causal relationship at hand.
Final Thoughts and Warnings
In this article, we explored the foundations of DML by building on the insights we get from FWL theorem. We saw through examples with synthetic data that it is possible to uncover nonlinear causal relationships by replacing linear models with machine learning models to compute residuals.
However, we have to be cautious. The sophistication of DML and related techniques can be a double-edged sword. While these methods are powerful in adjusting for confounding and unveiling hidden causal relationships, they also require scrutiny of their assumptions and careful interpretation of their results. Over reliance on model outputs without understanding the underlying mechanics can lead to misleading conclusions.
Of course, the real-world is messy, and it's very unlikely that we'll know how a sophisticated treatment affects an outcome in the presence of high dimensional nuisance parameters. Machine learning is all about acquiring, understanding, and exploiting the right data, and causal machine learning is no exception. We __ have to have confidence that our nuisance parameters form a sufficient adjustment set, and we have to have a solid framework for evaluating causal machine learning models. With all of this in place, we can leverage the theoretical framework built from FWL theorem to make impactful causal predictions.
References
- Applied Causal Inference Powered by ML and AI – https://arxiv.org/abs/2403.02467
- Causal Inference for the Brave and True – https://matheusfacure.github.io/python-causality-handbook/landing-page.html