The Additive API and Object-Oriented Architecture

Navigation:

This chapter details the overarching software engineering pipeline of the TAM framework. It explains how high-level user formulas are parsed, routed, and translated into robust linear algebra operations across the primary orchestration scripts (utils.py, _base.py, _factory.py, and additive.py), structurally grounded in the Primal resolution theory [Doumèche et al., 2025].

To avoid redundancy, low-level data padding, hardware dispatching, and specific effect implementations are delegated to their respective documentation files.

Global Configurations (utils.py)

Before any model is instantiated, the framework establishes a unified global configuration via utils.py.

To maintain mathematical determinism in the exact Primal inversion and guarantee convergence [Doumèche et al., 2025], the framework dynamically forces PyTorch and NumPy to utilize 64-bit precision (float64) where the hardware supports it. This initialization acts as the absolute source of truth for the device (TORCH_DEVICE) and the precision target (NUMPY_DTYPE) across the entire architecture.

from .hardware import hw
TORCH_DEVICE = hw.device

if hw.supports_float64:
    torch.set_default_dtype(torch.float64)
    NUMPY_DTYPE = np.float64
else:
    torch.set_default_dtype(torch.float32)
    NUMPY_DTYPE = np.float32

The Foundation: _base.py

The _base.py script acts as the structural foundation for all models in the framework. It defines the abstract class BaseTAM, which orchestrates the standardized control flow.

To ensure a consistent API without duplicating boilerplate tensor operations across advanced Meta-Learners, BaseTAM utilizes the Object-Oriented Template Method Pattern.

  • It explicitly manages state definitions (coefficients_, norm_params_).

  • It standardizes the temporal alignment and handling of missing groups. (For exact tensor padding logic, refer to the Data Pipeline).

  • It defines the continuous optimization problem logically, forcing child classes to implement the specific construction of the design matrix \(\Phi\) and the penalty matrix \(P\) required to stabilize the regularized normal equations [Hoerl and Kennard, 1970].

class BaseTAM(ABC):
    r"""
    Abstract Base Class for TAM models.

    Defines the skeleton for training and prediction. Subclasses (e.g., `StaticTAM`)
    must implement the abstract methods to define how design matrices (Phi),
    penalty matrices (P), and loss matrices (L) are constructed.

    """
    
    def __init__(self):
        # --- Fitted Model Attributes ---
        self.coefficients_: Optional[torch.Tensor] = None
        self.norm_params_: Optional[Dict] = None
        self.unique_groups_: Optional[List] = None
        self.effects_list_: Optional[List] = None

        # --- Configuration Attributes ---
        self.features_config_: Optional[Dict] = None
        self.group_col_: Optional[str] = None
        self.target_col_: Optional[str] = None
        self.date_col_: Optional[str] = None

    @abstractmethod
    def _prepare_data(
        self, 
        data: pd.DataFrame, 
        target_col: Optional[str] = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], List]:
        r"""
        Transforms the raw DataFrame into normalized, 3D-stacked tensors.

        Args:
            data: Input DataFrame (must be pre-balanced).
            target_col: Name of the target column (None for inference).

        Returns:
            Tuple containing:
            - Feature tensor (x_stacked)
            - Target tensor (y_stacked) or None
            - List of unique groups processed
        """
        raise NotImplementedError

    @abstractmethod
    def _build_design_matrix(self, x_data: torch.Tensor) -> torch.Tensor:
        r"""
        Constructs the design matrix Phi from input features.

        Args:
            x_data: Input feature tensor (n_groups, n_samples, n_features).

        Returns:
            Design matrix (n_groups, n_samples, n_total_coeffs).
        """
        raise NotImplementedError

    @abstractmethod
    def _build_penalty_matrix(self) -> torch.Tensor:
        r"""
        Constructs the global regularization matrix P (or M*M).

        Returns:
            Penalty matrix (n_total_coeffs, n_total_coeffs).
        """
        raise NotImplementedError

    @abstractmethod
    def _build_loss_matrix(self) -> torch.Tensor:
        r"""
        Constructs the loss weighting matrix L*L.

        Returns:
            Loss matrix (n_samples, n_samples).
        """
        raise NotImplementedError

The Factory Orchestration (_factory.py & _base_effects.py)

Instead of hardcoding basis functions into the solver, StaticTAM relies on an explicit Dependency Injection architecture. It expects components that conform strictly to the BaseEffect interface (_base_effects.py).

To bridge the user’s R-style formula string to these concrete interface objects, the framework uses the create_effects_from_parsed_terms factory. This function is engineered specifically to support architectural Grid Searches. By passing a token_values dictionary, the factory dynamically substitutes string variables with concrete hyperparameters (e.g., swapping 'gk_la' for 10), allowing the solver to rebuild massive architectures on the fly without re-parsing the original regex structure.

