FCX Unary Constraint Training (Adult)
- humancompatible.explain.fcx.FCX_unary_generation_adult.compute_loss(model, model_out, x, target_label, normalise_weights, validity_reg, margin, adj_matrix, pred_model)[source]
Compute the combined ELBO, validity hinge-loss, sparsity penalty, and causal regularization for a batch of counterfactual examples.
- This loss aggregates:
KL divergence between the encoder’s distribution and prior.
Reconstruction error (L1 proximity on mutable features).
Validity hinge-loss to enforce classifier flip.
Sparsity penalty to encourage minimal changes.
Causal regularization via a provided adjacency matrix.
- Parameters:
model (FCX_VAE) – The counterfactual VAE model instance.
model_out (dict) –
- Outputs from model.forward, containing:
’em’ (Tensor): encoder means, shape (batch, latent_dim)
’ev’ (Tensor): encoder variances, shape (batch, latent_dim)
’z’ (list[Tensor]): latent samples for each MC draw
’x_pred’ (list[Tensor]): reconstructions per MC sample
’mc_samples’ (int): number of Monte Carlo draws
x (Tensor) – Original input features, shape (batch, d).
target_label (Tensor) – True class labels (0 or 1), shape (batch,).
normalise_weights (dict[int, tuple(float, float)]) – Mapping feature index → (min, max) for proximity weighting.
validity_reg (float) – Weight of the validity hinge‐loss term.
margin (float) – Hinge loss margin hyperparameter.
adj_matrix (Tensor) – Binary causal adjacency matrix, shape (d, d).
- Returns:
Scalar total loss combining KL, reconstruction, validity, sparsity, and causal regularization.
- Return type:
Tensor
- humancompatible.explain.fcx.FCX_unary_generation_adult.train_constraint_loss(model, train_dataset, optimizer, normalise_weights, validity_reg, constraint_reg, margin, epochs=1000, batch_size=1024, adj_matrix=None, pred_model=None)[source]
Perform one epoch of FCX_VAE training under causal and LOF constraints.
- This routine:
Wraps the raw train_dataset into a DataLoader for batching.
For each batch: - Computes counterfactuals via model(train_x, train_y). - Calculates the combined ELBO + validity + sparsity + causal loss
using compute_loss.
Adds a hinge-loss enforcing age monotonicity as a causal constraint.
Computes a Local Outlier Factor (LOF) penalty on the latent code.
Backpropagates and steps the optimizer.
Accumulates and returns the total loss over all batches.
- Parameters:
model (FCX_VAE) – The VAE-based counterfactual generator.
train_dataset (array-like) – Training data array of shape (N, d+1), with labels in the last column.
optimizer (torch.optim.Optimizer) – Optimizer instance for updating VAE parameters.
normalise_weights (dict[int, tuple(float, float)]) – Feature (min, max) pairs for proximity scaling.
validity_reg (float) – Weight for the validity hinge-loss term.
constraint_reg (float) – Weight for the causal age‐constraint hinge-loss.
margin (float) – Margin hyperparameter for all hinge-loss terms.
epochs (int, optional) – Number of epochs to train (currently only one epoch is run per call).
batch_size (int, optional) – Mini-batch size for the DataLoader.
adj_matrix (torch.Tensor, optional) – Binary causal adjacency matrix for compute_loss.
- Returns:
The sum of all batch losses over the epoch.
- Return type:
float
- humancompatible.explain.fcx.FCX_unary_generation_adult.train_unary_fcx_vae(dataset_name, base_data_dir, base_model_dir, batch_size, epochs, validity, feasibility, margin)[source]
Train a unary FCX-VAE model to generate counterfactual explanations on the Adult dataset.
- This function:
Loads the Adult dataset.
Loads normalization weights and MAD feature weights.
Loads a pre-trained BlackBox classifier.
Initializes and trains the FCX-VAE under causal (+LOF) constraints by calling train_constraint_loss for the specified number of epochs.
Prints timing and loss statistics, and saves the final VAE checkpoint to disk.
- Parameters:
dataset_name (str) – Name of the dataset (e.g. “adult”).
base_data_dir (str) – Path to the directory containing {dataset_name}-*.npy and weight JSONs.
base_model_dir (str) – Directory in which to read/write model .pth files.
batch_size (int) – Mini-batch size for VAE training.
epochs (int) – Number of training epochs to run.
validity (float) – Weight for the validity hinge-loss term.
feasibility (float) – Weight for the feasibility constraint.
margin (float) – Margin hyperparameter for hinge-losses.
- Returns:
None
- class humancompatible.explain.fcx.scripts.fcx_vae_model.FCX_VAE(data_size, encoded_size, d, immutables=-2)[source]
Bases:
ModuleConditional Variational Autoencoder for generating feasible counterfactual explanations.
This model encodes an input feature vector x concatenated with a target class label c into a latent representation, then decodes back to a counterfactual example. Supports multiple Monte Carlo draws and computes ELBO and validity conditioned outputs.
- immutables
Number of initial immutable features (suffix features are mutable).
- Type:
int
- encoded_size
Dimensionality of the latent code.
- Type:
int
- data_size
Number of input features.
- Type:
int
- encoded_categorical_feature_indexes
Index ranges for one-hot categorical features.
- Type:
List[List[int]]
- encoded_continuous_feature_indexes
Indices for continuous features.
- Type:
List[int]
- encoded_start_cat
Index where categorical features begin.
- Type:
int
- encoder_mean
Network mapping input → latent mean.
- Type:
nn.Sequential
- encoder_var
Network mapping input → latent variance.
- Type:
nn.Sequential
- decoder_mean
Network mapping latent + label → reconstructed features.
- Type:
nn.Sequential
Initialize the FCX-VAE model.
- Parameters:
data_size (int) – Number of encoded input features (excluding label).
encoded_size (int) – Dimensionality of the latent space.
d (DataLoader) – Provides feature metadata (categorical splits, decoding).
immutables (int, optional) – Count of immutable feature columns at the end of x. Defaults to -2 (last two features are immutable).
- compute_elbo(x, c, pred_model, ret=False)[source]
Compute Evidence Lower Bound (ELBO) components for given inputs.
If ret is True, also return de-normalized originals, reconstructions, predicted labels, and latent code for analysis.
- Parameters:
x (torch.Tensor) – Original features, may be truncated to mutable features if ret.
c (torch.Tensor) – Target labels, shape (batch_size,).
pred_model (nn.Module) – BlackBox classifier to evaluate validity.
ret (bool) – Whether to return extras for visualization.
- Returns:
(log_px_z, kl_div, x, x_pred, cf_labels) If ret=True:
(log_px_z, kl_div, x_orig, x_pred_mod, cf_labels, z_code)
- Return type:
If ret=False
- decoder(z)[source]
Decode latent code back to feature space.
- Parameters:
z (torch.Tensor) – Concatenated latent code and label, shape (batch_size, encoded_size+1).
- Returns:
Reconstructed features, shape (batch_size, data_size).
- Return type:
torch.Tensor
- encoder(x)[source]
Encode input (x + label) into latent mean and log‑variance.
- Parameters:
x (torch.Tensor) – Concatenated input and class label, shape (batch_size, data_size+1).
- Returns:
Latent means, shape (batch_size, encoded_size). logvar (torch.Tensor): Latent log-variances, shape (batch_size, encoded_size).
- Return type:
mean (torch.Tensor)
- forward(x, c)[source]
Generate multiple Monte Carlo counterfactual draws.
- Parameters:
x (torch.Tensor) – Original features, shape (batch_size, data_size).
c (torch.Tensor) – Target class label (0/1), shape (batch_size,).
- Returns:
- {
‘em’: encoder means, ‘ev’: encoder variances, ‘z’: list of sampled latent codes, ‘x_pred’: list of reconstructed counterfactuals, ‘mc_samples’: int number of MC draws
}
- Return type:
dict
- normal_likelihood(x, mean, logvar, raxis=1)[source]
Compute the log-probability of x under N(mean, logvar).
- Parameters:
x (torch.Tensor) – Original or reconstructed features.
mean (torch.Tensor) – Latent means.
logvar (torch.Tensor) – Latent variances.
raxis (int) – Axis along which to sum log-likelihood.
- Returns:
Log-likelihood per example.
- Return type:
torch.Tensor
- sample_latent_code(mean, logvar)[source]
Perform the reparameterization trick to sample latent code.
z = mean + sqrt(logvar) * eps, where eps ~ N(0, I).
- Parameters:
mean (torch.Tensor) – Latent means.
logvar (torch.Tensor) – Latent log-variances.
- Returns:
Sampled latent code.
- Return type:
torch.Tensor
- class humancompatible.explain.fcx.LOFLoss.LOFLoss(n_neighbors=20, epsilon=1e-08)[source]
Bases:
ModuleLocal Outlier Factor (LOF) as a differentiable loss for anomaly detection.
This loss computes the average LOF score over a batch of input vectors, penalizing points that are outliers in the feature space.
- Parameters:
n_neighbors (int) – Number of nearest neighbors to consider (excluding self) when computing reachability distances. Default is 20.
epsilon (float) – Small constant to ensure numerical stability (no zero distances or division by zero). Default is 1e-8.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(input)[source]
Compute the LOF loss (mean LOF score) for the input batch.
- Steps:
Compute pairwise distances.
Identify the k nearest neighbors for each point (excluding self).
Compute reachability distances: max(distance_to_neighbor, neighbor_kth_distance).
Compute local reachability density (LRD) for each point.
Compute LOF score as the average ratio of neighbors’ LRD to own LRD.
Return the mean LOF score over all points.
- Parameters:
input (torch.Tensor) – Input tensor of shape (batch_size, feature_dim).
- Returns:
Scalar tensor containing the mean LOF score across the batch.
- Return type:
torch.Tensor
- pairwise_distances(x)[source]
Compute the pairwise Euclidean distances between rows of x.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, feature_dim).
- Returns:
A (batch_size × batch_size) distance matrix, where entry (i, j) is the Euclidean distance between x[i] and x[j].
- Return type:
torch.Tensor
- humancompatible.explain.fcx.scripts.datagen_func.prepare_datasets(dataset_name: str, base_dir: str = '../data/') None[source]
Prepare and save train/validation/test splits plus normalization metadata for a given dataset (Only in case the provided csv/npy files are missing or for custom solutions).
- This will:
Load one of: ‘adult’, ‘folktables_adult’, ‘census’, or ‘law’.
Filter / slice the raw DataFrame as in your original logic.
Compute median absolute deviations (MAD) and per-feature min/max.
One‑hot encode, normalize continuous features, and reorder columns.
Split into 10% test, 10% validation, and the rest training.
Save: - {base_dir}/{dataset_name}-train-set.npy - {base_dir}/{dataset_name}-val-set.npy - {base_dir}/{dataset_name}-test-set.npy - {base_dir}/{dataset_name}-normalise_weights.json - {base_dir}/{dataset_name}-mad.json
- Parameters:
dataset_name (str) – which dataset to prepare; one of ‘adult’, ‘folktables_adult’, ‘census’, or ‘law’.
base_dir (str) – directory under which to write out the .npy and .json files.