Auto-ML via Generalized Cross-Validation (GCV)

Navigation:

Hyperparameter tuning via standard Grid Search is computationally expensive. While earlier versions relied heavily on “Multi-Start” Coordinate Descent for all parameters, the framework fully automated the continuous penalty optimization via StaticTAM.auto_fit().

This engine leverages the smart_solve_gcv dispatcher, implementing Golub’s trace trick [Golub et al., 1979] to find the optimal regularization parameter (\(\lambda\)) for each effect simultaneously, without requiring K-fold retraining.


The Orchestrator and API Entry Point

The auto_fit method acts as the user-facing bridge. It aligns the data, establishes the loss bounds, and initializes the Multiple Smoothing Parameter (MSP) estimation algorithm.

src/tam/model/additive.py (MSP-GCV Auto-Fitting Loop)
    def auto_fit(
        self, 
        data_train: pd.DataFrame, 
        alpha_p_bounds: Tuple[float, float] = (-30.0, 6.0),
        number_of_steps: int = 10,
        alpha_p_list: Optional[List[float]] = None,
        gamma: float = 1.4,
        verbose: bool = False
    ) -> 'StaticTAM':
        """
        Automatically trains the model by finding the optimal regularization per effect.
        
        Uses Generalized Cross Validation (GCV) with Multiple Smoothing Parameter (MSP)
        estimation to balance model fit (bias) and complexity (variance) without 
        requiring a validation set. 
        
        It utilizes a Coordinate Descent algorithm starting from the formula's initial values.
        
        Args:
            data_train: Training DataFrame.
            alpha_p_bounds: Search range constraints for local steps (e.g., -30.0 to 6.0).
            number_of_steps: Number of subdivisions within alpha_p_bounds to define the step size.
            alpha_p_list: Explicit list of log10(lambda_p) coordinates to test. If provided,
                        alpha_p_bounds and number_of_steps are ignored.
            gamma: Inflation factor for the effective degrees of freedom. Values > 1.0 
                   (typically 1.4 to 1.5) force smoother models, preventing GCV 
                   from overfitting when data has autocorrelated errors.
            verbose: print verbose details if True
            
        Returns:
            self: The trained model instance with optimal lambda_ps applied.
        """
        print("--- Auto-Fitting with MSP-GCV (Multiple Smoothing Parameters) ---")

        data_train = _ensure_dummies(data_train, self.group_col_, self.date_col_)
        
        _, balanced_data = _balance_groups(
            dataset=data_train, group_col=self.group_col_, date_col=self.date_col_, method="drop"
        )
        x_train, y_train, self.unique_groups_ = self._prepare_data(
            balanced_data, target_col=self.target_col_
        )
        
        loss_L = self._build_loss_matrix()
        
        self.coefficients_, best_lambda_ps, gcv_score = smart_solve_gcv(
            x_data=x_train,
            y_data=y_train,
            effects_list=self.effects_list_,
            loss_matrix=loss_L,
            alpha_p_bounds=alpha_p_bounds,
            number_of_steps=number_of_steps,
            alpha_p_list=alpha_p_list,
            gamma=gamma,
            verbose=verbose
        )
        
        print(f"\nFinal GCV Score: {gcv_score:.4f}")
        print("Optimal lambda_ps found per effect:")
        for i, effect in enumerate(self.effects_list_):
            effect.lambda_p = best_lambda_ps[i]
            print(f" - {effect.feature_name}: {best_lambda_ps[i]:.2e} (log10 = {np.log10(best_lambda_ps[i]):.2f})")
        
        return self

VRAM Protection and Chunked Covariances

The smart_solve_gcv function operates as the exact algebraic engine. However, computing the trace of the inverse matrix requires the explicit instantiation of the global covariance matrix \(\Phi^T \Phi\).

To strictly prevent Out-Of-Memory (OOM) faults during the iterative search, the engine queries the hardware oracle to enforce an advanced 3-Tiered Memory Waterfall specifically designed for Coordinate Descent:

  1. The Caching Fast-Path (< 30% Limit): If caching the unpenalized covariance matrices (cov_X and cov_XY) for all spatial groups requires less than \(30\%\) of available VRAM, the orchestrator caches the entire system strictly on the GPU. This completely bypasses the massive \(\mathcal{O}(G \times T \times D^2)\) tensor product during the search phase.

  2. The Iterative Accumulation (< 40% Limit): If a global GPU cache is impossible, but a single data group requires less than \(40\%\) of VRAM, the engine falls back to looping over the groups on the GPU natively, accumulating covariances iteratively and aggressively freeing memory per step.

  3. Extreme CPU Offloading (> 40% Fallback): If evaluating a single group exceeds the \(40\%\) VRAM safety threshold (frequently triggered by massive Neural or Tensor interactions), the orchestrator intercepts the impending GPU crash and gracefully forces eval_device = torch.device('cpu'), offloading the entire cyclic trace resolution to the Host system’s RAM.

