from enum import Enum
from typing import Any
import numpy as np
from spn.algorithms.Inference import EPSILON, log_likelihood
from spn.algorithms.LearningWrappers import learn_mspn
from spn.structure.Base import Context, Leaf
from spn.structure.Base import Node as SPFlow_Node
from spn.structure.Base import Product, Sum, get_topological_order
from spn.structure.StatisticalTypes import MetaType
from ..data.DataHandler import DataHandler
from ..data.Features import Binary, Categorical, Contiguous, Feature, Mixed
from ..data.Types import DataLike
[docs]
class NodeType(Enum):
SUM = 0
PRODUCT = 1
LEAF = 2
LEAF_CATEGORICAL = 3
LEAF_BINARY = 4
[docs]
class Node:
"""A representation of a node in an SPN"""
def __init__(
self,
node: SPFlow_Node,
feature_list: list[Feature],
normalize: bool,
min_density: float,
):
"""
Initializes a custom Node object from an SPFlow_Node.
This constructor wraps an SPFlow library's internal node representation
to provide a more convenient and type-aware interface for SPN nodes.
It extracts relevant information such as node type, scope, densities,
and breakpoints (for continuous leaves) or weights (for sum nodes).
Parameters:
-----------
node : SPFlow_Node
The raw node object from the SPFlow library (e.g., spn.structure.Base.Leaf,
spn.structure.Base.Product, spn.structure.Base.Sum).
feature_list : list[Feature]
A list of Feature objects (e.g., Contiguous, Categorical, Binary, Mixed)
that define the characteristics of the input features. This list is
used to determine the specific type of leaf node and its properties.
normalize : bool
A boolean indicating whether the data used to learn the SPN was normalized
to a [0, 1] range. This affects how breakpoints are handled for continuous leaves.
min_density : float
A minimum density value to use, especially for handling edge cases
or padding in histograms to ensure non-zero probabilities (log-likelihoods).
Raises:
-------
NotImplementedError
If a multivariate leaf node (a leaf node spanning multiple features)
is encountered, as it's not currently supported.
ValueError
If an unknown or unsupported SPFlow node type is provided.
"""
self.__normalize = normalize
self.__min_density = min_density
if isinstance(node, Leaf):
self.densities = list(node.densities)
if isinstance(node.scope, list):
if len(node.scope) > 1:
raise NotImplementedError("Multivariate leaves are not supported.")
self.scope = node.scope[0]
else:
self.scope = node.scope
self.feature = feature_list[self.scope]
if isinstance(self.feature, Categorical):
self.type = NodeType.LEAF_CATEGORICAL
self.options = self.feature.numeric_vals
elif isinstance(self.feature, Binary):
self.type = NodeType.LEAF_BINARY
else:
self.type = NodeType.LEAF
# print(node.id, node.breaks, node.densities)
self.discrete = self.feature.discrete
if self.discrete:
self.breaks = [b - 0.5 for b in node.breaks]
else:
self.breaks = list(node.breaks)
dens = node.densities
duplicate = np.isclose(dens[1:], dens[:-1], rtol=1e-10)
self.densities = [dens[0]] + list(np.array(dens[1:])[~duplicate])
self.breaks = (
[self.breaks[0]]
+ list(np.array(self.breaks[1:-1])[~duplicate])
+ [self.breaks[-1]]
)
elif isinstance(node, Product):
self.type = NodeType.PRODUCT
elif isinstance(node, Sum):
self.type = NodeType.SUM
self.weights = node.weights
else:
raise ValueError("")
self.name = node.name
self.id = node.id
self.predecessors = node.children if hasattr(node, "children") else []
[docs]
def get_breaks_densities(
self, span_all=True
) -> tuple[np.ndarray[float], np.ndarray[float]]:
"""
Returns the breakpoints and corresponding density values for a continuous
leaf node.
This method ensures that the breakpoints cover the entire feature range
(if `span_all` is True) and normalizes them if the SPN was learned
on normalized data. It's crucial for constructing piecewise linear
functions for log-likelihood estimation.
Parameters:
-----------
span_all : bool, optional
If True, the returned breakpoints will span the entire defined range
of the input feature (either [0, 1] if normalized, or the feature's
original bounds). If the node's internal breaks are narrower,
`min_density` is used to pad the outer regions. Defaults to True.
Returns:
--------
tuple[np.ndarray[float], np.ndarray[float]]
A tuple containing two NumPy arrays:
- The first array contains the breakpoints (x-values) for the
piecewise function, scaled to the appropriate range (0-1 if normalized,
or original bounds if not).
- The second array contains the corresponding density values (y-values)
for each segment defined by the breakpoints.
Raises:
-------
ValueError
If this method is called on a node that is not a leaf node over a
`Contiguous` feature.
AssertionError
If the feature bounds are not available for scaling when `span_all` is True.
"""
if not hasattr(self, "feature") or not isinstance(self.feature, Contiguous):
raise ValueError("Only available to leaves over contiguous features")
density_vals = self.densities
breaks = self.breaks
if span_all:
lb, ub = (0, 1) if self.__normalize else self.feature.bounds
if lb is None or ub is None:
raise AssertionError("SPN input variables must have fixed bounds.")
# if histogram is narrower than the input bounds
if lb < breaks[0]:
breaks = [lb] + breaks
density_vals = [self.__min_density] + density_vals
if ub > breaks[-1]:
breaks = breaks + [ub]
density_vals = density_vals + [self.__min_density]
# if the breaks are not normalized, normalize them now
if not self.__normalize:
breaks = self.feature.encode(breaks, normalize=True, one_hot=False)
return np.array(breaks), np.array(density_vals)
[docs]
class SPN:
"""
A wrapper class for Sum-Product Networks (SPNs).
This class facilitates learning an SPN from data using SPFlow, representing
its structure in a custom `Node` format, and performing log-likelihood inference.
It integrates with a `DataHandler` for preprocessing input data.
"""
def __init__(
self,
data: DataLike,
data_handler: DataHandler,
normalize_data: bool = False,
# trunk-ignore(ruff/B006)
learn_mspn_kwargs: dict[str, Any] = {},
):
"""
Initializes the SPN wrapper, learns an SPN from the provided data,
and constructs its internal node representation.
Parameters:
-----------
data : DataLike
The input data used to learn the SPN. This can be a NumPy array
or a pandas DataFrame, as accepted by `DataHandler`.
data_handler : DataHandler
An instance of a DataHandler class responsible for preprocessing
the input data (e.g., encoding, scaling, handling feature types).
normalize_data : bool, optional
If True, the data will be normalized to a [0, 1] range before
learning the SPN. This setting affects how continuous leaf node
breakpoints are interpreted. Defaults to False.
learn_mspn_kwargs : dict[str, Any], optional
A dictionary of keyword arguments to be passed directly to the
SPFlow's `learn_mspn` function. This allows customizing the
SPN learning process (e.g., `min_instances_slice`).
Defaults to an empty dict.
"""
types = []
domains = []
self.__feature_list = data_handler.features + [data_handler.target_feature]
for feature in self.__feature_list:
if isinstance(feature, Contiguous):
if feature.discrete:
types.append(MetaType.DISCRETE)
domains.append(np.arange(feature.bounds[0], feature.bounds[1] + 1))
else:
types.append(MetaType.REAL)
domains.append(np.asarray(feature.bounds))
elif isinstance(feature, Categorical):
types.append(MetaType.DISCRETE)
domains.append(np.asarray(feature.numeric_vals))
elif isinstance(feature, Binary):
types.append(MetaType.BINARY)
domains.append(np.asarray([0, 1]))
elif isinstance(feature, Mixed):
types.append(MetaType.REAL)
domains.append(np.asarray(feature.bounds))
else:
raise ValueError(f"Unsupported feature type of feature {feature}")
context = Context(
meta_types=types,
domains=domains,
feature_names=[f.name for f in self.__feature_list],
)
self.__normalize_data = normalize_data
enc_data = data_handler.encode_all(
data, normalize=normalize_data, one_hot=False
)
if len(domains) != data_handler.n_features + 1:
print("recomputing domains")
context.add_domains(enc_data)
self.__data_handler = data_handler
self.__mspn = learn_mspn(enc_data, context, **learn_mspn_kwargs)
self.__nodes = [
Node(node, self.__feature_list, self.__normalize_data, self.min_density)
for node in get_topological_order(self.__mspn)
]
[docs]
def compute_ll(self, data: DataLike) -> np.ndarray[float]:
"""
Computes the exact log-likelihood of the given data using the learned SPN.
Parameters:
-----------
data : DataLike
The input data for which to compute the log-likelihood. Can be a
single sample (1D array) or multiple samples (2D array).
Returns:
--------
np.ndarray[float]
The log-likelihood values for the input data. Returns a scalar if
a single sample is provided, otherwise a NumPy array of log-likelihoods.
"""
if len(data.shape) == 1:
return self.compute_ll(data.reshape(1, -1))[0]
return log_likelihood(
self.__mspn,
self.__data_handler.encode_all(
data, normalize=self.__normalize_data, one_hot=False
),
)
[docs]
def compute_max_approx(self, data: DataLike, return_all: bool = False) -> float | dict[int, float]:
"""
Computes an approximate log-likelihood for a single data sample
by replacing sum operations with a max operation in the log-domain,
just as the MIO approximation would.
This method is useful for quickly evaluating log-likelihoods without
the full log-sum-exp computation, which is often approximated in MIP
contexts. It traverses the SPN in topological order.
Parameters:
-----------
data : DataLike
A single data sample (1D NumPy array or similar) for which to
compute the approximate log-likelihood.
return_all : bool, optional
If True, returns a dictionary where keys are node IDs and values
are their computed approximate log-likelihoods. If False, returns
only the approximate log-likelihood of the root node (the final output).
Defaults to False.
Returns:
--------
float | dict[int, float]
The approximate log-likelihood of the root node (if `return_all` is False),
or a dictionary of approximate log-likelihoods for all nodes (if `return_all` is True).
Raises:
-------
ValueError
If more than one sample is provided in `data`.
"""
if len(data.shape) != 1 or (data.shape[0] != 1 and len(data.shape) == 2):
raise ValueError("Can do only one sample, so far...")
input_data = self.__data_handler.encode_all(
data.reshape(1, -1), normalize=self.__normalize_data, one_hot=False
)[0]
node_vals = {}
for node in self.nodes:
if node.type == NodeType.LEAF:
for val, b in zip(
[self.min_density] + node.densities + [self.min_density],
node.breaks + [np.inf],
):
value = np.log(val)
if b > input_data[node.scope]:
break
if node.type == NodeType.LEAF_BINARY:
value = np.log(node.densities[input_data[node.scope].astype(int)])
if node.type == NodeType.LEAF_CATEGORICAL:
value = np.log(node.densities[input_data[node.scope].astype(int)])
if node.type == NodeType.PRODUCT:
value = sum(node_vals[n.id] for n in node.predecessors)
if node.type == NodeType.SUM:
value = max(
node_vals[n.id] + np.log(w)
for n, w in zip(node.predecessors, node.weights)
)
node_vals[node.id] = value
if return_all:
return node_vals
return node_vals[self.__mspn.id]
@property
def nodes(self) -> list[Node]:
"""
Returns a list of custom Node objects representing the SPN's structure,
ordered topologically (parents appear before children).
This property ensures that the `Node` objects are created and cached
upon first access.
Returns:
--------
list[Node]
A list of `Node` objects, ordered such that dependencies are met
(i.e., children nodes appear after their parents in the list).
"""
if not hasattr(self, "SPN__nodes"):
self.__nodes = [
Node(node, self.__feature_list, self.__normalize_data, self.min_density)
for node in get_topological_order(self.__mspn)
]
return self.__nodes
@property
def min_density(self) -> float:
"""
Returns the minimum density value used for SPN calculations.
This value typically comes from SPFLow's `EPSILON`,
which is a small constant to prevent log(0) issues.
Returns:
--------
float
The minimum density value (epsilon).
"""
return EPSILON
@property
def out_node_id(self) -> int:
"""
Returns the ID of the root node (output node) of the learned SPN.
Returns:
--------
int
The ID of the SPN's root node.
"""
return self.__mspn.id
@property
def spn_model(self) -> SPFlow_Node:
"""
Returns the raw SPFlow SPN model object.
This property provides access to the underlying SPFlow representation
of the SPN, which can be useful for direct interaction with the SPFlow
library's functionalities.
Returns:
--------
SPFlow_Node
The root node of the learned SPFlow SPN.
"""
return self.__mspn