Exploring Counterfactual Insights: From Correlation to Causation in Data Analysis
Picture this: A world where the sky takes on a serene shade of lemon yellow, birds have come to their senses and eloquently converse in fluent English, and where fruit trees defy gravity, displaying their deep lilac and electric purple leaves while offering up the most delectable fruits at your every whim.
And you think, finally! the world makes sense.
Well hello there!
Let's step back into reality, but fear not, for we're about to embark on a journey that's just as captivating – the world of counterfactuals. While our initial vision may be a delightful flight of fancy, counterfactuals open the door to a different kind of wonder, one that allows us to explore the ‘what ifs' of our own world.
The term ‘counterfactual' might initially sound complex, but it simply means considering scenarios that are contrary to factual or actual events. Although the term itself was coined in 1946, the idea dates back centuries to when humans first began pondering ‘what if' scenarios.
In psychology, counterfactual thinking is frequently used to delve into scenarios that are different from events that have already occurred. For instance, we might ponder whether a criminal would have chosen a different life path if offered alternative opportunities.
However, as data scientists, our focus isn't on the intricacies of criminal psychology. Instead, we're interested in harnessing the power of counterfactuals in the world of data. We aim to understand why data appears as it does and how to draw meaningful inferences from it. Our realm is one of equations, objective methods, and practical applications of these concepts in the world of data.
To explore this intriguing and imaginative concept further, let's dive into an example. But first, we'll provide some context…
Setting the stage with an intriguing scenario
Picture a scenario where you're a data scientist working for MM Securities, a fictitious security firm specializing in assessing system vulnerabilities. At present, they are in the midst of securing a substantial client contract, but an important challenge has emerged. The client has a unique requirement: they want to know if the said vulnerabilities are one of the causes of a Ransomware Attack. They propose that if MM Securities can convincingly demonstrate that these vulnerabilities indeed contribute to ransomware attacks, athey will eagerly engage in business.
This situation has piqued the interest of MM Securities' senior leadership, as they believe demonstrating a causal link can bring significant value to their organization. As a result, they turn to their amazing data science team for answers. Their hypothesis to the team is as follows:
"Organizations with the specific vulnerabilities we assess are at a heightened risk of falling victim to ransomware attacks."
Thankfully, MM Securities has a history of successfully tackling such challenges and possesses a relevant dataset for this particular issue. This is your moment to shine, to dive deep into the heart of the matter, and uncover the underlying causal relationship. Given your expertise in this domain, your investigative journey begins.
Making the first causal assumption…
Now that you've started your investigative journey, you start with the simple assumptions that vulnerabilities have a direct impact on ransomware attacks.
To break it down more clearly:
The Independent Variable or the Suspected Cause: Vulnerabilities in a system
The Dependent Variable or the Suspected Effect: A ransomware attack
Now that we've got our data, our hypothesis, and our variables all lined up, it's time to put our theory to the test.
Python Code to Spice up our Analysis
We'll begin by constructing a straightforward bayesian model using the pgmpy library. Additionally, for this demonstration, we're going to generate some synthetic data.
Our synthetic data reflects a pretty even split, with roughly 50% of the instances having vulnerabilities in their systems. We also design it to have some positive correlation between the presence of vulnerabilities and the occurrence of ransomware attacks.
Given that let's jump right into the code.
#Importing packages
from pgmpy.models import BayesianNetwork
from pgmpy.estimators import MaximumLikelihoodEstimator
from pgmpy.inference import VariableElimination
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
num_samples = 500
# Let's Generate synthetic data for vulnerabilities (binary: 0 or 1)
vulnerabilities = np.random.choice([0, 1], num_samples)
ransomware = [1 if np.random.uniform(0, 1) < prob * (v + 1) else 0 for v in vulnerabilities]
data = pd.DataFrame({
'Vulnerability': vulnerabilities,
'Ransomware': ransomware
})
# Designing the Network
model = BayesianNetwork([('Vulnerability', 'Ransomware')])
model.fit(data,estimator=MaximumLikelihoodEstimator)
# Check model for early errors
assert model.check_model()
inference = VariableElimination(model)
# Calculating marginal probabilities
prob_vulnerability = inference.query(variables=['Vulnerability']).values
prob_ransomware = inference.query(variables=['Ransomware']).values
# Calculating conditional probabilities i.e. P(Ransomware | Vulnerability) 
evidence_vulnerability = {'Vulnerability': 1}
prob_ransomware_with_vulnerability = inference.query(variables=['Ransomware'], evidence=evidence_vulnerability)
evidence_no_vulnerability = {'Vulnerability': 0}
prob_ransomware_without_vulnerability = inference.query(variables=['Ransomware'], evidence=evidence_no_vulnerability)
# Visualization
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 10))
# Subplot 1: Data Distribution of Instances with and without vulnerabilities
ax1.pie(
    prob_vulnerability, 
    labels = ['Data Points without Vulnerabilities', 'Data Points with Vulnerabilities'], 
    autopct = '%1.1f%%', colors = ['gold', 'tomato'],
    explode = (0.0, 0.1))
