Runners
Core Training Runners
Core training and simulation runners for quantammsim.
This module provides the two primary entry points for using quantammsim:
train_on_historic_data()Optimise strategy parameters on historical price data using either gradient descent (Adam/AdamW/SGD via Optax) or gradient-free search (Optuna). Supports ensemble training, early stopping with validation holdout, warm-starting from previous walk-forward cycles, checkpointing for Rademacher complexity analysis, and Stochastic Weight Averaging.
do_run_on_historic_data()Execute a single forward pass (simulation) with fixed parameters and return the full results dict. Used for post-training evaluation, walk-forward OOS testing, and visualisation. Supports injecting real trade data, time-varying fees/gas costs, and LP supply changes.
Both functions accept a run_fingerprint dict as their primary
configuration. See Run Fingerprints for the complete
reference of available settings.
- train_on_historic_data(run_fingerprint, root=None, iterations_per_print=1, force_init=False, price_data=None, verbose=True, run_location=None, return_training_metadata=False, warm_start_params=None, warm_start_weights=None)[source]
Optimise strategy parameters on historical price data.
This is the primary training entry point for quantammsim. It loads (or accepts) price data, constructs the JAX computation graph, and runs either gradient-based (Adam/AdamW/SGD) or gradient-free (Optuna) optimisation according to
run_fingerprint["optimisation_settings"]["method"].- Parameters:
run_fingerprint (dict) –
Master configuration dict. Key fields consumed here:
tokens,startDateString,endDateString,endTestDateString— data selectionrule— pool/strategy type (e.g."momentum","mean_reversion_channel")return_val— objective metric (default"daily_log_sharpe")optimisation_settings.method—"gradient_descent"or"optuna"optimisation_settings.optimiser—"adam","adamw", or"sgd"optimisation_settings.n_iterations— training epochsoptimisation_settings.val_fraction— fraction of training window held out for early-stopping validation (0 = disabled)optimisation_settings.use_swa— enable Stochastic Weight Averagingoptimisation_settings.track_checkpoints— save periodic parameter snapshots for Rademacher complexity analysis
See Run Fingerprints for the full reference.
root (str, optional) – Root directory for data files and saved results.
iterations_per_print (int, optional) – Print training progress every N iterations (default 1).
force_init (bool, optional) – If True, ignore cached results and re-initialise parameters.
price_data (array-like or DataFrame, optional) – Pre-loaded price data. When None, data is loaded from parquet files based on
run_fingerprintdate/token settings.verbose (bool, optional) – Print detailed progress information (default True).
run_location (str, optional) – Path to a previously-saved run to resume from. When None, a new run is initialised (or auto-detected from the fingerprint hash).
return_training_metadata (bool, optional) – If True, return
(params, metadata)where metadata containsepochs_trained,final_objective, andcheckpoint_returns(a(n_checkpoints, T-1)array for Rademacher complexity, or None if checkpointing was disabled).warm_start_params (dict, optional) – Strategy parameters from a previous walk-forward cycle. Each value is expanded to
(n_parameter_sets, ...)shape with added Gaussian noise (scale controlled byoptimisation_settings.noise_scale).warm_start_weights (array-like, optional) – Final portfolio weights from a previous cycle, shape
(n_assets,). The pool starts with a freshinitial_pool_valuedistributed according to these weights.
- Returns:
Gradient descent,
return_training_metadata=False: best params dict.Gradient descent,
return_training_metadata=True:(params, metadata)tuple.Optuna: list of best trials, or None if none completed.
- Return type:
- do_run_on_historic_data(run_fingerprint, params={}, root=None, price_data=None, verbose=False, raw_trades=None, fees=None, gas_cost=None, arb_fees=None, fees_df=None, gas_cost_df=None, arb_fees_df=None, lp_supply_df=None, do_test_period=False, low_data_mode=False, preslice_burnin=True)[source]
Execute a forward-pass simulation with fixed parameters.
Runs the full simulation pipeline — price loading, weight calculation, arbitrage, and metric computation — using pre-trained (or manually specified) strategy parameters. This is the primary entry point for post-training evaluation, walk-forward OOS testing, and visualisation.
- Parameters:
run_fingerprint (dict) – Master configuration dict (same structure as
train_on_historic_data()).params (dict or list of dict) – Strategy parameters. A single dict runs one simulation; a list of dicts runs multiple parameter sets in parallel via
vmap.root (str, optional) – Root directory for data files.
price_data (array-like or DataFrame, optional) – Pre-loaded price data. When None, loaded from parquet files.
verbose (bool, optional) – Print progress information (default False).
raw_trades (DataFrame, optional) – Real trade data to inject. Columns: unix timestamp (minute), token_in, token_out, amount_in.
fees (float, optional) – Swap fee override (e.g. 0.003 for 30 bps).
gas_cost (float, optional) – Gas cost override per transaction.
arb_fees (float, optional) – Arbitrageur fee override.
fees_df (DataFrame, optional) – Time-varying swap fees (columns: unix, fee).
gas_cost_df (DataFrame, optional) – Time-varying gas costs (columns: unix, gas_cost).
arb_fees_df (DataFrame, optional) – Time-varying arb fees (columns: unix, arb_fee).
lp_supply_df (DataFrame, optional) – Time-varying LP supply changes.
do_test_period (bool, optional) – If True, also run the OOS test period defined by
endDateStringtoendTestDateString(default False).low_data_mode (bool, optional) – If True, drop raw price arrays from the output dict to reduce memory usage (default False).
preslice_burnin (bool, optional) – If True, pre-slice data to
max_memory_daysof burn-in plus the simulation period (default True). Set False to load all available history.
- Returns:
When
do_test_period=False: a single results dict with keys includingvalues,reserves,weights,coarse_weights,objective, and per-asset breakdowns.When
do_test_period=True:(train_results, test_results).For multiple parameter sets, each value in the dict is a list (one entry per parameter set).
- Return type:
- do_run_on_historic_data_with_provided_coarse_weights(run_fingerprint, coarse_weights, params={}, root=None, price_data=None, verbose=False, raw_trades=None, fees=None, gas_cost=None, arb_fees=None, fees_df=None, gas_cost_df=None, arb_fees_df=None, lp_supply_df=None, do_test_period=False, low_data_mode=False)[source]
Execute a simulation using pre-computed coarse weights.
Like
do_run_on_historic_data(), but bypasses the weight-calculation step entirely. The caller providescoarse_weightsdirectly, and this function performs only fine-weight interpolation, arbitrage simulation, and metric computation.This is useful for replaying a trained strategy with externally-computed or manually-specified weight trajectories, or for separating the weight computation from the simulation for profiling or debugging.
- Parameters:
run_fingerprint (dict) – Master configuration dict.
coarse_weights (jnp.ndarray) – Pre-computed coarse weights, shape
(n_coarse_steps, n_assets).params (dict or list of dict, optional) – Strategy parameters (used only for
initial_reservesand any subsidiary parameters, not for weight computation).root (str, optional) – Root directory for data files.
price_data (array-like or DataFrame, optional) – Pre-loaded price data.
verbose (bool, optional) – Print progress (default False).
raw_trades (DataFrame, optional) – Real trade data to inject.
fees (float, optional) – Swap fee override.
gas_cost (float, optional) – Gas cost override.
arb_fees (float, optional) – Arbitrageur fee override.
fees_df (DataFrame, optional) – Time-varying swap fees.
gas_cost_df (DataFrame, optional) – Time-varying gas costs.
arb_fees_df (DataFrame, optional) – Time-varying arb fees.
lp_supply_df (DataFrame, optional) – Time-varying LP supply changes.
do_test_period (bool, optional) – Run OOS test period (default False).
low_data_mode (bool, optional) – Drop raw arrays from output to save memory (default False).
- Returns:
Same structure as
do_run_on_historic_data().- Return type:
Runner Utilities
- create_trial_params(trial, param_config, params, run_fingerprint, n_assets, expand_around=False)[source]
Create trial parameters for Optuna optimization.
Parameters:
- trialoptuna.Trial
The Optuna trial object
- param_configdict
Configuration for parameter optimization. Each parameter can have: - low: float, lower bound - high: float, upper bound - log_scale: bool, whether to use log scale - scalar: bool, whether to use same value for all assets
- paramsdict
Current parameter values, used for shape information
- run_fingerprintdict
Run configuration
- n_assetsint
Number of assets
Returns:
- dict
Trial parameters dictionary
Raises:
- ValueError
If parameter shapes are invalid or required config is missing
- generate_evaluation_points(start_idx, end_idx, bout_length, n_points, min_spacing, random_key=0)[source]
Generate evaluation start points for optuna-style hyperparameter search.
If the training period is exactly equal to bout_length (no room for multiple windows), returns just the start_idx as a single evaluation point.
- Parameters:
start_idx (int) – Start index of the training period
end_idx (int) – End index of the training period
bout_length (int) – Length of each evaluation window
n_points (int) – Desired number of evaluation points
min_spacing (int) – Minimum spacing between evaluation points (currently unused)
random_key (int) – Random seed for reproducibility
- Returns:
List of evaluation start indices
- Return type:
- find_best_balanced_solution(values_array, n_objectives=None)[source]
Find the solution closest to the ideal point after normalizing objectives.
- Parameters:
values_array – Either a numpy array of shape (n_trials, n_objectives) or a list of optuna trials with values attribute
n_objectives – Number of objectives. Only needed if using list of trials.
- Returns:
Index of the best balanced solution
- Return type:
- class OptunaManager(run_fingerprint)[source]
Bases:
objectManages an Optuna hyperparameter optimization study lifecycle.
Encapsulates study creation, execution, early stopping, and result persistence. Configuration is drawn from
run_fingerprint["optimisation_settings"]["optuna_settings"].- Parameters:
run_fingerprint (dict) – Run configuration. Must contain
optimisation_settings.optuna_settings.
- study
The Optuna study, created by
setup_study().- Type:
optuna.Study or None
- logger
File-backed logger writing to
output_dir/optimization.log.- Type:
- setup_study(multi_objective=False)[source]
Create and configure the Optuna study.
Initialises an Optuna study with TPE sampler (multivariate), median pruner, and optional RDB storage. For multi-objective mode, creates a three-direction maximize study (mean return, worst-case, stability).
- Parameters:
multi_objective (bool, optional) – If True, creates a multi-objective study with three maximize directions. Default is False (single-objective maximize).
- early_stopping_callback(study, trial)[source]
Enhanced callback to implement early stopping using both training and validation metrics.
- optimize(objective)[source]
Run the optimization process with error handling and parallel execution.
Delegates to
study.optimizewith the configured number of trials, timeout, parallel jobs, and early-stopping callback. All exceptions are caught (logged, not re-raised) so that partial results are always saved viasave_results.- Parameters:
objective (callable) – Optuna objective function:
trial -> float(single-objective) ortrial -> tuple[float, ...](multi-objective).
- class Hashabledict[source]
Bases:
dictA hashable dictionary class that enables using dictionaries as dictionary keys.
This class extends the built-in dict class to make dictionaries hashable by implementing the __hash__ and __eq__ methods. The hash is computed based on a sorted tuple of key-value pairs.
- __key()
Returns a tuple of sorted key-value pairs representing the dictionary.
- __eq__(other)[source]
Checks equality between this dictionary and another by comparing their sorted key-value pairs.
Examples
>>> d1 = Hashabledict({'a': 1, 'b': 2}) >>> d2 = Hashabledict({'b': 2, 'a': 1}) >>> hash(d1) == hash(d2) True >>> d1 == d2 True >>> d3 = {d1: 'value'} # Can use as dictionary key
- class NestedHashabledict(*args, **kwargs)[source]
Bases:
dictA hashable dictionary class that enables using dictionaries as dictionary keys. Handles deeply nested dictionaries by recursively converting all nested dicts.
- get_sig_variations(n_assets)[source]
Compute signature variations for arbitrage.
Returns all possible (asset_in, asset_out) pairs encoded as a tuple of tuples, where each inner tuple has exactly one +1 (asset out) and one -1 (asset in), with zeros elsewhere.
- Parameters:
n_assets (int) – Number of assets in the pool.
- Returns:
Tuple of tuples representing valid arbitrage directions. Each inner tuple has shape (n_assets,) with values in {-1, 0, 1}.
- Return type:
Example
>>> get_sig_variations(3) ((1, -1, 0), (1, 0, -1), (-1, 1, 0), (0, 1, -1), (-1, 0, 1), (0, -1, 1))
- create_static_dict(run_fingerprint, bout_length, all_sig_variations=None, overrides=None)[source]
Create a static_dict from run_fingerprint for use in forward passes.
This simplifies the previous pattern of manually picking ~30 fields from run_fingerprint to create static_dict. Instead, we start with the full run_fingerprint and: 1. Exclude training-only fields 2. Apply necessary transformations (e.g., tokens -> tuple) 3. Add computed fields (bout_length, all_sig_variations) 4. Apply any overrides
- Parameters:
- Returns:
Hashable static dictionary for use in JAX forward passes
- Return type:
Example
>>> static_dict = create_static_dict(run_fingerprint, bout_length=10080) >>> # Instead of manually building: >>> # static_dict = {"chunk_period": rf["chunk_period"], "bout_length": ..., ...}
- get_run_location(run_fingerprint)[source]
Generate a unique run location identifier based on the run fingerprint.
This function creates a unique identifier for a simulation run by hashing the run_fingerprint dictionary. The run_fingerprint contains configuration parameters that define the simulation run.
- Parameters:
run_fingerprint (dict) – A dictionary containing the configuration parameters for the simulation run. This typically includes parameters like start/end dates, tokens, rules, etc.
- Returns:
A string identifier in the format “run_<sha256_hash>” where the hash is generated from the sorted JSON representation of the run_fingerprint.
- Return type:
Examples
>>> fingerprint = {"startDate": "2023-01-01", "tokens": ["BTC", "ETH"]} >>> get_run_location(fingerprint) 'run_8d147a1f8b8...'
- nan_rollback(grads, params, old_params)[source]
Handles NaN values in gradients by rolling back to previous parameter values.
This function checks for NaN values in gradients and reverts the corresponding parameters back to their previous values when NaNs are detected. This helps maintain numerical stability during optimization.
- Parameters:
- Returns:
Updated parameters with NaN values rolled back to previous values
- Return type:
Examples
>>> grads = {"log_k": jnp.array([[1.0, jnp.nan], [3.0, 4.0]])} >>> params = {"log_k": jnp.array([[0.1, 0.2], [0.3, 0.4]])} >>> old_params = {"log_k": jnp.array([[0.05, 0.15], [0.25, 0.35]])} >>> rolled_back = nan_rollback(grads, params, old_params)
- has_nan_grads(grad_tree)[source]
Check whether any leaf in a gradient pytree contains NaN values.
JIT-compiled for use inside training loops. Uses
tree_reduceto scan all leaves without materializing intermediate structures.- Parameters:
grad_tree (pytree) – JAX pytree of gradient arrays.
- Returns:
Scalar boolean: True if any gradient leaf contains a NaN.
- Return type:
jnp.ndarray
- has_nan_params(params)[source]
Check whether any learnable parameter arrays contain NaN values.
Skips non-learnable keys (
initial_weights,initial_weights_logits,subsidary_params) that are not updated by the optimizer.
- nan_param_reinit(params, grads, pool, initial_params, run_fingerprint, n_tokens, n_parameter_sets)[source]
Reinitialize parameter sets that contain NaN values.
During training, parameters can become NaN from bad update steps even when gradients were finite (e.g., large learning rate + steep curvature). This function detects NaN-contaminated parameter sets and replaces them with freshly initialized (noised) parameters via
pool.init_parameters, preserving the remaining healthy sets.- Parameters:
params (dict) – Current parameter dict with arrays of shape
(n_parameter_sets, ...).grads (dict) – Current gradient dict (unused directly, but passed for API consistency).
pool (BaseTFMMPool) – Pool instance, used to call
init_parametersfor replacement values.initial_params (dict) – Initial values dict passed to
pool.init_parameters.run_fingerprint (dict) – Run configuration.
n_tokens (int) – Number of assets in the pool.
n_parameter_sets (int) – Number of parallel parameter sets.
- Returns:
Parameter dict with NaN-contaminated sets replaced by fresh initializations.
- Return type:
- get_unique_tokens(run_fingerprint)[source]
Gets unique tokens from run fingerprint including subsidiary pools.
Extracts all tokens from the main pool and subsidiary pools in the run fingerprint, removes duplicates, and returns a sorted list of unique tokens.
- Parameters:
run_fingerprint (dict) – Dictionary containing run configuration including tokens and subsidiary pools
- Returns:
Sorted list of unique token symbols
- Return type:
Examples
>>> fingerprint = { ... "tokens": ["BTC", "ETH"], ... "subsidary_pools": [{"tokens": ["ETH", "DAI"]}] ... } >>> get_unique_tokens(fingerprint) ['BTC', 'DAI', 'ETH']
- split_list(lst, num_splits)[source]
Splits a list into a specified number of roughly equal sublists.
Divides a list into num_splits sublists, distributing any remainder elements evenly among the first sublists.
- Parameters:
- Returns:
List of sublists
- Return type:
Examples
>>> split_list([1,2,3,4,5], 2) [[1,2,3], [4,5]] >>> split_list([1,2,3,4,5,6], 3) [[1,2], [3,4], [5,6]]
- invert_permutation(perm)[source]
Compute the inverse of a permutation.
Given a permutation array that maps indices to their new positions, returns the inverse permutation that maps the new positions back to their original indices.
- Parameters:
perm (numpy.ndarray) – Array representing a permutation of indices
- Returns:
The inverse permutation array
- Return type:
Examples
>>> perm = np.array([2,0,1]) >>> invert_permutation(perm) array([1, 2, 0])
- permute_list_of_params(list_of_params, seed=0)[source]
Randomly permute a list of parameters using a fixed random seed.
This function takes a list of parameters and returns a new list with the same elements in a randomly permuted order. The permutation is deterministic based on the provided random seed.
- Parameters:
- Returns:
A new list containing the same elements as the input list but in a randomly permuted order
- Return type:
Examples
>>> params = [1, 2, 3, 4] >>> permute_list_of_params(params, seed=42) [3, 1, 4, 2] >>> permute_list_of_params(params, seed=42) # Same seed gives same permutation [3, 1, 4, 2]
- unpermute_list_of_params(list_of_params)[source]
Restore the original order of a previously permuted list of parameters.
This function takes a list that was permuted using permute_list_of_params() and restores it to its original order by applying the inverse permutation with the same random seed.
- Parameters:
list_of_params (list) – The permuted list of parameters to restore to original order
- Returns:
A new list containing the same elements as the input list but restored to their original order before permutation
- Return type:
Examples
>>> params = [1, 2, 3, 4] >>> permuted = permute_list_of_params(params) # [3, 1, 4, 2] >>> unpermute_list_of_params(permuted) # Restores original order [1, 2, 3, 4]
- get_trades_and_fees(run_fingerprint, raw_trades, fees_df, gas_cost_df, arb_fees_df, lp_supply_df, do_test_period=False)[source]
Process trade and fee data for a simulation run.
Takes raw trades, fees, gas costs and arbitrage fees and converts them into arrays suitable for simulation. Handles both training and test periods if specified.
- Parameters:
run_fingerprint (dict) – Dictionary containing run configuration including start/end dates and tokens
raw_trades (pd.DataFrame, optional) – DataFrame containing raw trade data
fees_df (pd.DataFrame, optional) – DataFrame containing fee data
gas_cost_df (pd.DataFrame, optional) – DataFrame containing gas cost data
arb_fees_df (pd.DataFrame, optional) – DataFrame containing arbitrage fee data
lp_supply_df (pd.DataFrame, optional) – DataFrame containing LP supply data
do_test_period (bool, optional) – Whether to process data for a test period after training period (default False)
- Returns:
Contains processed arrays for trades, fees, gas costs and arb fees for both training and test periods as applicable
- Return type:
- create_daily_unix_array(start_date_str, end_date_str)[source]
Creates an array of daily Unix timestamps in milliseconds between two dates.
- create_time_step(row, unix_values, tokens, index)[source]
Creates a SimulationResultTimestepDto object for a single time step.
- Parameters:
- Returns:
Object containing timestamp and coin data for this timestep
- Return type:
SimulationResultTimestepDto
- optimized_output_conversion(simulationRunDto, outputDict, tokens)[source]
Converts simulation output dictionary to a list of SimulationResultTimestepDto objects.
- Parameters:
- Returns:
List of SimulationResultTimestepDto objects containing timestep data
- Return type:
The function: 1. Creates Unix timestamps for each day between start and end dates 2. Downsamples simulation data from minutes to daily frequency 3. Calculates token weights from reserves, prices and total value 4. Combines data into timestep DTOs with coin holdings and values
- probe_max_n_parameter_sets(run_fingerprint, min_sets=1, max_sets=64, safety_margin=0.9, verbose=True)[source]
Probe to find the maximum n_parameter_sets that fits in GPU memory.
Uses binary search to find the largest n_parameter_sets value that can complete a forward pass without OOM. Returns a dict with the recommended value and diagnostic info.
- Parameters:
run_fingerprint (dict) – The run fingerprint configuration. Will be modified temporarily during probing.
min_sets (int) – Minimum n_parameter_sets to try (default 1).
max_sets (int) – Maximum n_parameter_sets to try (default 64).
safety_margin (float) – Fraction of max found to use as recommendation (default 0.9). This provides headroom for gradient computation which uses more memory.
verbose (bool) – Whether to print progress information.
- Returns:
Keys:
max_n_parameter_sets(int),recommended_n_parameter_sets(int, with safety margin applied),probed_values(list),success_values(list),failed_values(list).- Return type:
Notes
This function temporarily modifies run_fingerprint during probing.
JAX caches are cleared between attempts.
The forward pass (without gradients) is used for probing, so gradient computation may require ~2x more memory. Hence the safety_margin.
- allocate_memory_budget(run_fingerprint, available_memory_gb=None, priority='exploration', probe_if_needed=True, max_ensemble_members=1, verbose=True)[source]
Allocate memory budget across hyperparameters based on priority.
- Parameters:
run_fingerprint (dict) – The run fingerprint configuration.
available_memory_gb (float, optional) – Available GPU memory in GB. If None and probe_if_needed=True, will probe to determine capacity.
priority (str) – How to allocate memory budget: - “exploration”: Maximize n_parameter_sets (find diverse solutions) - “robustness”: Balance n_parameter_sets and n_ensemble_members - “variance_reduction”: Maximize batch_size (stable gradients)
probe_if_needed (bool) – Whether to probe memory if available_memory_gb is not provided.
max_ensemble_members (int) – Maximum ensemble members to allocate (default 1 = no ensembling). Set higher (e.g., 4) if you want the “robustness” priority to use ensembles.
verbose (bool) – Whether to print allocation info.
- Returns:
Recommended settings with keys:
n_parameter_sets(int),n_ensemble_members(int),batch_size(int),priority_used(str),probe_result(dict or None).- Return type:
- apply_memory_allocation(run_fingerprint, allocation)[source]
Apply memory allocation results to a run_fingerprint.
- auto_configure_memory_params(run_fingerprint, priority='exploration', max_ensemble_members=1, verbose=True)[source]
Convenience function: probe memory and apply allocation in one step.
- Parameters:
run_fingerprint (dict) – The run fingerprint to configure (will be modified in place).
priority (str) – Allocation priority (“exploration”, “robustness”, “variance_reduction”).
max_ensemble_members (int) – Maximum ensemble members to allocate (default 1 = no ensembling).
verbose (bool) – Whether to print progress info.
- Returns:
The modified run_fingerprint with optimal memory settings.
- Return type:
Example
>>> run = {...} # your run_fingerprint >>> auto_configure_memory_params(run, priority="exploration") >>> train_on_historic_data(run)
- compute_selection_metric(train_metrics, val_metrics=None, continuous_test_metrics=None, method='best_val', metric='sharpe', min_threshold=0.0)[source]
Compute selection metric value for a single iteration/trial.
This is the shared core logic used by both BestParamsTracker (during training) and load_manually (post-training). Returns a value for comparison and the index of the best param set.
- Parameters:
train_metrics (list of dict) – Training metrics for each param set. Each dict has keys like “sharpe”, “returns_over_uniform_hodl”, etc.
val_metrics (list of dict, optional) – Validation metrics for each param set. Required if method=”best_val”.
continuous_test_metrics (list of dict, optional) – Continuous test metrics for each param set.
method (str) – Selection method. One of SELECTION_METHODS.
metric (str) – Which metric to use for comparison (e.g., “sharpe”, “returns_over_uniform_hodl”).
min_threshold (float) – Minimum threshold for “best_train_min_test” method.
- Returns:
(selection_value, best_param_idx) - value for comparison and index of best param set. Higher selection_value is always better.
- Return type:
- class BestParamsTracker(selection_method='best_val', metric='sharpe', min_threshold=0.0)[source]
Bases:
objectUnified tracking of params across training iterations/trials.
Tracks both “last” (most recent iteration) and “best” (by selection method) params along with their associated metrics and continuous outputs.
Used by both SGD and Optuna paths to ensure consistent param selection logic.
- Parameters:
- last_*
State from the most recent update() call.
- Type:
Various
- best_*
State from when selection metric was highest.
- Type:
Various
- update(iteration, params, continuous_outputs, train_metrics_list, val_metrics_list=None, continuous_test_metrics_list=None)[source]
Update tracker with current iteration’s state.
- Parameters:
iteration (int) – Current iteration/trial number.
params (dict) – Current parameters (batched over param sets).
continuous_outputs (dict) – Output from continuous forward pass. Must have “reserves” and “weights” with shape (n_param_sets, time_steps, …).
train_metrics_list (list of dict) – Training metrics for each param set.
val_metrics_list (list of dict, optional) – Validation metrics for each param set.
continuous_test_metrics_list (list of dict, optional) – Continuous test metrics for each param set.
- Returns:
True if this iteration improved the best metric, False otherwise.
- Return type:
- select_param_set(params_dict, idx, n_param_sets)[source]
Extract single param set from batched params.
- get_results(n_param_sets, train_bout_length)[source]
Get comprehensive results with both last and best state.
- Parameters:
- Returns:
Comprehensive results including: - last_* fields: State from most recent iteration - best_* fields: State from when selection metric was best - Selection metadata
- Return type:
Run Fingerprint Defaults
Multi-Period SGD
Multi-Period SGD Training for Financial Strategies
This module implements multi-period robust training where we optimize parameters across multiple temporal windows simultaneously with a single forward pass and continuous pool state.
Key Design: - ONE forward pass spanning the entire data period - Dynamic slice out evaluation windows for each “period” - Aggregate losses across periods -> single backward pass - Pool state continuity is automatic (one continuous simulation)
This is NOT walk-forward (no retraining per period), but rather finds ONE set of params that performs well across all temporal windows.
Benefits: - Automatic pool state continuity through continuous forward pass - Single JIT compilation (no recompilation for different bout lengths) - Efficient: one forward pass, one backward pass per update step - Encourages robust parameters that work across market regimes
- class PeriodSpec(period_id, rel_start, rel_end)[source]
Bases:
objectSpecification for a single evaluation period within the forward pass.
Defines a contiguous temporal slice of the forward pass output that constitutes one evaluation window. Multiple
PeriodSpecinstances partition (or overlap-partition) the full simulation into the windows used by multi-period SGD training.- period_id
Zero-based ordinal index identifying this period within the sequence of evaluation windows.
- Type:
- rel_start
Start index of this period, relative to the first timestep of the forward pass output (not the raw price array).
- Type:
- rel_end
End index (exclusive) of this period, relative to the first timestep of the forward pass output.
- Type:
- class MultiPeriodResult(period_sharpes, period_returns, period_returns_over_hodl, mean_sharpe, std_sharpe, worst_sharpe, mean_returns_over_hodl, epochs_trained, final_objective, best_params=<factory>)[source]
Bases:
objectResults from multi-period training.
Collects per-period evaluation metrics and their summary statistics after training a single parameter set across all temporal windows.
- Parameters:
- period_returns_over_hodl
Cumulative return relative to a uniform hold-all-assets baseline, per evaluation period.
- Type:
List[float]
- final_objective
Best aggregated objective value observed during training (the value that triggered
best_paramsto be saved).- Type:
- best_params
Strategy parameters corresponding to
final_objective, stored as NumPy arrays. Empty dict if training produced no valid update.- Type:
Dict[str, Any]
- __init__(period_sharpes, period_returns, period_returns_over_hodl, mean_sharpe, std_sharpe, worst_sharpe, mean_returns_over_hodl, epochs_trained, final_objective, best_params=<factory>)
- create_multi_period_training_step(base_forward_pass, prices, period_specs, n_assets, return_val, aggregation='mean', softmin_temperature=1.0)[source]
Create a training step function that computes aggregate loss across periods.
This returns a function with signature (params, start_index) -> scalar, compatible with the existing backpropagation factories.
- Parameters:
base_forward_pass (callable) – Partial forward_pass with full bout_length static_dict
prices (jnp.ndarray) – Full price array
period_specs (tuple of (rel_start, slice_len)) – For each period: relative start index and length within forward pass output. Must be tuple of tuples (static) so loop unrolls at trace time.
n_assets (int) – Number of assets
return_val (str) – Metric to compute per period (“sharpe”, “returns”, etc.)
aggregation (str) – How to combine period metrics: - “mean”: Simple average (all periods contribute equally) - “min”: Hard minimum (CAUTION: only minimum element gets gradients) - “softmin”: Soft minimum via negative softmax (recommended for worst-case) - “sum”: Sum of all metrics
softmin_temperature (float) – Temperature for softmin aggregation. Lower = closer to hard min. Default 1.0 gives moderate smoothing. Use 0.1-0.5 for sharper focus on worst.
- Returns:
Function (params, start_index) -> scalar
- Return type:
callable
Notes
IMPORTANT: Using aggregation=”min” has a gradient flow problem!
With hard min, gradients only flow through the single minimum element. This means: - Only 1 of N periods contributes to parameter updates - Gradients are sparse and noisy - Training can be unstable
- Solution: Use “softmin” instead, which computes a soft minimum:
softmin(x) = sum(x * softmax(-x / temperature))
This gives more weight to lower-performing periods while still allowing gradients to flow from all periods. As temperature → 0, softmin → hard min.
- generate_period_specs(n_periods, total_length, overlap_fraction=0.0)[source]
Generate period specifications that partition the simulation into evaluation windows.
Divides
total_lengthtimesteps inton_periodscontiguous windows. Whenoverlap_fractionis zero the windows tile the interval exactly (the last period absorbs any remainder). When positive, each window extends into its successor byoverlap_fractionof the base period length, producing correlated but longer evaluation windows useful for smoothing period-boundary effects.- Parameters:
n_periods (int) – Number of evaluation windows to generate.
total_length (int) – Total number of timesteps available in the forward pass output.
overlap_fraction (float, optional) – Fraction of a base period length by which consecutive windows overlap.
0.0(default) produces a non-overlapping partition;0.5means each window shares half its length with the next.
- Returns:
Ordered list of
PeriodSpecinstances covering (possibly overlapping) the full simulation length.- Return type:
List[PeriodSpec]
Examples
>>> specs = generate_period_specs(n_periods=4, total_length=1000) >>> [(s.rel_start, s.rel_end) for s in specs] [(0, 250), (250, 500), (500, 750), (750, 1000)]
>>> specs = generate_period_specs(3, 900, overlap_fraction=0.5) >>> [(s.rel_start, s.rel_end) for s in specs] [(0, 300), (150, 450), (300, 600)]
- multi_period_sgd_training(run_fingerprint, n_periods=4, overlap_fraction=0.0, max_epochs=500, aggregation='mean', softmin_temperature=1.0, verbose=True, root=None)[source]
Run multi-period SGD training.
Trains ONE set of parameters that performs well across multiple temporal windows simultaneously.
- Parameters:
run_fingerprint (dict) – Run configuration
n_periods (int) – Number of evaluation periods
overlap_fraction (float) – Fraction of overlap between periods (0.0 = no overlap)
max_epochs (int) – Maximum training epochs
aggregation (str) – How to combine period metrics: - “mean”: Simple average (default, all periods equal) - “softmin”: Soft minimum (recommended for worst-case optimization) - “min”: Hard minimum (NOT recommended - gradient flow issues) - “sum”: Sum of metrics
softmin_temperature (float) – Temperature for softmin aggregation. Lower = closer to hard min. Default 1.0. Use 0.1-0.5 for sharper worst-case focus.
verbose (bool) – Print progress
root (str | None)
- Returns:
Training result and summary statistics
- Return type:
Tuple[MultiPeriodResult, dict]
Metric Extraction
Metric Extraction: Registry-based lookup for cycle evaluation metrics.
This module provides unified metric extraction from CycleEvaluation objects, replacing repetitive if/elif chains with a registry-based approach.
Usage:
from quantammsim.runners.metric_extraction import extract_cycle_metric
# Extract aggregated metrics from cycle evaluations
value = extract_cycle_metric(cycle_evals, "mean_oos_sharpe")
value = extract_cycle_metric(cycle_evals, "worst_wfe")
value = extract_cycle_metric(cycle_evals, "neg_is_oos_gap")
- CYCLE_METRICS: Dict[str, str] = {'adjusted_oos_sharpe': 'adjusted_oos_sharpe', 'is_calmar': 'is_calmar', 'is_daily_log_sharpe': 'is_daily_log_sharpe', 'is_oos_gap': 'is_oos_gap', 'is_returns': 'is_returns', 'is_returns_over_hodl': 'is_returns_over_hodl', 'is_sharpe': 'is_sharpe', 'is_sterling': 'is_sterling', 'is_ulcer': 'is_ulcer', 'oos_calmar': 'oos_calmar', 'oos_daily_log_sharpe': 'oos_daily_log_sharpe', 'oos_returns': 'oos_returns', 'oos_returns_over_hodl': 'oos_returns_over_hodl', 'oos_sharpe': 'oos_sharpe', 'oos_sterling': 'oos_sterling', 'oos_ulcer': 'oos_ulcer', 'wfe': 'walk_forward_efficiency'}
Registry mapping short metric names to
CycleEvaluationattribute names.Keys are the tokens recognised in metric spec strings (e.g.
"mean_oos_sharpe"→ aggregator"mean"+ metric"oos_sharpe"). Values are the corresponding attribute onCycleEvaluation.
- AGGREGATORS: Dict[str, Callable[[List[float]], float]] = {'mean': <function _mean_agg>, 'worst': <function _worst_agg>}
Aggregation functions keyed by the prefix used in metric spec strings. E.g.
"mean_oos_sharpe"dispatches toAGGREGATORS["mean"].
- extract_cycle_metric(cycle_evals, metric_spec)[source]
Extract aggregated metric from CycleEvaluation list.
Supports metric specifications like: - “mean_oos_sharpe”: mean of oos_sharpe across cycles - “worst_wfe”: minimum walk_forward_efficiency - “neg_is_oos_gap”: negated mean of is_oos_gap (for minimization) - “adjusted_mean_oos_sharpe”: mean of adjusted_oos_sharpe
- Parameters:
cycle_evals (List[CycleEvaluation]) – List of cycle evaluation results
metric_spec (str) – Metric specification string
- Returns:
Aggregated metric value
- Return type:
Examples
>>> value = extract_cycle_metric(cycle_evals, "mean_oos_sharpe") >>> value = extract_cycle_metric(cycle_evals, "worst_wfe") >>> value = extract_cycle_metric(cycle_evals, "neg_is_oos_gap")
- get_metric_from_result(result, metric_name)[source]
Extract a metric from an EvaluationResult object.
- Parameters:
result (EvaluationResult) – The evaluation result object
metric_name (str) – Name of the metric to extract
- Returns:
The metric value
- Return type: