Quick introduction to generating counterfactual explanations using DiCE

[1]:
%load_ext autoreload
%autoreload 2
[2]:
# import DiCE
import dice_ml
from dice_ml.utils import helpers # helper functions

DiCE requires two inputs: a training dataset and a pre-trained ML model. When the training dataset is unknown (e.g., for privacy reasons), it can also work without access to the full dataset (see this notebook for examples).

Preliminaries: Loading a dataset and a ML model trained over it

Loading the Adult dataset

We use the “adult” income dataset from UCI Machine Learning Repository (https://archive.ics.uci.edu/ml/datasets/adult). For demonstration purposes, we transform the data as described in dice_ml.utils.helpers module.

[3]:
dataset = helpers.load_adult_income_dataset()

This dataset has 8 features. The outcome is income which is binarized to 0 (low-income, <=50K) or 1 (high-income, >50K).

[4]:
dataset.head()
[4]:
age workclass education marital_status occupation race gender hours_per_week income
0 28 Private Bachelors Single White-Collar White Female 60 0
1 30 Self-Employed Assoc Married Professional White Male 65 1
2 32 Private Some-college Married White-Collar White Male 50 0
3 20 Private Some-college Single Service White Female 35 0
4 41 Self-Employed Some-college Married White-Collar White Male 50 0
[5]:
# description of transformed features
adult_info = helpers.get_adult_data_info()
adult_info
[5]:
{'age': 'age',
 'workclass': 'type of industry (Government, Other/Unknown, Private, Self-Employed)',
 'education': 'education level (Assoc, Bachelors, Doctorate, HS-grad, Masters, Prof-school, School, Some-college)',
 'marital_status': 'marital status (Divorced, Married, Separated, Single, Widowed)',
 'occupation': 'occupation (Blue-Collar, Other/Unknown, Professional, Sales, Service, White-Collar)',
 'race': 'white or other race?',
 'gender': 'male or female?',
 'hours_per_week': 'total work hours per week',
 'income': '0 (<=50K) vs 1 (>50K)'}

Given this dataset, we construct a data object for DiCE. Since continuous and discrete features have different ways of perturbation, we need to specify the names of the continuous features. DiCE also requires the name of the output variable that the ML model will predict.

[6]:
# Step 1: dice_ml.Data
d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income')

Loading the ML model

DiCE supports sklearn, tensorflow and pytorch models.

The variable backend below indicates the implementation type of DiCE we want to use. Four backends are supported: sklearn, TensorFlow 1.x with backend=‘TF1’, Tensorflow 2.x with backend=‘TF2’, and PyTorch with backend=‘PYT’.

Below we show use a trained classification model using sklearn.

[7]:
# Split data into train and test
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.ensemble import RandomForestClassifier

target = dataset["income"]
datasetX = dataset.drop("income", axis=1)
x_train, x_test, y_train, y_test = train_test_split(datasetX,
                                                    target,
                                                    test_size = 0.2,
                                                    random_state=0,
                                                    stratify=target)

numerical=["age", "hours_per_week"]
categorical = x_train.columns.difference(numerical)
from sklearn.compose import ColumnTransformer

categorical_transformer = Pipeline(steps=[
    ('onehot', OneHotEncoder(handle_unknown='ignore'))])

transformations = ColumnTransformer(
    transformers=[
        ('cat', categorical_transformer, categorical)])

# Append classifier to preprocessing pipeline.
# Now we have a full prediction pipeline.
clf = Pipeline(steps=[('preprocessor', transformations),
                      ('classifier', RandomForestClassifier())])
model = clf.fit(x_train, y_train)

Generating counterfactual examples using DiCE

We now initialize the DiCE explainer, which needs a dataset and a model. DiCE provides local explanation for the model m and requires an query input whose outcome needs to be explained.

[8]:
# Using sklearn backend
m = dice_ml.Model(model=model, backend="sklearn")
# Using method=random for generating CFs
exp = dice_ml.Dice(d, m, method="random")

The method parameter specifies the explanation method. DiCE supports three methods for sklearn models: random sampling, genetic algorithm search, and kd-tree based generation.

The next code snippet shows how to generate and visualize counterfactuals. The first argument of the generate_counterfactuals method is the query instances on which counterfactuals are desired. This can be a dataframe with one or more rows.

Below we provide a sample input whose outcome is 0 (low-income) as per the ML model object m. Given the query input, we can now generate counterfactual explanations to show perturbed inputs from the original input where the ML model outputs class 1 (high-income). The last column shows the output of the classifier: income-output >=0.5 is class 1 and income-output<0.5 is class 0.

[9]:
e1 = exp.generate_counterfactuals(x_train[0:1], total_CFs=2, desired_class="opposite")
e1.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
age workclass education marital_status occupation race gender hours_per_week income
0 38 Private HS-grad Married Blue-Collar White Male 44 0

Diverse Counterfactual set (new outcome: 1.0)
age workclass education marital_status occupation race gender hours_per_week income
0 80.0 Self-Employed Some-college - Professional - Female 92.0 1
1 52.0 Self-Employed Some-college - Professional - Female 35.0 1

The show_only_changes parameter highlights the changes from the query instance. If you would like to see the full feature values for the counterfactuals, set it to False.

[10]:
e1.visualize_as_dataframe(show_only_changes=False)
Query instance (original outcome : 0)
age workclass education marital_status occupation race gender hours_per_week income
0 38 Private HS-grad Married Blue-Collar White Male 44 0

Diverse Counterfactual set (new outcome: 1.0)
age workclass education marital_status occupation race gender hours_per_week income
0 80.0 Self-Employed Some-college Married Professional White Female 92.0 1
1 52.0 Self-Employed Some-college Married Professional White Female 35.0 1

That’s it! You can try generating counterfactual explanations for other examples using the same code. It is also possible to restrict the features to vary while generating the counterfactuals, and to specify permitted range of features within which the counterfactual should be generated.

[11]:
# Changing only age and education
e2 = exp.generate_counterfactuals(x_train[0:1],
                                  total_CFs=2,
                                  desired_class="opposite",
                                  features_to_vary=["age", "education", "race"]
                                  )
e2.visualize_as_dataframe(show_only_changes=True)

Query instance (original outcome : 0)
age workclass education marital_status occupation race gender hours_per_week income
0 38 Private HS-grad Married Blue-Collar White Male 44 0

Diverse Counterfactual set (new outcome: 1.0)
age workclass education marital_status occupation race gender hours_per_week income
0 67.0 - Masters - - Other - - 1
1 66.0 - Prof-school - - Other - - 1
[12]:
# Restricting age to be between [20,30] and Education to be either {'Doctorate', 'Prof-school'}.
e3 = exp.generate_counterfactuals(x_train[0:1],
                                  total_CFs=2,
                                  desired_class="opposite",
                                  permitted_range={'age':[20,30], 'education':['Doctorate', 'Prof-school']})
e3.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
age workclass education marital_status occupation race gender hours_per_week income
0 38 Private HS-grad Married Blue-Collar White Male 44 0

Diverse Counterfactual set (new outcome: 1.0)
age workclass education marital_status occupation race gender hours_per_week income
0 28.0 Self-Employed Doctorate - Professional - Female 21.0 1
1 27.0 Self-Employed Doctorate - Professional - Female 50.0 1

Generating feature attributions (local and global) using DiCE

DiCE can generate feature importance scores using a summary of the counterfactuals generated. Intuitively, a feature that is changed more often when generating proximal counterfactuals for an input is locally important for causing the model’s prediction at the input. Formally, counterfactuals operationalize the necessity criterion for a model explanation: is the feature value necessary for the given model output?

For more details, refer to the paper, Towards Unifying Feature Attribution and Counterfactual Explanations: Different Means to the Same End.

Local feature importance scores

These scores are computed for a given query instance (input point) by summarizing a set of counterfactual examples around the point.

[13]:
query_instance=x_train[0:1]
imp = exp.local_feature_importance(query_instance, total_CFs=10)
print(imp.local_importance)
[{'workclass': 1.0, 'education': 1.0, 'occupation': 1.0, 'gender': 1.0, 'age': 1.0, 'hours_per_week': 1.0, 'marital_status': 0.0, 'race': 0.0}]

The total_CFs parameter denotes the number of counterfactuals that are used to create the local importance. More the better.

Global feature importance scores

A global importance score per feature can be estimated by aggregating the scores over individual inputs. The more the inputs, the better the estimate for global importance of a feature.

[14]:
query_instances=x_train[0:20]
imp = exp.global_feature_importance(query_instances)
print(imp.summary_importance)
{'hours_per_week': 0.975, 'occupation': 0.96, 'age': 0.915, 'workclass': 0.88, 'education': 0.785, 'marital_status': 0.7, 'gender': 0.675, 'race': 0.095}

Working with deep learning models (TensorFlow and PyTorch)

We now show examples of gradient-based methods with Tensorflow and Pytorch models. Since the gradient-based methods optimize the loss rather than simply sampling some points, they can be slower to generate counterfactuals. The loss is defined by three component: validity (does the CF have the desired model output), proximity (distance of CF from original point should be low), and diversity (multiple CFs should change different features). The DiCE loss formulation is described in the paper, Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations.

Below, we use a pre-trained ML model which produces high accuracy comparable to other baselines. For convenience, we include the sample trained model with the DiCE package.

Explaining a Tensorflow model

[15]:
# supress deprecation warnings from TF
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

backend = 'TF'+tf.__version__[0] # TF1
ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
# Step 2: dice_ml.Model
m = dice_ml.Model(model_path= ML_modelpath, backend=backend)
We want to note that the time required to find counterfactuals with Tensorflow 2.x's eager style of execution is significantly greater than that with TensorFlow 1.x's graph execution.

Based on the data object d and the model object m, we can now instantiate the DiCE class for generating explanations.

[16]:
# Step 3: initiate DiCE
exp = dice_ml.Dice(d, m)

Below we provide query instance as a dict.

[17]:
# query instance in the form of a dictionary or a dataframe; keys: feature name, values: feature value
query_instance = {'age':22,
                  'workclass':'Private',
                  'education':'HS-grad',
                  'marital_status':'Single',
                  'occupation':'Service',
                  'race': 'White',
                  'gender':'Female',
                  'hours_per_week': 45}
[18]:
# generate counterfactuals
dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=4, desired_class="opposite")
Diverse Counterfactuals found! total time taken: 01 min 01 sec
[19]:
# visualize the result, highlight only the changes
dice_exp.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
age workclass education marital_status occupation race gender hours_per_week income
0 22.0 Private HS-grad Single Service White Female 45.0 0.019

