Runners

Core Training Runners

Core training and simulation runners for quantammsim.

This module provides the two primary entry points for using quantammsim:

train_on_historic_data()

Optimise strategy parameters on historical price data using either gradient descent (Adam/AdamW/SGD via Optax) or gradient-free search (Optuna). Supports ensemble training, early stopping with validation holdout, warm-starting from previous walk-forward cycles, checkpointing for Rademacher complexity analysis, and Stochastic Weight Averaging.

do_run_on_historic_data()

Execute a single forward pass (simulation) with fixed parameters and return the full results dict. Used for post-training evaluation, walk-forward OOS testing, and visualisation. Supports injecting real trade data, time-varying fees/gas costs, and LP supply changes.

Both functions accept a run_fingerprint dict as their primary configuration. See Run Fingerprints for the complete reference of available settings.

train_on_historic_data(run_fingerprint, root=None, iterations_per_print=1, force_init=False, price_data=None, verbose=True, run_location=None, return_training_metadata=False, warm_start_params=None, warm_start_weights=None)[source]

Optimise strategy parameters on historical price data.

This is the primary training entry point for quantammsim. It loads (or accepts) price data, constructs the JAX computation graph, and runs either gradient-based (Adam/AdamW/SGD) or gradient-free (Optuna) optimisation according to run_fingerprint["optimisation_settings"]["method"].

Parameters:
  • run_fingerprint (dict) –

    Master configuration dict. Key fields consumed here:

    • tokens, startDateString, endDateString, endTestDateString — data selection

    • rule — pool/strategy type (e.g. "momentum", "mean_reversion_channel")

    • return_val — objective metric (default "daily_log_sharpe")

    • optimisation_settings.method"gradient_descent" or "optuna"

    • optimisation_settings.optimiser"adam", "adamw", or "sgd"

    • optimisation_settings.n_iterations — training epochs

    • optimisation_settings.val_fraction — fraction of training window held out for early-stopping validation (0 = disabled)

    • optimisation_settings.use_swa — enable Stochastic Weight Averaging

    • optimisation_settings.track_checkpoints — save periodic parameter snapshots for Rademacher complexity analysis

    See Run Fingerprints for the full reference.

  • root (str, optional) – Root directory for data files and saved results.

  • iterations_per_print (int, optional) – Print training progress every N iterations (default 1).

  • force_init (bool, optional) – If True, ignore cached results and re-initialise parameters.

  • price_data (array-like or DataFrame, optional) – Pre-loaded price data. When None, data is loaded from parquet files based on run_fingerprint date/token settings.

  • verbose (bool, optional) – Print detailed progress information (default True).

  • run_location (str, optional) – Path to a previously-saved run to resume from. When None, a new run is initialised (or auto-detected from the fingerprint hash).

  • return_training_metadata (bool, optional) – If True, return (params, metadata) where metadata contains epochs_trained, final_objective, and checkpoint_returns (a (n_checkpoints, T-1) array for Rademacher complexity, or None if checkpointing was disabled).

  • warm_start_params (dict, optional) – Strategy parameters from a previous walk-forward cycle. Each value is expanded to (n_parameter_sets, ...) shape with added Gaussian noise (scale controlled by optimisation_settings.noise_scale).

  • warm_start_weights (array-like, optional) – Final portfolio weights from a previous cycle, shape (n_assets,). The pool starts with a fresh initial_pool_value distributed according to these weights.

Returns:

  • Gradient descent, return_training_metadata=False: best params dict.

  • Gradient descent, return_training_metadata=True: (params, metadata) tuple.

  • Optuna: list of best trials, or None if none completed.

Return type:

dict or tuple or list or None

do_run_on_historic_data(run_fingerprint, params={}, root=None, price_data=None, verbose=False, raw_trades=None, fees=None, gas_cost=None, arb_fees=None, fees_df=None, gas_cost_df=None, arb_fees_df=None, lp_supply_df=None, do_test_period=False, low_data_mode=False, preslice_burnin=True)[source]

Execute a forward-pass simulation with fixed parameters.

Runs the full simulation pipeline — price loading, weight calculation, arbitrage, and metric computation — using pre-trained (or manually specified) strategy parameters. This is the primary entry point for post-training evaluation, walk-forward OOS testing, and visualisation.

Parameters:
  • run_fingerprint (dict) – Master configuration dict (same structure as train_on_historic_data()).

  • params (dict or list of dict) – Strategy parameters. A single dict runs one simulation; a list of dicts runs multiple parameter sets in parallel via vmap.

  • root (str, optional) – Root directory for data files.

  • price_data (array-like or DataFrame, optional) – Pre-loaded price data. When None, loaded from parquet files.

  • verbose (bool, optional) – Print progress information (default False).

  • raw_trades (DataFrame, optional) – Real trade data to inject. Columns: unix timestamp (minute), token_in, token_out, amount_in.

  • fees (float, optional) – Swap fee override (e.g. 0.003 for 30 bps).

  • gas_cost (float, optional) – Gas cost override per transaction.

  • arb_fees (float, optional) – Arbitrageur fee override.

  • fees_df (DataFrame, optional) – Time-varying swap fees (columns: unix, fee).

  • gas_cost_df (DataFrame, optional) – Time-varying gas costs (columns: unix, gas_cost).

  • arb_fees_df (DataFrame, optional) – Time-varying arb fees (columns: unix, arb_fee).

  • lp_supply_df (DataFrame, optional) – Time-varying LP supply changes.

  • do_test_period (bool, optional) – If True, also run the OOS test period defined by endDateString to endTestDateString (default False).

  • low_data_mode (bool, optional) – If True, drop raw price arrays from the output dict to reduce memory usage (default False).

  • preslice_burnin (bool, optional) – If True, pre-slice data to max_memory_days of burn-in plus the simulation period (default True). Set False to load all available history.

