Online Error Correction (AdaptiveTAM)

Navigation:

Standard Generalized Additive Models (GAMs) assume the underlying data generating process is globally stationary. However, real-world time series suffer from structural breaks, concept drift, and loss of exchangeability. While the underlying mathematics were explored in the theoretical formulations, AdaptiveTAM provides the production-grade hardware implementation.

It utilizes a Sliding Window tensorization approach to update the covariance matrices continuously, allowing the coefficients to smoothly adapt to new regimes without retraining the global model from scratch. This chapter details how this theory is translated into massively parallel PyTorch operations, strictly avoiding restrictive Python for loops.


Design Pattern: Instance Composition

The software architecture directly reflects the two-stage mathematical formulation of Online WeaKL. The AdaptiveTAM class does not inherit from StaticTAM; instead, it utilizes a Composition pattern.

Architectural Choice: By using composition, the framework prevents namespace collision and guarantees mathematical isolation between the physics-based macroscopic model and the reactive microscopic model. It owns and orchestrates two distinct instances:

  • self.base_model_: The expert instance, permanently frozen after its initial training.

  • self.adaptive_model_: A blank, template instance that will be cloned and dynamically retrained on each sliding window to correct the base model’s residuals.

src/tam/model/adaptative.py (Composition Pattern for the Two-Stage Model)
class AdaptiveTAM:
    r"""
    Initializes the AdaptiveTAM model.
    """
    
    def __init__(
        self,
        adaptive_formula: str,
        update_interval_periods: int,
        training_window_periods: int,
        steps_per_period: int,
        base_model: Optional[StaticTAM] = None,
        horizon_steps: int = 1,
        default_alpha_p: float = -9.0,
        group_col: Optional[str] = None,
        date_col: Optional[str] = None,
        add_base_effects: bool = False
    ):
        r"""
        Initializes the AdaptiveTAM model.

        Args:
            base_model: A fitted StaticTAM model instance.
            adaptive_formula: Formula for the adaptive correction model.
                              Features must be columns produced by base_model.decompose_prediction()
                              (e.g., 'effect_temp').
            update_interval_periods (int): The number of periods to skip before updating 
                the coefficients for a specific group (determines n_windows).
            training_window_periods (int): The historical look-back period used to solve 
                the local linear system for an independent group (determines num_samples_train).
            steps_per_period (int): The number of observations a single group experiences 
                within one logical period. 
                - If the data is monthly and the period is a month: 1.
                - If the data is 30-min, but group_col is 'Time of Day': 1 (since each group only sees one observation per day).
            horizon_steps (int): The forecasting horizon (H) used to prevent target leakage. 
                Enforces an information delay by truncating the last (H-1) samples from the 
                training buffer of every group simultaneously.
            default_alpha_p: Default regularization strength (log10).
        
        Raises:
            ValueError: If the base_model has not been fitted.
        """
        if base_model is not None and getattr(base_model, 'coefficients_', None) is None:
            raise ValueError("The base_model must be fitted before initializing AdaptiveTAM.")
        
        if add_base_effects and base_model is not None:
            for effect in base_model.effects_list_:
                effect_col = f"effect_{effect.feature_name}"
                if effect_col not in adaptive_formula:
                    adaptive_formula += f" + l({effect_col})"

        self.base_model_ = base_model
        self.adaptive_formula_ = adaptive_formula
        self.update_interval_periods_ = update_interval_periods
        self.training_window_periods_ = training_window_periods
        self.steps_per_period_ = steps_per_period
        self.horizon_steps_ = horizon_steps
        
        self.group_col_ = group_col or getattr(base_model, 'group_col_', "__dummy_group__")
        self.date_col_ = date_col or getattr(base_model, 'date_col_', "__dummy_date__")

        self.adaptive_model_ = StaticTAM(
            formula=adaptive_formula,
            group_col=self.group_col_,
            date_col=self.date_col_,
            default_alpha_p=default_alpha_p
        )
        
        self.coefficients_ = None
        self.norm_params_ = None
        self.unique_groups_ = None
        
        self.simulation_data_ = None
        self.predictions_ = None
        
        target_col_bm = getattr(base_model, 'target_col_', self.adaptive_model_.target_col_)
        self.target_col_ = self.adaptive_model_.target_col_ or f'Residual{target_col_bm}'

        if self.base_model_ is not None:
            base_target = getattr(self.base_model_, 'target_col_', None)
            if base_target and self.adaptive_model_.target_col_:
                expected_residual = f'Residual{base_target}'
                if self.adaptive_model_.target_col_ not in [expected_residual, base_target]:
                    warnings.warn(
                        f"Adaptive target '{self.adaptive_model_.target_col_}' does not match the base target "
                        f"'{base_target}' or its expected residual '{expected_residual}'. "
                        "This will cause a cross-target correction.",
                        UserWarning
                    )
        
        self.last_state_dict_ = None
        self.max_res_ = None
        self.min_res_ = None

