Source code for humancompatible.explain.fcx.scripts.fcx_vae_model

import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable

[docs] class FCX_VAE(nn.Module): """ Conditional Variational Autoencoder for generating feasible counterfactual explanations. This model encodes an input feature vector `x` concatenated with a target class label `c` into a latent representation, then decodes back to a counterfactual example. Supports multiple Monte Carlo draws and computes ELBO and validity conditioned outputs. Attributes: immutables (int): Number of initial immutable features (suffix features are mutable). encoded_size (int): Dimensionality of the latent code. data_size (int): Number of input features. encoded_categorical_feature_indexes (List[List[int]]): Index ranges for one-hot categorical features. encoded_continuous_feature_indexes (List[int]): Indices for continuous features. encoded_start_cat (int): Index where categorical features begin. encoder_mean (nn.Sequential): Network mapping input → latent mean. encoder_var (nn.Sequential): Network mapping input → latent variance. decoder_mean (nn.Sequential): Network mapping latent + label → reconstructed features. """ def __init__(self, data_size, encoded_size, d,immutables=-2): """ Initialize the FCX-VAE model. Args: data_size (int): Number of encoded input features (excluding label). encoded_size (int): Dimensionality of the latent space. d (DataLoader): Provides feature metadata (categorical splits, decoding). immutables (int, optional): Count of immutable feature columns at the end of `x`. Defaults to -2 (last two features are immutable). """ super(FCX_VAE, self).__init__() self.immutables=immutables self.encoded_size = encoded_size self.data_size = data_size self.encoded_categorical_feature_indexes = d.get_data_params()[2] self.encoded_continuous_feature_indexes=[] for i in range(self.data_size): valid=1 for v in self.encoded_categorical_feature_indexes: if i in v: valid=0 if valid: self.encoded_continuous_feature_indexes.append(i) self.encoded_start_cat = len(self.encoded_continuous_feature_indexes) # Plus 1 to the input encoding size and data size to incorporate the target class label self.encoder_mean = nn.Sequential( nn.Linear( self.data_size+1, 20 ), nn.BatchNorm1d(20), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 20, 16 ), nn.BatchNorm1d(16), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 16, 14 ), nn.BatchNorm1d(14), nn.Dropout(0.1), nn.ReLU(), nn.Linear(14,12), nn.BatchNorm1d(12), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 12, self.encoded_size) ) self.encoder_var = nn.Sequential( nn.Linear( self.data_size+1, 20 ), nn.BatchNorm1d(20), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 20, 16 ), nn.BatchNorm1d(16), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 16, 14 ), nn.BatchNorm1d(14), nn.Dropout(0.1), nn.ReLU(), nn.Linear(14,12), nn.BatchNorm1d(12), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 12, self.encoded_size), nn.Sigmoid() ) # Plus 1 to the input encoding size and data size to incorporate the target class label self.decoder_mean = nn.Sequential( nn.Linear( self.encoded_size+1, 12 ), nn.BatchNorm1d(12), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 12, 14 ), nn.BatchNorm1d(14), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 14, 16 ), nn.BatchNorm1d(16), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 16, 20 ), nn.BatchNorm1d(20), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 20, self.data_size), nn.Sigmoid() )
[docs] def encoder(self, x): """ Encode input (x + label) into latent mean and log‑variance. Args: x (torch.Tensor): Concatenated input and class label, shape (batch_size, data_size+1). Returns: mean (torch.Tensor): Latent means, shape (batch_size, encoded_size). logvar (torch.Tensor): Latent log-variances, shape (batch_size, encoded_size). """ mean = self.encoder_mean(x) logvar = 0.5+ self.encoder_var(x) return mean, logvar
[docs] def decoder(self, z): """ Decode latent code back to feature space. Args: z (torch.Tensor): Concatenated latent code and label, shape (batch_size, encoded_size+1). Returns: torch.Tensor: Reconstructed features, shape (batch_size, data_size). """ mean = self.decoder_mean(z) return mean
[docs] def sample_latent_code(self, mean, logvar): """ Perform the reparameterization trick to sample latent code. z = mean + sqrt(logvar) * eps, where eps ~ N(0, I). Args: mean (torch.Tensor): Latent means. logvar (torch.Tensor): Latent log-variances. Returns: torch.Tensor: Sampled latent code. """ eps = torch.randn_like(logvar) return mean + torch.sqrt(logvar)*eps
[docs] def normal_likelihood(self, x, mean, logvar, raxis=1): """ Compute the log-probability of x under N(mean, logvar). Args: x (torch.Tensor): Original or reconstructed features. mean (torch.Tensor): Latent means. logvar (torch.Tensor): Latent variances. raxis (int): Axis along which to sum log-likelihood. Returns: torch.Tensor: Log-likelihood per example. """ return torch.sum( -.5 * ((x - mean)*(1./logvar)*(x-mean) + torch.log(logvar) ), axis=1)
[docs] def forward(self, x, c): """ Generate multiple Monte Carlo counterfactual draws. Args: x (torch.Tensor): Original features, shape (batch_size, data_size). c (torch.Tensor): Target class label (0/1), shape (batch_size,). Returns: dict: { 'em': encoder means, 'ev': encoder variances, 'z': list of sampled latent codes, 'x_pred': list of reconstructed counterfactuals, 'mc_samples': int number of MC draws } """ c=c.view( c.shape[0], 1 ) c=torch.tensor(c).float() res={} mc_samples=50 em, ev= self.encoder( torch.cat((x,c),1) ) res['em'] =em res['ev'] =ev res['z'] =[] res['x_pred'] =[] res['mc_samples']=mc_samples for i in range(mc_samples): z = self.sample_latent_code(em, ev) x_pred= self.decoder( torch.cat((z,c),1) ) res['z'].append(z) res['x_pred'].append(x_pred) return res
[docs] def compute_elbo(self, x, c, pred_model,ret=False): """ Compute Evidence Lower Bound (ELBO) components for given inputs. If `ret` is True, also return de-normalized originals, reconstructions, predicted labels, and latent code for analysis. Args: x (torch.Tensor): Original features, may be truncated to mutable features if `ret`. c (torch.Tensor): Target labels, shape (batch_size,). pred_model (nn.Module): BlackBox classifier to evaluate validity. ret (bool): Whether to return extras for visualization. Returns: If ret=False: (log_px_z, kl_div, x, x_pred, cf_labels) If ret=True: (log_px_z, kl_div, x_orig, x_pred_mod, cf_labels, z_code) """ c=torch.tensor(c).float() c=c.view( c.shape[0], 1 ) # Adult: -4 , #Census -7, law-2 # #immutables=-4 #-2 immutables=self.immutables if ret: x_copy=x.clone() x=x[:,:immutables] # Adult: -4 , #Census -7 em, ev = self.encoder( torch.cat((x,c),1) ) kl_divergence = 0.5*torch.mean( em**2 +ev - torch.log(ev) - 1, axis=1 ) z = self.sample_latent_code(em, ev) dm= self.decoder( torch.cat((z,c),1) ) log_px_z = torch.tensor(0.0) x_pred= dm if ret: for i in range(1): x_copy2 = x_copy.clone() x_copy2[:,:immutables]=x_pred x_pred=x_copy2 #print(x) #print(x_pred) if ret: return torch.mean(log_px_z), torch.mean(kl_divergence), x_copy, x_pred, torch.argmax( pred_model(x_pred), dim=1 ), z #em,ev#z else: return torch.mean(log_px_z), torch.mean(kl_divergence), x, x_pred, torch.argmax( pred_model(x_pred), dim=1 )
class AutoEncoder(nn.Module): def __init__(self, data_size, encoded_size, d): super(AutoEncoder, self).__init__() self.encoded_size = encoded_size self.data_size = data_size self.encoded_categorical_feature_indexes = d.get_data_params()[2] self.encoded_continuous_feature_indexes=[] for i in range(self.data_size): valid=1 for v in self.encoded_categorical_feature_indexes: if i in v: valid=0 if valid: self.encoded_continuous_feature_indexes.append(i) self.encoded_start_cat = len(self.encoded_continuous_feature_indexes) self.encoder_mean = nn.Sequential( nn.Linear( self.data_size, 20 ), nn.BatchNorm1d(20), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 20, 16 ), nn.BatchNorm1d(16), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 16, 14 ), nn.BatchNorm1d(14), nn.Dropout(0.1), nn.ReLU(), nn.Linear(14,12), nn.BatchNorm1d(12), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 12, self.encoded_size) ) self.encoder_var = nn.Sequential( nn.Linear( self.data_size, 20 ), nn.BatchNorm1d(20), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 20, 16 ), nn.BatchNorm1d(16), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 16, 14 ), nn.BatchNorm1d(14), nn.Dropout(0.1), nn.ReLU(), nn.Linear(14,12), nn.BatchNorm1d(12), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 12, self.encoded_size), nn.Sigmoid() ) self.decoder_mean = nn.Sequential( nn.Linear( self.encoded_size, 12 ), nn.BatchNorm1d(12), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 12, 14 ), nn.BatchNorm1d(14), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 14, 16 ), nn.BatchNorm1d(16), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 16, 20 ), nn.BatchNorm1d(20), nn.Dropout(0.1), nn.ReLU(), nn.Linear( 20, self.data_size), nn.Sigmoid() ) def encoder(self, x): mean = self.encoder_mean(x) logvar = 0.05+ self.encoder_var(x) return mean, logvar def decoder(self, z): mean = self.decoder_mean(z) return mean def sample_latent_code(self, mean, logvar): eps = torch.randn_like(logvar) return mean + torch.sqrt(logvar)*eps def normal_likelihood(self, x, mean, logvar, raxis=1): return torch.sum( -.5 * ((x - mean)*(1./logvar)*(x-mean) + torch.log(logvar) ), axis=1) def forward(self, x): res={} mc_samples=50 em, ev= self.encoder(x) res['em'] =em res['ev'] =ev res['z'] =[] res['x_pred'] =[] res['mc_samples']=mc_samples for i in range(mc_samples): z = self.sample_latent_code(em, ev) x_pred= self.decoder(z) res['z'].append(z) res['x_pred'].append(x_pred) return res