OPERA GPU Tensor Batching¶
Navigation:
Theory introduction: See the Intro
Related mathematical theory: See the Mathematical Theory
This chapter explores the software engineering of the OperaTAM module. The core challenge in online aggregation is the inherently sequential nature of the time loop. While standard Primal solvers vectorize entire time series matrices instantaneously, online learning dictates that the weights at \(t+1\) depend strictly on the observation at time \(t\).
The Bottleneck of Sequential Python Loops¶
If the aggregation loop \(t=1, \dots, T\) were written natively in Python for a dataset containing thousands of time series (Groups), the performance would collapse.
At every time step, Python would have to dispatch minuscule array operations to the GPU. This triggers massive CPU-GPU kernel launch bottlenecks, leaving the massive parallel capabilities of the CUDA cores completely starved.
3D Tensor Batching & Padding¶
To achieve GPU saturation despite the sequential constraint, OperaTAM engineers the data into a strict 3D Tensor format prior to execution.
Architectural Choice (Uniform Temporal Alignment):
During the preparation phase in the predict_online method:
The framework uses the internal
_balance_groupsutility (withmethod="fill") to artificially pad any asynchronous or missing time series data with “fake dates”. This padding rigorously guarantees that every single group has the exact same temporal length \(T\).The data is stacked into continuous tensors
X_tensor_3dandY_tensor_3dwith the shape(Groups, Time, Experts)or \((B \times T \times K)\).
By perfectly aligning the temporal dimension across all groups, the framework transforms \(G\) independent sequential loops into a single, massive parallel operation. Once the compiled loop finishes simulating the historical timeframe, the framework seamlessly re-associates the outputs with their original indices and uses a boolean mask (_cleanup_dummies) to automatically strip away the artificial padding.
TorchScript C++ Compilation¶
Once the 3D tensors are assembled, they are passed into the specific algorithmic loops _mlpol_loop_optimized_3d or _ewa_loop_optimized_3d.
Architectural Choice (Bypassing the GIL):
To eliminate the Python Global Interpreter Lock (GIL) overhead from the sequential \(T\) iteration, these functions are decorated with @torch.jit.script. PyTorch strictly compiles the entire sequential logic-including the regret tracking, the adaptive learning rates, and the weight normalizations-into a single C++ computational graph.
Numerical Stability Engineering¶
Within these C++ compiled loops, extreme care is taken to prevent floating-point disasters common to GPU hardware:
1. Scale Invariance (MLpol): In _mlpol_loop_optimized_3d, the polynomial minimax strategy tracks squared regrets. To ensure the dynamic learning rates don’t exponentially explode when evaluating massive industrial targets, the entire input tensor X_scaled and target Y_scaled are divided by a scale_factor (the per-group maximum absolute value) before entering the loop.
A torch.where mask safely intercepts any groups with perfectly zero targets to prevent division-by-zero (NaN) crashes during this scaling.
@torch.jit.script
def _mlpol_loop_optimized_3d(
experts_tensor: torch.Tensor,
y_true: torch.Tensor,
loss_type: str
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Compiled TorchScript loop for MLpol (Polynomial Minimax Strategy).
Processes all groups and time steps simultaneously in a 3D tensor format
(Batch, Time, Experts) to saturate GPU cores and eliminate Python overhead.
Args:
experts_tensor: Predictions from experts, shape (B, T, K).
y_true: Ground truth targets, shape (B, T, 1).
loss_type: The loss function to use ('square' or 'absolute').
Returns:
Tuple containing the mixed predictions (B, T) and weights history (B, T, K).
"""
B, T, K = experts_tensor.shape
dtype = experts_tensor.dtype
device = experts_tensor.device
# Per-Group Scaling to prevent numerical instability
scale_factor = torch.max(torch.abs(y_true), dim=1, keepdim=True)[0]
# Prevent division by zero if a group's target is perfectly zero
scale_factor = torch.where(
scale_factor == 0.0,
torch.tensor(1.0, dtype=dtype, device=device),
scale_factor
)
X_scaled = experts_tensor / scale_factor
Y_scaled = y_true / scale_factor
# Pre-allocate outputs to avoid repetitive memory allocation
weights_history = torch.zeros((B, T, K), dtype=dtype, device=device)
predictions = torch.zeros((B, T), dtype=dtype, device=device)
# Initialize state variables for all groups
w = torch.ones((B, K), dtype=dtype, device=device) / float(K)
cum_regrets = torch.zeros((B, K), dtype=dtype, device=device)
max_sq_regrets = torch.zeros((B, K), dtype=dtype, device=device)
learning_rates = torch.ones((B, K), dtype=dtype, device=device) / (2.0**20)
# Compiled sequential time loop
for t in range(T):
xt_scaled = X_scaled[:, t, :]
yt_scaled = Y_scaled[:, t, 0]
weights_history[:, t, :] = w
# Original scale for actual prediction tracking
predictions[:, t] = torch.sum(w * experts_tensor[:, t, :], dim=1)
# Scaled space for gradient regret calculation
y_hat_scaled = torch.sum(w * xt_scaled, dim=1)
if loss_type == 'square':
r = 2.0 * (y_hat_scaled.unsqueeze(1) - yt_scaled.unsqueeze(1)) * (y_hat_scaled.unsqueeze(1) - xt_scaled)
else:
r = torch.sign(y_hat_scaled.unsqueeze(1) - yt_scaled.unsqueeze(1)) * (y_hat_scaled.unsqueeze(1) - xt_scaled)
r_square = r ** 2
cum_regrets += r
# Adaptive learning rate adjustments
max_r_square = torch.max(r_square, dim=1, keepdim=True)[0]
max_sq_regret_diff = torch.clamp(max_r_square - max_sq_regrets, min=0.0)
learning_rates = 1.0 / (1.0 / learning_rates + r_square + max_sq_regret_diff)
max_sq_regrets += max_sq_regret_diff
# Polynomial weight update: w is proportional to max(0, R)^2 (linearized)
relu_regrets = torch.clamp(cum_regrets, min=0.0)
w_next = learning_rates * relu_regrets
w_sum = torch.sum(w_next, dim=1, keepdim=True)
# Normalize weights, falling back to uniform if all regrets are <= 0
mask = w_sum > 0.0
w = torch.where(
mask,
w_next / w_sum,
torch.ones((B, K), dtype=dtype, device=device) / float(K)
)
return predictions, weights_history
2. Safe Softmax (EWA): In _ewa_loop_optimized_3d, computing standard exponentials (\(\exp(-x)\)) for massive cumulative losses triggers immediate inf or NaN values. The code utilizes the stable log-sum-exp trick (scaled_losses - max_val) to safely bound the maximum exponent to exactly 0, preventing exponential overflow without altering the mathematical weight distribution.
@torch.jit.script
def _ewa_loop_optimized_3d(
experts_tensor: torch.Tensor,
y_true: torch.Tensor,
loss_type: str,
eta: float
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Compiled TorchScript loop for EWA (Exponentially Weighted Aggregation).
Processes all groups simultaneously and prevents exponential underflow/overflow
using the stable log-sum-exp (safe softmax) numerical trick.
"""
B, T, K = experts_tensor.shape
dtype = experts_tensor.dtype
device = experts_tensor.device
# Pre-allocate outputs
weights_history = torch.zeros((B, T, K), dtype=dtype, device=device)
predictions = torch.zeros((B, T), dtype=dtype, device=device)
# State variables
w = torch.ones((B, K), dtype=dtype, device=device) / float(K)
cum_losses = torch.zeros((B, K), dtype=dtype, device=device)
for t in range(T):
xt = experts_tensor[:, t, :]
yt = y_true[:, t, 0]
weights_history[:, t, :] = w
predictions[:, t] = torch.sum(w * xt, dim=1)
# Compute expert losses
if loss_type == 'square':
losses = (xt - yt.unsqueeze(1)) ** 2
else:
losses = torch.abs(xt - yt.unsqueeze(1))
cum_losses += losses
# Stable Exponential Update (Safe Softmax)
scaled_losses = -eta * cum_losses
max_val = torch.max(scaled_losses, dim=1, keepdim=True)[0]
w_unnorm = torch.exp(scaled_losses - max_val)
w_sum = torch.sum(w_unnorm, dim=1, keepdim=True)
mask = w_sum > 0.0
w = torch.where(
mask,
w_unnorm / w_sum,
torch.ones((B, K), dtype=dtype, device=device) / float(K)
)
return predictions, weights_history
Causal Boundary Enforcement (Horizon Shifting)¶
In production forecasting, online algorithms must not leak future information.
Architectural Choice (Information Delay):
After the optimized loops evaluate the performance weights natively, the OperaTAM.predict_online method intercepts the raw weights_np array. If the user specifies a multi-step forecasting scenario (horizon_steps > 1), the algorithm explicitly shifts the learned weights forward by \(H - 1\) steps.
The “blind” initial steps are overridden with uniform weights (1.0 / len(experts)), and the historical weights are shifted causally to the right: shifted_weights[:, shift:, :] = weights_np[:, :-shift, :]. This ensures the aggregated prediction at time \(t\) strictly relies on expert performance evaluated prior to the blind forecast horizon.
Production Inference & State Freezing¶
While predict_online() handles the complex task of running the sequential 3D tensor loop over historical data, production forecasting requires applying a strict, static rule to future data.
Architectural Choice (The Inference API):
To standardize the deployment pipeline, OperaTAM implements a dedicated predict() method. When fit() (which wraps predict_online()) completes the historical simulation, the final updated weight vector per group is saved to self.weights_history_. When a user subsequently calls predict(test_data), the framework completely bypasses the TorchScript C++ loop. Instead, it extracts the final frozen weights and executes a standard NumPy dot product (experts @ weights), ensuring instant, deterministic out-of-sample predictions.