Source code for humancompatible.explain.glance.local_cfs.dice_method

from ..base import LocalCounterfactualMethod
import dice_ml
import pandas as pd


[docs] class DiceMethod(LocalCounterfactualMethod): """ Implementation of the Dice method for generating counterfactual instances.(https://interpret.ml/DiCE/) The Dice method uses a specified machine learning model and data to generate counterfactual examples, providing insights into how changes in feature values can influence model predictions. Methods: -------- __init__(): Initializes the DiceMethod instance. fit(model, data, outcome_name, continuous_features, feat_to_vary, random_seed=13): Fits the DiceMethod to the provided dataset, preparing the counterfactual generator. explain_instances(instances, num_counterfactuals): Generates counterfactual instances for the specified input instances. """ def __init__(self): """ Initializes a new instance of the DiceMethod class. Attributes: ---------- cf_generator : None or dice_ml.Dice Counterfactual generator instance, initially set to None. """ super().__init__() self.cf_generator = None
[docs] def fit( self, model, data, outcome_name, continuous_features, feat_to_vary, random_seed=13, ): """ Fits the DiceMethod to the provided dataset by creating a counterfactual generator. Parameters: ---------- model : object A machine learning model used for predictions. data : pd.DataFrame The dataset containing features and the outcome variable. outcome_name : str The name of the outcome variable in the dataset. continuous_features : List[str] A list of names for continuous (numerical) features. feat_to_vary : List[str] A list of feature names that can be varied to generate counterfactuals. random_seed : int, optional Seed for random number generation to ensure reproducibility, by default 13. """ dice_dataset = dice_ml.Data( dataframe=data, continuous_features=continuous_features, outcome_name=outcome_name, ) self.random_seed = random_seed self.feat_to_vary = feat_to_vary dice_model = dice_ml.Model(model=model, backend="sklearn", func=None) self.cf_generator = dice_ml.Dice(dice_dataset, dice_model, method="random")
[docs] def explain_instances( self, instances: pd.DataFrame, num_counterfactuals: int ) -> pd.DataFrame: """ Generates counterfactual instances for the specified input instances. Parameters: ---------- instances : pd.DataFrame DataFrame containing the instances for which counterfactuals are generated. num_counterfactuals : int The number of counterfactuals to generate for each instance. Returns: ------- pd.DataFrame A DataFrame containing the generated counterfactuals. Raises: ------- ValueError If the counterfactual generator has not been initialized (fit method not called). """ if self.cf_generator is None: raise ValueError("Fit the Local Counterfactual method first.") counterfactuals = self.cf_generator.generate_counterfactuals( instances, total_CFs=num_counterfactuals, desired_class=1, random_seed=self.random_seed, features_to_vary=self.feat_to_vary, posthoc_sparsity_param=None, ) return pd.concat( [ counterfactuals.cf_examples_list[i].final_cfs_df.iloc[:, :-1] for i in range(len(instances)) ], ignore_index=False, )