Source code for humancompatible.explain.fcx.LOFLoss


import torch
import torch.utils.data
from torch import nn
from torch.nn import functional as F

[docs] class LOFLoss(nn.Module): """ Local Outlier Factor (LOF) as a differentiable loss for anomaly detection. This loss computes the average LOF score over a batch of input vectors, penalizing points that are outliers in the feature space. Args: n_neighbors (int): Number of nearest neighbors to consider (excluding self) when computing reachability distances. Default is 20. epsilon (float): Small constant to ensure numerical stability (no zero distances or division by zero). Default is 1e-8. """ def __init__(self, n_neighbors=20, epsilon=1e-8): super(LOFLoss, self).__init__() self.n_neighbors = n_neighbors self.epsilon = epsilon
[docs] def pairwise_distances(self, x): """ Compute the pairwise Euclidean distances between rows of x. Args: x (torch.Tensor): Input tensor of shape (batch_size, feature_dim). Returns: torch.Tensor: A (batch_size × batch_size) distance matrix, where entry (i, j) is the Euclidean distance between x[i] and x[j]. """ x_norm = (x ** 2).sum(dim=1, keepdim=True) y_norm = x_norm.view(1, -1) dist = x_norm + y_norm - 2.0 * torch.mm(x, x.t()) dist = torch.sqrt(torch.clamp(dist, min=self.epsilon)) # Ensure no negative values before sqrt return dist
[docs] def forward(self, input): """ Compute the LOF loss (mean LOF score) for the input batch. Steps: 1. Compute pairwise distances. 2. Identify the k nearest neighbors for each point (excluding self). 3. Compute reachability distances: max(distance_to_neighbor, neighbor_kth_distance). 4. Compute local reachability density (LRD) for each point. 5. Compute LOF score as the average ratio of neighbors' LRD to own LRD. 6. Return the mean LOF score over all points. Args: input (torch.Tensor): Input tensor of shape (batch_size, feature_dim). Returns: torch.Tensor: Scalar tensor containing the mean LOF score across the batch. """ # Ensure input is a PyTorch tensor and on the correct device if not isinstance(input, torch.Tensor): input = torch.tensor(input, dtype=torch.float32) device = input.device # Compute pairwise distances distances = self.pairwise_distances(input) # Get the indices of the nearest neighbors (excluding self) knn_distances, knn_indices = torch.topk(distances, self.n_neighbors + 1, dim=1, largest=False) knn_distances, knn_indices = knn_distances[:, 1:], knn_indices[:, 1:] # Compute reachability distances reachability_distances = torch.max(knn_distances, distances.gather(1, knn_indices)) # Compute local reachability density lrd = self.n_neighbors / (reachability_distances.sum(dim=1) + self.epsilon) # Avoid division by zero # Compute LOF scores lof_scores = torch.zeros(input.size(0), device=device) for i in range(input.size(0)): lrd_ratios = lrd[knn_indices[i]] / (lrd[i] + self.epsilon) # Avoid division by zero lof_score = lrd_ratios.sum() / self.n_neighbors lof_scores[i] = lof_score # Compute the loss as the mean of LOF scores loss = torch.mean(lof_scores) return loss