Source code for humancompatible.explain.fcx.scripts.blackboxmodel

import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable


[docs] class BlackBox(nn.Module): """ Feedforward MLP classifier used as the oracle black-box model. This model consists of two linear layers to map the encoded feature space to binary logits. """ def __init__(self, inp_shape: int): """ Initialize the BlackBox classifier. Args: inp_shape (int): Number of input features. """ super(BlackBox, self).__init__() self.inp_shape = inp_shape self.hidden_dim = 10 self.predict_net = nn.Sequential( nn.Linear(self.inp_shape, self.hidden_dim), nn.Linear(self.hidden_dim, 2), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Compute logits for binary classification. Args: x (torch.Tensor): Input tensor of shape (batch_size, inp_shape). Returns: torch.Tensor: Logits of shape (batch_size, 2), where each entry corresponds to the unnormalized score for classes 0 and 1. """ return self.predict_net(x)