Returns:

When do_test_period=False: a single results dict with keys including values, reserves, weights, coarse_weights, objective, and per-asset breakdowns.

When do_test_period=True: (train_results, test_results).

For multiple parameter sets, each value in the dict is a list (one entry per parameter set).

Return type:

dict or tuple[dict, dict]

do_run_on_historic_data_with_provided_coarse_weights(run_fingerprint, coarse_weights, params={}, root=None, price_data=None, verbose=False, raw_trades=None, fees=None, gas_cost=None, arb_fees=None, fees_df=None, gas_cost_df=None, arb_fees_df=None, lp_supply_df=None, do_test_period=False, low_data_mode=False)[source]

Execute a simulation using pre-computed coarse weights.

Like do_run_on_historic_data(), but bypasses the weight-calculation step entirely. The caller provides coarse_weights directly, and this function performs only fine-weight interpolation, arbitrage simulation, and metric computation.

This is useful for replaying a trained strategy with externally-computed or manually-specified weight trajectories, or for separating the weight computation from the simulation for profiling or debugging.

Parameters:
  • run_fingerprint (dict) – Master configuration dict.

  • coarse_weights (jnp.ndarray) – Pre-computed coarse weights, shape (n_coarse_steps, n_assets).

  • params (dict or list of dict, optional) – Strategy parameters (used only for initial_reserves and any subsidiary parameters, not for weight computation).

  • root (str, optional) – Root directory for data files.

  • price_data (array-like or DataFrame, optional) – Pre-loaded price data.

  • verbose (bool, optional) – Print progress (default False).

  • raw_trades (DataFrame, optional) – Real trade data to inject.

  • fees (float, optional) – Swap fee override.

  • gas_cost (float, optional) – Gas cost override.

  • arb_fees (float, optional) – Arbitrageur fee override.

  • fees_df (DataFrame, optional) – Time-varying swap fees.

  • gas_cost_df (DataFrame, optional) – Time-varying gas costs.

  • arb_fees_df (DataFrame, optional) – Time-varying arb fees.

  • lp_supply_df (DataFrame, optional) – Time-varying LP supply changes.

  • do_test_period (bool, optional) – Run OOS test period (default False).

  • low_data_mode (bool, optional) – Drop raw arrays from output to save memory (default False).

Returns:

Same structure as do_run_on_historic_data().

Return type:

dict or tuple[dict, dict]

Runner Utilities

create_trial_params(trial, param_config, params, run_fingerprint, n_assets, expand_around=False)[source]

Create trial parameters for Optuna optimization.

Parameters:

trialoptuna.Trial

The Optuna trial object

param_configdict

Configuration for parameter optimization. Each parameter can have: - low: float, lower bound - high: float, upper bound - log_scale: bool, whether to use log scale - scalar: bool, whether to use same value for all assets

paramsdict

Current parameter values, used for shape information

run_fingerprintdict

Run configuration

n_assetsint

Number of assets

Returns:

dict

Trial parameters dictionary

Raises:

ValueError

If parameter shapes are invalid or required config is missing

Parameters:
Return type:

Dict

generate_evaluation_points(start_idx, end_idx, bout_length, n_points, min_spacing, random_key=0)[source]

Generate evaluation start points for optuna-style hyperparameter search.

If the training period is exactly equal to bout_length (no room for multiple windows), returns just the start_idx as a single evaluation point.

Parameters:
  • start_idx (int) – Start index of the training period

  • end_idx (int) – End index of the training period

  • bout_length (int) – Length of each evaluation window

  • n_points (int) – Desired number of evaluation points

  • min_spacing (int) – Minimum spacing between evaluation points (currently unused)

  • random_key (int) – Random seed for reproducibility

Returns:

List of evaluation start indices

Return type:

list

find_best_balanced_solution(values_array, n_objectives=None)[source]

Find the solution closest to the ideal point after normalizing objectives.

Parameters:
  • values_array – Either a numpy array of shape (n_trials, n_objectives) or a list of optuna trials with values attribute

  • n_objectives – Number of objectives. Only needed if using list of trials.

Returns:

Index of the best balanced solution

Return type:

int

get_best_balanced_solution(study)[source]
class OptunaManager(run_fingerprint)[source]

Bases: object

Manages an Optuna hyperparameter optimization study lifecycle.

Encapsulates study creation, execution, early stopping, and result persistence. Configuration is drawn from run_fingerprint["optimisation_settings"]["optuna_settings"].

Parameters:

run_fingerprint (dict) – Run configuration. Must contain optimisation_settings.optuna_settings.