(For the specific mapping of every individual effect, refer to The Spectral Dictionary).

def create_effects_from_parsed_terms(
    parsed_terms: List[Dict],
    token_values: Dict[str, Any],
    default_alpha_p: float,
    include_offset: bool = True,
    data_info: Optional[Dict[str, Any]] = None 
) -> List[BaseEffect]:
    """
    Instantiates a list of Effect objects based on parsed formula terms.

    This function handles:
    - Token substitution for hyperparameters (Dependency Injection from Grid Search).
    - Parsing of specific arguments for each effect type (Splines, Fourier, etc.).
    - Recursive creation of sub-effects for Tensor Products.

    Args:
        parsed_terms: List of term dictionaries returned by the formula parser.
        token_values: Dictionary of concrete values for hyperparameter tokens
                      (e.g., {'gk_la': 10}).
        default_alpha_p: Default log10(lambda_p) if not specified.
        include_offset: Whether to prepend an OffsetEffect (Intercept). 
                        False for recursive calls (e.g., inside 'te()').

    Returns:
        List of instantiated BaseEffect objects.
    """
    effects_list = []
    
    token_name_regex = re.compile(r'([a-zA-Z_][a-zA-Z0-9_]*)')

    if include_offset:
        offset_ap = token_values.get('ap_offset', default_alpha_p)
        lambda_p=10**float(offset_ap)
        effects_list.append(OffsetEffect(lambda_p, 'continue'))
    
    for term in parsed_terms:
        feature_name = term['feature']
        ttype = term['type']
        params = term['params'].copy()

        params_resolved = {}
        for key, val in params.items():
            resolved_val = val
            
            if isinstance(val, str):
                if val in token_values:
                    resolved_val = token_values[val]
            
            params_resolved[key] = resolved_val
        
        ap_val = params_resolved.get('ap', default_alpha_p)
        try:
            lambda_p = 10**float(ap_val)
        except (ValueError, TypeError):
             raise ValueError(f"Invalid value for 'ap' in term '{feature_name}': {ap_val}")

The Core Solver (additive.py)

The StaticTAM class is the primary engine of the framework. It inherits from BaseTAM and acts as the grand orchestrator. Once the mathematical blocks are assembled via the Factory, StaticTAM delegates the actual matrix inversions to the underlying Math Dispatcher.

Initialization and Dependency Routing

When initialized, StaticTAM parses the formula. A critical engineering choice is the detection of Grid Search tokens. If the parsed parameters contain unresolved strings, it flags the model as a template (is_grid_search_template_ = True), intentionally halting the instantiation of the effects to defer to the Multi-Start Coordinate Descent engine.

    def __init__(
        self,
        formula: str,
        group_col: str = None,
        date_col: str = None,
        default_alpha_p: float = -9.0,
        _internal_effects_list: Optional[List[BaseEffect]] = None,
        _internal_features_config: Optional[dict] = None 
    ):
        """
        Initializes the StaticTAM model.

        Args:
            formula: R-style formula defining the model structure 
                     (e.g., "Y ~ s(x) + l(t)").
            group_col: Column name used for grouping data (e.g., 'ID').
            date_col: Column name for time indexing.
            default_alpha_p: Default log10(lambda_p) regularization strength.
            _internal_effects_list: (Internal) Used for restoring state during grid search.
            _internal_features_config: (Internal) Used for restoring state during grid search.
        """
        super().__init__()
        
        self.effects_list_ = []
        self.formula_ = formula 
        self.default_alpha_p_ = default_alpha_p
        self.group_col_ = group_col or "__dummy_group__"
        self.date_col_ = date_col or "__dummy_date__"
        
        if _internal_effects_list:
            self.effects_list_ = _internal_effects_list
            self.features_config_ = _internal_features_config
            self.is_grid_search_template_ = False

        elif formula:
            self.target_col_, self.parsed_terms_ = parse_formula_to_terms(formula)
            real_features = self._extract_recursive_features(self.parsed_terms_)
            self.features_config_ = { "features": real_features }
            
            self.is_grid_search_template_ = False
            
            try:
                # Attempt standard instantiation; valid string hyperparams will pass.
                self.effects_list_ = create_effects_from_parsed_terms(
                    self.parsed_terms_, 
                    token_values={}, 
                    default_alpha_p=self.default_alpha_p_
                )
            except Exception as e:
                # Missing categorical counts are filled later in _prepare_data
                if isinstance(e, ValueError) and "requires 'n_cat'" in str(e):
                    pass
                else:
                    # If instantiation structurally fails and strings are present, it implies grid tokens are actively blocking the types.
                    has_str_vals = any(isinstance(v, str) for t in self.parsed_terms_ for v in t['params'].values())
                    if has_str_vals:
                        self.is_grid_search_template_ = True
                        print("Model initialized with Grid Search tokens. Use 'grid_search_fit()'.")
                    else:
                        raise e
        else:
            raise ValueError("`formula` must be provided to initialize StaticTAM.")