Because the system allows different regularization weights for different topological bases (e.g., one \(\lambda\) for Splines, another for Fourier), the global penalty matrix \(P\) must be updated dynamically during the optimization loop.

The dispatcher pre-calculates the index boundaries (start, end) for every individual effect in the formula and maps them to a blocks array. During each evaluation of the GCV objective function, it injects the actively tested \(\lambda\) values strictly into their corresponding diagonal blocks, ensuring zero cross-contamination between distinct structural penalties.


Discrete Coordinate Descent Optimization

To optimize the multiple parameters efficiently across the block-diagonal structure, smart_solve_gcv deploys a memory-safe Discrete Coordinate Descent loop [Wright, 2015].

  1. Initialization: It extracts the starting \(\log_{10}(\lambda)\) values from the effects_list.

  2. Iterative Search: For a maximum of 15 cycles, it iterates through each individual effect.

  3. Local Perturbation: It tests localized candidates (original - step_size and original + step_size) while holding all other effects’ penalties constant.

  4. Scoring: The system recalculates the exact cyclic trace and GCV score for each perturbation. If a candidate improves the global GCV score, it becomes the new baseline.

The loop naturally terminates early if a full cycle completes without any parameter achieving a lower GCV score. Once optimal parameters are found, it triggers a final dense inversion using solve_linear_system to map the optimal coefficients exactly to the target distribution.