Diverse Counterfactual set (new outcome: 1.0)
age workclass education marital_status occupation race gender hours_per_week income
0 70.0 - Masters - White-Collar - - 51.0 0
1 - Self-Employed Doctorate Married - - - - 0
2 47.0 - - Married - - - - 0
3 36.0 - Prof-school Married - - - 62.0 0

The counterfactuals generated above are slightly different from those shown in our paper, where the loss convergence condition was made more conservative for rigorous experimentation. To replicate the results in the paper, add an argument loss_converge_maxiter=2 (the default value is 1) in the exp.generate_counterfactuals() method above. For more info, see generate_counterfactuals() method in dice_ml.dice_interfaces.dice_tensorflow.py.

Explaining a Pytorch model

Just change the backend variable to ‘PYT’ to use DiCE with PyTorch. Below, we use a pre-trained ML model in PyTorch which produces high accuracy comparable to other baselines. For convenience, we include the sample trained model with the DiCE package.

[20]:
try:
    import torch
    print('PyTorch installed.')
except ImportError as e:
    print("Import Error!", e.name, "not found.. Please install from https://pytorch.org/")
PyTorch installed.
[21]:
backend = 'PYT'
ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
m = dice_ml.Model(model_path= ML_modelpath, backend=backend)

Instantiate the DiCE class with the new PyTorch model object m.

