FCX Binary Constraint Training (Adult)
- humancompatible.explain.fcx.FCX_binary_generation_adult.compute_loss(model, model_out, x, target_label, normalise_weights, validity_reg, margin, adj_matrix, pred_model=None)[source]
Compute the combined ELBO, validity hinge-loss, sparsity penalty, and causal regularization for a batch of examples.
- Parameters:
model (FCX_VAE) – The VAE counterfactual generator.
model_out (dict) – Output of model.forward, containing: - em (Tensor): encoder means, shape (batch, latent_dim) - ev (Tensor): encoder variances, shape (batch, latent_dim) - z (Tensor): latent samples, list length mc_samples - x_pred (list[Tensor]): reconstructions per MC sample - mc_samples (int): number of Monte Carlo draws
x (Tensor) – Original input features, shape (batch, d).
target_label (Tensor) – True class labels (0 or 1), shape (batch,).
normalise_weights (dict[int, tuple(float, float)]) – Mapping feature index → (min, max) for proximity weighting.
validity_reg (float) – Weight of the validity hinge-loss term.
margin (float) – Hinge margin hyperparameter.
adj_matrix (Tensor) – Binary causal adjacency matrix, shape (d, d).
- Returns:
- Scalar loss combining:
reconstruction error (L1 distance on mutable features)
KL divergence
validity hinge‑loss
sparsity penalty
causal regularization
- Return type:
Tensor
- humancompatible.explain.fcx.FCX_binary_generation_adult.train_binary_fcx_vae(dataset_name: str, base_data_dir: str = 'data/', base_model_dir: str = 'models/', batch_size: int = 64, epochs: int = 50, validity: float = 20.0, feasibility: float = 1.0, margin: float = 0.5)[source]
Train the FCX‑VAE model to generate binary counterfactuals for the specified dataset.
This function performs the full training pipeline: 1. Sets random seeds and device. 2. Builds an education-level penalty dictionary. 3. Loads and preprocesses the Adult dataset, filtering only low-income samples. 4. Loads normalization weights and the pretrained black-box classifier. 5. Constructs the causal adjacency matrix and ensures it is a DAG. 6. Initializes the FCX-VAE model and its optimizer. 7. Runs a validity + feasibility constrained training loop for the given number of epochs. 8. Saves the best model checkpoint and returns the trained VAE.
- Parameters:
dataset_name (str) – Dataset identifier, e.g. ‘adult’.
base_data_dir (str) – Path to the directory containing {dataset_name}-train-set.npy, {dataset_name}-val-set.npy, normalization JSONs, and adjacency CSVs.
base_model_dir (str) – Directory to load the black‐box model from and save the final VAE checkpoint.
batch_size (int) – Mini-batch size for VAE training.
epochs (int) – Number of training epochs.
validity (float) – Weight for the validity hinge‐loss term.
feasibility (float) – Weight for the causal feasibility regularizer.
margin (float) – Margin parameter for all hinge losses.
- Returns:
The trained VAE model instance, with its state_dict saved to {base_model_dir}/{dataset_name}-margin-{margin}-feasibility-{feasibility}-validity-{validity}-epoch-{epochs}-fcx-binary.pth.
- Return type:
- humancompatible.explain.fcx.FCX_binary_generation_adult.train_constraint_loss(model, train_dataset, optimizer, normalise_weights, validity_reg, constraint_reg, margin, epochs=1000, batch_size=1024, adj_matrix=None, ed_dict=None, pred_model=None)[source]
Perform one epoch of FCX‑VAE training under causal and LOF constraints.
This function iterates over the provided training data, computes the combined VAE ELBO loss, a causal constraint hinge loss education features, and a Local Outlier Factor (LOF) anomaly penalty on the latent codes. It then backpropagates and updates the model parameters.
- Parameters:
model (FCX_VAE) – The counterfactual VAE model to train.
train_dataset (ndarray) – Array of shape (N, D+1), where the last column is the binary label used for counterfactual conditioning.
optimizer (torch.optim.Optimizer) – Optimizer instance for model parameters.
normalise_weights (dict[int, tuple[float, float]]) – Per-feature (min, max) normalization weights for continuous features.
validity_reg (float) – Weight for the validity hinge loss component.
constraint_reg (float) – Weight for the causal/LOF constraint loss.
margin (float) – Margin hyperparameter for hinge losses.
epochs (int, optional) – Number of times to sample each datapoint (MC samples). Defaults to 1000.
batch_size (int, optional) – Mini-batch size for training. Defaults to 1024.
adj_matrix (torch.Tensor, optional) – Binary adjacency matrix enforcing causal dependencies. Shape (D, D).
ed_dict (dict[int, int], optional) – Mapping from categorical feature index to an education-level penalty coefficient.
- Returns:
The sum of training losses over all mini-batches.
- Return type:
float
- class humancompatible.explain.fcx.scripts.fcx_vae_model.FCX_VAE(data_size, encoded_size, d, immutables=-2)[source]
Bases:
ModuleConditional Variational Autoencoder for generating feasible counterfactual explanations.
This model encodes an input feature vector x concatenated with a target class label c into a latent representation, then decodes back to a counterfactual example. Supports multiple Monte Carlo draws and computes ELBO and validity conditioned outputs.
- immutables
Number of initial immutable features (suffix features are mutable).
- Type:
int
- encoded_size
Dimensionality of the latent code.
- Type:
int
- data_size
Number of input features.
- Type:
int
- encoded_categorical_feature_indexes
Index ranges for one-hot categorical features.
- Type:
List[List[int]]
- encoded_continuous_feature_indexes
Indices for continuous features.
- Type:
List[int]
- encoded_start_cat
Index where categorical features begin.
- Type:
int
- encoder_mean
Network mapping input → latent mean.
- Type:
nn.Sequential
- encoder_var
Network mapping input → latent variance.
- Type:
nn.Sequential
- decoder_mean
Network mapping latent + label → reconstructed features.
- Type:
nn.Sequential
Initialize the FCX-VAE model.
- Parameters:
data_size (int) – Number of encoded input features (excluding label).
encoded_size (int) – Dimensionality of the latent space.
d (DataLoader) – Provides feature metadata (categorical splits, decoding).
immutables (int, optional) – Count of immutable feature columns at the end of x. Defaults to -2 (last two features are immutable).
- compute_elbo(x, c, pred_model, ret=False)[source]
Compute Evidence Lower Bound (ELBO) components for given inputs.
If ret is True, also return de-normalized originals, reconstructions, predicted labels, and latent code for analysis.
- Parameters:
x (torch.Tensor) – Original features, may be truncated to mutable features if ret.
c (torch.Tensor) – Target labels, shape (batch_size,).
pred_model (nn.Module) – BlackBox classifier to evaluate validity.
ret (bool) – Whether to return extras for visualization.
- Returns:
(log_px_z, kl_div, x, x_pred, cf_labels) If ret=True:
(log_px_z, kl_div, x_orig, x_pred_mod, cf_labels, z_code)
- Return type:
If ret=False
- decoder(z)[source]
Decode latent code back to feature space.
- Parameters:
z (torch.Tensor) – Concatenated latent code and label, shape (batch_size, encoded_size+1).
- Returns:
Reconstructed features, shape (batch_size, data_size).
- Return type:
torch.Tensor
- encoder(x)[source]
Encode input (x + label) into latent mean and log‑variance.
- Parameters:
x (torch.Tensor) – Concatenated input and class label, shape (batch_size, data_size+1).
- Returns:
Latent means, shape (batch_size, encoded_size). logvar (torch.Tensor): Latent log-variances, shape (batch_size, encoded_size).
- Return type:
mean (torch.Tensor)
- forward(x, c)[source]
Generate multiple Monte Carlo counterfactual draws.
- Parameters:
x (torch.Tensor) – Original features, shape (batch_size, data_size).
c (torch.Tensor) – Target class label (0/1), shape (batch_size,).
- Returns:
- {
‘em’: encoder means, ‘ev’: encoder variances, ‘z’: list of sampled latent codes, ‘x_pred’: list of reconstructed counterfactuals, ‘mc_samples’: int number of MC draws
}
- Return type:
dict
- normal_likelihood(x, mean, logvar, raxis=1)[source]
Compute the log-probability of x under N(mean, logvar).
- Parameters:
x (torch.Tensor) – Original or reconstructed features.
mean (torch.Tensor) – Latent means.
logvar (torch.Tensor) – Latent variances.
raxis (int) – Axis along which to sum log-likelihood.
- Returns:
Log-likelihood per example.
- Return type:
torch.Tensor
- sample_latent_code(mean, logvar)[source]
Perform the reparameterization trick to sample latent code.
z = mean + sqrt(logvar) * eps, where eps ~ N(0, I).
- Parameters:
mean (torch.Tensor) – Latent means.
logvar (torch.Tensor) – Latent log-variances.
- Returns:
Sampled latent code.
- Return type:
torch.Tensor
- humancompatible.explain.fcx.scripts.causal_modules.binarize_adj_matrix(adj_matrix, threshold=0.5)[source]
Converts the adjacency matrix to binary by applying a threshold.
- Parameters:
adj_matrix (np.ndarray) – Original adjacency matrix.
threshold (float) – Threshold to determine edge existence.
- Returns:
Binarized adjacency matrix.
- Return type:
np.ndarray
- humancompatible.explain.fcx.scripts.causal_modules.causal_regularization_enhanced(outputs, adj_matrix, lambda_nc=0.001, lambda_c=0.001, theta=0.2)[source]
Enhanced causal regularization to enforce dependencies in outputs based on input adjacency matrix.
- Parameters:
outputs (torch.Tensor) – Reconstructed outputs, shape (batch_size, input_dim)
adj_matrix (torch.Tensor) – Adjacency matrix from input space, shape (input_dim, input_dim)
lambda_nc (float) – Regularization strength for non-connected pairs
lambda_c (float) – Regularization strength for connected pairs
theta (float) – Covariance threshold for connected pairs
- Returns:
Enhanced regularization loss
- Return type:
torch.Tensor
- humancompatible.explain.fcx.scripts.causal_modules.ensure_dag(adj_matrix)[source]
Enforce that a given adjacency matrix represents a Directed Acyclic Graph (DAG).
- Parameters:
adj_matrix (np.ndarray) – Square binary adjacency matrix of shape (n, n), where entry (i, j) == 1 indicates a directed edge i → j.
- Returns:
A modified adjacency matrix of the same shape, guaranteed to be acyclic (a DAG).
- Return type:
np.ndarray
- class humancompatible.explain.fcx.LOFLoss.LOFLoss(n_neighbors=20, epsilon=1e-08)[source]
Bases:
ModuleLocal Outlier Factor (LOF) as a differentiable loss for anomaly detection.
This loss computes the average LOF score over a batch of input vectors, penalizing points that are outliers in the feature space.
- Parameters:
n_neighbors (int) – Number of nearest neighbors to consider (excluding self) when computing reachability distances. Default is 20.
epsilon (float) – Small constant to ensure numerical stability (no zero distances or division by zero). Default is 1e-8.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(input)[source]
Compute the LOF loss (mean LOF score) for the input batch.
- Steps:
Compute pairwise distances.
Identify the k nearest neighbors for each point (excluding self).
Compute reachability distances: max(distance_to_neighbor, neighbor_kth_distance).
Compute local reachability density (LRD) for each point.
Compute LOF score as the average ratio of neighbors’ LRD to own LRD.
Return the mean LOF score over all points.
- Parameters:
input (torch.Tensor) – Input tensor of shape (batch_size, feature_dim).
- Returns:
Scalar tensor containing the mean LOF score across the batch.
- Return type:
torch.Tensor
- pairwise_distances(x)[source]
Compute the pairwise Euclidean distances between rows of x.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, feature_dim).
- Returns:
A (batch_size × batch_size) distance matrix, where entry (i, j) is the Euclidean distance between x[i] and x[j].
- Return type:
torch.Tensor
- humancompatible.explain.fcx.scripts.datagen_func.prepare_datasets(dataset_name: str, base_dir: str = '../data/') None[source]
Prepare and save train/validation/test splits plus normalization metadata for a given dataset (Only in case the provided csv/npy files are missing or for custom solutions).
- This will:
Load one of: ‘adult’, ‘folktables_adult’, ‘census’, or ‘law’.
Filter / slice the raw DataFrame as in your original logic.
Compute median absolute deviations (MAD) and per-feature min/max.
One‑hot encode, normalize continuous features, and reorder columns.
Split into 10% test, 10% validation, and the rest training.
Save: - {base_dir}/{dataset_name}-train-set.npy - {base_dir}/{dataset_name}-val-set.npy - {base_dir}/{dataset_name}-test-set.npy - {base_dir}/{dataset_name}-normalise_weights.json - {base_dir}/{dataset_name}-mad.json
- Parameters:
dataset_name (str) – which dataset to prepare; one of ‘adult’, ‘folktables_adult’, ‘census’, or ‘law’.
base_dir (str) – directory under which to write out the .npy and .json files.