study

The Optuna study, created by setup_study().

Type:

optuna.Study or None

logger

File-backed logger writing to output_dir/optimization.log.

Type:

logging.Logger

__init__(run_fingerprint)[source]
setup_study(multi_objective=False)[source]

Create and configure the Optuna study.

Initialises an Optuna study with TPE sampler (multivariate), median pruner, and optional RDB storage. For multi-objective mode, creates a three-direction maximize study (mean return, worst-case, stability).

Parameters:

multi_objective (bool, optional) – If True, creates a multi-objective study with three maximize directions. Default is False (single-objective maximize).

early_stopping_callback(study, trial)[source]

Enhanced callback to implement early stopping using both training and validation metrics.

save_results()[source]

Enhanced save_results to include validation metrics.

optimize(objective)[source]

Run the optimization process with error handling and parallel execution.

Delegates to study.optimize with the configured number of trials, timeout, parallel jobs, and early-stopping callback. All exceptions are caught (logged, not re-raised) so that partial results are always saved via save_results.

Parameters:

objective (callable) – Optuna objective function: trial -> float (single-objective) or trial -> tuple[float, ...] (multi-objective).

class Hashabledict[source]

Bases: dict

A hashable dictionary class that enables using dictionaries as dictionary keys.

This class extends the built-in dict class to make dictionaries hashable by implementing the __hash__ and __eq__ methods. The hash is computed based on a sorted tuple of key-value pairs.

__key()

Returns a tuple of sorted key-value pairs representing the dictionary.

__hash__()[source]

Returns an integer hash value for the dictionary.

__eq__(other)[source]

Checks equality between this dictionary and another by comparing their sorted key-value pairs.

Examples

>>> d1 = Hashabledict({'a': 1, 'b': 2})
>>> d2 = Hashabledict({'b': 2, 'a': 1})
>>> hash(d1) == hash(d2)
True
>>> d1 == d2
True
>>> d3 = {d1: 'value'}  # Can use as dictionary key
class NestedHashabledict(*args, **kwargs)[source]

Bases: dict

A hashable dictionary class that enables using dictionaries as dictionary keys. Handles deeply nested dictionaries by recursively converting all nested dicts.

__init__(*args, **kwargs)[source]
get_sig_variations(n_assets)[source]

Compute signature variations for arbitrage.

Returns all possible (asset_in, asset_out) pairs encoded as a tuple of tuples, where each inner tuple has exactly one +1 (asset out) and one -1 (asset in), with zeros elsewhere.

Parameters:

n_assets (int) – Number of assets in the pool.

Returns:

Tuple of tuples representing valid arbitrage directions. Each inner tuple has shape (n_assets,) with values in {-1, 0, 1}.

Return type:

tuple

Example

>>> get_sig_variations(3)
((1, -1, 0), (1, 0, -1), (-1, 1, 0), (0, 1, -1), (-1, 0, 1), (0, -1, 1))
create_static_dict(run_fingerprint, bout_length, all_sig_variations=None, overrides=None)[source]

Create a static_dict from run_fingerprint for use in forward passes.

This simplifies the previous pattern of manually picking ~30 fields from run_fingerprint to create static_dict. Instead, we start with the full run_fingerprint and: 1. Exclude training-only fields 2. Apply necessary transformations (e.g., tokens -> tuple) 3. Add computed fields (bout_length, all_sig_variations) 4. Apply any overrides

Parameters:
  • run_fingerprint (dict) – The full run configuration dictionary

  • bout_length (int) – Bout length to use (varies between train/test)

  • all_sig_variations (list, optional) – Pre-computed signature variations for arbitrage

  • overrides (dict, optional) – Additional key-value pairs to override/add

Returns:

Hashable static dictionary for use in JAX forward passes

Return type:

NestedHashabledict

Example

>>> static_dict = create_static_dict(run_fingerprint, bout_length=10080)
>>> # Instead of manually building:
>>> # static_dict = {"chunk_period": rf["chunk_period"], "bout_length": ..., ...}
class HashableArrayWrapper(val)[source]

Bases: Generic[T]

Parameters:

val (T)

__init__(val)[source]
Parameters:

val (T)

get_run_location(run_fingerprint)[source]

Generate a unique run location identifier based on the run fingerprint.

This function creates a unique identifier for a simulation run by hashing the run_fingerprint dictionary. The run_fingerprint contains configuration parameters that define the simulation run.

Parameters:

run_fingerprint (dict) – A dictionary containing the configuration parameters for the simulation run. This typically includes parameters like start/end dates, tokens, rules, etc.

Returns:

A string identifier in the format “run_<sha256_hash>” where the hash is generated from the sorted JSON representation of the run_fingerprint.

Return type:

str

Examples

>>> fingerprint = {"startDate": "2023-01-01", "tokens": ["BTC", "ETH"]}
>>> get_run_location(fingerprint)
'run_8d147a1f8b8...'
nan_rollback(grads, params, old_params)[source]

Handles NaN values in gradients by rolling back to previous parameter values.

This function checks for NaN values in gradients and reverts the corresponding parameters back to their previous values when NaNs are detected. This helps maintain numerical stability during optimization.