ax1.set_title('Data Distribution for Vulnerabilities')
# Subplot 2: Data Distribution of Instances with and without Ransomware
ax2.pie(
    prob_ransomware, 
    labels = ['Data Points without Ransomware', 'Data Points with Ransomware'], 
    autopct = '%1.1f%%', colors = ['gold', 'tomato'],
    explode = (0.0, 0.1))
ax2.set_title('Data Distribution for Ransomware')
# Subplot 3: NetworkX Showing assumed relationship between the variables
G = nx.DiGraph()
G.add_node('Vulnerability', pos=(-0.8, 0.0))
G.add_node('Ransomware', pos=(0.8, 0.0))
G.add_edge('Vulnerability', 'Ransomware', label='Direct Effect')
node_pos = nx.get_node_attributes(G, 'pos')
edge_labels = nx.get_edge_attributes(G, 'label')
pos = nx.planar_layout(G)
nx.draw_networkx_edge_labels(G, node_pos, 
                             edge_labels = edge_labels, 
                             verticalalignment = 'top',
                             font_size = 12, ax = ax3)
nx.draw_networkx_labels(G, node_pos, 
                        verticalalignment = 'bottom', 
                        font_size = 12, ax = ax3)
nx.draw(G, pos, with_labels = False, node_size = 2500, node_color = 'tab:olive', ax = ax3)
ax3.set_title('Our Data Relationship Assumption Model')
# Subplot 4: Bar chart for conditional probability of Ransomware given Vulnerability
values = [prob_ransomware_without_vulnerability.values[0], 
     prob_ransomware_without_vulnerability.values[1],
     prob_ransomware_with_vulnerability.values[0], 
     prob_ransomware_with_vulnerability.values[1]]
labels = ['P(R = 0|V = 0)', # P(No Ransomware Attack | No Vulnerabilites)
          'P(R = 1|V = 0)', # P(Ransomware Attack | No Vulnerabilites)
          'P(R = 0|V = 1)', # P(No Ransomware Attack | Vulnerabilites)
          'P(R = 1|V = 1)'] # P(Ransomware Attack | Vulnerabilites)
ax4.bar(labels, values, color = ['gold', 'tab:olive', 'lightcoral', 'tomato'])
for i, value in enumerate(values):
    ax4.annotate(f'{value:.2f}', (i, value), ha='center', va='bottom', fontsize=12)
