Hardware Memory Dispatch & Anti-OOM Systems¶
Navigation:
Theory introduction: See the Intro
Related mathematical theory: See the Mathematical Theory
This chapter details the hardware-aware engineering of the TAM framework. It explores how the theoretical complexity established in Computational Complexity is policed at runtime by dynamic PyTorch memory oracles to prevent catastrophic Out-Of-Memory (OOM) crashes during massive tensor calculations.
The Hardware Abstraction Layer (HAL)¶
The foundation of the framework’s stability is the HardwareManager (instantiated as the hw singleton) located in hardware.py.
Before any mathematical operation begins, this layer dynamically probes the host machine to detect the most capable compute backend, routing tensor operations in descending order of preference: NVIDIA CUDA, Apple MPS, Intel XPU, and finally the Host CPU.
Crucially, it manages disaster recovery via the handle_oom method. When an operation exceeds physical capabilities, this method intercepts the failure, executes a low-level cache purge (torch.cuda.empty_cache()), and computes a diminished workload batch size to allow the system to seamlessly retry the computation.
def handle_oom(
self,
current_batch: int,
context: str,
allow_cpu_fallback: bool = False
) -> Tuple[int, torch.device]:
"""
Handles Out-Of-Memory events by clearing cache and reducing batch sizes.
Args:
current_batch: The batch size that triggered the OOM.
context: String describing the operation that failed.
allow_cpu_fallback: Whether to attempt execution on CPU if batch hits 1.
Returns:
Tuple[int, torch.device]: The new safe batch size and compute device.
"""
self.empty_cache()
if current_batch <= 1:
if allow_cpu_fallback and self.backend != "cpu":
warnings.warn(
f"[OOM] VRAM exhausted during {context}. "
"Falling back to CPU execution."
)
return current_batch, torch.device("cpu")
else:
raise MemoryError(
f"Memory exhausted during {context} on {self.backend}. "
"Batch size 1 is too large to process."
)
new_batch = max(1, current_batch // 2)
warnings.warn(
f"[OOM] Alert during {context} on {self.backend}: "
f"Reducing batch from {current_batch} to {new_batch}."
)
return new_batch, self.device
The Memory Oracle and Safe Chunking¶
To proactively avoid invoking the OOM handler, the framework utilizes _memory.py as an advanced predictive oracle.
Before the _dispatcher.py attempts to allocate the massive global Covariance Matrix \(\Phi^T \Phi\), it queries can_fit_dense_matrix and evaluates the theoretical byte footprint against a strict Multi-Tiered Memory Waterfall:
The Dense Inversion Limit: The globally exact \(\mathcal{O}(D^3)\) solver is only authorized if the exact theoretical byte footprint of the dense inversion (accounting for Float64 precision) requires \(< 90\%\) of available VRAM, and the primal dimension is \(D \le 7500\). If either threshold is breached, the workload is routed to the Matrix-Free Conjugate Gradient solver.
Standard Group Chunking: For static data processing, the oracle bounds spatial tensor chunks to \(90\%\) of free VRAM (or \(70\%\) of system RAM) to maximize GPU compute occupancy without triggering PyTorch out-of-memory states.
Sliding Window Buffer (
AdaptiveTAM): Because online learning models require recursive history tracking, the oracle enforces a stricter \(80\%\) VRAM limit (\(60\%\) CPU RAM) to preserve buffer space for continuous state-space updates.
def can_fit_dense_matrix(
total_d: int,
device: torch.device,
batch_size: int = 1,
dtype_size: int = 8, # 8 bytes for float64
safety_factor: float = 4.0,
max_safe_d: int = 7500
) -> bool:
r"""
Evaluates if a dense Covariance Matrix inversion can safely execute in VRAM.
Inverting a matrix via LU or Cholesky decomposition requires allocating the
base matrix, the target vectors, and substantial temporary workspace memory
for the linear algebra backend (LAPACK for CPU, cuSOLVER/MAGMA for GPU).
Args:
total_d (int): The total feature dimension (D) of the Primal space.
device (torch.device): The compute device.
batch_size (int): The number of independent systems being solved simultaneously.
dtype_size (int): Bytes per element (8 for float64).
safety_factor (float): Multiplier accounting for backend workspace overhead.
max_safe_d (int): Hard mathematical limit for numerical stability.
Returns:
bool: True if the exact direct solver is safe to use; False otherwise.
"""
# Hard threshold for numerical stability and acceptable direct-inversion compute time
if total_d > max_safe_d:
return False
# Compute the exact byte footprint of a (D x D) dense matrix
matrix_bytes = batch_size * total_d * total_d * dtype_size
# Estimate total memory required for the solver operation
required_bytes = matrix_bytes * safety_factor
# Use HardwareManager instead of the deleted local function
available_bytes = hw.get_available_memory()
# Require that the operation takes no more than 90% of the currently free memory
return required_bytes < (available_bytes * 0.9)
The OOM Safety Net in the Dispatcher¶
Despite predictive calculations, unpredictable memory spikes can still occur during tensor decompositions or highly concurrent batching. The _dispatcher.py script shields these vulnerable linear algebra blocks inside robust try/except fallback loops.
If a torch.OutOfMemoryError is caught during the chunked processing of a group, the loop immediately invokes hw.handle_oom(). The dispatcher then smoothly re-attempts the exact same calculation with the halved batch size returned by the hardware manager, guaranteeing eventual convergence regardless of the hardware’s scale.
def smart_solve(
x_data: torch.Tensor,
y_data: torch.Tensor,
effects_list: List[BaseEffect],
penalty_matrix: torch.Tensor,
loss_matrix: torch.Tensor,
num_samples: int
) -> torch.Tensor:
r"""
Dynamically routes the mathematical resolution to the optimal solver.
"""
run_device = x_data.device
dummy_x = x_data[:, 0:1, :].to(run_device)
dummy_phi = build_phi_from_effects(dummy_x, effects_list)
total_d = dummy_phi.shape[-1]
del dummy_x, dummy_phi
is_safe_for_direct_inversion = can_fit_dense_matrix(total_d, run_device, batch_size=1)
if is_safe_for_direct_inversion:
return _run_chunked_direct_solver(
x_data, y_data, effects_list, penalty_matrix, loss_matrix, num_samples, total_d
)
else:
print(f"Notice: Feature dimension D={total_d} is massive.")
print("Routing to matrix-free Conjugate Gradient (CG) solver to prevent VRAM exhaustion...")
return _run_sparse_cg_solver(
x_data, y_data, effects_list, penalty_matrix, loss_matrix, num_samples
)
Sparse Routing and In-Place Memory Tricks¶
For algorithmic structures like Random Forests (Random Binning Features), the theoretical feature dimension \(D\) expands drastically, creating severe matrix parallelization bottlenecks [Wu et al., 2016]. The framework utilizes aggressive low-level PyTorch optimizations inside _tree.py to prevent these models from crashing the server upon instantiation:
In-Place Bounding: The binary leaf allocations naturally produce massive tensors. Instead of allocating a secondary normalized tensor to apply the \(1/\sqrt{B}\) RKHS bound, the framework strictly enforces an in-place mutation using
.mul_(self.scale). This minor optimization physically prevents PyTorch from allocating an extra redundant gigabyte in VRAM.
def build_feature_map(self, x_col: torch.Tensor) -> torch.Tensor:
r"""
Projects inputs into the sparse bit-string Primal representation natively on GPU.
Computes the mutually exclusive indicator functions for every tree simultaneously.
"""
# 1. Architectural Shape Normalization (The Ultimate Router)
if x_col.dim() == 1:
# 1D: From OOD Wrapper (Univariate) -> [N_ood]
batch_shape = list(x_col.shape)
x_in = x_col.unsqueeze(-1)
elif x_col.dim() == 2:
if x_col.shape[-1] == len(self.input_features):
# Ambiguity Resolution: Is it the te() dummy pass [1, 1], or OOD Wrapper [N_ood, Features]?
if self.split_features is None and x_col.shape == (1, 1) and len(self.input_features) == 1:
batch_shape = list(x_col.shape)
x_in = x_col.unsqueeze(-1)
else:
batch_shape = list(x_col.shape[:-1])
x_in = x_col
else:
# 2D: From te() regular pass [Batch, Time]
batch_shape = list(x_col.shape)
x_in = x_col.unsqueeze(-1)
else:
# 3D+: From _factory.py [Batch, Time, Features]
batch_shape = list(x_col.shape[:-1])
x_in = x_col
# 2. Bulletproof Dummy Pass Detection (Memory Probe)
is_dummy = False
if self.split_features is None:
if batch_shape == [1, 1]:
is_dummy = True
if self.split_features is None:
self._init_forest(x_in, is_dummy)
if is_dummy:
return torch.zeros(
*batch_shape, self.total_leaves,
device=x_in.device, dtype=torch.get_default_dtype()
)
# 3. Vectorized Evaluation
x_expanded = x_in.unsqueeze(-2).unsqueeze(-2)
if self.is_oblivious_binary:
index_shape = batch_shape + [self.n_trees, self.max_depth, 1]
split_feat_expanded = self.split_features.view(
*([1] * len(batch_shape)), self.n_trees, self.max_depth, 1
).expand(*index_shape)
input_expanded = x_expanded.expand(
*batch_shape, self.n_trees, self.max_depth, x_in.shape[-1]
)
x_splits = torch.gather(input_expanded, dim=-1, index=split_feat_expanded).squeeze(-1)
split_decisions = (x_splits > self.split_thresholds).to(torch.long)
leaf_indices = torch.sum(split_decisions * self.depth_multipliers, dim=-1)
else:
index_shape = batch_shape + [self.n_trees, 1, 1]
split_feat_expanded = self.split_features.view(
*([1] * len(batch_shape)), self.n_trees, 1, 1
).expand(*index_shape)
input_expanded = x_expanded.expand(
*batch_shape, self.n_trees, 1, x_in.shape[-1]
)
x_splits = torch.gather(input_expanded, dim=-1, index=split_feat_expanded).squeeze(-1).squeeze(-1)
x_splits_expand = x_splits.unsqueeze(-1)
thresh_expand = self.split_thresholds.view(*([1] * len(batch_shape)), self.n_trees, self.n_splits_per_tree)
leaf_indices = torch.sum((x_splits_expand > thresh_expand).to(torch.long), dim=-1)
one_hot_bins = torch.nn.functional.one_hot(leaf_indices, num_classes=self.leaves_per_tree)
# 4. Final Output Tensor
# Output shape perfectly matches the required Basis Tensor format [*batch_shape, Leaves]
phi_tensor = one_hot_bins.view(*batch_shape, self.total_leaves).to(torch.float64)
# Apply the RKHS normalization bound IN-PLACE
phi_tensor.mul_(self.scale)
return phi_tensor
Sparse COO Tensors: High-dimensional symmetric penalties (like those bounding 7,000 algorithmic leaves) would natively consume massive, contiguous memory blocks as dense diagonal matrices. The
TreeEffectclass constructs its structural penalty exclusively as a sparse coordinate (torch.sparse_coo_tensor) object.
This architectural choice forces the global linear algebra engine to utilize specialized sparse sub-routines, mathematically eliminating the storage of zeros and entirely circumventing the \(\mathcal{O}(D^2)\) physical allocation limitation.
def build_penalty_matrix(self) -> torch.Tensor:
r"""
Constructs the optimal structural penalty sub-matrix.
If sparsity_alpha > 0, the penalty dynamically adapts to the empirical
data density of each specific leaf. Starved leaves receive massive penalties
to prevent overfitting, while dense leaves receive standard shrinkage.
"""
if self.sparsity_alpha > 0.0 and hasattr(self, 'empirical_counts'):
# Calculate Anisotropic Penalty based on empirical data counts
C_i = self.empirical_counts.to(TORCH_DEVICE, dtype=torch.float64)
C_bar = C_i.mean()
epsilon = 1.0 # Smoothing constant to prevent division by zero
# Starved leaves (< C_bar) will get factors > 1.0
# Dense leaves (> C_bar) will get factors < 1.0
penalty_scaling = ((C_i + epsilon) / C_bar) ** (-self.sparsity_alpha)
diag_vals = self.lambda_p * penalty_scaling
else:
# Fallback to pure Isotropic Penalty (or for alpha = 0)
diag_vals = torch.full(
(self.total_leaves,), self.lambda_p,
device=TORCH_DEVICE, dtype=torch.get_default_dtype()
)
# Build as a sparse COO tensor directly to save VRAM
indices = torch.arange(self.total_leaves, device=TORCH_DEVICE)
indices = torch.stack([indices, indices], dim=0)
if hasattr(torch.sparse, 'check_sparse_tensor_invariants'):
with torch.sparse.check_sparse_tensor_invariants(False):
return torch.sparse_coo_tensor(
indices,
diag_vals,
size=(self.total_leaves, self.total_leaves),
device=TORCH_DEVICE
)
else:
return torch.sparse_coo_tensor(
indices,
diag_vals,
size=(self.total_leaves, self.total_leaves),
device=TORCH_DEVICE
)