Online Error Correction (AdaptiveTAM)¶
Navigation:
Theory introduction: See the Intro
Related mathematical theory: See the Mathematical Theory
Standard Generalized Additive Models (GAMs) assume the underlying data generating process is globally stationary. However, real-world time series suffer from structural breaks, concept drift, and loss of exchangeability. While the underlying mathematics were explored in the theoretical formulations, AdaptiveTAM provides the production-grade hardware implementation.
It utilizes a Sliding Window tensorization approach to update the covariance matrices continuously, allowing the coefficients to smoothly adapt to new regimes without retraining the global model from scratch. This chapter details how this theory is translated into massively parallel PyTorch operations, strictly avoiding restrictive Python for loops.
Design Pattern: Instance Composition¶
The software architecture directly reflects the two-stage mathematical formulation of Online WeaKL. The AdaptiveTAM class does not inherit from StaticTAM; instead, it utilizes a Composition pattern.
Architectural Choice: By using composition, the framework prevents namespace collision and guarantees mathematical isolation between the physics-based macroscopic model and the reactive microscopic model. It owns and orchestrates two distinct instances:
self.base_model_: The expert instance, permanently frozen after its initial training.self.adaptive_model_: A blank, template instance that will be cloned and dynamically retrained on each sliding window to correct the base model’s residuals.
class AdaptiveTAM:
r"""
Initializes the AdaptiveTAM model.
"""
def __init__(
self,
adaptive_formula: str,
update_interval_periods: int,
training_window_periods: int,
steps_per_period: int,
base_model: Optional[StaticTAM] = None,
horizon_steps: int = 1,
default_alpha_p: float = -9.0,
group_col: Optional[str] = None,
date_col: Optional[str] = None,
add_base_effects: bool = False
):
r"""
Initializes the AdaptiveTAM model.
Args:
base_model: A fitted StaticTAM model instance.
adaptive_formula: Formula for the adaptive correction model.
Features must be columns produced by base_model.decompose_prediction()
(e.g., 'effect_temp').
update_interval_periods (int): The number of periods to skip before updating
the coefficients for a specific group (determines n_windows).
training_window_periods (int): The historical look-back period used to solve
the local linear system for an independent group (determines num_samples_train).
steps_per_period (int): The number of observations a single group experiences
within one logical period.
- If the data is monthly and the period is a month: 1.
- If the data is 30-min, but group_col is 'Time of Day': 1 (since each group only sees one observation per day).
horizon_steps (int): The forecasting horizon (H) used to prevent target leakage.
Enforces an information delay by truncating the last (H-1) samples from the
training buffer of every group simultaneously.
default_alpha_p: Default regularization strength (log10).
Raises:
ValueError: If the base_model has not been fitted.
"""
if base_model is not None and getattr(base_model, 'coefficients_', None) is None:
raise ValueError("The base_model must be fitted before initializing AdaptiveTAM.")
if add_base_effects and base_model is not None:
for effect in base_model.effects_list_:
effect_col = f"effect_{effect.feature_name}"
if effect_col not in adaptive_formula:
adaptive_formula += f" + l({effect_col})"
self.base_model_ = base_model
self.adaptive_formula_ = adaptive_formula
self.update_interval_periods_ = update_interval_periods
self.training_window_periods_ = training_window_periods
self.steps_per_period_ = steps_per_period
self.horizon_steps_ = horizon_steps
self.group_col_ = group_col or getattr(base_model, 'group_col_', "__dummy_group__")
self.date_col_ = date_col or getattr(base_model, 'date_col_', "__dummy_date__")
self.adaptive_model_ = StaticTAM(
formula=adaptive_formula,
group_col=self.group_col_,
date_col=self.date_col_,
default_alpha_p=default_alpha_p
)
self.coefficients_ = None
self.norm_params_ = None
self.unique_groups_ = None
self.simulation_data_ = None
self.predictions_ = None
target_col_bm = getattr(base_model, 'target_col_', self.adaptive_model_.target_col_)
self.target_col_ = self.adaptive_model_.target_col_ or f'Residual{target_col_bm}'
if self.base_model_ is not None:
base_target = getattr(self.base_model_, 'target_col_', None)
if base_target and self.adaptive_model_.target_col_:
expected_residual = f'Residual{base_target}'
if self.adaptive_model_.target_col_ not in [expected_residual, base_target]:
warnings.warn(
f"Adaptive target '{self.adaptive_model_.target_col_}' does not match the base target "
f"'{base_target}' or its expected residual '{expected_residual}'. "
"This will cause a cross-target correction.",
UserWarning
)
self.last_state_dict_ = None
self.max_res_ = None
self.min_res_ = None
Advanced PyTorch Indexing (4D Vectorization)¶
One of the major engineering challenges of online learning is extracting the training windows. Sequentially extracting 10,000 localized windows using a Python for loop would completely stall the CUDA pipeline due to excessive CPU-GPU kernel launch latency.
To solve this, the _transform_data_adaptive function (located in _data.py) exploits PyTorch’s advanced indexing and broadcasting.
Architectural Choices & Frequency Agnosticism:
The code deliberately transitions from a rigid “Calendar Day” perspective to a Frequency-Agnostic Group Coordinate System. By using mathematical abstract variables like periods and steps_per_period, it perfectly isolates panel data entities regardless of whether the native sampling rate is monthly, daily, or half-hourly.
Anchor Calculation: The code identifies all valid starting indices (
start_indices) by walking backward through time, strictly ensuring no incomplete windows are processed.Causal Offset Generation (Preventing Target Leakage): It generates two constant tensors:
predict_offsets(the indices of the prediction window) andtrain_offsets(the relative indices of the learning window). Crucially, to prevent Target Leakage in multi-step forecasting, thetrain_offsetsare dynamically shifted backward by-(horizon_steps - 1). This physically truncates the “illegal future” from the training buffer before the tensor is ever evaluated on the GPU.Vectorized Extraction: By adding the column matrix of start indices (
start_indices.view(-1, 1)) to the offset vectors, PyTorch instantly broadcasts a complete 2D index matrix. The framework then extracts the entirety of the causally safe historical data in a single, contiguous memory call:x_group[train_indices].
def _transform_data_adaptive(
data: pd.DataFrame,
features: List[str],
group_col: str,
norm_params: Dict,
unique_groups: List,
target_col: str,
update_interval_periods: int,
training_window_periods: int,
steps_per_period: int,
horizon_steps: int = 1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Prepares data for adaptive learning using vectorized sliding window indexing.
Returns tensors for (X_train, Y_train, X_predict) for each simulation step.
Args:
data: Validation/Test DataFrame.
features: Feature list.
group_col: Grouping column.
norm_params: Normalization parameters.
unique_groups: Group names.
target_col: Target column.
update_interval_periods: Prediction window size.
training_window_periods: Training history size.
steps_per_period: Samples per period.
horizon_steps: Horizon of forecasting
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
(x_stacked, y_stacked, x_to_predict)
"""
learning_size_steps = training_window_periods * steps_per_period
window_size_steps = update_interval_periods * steps_per_period
all_groups_x_train = []
all_groups_y_train = []
all_groups_x_predict = []
if unique_groups is None:
raise ValueError("`unique_groups` cannot be None.")
prep_device = 'cpu'
for group_name in unique_groups:
if group_name not in norm_params:
continue
data_group = data[data[group_col] == group_name].reset_index(drop=True)
params = norm_params.get(group_name)
if params is None: continue
total_available_steps = len(data_group)
required_history = learning_size_steps + (horizon_steps - 1)
if total_available_steps <= required_history:
continue
# Normalize group data
data_group[features] = normalize(df_to_normalize=data_group[features], params=params)
x_group = torch.tensor(data_group[features].values, dtype=torch.float32, device=prep_device)
y_group = torch.tensor(data_group[target_col].values, dtype=torch.float32, device=prep_device).view(-1, 1)
# Calculate valid start indices (reverse chronological)
start_indices_list = []
first_predict_start = total_available_steps - (total_available_steps - learning_size_steps) % window_size_steps
if first_predict_start == total_available_steps and total_available_steps > learning_size_steps:
first_predict_start -= window_size_steps
current_predict_start = first_predict_start
while current_predict_start >= learning_size_steps:
if current_predict_start + window_size_steps <= total_available_steps:
start_indices_list.append(current_predict_start)
current_predict_start -= window_size_steps
if not start_indices_list:
continue
start_indices_list.reverse()
start_indices = torch.tensor(start_indices_list, device=prep_device, dtype=torch.long)
# Vectorized Window Indexing
train_end_offset = -(horizon_steps - 1) if horizon_steps > 1 else 0
train_start_offset = train_end_offset - learning_size_steps
train_offsets = torch.arange(train_start_offset, train_end_offset, device=prep_device)
predict_offsets = torch.arange(0, window_size_steps, device=prep_device)
train_indices = start_indices.view(-1, 1) + train_offsets
predict_indices = start_indices.view(-1, 1) + predict_offsets
# Gather
group_x_train = x_group[train_indices]
group_y_train = y_group[train_indices]
group_x_predict = x_group[predict_indices]
all_groups_x_train.append(group_x_train)
all_groups_y_train.append(group_y_train)
all_groups_x_predict.append(group_x_predict)
if not all_groups_x_train:
raise ValueError("No simulation data could be generated. Check dataset length/window sizes.")
x_stacked = torch.stack(all_groups_x_train).to(TORCH_DEVICE)
y_stacked = torch.stack(all_groups_y_train).to(TORCH_DEVICE)
x_to_predict = torch.stack(all_groups_x_predict).to(TORCH_DEVICE)
return x_stacked, y_stacked, x_to_predict
Dynamic Target and Preparation¶
The prepare_simulation method is executed before launching the online loop. Its role is to massively vectorize the dataset and shift the target space.
Architectural Choice: It calculates the macroscopic predictions (data_pred) and subtracts them from the true target to isolate the residuals (\(\epsilon_t = Y_t - \hat{Y}_{base}\)) upfront. Executing this in a single block across the entire validation history prevents an immense computational bottleneck that would occur if the base model were evaluated repeatedly inside the sequential sliding window loop.
def prepare_simulation(self, data: pd.DataFrame) -> 'AdaptiveTAM':
r"""
Prepares tensors for the adaptive sliding-window simulation.
This process:
1. Computes base model effects (features) and residuals (targets).
2. Normalizes the adaptive features.
3. Constructs sliding-window tensors (X_train, Y_train, X_predict).
Args:
data: The dataset (validation or test) for the simulation.
Returns:
self: The instance with populated ``simulation_data_``.
"""
if self.base_model_ is not None:
data_bm = self.base_model_.decompose_prediction(data)
data_pred = self.base_model_.predict(data)
target_col_bm = self.base_model_.target_col_
else:
data_bm = data.copy()
data_pred = data.copy()
target_col_bm = self.target_col_
data_bm = _ensure_dummies(data_bm, self.group_col_, self.date_col_)
data_pred = _ensure_dummies(data_pred, self.group_col_, self.date_col_)
cols_float = data_bm.select_dtypes(include=['float64']).columns
data_bm[cols_float] = data_bm[cols_float].astype('float32')
est_col_name = f'Estimated{target_col_bm}'
if est_col_name not in data_bm.columns:
data_bm[est_col_name] = data_pred.get(est_col_name, 0.0)
default_res_col = f'Residual{target_col_bm}'
data_bm[default_res_col] = data_bm[target_col_bm] - data_bm[est_col_name]
adaptive_features_config = self.adaptive_model_.features_config_
if 'features' not in adaptive_features_config or not isinstance(adaptive_features_config['features'], list):
raise KeyError("Invalid feature configuration in adaptive model.")
adaptive_features = adaptive_features_config['features']
required_cols = adaptive_features + [self.group_col_, self.target_col_, self.date_col_]
_check_features(dataset=data_bm, required_features=required_cols)
mask, balanced_data = _balance_groups(
dataset=data_bm, group_col=self.group_col_, date_col=self.date_col_, method="fill"
)
data_info = self.adaptive_model_._get_data_info(balanced_data)
if not self.adaptive_model_.effects_list_ and not self.adaptive_model_.is_grid_search_template_:
self.adaptive_model_.effects_list_ = create_effects_from_parsed_terms(
self.adaptive_model_.parsed_terms_,
token_values={},
default_alpha_p=self.adaptive_model_.default_alpha_p_,
data_info=data_info
)
self.norm_params_, self.unique_groups_ = _fit_normalization_params(
data=balanced_data,
features=adaptive_features,
group_col=self.group_col_
)
hw.empty_cache()
x_stacked, y_stacked, x_to_predict = _transform_data_adaptive(
data=balanced_data,
features=adaptive_features,
group_col=self.group_col_,
norm_params=self.norm_params_,
unique_groups=self.unique_groups_,
target_col=self.target_col_,
update_interval_periods=self.update_interval_periods_,
training_window_periods=self.training_window_periods_,
steps_per_period=self.steps_per_period_,
horizon_steps=self.horizon_steps_
)
self.simulation_data_ = (
x_stacked.cpu(),
y_stacked.cpu(),
x_to_predict.cpu(),
balanced_data, data_bm, target_col_bm, mask
)
del x_stacked, y_stacked, x_to_predict
hw.empty_cache()
return self
Batch Execution and Numerical Safety (Clipping)¶
In the simulation() method, calculations are not executed step-by-step. The solver must simultaneously resolve the exact local empirical risk minimization problem for all sliding windows across all groups.
Architectural Choice 1: Dimensional Flattening:
The generated tensors naturally possess 4 dimensions (n_groups, n_windows, num_samples, features). The standard Core mathematical solvers (solve_linear_system) are built exclusively for 3D batches (batch, samples, features). Instead of rewriting the core linear algebra module, simulation() flattens the first two dimensions: total_items = n_groups * n_windows. This elegantly maps the massive localized window problem into a standard, independent batch structure, saturating the GPU cores perfectly.
Architectural Choice 2: OOM Resilience:
Because total_items can easily exceed tens of thousands of matrices, the method implements the dynamic hw.handle_oom fallback mechanism. If the GPU VRAM exhausts, it catches the torch.OutOfMemoryError and halves the safe_batch_size iteratively, ensuring the simulation completes gracefully even on constrained hardware.
Architectural Choice 3: Algorithmic Safeguard (Target Clipping):
Adaptive models can diverge wildly if they train on a window containing purely anomalous values (e.g., a sensor failure). To guarantee safety in production, the implementation applies strict bounded clipping. The final adapted prediction is physically bounded by the maximum and minimum historical values of the observed base residuals (data_bm[self.target_col_].max()), preventing the corrector from extrapolating unrealistic deviations during unexpected exogenous shocks.
def simulation(self) -> pd.DataFrame:
r"""
Executes the sliding-window simulation using scalable batch processing.
It flattens the group and window dimensions to treat each sliding window
as an independent linear system. This strictly preserves the mathematical
regularization scale and prevents VRAM exhaustion when processing deep
historical data across multiple groups.
"""
if self.simulation_data_ is None:
raise RuntimeError("Simulation data is uninitialized. Call 'prepare_simulation()' first.")
if self.adaptive_model_.is_grid_search_template_:
raise RuntimeError("Model contains grid search tokens. Call 'grid_search_fit()' first.")
x_stacked, y_stacked, x_to_predict, balanced_data, data_bm, target_col_bm, mask = self.simulation_data_
n_groups = x_stacked.shape[0]
n_windows = x_stacked.shape[1]
num_samples_train = x_stacked.shape[2]
window_size_steps = x_to_predict.shape[2]
total_items = n_groups * n_windows
x_flat = x_stacked.view(total_items, num_samples_train, -1)
y_flat = y_stacked.view(total_items, num_samples_train, -1)
x_pred_flat = x_to_predict.view(total_items, window_size_steps, -1)
run_device = TORCH_DEVICE
sobolev_matrix = self.adaptive_model_._build_penalty_matrix().to(run_device)
loss_L_star_L = self.adaptive_model_._build_loss_matrix().to(run_device)
dummy_x = x_flat[0:1].to(run_device)
dummy_phi = self.adaptive_model_._build_design_matrix(dummy_x)
n_coeffs = dummy_phi.shape[-1]
del dummy_x, dummy_phi
safe_batch_size = get_safe_window_batch_size(
num_samples_per_window=num_samples_train,
total_d=n_coeffs,
device=run_device
)
safe_batch_size = min(safe_batch_size, total_items)
all_predictions = []
start_idx = 0
while start_idx < total_items:
end_idx = min(start_idx + safe_batch_size, total_items)
try:
batch_x = x_flat[start_idx:end_idx].to(run_device)
batch_y = y_flat[start_idx:end_idx].to(run_device)
batch_x_pred = x_pred_flat[start_idx:end_idx].to(run_device)
phi_batch = self.adaptive_model_._build_design_matrix(batch_x)
cov_X, cov_XY = _compute_weighted_covariances(phi_batch, batch_y, loss_L_star_L)
coeffs_batch = solve_linear_system(cov_X, cov_XY, sobolev_matrix, num_samples_train)
phi_pred_batch = self.adaptive_model_._build_design_matrix(batch_x_pred)
preds_batch = _predict_from_coeffs(phi_pred_batch, coeffs_batch)
all_predictions.append(preds_batch.detach().cpu())
del batch_x, batch_y, batch_x_pred, phi_batch, cov_X, cov_XY, coeffs_batch, phi_pred_batch, preds_batch
start_idx += safe_batch_size
except (torch.OutOfMemoryError, MemoryError):
safe_batch_size, run_device = hw.handle_oom(
current_batch=safe_batch_size,
device=run_device,
context="adaptive simulation batch reduction",
allow_cpu_fallback=True
)
sobolev_matrix = sobolev_matrix.to(run_device)
loss_L_star_L = loss_L_star_L.to(run_device)
continue
hw.empty_cache()
predictions_flat = torch.cat(all_predictions, dim=0).squeeze(-1)
predictions_cpu = predictions_flat.view(n_groups, n_windows, window_size_steps)
data_with_predictions = _reassemble_predictions(
original_data=balanced_data,
predictions_stacked=predictions_cpu,
group_col=self.group_col_,
unique_groups=self.unique_groups_,
target_col=self.target_col_
)
max_res = np.float32(data_bm[self.target_col_].max())
min_res = np.float32(data_bm[self.target_col_].min())
est_col = f'Estimated{self.target_col_}'
data_with_predictions.loc[data_with_predictions[est_col] >= max_res, est_col] = max_res
data_with_predictions.loc[data_with_predictions[est_col] <= min_res, est_col] = min_res
adapted_col = f"AdaptedEstimated{target_col_bm}"
if self.target_col_ == target_col_bm:
data_with_predictions[adapted_col] = data_with_predictions[est_col].fillna(0)
else:
data_with_predictions[adapted_col] = (
data_with_predictions[f'Estimated{target_col_bm}'] +
data_with_predictions[est_col].fillna(0)
)
self.predictions_ = _cleanup_dummies(data_with_predictions[mask], self.group_col_, self.date_col_)
return self.predictions_
Coordinate Descent Search Algorithm¶
The grid_search_fit() method implements the “Multi-Start” hyperparameter solver described in the theory. It evaluates pairs (e.g., training_window_periods, default \(\lambda\)) sequentially, moving along a single axis at a time until local convergence is reached.
Architectural Choice: Because the global objective function evaluated across overlapping sliding windows creates highly non-convex search topologies, standard analytical Generalized Cross-Validation (GCV) breaks down. Coordinate Descent provides robust navigation through these complex spaces. To systematically avoid poor local minima, the algorithm leverages three distinct initialization strategies (Conservative, Median, Aggressive) before executing the cyclic axis search.
def grid_search_fit(
self,
data_val: pd.DataFrame,
grid_search_config: dict
) -> 'AdaptiveTAM':
r"""
Optimizes the adaptive model using Multi-Start Coordinate Descent.
Optimizes hyperparameters to minimize the RMSE of the final
adapted prediction over the simulation period.
Args:
data_val: Validation DataFrame for simulation.
grid_search_config: Dictionary mapping tokens to value lists.
Returns:
AdaptiveTAM: A new fitted model instance.
"""
print("--- Starting Grid Search (Multi-Start Coordinate Descent) ---")
self.prepare_simulation(data_val)
search_axes, token_names = self.adaptive_model_._parse_grid_axes(grid_search_config)
data_info = self.adaptive_model_._get_data_info(data_val)
if not token_names:
raise ValueError("No grid tokens found in adaptive formula or config.")
print(f"Optimizing axes: {token_names}")
start_points = [
{"name": "Conservative", "tokens": {t: max(vals) if ('ap' in t or 'lambda_p' in t) else min(vals) for t, vals in search_axes.items()}},
{"name": "Median", "tokens": {t: vals[len(vals)//2] for t, vals in search_axes.items()}},
{"name": "Aggressive", "tokens": {t: min(vals) if ('ap' in t or 'lambda_p' in t) else max(vals) for t, vals in search_axes.items()}}
]
min_global_rmse = float('inf')
optimal_effects_list = None
current_best_tokens_global = None
for strategy in start_points:
print(f"\n=== Strategy: {strategy['name']} Start ===")
current_best_tokens = strategy["tokens"].copy()
try:
start_effects_list = create_effects_from_parsed_terms(
self.adaptive_model_.parsed_terms_,
current_best_tokens,
self.adaptive_model_.default_alpha_p_,
data_info=data_info
)
current_rmse = self._evaluate_adaptive_config(start_effects_list, current_best_tokens)
current_optimal_effects = start_effects_list
except Exception:
continue
if current_rmse >= float('inf'):
continue
print(f" Start RMSE: {current_rmse:.4f}")
cycle = 0
while True:
cycle += 1
has_improved_in_cycle = False
for token_name in token_names:
best_value = current_best_tokens[token_name]
original_val = best_value
for value in search_axes[token_name]:
if value == original_val: continue
tokens_to_test = current_best_tokens.copy()
tokens_to_test[token_name] = value
try:
effects_list = create_effects_from_parsed_terms(
self.adaptive_model_.parsed_terms_,
tokens_to_test,
self.adaptive_model_.default_alpha_p_,
data_info=data_info
)
rmse = self._evaluate_adaptive_config(effects_list, tokens_to_test)
if rmse < current_rmse:
current_rmse = rmse
current_optimal_effects = effects_list
best_value = value
has_improved_in_cycle = True
except Exception:
continue
current_best_tokens[token_name] = best_value
if not has_improved_in_cycle or cycle >= 5: break
if current_rmse < min_global_rmse:
print(f" >>> New Global Best found by {strategy['name']}! ({current_rmse:.4f})")
min_global_rmse = current_rmse
optimal_effects_list = current_optimal_effects
current_best_tokens_global = current_best_tokens
print("-" * 30)
print(f"Grid Search complete. Best RMSE: {min_global_rmse:.4f}")
if optimal_effects_list is None:
raise RuntimeError("Grid search failed.")
print(f"Best tokens: {current_best_tokens_global}")
final_model = AdaptiveTAM(
base_model=self.base_model_,
adaptive_formula=self.adaptive_formula_,
update_interval_periods=self.update_interval_periods_,
training_window_periods=self.training_window_periods_,
steps_per_period=self.steps_per_period_,
horizon_steps=self.horizon_steps_
)
final_adaptive_model_internal = StaticTAM(
formula=self.adaptive_model_.formula_,
group_col=self.base_model_.group_col_ if self.base_model_ is not None else self.group_col_,
date_col=self.base_model_.date_col_ if self.base_model_ is not None else self.date_col_,
_internal_effects_list=optimal_effects_list,
_internal_features_config=self.adaptive_model_.features_config_
)
final_model.adaptive_model_ = final_adaptive_model_internal
final_model.norm_params_ = self.norm_params_
final_model.unique_groups_ = self.unique_groups_
final_model.target_col_ = self.target_col_
final_model.predict_online(data_val)
return final_model
Separation of Concerns: Simulation vs. Inference¶
To guarantee speed and safety in operational production pipelines, AdaptiveTAM strictly separates the continuous historical simulation from out-of-sample inference.
Architectural Choice (The \(O(1)\) Inference Optimization):
fit(data): Instead of running the massive sliding-window simulation over the entire dataset,fit()leverages the_save_final_state()method. This function slices only the final available training window from the 4D tensor, solves the exact Primal linear system for that single step, and freezes the optimal coefficients (last_state_dict_) along with historical safety clipping bounds.predict(data): A purely deterministic, read-only method. It builds the design matrix \(\Phi\) for the new data and multiplies it directly against the frozen coefficients. By completely bypassing sliding-window tensorization and system resolution during inference, it guarantees blazing-fast \(O(1)\) execution time and absolute protection against target leakage.