ax4.set_xlabel('Vulnerability')
ax4.set_ylabel('Probability of Ransomware')
ax4.set_title('Conditional Probability of Ransomware given Vulnerability')
ax4.set_xticklabels(labels, rotation=45)
plt.tight_layout()
plt.show()Now, let's delve into visualizations to glean deeper insights. We visualize the initial data distribution and our assumed graphical model using the NetworkX library. These visualizations reveal a clear correlation between the presence of vulnerabilities and the occurence of ransomware attacks.
In the figure below, we illustrate this correlation by displaying conditional probability between ransomware and vulnerability i.e.
P( Ransomware | Vulnerability)
Before we go further, let's take a moment to understand conditional probability.
Conditional Probability
Conditional probability, denoted as P(X|Y), simply suggests the probability of the occurrence of X occurring given the occurrence of Y. It's important to note that conditional probability doesn't imply causation or the sequence of events; it solely addresses the correlation between them.
Under most circumstances P(X|Y) != P(Y|X) but both these say nothing about causality just the probability is being derived from a different set of instances in each case.
Now that we've clarified this fundamental concept, you might recall the famous saying from a statistics class:
Correlation is not causation.
Therefore, our current analysis, which has revealed a correlation between vulnerability presence and ransomware attacks, isn't sufficient to prove causation.
So, what is causation, how do we define it, and what does it have to do with counterfactuals?
In data science, discussing counterfactuals often intertwines with causality, interventions, and model interpretability. Up to this point, we've merely examined an existing dataset to identify correlation, but we haven't determined whether this correlation implies causation.
To explore causation, we start with counterfactual analysis. Let's consider a possible counterfactual scenario and test it:
Possible Counterfactual: Would there be a Ransomware given no vulnerability?
Possible Counterfactual Statement: The org would have not been effected by ransomware had there been no vulnerabilities.
To establish causation, we need to investigate whether the removal of vulnerabilities eliminates or atleast reduces the likelihood of ransomware. In theory, this is plausible, but in practice, it often requires extensive time and complex data collection efforts.
Before we proceed with the code, let's discuss a critical concept that connects our technical analysis to the broader notion of causation: the ‘do' operator. Understanding how this operator works is essential for testing our counterfactual assumptions.
The ‘do' Calculus and Intervention
In probability theory, causal relationships are often mathematically represented through interventions, utilizing the ‘do' operator.
In our current scenario, we aim to intervene using the ‘do' operator: by treating vulnerabilities and observing the impact on ransomware attacks. This intervention is actively changing something that introduces the concept of sequence of events and goes beyond mere observations; it allows us to assess the effect on the variable Ransomware based on changes to variable Vulnerabilities.
However, it's important to acknowledge that in many real-world situations, conducting such interventions is impractical or impossible, due to various constraints and ethical considerations.
Now, that we have clarified the above concept, let's move to the actual analysis.
Counterfactual Analysis in Python
To conduct the intervention in our model, we introduce an additional variable called ‘treatment'. Treatment, in this context, means treatment or remediation of vulnerabilities in the system. We here assume that MM Securities takes proactive steps to address vulnerabilities in their customers' systems.
Specifically, we assume that MM Securities applies this treatment independently to about 60% of the organizations they serve. Moreover, this treatment proves to be effective in remediating vulnerabilities approximately 90% of the time. However, if the treatment is not applied, nothing changes and the data distribution of vulnerabilities remains the same.
It's worth noting that this modeling process demands careful consideration of data, a deep understanding of the system, and domain expertise, which is often necessary when working with graphical models.
For the next step in our analysis, we can easily code this model in Python using our known probabilities. Fortunately, pgmpy poffers a TabularCPD option, where CPD stands for Conditional Probability Distribution.
Our new scenario can be summarized as follows:
We introduce the ‘treatment' variable that directly impacts the vulnerabilities.
- 40% of instances do not receive the ‘treatment': P(do(No Treatment)) = 0.4 , while 60% do: P(do(Treatment)) = 0.6.
- When no treatment is applied, the distribution of vulnerabilities remains unaffected and can be derived from our initial data: P(Vulnerability = 0| do(No Treatment)) = 0.52, and P(Vulnerability = 1| do(no Treatment)) = 0.48
- When treatment is applied, independent to the vulnerability data point at hand, 90% of the instances result in zero vulnerabilities: P(Vulnerability = 0| do(Treatment)) = 0.9, while 10% still have vulnerabilities: P(Vulnerability = 1 | do(Treatment)) = 0.1
We also have prior knowledge of the conditional probabilities for ransomware and vulnerabilities from our previous data, which we incorporate into our analysis:
- P(Ransomware = 0 | Vulnerability = 0) = 0.80
- P(Ransomware = 0 | Vulnerability = 1) = 0.56
- P(Ransomware = 1| Vulnerability = 0) = 0.20
- P(Ransomware = 1| Vulnerability = 1) = 0.44
With this information, we can now proceed to analyze the effects of our intervention.
So let's jump right in to the python implementation:
from pgmpy.models import BayesianNetwork
from pgmpy.factors.discrete import TabularCPD
from pgmpy.inference import VariableElimination
import matplotlib.pyplot as plt
# Define Conditional Probability Distributions (CPDs)
cpd_treatment = TabularCPD(variable='Treatment', variable_card=2, values=[[0.4], [0.6]])
cpd_vulnerability = TabularCPD(variable='Vulnerability', variable_card=2, values=[[0.52, 0.9], [0.48, 0.1]],
                            evidence=['Treatment'], evidence_card=[2])
cpd_ransomware = TabularCPD(variable='Ransomware', variable_card=2, values=[[0.80, 0.56], [0.20, 0.44]],
                            evidence=['Vulnerability'], evidence_card=[2])
# Create Network and add CPDs to the model
model = BayesianNetwork([('Treatment','Vulnerability'),('Vulnerability', 'Ransomware')])
model.add_cpds(cpd_treatment ,cpd_vulnerability, cpd_ransomware)
# Check model consistency
assert model.check_model()
inference = VariableElimination(model)
# Caluclating Total Effect of Treatment
evidence_treatment = {'Treatment': 1} 
prob_treatment = inference.query(variables=['Ransomware'], evidence=evidence_treatment)
evidence_no_treatment = {'Treatment': 0}  
prob_no_treatment = inference.query(variables=['Ransomware'], evidence=evidence_no_treatment)
total_effect = prob_treatment.values[1] - prob_no_treatment.values[1]
print('Total Effect of the treatment: ', total_effect)
# Creating Visualization
# 2 Subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
# Subplot 1: NetworkX Showing treatment on the assumed relationship between the variables
G = nx.DiGraph()
G.add_nodes_from(['Treatment', 'Vulnerability', 'Ransomware'])
pos = {
    'Treatment': (0, 1),
    'Vulnerability': (1, 0),
    'Ransomware': (2, 1),
}
G.add_edge('Vulnerability', 'Ransomware', label='Direct Effect')
G.add_edge('Treatment', 'Vulnerability', label='Treatment')
edge_labels = nx.get_edge_attributes(G, 'label')
nx.draw_networkx_edge_labels(G, pos, 
                             edge_labels = edge_labels, 
                             verticalalignment = 'top',
                             font_size = 12, ax = ax1)
