In my previous post on Machine Learning for Churn prediction, I showed that Machine Learning models (ML also known as AI) are the most accurate. But in the past, machine learning models suffered from being hard to explain. In my book, Fighting Churn With Data, I wrote that you should use Logistic Regression to explain churn. But I also showed that the machine learning model XGBoost gave higher accuracy. Now I recently learned about a great new technique known as SHAP! SHAP makes machine learning and AI models more explainable than they have ever been! SHAP stands for SHapely Additive exPlanations. So in this post I will demonstrate using SHAP to achieve explainable machine learning churn prediction.
SHAP analyzes how much each metric (or “feature” or “independent variable”) contributes to a machine learning prediction as if the variable were a player on a team. The analysis depends on looking at the predictions with each subset of features. It relies on clever algorithms that solve the problem exactly for tree models, and approximates it for other models. There are many good articles giving you an explanation of SHAP, so I will focus on the application to churn. I recommend:
- This blog post by Dr. Dataman on Medium which explains the general theory of SHAP values.
- The documentation of the SHAP Python package.
The SHAP Explanation for XGBoost Machine Learning Churn Prediction
Below is an example of the output of SHAP for the churn case study from the book, Fighting Churn With Data. This gives a visual illustration of the contribution of each metric in the model.
- The features are listed on the left: Each row are the results for one feature
- The colored bars are stacked dots representing the contribution of each feature to each example in a set of predictions.
- The vertical line in the center shows the baseline model prediction. If a point is to the left of the line, for that example the feature contributes to reducing the model output prediction. If a point is to the right of the line, for that example the variable contributes to increasing the model output prediction.
- The color of each point represents whether the point is from an example where that variable (feature) is high (red) or low (blue.)
This is a model for churn probability, which is the model output. These are a summary of the key findings:
- Message per month are New Friend per month are the two most influential variables. High values of these measurements are associated with lower values of the churn probability.
- Adview per post is associated with increasing the churn risk forecast for those examples with high values on this metric. But this metric is less influential than Message per Month and New Friend per month – you can see this by noting that the stacked dots are never as far from the 0.0 influence line as they are for Messages and New Friends. Dislike percent is similar, and somewhat less influential.
Why this is an improvement over previous methods
If you have ever used other methods to look at the importance and influence of features in XGBoost then you will recognize what a big advance this represents! (The methods were equally bad for ensembles of decision trees like Random Forest.) You can try try this yourself by calling the function from the xgboost module: xgboost.plot_importance(xgb_model)
). That older approach will tell you what was the most important according to a few different definitions of importance. But those methods do not give you a sense of the direction or magnitude of the influence each variable in the model. Understanding the direction and magnitude of influence is key to explainability. So this is a big improvement!
Code for SHAP Explainable XGBoost Machine Learning Churn Prediction
Now let’s look at the code. You can find this example in the Fighting Churn With Data Github repo. The code assumes there is already a saved XGBoost model and a data set. The data is necessary because the explanation runs on actual examples. But the analysis is done in 3 lines of code! It’s that easy.
import pickle
import shap
import pandas as pd
def shap_explain_xgb(model_pickle_path, data_set_path):
# loading saved model and pickle
with open(model_pickle_path, 'rb') as fid:
xgb_model = pickle.load(fid)
current_df = pd.read_csv(data_set_path)
# create the SHAP explanation summary
explainer = shap.Explainer(xgb_model)
shap_values = explainer(current_df)
shap.summary_plot(shap_values, current_df, show=False)
SHAP for Logistic Regression Churn Prediction
For comparison, here is the result from using SHAP on the Logistic Regression model. For this model, the result was already explainable using the model coefficients (as explained previously in my post on Customer churn probability forecasting). You can see that both XGBoost and Logistic Regression find many of the same variables to be influential:
- New friend per month is the most influential, and high values reduce the churn probability.
- Metric group 2, a result of dimension reduction, consists of the combination of messages and replies. (Explained in the post on Customer Behavior Correlation and Churn.)
- Adview per post is (like in the XGBoost case) the most influential metric that is negative (in the sense of high values increasing the churn probability.)
There are other differences of course, most notably the very different shape of the influence! For XGBoost (above) the influence of the variable is exactly the same for many points, because XGBoost is based on decision trees: For one cut point in a tree, all values above the cut point may reach the same terminal node. But Logistic Regression is a linear model, and increasing one measure always results in a monotonic increase in the influence. This results in the more continuous type of influence shown in the plot below.
Adding Explainable Machine Learning To Fighting Churn With Data
Unfortunately, when I wrote Fighting Churn With Data I did not know about using SHAP to explain XGBoost. So in the book, I said to use Logistic Regression as the sole method for interpreting the influence of the metrics on churn. So this new method will be a great improvement for the second edition! For now, I have included the code to produce SHAP plots for XGBoost and Logistic Regression along with rest of the code for the book in the Github repository. You can find them as the last two listings under chapter 9:
If you have the rest of the code from the book running, it should be no problem for you to run those as well. If you want more information and details about using XGBoost to predict churn, check out Chapter 9 of my book, Fighting Churn With Data. Its packed with more information about analyzing churn, coming up with good metrics (features) for predicting churn, and using all that information to actually reduce churn in practice.