FCX Binary Constraint Training (Adult)


humancompatible.explain.fcx.FCX_binary_generation_adult.compute_loss(model, model_out, x, target_label, normalise_weights, validity_reg, margin, adj_matrix, pred_model=None)[source]

Compute the combined ELBO, validity hinge-loss, sparsity penalty, and causal regularization for a batch of examples.

Parameters:
  • model (FCX_VAE) – The VAE counterfactual generator.

  • model_out (dict) – Output of model.forward, containing: - em (Tensor): encoder means, shape (batch, latent_dim) - ev (Tensor): encoder variances, shape (batch, latent_dim) - z (Tensor): latent samples, list length mc_samples - x_pred (list[Tensor]): reconstructions per MC sample - mc_samples (int): number of Monte Carlo draws

  • x (Tensor) – Original input features, shape (batch, d).

  • target_label (Tensor) – True class labels (0 or 1), shape (batch,).

  • normalise_weights (dict[int, tuple(float, float)]) – Mapping feature index → (min, max) for proximity weighting.

  • validity_reg (float) – Weight of the validity hinge-loss term.

  • margin (float) – Hinge margin hyperparameter.

  • adj_matrix (Tensor) – Binary causal adjacency matrix, shape (d, d).

Returns:

Scalar loss combining:
  • reconstruction error (L1 distance on mutable features)

  • KL divergence

  • validity hinge‑loss

  • sparsity penalty

  • causal regularization

Return type:

Tensor

humancompatible.explain.fcx.FCX_binary_generation_adult.train_binary_fcx_vae(dataset_name: str, base_data_dir: str = 'data/', base_model_dir: str = 'models/', batch_size: int = 64, epochs: int = 50, validity: float = 20.0, feasibility: float = 1.0, margin: float = 0.5)[source]

Train the FCX‑VAE model to generate binary counterfactuals for the specified dataset.

This function performs the full training pipeline: 1. Sets random seeds and device. 2. Builds an education-level penalty dictionary. 3. Loads and preprocesses the Adult dataset, filtering only low-income samples. 4. Loads normalization weights and the pretrained black-box classifier. 5. Constructs the causal adjacency matrix and ensures it is a DAG. 6. Initializes the FCX-VAE model and its optimizer. 7. Runs a validity + feasibility constrained training loop for the given number of epochs. 8. Saves the best model checkpoint and returns the trained VAE.

Parameters:
  • dataset_name (str) – Dataset identifier, e.g. ‘adult’.

  • base_data_dir (str) – Path to the directory containing {dataset_name}-train-set.npy, {dataset_name}-val-set.npy, normalization JSONs, and adjacency CSVs.

  • base_model_dir (str) – Directory to load the black‐box model from and save the final VAE checkpoint.

  • batch_size (int) – Mini-batch size for VAE training.

  • epochs (int) – Number of training epochs.

  • validity (float) – Weight for the validity hinge‐loss term.

  • feasibility (float) – Weight for the causal feasibility regularizer.

  • margin (float) – Margin parameter for all hinge losses.

Returns:

The trained VAE model instance, with its state_dict saved to {base_model_dir}/{dataset_name}-margin-{margin}-feasibility-{feasibility}-validity-{validity}-epoch-{epochs}-fcx-binary.pth.

Return type:

FCX_VAE

humancompatible.explain.fcx.FCX_binary_generation_adult.train_constraint_loss(model, train_dataset, optimizer, normalise_weights, validity_reg, constraint_reg, margin, epochs=1000, batch_size=1024, adj_matrix=None, ed_dict=None, pred_model=None)[source]

Perform one epoch of FCX‑VAE training under causal and LOF constraints.

This function iterates over the provided training data, computes the combined VAE ELBO loss, a causal constraint hinge loss education features, and a Local Outlier Factor (LOF) anomaly penalty on the latent codes. It then backpropagates and updates the model parameters.

Parameters:
  • model (FCX_VAE) – The counterfactual VAE model to train.

  • train_dataset (ndarray) – Array of shape (N, D+1), where the last column is the binary label used for counterfactual conditioning.

  • optimizer (torch.optim.Optimizer) – Optimizer instance for model parameters.

  • normalise_weights (dict[int, tuple[float, float]]) – Per-feature (min, max) normalization weights for continuous features.

  • validity_reg (float) – Weight for the validity hinge loss component.

  • constraint_reg (float) – Weight for the causal/LOF constraint loss.

  • margin (float) – Margin hyperparameter for hinge losses.

  • epochs (int, optional) – Number of times to sample each datapoint (MC samples). Defaults to 1000.

  • batch_size (int, optional) – Mini-batch size for training. Defaults to 1024.

  • adj_matrix (torch.Tensor, optional) – Binary adjacency matrix enforcing causal dependencies. Shape (D, D).

  • ed_dict (dict[int, int], optional) – Mapping from categorical feature index to an education-level penalty coefficient.

