Source code for quantammsim.runners.training_evaluator

"""
Training Evaluator: A Meta-Runner for Assessing Training Effectiveness.

Wrap any training approach and evaluate whether it's effective using:

- Walk-Forward Efficiency (Pardo)
- Rademacher Complexity (Paleologo) — requires checkpoint tracking, see below
- OOS performance metrics

Usage:

.. code-block:: python

    from quantammsim.runners.training_evaluator import TrainingEvaluator, compare_trainers

    # Option 1: Wrap existing runner
    evaluator = TrainingEvaluator.from_runner("train_on_historic_data", max_iterations=500)
    results = evaluator.evaluate(run_fingerprint, n_cycles=5)

    # Option 2: Wrap custom function
    def my_trainer(data_dict, train_start_idx, train_end_idx, pool, run_fp, warm_start=None):
        # ... your logic ...
        return params, {"epochs": n}

    evaluator = TrainingEvaluator.from_function(my_trainer)

    # Option 3: Compare approaches
    comparison = compare_trainers(
        run_fingerprint,
        trainers={
            "sgd": TrainingEvaluator.from_runner("train_on_historic_data"),
            "random": TrainingEvaluator.random_baseline(),
        },
    )

Rademacher Complexity
~~~~~~~~~~~~~~~~~~~~~

Rademacher complexity measures overfitting risk by tracking the "search space"
explored during optimization. To compute Rademacher complexity, the trainer
must return ``checkpoint_returns`` in metadata:

.. code-block:: python

    def my_trainer_with_checkpoints(...):
        checkpoint_returns = []
        for epoch in range(n_epochs):
            params = update(params)
            if epoch % checkpoint_interval == 0:
                returns = evaluate(params)  # Returns array of shape (T,)
                checkpoint_returns.append(returns)

        return params, {
            "epochs_trained": n_epochs,
            "checkpoint_returns": np.stack(checkpoint_returns),  # (n_checkpoints, T)
        }

    evaluator = TrainingEvaluator.from_function(
        my_trainer_with_checkpoints,
        compute_rademacher=True,  # Enable Rademacher computation
    )

The built-in wrapper for ``train_on_historic_data`` supports checkpoint tracking.
Enable it by passing ``compute_rademacher=True`` to ``from_runner()``:

.. code-block:: python

    evaluator = TrainingEvaluator.from_runner(
        "train_on_historic_data",
        compute_rademacher=True,  # Enable checkpoint tracking
        checkpoint_interval=10,   # Optional: checkpoint every N iterations
    )

For ``multi_period_sgd`` or custom trainers, you can implement checkpoint
tracking manually by returning ``checkpoint_returns`` in metadata (as shown
above).
"""

import numpy as np
import jax.numpy as jnp
from jax import jit
from jax.tree_util import Partial
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional, Any, Callable, Union, Generator
from copy import deepcopy
from datetime import datetime
from functools import partial

from quantammsim.runners.default_run_fingerprint import run_fingerprint_defaults
from quantammsim.core_simulator.param_utils import recursive_default_set
from quantammsim.runners.jax_runner_utils import (
    Hashabledict,
    get_unique_tokens,
    create_static_dict,
    get_sig_variations,
)
from quantammsim.utils.post_train_analysis import calculate_period_metrics
from quantammsim.utils.data_processing.historic_data_utils import get_data_dict
from quantammsim.pools.creator import create_pool
from quantammsim.core_simulator.forward_pass import forward_pass_nograd

# Import utilities from robust_walk_forward
from quantammsim.runners.robust_walk_forward import (
    compute_empirical_rademacher,
    compute_rademacher_haircut,
    compute_walk_forward_efficiency,
    WalkForwardCycle,
    generate_walk_forward_cycles,
)


# =============================================================================
# Result Data Classes
# =============================================================================