Parameters:
  • grads (dict) – Dictionary containing the current gradients

  • params (dict) – Dictionary containing the current parameter values

  • old_params (dict) – Dictionary containing the previous parameter values

Returns:

Updated parameters with NaN values rolled back to previous values

Return type:

dict

Examples

>>> grads = {"log_k": jnp.array([[1.0, jnp.nan], [3.0, 4.0]])}
>>> params = {"log_k": jnp.array([[0.1, 0.2], [0.3, 0.4]])}
>>> old_params = {"log_k": jnp.array([[0.05, 0.15], [0.25, 0.35]])}
>>> rolled_back = nan_rollback(grads, params, old_params)
has_nan_grads(grad_tree)[source]

Check whether any leaf in a gradient pytree contains NaN values.

JIT-compiled for use inside training loops. Uses tree_reduce to scan all leaves without materializing intermediate structures.

Parameters:

grad_tree (pytree) – JAX pytree of gradient arrays.

Returns:

Scalar boolean: True if any gradient leaf contains a NaN.

Return type:

jnp.ndarray

has_nan_params(params)[source]

Check whether any learnable parameter arrays contain NaN values.

Skips non-learnable keys (initial_weights, initial_weights_logits, subsidary_params) that are not updated by the optimizer.

Parameters:

params (dict) – Parameter dict with arrays of shape (n_parameter_sets, n_assets).

Returns:

True if any learnable parameter contains a NaN.

Return type:

bool

nan_param_reinit(params, grads, pool, initial_params, run_fingerprint, n_tokens, n_parameter_sets)[source]

Reinitialize parameter sets that contain NaN values.

During training, parameters can become NaN from bad update steps even when gradients were finite (e.g., large learning rate + steep curvature). This function detects NaN-contaminated parameter sets and replaces them with freshly initialized (noised) parameters via pool.init_parameters, preserving the remaining healthy sets.

Parameters:
  • params (dict) – Current parameter dict with arrays of shape (n_parameter_sets, ...).

  • grads (dict) – Current gradient dict (unused directly, but passed for API consistency).

  • pool (BaseTFMMPool) – Pool instance, used to call init_parameters for replacement values.

  • initial_params (dict) – Initial values dict passed to pool.init_parameters.

  • run_fingerprint (dict) – Run configuration.

  • n_tokens (int) – Number of assets in the pool.

  • n_parameter_sets (int) – Number of parallel parameter sets.

Returns:

Parameter dict with NaN-contaminated sets replaced by fresh initializations.

Return type:

dict

get_unique_tokens(run_fingerprint)[source]

Gets unique tokens from run fingerprint including subsidiary pools.

Extracts all tokens from the main pool and subsidiary pools in the run fingerprint, removes duplicates, and returns a sorted list of unique tokens.

Parameters:

run_fingerprint (dict) – Dictionary containing run configuration including tokens and subsidiary pools

Returns:

Sorted list of unique token symbols

Return type:

list

Examples

>>> fingerprint = {
...     "tokens": ["BTC", "ETH"],
...     "subsidary_pools": [{"tokens": ["ETH", "DAI"]}]
... }
>>> get_unique_tokens(fingerprint)
['BTC', 'DAI', 'ETH']
split_list(lst, num_splits)[source]

Splits a list into a specified number of roughly equal sublists.

Divides a list into num_splits sublists, distributing any remainder elements evenly among the first sublists.

Parameters:
  • lst (list) – The input list to split

  • num_splits (int) – Number of sublists to create

Returns:

List of sublists

Return type:

list

Examples

>>> split_list([1,2,3,4,5], 2)
[[1,2,3], [4,5]]
>>> split_list([1,2,3,4,5,6], 3)
[[1,2], [3,4], [5,6]]
invert_permutation(perm)[source]

Compute the inverse of a permutation.

Given a permutation array that maps indices to their new positions, returns the inverse permutation that maps the new positions back to their original indices.

Parameters:

perm (numpy.ndarray) – Array representing a permutation of indices

Returns:

The inverse permutation array

Return type:

numpy.ndarray

Examples

>>> perm = np.array([2,0,1])
>>> invert_permutation(perm)
array([1, 2, 0])
permute_list_of_params(list_of_params, seed=0)[source]

Randomly permute a list of parameters using a fixed random seed.

This function takes a list of parameters and returns a new list with the same elements in a randomly permuted order. The permutation is deterministic based on the provided random seed.

Parameters:
  • list_of_params (list) – The list of parameters to permute

  • seed (int, optional) – Random seed to use for reproducible permutations (default: 0)

Returns:

A new list containing the same elements as the input list but in a randomly permuted order

Return type:

list

Examples

>>> params = [1, 2, 3, 4]
>>> permute_list_of_params(params, seed=42)
[3, 1, 4, 2]
>>> permute_list_of_params(params, seed=42)  # Same seed gives same permutation
[3, 1, 4, 2]
unpermute_list_of_params(list_of_params)[source]

Restore the original order of a previously permuted list of parameters.

This function takes a list that was permuted using permute_list_of_params() and restores it to its original order by applying the inverse permutation with the same random seed.

Parameters:

list_of_params (list) – The permuted list of parameters to restore to original order

Returns:

A new list containing the same elements as the input list but restored to their original order before permutation

Return type:

list

Examples

>>> params = [1, 2, 3, 4]
>>> permuted = permute_list_of_params(params)  # [3, 1, 4, 2]
>>> unpermute_list_of_params(permuted)  # Restores original order
[1, 2, 3, 4]
get_trades_and_fees(run_fingerprint, raw_trades, fees_df, gas_cost_df, arb_fees_df, lp_supply_df, do_test_period=False)[source]

Process trade and fee data for a simulation run.

Takes raw trades, fees, gas costs and arbitrage fees and converts them into arrays suitable for simulation. Handles both training and test periods if specified.

Parameters:
  • run_fingerprint (dict) – Dictionary containing run configuration including start/end dates and tokens

  • raw_trades (pd.DataFrame, optional) – DataFrame containing raw trade data

  • fees_df (pd.DataFrame, optional) – DataFrame containing fee data

  • gas_cost_df (pd.DataFrame, optional) – DataFrame containing gas cost data

  • arb_fees_df (pd.DataFrame, optional) – DataFrame containing arbitrage fee data

  • lp_supply_df (pd.DataFrame, optional) – DataFrame containing LP supply data

  • do_test_period (bool, optional) – Whether to process data for a test period after training period (default False)

Returns:

Contains processed arrays for trades, fees, gas costs and arb fees for both training and test periods as applicable

Return type:

dict

create_daily_unix_array(start_date_str, end_date_str)[source]

Creates an array of daily Unix timestamps in milliseconds between two dates.

Parameters:
  • start_date_str (str) – Start date string in format ‘YYYY-MM-DD HH:MM:SS’

  • end_date_str (str) – End date string in format ‘YYYY-MM-DD HH:MM:SS’

Returns:

Array of Unix timestamps in milliseconds for each day between start and end dates

Return type:

list

create_time_step(row, unix_values, tokens, index)[source]

Creates a SimulationResultTimestepDto object for a single time step.

Parameters:
  • row (pd.Series) – Row containing prices, reserves and weights data for this timestep

  • unix_values (list) – List of Unix timestamps in milliseconds

  • tokens (list) – List of token symbols

  • index (int) – Index of current timestep

Returns:

Object containing timestamp and coin data for this timestep

Return type:

SimulationResultTimestepDto

optimized_output_conversion(simulationRunDto, outputDict, tokens)[source]

Converts simulation output dictionary to a list of SimulationResultTimestepDto objects.

Parameters:
  • simulationRunDto (SimulationRunDto) – Object containing simulation run parameters

  • outputDict (dict) – Dictionary containing simulation output data including prices, reserves, and values

  • tokens (list) – List of token symbols used in simulation

Returns:

List of SimulationResultTimestepDto objects containing timestep data

Return type:

list

The function: 1. Creates Unix timestamps for each day between start and end dates 2. Downsamples simulation data from minutes to daily frequency 3. Calculates token weights from reserves, prices and total value 4. Combines data into timestep DTOs with coin holdings and values

probe_max_n_parameter_sets(run_fingerprint, min_sets=1, max_sets=64, safety_margin=0.9, verbose=True)[source]

Probe to find the maximum n_parameter_sets that fits in GPU memory.

Uses binary search to find the largest n_parameter_sets value that can complete a forward pass without OOM. Returns a dict with the recommended value and diagnostic info.

Parameters:
  • run_fingerprint (dict) – The run fingerprint configuration. Will be modified temporarily during probing.

  • min_sets (int) – Minimum n_parameter_sets to try (default 1).

  • max_sets (int) – Maximum n_parameter_sets to try (default 64).

  • safety_margin (float) – Fraction of max found to use as recommendation (default 0.9). This provides headroom for gradient computation which uses more memory.

  • verbose (bool) – Whether to print progress information.

Returns:

Keys: max_n_parameter_sets (int), recommended_n_parameter_sets (int, with safety margin applied), probed_values (list), success_values (list), failed_values (list).

Return type:

dict

Notes

  • This function temporarily modifies run_fingerprint during probing.

  • JAX caches are cleared between attempts.

  • The forward pass (without gradients) is used for probing, so gradient computation may require ~2x more memory. Hence the safety_margin.

allocate_memory_budget(run_fingerprint, available_memory_gb=None, priority='exploration', probe_if_needed=True, max_ensemble_members=1, verbose=True)[source]

Allocate memory budget across hyperparameters based on priority.

Parameters:
  • run_fingerprint (dict) – The run fingerprint configuration.

  • available_memory_gb (float, optional) – Available GPU memory in GB. If None and probe_if_needed=True, will probe to determine capacity.

  • priority (str) – How to allocate memory budget: - “exploration”: Maximize n_parameter_sets (find diverse solutions) - “robustness”: Balance n_parameter_sets and n_ensemble_members - “variance_reduction”: Maximize batch_size (stable gradients)

  • probe_if_needed (bool) – Whether to probe memory if available_memory_gb is not provided.

  • max_ensemble_members (int) – Maximum ensemble members to allocate (default 1 = no ensembling). Set higher (e.g., 4) if you want the “robustness” priority to use ensembles.

  • verbose (bool) – Whether to print allocation info.

Returns:

Recommended settings with keys: n_parameter_sets (int), n_ensemble_members (int), batch_size (int), priority_used (str), probe_result (dict or None).

Return type:

dict

apply_memory_allocation(run_fingerprint, allocation)[source]

Apply memory allocation results to a run_fingerprint.

Parameters:
  • run_fingerprint (dict) – The run fingerprint to modify (will be modified in place).

  • allocation (dict) – Result from allocate_memory_budget().

Returns:

The modified run_fingerprint.

Return type:

dict

auto_configure_memory_params(run_fingerprint, priority='exploration', max_ensemble_members=1, verbose=True)[source]

Convenience function: probe memory and apply allocation in one step.

Parameters:
  • run_fingerprint (dict) – The run fingerprint to configure (will be modified in place).

  • priority (str) – Allocation priority (“exploration”, “robustness”, “variance_reduction”).

  • max_ensemble_members (int) – Maximum ensemble members to allocate (default 1 = no ensembling).

  • verbose (bool) – Whether to print progress info.

Returns:

The modified run_fingerprint with optimal memory settings.

Return type:

dict

Example

>>> run = {...}  # your run_fingerprint
>>> auto_configure_memory_params(run, priority="exploration")
>>> train_on_historic_data(run)
compute_selection_metric(train_metrics, val_metrics=None, continuous_test_metrics=None, method='best_val', metric='sharpe', min_threshold=0.0)[source]

Compute selection metric value for a single iteration/trial.

This is the shared core logic used by both BestParamsTracker (during training) and load_manually (post-training). Returns a value for comparison and the index of the best param set.

Parameters:
  • train_metrics (list of dict) – Training metrics for each param set. Each dict has keys like “sharpe”, “returns_over_uniform_hodl”, etc.

  • val_metrics (list of dict, optional) – Validation metrics for each param set. Required if method=”best_val”.

  • continuous_test_metrics (list of dict, optional) – Continuous test metrics for each param set.

  • method (str) – Selection method. One of SELECTION_METHODS.

  • metric (str) – Which metric to use for comparison (e.g., “sharpe”, “returns_over_uniform_hodl”).

  • min_threshold (float) – Minimum threshold for “best_train_min_test” method.

Returns:

(selection_value, best_param_idx) - value for comparison and index of best param set. Higher selection_value is always better.

Return type:

tuple of (float, int)

class BestParamsTracker(selection_method='best_val', metric='sharpe', min_threshold=0.0)[source]

Bases: object

Unified tracking of params across training iterations/trials.

Tracks both “last” (most recent iteration) and “best” (by selection method) params along with their associated metrics and continuous outputs.

Used by both SGD and Optuna paths to ensure consistent param selection logic.

Parameters:
  • selection_method (str) – Method for selecting best params. One of SELECTION_METHODS.

  • metric (str) – Which metric to use for selection (e.g., “sharpe”, “returns_over_uniform_hodl”).

  • min_threshold (float) – Minimum threshold for “best_train_min_test” method.

last_*

State from the most recent update() call.

Type:

Various

best_*

State from when selection metric was highest.

Type:

Various

__init__(selection_method='best_val', metric='sharpe', min_threshold=0.0)[source]
Parameters:
  • selection_method (str)

  • metric (str)

  • min_threshold (float)

update(iteration, params, continuous_outputs, train_metrics_list, val_metrics_list=None, continuous_test_metrics_list=None)[source]

Update tracker with current iteration’s state.

Parameters:
  • iteration (int) – Current iteration/trial number.

  • params (dict) – Current parameters (batched over param sets).

  • continuous_outputs (dict) – Output from continuous forward pass. Must have “reserves” and “weights” with shape (n_param_sets, time_steps, …).

  • train_metrics_list (list of dict) – Training metrics for each param set.

  • val_metrics_list (list of dict, optional) – Validation metrics for each param set.

  • continuous_test_metrics_list (list of dict, optional) – Continuous test metrics for each param set.

Returns:

True if this iteration improved the best metric, False otherwise.

Return type:

bool

select_param_set(params_dict, idx, n_param_sets)[source]

Extract single param set from batched params.

Parameters:
  • params_dict (dict) – Batched parameters with shape (n_param_sets, …) for each key.

  • idx (int) – Index of param set to extract.

  • n_param_sets (int) – Total number of param sets.

Returns:

Parameters for single param set with shape (…) for each key.

Return type:

dict

get_results(n_param_sets, train_bout_length)[source]

Get comprehensive results with both last and best state.

Parameters:
  • n_param_sets (int) – Number of parameter sets (for extracting correct shapes).

  • train_bout_length (int) – Length of training period. Used to extract final reserves/weights at end of training (not end of test) for warm-starting.

Returns:

Comprehensive results including: - last_* fields: State from most recent iteration - best_* fields: State from when selection metric was best - Selection metadata

Return type:

dict

Run Fingerprint Defaults

Multi-Period SGD

Multi-Period SGD Training for Financial Strategies

This module implements multi-period robust training where we optimize parameters across multiple temporal windows simultaneously with a single forward pass and continuous pool state.

Key Design: - ONE forward pass spanning the entire data period - Dynamic slice out evaluation windows for each “period” - Aggregate losses across periods -> single backward pass - Pool state continuity is automatic (one continuous simulation)