src/tam/model/_dispatcher_gcv.py (Memory-Safe MSP GCV Solver)
def smart_solve_gcv(
    x_data: torch.Tensor,
    y_data: torch.Tensor,
    effects_list: List[BaseEffect],
    loss_matrix: torch.Tensor,
    alpha_p_bounds: Tuple[float, float],
    number_of_steps: int = 10,
    alpha_p_list: Optional[List[float]] = None,
    gamma: float = 1.4,
    verbose: bool = False
) -> Tuple[torch.Tensor, np.ndarray, float]:
    r"""
    Memory-safe Generalized Cross Validation (GCV) solver.
    Dynamically routes matrix inversions and chunking based on available VRAM.
    """
    run_device = x_data.device
    num_samples = x_data.shape[1]
    n_groups = x_data.shape[0]
    
    dummy_x = x_data[:, 0:1, :].to(run_device)
    dummy_phi = build_phi_from_effects(dummy_x, effects_list)
    total_d = dummy_phi.shape[-1]
    del dummy_x, dummy_phi
    
    if total_d > 7500:
        raise MemoryError(
            f"Feature dimension D={total_d} is too massive for Generalized Cross Validation (GCV). "
            "GCV requires computing the exact trace of the dense inverse covariance matrix, "
            "which will cause a severe Out-Of-Memory (OOM) crash on your GPU. "
            "Please use `grid_search_fit()` instead, which utilizes matrix-free Conjugate Gradient routing."
        )

    blocks = []
    c_idx = 0
    for e in effects_list:
        k = e.get_n_coeffs()
        mat = e.build_penalty_matrix()
        if mat.is_sparse: 
            mat = mat.to_dense()
        blocks.append((c_idx, c_idx + k, mat))
        c_idx += k
        
    n_effects = len(effects_list)
        
    bytes_per_group = (total_d * total_d * 8) * 5
    total_bytes = bytes_per_group * n_groups
    available_mem = hw.get_available_memory()
    
    def _get_chunked_covs(x_subset: torch.Tensor, y_subset: torch.Tensor):
        nonlocal run_device
        n_groups_in_subset = x_subset.shape[0]
        num_samples_in_subset = x_subset.shape[1]
        
        cov_x_total = torch.zeros((n_groups_in_subset, total_d, total_d), dtype=torch.get_default_dtype(), device=run_device)
        cov_xy_total = torch.zeros((n_groups_in_subset, total_d, y_subset.shape[-1]), dtype=torch.get_default_dtype(), device=run_device)
        Y_sq_total = torch.zeros(n_groups_in_subset, dtype=torch.get_default_dtype(), device=run_device)
        
        available_bytes = hw.get_available_memory()
        allocatable_bytes = available_bytes * 0.8
        bytes_per_group_full_n = num_samples_in_subset * total_d * 8 * 5.0 
        safe_group_batch = max(1, int(allocatable_bytes // bytes_per_group_full_n)) if bytes_per_group_full_n > 0 else 1
        
        g_start = 0
        while g_start < n_groups_in_subset:
            g_end = min(g_start + safe_group_batch, n_groups_in_subset)
            current_sub_batch_size = g_end - g_start
            
            try:
                x_chunk = x_subset[g_start:g_end, :, :].to(run_device)
                y_chunk = y_subset[g_start:g_end, :, :].to(run_device)
                phi_chunk = build_phi_from_effects(x_chunk, effects_list)
                
                if loss_matrix.shape[0] == 1:
                    L_sqrt = loss_matrix[0, 0].sqrt()
                    phi_weighted = phi_chunk * L_sqrt
                    cov_x_total[g_start:g_end] = phi_weighted.mT @ phi_weighted
                    cov_xy_total[g_start:g_end] = phi_chunk.mT @ (y_chunk * loss_matrix[0, 0])
                    Y_sq_total[g_start:g_end] = torch.sum(torch.abs(y_chunk)**2 * loss_matrix[0, 0], dim=1).squeeze(-1)
                    del phi_weighted, L_sqrt
                else:
                    cov_x_total[g_start:g_end] = phi_chunk.mT @ phi_chunk
                    y_weighted = y_chunk.to(phi_chunk.dtype) @ loss_matrix
                    cov_xy_total[g_start:g_end] = phi_chunk.mT @ y_weighted
                    Y_sq_total[g_start:g_end] = torch.sum(torch.abs(y_chunk)**2, dim=1).squeeze(-1)
                    del y_weighted
                    
                del x_chunk, y_chunk, phi_chunk
                g_start += current_sub_batch_size
                
            except (torch.OutOfMemoryError, MemoryError):
                if safe_group_batch > 1:
                    safe_group_batch, run_device = hw.handle_oom(
                        current_batch=safe_group_batch, 
                        device=run_device, 
                        context="GCV covariance group reduction", 
                        allow_cpu_fallback=False
                    )
                    continue
                else:
                    raise RuntimeError("A single full group exceeds available physical memory during GCV covariance computation.")
                
        return cov_x_total, cov_xy_total, Y_sq_total

    initial_alpha_ps = np.array([
        np.log10(e.lambda_p) if e.lambda_p > 0 else alpha_p_bounds[0] 
        for e in effects_list
    ], dtype=np.float64)
    
    step_size = (alpha_p_bounds[1] - alpha_p_bounds[0]) / float(number_of_steps)

    if total_bytes < available_mem * 0.3:
        if verbose:
            print("[GCV Engine] VRAM footprint < 30%. Caching covariances globally on GPU.")
        cov_X, cov_XY, Y_sq = _get_chunked_covs(x_data, y_data)
        current_penalty = torch.zeros((total_d, total_d), dtype=torch.get_default_dtype(), device=run_device)
        
        def gcv_objective(alpha_ps: np.ndarray) -> float:
            current_penalty.zero_()
            for i, current_alpha in enumerate(alpha_ps):
                start, end, mat = blocks[i]
                current_penalty[start:end, start:end] = mat.to(run_device) * (10.0 ** current_alpha)
                
            score = compute_gcv_score(
                cov_X=cov_X, 
                cov_XY=cov_XY, 
                Y_sq=Y_sq,
                penalty_M_star_M=current_penalty, 
                lambda_p=1.0, 
                n_samples=num_samples,
                gamma=gamma
            )
            return score.item()

        current_alpha_ps = np.copy(initial_alpha_ps)
        best_gcv = gcv_objective(current_alpha_ps)
        
        cycle = 0
        max_cycles = 15
        
        while cycle < max_cycles:
            cycle += 1
            improved_in_cycle = False
            
            for i in range(n_effects):
                original_val = current_alpha_ps[i]
                best_val_for_effect = original_val
                
                if alpha_p_list is not None:
                    candidates = alpha_p_list
                else:
                    candidates = [
                        np.clip(original_val - step_size, alpha_p_bounds[0], alpha_p_bounds[1]),
                        np.clip(original_val + step_size, alpha_p_bounds[0], alpha_p_bounds[1])
                    ]
                    
                for cand in candidates:
                    if np.isclose(cand, original_val):
                        continue
                        
                    current_alpha_ps[i] = cand
                    score = gcv_objective(current_alpha_ps)
                    
                    if score < best_gcv:
                        best_gcv = score
                        best_val_for_effect = cand
                        improved_in_cycle = True
                        
                current_alpha_ps[i] = best_val_for_effect
                
            if not improved_in_cycle:
                break
                
        best_lambda_ps = 10.0 ** current_alpha_ps
        
        current_penalty.zero_()
        for i, a in enumerate(best_lambda_ps):
            start, end, mat = blocks[i]
            current_penalty[start:end, start:end] = mat.to(run_device) * a
            
        coeffs = solve_linear_system(cov_X, cov_XY, current_penalty, num_samples)
        
        return coeffs, best_lambda_ps, best_gcv

    else:
        eval_device = run_device if bytes_per_group < available_mem * 0.4 else torch.device('cpu')
        if eval_device.type == 'cpu':
            print("Notice: KxK matrix inversions exceed GPU VRAM. Routing solver to CPU.")
        else:
            print("Notice: Total system exceeds VRAM. Processing GCV in group chunks.")
            
        cov_X_list, cov_XY_list, Y_sq_list = [], [], []
        for g in range(n_groups):
            cx, cxy, y_sq = _get_chunked_covs(x_data[g:g+1], y_data[g:g+1])
            cov_X_list.append(cx.cpu())
            cov_XY_list.append(cxy.cpu())
            Y_sq_list.append(y_sq.cpu())
            hw.empty_cache()
            
        current_penalty = torch.zeros((total_d, total_d), dtype=torch.get_default_dtype(), device=eval_device)
        
        def gcv_objective(alpha_ps: np.ndarray) -> float:
            current_penalty.zero_()
            for i, current_alpha in enumerate(alpha_ps):
                start, end, mat = blocks[i]
                current_penalty[start:end, start:end] = mat.to(eval_device) * (10.0 ** current_alpha)
                
            total_score = 0.0
            for g in range(n_groups):
                try:
                    score = compute_gcv_score(
                        cov_X=cov_X_list[g].to(eval_device), 
                        cov_XY=cov_XY_list[g].to(eval_device), 
                        Y_sq=Y_sq_list[g].to(eval_device),
                        penalty_M_star_M=current_penalty, 
                        lambda_p=1.0, 
                        n_samples=num_samples,
                        gamma=gamma
                    )
                    total_score += score.item()
                except (torch.OutOfMemoryError, MemoryError):
                    hw.empty_cache()
                    return float('inf')
            return total_score / n_groups

        current_alpha_ps = np.copy(initial_alpha_ps)
        best_gcv = gcv_objective(current_alpha_ps)
        
        cycle = 0
        max_cycles = 15
        
        while cycle < max_cycles:
            cycle += 1
            improved_in_cycle = False
            
            for i in range(n_effects):
                original_val = current_alpha_ps[i]
                best_val_for_effect = original_val
                
                if alpha_p_list is not None:
                    candidates = alpha_p_list
                else:
                    candidates = [
                        np.clip(original_val - step_size, alpha_p_bounds[0], alpha_p_bounds[1]),
                        np.clip(original_val + step_size, alpha_p_bounds[0], alpha_p_bounds[1])
                    ]
                    
                for cand in candidates:
                    if np.isclose(cand, original_val):
                        continue
                        
                    current_alpha_ps[i] = cand
                    score = gcv_objective(current_alpha_ps)
                    
                    if score < best_gcv:
                        best_gcv = score
                        best_val_for_effect = cand
                        improved_in_cycle = True
                        
                current_alpha_ps[i] = best_val_for_effect
                
            if not improved_in_cycle:
                break
                
        best_lambda_ps = 10.0 ** current_alpha_ps
        
        current_penalty.zero_()
        for i, a in enumerate(best_lambda_ps):
            start, end, mat = blocks[i]
            current_penalty[start:end, start:end] = mat.to(eval_device) * a
        
        group_coeffs = []
        for g in range(n_groups):
            coeffs_g = solve_linear_system(
                cov_X_list[g].to(eval_device), 
                cov_XY_list[g].to(eval_device), 
                current_penalty, 
                num_samples
            )
            group_coeffs.append(coeffs_g.cpu())
            
        coeffs = torch.cat(group_coeffs, dim=0)
        return coeffs, best_lambda_ps, best_gcv