FCX BlackBox Model training (Adult)

humancompatible.explain.fcx.scripts.blackbox_model_train.train_blackbox(dataset_name: str, base_data_dir: str = '../../data/', base_model_dir: str = '../models/', seed: int = 10000000, epochs: int = 100, batch_size: int | None = None, learning_rate: float | None = None) BlackBox

Train (or fine-tune) a BlackBox classifier on a specified dataset and save its weights.

This function will:
  1. Load the chosen dataset’s training and validation splits.

  2. Instantiate a BlackBox MLP with input dimension matching the encoded features.

  3. Configure an optimizer (Adam or SGD) and cross‑entropy loss.

  4. Optionally balance the census training set by down‑sampling.

  5. Run a standard training loop for epochs epochs, reporting training accuracy.

  6. Evaluate on the held-out validation split, reporting validation accuracy.

  7. Save the trained model’s state_dict() to {base_model_dir}/{dataset_name}.pth.

Parameters:
  • dataset_name (str) – Which dataset to train on. One of: - ‘adult’ - ‘census’ - ‘law’ - ‘folktables_adult’

  • base_data_dir (str) – Path to the root data directory containing {dataset_name}-*.npy or CSV files.

  • base_model_dir (str) – Directory in which to save the final model checkpoint.

  • seed (int) – Random seed for both NumPy and PyTorch for reproducibility.

  • epochs (int) – Number of training epochs to run.

  • batch_size (int, optional) – Mini-batch size for training (defaults to dataset‑specific default).

  • learning_rate (float, optional) – Learning rate for the optimizer (defaults to dataset‑specific default).

Returns:

The trained BlackBox model instance. Its weights are also saved to disk.

Return type:

BlackBox

Raises:

ValueError – If dataset_name is not one of the supported options.

class humancompatible.explain.fcx.scripts.blackboxmodel.BlackBox(inp_shape: int)[source]

Bases: Module

Feedforward MLP classifier used as the oracle black-box model. This model consists of two linear layers to map the encoded feature space to binary logits.

Initialize the BlackBox classifier.

Parameters:

inp_shape (int) – Number of input features.

forward(x: Tensor) Tensor[source]

Compute logits for binary classification.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size, inp_shape).

Returns:

Logits of shape (batch_size, 2), where each entry

corresponds to the unnormalized score for classes 0 and 1.

Return type:

torch.Tensor