SHAP explanations¶
SHAP Explanation of a simple NN regressor
TL;DR: SHAP can decompose the value of a prediction $f(x)$, assigning one portion of it to each feature $x_i$. This is great and helps us understand how the output value was composed. However, if the model wasn't causal in the first place, we cannot interpret those causally! In other words, they are not the driving forces of each $x_i$.
In [2]:
Copied!
import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor
warnings.filterwarnings("ignore", category=ConvergenceWarning)
import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor
warnings.filterwarnings("ignore", category=ConvergenceWarning)
/Users/emiliomaddalena/Documents/github/causal-inference-studies/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Generate synthetic data¶
In [3]:
Copied!
np.random.seed(42)
n_samples = 200
X1 = np.random.uniform(0, 10, n_samples)
X2 = np.random.uniform(0, 10, n_samples)
X3 = np.random.uniform(0, 10, n_samples)
# Define the target with a non-linear relationship
def f(X1, X2, X3):
return 5 * X1**2 - 10 * X2 + 5 * X3
Y = f(X1, X2, X3) + np.random.normal(0, 2, n_samples)
X = pd.DataFrame({"X1": X1, "X2": X2, "X3": X3})
np.random.seed(42)
n_samples = 200
X1 = np.random.uniform(0, 10, n_samples)
X2 = np.random.uniform(0, 10, n_samples)
X3 = np.random.uniform(0, 10, n_samples)
# Define the target with a non-linear relationship
def f(X1, X2, X3):
return 5 * X1**2 - 10 * X2 + 5 * X3
Y = f(X1, X2, X3) + np.random.normal(0, 2, n_samples)
X = pd.DataFrame({"X1": X1, "X2": X2, "X3": X3})
Train a simple NN regressor¶
In [4]:
Copied!
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
mlp = MLPRegressor(
hidden_layer_sizes=(32, 16),
activation="relu",
max_iter=1,
warm_start=True,
random_state=42,
)
n_epochs = 10000
test_losses = []
train_losses = []
for epoch in range(n_epochs):
mlp.fit(X_train, y_train)
train_losses.append(mean_squared_error(y_train, mlp.predict(X_train)))
test_losses.append(mean_squared_error(y_test, mlp.predict(X_test)))
if (epoch + 1) % 100 == 0:
print(f"Epoch {epoch+1}, Training MSE: {test_losses[-1]:.4f}")
plt.plot(test_losses, label="Test Loss")
plt.plot(train_losses, label="Train Loss")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Test MSE")
plt.title("Test Loss Across Epochs")
plt.yscale("log")
plt.show()
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
mlp = MLPRegressor(
hidden_layer_sizes=(32, 16),
activation="relu",
max_iter=1,
warm_start=True,
random_state=42,
)
n_epochs = 10000
test_losses = []
train_losses = []
for epoch in range(n_epochs):
mlp.fit(X_train, y_train)
train_losses.append(mean_squared_error(y_train, mlp.predict(X_train)))
test_losses.append(mean_squared_error(y_test, mlp.predict(X_test)))
if (epoch + 1) % 100 == 0:
print(f"Epoch {epoch+1}, Training MSE: {test_losses[-1]:.4f}")
plt.plot(test_losses, label="Test Loss")
plt.plot(train_losses, label="Train Loss")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Test MSE")
plt.title("Test Loss Across Epochs")
plt.yscale("log")
plt.show()
Epoch 100, Training MSE: 38622.5712 Epoch 200, Training MSE: 29233.1952 Epoch 300, Training MSE: 17852.8761 Epoch 400, Training MSE: 11022.6266 Epoch 500, Training MSE: 4586.9756 Epoch 600, Training MSE: 1609.0927 Epoch 700, Training MSE: 1393.3930 Epoch 800, Training MSE: 1282.0586 Epoch 900, Training MSE: 1165.4445 Epoch 1000, Training MSE: 1063.7715 Epoch 1100, Training MSE: 971.0494 Epoch 1200, Training MSE: 886.1645 Epoch 1300, Training MSE: 806.4706 Epoch 1400, Training MSE: 738.3108 Epoch 1500, Training MSE: 687.0482 Epoch 1600, Training MSE: 644.3431 Epoch 1700, Training MSE: 608.2447 Epoch 1800, Training MSE: 577.4127 Epoch 1900, Training MSE: 548.2953 Epoch 2000, Training MSE: 520.9304 Epoch 2100, Training MSE: 494.2164 Epoch 2200, Training MSE: 469.1665 Epoch 2300, Training MSE: 444.9070 Epoch 2400, Training MSE: 421.2823 Epoch 2500, Training MSE: 397.7992 Epoch 2600, Training MSE: 373.7161 Epoch 2700, Training MSE: 356.6535 Epoch 2800, Training MSE: 345.1310 Epoch 2900, Training MSE: 334.3159 Epoch 3000, Training MSE: 321.6885 Epoch 3100, Training MSE: 307.6158 Epoch 3200, Training MSE: 292.5208 Epoch 3300, Training MSE: 277.9716 Epoch 3400, Training MSE: 268.3170 Epoch 3500, Training MSE: 261.3045 Epoch 3600, Training MSE: 253.0394 Epoch 3700, Training MSE: 247.1128 Epoch 3800, Training MSE: 243.3261 Epoch 3900, Training MSE: 240.1063 Epoch 4000, Training MSE: 235.5392 Epoch 4100, Training MSE: 230.4615 Epoch 4200, Training MSE: 225.1938 Epoch 4300, Training MSE: 220.7081 Epoch 4400, Training MSE: 216.3962 Epoch 4500, Training MSE: 212.1855 Epoch 4600, Training MSE: 207.9556 Epoch 4700, Training MSE: 203.4714 Epoch 4800, Training MSE: 199.9900 Epoch 4900, Training MSE: 196.5457 Epoch 5000, Training MSE: 193.4157 Epoch 5100, Training MSE: 190.6366 Epoch 5200, Training MSE: 189.3299 Epoch 5300, Training MSE: 188.6503 Epoch 5400, Training MSE: 188.0422 Epoch 5500, Training MSE: 187.9301 Epoch 5600, Training MSE: 187.8248 Epoch 5700, Training MSE: 187.7170 Epoch 5800, Training MSE: 187.4734 Epoch 5900, Training MSE: 187.2568 Epoch 6000, Training MSE: 187.1055 Epoch 6100, Training MSE: 187.0167 Epoch 6200, Training MSE: 186.9276 Epoch 6300, Training MSE: 186.8448 Epoch 6400, Training MSE: 186.8074 Epoch 6500, Training MSE: 186.8073 Epoch 6600, Training MSE: 186.8072 Epoch 6700, Training MSE: 186.8071 Epoch 6800, Training MSE: 186.8070 Epoch 6900, Training MSE: 186.8068 Epoch 7000, Training MSE: 186.8067 Epoch 7100, Training MSE: 186.8066 Epoch 7200, Training MSE: 186.8065 Epoch 7300, Training MSE: 186.8064 Epoch 7400, Training MSE: 186.8063 Epoch 7500, Training MSE: 186.8062 Epoch 7600, Training MSE: 186.8061 Epoch 7700, Training MSE: 186.8059 Epoch 7800, Training MSE: 186.8058 Epoch 7900, Training MSE: 186.8057 Epoch 8000, Training MSE: 186.8056 Epoch 8100, Training MSE: 186.8055 Epoch 8200, Training MSE: 186.8054 Epoch 8300, Training MSE: 186.8053 Epoch 8400, Training MSE: 186.8052 Epoch 8500, Training MSE: 186.8051 Epoch 8600, Training MSE: 186.8050 Epoch 8700, Training MSE: 186.8049 Epoch 8800, Training MSE: 186.8048 Epoch 8900, Training MSE: 186.8047 Epoch 9000, Training MSE: 186.8046 Epoch 9100, Training MSE: 186.8045 Epoch 9200, Training MSE: 186.8044 Epoch 9300, Training MSE: 186.8043 Epoch 9400, Training MSE: 186.8042 Epoch 9500, Training MSE: 186.8041 Epoch 9600, Training MSE: 186.8040 Epoch 9700, Training MSE: 186.8039 Epoch 9800, Training MSE: 186.8038 Epoch 9900, Training MSE: 186.8037 Epoch 10000, Training MSE: 186.8036
Assess performance on the test set¶
In [5]:
Copied!
idxs = np.random.choice(len(X_test), size=40, replace=False)
X_sample = X_test.iloc[idxs]
y_true = f(X_sample["X1"], X_sample["X2"], X_sample["X3"])
y_pred = mlp.predict(X_sample)
#plt.figure(figsize=(6, 6))
plt.scatter(y_true, y_pred, alpha=0.7)
plt.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--', label='Ideal')
plt.xlabel("True Values")
plt.ylabel("Predicted Values")
plt.title("True vs Predicted Values")
plt.grid()
idxs = np.random.choice(len(X_test), size=40, replace=False)
X_sample = X_test.iloc[idxs]
y_true = f(X_sample["X1"], X_sample["X2"], X_sample["X3"])
y_pred = mlp.predict(X_sample)
#plt.figure(figsize=(6, 6))
plt.scatter(y_true, y_pred, alpha=0.7)
plt.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--', label='Ideal')
plt.xlabel("True Values")
plt.ylabel("Predicted Values")
plt.title("True vs Predicted Values")
plt.grid()
Explain the model with SHAP¶
In [6]:
Copied!
# Create a SHAP explainer
explainer = shap.Explainer(mlp.predict, X)
shap_values = explainer(X)
# Create a SHAP explainer
explainer = shap.Explainer(mlp.predict, X)
shap_values = explainer(X)
In [ ]:
Copied!
# Explain the output number 20
i = 20
print(shap_values[i].data)
shap.plots.waterfall(shap_values[i])
shap.initjs()
shap.force_plot(shap_values.base_values[i], shap_values.values[i], feature_names=['X1', 'X2', 'X3'])
# Explain the output number 20
i = 20
print(shap_values[i].data)
shap.plots.waterfall(shap_values[i])
shap.initjs()
shap.force_plot(shap_values.base_values[i], shap_values.values[i], feature_names=['X1', 'X2', 'X3'])
[6.11852895 6.57612892 7.91579044]
Out[ ]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
In [10]:
Copied!
# Explain all of the examples
shap.summary_plot(shap_values, X)
# Explain all of the examples
shap.summary_plot(shap_values, X)
/var/folders/pt/2bxzkxcx2199r7zn3rhd3qt40000gn/T/ipykernel_66508/3448107941.py:2: FutureWarning: The NumPy global RNG was seeded by calling `np.random.seed`. In a future version this function will no longer use the global RNG. Pass `rng` explicitly to opt-in to the new behaviour and silence this warning. shap.summary_plot(shap_values, X)
Interpretability pitfalls¶
TBW
Conclusions¶
TBW