Advanced PyTorch Indexing (4D Vectorization)

One of the major engineering challenges of online learning is extracting the training windows. Sequentially extracting 10,000 localized windows using a Python for loop would completely stall the CUDA pipeline due to excessive CPU-GPU kernel launch latency.

To solve this, the _transform_data_adaptive function (located in _data.py) exploits PyTorch’s advanced indexing and broadcasting.

Architectural Choices & Frequency Agnosticism: The code deliberately transitions from a rigid “Calendar Day” perspective to a Frequency-Agnostic Group Coordinate System. By using mathematical abstract variables like periods and steps_per_period, it perfectly isolates panel data entities regardless of whether the native sampling rate is monthly, daily, or half-hourly.

  1. Anchor Calculation: The code identifies all valid starting indices (start_indices) by walking backward through time, strictly ensuring no incomplete windows are processed.

  2. Causal Offset Generation (Preventing Target Leakage): It generates two constant tensors: predict_offsets (the indices of the prediction window) and train_offsets (the relative indices of the learning window). Crucially, to prevent Target Leakage in multi-step forecasting, the train_offsets are dynamically shifted backward by -(horizon_steps - 1). This physically truncates the “illegal future” from the training buffer before the tensor is ever evaluated on the GPU.

  3. Vectorized Extraction: By adding the column matrix of start indices (start_indices.view(-1, 1)) to the offset vectors, PyTorch instantly broadcasts a complete 2D index matrix. The framework then extracts the entirety of the causally safe historical data in a single, contiguous memory call: x_group[train_indices].

src/tam/model/_data.py (Vectorized Sliding Window Extraction Without For Loops)
def _transform_data_adaptive(
    data: pd.DataFrame,
    features: List[str],
    group_col: str,
    norm_params: Dict,
    unique_groups: List,
    target_col: str,
    update_interval_periods: int,
    training_window_periods: int,
    steps_per_period: int,
    horizon_steps: int = 1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    r"""
    Prepares data for adaptive learning using vectorized sliding window indexing.

    Returns tensors for (X_train, Y_train, X_predict) for each simulation step.

    Args:
        data: Validation/Test DataFrame.
        features: Feature list.
        group_col: Grouping column.
        norm_params: Normalization parameters.
        unique_groups: Group names.
        target_col: Target column.
        update_interval_periods: Prediction window size.
        training_window_periods: Training history size.
        steps_per_period: Samples per period.
        horizon_steps: Horizon of forecasting

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        (x_stacked, y_stacked, x_to_predict)
    """
    learning_size_steps = training_window_periods * steps_per_period
    window_size_steps = update_interval_periods * steps_per_period

    all_groups_x_train = []
    all_groups_y_train = []
    all_groups_x_predict = []
    
    if unique_groups is None:
        raise ValueError("`unique_groups` cannot be None.")
    
    prep_device = 'cpu'

    for group_name in unique_groups:
        if group_name not in norm_params:
            continue
            
        data_group = data[data[group_col] == group_name].reset_index(drop=True)
        
        params = norm_params.get(group_name)
        if params is None: continue
        
        total_available_steps = len(data_group)
        required_history = learning_size_steps + (horizon_steps - 1)
        if total_available_steps <= required_history:
            continue

        # Normalize group data
        data_group[features] = normalize(df_to_normalize=data_group[features], params=params)
        
        x_group = torch.tensor(data_group[features].values, dtype=torch.float32, device=prep_device)
        y_group = torch.tensor(data_group[target_col].values, dtype=torch.float32, device=prep_device).view(-1, 1)

        #  Calculate valid start indices (reverse chronological)
        start_indices_list = []
        first_predict_start = total_available_steps - (total_available_steps - learning_size_steps) % window_size_steps
        if first_predict_start == total_available_steps and total_available_steps > learning_size_steps:
             first_predict_start -= window_size_steps
        
        current_predict_start = first_predict_start
        while current_predict_start >= learning_size_steps:
            if current_predict_start + window_size_steps <= total_available_steps:
                start_indices_list.append(current_predict_start)
            current_predict_start -= window_size_steps
        
        if not start_indices_list:
            continue
            
        start_indices_list.reverse()
        start_indices = torch.tensor(start_indices_list, device=prep_device, dtype=torch.long)

        #  Vectorized Window Indexing
        train_end_offset = -(horizon_steps - 1) if horizon_steps > 1 else 0
        train_start_offset = train_end_offset - learning_size_steps       
        train_offsets = torch.arange(train_start_offset, train_end_offset, device=prep_device)
        predict_offsets = torch.arange(0, window_size_steps, device=prep_device)

        train_indices = start_indices.view(-1, 1) + train_offsets
        predict_indices = start_indices.view(-1, 1) + predict_offsets

        #  Gather
        group_x_train = x_group[train_indices]
        group_y_train = y_group[train_indices]
        group_x_predict = x_group[predict_indices]
        
        all_groups_x_train.append(group_x_train)
        all_groups_y_train.append(group_y_train)
        all_groups_x_predict.append(group_x_predict)

    if not all_groups_x_train:
        raise ValueError("No simulation data could be generated. Check dataset length/window sizes.")

    x_stacked = torch.stack(all_groups_x_train).to(TORCH_DEVICE)
    y_stacked = torch.stack(all_groups_y_train).to(TORCH_DEVICE)
    x_to_predict = torch.stack(all_groups_x_predict).to(TORCH_DEVICE)

    return x_stacked, y_stacked, x_to_predict

Dynamic Target and Preparation

The prepare_simulation method is executed before launching the online loop. Its role is to massively vectorize the dataset and shift the target space.

Architectural Choice: It calculates the macroscopic predictions (data_pred) and subtracts them from the true target to isolate the residuals (\(\epsilon_t = Y_t - \hat{Y}_{base}\)) upfront. Executing this in a single block across the entire validation history prevents an immense computational bottleneck that would occur if the base model were evaluated repeatedly inside the sequential sliding window loop.

src/tam/model/adaptative.py (Batch Residual Calculation and Tensorization)
    def prepare_simulation(self, data: pd.DataFrame) -> 'AdaptiveTAM':
        r"""
        Prepares tensors for the adaptive sliding-window simulation.

        This process:
        1. Computes base model effects (features) and residuals (targets).
        2. Normalizes the adaptive features.
        3. Constructs sliding-window tensors (X_train, Y_train, X_predict).

        Args:
            data: The dataset (validation or test) for the simulation.

        Returns:
            self: The instance with populated ``simulation_data_``.
        """
        
        if self.base_model_ is not None:
            data_bm = self.base_model_.decompose_prediction(data) 
            data_pred = self.base_model_.predict(data)
            target_col_bm = self.base_model_.target_col_
        else:
            data_bm = data.copy()
            data_pred = data.copy()
            target_col_bm = self.target_col_
       
        data_bm = _ensure_dummies(data_bm, self.group_col_, self.date_col_)
        data_pred = _ensure_dummies(data_pred, self.group_col_, self.date_col_)
        
        cols_float = data_bm.select_dtypes(include=['float64']).columns
        data_bm[cols_float] = data_bm[cols_float].astype('float32')

        est_col_name = f'Estimated{target_col_bm}'

        if est_col_name not in data_bm.columns:
            data_bm[est_col_name] = data_pred.get(est_col_name, 0.0)

        default_res_col = f'Residual{target_col_bm}'
        data_bm[default_res_col] = data_bm[target_col_bm] - data_bm[est_col_name]
        
        adaptive_features_config = self.adaptive_model_.features_config_
        
        if 'features' not in adaptive_features_config or not isinstance(adaptive_features_config['features'], list):
            raise KeyError("Invalid feature configuration in adaptive model.")
        
        adaptive_features = adaptive_features_config['features']
        
        required_cols = adaptive_features + [self.group_col_, self.target_col_, self.date_col_]
        _check_features(dataset=data_bm, required_features=required_cols)
        mask, balanced_data = _balance_groups(
            dataset=data_bm, group_col=self.group_col_, date_col=self.date_col_, method="fill"
        )
        
        data_info = self.adaptive_model_._get_data_info(balanced_data)

        if not self.adaptive_model_.effects_list_ and not self.adaptive_model_.is_grid_search_template_:
            self.adaptive_model_.effects_list_ = create_effects_from_parsed_terms(
                self.adaptive_model_.parsed_terms_, 
                token_values={}, 
                default_alpha_p=self.adaptive_model_.default_alpha_p_,
                data_info=data_info
            )

        self.norm_params_, self.unique_groups_ = _fit_normalization_params(
            data=balanced_data, 
            features=adaptive_features, 
            group_col=self.group_col_
        )
        
        hw.empty_cache()

        x_stacked, y_stacked, x_to_predict = _transform_data_adaptive(
            data=balanced_data, 
            features=adaptive_features, 
            group_col=self.group_col_,
            norm_params=self.norm_params_,
            unique_groups=self.unique_groups_,
            target_col=self.target_col_,
            update_interval_periods=self.update_interval_periods_,
            training_window_periods=self.training_window_periods_,
            steps_per_period=self.steps_per_period_,
            horizon_steps=self.horizon_steps_
        )
        
        self.simulation_data_ = (
            x_stacked.cpu(), 
            y_stacked.cpu(), 
            x_to_predict.cpu(), 
            balanced_data, data_bm, target_col_bm, mask
        )
        del x_stacked, y_stacked, x_to_predict
        hw.empty_cache()

        return self

Batch Execution and Numerical Safety (Clipping)

In the simulation() method, calculations are not executed step-by-step. The solver must simultaneously resolve the exact local empirical risk minimization problem for all sliding windows across all groups.

Architectural Choice 1: Dimensional Flattening: The generated tensors naturally possess 4 dimensions (n_groups, n_windows, num_samples, features). The standard Core mathematical solvers (solve_linear_system) are built exclusively for 3D batches (batch, samples, features). Instead of rewriting the core linear algebra module, simulation() flattens the first two dimensions: total_items = n_groups * n_windows. This elegantly maps the massive localized window problem into a standard, independent batch structure, saturating the GPU cores perfectly.

Architectural Choice 2: OOM Resilience: Because total_items can easily exceed tens of thousands of matrices, the method implements the dynamic hw.handle_oom fallback mechanism. If the GPU VRAM exhausts, it catches the torch.OutOfMemoryError and halves the safe_batch_size iteratively, ensuring the simulation completes gracefully even on constrained hardware.

Architectural Choice 3: Algorithmic Safeguard (Target Clipping): Adaptive models can diverge wildly if they train on a window containing purely anomalous values (e.g., a sensor failure). To guarantee safety in production, the implementation applies strict bounded clipping. The final adapted prediction is physically bounded by the maximum and minimum historical values of the observed base residuals (data_bm[self.target_col_].max()), preventing the corrector from extrapolating unrealistic deviations during unexpected exogenous shocks.

src/tam/model/adaptative.py (Batch Linear Resolution, OOM Handling, and Safety Clipping)
    def simulation(self) -> pd.DataFrame:
        r"""
        Executes the sliding-window simulation using scalable batch processing.
        
        It flattens the group and window dimensions to treat each sliding window 
        as an independent linear system. This strictly preserves the mathematical 
        regularization scale and prevents VRAM exhaustion when processing deep 
        historical data across multiple groups.
        """
        if self.simulation_data_ is None:
            raise RuntimeError("Simulation data is uninitialized. Call 'prepare_simulation()' first.")
            
        if self.adaptive_model_.is_grid_search_template_:
            raise RuntimeError("Model contains grid search tokens. Call 'grid_search_fit()' first.")

        x_stacked, y_stacked, x_to_predict, balanced_data, data_bm, target_col_bm, mask = self.simulation_data_
        
        n_groups = x_stacked.shape[0]
        n_windows = x_stacked.shape[1]
        num_samples_train = x_stacked.shape[2]
        window_size_steps = x_to_predict.shape[2]
        
        total_items = n_groups * n_windows
        
        x_flat = x_stacked.view(total_items, num_samples_train, -1)
        y_flat = y_stacked.view(total_items, num_samples_train, -1)
        x_pred_flat = x_to_predict.view(total_items, window_size_steps, -1)
        
        run_device = TORCH_DEVICE
        
        sobolev_matrix = self.adaptive_model_._build_penalty_matrix().to(run_device)
        loss_L_star_L = self.adaptive_model_._build_loss_matrix().to(run_device)

        dummy_x = x_flat[0:1].to(run_device)
        dummy_phi = self.adaptive_model_._build_design_matrix(dummy_x)
        n_coeffs = dummy_phi.shape[-1]
        del dummy_x, dummy_phi
        
        safe_batch_size = get_safe_window_batch_size(
            num_samples_per_window=num_samples_train,
            total_d=n_coeffs,
            device=run_device
        )
        safe_batch_size = min(safe_batch_size, total_items)
        all_predictions = []
        
        start_idx = 0
        while start_idx < total_items:
            end_idx = min(start_idx + safe_batch_size, total_items)
            
            try:
                batch_x = x_flat[start_idx:end_idx].to(run_device)
                batch_y = y_flat[start_idx:end_idx].to(run_device)
                batch_x_pred = x_pred_flat[start_idx:end_idx].to(run_device)

                phi_batch = self.adaptive_model_._build_design_matrix(batch_x)

                cov_X, cov_XY = _compute_weighted_covariances(phi_batch, batch_y, loss_L_star_L)
                coeffs_batch = solve_linear_system(cov_X, cov_XY, sobolev_matrix, num_samples_train)
                
                phi_pred_batch = self.adaptive_model_._build_design_matrix(batch_x_pred)
                preds_batch = _predict_from_coeffs(phi_pred_batch, coeffs_batch)
                
                all_predictions.append(preds_batch.detach().cpu())
                
                del batch_x, batch_y, batch_x_pred, phi_batch, cov_X, cov_XY, coeffs_batch, phi_pred_batch, preds_batch
                start_idx += safe_batch_size
                
            except (torch.OutOfMemoryError, MemoryError):
                safe_batch_size, run_device = hw.handle_oom(
                    current_batch=safe_batch_size, 
                    device=run_device, 
                    context="adaptive simulation batch reduction", 
                    allow_cpu_fallback=True
                )
                
                sobolev_matrix = sobolev_matrix.to(run_device)
                loss_L_star_L = loss_L_star_L.to(run_device)
                continue

        hw.empty_cache()

        predictions_flat = torch.cat(all_predictions, dim=0).squeeze(-1)
        predictions_cpu = predictions_flat.view(n_groups, n_windows, window_size_steps)

        data_with_predictions = _reassemble_predictions(
            original_data=balanced_data, 
            predictions_stacked=predictions_cpu,
            group_col=self.group_col_,
            unique_groups=self.unique_groups_, 
            target_col=self.target_col_
        )

        max_res = np.float32(data_bm[self.target_col_].max())
        min_res = np.float32(data_bm[self.target_col_].min())
        est_col = f'Estimated{self.target_col_}'
        
        data_with_predictions.loc[data_with_predictions[est_col] >= max_res, est_col] = max_res
        data_with_predictions.loc[data_with_predictions[est_col] <= min_res, est_col] = min_res
                
        adapted_col = f"AdaptedEstimated{target_col_bm}"
        if self.target_col_ == target_col_bm:
            data_with_predictions[adapted_col] = data_with_predictions[est_col].fillna(0)
        else:
            data_with_predictions[adapted_col] = (
                data_with_predictions[f'Estimated{target_col_bm}'] + 
                data_with_predictions[est_col].fillna(0)
            )
                                                 
        self.predictions_ = _cleanup_dummies(data_with_predictions[mask], self.group_col_, self.date_col_)
        return self.predictions_

Coordinate Descent Search Algorithm

The grid_search_fit() method implements the “Multi-Start” hyperparameter solver described in the theory. It evaluates pairs (e.g., training_window_periods, default \(\lambda\)) sequentially, moving along a single axis at a time until local convergence is reached.

Architectural Choice: Because the global objective function evaluated across overlapping sliding windows creates highly non-convex search topologies, standard analytical Generalized Cross-Validation (GCV) breaks down. Coordinate Descent provides robust navigation through these complex spaces. To systematically avoid poor local minima, the algorithm leverages three distinct initialization strategies (Conservative, Median, Aggressive) before executing the cyclic axis search.

src/tam/model/adaptative.py (Coordinate Descent Algorithm for Hyperparameters)
    def grid_search_fit(
            self,
            data_val: pd.DataFrame,
            grid_search_config: dict
        ) -> 'AdaptiveTAM':
            r"""
            Optimizes the adaptive model using Multi-Start Coordinate Descent.

            Optimizes hyperparameters to minimize the RMSE of the final
            adapted prediction over the simulation period.

            Args:
                data_val: Validation DataFrame for simulation.
                grid_search_config: Dictionary mapping tokens to value lists.

            Returns:
                AdaptiveTAM: A new fitted model instance.
            """
            print("--- Starting Grid Search (Multi-Start Coordinate Descent) ---")
            
            self.prepare_simulation(data_val)
            
            search_axes, token_names = self.adaptive_model_._parse_grid_axes(grid_search_config)
            
            data_info = self.adaptive_model_._get_data_info(data_val)

            if not token_names:
                raise ValueError("No grid tokens found in adaptive formula or config.")

            print(f"Optimizing axes: {token_names}")
            
            start_points = [
                {"name": "Conservative", "tokens": {t: max(vals) if ('ap' in t or 'lambda_p' in t) else min(vals) for t, vals in search_axes.items()}},
                {"name": "Median", "tokens": {t: vals[len(vals)//2] for t, vals in search_axes.items()}},
                {"name": "Aggressive", "tokens": {t: min(vals) if ('ap' in t or 'lambda_p' in t) else max(vals) for t, vals in search_axes.items()}}
            ]

            min_global_rmse = float('inf')
            optimal_effects_list = None
            current_best_tokens_global = None

            for strategy in start_points:
                print(f"\n=== Strategy: {strategy['name']} Start ===")
                current_best_tokens = strategy["tokens"].copy()
                
                try:
                    start_effects_list = create_effects_from_parsed_terms(
                        self.adaptive_model_.parsed_terms_,
                        current_best_tokens,
                        self.adaptive_model_.default_alpha_p_,
                        data_info=data_info
                    )
                    current_rmse = self._evaluate_adaptive_config(start_effects_list, current_best_tokens)
                    current_optimal_effects = start_effects_list
                except Exception:
                    continue
                    
                if current_rmse >= float('inf'):
                    continue
                
                print(f"  Start RMSE: {current_rmse:.4f}")

                cycle = 0
                while True:
                    cycle += 1
                    has_improved_in_cycle = False
                    
                    for token_name in token_names:
                        best_value = current_best_tokens[token_name]
                        original_val = best_value
                        
                        for value in search_axes[token_name]:
                            if value == original_val: continue 
                                
                            tokens_to_test = current_best_tokens.copy()
                            tokens_to_test[token_name] = value
                            
                            try:
                                effects_list = create_effects_from_parsed_terms(
                                    self.adaptive_model_.parsed_terms_,
                                    tokens_to_test,
                                    self.adaptive_model_.default_alpha_p_,
                                    data_info=data_info
                                )
                                rmse = self._evaluate_adaptive_config(effects_list, tokens_to_test)

                                if rmse < current_rmse:
                                    current_rmse = rmse
                                    current_optimal_effects = effects_list
                                    best_value = value
                                    has_improved_in_cycle = True
                            except Exception:
                                continue
                        
                        current_best_tokens[token_name] = best_value
                    
                    if not has_improved_in_cycle or cycle >= 5: break

                if current_rmse < min_global_rmse:
                    print(f"  >>> New Global Best found by {strategy['name']}! ({current_rmse:.4f})")
                    min_global_rmse = current_rmse
                    optimal_effects_list = current_optimal_effects
                    current_best_tokens_global = current_best_tokens

            print("-" * 30)
            print(f"Grid Search complete. Best RMSE: {min_global_rmse:.4f}")
            
            if optimal_effects_list is None:
                raise RuntimeError("Grid search failed.")

            print(f"Best tokens: {current_best_tokens_global}")

            final_model = AdaptiveTAM(
                base_model=self.base_model_,
                adaptive_formula=self.adaptive_formula_,
                update_interval_periods=self.update_interval_periods_,
                training_window_periods=self.training_window_periods_,
                steps_per_period=self.steps_per_period_,
                horizon_steps=self.horizon_steps_
            )
            
            final_adaptive_model_internal = StaticTAM(
                formula=self.adaptive_model_.formula_,
                group_col=self.base_model_.group_col_ if self.base_model_ is not None else self.group_col_,
                date_col=self.base_model_.date_col_ if self.base_model_ is not None else self.date_col_,
                _internal_effects_list=optimal_effects_list,
                _internal_features_config=self.adaptive_model_.features_config_
            )
            
            final_model.adaptive_model_ = final_adaptive_model_internal
            final_model.norm_params_ = self.norm_params_
            final_model.unique_groups_ = self.unique_groups_
            final_model.target_col_ = self.target_col_

            final_model.predict_online(data_val)
            
            return final_model

Separation of Concerns: Simulation vs. Inference

To guarantee speed and safety in operational production pipelines, AdaptiveTAM strictly separates the continuous historical simulation from out-of-sample inference.

Architectural Choice (The \(O(1)\) Inference Optimization):

  • fit(data): Instead of running the massive sliding-window simulation over the entire dataset, fit() leverages the _save_final_state() method. This function slices only the final available training window from the 4D tensor, solves the exact Primal linear system for that single step, and freezes the optimal coefficients (last_state_dict_) along with historical safety clipping bounds.

  • predict(data): A purely deterministic, read-only method. It builds the design matrix \(\Phi\) for the new data and multiplies it directly against the frozen coefficients. By completely bypassing sliding-window tensorization and system resolution during inference, it guarantees blazing-fast \(O(1)\) execution time and absolute protection against target leakage.