PyTorch Math Dispatcher & Matrix-Free Solvers¶
Navigation:
Theory introduction: See the Intro
Related mathematical theory: See the Mathematical Theory
This chapter explores how the linear algebra theory described in Linear Algebra: Direct vs. Iterative Solvers is translated into high-performance, batched PyTorch tensor operations, seamlessly bridging statistical abstractions with low-level hardware constraints.
Covariance Accumulation¶
The _math.py module handles the foundational matrix multiplications. The _compute_weighted_covariances function is responsible for building the Left-Hand Side (cov_X) and Right-Hand Side (cov_XY) of the regularized normal equations [Doumèche et al., 2025].
By utilizing torch.mT (batched matrix transpose) and the @ operator, PyTorch computes these covariance matrices across the temporal or group batch dimension simultaneously. The script dynamically handles target weighting via a loss-weighting matrix loss_L_star_L prior to the dot product, ensuring that multidimensional outputs are appropriately scaled.
def _compute_weighted_covariances(
phi: torch.Tensor,
y_data: torch.Tensor,
loss_L_star_L: torch.Tensor
) -> tuple:
r"""
Computes the loss-weighted covariance matrices (LHS and RHS of normal equations).
Calculates:
- cov_X = Phi.H @ (L*L) @ Phi (Weighted Feature Covariance)
- cov_XY = Phi.H @ (L*L) @ Y (Weighted Feature-Target Covariance)
Args:
phi: The design matrix. Shape: (..., n_samples, n_coeffs).
y_data: The target tensor. Shape: (..., n_samples, d_out).
loss_L_star_L: The loss-weighting matrix. Shape: (d_out, d_out).
Returns:
Tuple[torch.Tensor, torch.Tensor]: (cov_X, cov_XY)
"""
# Ensure all tensors are on the same device
phi = phi.to(TORCH_DEVICE)
y_data = y_data.to(TORCH_DEVICE)
loss_L_star_L = loss_L_star_L.to(TORCH_DEVICE)
# Cast y_data to match phi's dtype
y_data_aligned = y_data.to(phi.dtype)
# Compute Weighted Y
# y_weighted shape: (..., n_samples, d_out)
y_weighted = y_data_aligned @ loss_L_star_L
# Compute RHS: cov_XY = Phi^H @ Y_weighted
cov_XY = phi.mT @ y_weighted
# Compute LHS: cov_X
# Note: Current implementation treats single-target and multi-target differently
if loss_L_star_L.shape[0] == 1:
# Single-target: Apply scalar weight via square root
L_sqrt = loss_L_star_L[0, 0].sqrt()
phi_weighted = phi * L_sqrt
cov_X = phi_weighted.mT @ phi_weighted
else:
# Multi-target: Currently defaulting to unweighted feature covariance
# (Phi^H @ Phi) for stability in hierarchical cases.
cov_X = phi.mT @ phi
return cov_X, cov_XY
Direct Solvers and the Jitter Application¶
When the topological complexity (the feature dimension \(D\)) is small enough to fit safely in VRAM, the orchestrator routes the problem to the solve_linear_system function, which utilizes exact direct inversion via torch.linalg.solve [Golub and Van Loan, 1996].
To guarantee that the matrix remains strictly positive-definite across diverse precision levels (Float32 vs Float64) and highly correlated feature spaces, the framework injects an adaptive Jitter.
This numerical anvil adds a microscopic trace (\(\delta I = 10^{-6} \times T \times I\)) to the diagonal, acting as a baseline Ridge penalty [Hoerl and Kennard, 1970]. This immediately stabilizes the condition number of the matrix before passing it to the underlying linear algebra backend (e.g., LAPACK for CPU, or cuSOLVER/MAGMA for GPU).
def solve_linear_system(
cov_X: torch.Tensor,
cov_XY: torch.Tensor,
penalty_M_star_M: torch.Tensor,
n_samples: Union[int, float]
) -> torch.Tensor:
r"""
Solves the regularized linear system (batched).
Solves for Coefficients (Beta):
(cov_X + n_samples * P) @ Beta = cov_XY
Args:
cov_X: The weighted feature covariance matrix. Shape: (..., K, K).
cov_XY: The weighted feature-target covariance. Shape: (..., K, d_out).
penalty_M_star_M: The penalty matrix P. Shape: (K, K).
n_samples: Scaling factor for the penalty term (usually sample count).
Returns:
torch.Tensor: Fitted coefficients. Shape: (..., K, d_out).
"""
cov_X = cov_X.to(TORCH_DEVICE)
cov_XY = cov_XY.to(TORCH_DEVICE)
penalty_M_star_M = penalty_M_star_M.to(TORCH_DEVICE)
if penalty_M_star_M.is_sparse:
penalty_M_star_M = penalty_M_star_M.to_dense()
# Expand penalty matrix dimensions to match batch size of cov_X
dims_to_add = cov_X.dim() - penalty_M_star_M.dim()
regularization_term = (n_samples * penalty_M_star_M).view(
*([1] * dims_to_add), *penalty_M_star_M.shape
)
# LHS construction
matrix_to_invert = cov_X + regularization_term
# Add Jitter for numerical stability (regularize diagonal)
jitter_scale = 1e-6 * n_samples
jitter = jitter_scale * torch.eye(
matrix_to_invert.shape[-1],
device=TORCH_DEVICE,
dtype=matrix_to_invert.dtype
)
matrix_to_invert += jitter.view(*([1] * dims_to_add), *jitter.shape)
# Solve system: A^-1 @ B
coeffs_list = hw.safe_solve(matrix_to_invert, cov_XY)
return coeffs_list
The Matrix-Free Conjugate Gradient¶
If the user designs a massive architecture (e.g., crossing a Random Forest with 10,000 leaves against a Fourier series), the resulting dense covariance matrix \(\Phi^T \Phi\) would trigger a catastrophic Out-Of-Memory (OOM) error.
To bypass this physical barrier, the framework utilizes solve_sparse_cg, a Matrix-Free Conjugate Gradient solver.
Instead of computing and allocating the massive \(D \times D\) system, the solver operates entirely within a Krylov subspace [Golub and Van Loan, 1996]. It requires only a closure function compute_Av(v) that evaluates the matrix-vector product \((A \cdot v)\). This allows the framework to iteratively discover the optimal coefficients \(\hat{\theta}\) without ever materializing the dense matrix in memory.
def solve_sparse_cg(
compute_Av: Callable[[torch.Tensor], torch.Tensor],
b: torch.Tensor,
x0: Optional[torch.Tensor] = None,
tol: float = 1e-4,
max_iter: int = 1000
) -> torch.Tensor:
r"""
Solves the linear system Ax = b using the matrix-free Conjugate Gradient method.
This solver never explicitly constructs the dense covariance matrix A. Instead,
it relies on a closure `compute_Av` that evaluates the matrix-vector product
(A @ v) dynamically. This allows for the resolution of massive algorithmic
structures (like Random Forests) without exhausting VRAM.
Args:
compute_Av: A closure evaluating (A @ v) for a given vector v.
b: The right-hand side target vector (\Phi^* Y).
x0: Initial guess for the coefficients.
tol: Tolerance threshold for the residual norm.
max_iter: Maximum number of iterations before stopping.
Returns:
torch.Tensor: The optimized coefficient vector.
"""
if x0 is None:
x = torch.zeros_like(b)
else:
x = x0.clone()
r = b - compute_Av(x)
p = r.clone()
is_batched = b.dim() >= 3
sum_dims = tuple(range(1, b.dim())) if is_batched else tuple(range(b.dim()))
rsold = torch.sum(r * r, dim=sum_dims, keepdim=True)
norm_b = torch.sqrt(torch.sum(b * b, dim=sum_dims, keepdim=True))
norm_b = torch.clamp(norm_b, min=1e-12)
for i in range(max_iter):
Ap = compute_Av(p)
p_Ap = torch.sum(p * Ap, dim=sum_dims, keepdim=True)
valid_mask = p_Ap > 1e-12
lambda_p = torch.zeros_like(rsold)
lambda_p[valid_mask] = rsold[valid_mask] / p_Ap[valid_mask]
x = x + lambda_p * p
r = r - lambda_p * Ap
rsnew = torch.sum(r * r, dim=sum_dims, keepdim=True)
if torch.all(torch.sqrt(rsnew) / norm_b < tol):
break
beta = torch.zeros_like(rsnew)
beta[valid_mask] = rsnew[valid_mask] / rsold[valid_mask]
p = r + beta * p
rsold = rsnew
return x
The Dispatcher Routing Logic¶
The decision to route between the exact chunked Direct Solver and the iterative Matrix-Free Conjugate Gradient is dynamically evaluated by the smart_solve function in _dispatcher.py.
Before attempting to allocate the global covariance matrix, the orchestrator cross-references the topological complexity against the physical hardware limits using can_fit_dense_matrix.
Safe VRAM (Direct Solver): If direct inversion is deemed safe, it triggers
_run_chunked_direct_solver. This function executes predictive Group Chunking by allocating up to 80% of available memory to calculate asafe_group_batch.Exhaustion Risk (CG Fallback): If the feature dimension \(D\) exceeds the safety threshold,
smart_solveautomatically routes to_run_sparse_cg_solverto prevent VRAM exhaustion. This function defines thecompute_Av(v)closure on-the-fly, safely computing the matrix-vector products recursively across manageable data chunks.
If an unexpected OOM exception is still encountered during extreme scaling, both solvers are wrapped in a robust try/except loop. This loop communicates with the hardware manager to iteratively reduce the safe_group_batch size, clearing the cache until the computation proceeds stably.
def smart_solve(
x_data: torch.Tensor,
y_data: torch.Tensor,
effects_list: List[BaseEffect],
penalty_matrix: torch.Tensor,
loss_matrix: torch.Tensor,
num_samples: int
) -> torch.Tensor:
r"""
Dynamically routes the mathematical resolution to the optimal solver.
"""
run_device = x_data.device
dummy_x = x_data[:, 0:1, :].to(run_device)
dummy_phi = build_phi_from_effects(dummy_x, effects_list)
total_d = dummy_phi.shape[-1]
del dummy_x, dummy_phi
is_safe_for_direct_inversion = can_fit_dense_matrix(total_d, run_device, batch_size=1)
if is_safe_for_direct_inversion:
return _run_chunked_direct_solver(
x_data, y_data, effects_list, penalty_matrix, loss_matrix, num_samples, total_d
)
else:
print(f"Notice: Feature dimension D={total_d} is massive.")
print("Routing to matrix-free Conjugate Gradient (CG) solver to prevent VRAM exhaustion...")
return _run_sparse_cg_solver(
x_data, y_data, effects_list, penalty_matrix, loss_matrix, num_samples
)