Returns:

The sum of training losses over all mini-batches.

Return type:

float

class humancompatible.explain.fcx.scripts.fcx_vae_model.FCX_VAE(data_size, encoded_size, d, immutables=-2)[source]

Bases: Module

Conditional Variational Autoencoder for generating feasible counterfactual explanations.

This model encodes an input feature vector x concatenated with a target class label c into a latent representation, then decodes back to a counterfactual example. Supports multiple Monte Carlo draws and computes ELBO and validity conditioned outputs.

immutables

Number of initial immutable features (suffix features are mutable).

Type:

int

encoded_size

Dimensionality of the latent code.

Type:

int

data_size

Number of input features.

Type:

int

encoded_categorical_feature_indexes

Index ranges for one-hot categorical features.

Type:

List[List[int]]

encoded_continuous_feature_indexes

Indices for continuous features.

Type:

List[int]

encoded_start_cat

Index where categorical features begin.

Type:

int

encoder_mean

Network mapping input → latent mean.

Type:

nn.Sequential

encoder_var

Network mapping input → latent variance.

Type:

nn.Sequential

decoder_mean

Network mapping latent + label → reconstructed features.

Type:

nn.Sequential

Initialize the FCX-VAE model.

Parameters:
  • data_size (int) – Number of encoded input features (excluding label).

  • encoded_size (int) – Dimensionality of the latent space.

  • d (DataLoader) – Provides feature metadata (categorical splits, decoding).

  • immutables (int, optional) – Count of immutable feature columns at the end of x. Defaults to -2 (last two features are immutable).

compute_elbo(x, c, pred_model, ret=False)[source]

Compute Evidence Lower Bound (ELBO) components for given inputs.

If ret is True, also return de-normalized originals, reconstructions, predicted labels, and latent code for analysis.

Parameters:
  • x (torch.Tensor) – Original features, may be truncated to mutable features if ret.

  • c (torch.Tensor) – Target labels, shape (batch_size,).

  • pred_model (nn.Module) – BlackBox classifier to evaluate validity.

  • ret (bool) – Whether to return extras for visualization.

Returns:

(log_px_z, kl_div, x, x_pred, cf_labels) If ret=True:

(log_px_z, kl_div, x_orig, x_pred_mod, cf_labels, z_code)

Return type:

If ret=False

decoder(z)[source]

Decode latent code back to feature space.

Parameters:

z (torch.Tensor) – Concatenated latent code and label, shape (batch_size, encoded_size+1).

Returns:

Reconstructed features, shape (batch_size, data_size).

Return type:

torch.Tensor

encoder(x)[source]

Encode input (x + label) into latent mean and log‑variance.

Parameters:

x (torch.Tensor) – Concatenated input and class label, shape (batch_size, data_size+1).

Returns:

Latent means, shape (batch_size, encoded_size). logvar (torch.Tensor): Latent log-variances, shape (batch_size, encoded_size).

Return type:

mean (torch.Tensor)

forward(x, c)[source]

Generate multiple Monte Carlo counterfactual draws.

Parameters:
  • x (torch.Tensor) – Original features, shape (batch_size, data_size).

  • c (torch.Tensor) – Target class label (0/1), shape (batch_size,).

Returns:

{

‘em’: encoder means, ‘ev’: encoder variances, ‘z’: list of sampled latent codes, ‘x_pred’: list of reconstructed counterfactuals, ‘mc_samples’: int number of MC draws

}

Return type:

dict

normal_likelihood(x, mean, logvar, raxis=1)[source]

Compute the log-probability of x under N(mean, logvar).

Parameters:
  • x (torch.Tensor) – Original or reconstructed features.

  • mean (torch.Tensor) – Latent means.

  • logvar (torch.Tensor) – Latent variances.

  • raxis (int) – Axis along which to sum log-likelihood.

Returns:

Log-likelihood per example.

Return type:

torch.Tensor

sample_latent_code(mean, logvar)[source]

Perform the reparameterization trick to sample latent code.

z = mean + sqrt(logvar) * eps, where eps ~ N(0, I).