nx.draw(G, pos, with_labels = True, 
        node_size=2500, node_color='tab:olive',
        arrowstyle="-|>,head_width=0.5,head_length=1", ax=ax1)
ax1.set_title('Our Data Relationship Assumption Model')
# Subplot 2: Bar chart for intervention of Ransomware given do(Treatment)
total_effect = prob_treatment.values[1] - prob_no_treatment.values[1]
values = [prob_no_treatment.values[1], 
          prob_treatment.values[1],
          total_effect]
labels = ['E(Ransomware|No treatment)', 
          'E(Ransomware|do(Treatment))', 
          'Total Effect']
ax2.bar(
    labels,
    values, 
    color = ['gold', 'lightcoral', 'tomato']
)
for i, value in enumerate(values):
    ax2.annotate(f'{value:.2f}', (i, value), ha='center', va='bottom', fontsize=12)
ax2.set_xlabel('Scenario')
ax2.set_ylabel('Expectation of Ransomware')
ax2.set_title('Effect of treatment on ransomware')
plt.savefig('treated_plots.png', dpi=300, bbox_inches='tight')
plt.tight_layout()
plt.show()Our primary objective in conducting this counterfactual analysis was to determine whether the absence of vulnerabilities would reduce ransomware cases. To assess this, we measure the total effect of treating vulnerabilities on ransomware incidents.
To measure this effect, we compute the difference in expected value between two scenarios: one where no treatment is applied and another where treatment is actively administered.
Our total effect in this scenario turned out to be net negative as can be seen in the figure.
The visual representation clearly illustrates that implementing the treatment has a noticeable effect on reducing the number of ransomware attacks. We observe a net negative effect on ransomware incidents when treatment is administered to the vulnerabilities.
While this outcome doesn't conclusively prove causation, it strongly suggests that vulnerability is a factor influencing ransomware attacks, and treating vulnerabilities is likely to reduce such attacks.
Counter Factuals and their Limitations
We can move towards the end by emphasizing again that correlation does not equate causation. Counterfactuals can be used to establish causation by creating controlled comparisons between scenarios with and without interventions or treatments and as a result, they can help evaluate causal relationships.
Counterfactuals can be thought of as a nifty little way of evaluating the bubbles of causal happiness. In data science, that can be an extremely helpful thing to do.
Still counterfactuals may not be enough to prove causation, it's therefore essential to be aware of several considerations:
- When talking about causality, counterfactual analysis relies heavily on the quality of data and the validity of assumptions made during modeling.
- It may often not be viable to conduct counterfactual analysis in the real world; practically or ethically.
- We haven't even talked about confounding variables that may actually be responsible. Counterfactual analysis isn't enough for those relationships and we have to use more complex tools like Structural Causal Models for these scenarios.
- Similar to other statistical analyses, counterfactual analysis should be evaluated for statistical significance to ensure that the observed effects are not due to chance.
Let's talk about some of the other applications of counterfactuals…
While in this article, we only looked at counterfactuals in Causal Inference, the topic of counterfactuals is too vast to cover. They're often used in model interpretation, risk minimization, A/B testing, bias detection in models etc.
Wrapping Up
While this post was about emphasizing the role of counterfactuals in causal inference, it's always good to keep in mind that the nuances in causation and Counterfactual analysis often rely heavily on human judgment and the accurate interpretation of domain knowledge and data.
Nevertheless, thinking counterfactually in general is something that should be a routine process for any hypothesis testing
Now that we've scratched the surface of this complex and multifaceted topic, you might be eager to explore it further. While we can't cover everything in a single post, I've compiled additional resources to satisfy your curiosity and delve deeper into the world of counterfactuals.
Other fantastic resources on Counterfactuals…
Building counterfactuals for sklearn models
GitHub – MaheepChaudhary/Causation-inComputerVision: The repository contains lists of papers on…
Don't forget to read some of my other intriguing articles!
P-Values: Understanding Statistical Significance in Plain Language
Beyond Bar Charts: Data with Sankey, Circular Packing, and Network Graphs
Feel free to share your thoughts in the comments.