This is NOT walk-forward (no retraining per period), but rather finds ONE set of params that performs well across all temporal windows.

Benefits: - Automatic pool state continuity through continuous forward pass - Single JIT compilation (no recompilation for different bout lengths) - Efficient: one forward pass, one backward pass per update step - Encourages robust parameters that work across market regimes

class PeriodSpec(period_id, rel_start, rel_end)[source]

Bases: object

Specification for a single evaluation period within the forward pass.

Defines a contiguous temporal slice of the forward pass output that constitutes one evaluation window. Multiple PeriodSpec instances partition (or overlap-partition) the full simulation into the windows used by multi-period SGD training.

Parameters:
  • period_id (int)

  • rel_start (int)

  • rel_end (int)

period_id

Zero-based ordinal index identifying this period within the sequence of evaluation windows.

Type:

int

rel_start

Start index of this period, relative to the first timestep of the forward pass output (not the raw price array).

Type:

int

rel_end

End index (exclusive) of this period, relative to the first timestep of the forward pass output.

Type:

int

property length: int

Return the number of timesteps in this period.

__init__(period_id, rel_start, rel_end)
Parameters:
  • period_id (int)

  • rel_start (int)

  • rel_end (int)

Return type:

None

class MultiPeriodResult(period_sharpes, period_returns, period_returns_over_hodl, mean_sharpe, std_sharpe, worst_sharpe, mean_returns_over_hodl, epochs_trained, final_objective, best_params=<factory>)[source]

Bases: object

Results from multi-period training.

Collects per-period evaluation metrics and their summary statistics after training a single parameter set across all temporal windows.

Parameters:
period_sharpes

Annualised Sharpe ratio for each evaluation period.

Type:

List[float]

period_returns

Cumulative return for each evaluation period.

Type:

List[float]

period_returns_over_hodl

Cumulative return relative to a uniform hold-all-assets baseline, per evaluation period.

Type:

List[float]

mean_sharpe

Arithmetic mean of period_sharpes.

Type:

float

std_sharpe

Standard deviation of period_sharpes, measuring cross-period consistency.

Type:

float

worst_sharpe

Minimum of period_sharpes (worst single-period performance).

Type:

float

mean_returns_over_hodl

Arithmetic mean of period_returns_over_hodl.

Type:

float

epochs_trained

Total number of gradient update steps executed.

Type:

int

final_objective

Best aggregated objective value observed during training (the value that triggered best_params to be saved).

Type:

float

best_params

Strategy parameters corresponding to final_objective, stored as NumPy arrays. Empty dict if training produced no valid update.

Type:

Dict[str, Any]

__init__(period_sharpes, period_returns, period_returns_over_hodl, mean_sharpe, std_sharpe, worst_sharpe, mean_returns_over_hodl, epochs_trained, final_objective, best_params=<factory>)
Parameters:
Return type:

None

create_multi_period_training_step(base_forward_pass, prices, period_specs, n_assets, return_val, aggregation='mean', softmin_temperature=1.0)[source]

Create a training step function that computes aggregate loss across periods.

This returns a function with signature (params, start_index) -> scalar, compatible with the existing backpropagation factories.

Parameters:
  • base_forward_pass (callable) – Partial forward_pass with full bout_length static_dict

  • prices (jnp.ndarray) – Full price array

  • period_specs (tuple of (rel_start, slice_len)) – For each period: relative start index and length within forward pass output. Must be tuple of tuples (static) so loop unrolls at trace time.

  • n_assets (int) – Number of assets

  • return_val (str) – Metric to compute per period (“sharpe”, “returns”, etc.)

  • aggregation (str) – How to combine period metrics: - “mean”: Simple average (all periods contribute equally) - “min”: Hard minimum (CAUTION: only minimum element gets gradients) - “softmin”: Soft minimum via negative softmax (recommended for worst-case) - “sum”: Sum of all metrics

  • softmin_temperature (float) – Temperature for softmin aggregation. Lower = closer to hard min. Default 1.0 gives moderate smoothing. Use 0.1-0.5 for sharper focus on worst.

Returns:

Function (params, start_index) -> scalar

Return type:

callable

Notes

IMPORTANT: Using aggregation=”min” has a gradient flow problem!

With hard min, gradients only flow through the single minimum element. This means: - Only 1 of N periods contributes to parameter updates - Gradients are sparse and noisy - Training can be unstable

Solution: Use “softmin” instead, which computes a soft minimum:

softmin(x) = sum(x * softmax(-x / temperature))

This gives more weight to lower-performing periods while still allowing gradients to flow from all periods. As temperature → 0, softmin → hard min.

generate_period_specs(n_periods, total_length, overlap_fraction=0.0)[source]

Generate period specifications that partition the simulation into evaluation windows.

Divides total_length timesteps into n_periods contiguous windows. When overlap_fraction is zero the windows tile the interval exactly (the last period absorbs any remainder). When positive, each window extends into its successor by overlap_fraction of the base period length, producing correlated but longer evaluation windows useful for smoothing period-boundary effects.