[22]:
exp = dice_ml.Dice(d, m)
[23]:
# query instance in the form of a dictionary; keys: feature name, values: feature value
query_instance = {'age':22,
                  'workclass':'Private',
                  'education':'HS-grad',
                  'marital_status':'Single',
                  'occupation':'Service',
                  'race': 'White',
                  'gender':'Female',
                  'hours_per_week': 45}
[24]:
# generate counterfactuals
dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=4, desired_class="opposite")
/home/amit/py-envs/env3.8/lib/python3.8/site-packages/torch/autograd/__init__.py:130: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
  Variable._execution_engine.run_backward(
Diverse Counterfactuals found! total time taken: 00 min 26 sec
[25]:
# highlight only the changes
dice_exp.visualize_as_dataframe(show_only_changes=True)
Query instance (original outcome : 0)
age workclass education marital_status occupation race gender hours_per_week income
0 22.0 Private HS-grad Single Service White Female 45.0 0.0

Diverse Counterfactual set (new outcome: 1.0)
age workclass education marital_status occupation race gender hours_per_week income
0 72.0 - - Married White-Collar - - - -
1 29.0 - Prof-school Married - - Male - -
2 52.0 - Doctorate Married - - - - -
3 47.0 - Masters Married - - - 73.0 -

More resources: What’s next?

DiCE has multiple configurable options and support for different kinds of models. Follow these notebooks to learn more.

  1. You can constrain the features to vary, weigh the relative importance of different features for computing distance, or specify permitted ranges for each feature. Check out the Customizing Counterfactuals Notebook.

  2. You can use it for multi-class classification or regression models. Counterfactuals for Multi-class Classification and Regression Models Notebook.

  3. Explore the different model-agnostic explanation methods in DiCE. Model-agnostic Counterfactual Generation Methods.

  4. You can generate CFs even without access to training data (e.g., for privacy-sensitive data). DiCE with Private Data Notebook.

  5. Feasibility of counterfactuals is an important consideration. You can try out this VAE-based method that adds a CF likelihood term to the loss. VAE-based Counterfactuals (Pytorch only).