Source code for humancompatible.explain.fcx.FCX_binary_generation_adult

import sys
import os
import json
import time
import numpy as np
import pandas as pd
import networkx as nx
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from dataloader import DataLoader
from fcx_vae_model import FCX_VAE
from blackboxmodel import BlackBox
from helpers import load_adult_income_dataset
from causal_modules import causal_regularization_enhanced, binarize_adj_matrix, ensure_dag
from LOFLoss import LOFLoss

cuda = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

[docs] def compute_loss( model, model_out, x, target_label, normalise_weights, validity_reg, margin,adj_matrix,pred_model=None ): """ Compute the combined ELBO, validity hinge-loss, sparsity penalty, and causal regularization for a batch of examples. Args: model (FCX_VAE): The VAE counterfactual generator. model_out (dict): Output of `model.forward`, containing: - em (Tensor): encoder means, shape (batch, latent_dim) - ev (Tensor): encoder variances, shape (batch, latent_dim) - z (Tensor): latent samples, list length mc_samples - x_pred (list[Tensor]): reconstructions per MC sample - mc_samples (int): number of Monte Carlo draws x (Tensor): Original input features, shape (batch, d). target_label (Tensor): True class labels (0 or 1), shape (batch,). normalise_weights (dict[int, tuple(float, float)]): Mapping feature index → (min, max) for proximity weighting. validity_reg (float): Weight of the validity hinge-loss term. margin (float): Hinge margin hyperparameter. adj_matrix (Tensor): Binary causal adjacency matrix, shape (d, d). Returns: Tensor: Scalar loss combining: - reconstruction error (L1 distance on mutable features) - KL divergence - validity hinge‑loss - sparsity penalty - causal regularization """ lambda_nc = 1 lambda_c = 1 em = model_out['em'] ev = model_out['ev'] z = model_out['z'] dm = model_out['x_pred'] mc_samples = model_out['mc_samples'] #KL Divergence kl_divergence = 0.5*torch.mean( em**2 +ev - torch.log(ev) - 1, axis=1 ) #Reconstruction Term #Proximity: L1 Loss x_pred = dm[0] # immutables temp_copy = x.clone() temp_copy[:,:-4] = x_pred x_pred = temp_copy reg_loss = causal_regularization_enhanced(x_pred[:,:-4], adj_matrix, lambda_nc=lambda_nc, lambda_c=lambda_c) s= model.encoded_start_cat recon_err = -torch.sum( torch.abs(x[:,s:-1] - x_pred[:,s:-1]), axis=1 ) for key in normalise_weights.keys(): recon_err+= -(normalise_weights[key][1] - normalise_weights[key][0])*torch.abs(x[:,key] - x_pred[:,key]) # Sum to 1 over the categorical indexes of a feature for v in model.encoded_categorical_feature_indexes: temp = -torch.abs( 1.0-torch.sum( x_pred[:, v[0]:v[-1]+1], axis=1) ) recon_err += temp count=0 count+= torch.sum(x_pred[:,:s]<0,axis=1).float() count+= torch.sum(x_pred[:,:s]>1,axis=1).float() #Validity temp_logits = pred_model(x_pred) validity_loss= torch.zeros(1).to(cuda) temp_1= temp_logits[target_label==1,:] temp_0= temp_logits[target_label==0,:] validity_loss += F.hinge_embedding_loss( F.sigmoid(temp_1[:,1]).to(cuda) - F.sigmoid(temp_1[:,0]).to(cuda), torch.tensor(-1).to(cuda), margin, reduction='mean') validity_loss += F.hinge_embedding_loss( F.sigmoid(temp_0[:,0]).to(cuda) - F.sigmoid(temp_0[:,1]).to(cuda), torch.tensor(-1).to(cuda), margin, reduction='mean') sparsity=torch.zeros(1).to(cuda) for sample in range(0,x.shape[0]): temp=0 for v in model.encoded_categorical_feature_indexes[:-2]: temp +=0.5*torch.sum( torch.sum( torch.norm(x_pred[sample, v[0]:v[-1]+1]-x[sample, v[0]:v[-1]+1],p=1)>0.01) )#/x.shape[0] for t in [0,1]: temp +=0.5*torch.sum( torch.sum( torch.norm(x_pred[sample, t]-x[sample, t])) ) sparsity += temp sparsity =1*(sparsity/x.shape[0]) for i in range(1,mc_samples): x_pred = dm[i] # immutable variables at the end temp_copy = x.clone() temp_copy[:,:-4] = x_pred x_pred = temp_copy reg_loss+=causal_regularization_enhanced(x_pred[:,:-4], adj_matrix, lambda_nc=lambda_nc, lambda_c=lambda_c) recon_err += -torch.sum( torch.abs(x[:,s:-1] - x_pred[:,s:-1]), axis=1 ) for key in normalise_weights.keys(): recon_err+= -(normalise_weights[key][1] - normalise_weights[key][0])*torch.abs(x[:,key] - x_pred[:,key]) # Sum to 1 over the categorical indexes of a feature for v in model.encoded_categorical_feature_indexes: temp = -torch.abs( 1.0-torch.sum( x_pred[:, v[0]:v[-1]+1], axis=1) ) recon_err += temp count+= torch.sum(x_pred[:,:s]<0,axis=1).float() count+= torch.sum(x_pred[:,:s]>1,axis=1).float() temp_logits = pred_model(x_pred) temp_1= temp_logits[target_label==1,:] temp_0= temp_logits[target_label==0,:] validity_loss += F.hinge_embedding_loss( F.sigmoid(temp_1[:,1]).to(cuda) - F.sigmoid(temp_1[:,0]).to(cuda), torch.tensor(-1).to(cuda), margin, reduction='mean') validity_loss += F.hinge_embedding_loss( F.sigmoid(temp_0[:,0]).to(cuda) - F.sigmoid(temp_0[:,1]).to(cuda), torch.tensor(-1).to(cuda), margin, reduction='mean') recon_err = recon_err / mc_samples validity_loss = -1*validity_reg*validity_loss/mc_samples reg_loss=reg_loss/mc_samples sparsity = 1*1*sparsity print('recon: ',-torch.mean(recon_err), ' KL: ', torch.mean(kl_divergence), ' Validity: ', -validity_loss,'sparsity: ',sparsity,'reg_loss: ',reg_loss) #return -torch.mean(recon_err - kl_divergence) - validity_loss + sparsity +reg_loss*5 #20 #den kserw mhpws thelei meion loss = ( -torch.mean(recon_err - kl_divergence) - validity_loss + sparsity + reg_loss * 5 ) # ensure it’s a 0‑d tensor, not a 1‑element vector return loss.squeeze()
[docs] def train_constraint_loss(model, train_dataset, optimizer, normalise_weights, validity_reg, constraint_reg, margin, epochs=1000, batch_size=1024,adj_matrix=None,ed_dict=None,pred_model=None): """ Perform one epoch of FCX‑VAE training under causal and LOF constraints. This function iterates over the provided training data, computes the combined VAE ELBO loss, a causal constraint hinge loss education features, and a Local Outlier Factor (LOF) anomaly penalty on the latent codes. It then backpropagates and updates the model parameters. Args: model (FCX_VAE): The counterfactual VAE model to train. train_dataset (ndarray): Array of shape (N, D+1), where the last column is the binary label used for counterfactual conditioning. optimizer (torch.optim.Optimizer): Optimizer instance for model parameters. normalise_weights (dict[int, tuple[float, float]]): Per-feature (min, max) normalization weights for continuous features. validity_reg (float): Weight for the validity hinge loss component. constraint_reg (float): Weight for the causal/LOF constraint loss. margin (float): Margin hyperparameter for hinge losses. epochs (int, optional): Number of times to sample each datapoint (MC samples). Defaults to 1000. batch_size (int, optional): Mini-batch size for training. Defaults to 1024. adj_matrix (torch.Tensor, optional): Binary adjacency matrix enforcing causal dependencies. Shape (D, D). ed_dict (dict[int, int], optional): Mapping from categorical feature index to an education-level penalty coefficient. Returns: float: The sum of training losses over all mini-batches. """ batch_num=0 train_loss=0.0 train_size=0 criterion=LOFLoss(n_neighbors=30) train_dataset= torch.tensor( train_dataset ).float().to(cuda) train_dataset= torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) lof_loss=0 good_cf_count=0 for train_x in enumerate(train_dataset): optimizer.zero_grad() train_x= train_x[1] train_x_back = train_x.clone() train_y = 1.0-torch.argmax( pred_model(train_x), dim=1 ) train_x = train_x[:,:-4] train_size += train_x.shape[0] out= model(train_x, train_y) train_x_back[:,:-4] = train_x train_x= train_x_back loss = compute_loss(model, out, train_x, train_y, normalise_weights, validity_reg, margin,adj_matrix,pred_model) dm = out['x_pred'] mc_samples = out['mc_samples'] x_pred = dm[0] #Age not decreasing constraint_loss = F.hinge_embedding_loss( x_pred[:,0] - train_x[:,0], torch.tensor(-1).to(cuda), 0).to(cuda) #Ed should not decrease for key in ed_dict.keys(): constraint_loss += -0.05*ed_dict[key]*torch.mean( x_pred[:,key] - train_x[:,key] ) for j in range(1, mc_samples): x_pred = dm[j] constraint_loss+= F.hinge_embedding_loss( x_pred[:,0] - train_x[:,0], torch.tensor(-1).to(cuda), 0).to(cuda) for key in ed_dict.keys(): constraint_loss += -0.05*ed_dict[key]*torch.mean( x_pred[:,key] - train_x[:,key] ) z_t = out['z'] temp_lof_loss=0 for z_temp in z_t: temp_lof_loss += criterion(z_temp) temp_lof_loss= temp_lof_loss/mc_samples constraint_loss= constraint_loss/mc_samples constraint_loss= constraint_reg*(constraint_loss) lof_loss=temp_lof_loss loss=loss+lof_loss*80 loss=loss+ torch.mean(constraint_loss) train_loss += loss.item() batch_num+=1 if loss.requires_grad: loss.backward() optimizer.step() else: pass ret= train_loss print('Train Avg Loss: ', ret, train_size) return ret
[docs] def train_binary_fcx_vae( dataset_name: str, base_data_dir: str = 'data/', base_model_dir: str = 'models/', batch_size: int = 64, epochs: int = 50, validity: float = 20.0, feasibility: float = 1.0, margin: float = 0.5, ): """ Train the FCX‑VAE model to generate binary counterfactuals for the specified dataset. This function performs the full training pipeline: 1. Sets random seeds and device. 2. Builds an education-level penalty dictionary. 3. Loads and preprocesses the Adult dataset, filtering only low-income samples. 4. Loads normalization weights and the pretrained black-box classifier. 5. Constructs the causal adjacency matrix and ensures it is a DAG. 6. Initializes the FCX-VAE model and its optimizer. 7. Runs a validity + feasibility constrained training loop for the given number of epochs. 8. Saves the best model checkpoint and returns the trained VAE. Args: dataset_name (str): Dataset identifier, e.g. `'adult'`. base_data_dir (str): Path to the directory containing `{dataset_name}-train-set.npy`, `{dataset_name}-val-set.npy`, normalization JSONs, and adjacency CSVs. base_model_dir (str): Directory to load the black‐box model from and save the final VAE checkpoint. batch_size (int): Mini-batch size for VAE training. epochs (int): Number of training epochs. validity (float): Weight for the validity hinge‐loss term. feasibility (float): Weight for the causal feasibility regularizer. margin (float): Margin parameter for all hinge losses. Returns: FCX_VAE: The trained VAE model instance, with its `state_dict` saved to `{base_model_dir}/{dataset_name}-margin-{margin}-feasibility-{feasibility}-validity-{validity}-epoch-{epochs}-fcx-binary.pth`. """ constraint_reg=feasibility # Globals and reproducibility #global pred_model torch.manual_seed(10000000) global cuda, ed_dict cuda = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Encoded features & education mapping encoded_feature_names = ['age','hours_per_week','workclass_Government','workclass_Other/Unknown', 'workclass_Private','workclass_Self-Employed','education_Assoc','education_Bachelors', 'education_Doctorate','education_HS-grad','education_Masters','education_Prof-school', 'education_School','education_Some-college','marital_status_Divorced', 'marital_status_Married','marital_status_Separated','marital_status_Single', 'marital_status_Widowed','occupation_Blue-Collar','occupation_Other/Unknown', 'occupation_Professional','occupation_Sales','occupation_Service','occupation_White-Collar', 'race_Other','race_White','gender_Female','gender_Male'] education_score = {'HS-grad':0,'School':0,'Bachelors':1,'Assoc':1,'Some-college':1,'Masters':2,'Prof-school':2,'Doctorate':3} ed_dict = {encoded_feature_names.index(f'education_{k}'):v for k,v in education_score.items()} # 1. Data prep dataset = load_adult_income_dataset() params= {'dataframe':dataset.copy(), 'continuous_features':['age','hours_per_week'], 'outcome_name':'income'} d = DataLoader(params) feat_to_change = d.get_indexes_of_features_to_vary(['age','hours_per_week','workclass','education','marital_status','occupation']) vae_train_dataset= np.load(base_data_dir+dataset_name+'-train-set.npy') vae_val_dataset= np.load(base_data_dir+dataset_name+'-val-set.npy') # CF Generation for only low to high income data points if dataset_name == 'adult': vae_train_dataset= vae_train_dataset[vae_train_dataset[:,-1]==0,:] vae_val_dataset= vae_val_dataset[vae_val_dataset[:,-1]==0,:] vae_train_dataset= vae_train_dataset[:,:-1] vae_val_dataset= vae_val_dataset[:,:-1] with open(base_data_dir+dataset_name+'-normalise_weights.json') as f: normalise_weights= json.load(f) normalise_weights = {int(k):v for k,v in normalise_weights.items()} with open(base_data_dir+dataset_name+'-mad.json') as f: mad_feature_weights= json.load(f) # GRAPH LOADING and ADJACENCY MATRIX dd_check = pd.read_csv(base_data_dir+'adult-train-set_check.csv') adj = pd.read_csv(base_data_dir+'adult_causal_graph_adjacency_matrix.csv',index_col=0) adj = adj.reindex(index=dd_check.columns, columns=dd_check.columns) if 'income' in adj.columns: adj = adj.drop('income',axis=0) adj = adj.drop('income',axis=1) adj = adj.drop('gender_Male',axis=0) adj = adj.drop('gender_Male',axis=1) adj = adj.drop('gender_Female',axis=0) adj = adj.drop('gender_Female',axis=1) adj = adj.drop('race_Other',axis=0) adj = adj.drop('race_Other',axis=1) adj = adj.drop('race_White',axis=0) adj = adj.drop('race_White',axis=1) # Create a directed graph #G = nx.from_numpy_matrix(adj.values, create_using=nx.DiGraph()) G = nx.from_numpy_array(adj.to_numpy(), create_using=nx.DiGraph()) # Check for cycles try: topo_order = nx.topological_sort(G) print("Topological Order:", list(topo_order)) except nx.NetworkXUnfeasible: print("The graph has at least one cycle.") # Visualize the graph #nx.draw(G, with_labels=True, arrows=True) #plt.show() # Create a directed graph #G = nx.from_numpy_matrix(adj.values, create_using=nx.DiGraph()) G = nx.from_numpy_array(adj.to_numpy(), create_using=nx.DiGraph()) # Detect cycles try: cycles = list(nx.find_cycle(G, orientation='original')) print("Cycles found:", cycles) # Remove the last edge in the cycle to break it if cycles: edge_to_remove = cycles[-1][0], cycles[-1][1] G.remove_edge(*edge_to_remove) print(f"Removed edge: {edge_to_remove}") except nx.exception.NetworkXNoCycle: print("No cycles found.") #adj2 = nx.to_numpy_matrix(G).astype(int) adj2 = nx.to_numpy_array(G).astype(int) adj = pd.DataFrame(adj2, index=adj.columns, columns=adj.columns) adj_values = adj.values adj_values = binarize_adj_matrix(adj_values, threshold=0.5) adj_values = ensure_dag(adj_values) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") adj_values = torch.tensor(adj_values).float().to(device) #Load Black Box Model data_size= len(d.encoded_feature_names) pred_model= BlackBox(data_size).to(cuda) path= base_model_dir + dataset_name +'.pth' pred_model.load_state_dict(torch.load(path)) pred_model.eval() # VAE MODEL wm1=1e-2 wm2=1e-2 wm3=1e-2 wm4=1e-2 data_size= len(feat_to_change) encoded_size=10 fcx_vae = FCX_VAE(data_size, encoded_size, d).to(cuda) learning_rate = 1e-2 fcx_vae_optimizer = optim.Adam([ {'params': filter(lambda p: p.requires_grad, fcx_vae.encoder_mean.parameters()),'weight_decay': wm1}, {'params': filter(lambda p: p.requires_grad, fcx_vae.encoder_var.parameters()),'weight_decay': wm2}, {'params': filter(lambda p: p.requires_grad, fcx_vae.decoder_mean.parameters()),'weight_decay': wm3} ], lr=learning_rate) #Train VAE loss_val = [] likelihood_val = [] valid_cf_count = [] patience = 0 epoch_time_list = [] for epoch in range(epochs): np.random.shuffle(vae_train_dataset) start_time = time.time() loss_val.append( train_constraint_loss( fcx_vae, vae_train_dataset, fcx_vae_optimizer, normalise_weights, validity, constraint_reg, margin, 1, batch_size,adj_values,ed_dict,pred_model) ) end_time = time.time() epoch_time_list.append(end_time-start_time) if epoch==0: best_loss= loss_val[-1] else: if loss_val[-1]-best_loss<5: best_loss= loss_val[-1] patience=0 #Saving the final model torch.save(fcx_vae.state_dict(), base_model_dir + dataset_name + '-margin-' + str(margin) + '-feasibility-' + str(feasibility) + '-validity-'+ str(validity) + '-epoch-' + str(epochs) + '-' + 'fcx-binary' + '.pth') else: patience+=1 print('----Epoch: ', epoch, ' Loss: ', loss_val[-1], ' Best: ', best_loss) if patience>5:# or epoch==12: break #mean time print("Mean time: ",np.mean(epoch_time_list)) #sum time print("Sum time: ",np.sum(epoch_time_list)) #Saving the final model torch.save(fcx_vae.state_dict(), base_model_dir + dataset_name + '-margin-' + str(margin) + '-feasibility-' + str(feasibility) + '-validity-'+ str(validity) + '-epoch-' + str(epochs) + '-' + 'fcx-binary' + '.pth') # plot loss val """ plt.plot(loss_val) plt.ylabel('loss') plt.xlabel('epoch') plt.show()""" return fcx_vae
if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--dataset_name', type=str, default='adult') parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--epoch', type=int, default=50) parser.add_argument('--validity', type=float, default=20) parser.add_argument('--feasibility', type=float, default=1) parser.add_argument('--margin', type=float, default=0.5) args = parser.parse_args() train_binary_fcx_vae( args.dataset_name, base_data_dir='data/', base_model_dir='models/', batch_size=args.batch_size, epochs=args.epoch, validity=args.validity, feasibility=args.feasibility, margin=args.margin, constraint_reg=args.constraint_reg )