Source code for dice_ml.dice

"""Module pointing to different implementations of DiCE based on different
   frameworks such as Tensorflow or PyTorch or sklearn, and different methods
   such as RandomSampling, DiCEKD or DiCEGenetic"""

from dice_ml.constants import BackEndTypes, SamplingStrategy
from dice_ml.data_interfaces.private_data_interface import PrivateData
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
from dice_ml.utils.exception import UserConfigValidationException


[docs]class Dice(ExplainerBase): """An interface class to different DiCE implementations.""" def __init__(self, data_interface, model_interface, method="random", **kwargs): """Init method :param data_interface: an interface to access data related params. :param model_interface: an interface to access the output or gradients of a trained ML model. :param method: Name of the method to use for generating counterfactuals """ self.decide_implementation_type(data_interface, model_interface, method, **kwargs)
[docs] def decide_implementation_type(self, data_interface, model_interface, method, **kwargs): """Decides DiCE implementation type.""" if model_interface.backend == BackEndTypes.Sklearn: if method == SamplingStrategy.KdTree and isinstance(data_interface, PrivateData): raise UserConfigValidationException( 'Private data interface is not supported with sklearn kdtree explainer' ' since kdtree explainer needs access to entire training data') self.__class__ = decide(model_interface, method) self.__init__(data_interface, model_interface, **kwargs)
def _generate_counterfactuals(self, query_instance, total_CFs, desired_class="opposite", desired_range=None, permitted_range=None, features_to_vary="all", stopping_threshold=0.5, posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", verbose=False, **kwargs): raise NotImplementedError("This method should be implemented by the concrete classes " "that inherit from ExplainerBase")
[docs]def decide(model_interface, method): """Decides DiCE implementation type. To add new implementations of DiCE, add the class in explainer_interfaces subpackage and import-and-return the class in an elif loop as shown in the below method. """ if model_interface.backend == BackEndTypes.Sklearn: if method == SamplingStrategy.Random: # random sampling of CFs from dice_ml.explainer_interfaces.dice_random import DiceRandom return DiceRandom elif method == SamplingStrategy.Genetic: from dice_ml.explainer_interfaces.dice_genetic import DiceGenetic return DiceGenetic elif method == SamplingStrategy.KdTree: from dice_ml.explainer_interfaces.dice_KD import DiceKD return DiceKD else: raise UserConfigValidationException("Unsupported sample strategy {0} provided. " "Please choose one of {1}, {2} or {3}".format( method, SamplingStrategy.Random, SamplingStrategy.Genetic, SamplingStrategy.KdTree )) elif model_interface.backend == BackEndTypes.Tensorflow1: # pretrained Keras Sequential model with Tensorflow 1.x backend from dice_ml.explainer_interfaces.dice_tensorflow1 import \ DiceTensorFlow1 return DiceTensorFlow1 elif model_interface.backend == BackEndTypes.Tensorflow2: # pretrained Keras Sequential model with Tensorflow 2.x backend from dice_ml.explainer_interfaces.dice_tensorflow2 import \ DiceTensorFlow2 return DiceTensorFlow2 elif model_interface.backend == BackEndTypes.Pytorch: # PyTorch backend from dice_ml.explainer_interfaces.dice_pytorch import DicePyTorch return DicePyTorch else: # all other backends backend_dice = model_interface.backend['explainer'] module_name, class_name = backend_dice.split('.') module = __import__("dice_ml.explainer_interfaces." + module_name, fromlist=[class_name]) return getattr(module, class_name)