Component Decomposition

Because the Primal space concatenates independent topological blocks, the framework can mathematically isolate the contribution of each effect. The decompose_prediction method vectors this operation, multiplying the partitioned design matrix by its corresponding isolated coefficients to return a structural breakdown of the forecast.

    def decompose_prediction(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Decomposes the prediction into additive components per feature.

        Args:
            data: DataFrame containing input features.

        Returns:
            DataFrame with original data and additional 'effect_feature' columns.
        """
        data = _ensure_dummies(data, self.group_col_, self.date_col_)

        if self.coefficients_ is None:
            raise RuntimeError("Model must be fitted first.")

        required_cols = self.features_config_['features'] + [self.group_col_, self.date_col_]
        _check_features(dataset=data, required_features=required_cols)
        
        mask, balanced_data = _balance_groups(
            dataset=data, group_col=self.group_col_, date_col=self.date_col_, method="fill"
        )

        x_predict, _, _ = self._prepare_data(balanced_data)
        
        final_decomposed_effects = smart_decompose(x_predict, self.coefficients_, self.effects_list_)

        decomposed_df = _reassemble_decomposed_predictions(
            balanced_data, final_decomposed_effects, self.group_col_, self.unique_groups_
        )

        return _cleanup_dummies(decomposed_df[mask], self.group_col_, self.date_col_)

Hyperparameter Routing: Continuous vs. Discrete

To safely scale to Gigadata without exhausting computational time, StaticTAM divides hyperparameter tuning into two distinct structural methods.

The Continuous Algebraic Solver (GCV)

For continuous structural penalties (the \(\lambda\) regularization weights), iterative searching is mathematically obsolete. The auto_fit method routes the training data to the Generalized Cross-Validation (GCV) dispatcher. This computes the optimal Multiple Smoothing Parameters analytically via the cyclic trace trick [Golub et al., 1979].

(For the implementation details and block-diagonal routing logic of this solver, see the GCV Implementation Guide).

The Discrete Architectural Solver (Coordinate Descent)

While regularization is continuous, topological choices-such as the number of knots in a Spline or the maximum depth of a Tree-are strictly discrete. The grid_search_fit method employs a Multi-Start Coordinate Descent algorithm to resolve these non-differentiable tokens. It tests structural mutations by iteratively cycling through the parameter axes, executing rapid trial evaluations to find the optimal global architecture.

    def grid_search_fit(
            self, 
            data_train: pd.DataFrame, 
            data_val: pd.DataFrame, 
            grid_search_config: dict
        ):
            """
            Finds optimal hyperparameters via Multi-Start Coordinate Descent.

            Args:
                data_train (pd.DataFrame): Training data.
                data_val (pd.DataFrame): Validation data for scoring.
                grid_search_config (dict): Dictionary mapping tokens to lists of values.

            Returns:
                StaticTAM: A new fitted model with optimal parameters.
            """
            print("--- Starting Grid Search (Multi-Start Coordinate Descent) ---")
            
            data_train = _ensure_dummies(data_train, self.group_col_, self.date_col_)
            data_val = _ensure_dummies(data_val, self.group_col_, self.date_col_)
            
            temp_model = StaticTAM(self.formula_, self.group_col_, self.date_col_)
            temp_model.features_config_ = self.features_config_
            temp_model.target_col_ = self.target_col_
            
            required_cols_tr = self.features_config_['features'] + [self.group_col_, self.target_col_, self.date_col_]
            _check_features(dataset=data_train, required_features=required_cols_tr)
            _, balanced_data_train = _balance_groups(dataset=data_train, group_col=self.group_col_, date_col=self.date_col_, method="drop")
            
            required_cols_val = self.features_config_['features'] + [self.group_col_, self.target_col_, self.date_col_]
            _check_features(dataset=data_val, required_features=required_cols_val)
            _, balanced_data_val = _balance_groups(dataset=data_val, group_col=self.group_col_, date_col=self.date_col_, method="drop")
            
            x_train, y_train, unique_groups = temp_model._prepare_data(balanced_data_train, self.target_col_, ignore_template_check=True)
            x_val, y_val, _ = temp_model._prepare_data(balanced_data_val, self.target_col_, ignore_template_check=True)
            
            num_samples_train = x_train.shape[1]
            loss_L_star_L = torch.eye(1, device=TORCH_DEVICE, dtype=torch.get_default_dtype()) 

            search_axes, token_names = self._parse_grid_axes(grid_search_config)
            data_info = self._get_data_info(balanced_data_train)

            if not token_names:
                print("No grid tokens found. Fitting single configuration.")
                combo = self._build_combo_from_tokens({}, data_info=data_info)
                rmse, coeffs = self._evaluate_combination(
                    combo, x_train, y_train, x_val, y_val, num_samples_train, loss_L_star_L
                )
                optimal_params_combo, optimal_coeffs, min_global_rmse = combo, coeffs, rmse
            else:
                start_points = [
                    {"name": "Conservative", "tokens": {t: max(vals) if ('ap' in t or 'lambda_p' in t) else min(vals) for t, vals in search_axes.items()}},
                    {"name": "Median", "tokens": {t: vals[len(vals)//2] for t, vals in search_axes.items()}},
                    {"name": "Aggressive", "tokens": {t: min(vals) if ('ap' in t or 'lambda_p' in t) else max(vals) for t, vals in search_axes.items()}}
                ]

                global_best_rmse = float('inf')
                global_best_combo = None
                global_best_coeffs = None

                for strategy in start_points:
                    print(f"\n=== Strategy: {strategy['name']} Start ===")
                    current_best_tokens = strategy["tokens"].copy()
                    
                    complete_token_map = {}
                    for tname, tvals in search_axes.items():
                        if tname in current_best_tokens:
                            complete_token_map[tname] = current_best_tokens[tname]
                        else:
                            complete_token_map[tname] = tvals[0]
                    
                    current_best_tokens = complete_token_map.copy()

                    start_combo = self._build_combo_from_tokens(current_best_tokens, data_info=data_info)
                    current_rmse, current_coeffs = self._evaluate_combination(
                        start_combo, x_train, y_train, x_val, y_val, num_samples_train, loss_L_star_L
                    )
                    current_optimal_combo = start_combo

                    if current_rmse >= float('inf'):
                        continue

                    cycle = 0
                    while True:
                        cycle += 1
                        has_improved_in_cycle = False
                        
                        for token_name in token_names:
                            best_val_for_axis = current_best_tokens[token_name]
                            original_val = best_val_for_axis
                            possible_values = search_axes[token_name]
                            
                            for value in possible_values:
                                if value == original_val: continue 
                                tokens_to_test = current_best_tokens.copy()
                                tokens_to_test[token_name] = value
                                
                                combo = self._build_combo_from_tokens(tokens_to_test, data_info=data_info)
                                rmse, coeffs = self._evaluate_combination(
                                    combo, x_train, y_train, x_val, y_val, num_samples_train, loss_L_star_L
                                )

                                if rmse < current_rmse:
                                    current_rmse = rmse
                                    current_optimal_combo = combo
                                    current_coeffs = coeffs
                                    best_val_for_axis = value
                                    has_improved_in_cycle = True
                            
                            current_best_tokens[token_name] = best_val_for_axis
                        
                        print(f"  Cycle {cycle} | Current RMSE: {current_rmse:.4f}")
                        if not has_improved_in_cycle or cycle >= 5: break
                    
                    if current_rmse < global_best_rmse:
                        print(f"  >>> New Global Best found by {strategy['name']}! ({current_rmse:.4f})")
                        global_best_rmse = current_rmse
                        global_best_combo = current_optimal_combo
                        global_best_coeffs = current_coeffs

                optimal_params_combo = global_best_combo
                optimal_coeffs = global_best_coeffs
                min_global_rmse = global_best_rmse

            print("-" * 30)
            print(f"Grid search complete. Optimal Validation RMSE found: {min_global_rmse:.2f}")
            
            if optimal_params_combo is None:
                raise RuntimeError("Grid search failed to find any valid configuration.")

            print(f"Best tokens: {optimal_params_combo.get('token_values', 'N/A')}")

            model = self.__class__(
                formula=self.formula_, 
                group_col=self.group_col_,
                date_col=self.date_col_,
                _internal_effects_list=optimal_params_combo['effects_list'],
                _internal_features_config=self.features_config_ 
            )
            
            model.coefficients_ = optimal_coeffs
            model.target_col_ = self.target_col_ 
            model.norm_params_ = temp_model.norm_params_
            model.unique_groups_ = unique_groups

            return model

Separation of Concerns: Simulation vs. Inference

To guarantee safety in operational production pipelines, AdaptiveTAM strictly separates historical learning from out-of-sample inference.

Architectural Choice (The fit / predict Split):

  • fit(data): Runs the full sliding-window predict_online() simulation. At the end of the simulation, it extracts the final historical residuals and trains a global, frozen StaticTAM model (self.static_residual_model_) to act as the permanent correction rule.

  • predict(data): A purely deterministic, read-only method. It applies the base model and the frozen static residual model to new data. By completely bypassing the sliding-window simulation during inference, it guarantees blazing-fast, \(O(1)\) execution time and zero target leakage.