The Spectral Dictionary: PyTorch Implementations & Factory Assembly¶
Navigation:
Theory introduction: See the Intro
Related mathematical theory: See the Mathematical Theory
Core Interface: _base_effects.py
Factory Builder: _factory.py
This document details the complete software engineering pipeline behind the TAM mathematical dictionary. For each effect, it exposes:
The Factory Mapping: How the text formula is parsed and instantiated.
The Feature Map: How the mathematical projection \(\Phi(X)\) is computed.
The Penalty Matrix: How the structural constraint \(P\) is built.
The Base Contract and Factory Assembly¶
Every mathematical basis inherits from the abstract BaseEffect class. This ensures the core solver remains agnostic to the underlying topology, rigorously satisfying Aronszajn’s conditions for a valid Reproducing Kernel Hilbert Space (RKHS) [Aronszajn, 1950].
@abstractmethod
def get_n_coeffs(self) -> int:
"""Returns the dimension of the feature space."""
raise NotImplementedError
@abstractmethod
def build_feature_map(self, x_col: torch.Tensor) -> torch.Tensor:
"""Builds the feature map matrix.
Args:
x_col (torch.Tensor): Input data tensor.
Returns:
torch.Tensor: The partial feature map matrix.
"""
raise NotImplementedError
@abstractmethod
def build_penalty_matrix(self) -> torch.Tensor:
"""Builds the square penalty matrix."""
raise NotImplementedError
The factory merges all distinct penalty matrices into a global block-diagonal system. To prevent memory exhaustion, it automatically coalesces dense blocks into a sparse coordinate tensor to safely bypass GPU VRAM limits.
def build_penalty_from_effects(effects_list: List[BaseEffect]) -> torch.Tensor:
"""
Constructs the global block-diagonal Penalty Matrix.
Aggregates individual penalty matrices into a large sparse block matrix.
Safely handles both dense and pre-sparsified matrices (like Trees) to prevent OOM.
"""
matrices = [e.build_penalty_matrix() for e in effects_list]
if not matrices:
return torch.zeros((0, 0), dtype=torch.get_default_dtype())
total_size = sum(m.shape[0] for m in matrices)
run_device = matrices[0].device
if total_size > 5000:
indices_list = []
values_list = []
current_idx = 0
for m in matrices:
k = m.shape[0]
if m.is_sparse:
m = m.coalesce()
row_idx = m.indices()[0] + current_idx
col_idx = m.indices()[1] + current_idx
indices_list.append(torch.stack([row_idx, col_idx], dim=0))
values_list.append(m.values())
else:
nz = m.nonzero(as_tuple=True)
if nz[0].numel() > 0:
row_idx = nz[0] + current_idx
col_idx = nz[1] + current_idx
indices_list.append(torch.stack([row_idx, col_idx], dim=0))
values_list.append(m[nz])
current_idx += k
if indices_list:
indices = torch.cat(indices_list, dim=1)
values = torch.cat(values_list, dim=0)
P = torch.sparse_coo_tensor(
indices, values, size=(total_size, total_size), device=run_device
)
else:
P = torch.sparse_coo_tensor(
size=(total_size, total_size), dtype=torch.get_default_dtype(), device=run_device
)
return P
else:
P = torch.zeros((total_size, total_size), dtype=torch.get_default_dtype(), device=run_device)
current_idx = 0
for m in matrices:
k = m.shape[0]
if m.is_sparse:
P[current_idx:current_idx+k, current_idx:current_idx+k] = m.to_dense()
else:
P[current_idx:current_idx+k, current_idx:current_idx+k] = m
current_idx += k
return P
Linear¶
1. Factory Parsing
if ttype == 'l':
scaled = float(params_resolved.get('scaled', np.pi))
extrap_val = params_resolved.get('extrapolate', 'continue')
effects_list.append(LinearEffect(feature_name, scaled, lambda_p, extrap_val))
2. Feature Map & 3. Penalty Matrix The linear effect applies direct spatial scaling, and its penalty is a standard Ridge (\(L_2\)) scalar to stabilize the coefficients [Hoerl and Kennard, 1970].
class LinearEffect(BaseEffect):
r"""Implements a simple Linear effect mapping directly to the scaled input space."""
def __init__(self, feature_name: str, scaled: float, lambda_p: float, extrapolate: str):
super().__init__(feature_name, "linear", lambda_p, extrapolate)
self.scaled = scaled
def get_n_coeffs(self) -> int:
return 1
def build_feature_map(self, x_col: torch.Tensor) -> torch.Tensor:
r"""Reshapes the input data to be a column vector."""
x_scaled = x_col * self.scaled
return x_scaled.unsqueeze(-1)
def build_penalty_matrix(self) -> torch.Tensor:
r"""Ridge penalty."""
return torch.tensor([[self.lambda_p]], device=TORCH_DEVICE, dtype=torch.get_default_dtype())
Fourier¶
1. Factory Parsing
elif ttype == 'f':
m = int(params_resolved.get('m', 10))
s = int(params_resolved.get('s', 1))
cyclic_raw = params_resolved.get('cyclic', 'False')
is_cyclic = str(cyclic_raw).strip().lower() in ['true', '1', 't', 'y', 'yes']
extrap_val = params_resolved.get('extrapolate', 'continue')
effects_list.append(FourierEffect(feature_name, m, s, lambda_p, is_cyclic, extrap_val))
2. Feature Map
Evaluates batched trigonometric functions, dynamically scaling angular frequencies based on the cyclic boundary flag to prevent endpoint distortions.
def build_feature_map(self, x_col: torch.Tensor) -> torch.Tensor:
r"""Builds Fourier basis functions."""
x_scaled = x_col * np.pi
x_expanded = x_scaled.unsqueeze(-1)
freqs = torch.arange(1, self.m + 1, device=x_col.device, dtype=torch.get_default_dtype())
dims_to_add = x_expanded.dim() - 1
freqs_expanded = freqs.view(*([1] * dims_to_add), -1)
if self.cyclic:
theta = x_expanded * freqs_expanded
else:
theta = x_expanded * freqs_expanded / 2
return torch.cat([torch.cos(theta), torch.sin(theta)], dim=-1)
3. Penalty Matrix Builds the purely diagonal Sobolev norm penalty scaled by \(k^{2s}\), acting as an exact analytical low-pass filter [Doumèche et al., 2025].
def build_penalty_matrix(self) -> torch.Tensor:
r"""Diagonal Sobolev penalty for Real Fourier Basis."""
freqs = torch.arange(1, self.m + 1, device=TORCH_DEVICE, dtype=torch.get_default_dtype())
penalty_half = self.lambda_p * (1 + freqs ** (2 * self.s))
diag_full = torch.cat([penalty_half, penalty_half])
return torch.diag(diag_full)
P-Splines¶
1. Factory Parsing
elif ttype == 's':
k = int(params_resolved.get('k', 10))
deg = int(params_resolved.get('deg', 3))
p = int(params_resolved.get('p', 2))
extrap_val = params_resolved.get('extrapolate', 'linear')
effects_list.append(SplineEffect(feature_name, k, deg, p, lambda_p, extrap_val))
2. Feature Map Executes the recursive Cox-de Boor algorithm dynamically on the GPU.
def build_feature_map(self, x_col: torch.Tensor) -> torch.Tensor:
r"""Builds the pure-PyTorch B-spline design matrix."""
x_shape = x_col.shape
is_dummy = (x_col.shape[-1] == 1 if x_col.dim() > 0 else False)
knots = self._get_knots(x_col, is_dummy=is_dummy)
phi_matrix = self._cox_de_boor(x_col, knots, self.spline_degree)
return phi_matrix.reshape(*x_shape, -1)
3. Penalty Matrix
Constructs the structural penalty natively using consecutive finite difference matrices torch.diff(I).
def build_penalty_matrix(self) -> torch.Tensor:
r"""Builds the finite difference penalty matrix entirely on GPU."""
n_coeffs = self.get_n_coeffs()
I = torch.eye(n_coeffs, device=TORCH_DEVICE, dtype=torch.get_default_dtype())
if self.penalty_order == 0:
D = I
else:
D = torch.diff(I, n=self.penalty_order, dim=0)
M_star_M = self.lambda_p * (D.T @ D)
return M_star_M
Chebyshev¶
1. Factory Parsing
elif ttype == 'p':
deg = int(params_resolved.get('deg', 5))
s = int(params_resolved.get('s', 0))
extrap_val = params_resolved.get('extrapolate', 'saturation')
effects_list.append(ChebyshevEffect(feature_name, deg, s, lambda_p, extrap_val))
2. Feature Map Evaluates the stable Chebyshev recurrence relation iteratively in-place. This bounds the global polynomial interpolation and explicitly mitigates Runge’s phenomenon [Rivlin, 1990].
def build_feature_map(self, x: torch.Tensor) -> torch.Tensor:
r"""
Builds the Chebyshev design matrix via stable recurrence.
Args:
x: Input data. Must be normalized to [-1, 1].
Returns:
Design matrix of shape (..., n_samples, degree).
"""
# Pre-allocate the full tensor to avoid `cat` memory spikes
out_shape = list(x.shape) + [self.degree]
phi_out = torch.empty(out_shape, dtype=torch.get_default_dtype(), device=x.device)
# Fill iteratively in-place
t_n_minus_1 = torch.ones_like(x)
t_n = x
phi_out[..., 0] = t_n
for i in range(1, self.degree):
t_next = 2 * x * t_n - t_n_minus_1
phi_out[..., i] = t_next
t_n_minus_1 = t_n
t_n = t_next
return phi_out
3. Penalty Matrix Applies a diagonal smoothness penalty mirroring the Fourier Sobolev norm.
def build_penalty_matrix(self) -> torch.Tensor:
r"""
Builds a diagonal spectral penalty matrix.
P_{kk} = lambda_p * (1 + k^{2s})
"""
degrees = torch.arange(1, self.degree + 1, device=TORCH_DEVICE)
diag_values = self.lambda_p * (1 + degrees**(2 * self.s))
return torch.diag(diag_values)
Categorical¶
1. Factory Parsing
elif ttype == 'c':
n_cat = params_resolved.get('n_cat')
if n_cat is None:
if data_info is not None and feature_name in data_info:
n_cat = data_info[feature_name]
else:
raise ValueError(f"Categorical term 'c({feature_name})' requires 'n_cat' or data context to infer it.")
else:
n_cat = int(n_cat)
topo = params_resolved.get('topo', 'nominal')
p_order = int(params_resolved.get('p_order', 1))
extrap_val = params_resolved.get('extrapolate', 'continue')
effects_list.append(CategoricalEffect(feature_name, n_cat, topo, lambda_p, p_order, extrap_val))
2. Feature Map Utilizes native one-hot embedding routing.
def build_feature_map(self, x_col: torch.Tensor) -> torch.Tensor:
r"""
Projects inputs based on topology:
- Nominal/Ordinal: One-Hot basis.
- Fourier: Sin/cos basis.
"""
if self.topology in ['nominal', 'ordinal']:
x_indices = torch.round((x_col + 1.0) / 2.0 * (self.n_categories - 1)).long()
x_safe = torch.clamp(x_indices, 0, self.n_categories - 1)
phi = torch.nn.functional.one_hot(x_safe, num_classes=self.n_categories)
return phi.to(device=TORCH_DEVICE, dtype=torch.get_default_dtype())
elif self.topology == 'fourier':
x_scaled = x_col * np.pi
x_expanded = x_scaled.unsqueeze(-1)
freqs = torch.arange(1, self.m + 1, device=x_col.device, dtype=torch.get_default_dtype())
dims_to_add = x_expanded.dim() - 1
freqs_expanded = freqs.view(*([1] * dims_to_add), -1)
theta = x_expanded * freqs_expanded / 2
return torch.cat([torch.cos(theta), torch.sin(theta)], dim=-1)
3. Penalty Matrix
Routes structurally: isotropic Ridge for nominal, and finite difference \(D^\top D\) for ordinal features to enforce sequential class transitions.
def build_penalty_matrix(self) -> torch.Tensor:
r"""Builds penalty: Identity (Nominal) or Difference (Ordinal)."""
if self.topology == 'nominal':
return torch.eye(
self.n_categories, device=TORCH_DEVICE, dtype=torch.get_default_dtype()
) * self.lambda_p
elif self.topology == 'ordinal':
# Create the identity matrix directly on the GPU
I = torch.eye(self.n_categories, device=TORCH_DEVICE, dtype=torch.get_default_dtype())
# Compute differences natively on the GPU
D = torch.diff(I, n=self.penalty_order, dim=0)
P = D.T @ D
return self.lambda_p * P
elif self.topology == 'fourier':
freqs = torch.arange(1, self.m + 1, device=TORCH_DEVICE, dtype=torch.get_default_dtype())
penalty_half = self.lambda_p * (1 + freqs ** (2 * self.s))
diag_full = torch.cat([penalty_half, penalty_half])
return torch.diag(diag_full)
return torch.zeros(1)
Radial Basis Function (RBF)¶
1. Factory Parsing
elif ttype == 'rbf':
n_centers = int(params_resolved.get('n_centers', 50))
gamma = params_resolved.get('gamma', None)
if gamma is not None: gamma = float(gamma)
nu = params_resolved.get('nu', None)
others_str = params_resolved.get('others', None)
additional_features = None
if others_str:
additional_features = [s.strip() for s in others_str.split('|') if s.strip()]
extrap_val = params_resolved.get('extrapolate', 'continue')
effects_list.append(RBFEffect(feature_name, n_centers, gamma, nu, lambda_p, additional_features, extrap_val))
2. Feature Map
Computes pairwise Euclidean distances to centroids natively via torch.cdist before applying the Gaussian or Matérn kernel [Rasmussen and Williams, 2006].
def build_feature_map(self, x_col: torch.Tensor) -> torch.Tensor:
r"""
Builds the continuous Radial Basis Function (RBF) kernel feature map.
Computes the Euclidean similarity vector between inputs and the fixed centers:
Phi(x) = [k(x, c_1), ..., k(x, c_K)]
Includes dynamic dimensional routing to guarantee compatibility with
multivariate Tensor Products (Kronecker te) and Out-Of-Distribution (OOD) Wrappers.
"""
# 1. Architectural Shape Normalization
if x_col.dim() == 1:
x_in = x_col.unsqueeze(-1)
elif x_col.dim() == 2:
if x_col.shape[-1] == len(self.input_features):
# Safeguard for the framework's (1, 1) VRAM memory probe
if self.centers is None and x_col.shape == (1, 1) and len(self.input_features) == 1:
x_in = x_col.unsqueeze(-1)
else:
x_in = x_col
else:
x_in = x_col.unsqueeze(-1)
else:
x_in = x_col
# 2. Bulletproof Dummy Pass Detection (Memory Probe)
# Prevents initialization triggers during VRAM footprint estimations
is_dummy = False
if self.centers is None:
batch_shape = list(x_in.shape[:-1]) if x_in.dim() > 2 else list(x_in.shape)
# Standardize [1] or [1,1] dummy signals
if batch_shape == [1, 1] or batch_shape == [1]:
is_dummy = True
# 3. Lazy Initialization
if self.centers is None:
c, g = self._init_params(x_in, is_dummy)
else:
c, g = self.centers, self.gamma
# Align centers with current chunk compute device
c = self._align_device(x_in, c)
# Flatten to (Total_Samples, D) for highly optimized torch.cdist compatibility
x_flat = x_in.reshape(-1, x_in.shape[-1])
# Compute pairwise Euclidean distance natively on GPU
dists_flat = torch.cdist(x_flat, c, p=2.0)
# Reshape exactly back to the routed topological batch structure
# Output Shape: (..., N_samples, N_centers)
output_shape = list(x_in.shape[:-1]) + [c.shape[0]]
dists = dists_flat.view(*output_shape)
# 4. Apply the strictly positive Kernel Function
if self.nu is not None:
phi = self._matern_kernel(dists, g)
else:
# Standard Gaussian Kernel: exp(-gamma * d^2)
phi = torch.exp(-g * (dists ** 2))
return phi
3. Penalty Matrix Applies an isotropic identity penalty over the spatial prototypes.
def build_penalty_matrix(self) -> torch.Tensor:
r"""
Constructs the Ridge penalty matrix (Identity).
Since RBF centers are isotropic, we penalize the magnitude
of coefficients uniformly.
"""
return torch.eye(
self.n_centers, device=TORCH_DEVICE, dtype=torch.get_default_dtype()
) * self.lambda_p
Wavelets¶
1. Factory Parsing
elif ttype == 'w':
n_scales = int(params_resolved.get('n_scales', 5))
n_locs = int(params_resolved.get('n_locations', 20))
extrap_val = params_resolved.get('extrapolate', 'continue')
effects_list.append(WaveletEffect(feature_name, n_scales, n_locs, lambda_p, extrap_val))
2. Feature Map
Leverages massive .unsqueeze() tensor broadcasting to compute the entire time-scale grid for the Ricker wavelet in a single pass [Torrence and Compo, 1998].
def build_feature_map(self, x_col: torch.Tensor) -> torch.Tensor:
"""Builds the pure-tensor Wavelet feature map."""
# Detect if this is a 1-sample dummy tensor used for memory estimation
is_dummy = x_col.shape[-1] == 1 if x_col.dim() > 0 else False
if self.locations is None:
locs, r_scales = self._init_grid(x_col, is_dummy)
else:
locs, r_scales = self.locations, self.real_scales
# --- Full Tensor Broadcasting Fix (No CPU/GPU cat memory spikes) ---
# Expand x_col for broadcasting: (..., N) -> (..., N, 1, 1)
x_expanded = x_col.unsqueeze(-1).unsqueeze(-1)
# Scales: (S) -> (1, S, 1)
s = r_scales.view(1, -1, 1)
# Locations: (L) -> (1, 1, L)
locs_expanded = locs.view(1, 1, -1)
# Compute the entire Time-Scale grid simultaneously on GPU
z = (x_expanded - locs_expanded) / s
psi = (1 - z**2) * torch.exp(-0.5 * z**2) / torch.sqrt(s)
return psi.reshape(*x_col.shape, -1)
3. Penalty Matrix Enforces the sparsity prior via a scale-dependent diagonal “whitening” penalty to isolate transient shocks [Donoho and Johnstone, 1994].
def build_penalty_matrix(self) -> torch.Tensor:
r"""
Builds the scale-dependent penalty matrix.
"""
blocks = []
for s_factor in self.scale_factors:
weight = self.lambda_p / (s_factor ** 2)
block = torch.ones(
self.n_locations, device=TORCH_DEVICE, dtype=torch.get_default_dtype()
) * weight
blocks.append(block)
diag = torch.cat(blocks)
return torch.diag(diag)
Neural¶
1. Factory Parsing
elif ttype == 'n':
n_neurons = int(params_resolved.get('n_neurons', 500))
act = params_resolved.get('act', 'relu')
seed = int(params_resolved.get('seed', 42))
n_hidden_layers = int(params_resolved.get('n_hidden_layers', 1))
others_str = params_resolved.get('others', None)
additional_features = None
if others_str:
additional_features = [s.strip() for s in others_str.split('|') if s.strip()]
extrap_val = params_resolved.get('extrapolate', 'linear')
effects_list.append(NeuralEffect(feature_name, n_neurons, act, lambda_p, additional_features, seed, n_hidden_layers, extrap_val))
2. Feature Map Projects features through a frozen Neural Network.
def build_feature_map(self, x_col: torch.Tensor) -> torch.Tensor:
r"""
Executes the Explicit Primal Tensorization (EPT).
Projects input data through the frozen multi-layer network, iteratively applying
the activation function $\sigma(z^{(l-1)} W^{(l)} + b^{(l)})$. Returns the final
hidden layer scaled by $1/\sqrt{N_L}$ to normalize the final output variance.
This method natively resolves tensor broadcasting ambiguities when called by
independent sub-systems.
Args:
x_col (torch.Tensor): Input tensor of varying dimensionality.
Returns:
torch.Tensor: The finite-dimensional Primal block $\phi_{neural}(x)$.
"""
# 1. Architectural Shape Normalization
if x_col.dim() == 1:
# 1D: From OOD Wrapper (Univariate) -> [N_ood]
x_in = x_col.unsqueeze(-1)
elif x_col.dim() == 2:
# 2D: Resolve ambiguity between OOD Wrapper [N_ood, Features] and te() [Batch, Time]
if x_col.shape[-1] == len(self.input_features):
# Safeguard for the framework's (1, 1) VRAM memory probe
if self.weights_list is None and x_col.shape == (1, 1) and len(self.input_features) == 1:
x_in = x_col.unsqueeze(-1)
else:
x_in = x_col # Features are perfectly intact
else:
x_in = x_col.unsqueeze(-1) # Feature dimension was stripped by te()
else:
# 3D+: Natively structured from the _factory.py matrix builder
x_in = x_col
# 2. Lazy Initialization
if self.weights_list is None:
input_dim = x_in.shape[-1]
self._init_random_weights(input_dim)
# 3. Deep Linear Projection & Activation Loop (Explicit Primal Tensorization)
phi = x_in
for w, b in zip(self.weights_list, self.bias_list):
w_aligned, b_aligned = self._align_device(x_in, w, b)
projection = phi @ w_aligned + b_aligned
if self.activation == 'relu':
phi = torch.relu(projection)
elif self.activation == 'cos':
phi = torch.cos(projection)
elif self.activation == 'tanh':
phi = torch.tanh(projection)
else:
raise ValueError(f"Unknown activation function: {self.activation}")
# Return final scaled features (1 / sqrt(N_L))
return (phi * self.scale)
3. Penalty Matrix Applies an isotropic Ridge penalty (\(L_2\)) to the final linear readout coefficients.
def build_penalty_matrix(self) -> torch.Tensor:
r"""
Constructs the mathematically optimal penalization for the Primal block.
Because learning is restricted exclusively to the final readout layer $\theta$,
the optimization problem collapses into a convex quadratic form. The penalty
is an isotropic Ridge penalty ($P_{neural} = \lambda_p I$) which, combined with
the variance scaling, guarantees operation within a valid RKHS.
"""
return torch.eye(
self.n_neurons, device=TORCH_DEVICE, dtype=torch.get_default_dtype()
) * self.lambda_p
Tree / Random Forest¶
1. Factory Parsing
elif ttype == 't':
n_trees = int(params_resolved.get('n_trees', 50))
max_depth = int(params_resolved.get('max_depth', 5))
max_leaves_raw = params_resolved.get('max_leaves', None)
max_leaves = int(max_leaves_raw) if max_leaves_raw is not None else None
seed = int(params_resolved.get('seed', 42))
sp_alpha = float(params_resolved.get('sp_alpha', 0))
split_strategy = params_resolved.get('split_strategy', 'uniform').lower().strip()
others_str = params_resolved.get('others', None)
additional_features = None
if others_str:
additional_features = [s.strip() for s in others_str.split('|') if s.strip()]
extrap_val = params_resolved.get('extrapolate', 'continue')
effects_list.append(TreeEffect(feature_name, n_trees, max_depth, max_leaves, lambda_p, additional_features, seed, extrap_val, sp_alpha, split_strategy))
2. Feature Map
Converts Oblivious Trees into sparse Euclidean bit-strings, modeling Random Binning Features [Wu et al., 2016]. Scales the output destructively in-place using .mul_() to safely bypass memory limits.
def build_feature_map(self, x_col: torch.Tensor) -> torch.Tensor:
r"""
Projects inputs into the sparse bit-string Primal representation natively on GPU.
Computes the mutually exclusive indicator functions for every tree simultaneously.
"""
# 1. Architectural Shape Normalization (The Ultimate Router)
if x_col.dim() == 1:
# 1D: From OOD Wrapper (Univariate) -> [N_ood]
batch_shape = list(x_col.shape)
x_in = x_col.unsqueeze(-1)
elif x_col.dim() == 2:
if x_col.shape[-1] == len(self.input_features):
# Ambiguity Resolution: Is it the te() dummy pass [1, 1], or OOD Wrapper [N_ood, Features]?
if self.split_features is None and x_col.shape == (1, 1) and len(self.input_features) == 1:
batch_shape = list(x_col.shape)
x_in = x_col.unsqueeze(-1)
else:
batch_shape = list(x_col.shape[:-1])
x_in = x_col
else:
# 2D: From te() regular pass [Batch, Time]
batch_shape = list(x_col.shape)
x_in = x_col.unsqueeze(-1)
else:
# 3D+: From _factory.py [Batch, Time, Features]
batch_shape = list(x_col.shape[:-1])
x_in = x_col
# 2. Bulletproof Dummy Pass Detection (Memory Probe)
is_dummy = False
if self.split_features is None:
if batch_shape == [1, 1]:
is_dummy = True
if self.split_features is None:
self._init_forest(x_in, is_dummy)
if is_dummy:
return torch.zeros(
*batch_shape, self.total_leaves,
device=x_in.device, dtype=torch.get_default_dtype()
)
# 3. Vectorized Evaluation
x_expanded = x_in.unsqueeze(-2).unsqueeze(-2)
if self.is_oblivious_binary:
index_shape = batch_shape + [self.n_trees, self.max_depth, 1]
split_feat_expanded = self.split_features.view(
*([1] * len(batch_shape)), self.n_trees, self.max_depth, 1
).expand(*index_shape)
input_expanded = x_expanded.expand(
*batch_shape, self.n_trees, self.max_depth, x_in.shape[-1]
)
x_splits = torch.gather(input_expanded, dim=-1, index=split_feat_expanded).squeeze(-1)
split_decisions = (x_splits > self.split_thresholds).to(torch.long)
leaf_indices = torch.sum(split_decisions * self.depth_multipliers, dim=-1)
else:
index_shape = batch_shape + [self.n_trees, 1, 1]
split_feat_expanded = self.split_features.view(
*([1] * len(batch_shape)), self.n_trees, 1, 1
).expand(*index_shape)
input_expanded = x_expanded.expand(
*batch_shape, self.n_trees, 1, x_in.shape[-1]
)
x_splits = torch.gather(input_expanded, dim=-1, index=split_feat_expanded).squeeze(-1).squeeze(-1)
x_splits_expand = x_splits.unsqueeze(-1)
thresh_expand = self.split_thresholds.view(*([1] * len(batch_shape)), self.n_trees, self.n_splits_per_tree)
leaf_indices = torch.sum((x_splits_expand > thresh_expand).to(torch.long), dim=-1)
one_hot_bins = torch.nn.functional.one_hot(leaf_indices, num_classes=self.leaves_per_tree)
# 4. Final Output Tensor
# Output shape perfectly matches the required Basis Tensor format [*batch_shape, Leaves]
phi_tensor = one_hot_bins.view(*batch_shape, self.total_leaves).to(torch.float64)
# Apply the RKHS normalization bound IN-PLACE
phi_tensor.mul_(self.scale)
return phi_tensor
3. Penalty Matrix Safely instantiates the Anisotropic Sparsity-Adaptive Ridge penalty over the terminal leaves strictly as a sparse COO tensor. It dynamically scales the \(L_2\) shrinkage inversely to the empirical data density (\(C_i\)) captured during initialization, heavily penalizing starved edge leaves to guarantee global matrix rank.
def build_penalty_matrix(self) -> torch.Tensor:
r"""
Constructs the optimal structural penalty sub-matrix.
If sparsity_alpha > 0, the penalty dynamically adapts to the empirical
data density of each specific leaf. Starved leaves receive massive penalties
to prevent overfitting, while dense leaves receive standard shrinkage.
"""
if self.sparsity_alpha > 0.0 and hasattr(self, 'empirical_counts'):
# Calculate Anisotropic Penalty based on empirical data counts
C_i = self.empirical_counts.to(TORCH_DEVICE, dtype=torch.float64)
C_bar = C_i.mean()
epsilon = 1.0 # Smoothing constant to prevent division by zero
# Starved leaves (< C_bar) will get factors > 1.0
# Dense leaves (> C_bar) will get factors < 1.0
penalty_scaling = ((C_i + epsilon) / C_bar) ** (-self.sparsity_alpha)
diag_vals = self.lambda_p * penalty_scaling
else:
# Fallback to pure Isotropic Penalty (or for alpha = 0)
diag_vals = torch.full(
(self.total_leaves,), self.lambda_p,
device=TORCH_DEVICE, dtype=torch.get_default_dtype()
)
# Build as a sparse COO tensor directly to save VRAM
indices = torch.arange(self.total_leaves, device=TORCH_DEVICE)
indices = torch.stack([indices, indices], dim=0)
if hasattr(torch.sparse, 'check_sparse_tensor_invariants'):
with torch.sparse.check_sparse_tensor_invariants(False):
return torch.sparse_coo_tensor(
indices,
diag_vals,
size=(self.total_leaves, self.total_leaves),
device=TORCH_DEVICE
)
else:
return torch.sparse_coo_tensor(
indices,
diag_vals,
size=(self.total_leaves, self.total_leaves),
device=TORCH_DEVICE
)
Tensor Product¶
1. Factory Parsing
Parses the te() token recursively to instantiate the nested sub-effects.
elif ttype == 'te':
raw_arguments = list(params.keys())
functional_sub_term_strings = [
s for s in raw_arguments
if re.match(r'^\s*(\w+)\s*\(', s)
]
dummy_formula = "DUMMY ~ " + " + ".join(functional_sub_term_strings)
try:
_, sub_parsed_terms = parse_formula_to_terms(dummy_formula)
sub_effects = create_effects_from_parsed_terms(
sub_parsed_terms,
token_values,
default_alpha_p,
include_offset=False,
data_info=data_info
)
if len(sub_effects) < 2:
raise ValueError(f"Tensor Product requires at least two functional sub-terms. Found {len(sub_effects)}.")
extrap_val = params_resolved.get('extrapolate', 'continue')
effects_list.append(TensorProductEffect(sub_effects, lambda_p, extrap_val))
except ValueError as e:
raise ValueError(f"Tensor Product failed to parse functional sub-terms {functional_sub_term_strings}: {e}")
2. Feature Map
Generates the interaction surface using row-wise Kronecker products via torch.einsum.
@staticmethod
def kronecker_product_einsum(t1: torch.Tensor, t2: torch.Tensor) -> torch.Tensor:
"""
Computes the Kronecker Product pair by pair using Einstein summation broadcasting.
Exposed statically so composite effects (like LinearTree) can use it natively.
"""
t1_expanded = t1.unsqueeze(-1) # (..., d1, 1)
t2_expanded = t2.unsqueeze(-2) # (..., 1, d2)
out = t1_expanded * t2_expanded
target_dim = t1.shape[-1] * t2.shape[-1]
return out.reshape(*out.shape[:-2], target_dim)
def build_feature_map(self, x_data: torch.Tensor) -> torch.Tensor:
r"""
Builds the global feature map via Kronecker product.
Args:
x_data: Input tensor (Batch, N_samples, N_effects).
Note: The factory ensures columns are correctly ordered/selected.
"""
# Generate feature maps for each sub-effect
phi_list = []
col_idx = 0
for effect in self.effects:
# Determine how many columns this sub-effect needs
n_cols = len(getattr(effect, 'input_features', [effect.feature_name]))
# Slice the exact number of columns required
x_cols = x_data[..., col_idx : col_idx + n_cols]
# Only squeeze the tensor for legacy 1D math effects.
# Spatial kernels strictly require the feature dimension to remain intact.
if n_cols == 1 and effect.__class__.__name__ not in ['TreeEffect', 'LinearTreeEffect', 'RBFEffect', 'NeuralEffect']:
x_cols = x_cols.squeeze(-1)
phi_list.append(effect.transform(x_cols))
col_idx += n_cols
# Reduce the list of tensors into a single tensor
phi_cross = functools.reduce(self.kronecker_product_einsum, phi_list)
return phi_cross
3. Penalty Matrix Preserves marginal regularization independence via the Kronecker sum.
def build_penalty_matrix(self) -> torch.Tensor:
r"""
Builds the anisotropic penalty matrix.
It sums the penalties of each dimension, expanded by Identity matrices
on other dimensions.
"""
dims = [e.get_n_coeffs() for e in self.effects]
penalties = [e.build_penalty_matrix() for e in self.effects]
P_total = 0
for i in range(len(self.effects)):
# Construct term i: I x ... x P_i x ... x I
current_term = None
for j in range(len(self.effects)):
if i == j:
mat = penalties[j] # Active penalty
else:
mat = torch.eye(
dims[j], device=TORCH_DEVICE, dtype=torch.get_default_dtype()
) # Identity (Passive)
if mat.is_sparse:
mat = mat.to_dense()
if current_term is None:
current_term = mat
else:
# Kronecker product of matrices
current_term = torch.kron(current_term, mat)
P_total = P_total + current_term
# Apply global scaling lambda_p
return P_total * self.lambda_p
Universal Physics (PIKL)¶
1. Factory Parsing
elif ttype == 'phys':
basis = params_resolved.get('basis', 'spline')
n_coeffs = int(params_resolved.get('k', 20) if basis != 'fourier' else params_resolved.get('n_coeffs', 20))
diff_weights = {}
for pk, pv in params_resolved.items():
if pk.startswith('D') and pk[1:].isdigit():
diff_weights[pk] = float(pv)
if not diff_weights: diff_weights = {'D2': 1.0}
reserved_keys = ['k', 'n_coeffs', 'ap', 'basis', 'extrapolate']
basis_kwargs = {k: v for k, v in params_resolved.items() if k not in reserved_keys and not k.startswith('D')}
extrap_val = params_resolved.get('extrapolate', 'continue')
effects_list.append(UniversalPhysicsEffect(
feature_name, basis, n_coeffs, diff_weights, lambda_p, extrap_val, **basis_kwargs
))
2. Feature Map Delegates the feature extraction entirely to the underlying differentiable topology (Spline, Fourier, Neural).
class UniversalPhysicsEffect(BaseEffect):
"""
Universal Physics Effect constrained by a linear PDE.
This effect imposes that the learned function f minimizes the residual
of a differential equation L(f) approx 0.
The operator L is defined by a weighted sum of derivatives:
L(f) = w_0 * f + w_1 * df/dt + w_2 * d^2f/dt^2 + ...
Attributes:
basis_type (str): The type of basis used for approximation.
diff_weights (dict): Dictionary of operator weights.
"""
def __init__(
self,
feature_name: str,
basis_type: str,
n_coeffs: int,
diff_weights: dict,
lambda_p: float,
extrapolate: str,
**basis_params
):
super().__init__(feature_name, f"phys_{basis_type}", lambda_p, extrapolate)
self.basis_type = basis_type
self.diff_weights = diff_weights
if basis_type == 'spline':
deg = basis_params.get('spline_degree', 3)
self.base_effect = SplineEffect(
feature_name, n_knots=n_coeffs, spline_degree=deg, penalty_order=0, lambda_p=0.0, extrapolate='continue'
)
self.n_coeffs_val = self.base_effect.get_n_coeffs()
elif basis_type == 'fourier':
m = n_coeffs // 2
s_val = basis_params.get('s', 0)
cyclic = basis_params.get('cyclic', False)
self.base_effect = FourierEffect(
feature_name, m=m, s=s_val,cyclic=cyclic,
lambda_p=0.0, extrapolate='continue')
self.n_coeffs_val = self.base_effect.get_n_coeffs()
elif basis_type == 'neural':
act = basis_params.get('act', 'tanh')
seed = basis_params.get('seed', 42)
layers = basis_params.get('n_hidden_layers', 1)
others_str = basis_params.get('others', None)
additional_features = [s.strip() for s in str(others_str).split('|') if s.strip()] if others_str else None
self.base_effect = NeuralEffect(
feature_name,
n_neurons=n_coeffs,
activation=act,
additional_features=additional_features,
seed=seed,
n_hidden_layers=layers,
lambda_p=0.0, extrapolate='continue'
)
self.n_coeffs_val = n_coeffs
else:
raise ValueError(f"Unknown basis type for physics effect: {basis_type}")
def get_n_coeffs(self) -> int:
return self.n_coeffs_val
def build_feature_map(self, x_col: torch.Tensor) -> torch.Tensor:
return self.base_effect.build_feature_map(x_col)
3. Penalty Matrix Enforces the explicit linear differential operator by constructing the analytic stiffness matrix \(P\), circumventing the instability of PINNs via exact Physics-Informed Kernel Learning [Doumèche et al., 2025].
def build_penalty_matrix(self) -> torch.Tensor:
"""Builds the Stiffness Matrix P such that f^T P f approximates the integral of (L(f))^2."""
if self.basis_type == 'spline':
return self._build_spline_penalty()
elif self.basis_type == 'fourier':
return self._build_fourier_penalty()
elif self.basis_type == 'neural':
return self._build_neural_penalty()
return torch.zeros(1)
PID (Autoregressive Control)¶
The PID effect code¶
1. Factory Parsing
Parses the pid() token to extract the target lag feature, the look-back window for the rolling integral, and the artificial stiffness multiplier for the derivative action.
elif ttype == 'pid':
w = int(params_resolved.get('w', 7))
d_pen = float(params_resolved.get('d_pen', 10.0))
extrap_val = params_resolved.get('extrapolate', 'continue')
effects_list.append(PIDEffect(feature_name, w, lambda_p, d_pen, extrap_val))
2. Feature Map
Projects endogenous target lags into a discrete Proportional-Integral-Derivative control space. It executes strictly via in-place, vectorized tensor operations (using torch.diff for the derivative and a padded cumsum for the integral rolling mean) to prevent VRAM-heavy matrix cloning.
def build_feature_map(self, x_col: torch.Tensor) -> torch.Tensor:
r"""
Constructs the discrete PID feature space using vectorized tensor operations.
Args:
x_col (torch.Tensor): The raw autoregressive lag tensor.
Expected shape: (..., Time).
Returns:
torch.Tensor: The concatenated feature map [P, I, D].
Shape: (..., Time, 3).
"""
# 1. Proportional (P): The raw lag state
P = x_col.unsqueeze(-1)
# 2. Derivative (D): Rate of change (y_t - y_{t-1})
# Uses native torch.diff, prepending a zero to maintain the time dimension size
zero_prep = torch.zeros_like(x_col[..., :1])
D = torch.diff(x_col, dim=-1, prepend=zero_prep).unsqueeze(-1)
# 3. Integral (I): Rolling mean over the specified 'window'
cumsum_x = torch.cumsum(x_col, dim=-1)
w = min(self.window, x_col.shape[-1])
if w > 0:
# Pad the left side with zeros to shift the cumulative sum by 'w' steps
shift_cumsum = F.pad(cumsum_x[..., :-w], pad=(w, 0), mode='constant', value=0.0)
I_val = cumsum_x - shift_cumsum
else:
I_val = cumsum_x
# Divide by 'w' to convert the rolling sum into a rolling mean,
# guaranteeing the L2 penalty applies isotropically across P, I, and D.
I = (I_val / max(1, w)).unsqueeze(-1)
# 4. Assemble the final feature block
return torch.cat([P, I, D], dim=-1)
3. Penalty Matrix
Constructs a \(3 \times 3\) diagonal stiffness matrix. Uniquely applies an artificial structural boost (d_penalty_multiplier) to the derivative term to rigorously enforce low-pass filtering and prevent the amplification of high-frequency stochastic noise [Åström and Murray, 2021].
def build_penalty_matrix(self) -> torch.Tensor:
r"""
Constructs the stiffness matrix for the PID terms.
Returns:
torch.Tensor: A 3x3 diagonal penalty matrix. The Derivative term
receives a boosted penalty to enforce low-pass filtering.
"""
diag = torch.tensor(
[1.0, 1.0, self.d_penalty_multiplier],
dtype=torch.get_default_dtype(),
device=TORCH_DEVICE
)
return self.lambda_p * torch.diag(diag)
Control Diagnostics (Bode Stability)¶
Core Interface: bode.py
Because the pid() effect introduces an endogenous closed-loop dynamical system, the framework provides a native Control Diagnostics suite. This module extracts the learned structural weights and maps them to the frequency domain via the Z-transform, explicitly generating Bode plots and proving Bounded-Input Bounded-Output (BIBO) stability.
1. Public API
Exposes the plot_bode method, which dynamically routes requests for both static PID constraints and localized Gain-Scheduled (Tensor Product) controllers.
def plot_bode(
self,
pid_feature: str,
target_group: Optional[str] = None,
cond_feature: Optional[str] = None,
cond_value: Optional[float] = None
):
"""
Generates a Control Theory Bode Plot for an autoregressive PID component.
Automatically handles both static PID constraints and Gain-Scheduled
(Tensor Product) PID controllers based on the provided arguments. The
integral window size is extracted dynamically from the fitted model.
Args:
pid_feature (str): The name of the autoregressive target lag feature.
target_group (str, optional): The specific group (e.g., time of day) to analyze.
cond_feature (str, optional): The conditional feature name (e.g., 'temperature')
if analyzing a Gain-Scheduled (Tensor Product) PID.
cond_value (float, optional): The physical value of the condition to evaluate.
Raises:
RuntimeError: If the model has not been fitted.
"""
if getattr(self.model, 'coefficients_', None) is None:
raise RuntimeError("Model must be fitted before running Control Diagnostics.")
group_idx, group_key, target_group = self._resolve_group_routing(target_group)
if cond_feature is not None and cond_value is not None:
Kp, Ki, Kd, window = self._extract_dynamic_weights(
pid_feature, cond_feature, cond_value, group_idx, group_key
)
title = f"Local Filter Response: {pid_feature} | {cond_feature}={cond_value} (tod={target_group})"
else:
Kp, Ki, Kd, window = self._extract_static_weights(
pid_feature, group_idx, group_key
)
title = f"Autoregressive Filter Response: {pid_feature} (tod={target_group})"
print(f"\n--- Dynamics for Group (tod): {target_group} ---")
print(f"Physical Control Weights -> Kp: {Kp:.4f} | Ki: {Ki:.4f} | Kd: {Kd:.4f} | Window: {window}")
self._render_bode_plot(Kp, Ki, Kd, window, title)
2. Physical Unscaling Reverses the Primal feature map normalization (the “Normalization Illusion”) to extract the exact physical control weights (\(K_p, K_i, K_d\)). It rigorously slices the multi-dimensional tensor product space if conditioned on exogenous variables (like temperature).
def _extract_static_weights(self, pid_feature: str, group_idx: int, group_key: str) -> Tuple[float, float, float, int]:
"""
Extracts and unscales physical coefficients for a standard PID effect.
Returns Kp, Ki, Kd, and the integral window size.
"""
coeff_idx = 0
pid_effect = None
for effect in self.model.effects_list_:
if getattr(effect, 'effect_type', None) == "pid" and effect.feature_name == pid_feature:
pid_effect = effect
break
coeff_idx += effect.get_n_coeffs()
if pid_effect is None:
raise ValueError(f"No standalone PIDEffect found for feature: {pid_feature}")
group_coeffs = self.model.coefficients_[group_idx]
Kp_raw = group_coeffs[coeff_idx, 0].item()
Ki_raw = group_coeffs[coeff_idx + 1, 0].item()
Kd_raw = group_coeffs[coeff_idx + 2, 0].item()
window = getattr(pid_effect, 'window', 1)
scale = self._get_feature_scale(pid_feature, group_key)
return Kp_raw / scale, Ki_raw / scale, Kd_raw / scale, window
def _extract_dynamic_weights(
self, pid_feature: str, cond_feature: str, cond_value: float, group_idx: int, group_key: str
) -> Tuple[float, float, float, int]:
"""
Extracts, evaluates, and unscales coefficients for a Gain-Scheduled (Tensor Product) PID.
Returns Kp, Ki, Kd, and the integral window size.
"""
coeff_idx = 0
target_te_effect = None
for effect in self.model.effects_list_:
if getattr(effect, 'effect_type', None) == "tensor_product":
sub_names = [e.feature_name for e in effect.effects]
if pid_feature in sub_names and cond_feature in sub_names:
target_te_effect = effect
break
coeff_idx += effect.get_n_coeffs()
if not target_te_effect:
raise ValueError(f"No Tensor Product found linking {pid_feature} and {cond_feature}.")
pid_eff, cond_eff = target_te_effect.effects[0], target_te_effect.effects[1]
if getattr(pid_eff, 'effect_type', None) != "pid":
raise ValueError("The PID effect must be the FIRST argument in the tensor product formula.")
group_coeffs = self.model.coefficients_[group_idx]
te_coeffs_raw = group_coeffs[coeff_idx : coeff_idx + target_te_effect.get_n_coeffs(), 0]
te_coeffs_2d = te_coeffs_raw.view(3, cond_eff.get_n_coeffs())
# 1. Dynamically extract the device where the model's coefficients reside
target_device = te_coeffs_2d.device
scale_cond, center_cond = self._get_feature_scale(cond_feature, group_key, return_center=True)
cond_norm = (cond_value - center_cond) / scale_cond
# 2. Force the condition tensor to spawn on that exact same device
cond_tensor = torch.tensor([cond_norm], dtype=torch.get_default_dtype(), device=target_device)
phi_cond = cond_eff.build_feature_map(cond_tensor)
local_raw_pid = torch.matmul(te_coeffs_2d, phi_cond.squeeze())
window = getattr(pid_eff, 'window', 1)
scale_pid = self._get_feature_scale(pid_feature, group_key)
return local_raw_pid[0].item() / scale_pid, local_raw_pid[1].item() / scale_pid, local_raw_pid[2].item() / scale_pid, window
3. Digital Filter Construction & Rendering Builds the discrete transfer function polynomial, dynamically distributing the integral term over the rolling window. It renders the Bode phase and gain margins and automatically verifies that all complex poles remain strictly inside the unit circle.
def _render_bode_plot(self, Kp: float, Ki: float, Kd: float, window: int, title: str):
"""
Builds the discrete digital filter transfer function and renders the Bode diagram.
"""
import matplotlib.pyplot as plt
Ki_w = Ki / max(1, window)
a_1 = Kp + Ki_w + Kd
a_2 = Ki_w - Kd
num = [1.0] + [0.0] * window
den = [1.0, -a_1, -a_2]
for _ in range(3, window + 1):
den.append(-Ki_w)
sys = signal.TransferFunction(num, den, dt=1.0)
w_freq, mag, phase = signal.dbode(sys)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
fig.suptitle(title, fontsize=14, fontweight='bold')
ax1.semilogx(w_freq, mag, color='#1f77b4', linewidth=2.5)
ax1.set_ylabel('Gain Magnitude (dB)', fontweight='bold')
ax1.grid(True, which="both", ls="--", alpha=0.6)
ax2.semilogx(w_freq, phase, color='#ff7f0e', linewidth=2.5)
ax2.set_ylabel('Phase (degrees)', fontweight='bold')
ax2.set_xlabel('Frequency (rad/hour)', fontweight='bold')
ax2.grid(True, which="both", ls="--", alpha=0.6)
plt.tight_layout()
plt.show()
max_pole = max(abs(p) for p in sys.poles)
print("-" * 50)
print(f"Maximum Pole Magnitude: {max_pole:.4f}")
if max_pole >= 1.0:
print("CONTROL WARNING: System is UNSTABLE (Poles outside unit circle).")
else:
print("CONTROL VERIFIED: System is strictly stable.")
print("-" * 50)
Linear Tree (Varying-Coefficient Trees)¶
Related Theory: See the Mathematical Definition
1. Factory Parsing
The parser explicitly extracts the local slope feature and enforces n_trees=1 to mathematically prevent overlapping collinearity during the Primal resolution. It then directly instantiates the native composite effect.
elif ttype == 'lt':
slope_feat = params_resolved.pop('slope', feature_name)
extrap_val = params_resolved.pop('extrapolate', 'linear')
# Guardrail: Force n_trees=1 to prevent overlapping collinearity inside a single feature
params_resolved['n_trees'] = 1
n_trees = 1
max_depth = int(params_resolved.get('max_depth', 5))
max_leaves_raw = params_resolved.get('max_leaves', None)
max_leaves = int(max_leaves_raw) if max_leaves_raw is not None else None
seed = int(params_resolved.get('seed', 42))
sp_alpha = float(params_resolved.get('sp_alpha', 0.0))
split_strategy = params_resolved.get('split_strategy', 'uniform').lower().strip()
others_str = params_resolved.get('others', None)
additional_features = [s.strip() for s in others_str.split('|') if s.strip()] if others_str else None
# Directly append the composite effect! No sub-effect macro hacks needed.
effects_list.append(LinearTreeEffect(
feature_name=feature_name,
slope_feature=slope_feat,
n_trees=n_trees,
max_depth=max_depth,
max_leaves=max_leaves,
lambda_p=lambda_p,
additional_features=additional_features,
seed=seed,
extrapolate=extrap_val,
sparsity_alpha=sp_alpha,
split_strategy=split_strategy
))
2. Feature Map
Natively encapsulates the base TreeEffect (local intercept) and computes the Kronecker product of the slope tree with the LinearEffect to generate the localized gradients.
def build_feature_map(self, x_data: torch.Tensor) -> torch.Tensor:
n_tree_cols = len(self.tree_features)
x_tree = x_data[..., 0 : n_tree_cols]
phi_base = self.base_tree.transform(x_tree)
phi_slope_tree = self.slope_tree.transform(x_tree)
x_linear = x_data[..., n_tree_cols : n_tree_cols + 1].squeeze(-1)
phi_linear = self.linear.transform(x_linear)
phi_tensor = TensorProductEffect.kronecker_product_einsum(phi_slope_tree, phi_linear)
return torch.cat([phi_base, phi_tensor], dim=-1)
3. Penalty Matrix Constructs the block-diagonal encapsulation of the anisotropic sparsity-adaptive tree penalty (local intercepts) and the Kronecker tensor penalty (local slopes), safely coalescing them into a single sparse COO tensor.
def build_penalty_matrix(self) -> torch.Tensor:
P1 = self.base_tree.build_penalty_matrix()
P2 = self.tensor.build_penalty_matrix()
n1, n2 = P1.shape[0], P2.shape[0]
indices_list = []
values_list = []
# Block 1: Base Tree Intercepts (Top Left)
if P1.is_sparse:
P1 = P1.coalesce()
indices_list.append(P1.indices())
values_list.append(P1.values())
else:
nz = P1.nonzero(as_tuple=True)
if nz[0].numel() > 0:
indices_list.append(torch.stack([nz[0], nz[1]], dim=0))
values_list.append(P1[nz])
# Block 2: Tensor Product Slopes (Bottom Right)
if P2.is_sparse:
P2 = P2.coalesce()
indices_list.append(P2.indices() + n1)
values_list.append(P2.values())
else:
nz = P2.nonzero(as_tuple=True)
if nz[0].numel() > 0:
indices_list.append(torch.stack([nz[0] + n1, nz[1] + n1], dim=0))
values_list.append(P2[nz])
# Safely concatenate and build the global sparse block
if indices_list:
indices = torch.cat(indices_list, dim=1)
values = torch.cat(values_list, dim=0)
return torch.sparse_coo_tensor(indices, values, size=(n1+n2, n1+n2), device=P1.device)
else:
return torch.sparse_coo_tensor(
size=(n1+n2, n1+n2), dtype=P1.dtype, device=P1.device
)