FCX BlackBox Model training (Adult)
- humancompatible.explain.fcx.scripts.blackbox_model_train.train_blackbox(dataset_name: str, base_data_dir: str = '../../data/', base_model_dir: str = '../models/', seed: int = 10000000, epochs: int = 100, batch_size: int | None = None, learning_rate: float | None = None) BlackBox
Train (or fine-tune) a BlackBox classifier on a specified dataset and save its weights.
- This function will:
Load the chosen dataset’s training and validation splits.
Instantiate a BlackBox MLP with input dimension matching the encoded features.
Configure an optimizer (Adam or SGD) and cross‑entropy loss.
Optionally balance the census training set by down‑sampling.
Run a standard training loop for epochs epochs, reporting training accuracy.
Evaluate on the held-out validation split, reporting validation accuracy.
Save the trained model’s state_dict() to {base_model_dir}/{dataset_name}.pth.
- Parameters:
dataset_name (str) – Which dataset to train on. One of: - ‘adult’ - ‘census’ - ‘law’ - ‘folktables_adult’
base_data_dir (str) – Path to the root data directory containing {dataset_name}-*.npy or CSV files.
base_model_dir (str) – Directory in which to save the final model checkpoint.
seed (int) – Random seed for both NumPy and PyTorch for reproducibility.
epochs (int) – Number of training epochs to run.
batch_size (int, optional) – Mini-batch size for training (defaults to dataset‑specific default).
learning_rate (float, optional) – Learning rate for the optimizer (defaults to dataset‑specific default).
- Returns:
The trained BlackBox model instance. Its weights are also saved to disk.
- Return type:
- Raises:
ValueError – If dataset_name is not one of the supported options.
- class humancompatible.explain.fcx.scripts.blackboxmodel.BlackBox(inp_shape: int)[source]
Bases:
ModuleFeedforward MLP classifier used as the oracle black-box model. This model consists of two linear layers to map the encoded feature space to binary logits.
Initialize the BlackBox classifier.
- Parameters:
inp_shape (int) – Number of input features.