Parameters:
  • mean (torch.Tensor) – Latent means.

  • logvar (torch.Tensor) – Latent log-variances.

Returns:

Sampled latent code.

Return type:

torch.Tensor

humancompatible.explain.fcx.scripts.causal_modules.binarize_adj_matrix(adj_matrix, threshold=0.5)[source]

Converts the adjacency matrix to binary by applying a threshold.

Parameters:
  • adj_matrix (np.ndarray) – Original adjacency matrix.

  • threshold (float) – Threshold to determine edge existence.

Returns:

Binarized adjacency matrix.

Return type:

np.ndarray

humancompatible.explain.fcx.scripts.causal_modules.causal_regularization_enhanced(outputs, adj_matrix, lambda_nc=0.001, lambda_c=0.001, theta=0.2)[source]

Enhanced causal regularization to enforce dependencies in outputs based on input adjacency matrix.

Parameters:
  • outputs (torch.Tensor) – Reconstructed outputs, shape (batch_size, input_dim)

  • adj_matrix (torch.Tensor) – Adjacency matrix from input space, shape (input_dim, input_dim)

  • lambda_nc (float) – Regularization strength for non-connected pairs

  • lambda_c (float) – Regularization strength for connected pairs

  • theta (float) – Covariance threshold for connected pairs

Returns:

Enhanced regularization loss

Return type:

torch.Tensor

humancompatible.explain.fcx.scripts.causal_modules.ensure_dag(adj_matrix)[source]

Enforce that a given adjacency matrix represents a Directed Acyclic Graph (DAG).

Parameters:

adj_matrix (np.ndarray) – Square binary adjacency matrix of shape (n, n), where entry (i, j) == 1 indicates a directed edge i → j.

Returns:

A modified adjacency matrix of the same shape, guaranteed to be acyclic (a DAG).

Return type:

np.ndarray

class humancompatible.explain.fcx.LOFLoss.LOFLoss(n_neighbors=20, epsilon=1e-08)[source]

Bases: Module

Local Outlier Factor (LOF) as a differentiable loss for anomaly detection.

This loss computes the average LOF score over a batch of input vectors, penalizing points that are outliers in the feature space.

Parameters:
  • n_neighbors (int) – Number of nearest neighbors to consider (excluding self) when computing reachability distances. Default is 20.

  • epsilon (float) – Small constant to ensure numerical stability (no zero distances or division by zero). Default is 1e-8.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(input)[source]

Compute the LOF loss (mean LOF score) for the input batch.

Steps:
  1. Compute pairwise distances.

  2. Identify the k nearest neighbors for each point (excluding self).

  3. Compute reachability distances: max(distance_to_neighbor, neighbor_kth_distance).

  4. Compute local reachability density (LRD) for each point.

  5. Compute LOF score as the average ratio of neighbors’ LRD to own LRD.

  6. Return the mean LOF score over all points.

Parameters:

input (torch.Tensor) – Input tensor of shape (batch_size, feature_dim).

Returns:

Scalar tensor containing the mean LOF score across the batch.

Return type:

torch.Tensor

pairwise_distances(x)[source]

Compute the pairwise Euclidean distances between rows of x.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size, feature_dim).

Returns:

A (batch_size × batch_size) distance matrix, where entry (i, j) is the Euclidean distance between x[i] and x[j].

Return type:

torch.Tensor

humancompatible.explain.fcx.scripts.datagen_func.prepare_datasets(dataset_name: str, base_dir: str = '../data/') None[source]

Prepare and save train/validation/test splits plus normalization metadata for a given dataset (Only in case the provided csv/npy files are missing or for custom solutions).

This will:
  1. Load one of: ‘adult’, ‘folktables_adult’, ‘census’, or ‘law’.

  2. Filter / slice the raw DataFrame as in your original logic.

  3. Compute median absolute deviations (MAD) and per-feature min/max.

  4. One‑hot encode, normalize continuous features, and reorder columns.

  5. Split into 10% test, 10% validation, and the rest training.

  6. Save: - {base_dir}/{dataset_name}-train-set.npy - {base_dir}/{dataset_name}-val-set.npy - {base_dir}/{dataset_name}-test-set.npy - {base_dir}/{dataset_name}-normalise_weights.json - {base_dir}/{dataset_name}-mad.json

Parameters:
  • dataset_name (str) – which dataset to prepare; one of ‘adult’, ‘folktables_adult’, ‘census’, or ‘law’.

  • base_dir (str) – directory under which to write out the .npy and .json files.