Parameters:
  • n_periods (int) – Number of evaluation windows to generate.

  • total_length (int) – Total number of timesteps available in the forward pass output.

  • overlap_fraction (float, optional) – Fraction of a base period length by which consecutive windows overlap. 0.0 (default) produces a non-overlapping partition; 0.5 means each window shares half its length with the next.

Returns:

Ordered list of PeriodSpec instances covering (possibly overlapping) the full simulation length.

Return type:

List[PeriodSpec]

Examples

>>> specs = generate_period_specs(n_periods=4, total_length=1000)
>>> [(s.rel_start, s.rel_end) for s in specs]
[(0, 250), (250, 500), (500, 750), (750, 1000)]
>>> specs = generate_period_specs(3, 900, overlap_fraction=0.5)
>>> [(s.rel_start, s.rel_end) for s in specs]
[(0, 300), (150, 450), (300, 600)]
multi_period_sgd_training(run_fingerprint, n_periods=4, overlap_fraction=0.0, max_epochs=500, aggregation='mean', softmin_temperature=1.0, verbose=True, root=None)[source]

Run multi-period SGD training.

Trains ONE set of parameters that performs well across multiple temporal windows simultaneously.

Parameters:
  • run_fingerprint (dict) – Run configuration

  • n_periods (int) – Number of evaluation periods

  • overlap_fraction (float) – Fraction of overlap between periods (0.0 = no overlap)

  • max_epochs (int) – Maximum training epochs

  • aggregation (str) – How to combine period metrics: - “mean”: Simple average (default, all periods equal) - “softmin”: Soft minimum (recommended for worst-case optimization) - “min”: Hard minimum (NOT recommended - gradient flow issues) - “sum”: Sum of metrics

  • softmin_temperature (float) – Temperature for softmin aggregation. Lower = closer to hard min. Default 1.0. Use 0.1-0.5 for sharper worst-case focus.

  • verbose (bool) – Print progress

  • root (str | None)

Returns:

Training result and summary statistics

Return type:

Tuple[MultiPeriodResult, dict]

Metric Extraction

Metric Extraction: Registry-based lookup for cycle evaluation metrics.

This module provides unified metric extraction from CycleEvaluation objects, replacing repetitive if/elif chains with a registry-based approach.

Usage:

from quantammsim.runners.metric_extraction import extract_cycle_metric

# Extract aggregated metrics from cycle evaluations
value = extract_cycle_metric(cycle_evals, "mean_oos_sharpe")
value = extract_cycle_metric(cycle_evals, "worst_wfe")
value = extract_cycle_metric(cycle_evals, "neg_is_oos_gap")
CYCLE_METRICS: Dict[str, str] = {'adjusted_oos_sharpe': 'adjusted_oos_sharpe', 'is_calmar': 'is_calmar', 'is_daily_log_sharpe': 'is_daily_log_sharpe', 'is_oos_gap': 'is_oos_gap', 'is_returns': 'is_returns', 'is_returns_over_hodl': 'is_returns_over_hodl', 'is_sharpe': 'is_sharpe', 'is_sterling': 'is_sterling', 'is_ulcer': 'is_ulcer', 'oos_calmar': 'oos_calmar', 'oos_daily_log_sharpe': 'oos_daily_log_sharpe', 'oos_returns': 'oos_returns', 'oos_returns_over_hodl': 'oos_returns_over_hodl', 'oos_sharpe': 'oos_sharpe', 'oos_sterling': 'oos_sterling', 'oos_ulcer': 'oos_ulcer', 'wfe': 'walk_forward_efficiency'}

Registry mapping short metric names to CycleEvaluation attribute names.

Keys are the tokens recognised in metric spec strings (e.g. "mean_oos_sharpe" → aggregator "mean" + metric "oos_sharpe"). Values are the corresponding attribute on CycleEvaluation.

AGGREGATORS: Dict[str, Callable[[List[float]], float]] = {'mean': <function _mean_agg>, 'worst': <function _worst_agg>}

Aggregation functions keyed by the prefix used in metric spec strings. E.g. "mean_oos_sharpe" dispatches to AGGREGATORS["mean"].

extract_cycle_metric(cycle_evals, metric_spec)[source]

Extract aggregated metric from CycleEvaluation list.

Supports metric specifications like: - “mean_oos_sharpe”: mean of oos_sharpe across cycles - “worst_wfe”: minimum walk_forward_efficiency - “neg_is_oos_gap”: negated mean of is_oos_gap (for minimization) - “adjusted_mean_oos_sharpe”: mean of adjusted_oos_sharpe

Parameters:
  • cycle_evals (List[CycleEvaluation]) – List of cycle evaluation results

  • metric_spec (str) – Metric specification string

Returns:

Aggregated metric value

Return type:

float

Examples

>>> value = extract_cycle_metric(cycle_evals, "mean_oos_sharpe")
>>> value = extract_cycle_metric(cycle_evals, "worst_wfe")
>>> value = extract_cycle_metric(cycle_evals, "neg_is_oos_gap")
get_metric_from_result(result, metric_name)[source]

Extract a metric from an EvaluationResult object.

Parameters:
  • result (EvaluationResult) – The evaluation result object

  • metric_name (str) – Name of the metric to extract

Returns:

The metric value

Return type:

float