[docs] @dataclass class CycleEvaluation: """Evaluation results for a single walk-forward cycle. Captures in-sample (IS) and out-of-sample (OOS) performance metrics for one train/test window, plus robustness diagnostics. Attributes ---------- cycle_number : int Zero-based index of this cycle. is_sharpe : float Annualised Sharpe ratio on the in-sample (training) window. is_returns_over_hodl : float Cumulative return relative to uniform HODL on the IS window. oos_sharpe : float Annualised Sharpe ratio on the out-of-sample (test) window. oos_returns_over_hodl : float Cumulative return relative to uniform HODL on the OOS window. walk_forward_efficiency : float WFE = OOS Sharpe / IS Sharpe (Pardo metric). is_oos_gap : float IS Sharpe minus OOS Sharpe (positive = overfitting). epochs_trained : int Number of gradient updates in this cycle's training run. rademacher_complexity : float or None Empirical Rademacher complexity from training checkpoints. adjusted_oos_sharpe : float or None OOS Sharpe minus the Rademacher haircut. is_calmar, oos_calmar : float or None Calmar ratio (return / max drawdown) for IS and OOS. is_sterling, oos_sterling : float or None Sterling ratio for IS and OOS. is_ulcer, oos_ulcer : float or None Ulcer index for IS and OOS. is_returns, oos_returns : float or None Cumulative returns for IS and OOS. is_daily_log_sharpe, oos_daily_log_sharpe : float or None Daily-log-return Sharpe for IS and OOS. trained_params : dict or None Strategy parameters at end of training for this cycle. train_start_date, train_end_date : str or None IS window date boundaries. test_start_date, test_end_date : str or None OOS window date boundaries. run_location : str or None Filesystem path to the training output for this cycle. run_fingerprint : dict or None Full run configuration used for this cycle. """ cycle_number: int is_sharpe: float is_returns_over_hodl: float oos_sharpe: float oos_returns_over_hodl: float walk_forward_efficiency: float is_oos_gap: float epochs_trained: int = 0 rademacher_complexity: Optional[float] = None adjusted_oos_sharpe: Optional[float] = None # Additional risk metrics (from calculate_period_metrics) is_calmar: Optional[float] = None oos_calmar: Optional[float] = None is_sterling: Optional[float] = None oos_sterling: Optional[float] = None is_ulcer: Optional[float] = None oos_ulcer: Optional[float] = None is_returns: Optional[float] = None oos_returns: Optional[float] = None is_daily_log_sharpe: Optional[float] = None oos_daily_log_sharpe: Optional[float] = None # Trained strategy parameters for this cycle trained_params: Optional[Dict[str, Any]] = None # Cycle date ranges train_start_date: Optional[str] = None train_end_date: Optional[str] = None test_start_date: Optional[str] = None test_end_date: Optional[str] = None # Provenance: for debugging and linking to output files run_location: Optional[str] = None run_fingerprint: Optional[Dict[str, Any]] = None
[docs] @dataclass class EvaluationResult: """Complete evaluation results across all walk-forward cycles. Aggregates per-cycle metrics into summary statistics and provides an effectiveness verdict based on configurable thresholds. Attributes ---------- trainer_name : str Identifier for the trainer wrapper that produced these results. trainer_config : Dict[str, Any] Configuration dict passed to the trainer. cycles : List[CycleEvaluation] Per-cycle evaluation results. mean_wfe : float Mean Walk-Forward Efficiency across cycles. mean_oos_sharpe : float Mean OOS Sharpe ratio across cycles. std_oos_sharpe : float Standard deviation of OOS Sharpe across cycles. worst_oos_sharpe : float Minimum OOS Sharpe across cycles. mean_is_oos_gap : float Mean IS–OOS Sharpe gap (positive = overfitting). aggregate_rademacher : float or None Mean Rademacher complexity across cycles (if computed). adjusted_mean_oos_sharpe : float or None Mean OOS Sharpe minus the mean Rademacher haircut. is_effective : bool Whether the strategy passes the effectiveness criteria (positive mean OOS Sharpe, WFE > threshold, etc.). effectiveness_reasons : List[str] Human-readable explanations for the effectiveness verdict. """ trainer_name: str trainer_config: Dict[str, Any] cycles: List[CycleEvaluation] # Aggregate metrics mean_wfe: float mean_oos_sharpe: float std_oos_sharpe: float worst_oos_sharpe: float mean_is_oos_gap: float # Rademacher-adjusted aggregate_rademacher: Optional[float] = None adjusted_mean_oos_sharpe: Optional[float] = None # Verdict is_effective: bool = False effectiveness_reasons: List[str] = field(default_factory=list)
# ============================================================================= # Trainer Wrappers # =============================================================================
[docs] class TrainerWrapper: """ Base class for wrapping training functions. A trainer must implement: train(data_dict, train_start_idx, train_end_idx, pool, run_fp, warm_start, ...) -> (params, metadata) """
[docs] def __init__(self, name: str = "trainer", config: Optional[Dict] = None): self._name = name self._config = config or {}
@property def name(self) -> str: return self._name @property def config(self) -> Dict[str, Any]: return self._config
[docs] def train( self, data_dict: dict, train_start_idx: int, train_end_idx: int, pool: Any, run_fingerprint: dict, n_assets: int, warm_start_params: Optional[Dict] = None, warm_start_weights: Optional[Any] = None, train_start_date: Optional[str] = None, train_end_date: Optional[str] = None, test_end_date: Optional[str] = None, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Train and return (params, metadata). Parameters ---------- warm_start_params : dict, optional Strategy parameters from previous cycle to use as initialization. warm_start_weights : array-like, optional Final weights from previous cycle. Pool starts with fresh initial_pool_value but distributed according to these weights (simulating continuous operation). """ raise NotImplementedError
[docs] class FunctionWrapper(TrainerWrapper): """Wrap a plain ``(run_fingerprint, **kwargs) -> (params, metrics)`` function as a trainer. Use via :meth:`TrainingEvaluator.from_function` rather than constructing directly. """
[docs] def __init__( self, fn: Callable, name: str = "custom", config: Optional[Dict] = None, ): super().__init__(name, config) self.fn = fn
[docs] def train( self, data_dict: dict, train_start_idx: int, train_end_idx: int, pool: Any, run_fingerprint: dict, n_assets: int, warm_start_params: Optional[Dict] = None, warm_start_weights: Optional[Any] = None, train_start_date: Optional[str] = None, train_end_date: Optional[str] = None, test_end_date: Optional[str] = None, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: return self.fn( data_dict=data_dict, train_start_idx=train_start_idx, train_end_idx=train_end_idx, pool=pool, run_fingerprint=run_fingerprint, n_assets=n_assets, warm_start_params=warm_start_params, warm_start_weights=warm_start_weights, )
[docs] class ExistingRunnerWrapper(TrainerWrapper): """ Wrap an existing runner (train_on_historic_data, etc). This creates a thin adapter that calls the existing runner with appropriate parameters. """
[docs] def __init__( self, runner_name: str, runner_kwargs: Optional[Dict] = None, compute_rademacher: bool = False, root: str = None, ): self.runner_name = runner_name self.runner_kwargs = runner_kwargs or {} self.compute_rademacher = compute_rademacher self.root = root super().__init__( name=f"{runner_name}", config=self.runner_kwargs, )
[docs] def train( self, data_dict: dict, train_start_idx: int, train_end_idx: int, pool: Any, run_fingerprint: dict, n_assets: int, warm_start_params: Optional[Dict] = None, warm_start_weights: Optional[Any] = None, train_start_date: Optional[str] = None, train_end_date: Optional[str] = None, test_end_date: Optional[str] = None, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Call the existing runner. Note: This adapts the cycle-based interface to the existing runners which expect full run_fingerprint with date strings. The date strings are used to modify the fingerprint so each cycle trains on different data. """ if self.runner_name == "train_on_historic_data": return self._run_train_on_historic_data( data_dict, train_start_idx, train_end_idx, pool, run_fingerprint, n_assets, warm_start_params, warm_start_weights, train_start_date, train_end_date, test_end_date, ) elif self.runner_name == "multi_period_sgd": return self._run_multi_period_sgd( data_dict, train_start_idx, train_end_idx, pool, run_fingerprint, n_assets, warm_start_params, train_start_date, train_end_date, ) else: raise ValueError(f"Unknown runner: {self.runner_name}")
def _run_train_on_historic_data( self, data_dict: dict, train_start_idx: int, train_end_idx: int, pool: Any, run_fingerprint: dict, n_assets: int, warm_start_params: Optional[Dict], warm_start_weights: Optional[Any], train_start_date: Optional[str], train_end_date: Optional[str], test_end_date: Optional[str] = None, ) -> Tuple[Dict, Dict]: """Adapter for train_on_historic_data.""" from datetime import datetime, timedelta from quantammsim.runners.jax_runners import train_on_historic_data # Create a local fingerprint for this cycle local_fp = deepcopy(run_fingerprint) # Update date strings for this cycle's training window if train_start_date is not None: local_fp["startDateString"] = train_start_date if train_end_date is not None: local_fp["endDateString"] = train_end_date # Set OOS test period - use actual test_end_date if provided # This ensures train_on_historic_data computes proper OOS metrics if test_end_date is not None: local_fp["endTestDateString"] = test_end_date elif train_end_date is not None: # Fallback to 1 day after training (not recommended) train_end_dt = datetime.strptime(train_end_date, "%Y-%m-%d %H:%M:%S") test_end_dt = train_end_dt + timedelta(days=1) local_fp["endTestDateString"] = test_end_dt.strftime("%Y-%m-%d %H:%M:%S") # Override iterations if specified if "max_iterations" in self.runner_kwargs: local_fp["optimisation_settings"]["n_iterations"] = self.runner_kwargs["max_iterations"] # Enable checkpoint tracking if computing Rademacher if self.compute_rademacher: local_fp["optimisation_settings"]["track_checkpoints"] = True local_fp["optimisation_settings"]["checkpoint_interval"] = self.runner_kwargs.get( "checkpoint_interval", 10 ) # Always return training metadata - we need test metrics from the runner result = train_on_historic_data( local_fp, iterations_per_print=self.runner_kwargs.get("iterations_per_print", 10000), return_training_metadata=True, # Always get metadata for OOS metrics force_init=True, # Don't load/restart previous runs root=self.root, warm_start_params=warm_start_params, warm_start_weights=warm_start_weights, ) # Unpack (params, metadata) tuple - both SGD and optuna return this format params, metadata = result # train_on_historic_data now returns properly shaped params # (n_ensemble_members, ...) not (n_parameter_sets, n_ensemble_members, ...) # No squeeze needed - selection happens in train_on_historic_data return params, metadata def _run_multi_period_sgd( self, data_dict: dict, train_start_idx: int, train_end_idx: int, pool: Any, run_fingerprint: dict, n_assets: int, warm_start_params: Optional[Dict], train_start_date: Optional[str], train_end_date: Optional[str], ) -> Tuple[Dict, Dict]: """Run multi-period SGD training for a single walk-forward cycle. Adapts the multi-period runner interface to the standard trainer contract by constructing a cycle-specific fingerprint and extracting the best parameters from the result. Parameters ---------- data_dict : dict Price data and index bounds. train_start_idx, train_end_idx : int Row indices bounding the training window. pool : AbstractPool Pool instance (used for parameter initialisation). run_fingerprint : dict Base run configuration (deep-copied and modified per cycle). n_assets : int Number of assets. warm_start_params : dict or None Parameters from the previous cycle for warm-starting. train_start_date, train_end_date : str or None ISO-8601 date boundaries for this cycle's training window. Returns ------- params : dict Best parameters found during training. metadata : dict Summary statistics from the multi-period result. """ from datetime import datetime, timedelta from quantammsim.runners.multi_period_sgd import multi_period_sgd_training local_fp = deepcopy(run_fingerprint) # Update date strings for this cycle's training window if train_start_date is not None: local_fp["startDateString"] = train_start_date if train_end_date is not None: local_fp["endDateString"] = train_end_date # Set a test period just after training end for consistency train_end_dt = datetime.strptime(train_end_date, "%Y-%m-%d %H:%M:%S") test_end_dt = train_end_dt + timedelta(days=1) local_fp["endTestDateString"] = test_end_dt.strftime("%Y-%m-%d %H:%M:%S") result, summary = multi_period_sgd_training( local_fp, n_periods=self.runner_kwargs.get("n_periods", 4), max_epochs=self.runner_kwargs.get("max_epochs", 200), aggregation=self.runner_kwargs.get("aggregation", "mean"), verbose=False, root=self.root, ) params = result.best_params # Construct metrics in the format expected by training_evaluator # Use mean metrics across periods for "train" and last period for "test" train_metrics = { "sharpe": result.mean_sharpe, "returns_over_uniform_hodl": result.mean_returns_over_hodl, "return": np.mean(result.period_returns) if result.period_returns else 0.0, } # Use last period as "test" (most recent/OOS-like) test_metrics = { "sharpe": result.period_sharpes[-1] if result.period_sharpes else result.mean_sharpe, "returns_over_uniform_hodl": result.period_returns_over_hodl[-1] if result.period_returns_over_hodl else result.mean_returns_over_hodl, "return": result.period_returns[-1] if result.period_returns else 0.0, } metadata = { "epochs_trained": result.epochs_trained, "final_objective": result.final_objective, # New format fields "best_train_metrics": [train_metrics], "best_continuous_test_metrics": [test_metrics], "best_param_idx": 0, "best_final_weights": None, "best_final_reserves": None, # Legacy fields for backward compat "final_train_metrics": [train_metrics], "final_continuous_test_metrics": [test_metrics], "final_weights": None, "final_reserves": None, # Multi-period specific info "n_periods": summary.get("n_periods"), "aggregation": summary.get("aggregation"), "period_sharpes": result.period_sharpes, "worst_sharpe": result.worst_sharpe, } return params, metadata
[docs] class RandomBaselineWrapper(TrainerWrapper): """ Baseline: Random parameters. Use to check if your trainer beats random chance. """
[docs] def __init__(self, seed: int = 42): super().__init__(name="random_baseline", config={"seed": seed}) self.seed = seed self._call_count = 0
[docs] def train( self, data_dict: dict, train_start_idx: int, train_end_idx: int, pool: Any, run_fingerprint: dict, n_assets: int, warm_start_params: Optional[Dict] = None, warm_start_weights: Optional[Any] = None, train_start_date: Optional[str] = None, train_end_date: Optional[str] = None, test_end_date: Optional[str] = None, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Return random parameters (ignores warm-start and date strings).""" rng = np.random.RandomState(self.seed + self._call_count) self._call_count += 1 initial_params = { "initial_memory_length": run_fingerprint["initial_memory_length"], "initial_memory_length_delta": run_fingerprint["initial_memory_length_delta"], "initial_k_per_day": run_fingerprint["initial_k_per_day"], "initial_weights_logits": run_fingerprint["initial_weights_logits"], "initial_log_amplitude": run_fingerprint["initial_log_amplitude"], "initial_raw_width": run_fingerprint["initial_raw_width"], "initial_raw_exponents": run_fingerprint["initial_raw_exponents"], "initial_pre_exp_scaling": run_fingerprint["initial_pre_exp_scaling"], } n_parameter_sets = 1 params = pool.init_parameters( initial_params, run_fingerprint, n_assets, n_parameter_sets ) # Add random noise for key in params: if hasattr(params[key], 'shape') and params[key].size > 0: noise = rng.randn(*params[key].shape) * 0.5 params[key] = params[key] + noise # Squeeze out parameter set dimension params = { k: jnp.squeeze(v, axis=0) if hasattr(v, 'shape') and len(v.shape) > 1 else v for k, v in params.items() } # Compute metrics for random params (required by TrainingEvaluator) # This ensures RandomBaselineWrapper follows the same contract as real trainers train_metrics, test_metrics = self._compute_metrics( params, data_dict, train_start_idx, train_end_idx, pool, run_fingerprint, n_assets ) metadata = { "method": "random_baseline", "epochs_trained": 0, "final_objective": 0.0, # New format fields "best_train_metrics": [train_metrics], "best_continuous_test_metrics": [test_metrics], "best_param_idx": 0, "best_final_weights": None, "best_final_reserves": None, # Legacy fields for backward compat "final_train_metrics": [train_metrics], "final_continuous_test_metrics": [test_metrics], "final_weights": None, "final_reserves": None, } return params, metadata
def _compute_metrics( self, params: Dict[str, Any], data_dict: dict, train_start_idx: int, train_end_idx: int, pool: Any, run_fingerprint: dict, n_assets: int, ) -> Tuple[Dict[str, float], Dict[str, float]]: """Compute train and test period metrics for a given parameter set. Runs a no-gradient forward pass over the training window and the test window, then calls ``calculate_period_metrics`` on each to extract Sharpe, Calmar, returns-over-HODL, etc. Parameters ---------- params : dict Strategy parameter dict. data_dict : dict Price data and index bounds. train_start_idx, train_end_idx : int Row indices bounding the training window. pool : AbstractPool Pool instance. run_fingerprint : dict Run configuration. n_assets : int Number of assets. Returns ------- train_metrics : dict Metric dict for the training window. test_metrics : dict Metric dict for the test window. """ from jax import jit from functools import partial as Partial from quantammsim.core_simulator.forward_pass import forward_pass_nograd from quantammsim.runners.jax_runner_utils import Hashabledict, create_static_dict from quantammsim.runners.jax_runners import get_sig_variations from quantammsim.utils.post_train_analysis import calculate_period_metrics, calculate_continuous_test_metrics all_sig_variations = get_sig_variations(n_assets) # Get test period info from run_fingerprint test_fraction = run_fingerprint.get("test_fraction", 0.2) train_bout_length = train_end_idx - train_start_idx test_bout_length = int(train_bout_length * test_fraction / (1 - test_fraction)) # Create continuous forward pass covering train + test continuous_bout_length = train_bout_length + test_bout_length static_dict = create_static_dict( run_fingerprint, continuous_bout_length, all_sig_variations, overrides={ "n_assets": n_assets, "return_val": "reserves_and_values", "training_data_kind": run_fingerprint.get("optimisation_settings", {}).get("training_data_kind", "historic_with_noise"), } ) eval_fn = jit(Partial( forward_pass_nograd, prices=data_dict["prices"], static_dict=Hashabledict(static_dict), pool=pool, )) output = eval_fn(params, (train_start_idx, 0)) # Training metrics train_dict = { "value": output["value"][:train_bout_length], "reserves": output["reserves"][:train_bout_length], } train_prices = data_dict["prices"][train_start_idx:train_start_idx + train_bout_length] train_metrics = calculate_period_metrics(train_dict, train_prices) # Continuous test metrics continuous_dict = { "value": output["value"], "reserves": output["reserves"], } continuous_prices = data_dict["prices"][train_start_idx:train_start_idx + continuous_bout_length] test_metrics = calculate_continuous_test_metrics( continuous_dict, train_bout_length, test_bout_length, continuous_prices ) return train_metrics, test_metrics
# ============================================================================= # Main Evaluator # =============================================================================
[docs] class TrainingEvaluator: """ Evaluates whether a training approach is effective. Wraps any trainer and runs walk-forward evaluation to assess effectiveness using WFE and Rademacher metrics. Pruning ------- This evaluator yields CycleEvaluation results via evaluate_iter(), allowing the consumer (e.g., HyperparamTuner) to decide when to prune. The evaluator itself does not prune - it evaluates all cycles unless the consumer stops iterating. This design keeps pruning logic in one place (the Optuna integration) rather than duplicating it here. """
[docs] def __init__( self, trainer: TrainerWrapper, n_cycles: int = 5, keep_fixed_start: bool = False, # Rolling window by default (consistent bout_offset meaning) compute_rademacher: bool = False, # Off by default (needs checkpoint tracking) verbose: bool = True, root: str = None, wfe_metric: str = "sharpe", # Metric for WFE and IS-OOS gap (default: sharpe per Pardo) ): self.trainer = trainer self.n_cycles = n_cycles self.keep_fixed_start = keep_fixed_start self.compute_rademacher = compute_rademacher self.verbose = verbose self.root = root self.wfe_metric = wfe_metric
# ------------------------------------------------------------------------- # Convenience Constructors # -------------------------------------------------------------------------
[docs] @classmethod def from_runner( cls, runner_name: str, n_cycles: int = 5, keep_fixed_start: bool = False, # Rolling window by default verbose: bool = True, compute_rademacher: bool = False, root: str = None, wfe_metric: str = "sharpe", **runner_kwargs, ) -> "TrainingEvaluator": """ Create evaluator from an existing runner. Parameters ---------- runner_name : str One of: "train_on_historic_data", "multi_period_sgd" n_cycles : int Number of walk-forward cycles verbose : bool Print progress compute_rademacher : bool Enable Rademacher complexity computation. This enables checkpoint tracking in the trainer, which saves intermediate returns during training for Rademacher estimation. Default False. root : str, optional Root directory for data files. If None, uses default data location. wfe_metric : str Metric to use for WFE and IS-OOS gap computation. Default "sharpe" (per Pardo). Can be any metric from calculate_period_metrics (sharpe, calmar, sterling, etc.) **runner_kwargs Arguments passed to the runner (e.g., max_iterations=500) Example ------- >>> evaluator = TrainingEvaluator.from_runner( ... "train_on_historic_data", ... max_iterations=500, ... compute_rademacher=True, # Enable Rademacher complexity ... ) """ wrapper = ExistingRunnerWrapper( runner_name, runner_kwargs, compute_rademacher=compute_rademacher, root=root ) return cls( trainer=wrapper, n_cycles=n_cycles, keep_fixed_start=keep_fixed_start, verbose=verbose, compute_rademacher=compute_rademacher, root=root, wfe_metric=wfe_metric, )
[docs] @classmethod def from_function( cls, fn: Callable, name: str = "custom", n_cycles: int = 5, keep_fixed_start: bool = False, # Rolling window by default verbose: bool = True, root: str = None, wfe_metric: str = "sharpe", **config, ) -> "TrainingEvaluator": """ Create evaluator from a custom training function. Parameters ---------- fn : Callable Function with signature ``fn(data_dict, train_start_idx, train_end_idx, pool, run_fingerprint, n_assets, warm_start_params) -> (params, metadata)``. name : str Name for this trainer n_cycles : int Number of walk-forward cycles keep_fixed_start : bool If True, expanding window (train always starts from beginning). If False, rolling window (train window moves forward). root : str, optional Root directory for data files. If None, uses default data location. wfe_metric : str Metric to use for WFE and IS-OOS gap computation. Default "sharpe". **config Config dict for reporting Example ------- >>> def my_trainer(data_dict, train_start_idx, train_end_idx, pool, ... run_fingerprint, n_assets, warm_start_params=None): ... # Your training logic ... return params, {"epochs": 100} >>> >>> evaluator = TrainingEvaluator.from_function(my_trainer) """ wrapper = FunctionWrapper(fn, name=name, config=config) return cls(trainer=wrapper, n_cycles=n_cycles, keep_fixed_start=keep_fixed_start, verbose=verbose, root=root, wfe_metric=wfe_metric)
[docs] @classmethod def random_baseline( cls, seed: int = 42, n_cycles: int = 5, keep_fixed_start: bool = False, # Rolling window by default verbose: bool = True, root: str = None, wfe_metric: str = "sharpe", ) -> "TrainingEvaluator": """ Create evaluator that uses random parameters. Use this as a baseline to verify your trainer beats random chance. Parameters ---------- seed : int Random seed for reproducibility n_cycles : int Number of walk-forward cycles keep_fixed_start : bool If True, expanding window. If False, rolling window. verbose : bool Print progress root : str, optional Root directory for data files. If None, uses default data location. wfe_metric : str Metric to use for WFE and IS-OOS gap computation. Default "sharpe". """ wrapper = RandomBaselineWrapper(seed=seed) return cls(trainer=wrapper, n_cycles=n_cycles, keep_fixed_start=keep_fixed_start, verbose=verbose, root=root, wfe_metric=wfe_metric)
# ------------------------------------------------------------------------- # Core Evaluation # -------------------------------------------------------------------------
[docs] def evaluate_iter( self, run_fingerprint: dict ) -> "Generator[CycleEvaluation, None, EvaluationResult]": """ Generator that yields CycleEvaluation after each cycle completes. This allows callers to inspect intermediate results and potentially stop early (e.g., for Optuna pruning). Yields ------ CycleEvaluation Results from each completed cycle Returns ------- EvaluationResult Final aggregated results (accessible via generator.value after StopIteration) Example ------- >>> evaluator = TrainingEvaluator.from_runner("train_on_historic_data") >>> gen = evaluator.evaluate_iter(run_fingerprint) >>> for cycle_eval in gen: ... print(f"Cycle {cycle_eval.cycle_number}: OOS Sharpe = {cycle_eval.oos_sharpe}") ... if cycle_eval.oos_sharpe < -1.0: ... break # Stop early if terrible >>> # If completed, get final result >>> # final_result = gen.value # Only available after StopIteration """ recursive_default_set(run_fingerprint, run_fingerprint_defaults) if self.verbose: print("=" * 70) print(f"EVALUATING: {self.trainer.name}") print("=" * 70) print(f"Config: {self.trainer.config}") print(f"Cycles: {self.n_cycles}") print(f"Mode: {'Expanding' if self.keep_fixed_start else 'Rolling'}") print("=" * 70) # Setup unique_tokens = get_unique_tokens(run_fingerprint) n_assets = len(unique_tokens) pool = create_pool(run_fingerprint["rule"]) assert pool.is_trainable(), "Pool must be trainable" # Generate cycles (reuse from robust_walk_forward) cycles = generate_walk_forward_cycles( start_date=run_fingerprint["startDateString"], end_date=run_fingerprint["endTestDateString"], n_cycles=self.n_cycles, keep_fixed_start=self.keep_fixed_start, ) # Load data for full period last_test_end = cycles[-1].test_end_date if self.verbose: print(f"\nLoading data: {run_fingerprint['startDateString']}{last_test_end}") data_dict = get_data_dict( unique_tokens, run_fingerprint, data_kind=run_fingerprint["optimisation_settings"]["training_data_kind"], max_memory_days=run_fingerprint["max_memory_days"], start_date_string=run_fingerprint["startDateString"], end_time_string=last_test_end, do_test_period=False, root=self.root, ) if self.verbose: print(f"Data loaded: {data_dict['prices'].shape[0]} timesteps") # Convert cycle dates to indices self._compute_cycle_indices(cycles, run_fingerprint, data_dict, last_test_end) # Run evaluation cycle_results = [] prev_params = None prev_weights = None # Track final weights for warm-starting next cycle all_checkpoint_returns = [] # For aggregate Rademacher for cycle in cycles: if self.verbose: print(f"\n--- Cycle {cycle.cycle_number} ---") # Train - pass test_end_date so runner computes proper OOS metrics # Pass warm_start_weights for initial weight distribution (fresh pool value) params, metadata = self.trainer.train( data_dict=data_dict, train_start_idx=cycle.train_start_idx, train_end_idx=cycle.train_end_idx, pool=pool, run_fingerprint=run_fingerprint, n_assets=n_assets, warm_start_params=prev_params, warm_start_weights=prev_weights, train_start_date=cycle.train_start_date, train_end_date=cycle.train_end_date, test_end_date=cycle.test_end_date, ) # Handle training failure (e.g., all inner Optuna trials failed) if params is None: error_msg = metadata.get("error", "Training returned None params") if metadata else "Training returned None params" raise ValueError(f"Training failed for cycle {cycle.cycle_number}: {error_msg}") # Get metrics from runner's training metadata (computed by train_on_historic_data) # This keeps all metric logic in the runner, not duplicated here # # Structure from jax_runners.py (new format): # best_train_metrics[param_idx]: dict with sharpe, calmar, etc. (IS metrics) # best_continuous_test_metrics[param_idx]: dict with OOS metrics # - OOS metrics from continuous forward pass (train→test seamlessly) # best_param_idx: index of best param set # best_final_weights: weights at end of training for warm-starting # # Legacy fields (deprecated, for backward compat): # final_train_metrics, final_continuous_test_metrics, final_weights # best_idx = metadata.get("best_param_idx", 0) # Use new field names, falling back to legacy for backward compatibility best_train = metadata.get("best_train_metrics") or metadata.get("final_train_metrics") best_oos = metadata.get("best_continuous_test_metrics") or metadata.get("final_continuous_test_metrics") # Validate that metrics are provided - no silent fallback if best_train is None: raise ValueError( f"Training metadata missing 'best_train_metrics' for cycle {cycle.cycle_number}. " f"Ensure trainer returns proper metadata with metrics. " f"Available keys: {list(metadata.keys()) if metadata else 'None'}" ) if not isinstance(best_train, list) or len(best_train) <= best_idx: raise ValueError( f"Invalid best_train_metrics for cycle {cycle.cycle_number}: " f"expected list with at least {best_idx + 1} elements, " f"got {type(best_train).__name__} with {len(best_train) if isinstance(best_train, list) else 'N/A'} elements" ) is_metrics = best_train[best_idx] if best_oos is None: raise ValueError( f"Training metadata missing 'best_continuous_test_metrics' for cycle {cycle.cycle_number}. " f"Ensure trainer returns proper metadata with test metrics. " f"Available keys: {list(metadata.keys()) if metadata else 'None'}" ) if not isinstance(best_oos, list) or len(best_oos) <= best_idx: raise ValueError( f"Invalid best_continuous_test_metrics for cycle {cycle.cycle_number}: " f"expected list with at least {best_idx + 1} elements, " f"got {type(best_oos).__name__} with {len(best_oos) if isinstance(best_oos, list) else 'N/A'} elements" ) # OOS metrics from calculate_continuous_test_metrics (already unprefixed) oos_metrics = best_oos[best_idx] # Compute WFE using configured metric (default: sharpe per Pardo) is_wfe_metric = is_metrics.get(self.wfe_metric, is_metrics["sharpe"]) oos_wfe_metric = oos_metrics.get(self.wfe_metric, oos_metrics["sharpe"]) wfe = compute_walk_forward_efficiency( is_wfe_metric, oos_wfe_metric, cycle.train_end_idx - cycle.train_start_idx, cycle.test_end_idx - cycle.test_start_idx, ) # Compute Rademacher if checkpoint data available rademacher_complexity = None adjusted_oos_sharpe = None checkpoint_returns = metadata.get("checkpoint_returns") if self.compute_rademacher and checkpoint_returns is not None: checkpoint_returns = np.array(checkpoint_returns) if checkpoint_returns.size > 0: rademacher_complexity = compute_empirical_rademacher(checkpoint_returns) test_T = cycle.test_end_idx - cycle.test_start_idx adjusted_oos_sharpe, _ = compute_rademacher_haircut( oos_metrics["sharpe"], rademacher_complexity, test_T, ) all_checkpoint_returns.append(checkpoint_returns) # Convert params to serializable format (JAX arrays -> Python floats/lists) serializable_params = {} for k, v in params.items(): if hasattr(v, 'tolist'): serializable_params[k] = v.tolist() elif isinstance(v, (int, float, str, bool, type(None))): serializable_params[k] = v else: try: serializable_params[k] = float(v) except (TypeError, ValueError): serializable_params[k] = str(v) cycle_eval = CycleEvaluation( cycle_number=cycle.cycle_number, is_sharpe=is_metrics["sharpe"], is_returns_over_hodl=is_metrics["returns_over_uniform_hodl"], oos_sharpe=oos_metrics["sharpe"], oos_returns_over_hodl=oos_metrics["returns_over_uniform_hodl"], walk_forward_efficiency=wfe, is_oos_gap=is_wfe_metric - oos_wfe_metric, # Uses configured metric epochs_trained=metadata.get("epochs_trained", 0), rademacher_complexity=rademacher_complexity, adjusted_oos_sharpe=adjusted_oos_sharpe, # Additional risk metrics is_calmar=is_metrics.get("calmar"), oos_calmar=oos_metrics.get("calmar"), is_sterling=is_metrics.get("sterling"), oos_sterling=oos_metrics.get("sterling"), is_ulcer=is_metrics.get("ulcer"), oos_ulcer=oos_metrics.get("ulcer"), is_returns=is_metrics.get("return"), oos_returns=oos_metrics.get("return"), is_daily_log_sharpe=is_metrics.get("daily_log_sharpe"), oos_daily_log_sharpe=oos_metrics.get("daily_log_sharpe"), # Trained strategy params and dates trained_params=serializable_params, train_start_date=cycle.train_start_date, train_end_date=cycle.train_end_date, test_start_date=cycle.test_start_date, test_end_date=cycle.test_end_date, # Provenance: for debugging and linking to output files run_location=metadata.get("run_location"), run_fingerprint=metadata.get("run_fingerprint"), ) cycle_results.append(cycle_eval) prev_params = params # Capture weights for warm-starting next cycle (new field, fallback to legacy) best_weights = metadata.get("best_final_weights") prev_weights = best_weights if best_weights is not None else metadata.get("final_weights") if self.verbose: print(f" IS: sharpe={is_metrics['sharpe']:.4f}") print(f" OOS: sharpe={oos_metrics['sharpe']:.4f}") print(f" WFE: {wfe:.4f}") if rademacher_complexity is not None: print(f" Rademacher: R̂={rademacher_complexity:.4f}, adj_sharpe={adjusted_oos_sharpe:.4f}") # Yield intermediate result for pruning decisions yield cycle_eval # Aggregate results result = self._aggregate_results(cycle_results, cycles, all_checkpoint_returns) if self.verbose: self.print_report(result) return result
[docs] def evaluate(self, run_fingerprint: dict) -> EvaluationResult: """ Run walk-forward evaluation. Parameters ---------- run_fingerprint : dict Run configuration Returns ------- EvaluationResult Comprehensive evaluation results """ # Use the generator, consuming all cycles gen = self.evaluate_iter(run_fingerprint) result = None try: while True: next(gen) except StopIteration as e: result = e.value return result
def _compute_cycle_indices( self, cycles: List[WalkForwardCycle], run_fingerprint: dict, data_dict: dict, last_test_end: str, ): """Convert cycle dates to data indices. All indices are bounded to [data_dict["start_idx"], data_dict["end_idx"]] to handle edge cases from date rounding in generate_walk_forward_cycles. """ def to_ts(date_str): return datetime.strptime(date_str, "%Y-%m-%d %H:%M:%S").timestamp() total_ts = to_ts(last_test_end) - to_ts(run_fingerprint["startDateString"]) data_length = data_dict["end_idx"] - data_dict["start_idx"] start_ts = to_ts(run_fingerprint["startDateString"]) min_idx = data_dict["start_idx"] max_idx = data_dict["end_idx"] def compute_idx(date_str): """Compute index from date, bounded to valid range.""" raw_idx = data_dict["start_idx"] + int( data_length * (to_ts(date_str) - start_ts) / total_ts ) return max(min_idx, min(raw_idx, max_idx)) for cycle in cycles: cycle.train_start_idx = compute_idx(cycle.train_start_date) cycle.train_end_idx = compute_idx(cycle.train_end_date) cycle.test_start_idx = compute_idx(cycle.test_start_date) cycle.test_end_idx = compute_idx(cycle.test_end_date) # Ensure proper ordering: train_start < train_end <= test_start < test_end # This handles edge cases where dates round to same index if cycle.train_end_idx <= cycle.train_start_idx: cycle.train_end_idx = min(cycle.train_start_idx + 1, max_idx) if cycle.test_start_idx < cycle.train_end_idx: cycle.test_start_idx = cycle.train_end_idx if cycle.test_end_idx <= cycle.test_start_idx: cycle.test_end_idx = min(cycle.test_start_idx + 1, max_idx) def _aggregate_results( self, cycle_results: List[CycleEvaluation], cycles: List[WalkForwardCycle], all_checkpoint_returns: List[np.ndarray], ) -> EvaluationResult: """Aggregate per-cycle results into a single :class:`EvaluationResult`. Computes mean/std/worst OOS Sharpe, mean WFE, mean IS–OOS gap, optional aggregate Rademacher complexity with haircut, and an effectiveness verdict. Parameters ---------- cycle_results : List[CycleEvaluation] Per-cycle evaluation results. cycles : List[WalkForwardCycle] Cycle specifications (used for test-window lengths). all_checkpoint_returns : List[np.ndarray] Per-cycle checkpoint return matrices for Rademacher computation. Empty list if checkpoints were not tracked. Returns ------- EvaluationResult Aggregated evaluation with summary statistics and verdict. """ oos_sharpes = [c.oos_sharpe for c in cycle_results] wfes = [c.walk_forward_efficiency for c in cycle_results] gaps = [c.is_oos_gap for c in cycle_results] # Let NaN flow through - consumer (hyperparam_tuner) handles bad values mean_wfe = np.mean(wfes) if wfes else np.nan mean_oos_sharpe = np.mean(oos_sharpes) if oos_sharpes else np.nan std_oos_sharpe = np.std(oos_sharpes) if oos_sharpes else np.nan worst_oos_sharpe = np.min(oos_sharpes) if oos_sharpes else np.nan mean_gap = np.mean(gaps) if gaps else np.nan # Compute aggregate Rademacher if checkpoint data available aggregate_rademacher = None adjusted_mean_oos_sharpe = None if self.compute_rademacher and all_checkpoint_returns: # Different cycles may have different return lengths # Filter out empty arrays first non_empty_arrays = [arr for arr in all_checkpoint_returns if arr.size > 0] if non_empty_arrays: min_len = min(arr.shape[-1] for arr in non_empty_arrays) if min_len > 0: # Truncate and stack, skipping any that would become empty truncated = [] for arr in non_empty_arrays: if arr.ndim == 1: arr = arr.reshape(1, -1) truncated_arr = arr[:, :min_len] if truncated_arr.size > 0: truncated.append(truncated_arr) if truncated: combined_returns = np.vstack(truncated) aggregate_rademacher = compute_empirical_rademacher(combined_returns) # Compute haircut on aggregate OOS sharpe total_test_T = sum(c.test_end_idx - c.test_start_idx for c in cycles) adjusted_mean_oos_sharpe, _ = compute_rademacher_haircut( mean_oos_sharpe, aggregate_rademacher, total_test_T, ) # Effectiveness verdict is_effective = False reasons = [] if mean_wfe >= 0.5: reasons.append(f"WFE {mean_wfe:.2f} >= 0.5 (good IS→OOS transfer)") is_effective = True else: reasons.append(f"WFE {mean_wfe:.2f} < 0.5 (poor IS→OOS transfer)") if worst_oos_sharpe > 0: reasons.append(f"Worst OOS Sharpe {worst_oos_sharpe:.2f} > 0") else: reasons.append(f"Worst OOS Sharpe {worst_oos_sharpe:.2f} <= 0") is_effective = False if mean_gap < 0.5: reasons.append(f"IS-OOS gap {mean_gap:.2f} < 0.5 (not overfitting badly)") else: reasons.append(f"IS-OOS gap {mean_gap:.2f} >= 0.5 (significant overfitting)") if aggregate_rademacher is not None: reasons.append(f"Rademacher R̂={aggregate_rademacher:.3f}") if adjusted_mean_oos_sharpe is not None and adjusted_mean_oos_sharpe > 0: reasons.append(f"Adjusted OOS Sharpe {adjusted_mean_oos_sharpe:.2f} > 0") elif adjusted_mean_oos_sharpe is not None: reasons.append(f"Adjusted OOS Sharpe {adjusted_mean_oos_sharpe:.2f} <= 0") is_effective = False return EvaluationResult( trainer_name=self.trainer.name, trainer_config=self.trainer.config, cycles=cycle_results, mean_wfe=mean_wfe, mean_oos_sharpe=mean_oos_sharpe, std_oos_sharpe=std_oos_sharpe, worst_oos_sharpe=worst_oos_sharpe, mean_is_oos_gap=mean_gap, aggregate_rademacher=aggregate_rademacher, adjusted_mean_oos_sharpe=adjusted_mean_oos_sharpe, is_effective=is_effective, effectiveness_reasons=reasons, )
[docs] def print_report(self, result: EvaluationResult): """Print a human-readable evaluation report to stdout. Shows per-cycle IS/OOS metrics in a tabular layout, aggregate statistics, Rademacher diagnostics (if available), and the effectiveness verdict. Parameters ---------- result : EvaluationResult Completed evaluation result to display. """ print("\n" + "=" * 70) print("EVALUATION REPORT") print("=" * 70) print(f"Trainer: {result.trainer_name}") print("\n--- Aggregate Metrics ---") print(f"Mean WFE: {result.mean_wfe:.4f}") print(f"Mean OOS Sharpe: {result.mean_oos_sharpe:.4f} ± {result.std_oos_sharpe:.4f}") print(f"Worst OOS Sharpe: {result.worst_oos_sharpe:.4f}") print(f"IS-OOS Gap: {result.mean_is_oos_gap:.4f}") if result.aggregate_rademacher is not None: print("\n--- Rademacher Metrics ---") print(f"Aggregate R̂: {result.aggregate_rademacher:.4f}") if result.adjusted_mean_oos_sharpe is not None: print(f"Adjusted Sharpe: {result.adjusted_mean_oos_sharpe:.4f}") print(f"\n--- Verdict ---") print(f"Effective: {'YES' if result.is_effective else 'NO'}") for reason in result.effectiveness_reasons: print(f" • {reason}") print("\n--- Per-Cycle ---") for c in result.cycles: rademacher_str = "" if c.rademacher_complexity is not None: rademacher_str = f", R̂={c.rademacher_complexity:.3f}" print(f" Cycle {c.cycle_number}: " f"IS={c.is_sharpe:.3f} → OOS={c.oos_sharpe:.3f} " f"(WFE={c.walk_forward_efficiency:.2f}{rademacher_str})") print("=" * 70)
# ============================================================================= # Comparison Utility # =============================================================================
[docs] def compare_trainers( run_fingerprint: dict, trainers: Dict[str, TrainingEvaluator], verbose: bool = True, ) -> Dict[str, EvaluationResult]: """ Compare multiple trainers on the same data. Parameters ---------- run_fingerprint : dict Run configuration trainers : Dict[str, TrainingEvaluator] Dictionary of name -> evaluator verbose : bool Print progress and summary Returns ------- Dict[str, EvaluationResult] Results keyed by trainer name Example ------- >>> results = compare_trainers( ... run_fingerprint, ... trainers={ ... "sgd_500": TrainingEvaluator.from_runner( ... "train_on_historic_data", max_iterations=500 ... ), ... "sgd_100": TrainingEvaluator.from_runner( ... "train_on_historic_data", max_iterations=100 ... ), ... "random": TrainingEvaluator.random_baseline(), ... }, ... ) """ results = {} for name, evaluator in trainers.items(): if verbose: print(f"\n{'#' * 70}") print(f"# Evaluating: {name}") print(f"{'#' * 70}") results[name] = evaluator.evaluate(run_fingerprint) if verbose: print("\n" + "=" * 70) print("COMPARISON SUMMARY") print("=" * 70) print(f"{'Trainer':<30} {'WFE':>8} {'OOS':>8} {'Worst':>8} {'Gap':>8} {'Eff?':>6}") print("-" * 70) for name, r in results.items(): eff = "YES" if r.is_effective else "NO" print(f"{name:<30} {r.mean_wfe:>8.3f} {r.mean_oos_sharpe:>8.3f} " f"{r.worst_oos_sharpe:>8.3f} {r.mean_is_oos_gap:>8.3f} {eff:>6}") print("=" * 70) return results
# ============================================================================= # Example # ============================================================================= if __name__ == "__main__": run_fingerprint = { "startDateString": "2022-01-01 00:00:00", "endDateString": "2023-06-01 00:00:00", "endTestDateString": "2024-01-01 00:00:00", "tokens": ["BTC", "ETH"], "rule": "momentum", "chunk_period": 1440, "weight_interpolation_period": 1440, "initial_pool_value": 1000000.0, "fees": 0.003, "return_val": "sharpe", "optimisation_settings": { "training_data_kind": "historic", "optimiser": "adam", "base_lr": 0.1, "n_iterations": 200, }, } # Simple usage evaluator = TrainingEvaluator.random_baseline(n_cycles=3) result = evaluator.evaluate(run_fingerprint)