Source code for quantammsim.runners.jax_runner_utils

import numpy as np
import pandas as pd

import json
import hashlib

# again, this only works on startup!
from jax import config, jit
from jax.tree_util import tree_map, tree_reduce
import jax.numpy as jnp

from quantammsim.core_simulator.windowing_utils import (
    raw_fee_like_amounts_to_fee_like_array,
    raw_trades_to_trade_array,
)

from quantammsim.apis.rest_apis.simulator_dtos.simulation_run_dto import (
    LiquidityPoolCoinDto,
    SimulationResultTimestepDto,
)

config.update("jax_enable_x64", True)

import os
import optuna
import logging
from datetime import datetime
from pathlib import Path
import plotly.graph_objects as go
from optuna.visualization import plot_optimization_history, plot_param_importances
import numpy as np


from typing import Dict, Any, Generic, TypeVar, List, Optional, Tuple
from copy import deepcopy
T = TypeVar('T')      # Declare type variable

[docs] def create_trial_params( trial: Any, # optuna.Trial, but avoid direct dependency param_config: Dict, params: Dict, run_fingerprint: Dict, n_assets: int, expand_around=False ) -> Dict: """ Create trial parameters for Optuna optimization. Parameters: ----------- trial : optuna.Trial The Optuna trial object param_config : dict Configuration for parameter optimization. Each parameter can have: - low: float, lower bound - high: float, upper bound - log_scale: bool, whether to use log scale - scalar: bool, whether to use same value for all assets params : dict Current parameter values, used for shape information run_fingerprint : dict Run configuration n_assets : int Number of assets Returns: -------- dict Trial parameters dictionary Raises: ------- ValueError If parameter shapes are invalid or required config is missing """ trial_params = {} # Copy subsidary_params if present (required by forward pass) if "subsidary_params" in params: trial_params["subsidary_params"] = params["subsidary_params"] for key, value in params.items(): if key == "subsidary_params": continue # Verify value has correct shape if not hasattr(value, 'shape') or len(value.shape) < 2: raise ValueError(f"Parameter {key} must have at least 2 dimensions") param_length = value.shape[1] config = param_config.get(key, {}) # Set defaults while preserving any existing config if expand_around: default_config = { "low": 0.1, "high": 0.1, "log_scale": False, "scalar": False } else: default_config = { "low": -10.0, "high": 10.0, "log_scale": False, "scalar": False } config = {**default_config, **config} # Handle logit_delta_lamb parameters if key.startswith("logit_delta_lamb") and not run_fingerprint.get( "use_alt_lamb", False ): trial_params[key] = jnp.zeros(param_length) continue # Handle initial_weights_logits specially if key == "initial_weights_logits": trial_params[key] = jnp.zeros(n_assets) continue if key == "initial_weights": trial_params[key] = value continue # Handle scalar vs vector parameters if config["scalar"]: # Create single value and repeat param_value = trial.suggest_float( key, # Use key directly for scalar params config["low"], config["high"], log=config["log_scale"], ) trial_params[key] = jnp.full(param_length, param_value) else: # Create array of different values trial_params[key] = jnp.array( [ trial.suggest_float( f"{key}_{i}", ( config["low"] if not expand_around else float(params[key][0][i]) - config["low"] ), ( config["high"] if not expand_around else float(params[key][0][i]) + config["high"] ), log=config["log_scale"], ) for i in range(param_length) ] ) return trial_params
[docs] def generate_evaluation_points( start_idx, end_idx, bout_length, n_points, min_spacing, random_key=0 ): """Generate evaluation start points for optuna-style hyperparameter search. If the training period is exactly equal to bout_length (no room for multiple windows), returns just the start_idx as a single evaluation point. Parameters ---------- start_idx : int Start index of the training period end_idx : int End index of the training period bout_length : int Length of each evaluation window n_points : int Desired number of evaluation points min_spacing : int Minimum spacing between evaluation points (currently unused) random_key : int Random seed for reproducibility Returns ------- list List of evaluation start indices """ np.random.seed(random_key) available_range = end_idx - start_idx - bout_length # Handle edge case where training period equals bout_length if available_range <= 0: # Only one evaluation point possible: the start of the training period return [start_idx] # Generate random points points = np.random.randint(0, available_range, n_points) points = np.sort(points) # Sort for better coverage # Generate equally spaced points equal_points = np.linspace(0, available_range, n_points, dtype=int) # Combine with random points and sort all_points = np.concatenate([points, equal_points]) all_points = np.unique(all_points) # Convert to absolute indices evaluation_starts = [start_idx + p for p in all_points] return evaluation_starts
[docs] def find_best_balanced_solution(values_array, n_objectives=None): """Find the solution closest to the ideal point after normalizing objectives. Args: values_array: Either a numpy array of shape (n_trials, n_objectives) or a list of optuna trials with values attribute n_objectives: Number of objectives. Only needed if using list of trials. Returns: int: Index of the best balanced solution """ if not isinstance(values_array, np.ndarray): # Convert list of trials to numpy array values_array = np.array([t.values for t in values_array]) if n_objectives is None: n_objectives = values_array.shape[1] normalized = (values_array - values_array.min(axis=0)) / ( values_array.max(axis=0) - values_array.min(axis=0) ) # Find solution closest to ideal point ideal_point = np.ones(n_objectives) distances = np.linalg.norm(normalized - ideal_point, axis=1) best_idx = np.argmin(distances) return best_idx
[docs] def get_best_balanced_solution(study): trials = study.best_trials # Normalize each objective to [0,1] # Use the helper function if len(trials) > 1: best_idx = find_best_balanced_solution(trials, len(study.directions)) else: best_idx = 0 return trials[best_idx].params, trials[best_idx].values, best_idx
[docs] class OptunaManager: """Manages an Optuna hyperparameter optimization study lifecycle. Encapsulates study creation, execution, early stopping, and result persistence. Configuration is drawn from ``run_fingerprint["optimisation_settings"]["optuna_settings"]``. Parameters ---------- run_fingerprint : dict Run configuration. Must contain ``optimisation_settings.optuna_settings``. Attributes ---------- study : optuna.Study or None The Optuna study, created by :meth:`setup_study`. logger : logging.Logger File-backed logger writing to ``output_dir/optimization.log``. """
[docs] def __init__(self, run_fingerprint): self.run_fingerprint = run_fingerprint self.optuna_settings = run_fingerprint["optimisation_settings"]["optuna_settings"] self.output_dir = Path(run_fingerprint["optimisation_settings"]["optuna_settings"].get("output_dir", "optuna_studies")) self.output_dir.mkdir(parents=True, exist_ok=True) self.study = None self.logger = self._setup_logger()
def _setup_logger(self): """Configure logging for the optimization process.""" optuna.logging.set_verbosity(optuna.logging.ERROR) logger = logging.getLogger(f"optuna_{datetime.now().strftime('%Y%m%d_%H%M%S')}") logger.setLevel(logging.INFO) # File handler fh = logging.FileHandler(self.output_dir / "optimization.log") fh.setLevel(logging.INFO) # Console handler # ch = logging.StreamHandler() # ch.setLevel(logging.INFO) # Formatter formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) fh.setFormatter(formatter) # ch.setFormatter(formatter) logger.addHandler(fh) return logger
[docs] def setup_study(self, multi_objective=False): """Create and configure the Optuna study. Initialises an Optuna study with TPE sampler (multivariate), median pruner, and optional RDB storage. For multi-objective mode, creates a three-direction maximize study (mean return, worst-case, stability). Parameters ---------- multi_objective : bool, optional If True, creates a multi-objective study with three maximize directions. Default is False (single-objective maximize). """ self.multi_objective = multi_objective study_name = ( self.optuna_settings["study_name"] or f"quantamm_{self.run_fingerprint['rule']}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" ) # Setup storage storage = None if self.optuna_settings["storage"]["url"]: storage = optuna.storages.RDBStorage( url=self.optuna_settings["storage"]["url"] ) # Custom sampler with multivariate TPE sampler = optuna.samplers.TPESampler( n_startup_trials=self.optuna_settings["n_startup_trials"], multivariate=True ) # Custom pruner pruner = optuna.pruners.MedianPruner( n_startup_trials=5, n_warmup_steps=20, interval_steps=1, ) if multi_objective: self.study = optuna.create_study( study_name=study_name, storage=storage, pruner=pruner, sampler=sampler, directions=["maximize", "maximize", "maximize"], ) else: self.study = optuna.create_study( study_name=study_name, storage=storage, sampler=sampler, pruner=pruner, direction="maximize", )
[docs] def early_stopping_callback(self, study, trial): """Enhanced callback to implement early stopping using both training and validation metrics.""" if not self.optuna_settings["early_stopping"]["enabled"]: return patience = self.optuna_settings["early_stopping"]["patience"] min_improvement = self.optuna_settings["early_stopping"]["min_improvement"] if len(study.trials) < patience: return # Get best validation value up to current trial completed_trials = [ t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE ] if not completed_trials: return validation_values = [ t.user_attrs.get("validation_value", float("-inf")) for t in completed_trials ] best_validation = max(validation_values) recent_trials = completed_trials[-patience:] recent_best_validation = max( t.user_attrs.get("validation_value", float("-inf")) for t in recent_trials ) relative_improvement = (recent_best_validation - best_validation) / abs( best_validation ) if relative_improvement < min_improvement: self.logger.info( f"Stopping study: No validation improvement > {min_improvement} " f"in last {patience} trials" ) study.stop()
[docs] def save_results(self): """Enhanced save_results to include validation metrics.""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") study_dir = self.output_dir / f"study_{timestamp}" study_dir.mkdir(parents=True, exist_ok=True) # Check if any trials completed completed_trials = [ t for t in self.study.trials if t.state == optuna.trial.TrialState.COMPLETE ] has_completed_trials = len(completed_trials) > 0 # Save study statistics if self.multi_objective: if has_completed_trials: pareto_front_trials = self.study.best_trials # Returns list of all non-dominated trials best_balanced_params, best_balanced_values, best_balanced_idx = get_best_balanced_solution(self.study) stats = { "best_params": [trial.params for trial in pareto_front_trials], "best_values": [trial.values for trial in pareto_front_trials], "n_trials": len(self.study.trials), "n_completed_trials": len(completed_trials), "datetime": timestamp, "run_fingerprint": self.run_fingerprint, "best_balanced_params": best_balanced_params, "best_balanced_values": best_balanced_values, "best_balanced_idx": int(best_balanced_idx), } else: stats = { "best_params": None, "best_values": None, "n_trials": len(self.study.trials), "n_completed_trials": 0, "datetime": timestamp, "run_fingerprint": self.run_fingerprint, "error": "No trials completed successfully", } else: if has_completed_trials: stats = { "best_value": float(self.study.best_value), # Convert to Python float "best_params": { k: float(v) for k, v in self.study.best_params.items() # Convert to Python float }, "n_trials": len(self.study.trials), "n_completed_trials": len(completed_trials), "datetime": timestamp, "run_fingerprint": self.run_fingerprint, } else: stats = { "best_value": None, "best_params": None, "n_trials": len(self.study.trials), "n_completed_trials": 0, "datetime": timestamp, "run_fingerprint": self.run_fingerprint, "error": "No trials completed successfully", } with open(study_dir / "study_results.json", "w") as f: json.dump(stats, f, indent=2) # Save visualization plots # fig_history = plot_optimization_history(self.study) # fig_history.write_html(str(study_dir / "optimization_history.html")) # fig_importance = plot_param_importances(self.study) # fig_importance.write_html(str(study_dir / "param_importance.html")) # Save trial data with validation metrics trial_data = [] for trial in self.study.trials: if trial.state == optuna.trial.TrialState.COMPLETE: trial_data_entry = { "number": trial.number, "datetime_start": trial.datetime_start.isoformat(), "datetime_complete": trial.datetime_complete.isoformat(), "params": {k: float(v) for k, v in trial.params.items()}, } # Handle multi-objective values if self.multi_objective: trial_data_entry.update( { "mean_return": float(trial.values[0]), "worst_case": float(trial.values[1]), "stability": float(trial.values[2]), } ) else: trial_data_entry["train_value"] = float(trial.value) # Add user attributes for attr in [ "validation_value", "validation_returns_over_hodl", "validation_returns_over_uniform_hodl", "validation_sharpe", "validation_return", "train_returns_over_hodl", "train_returns_over_uniform_hodl", "train_sharpe", "train_return", ]: trial_data_entry[attr] = float( trial.user_attrs.get(attr, float("-inf")) ) trial_data.append(trial_data_entry) with open(study_dir / "trial_data.json", "w") as f: json.dump(trial_data, f, indent=2)
# Create and save training vs validation plot # self._plot_train_vs_validation(trial_data, study_dir) # # Create and save training vs validation plot # self._plot_train_vs_validation(trial_data, study_dir)
[docs] def optimize(self, objective): """Run the optimization process with error handling and parallel execution. Delegates to ``study.optimize`` with the configured number of trials, timeout, parallel jobs, and early-stopping callback. All exceptions are caught (logged, not re-raised) so that partial results are always saved via ``save_results``. Parameters ---------- objective : callable Optuna objective function: ``trial -> float`` (single-objective) or ``trial -> tuple[float, ...]`` (multi-objective). """ try: self.study.optimize( objective, n_trials=self.optuna_settings["n_trials"], timeout=self.optuna_settings["timeout"], n_jobs=self.optuna_settings["n_jobs"], callbacks=[self.early_stopping_callback], catch=(Exception,), ) except KeyboardInterrupt: self.logger.info("Optimization interrupted by user") except Exception as e: self.logger.error(f"Optimization failed: {str(e)}") finally: self.save_results()
[docs] class Hashabledict(dict): """A hashable dictionary class that enables using dictionaries as dictionary keys. This class extends the built-in dict class to make dictionaries hashable by implementing the __hash__ and __eq__ methods. The hash is computed based on a sorted tuple of key-value pairs. Methods ------- __key() Returns a tuple of sorted key-value pairs representing the dictionary. __hash__() Returns an integer hash value for the dictionary. __eq__(other) Checks equality between this dictionary and another by comparing their sorted key-value pairs. Examples -------- >>> d1 = Hashabledict({'a': 1, 'b': 2}) >>> d2 = Hashabledict({'b': 2, 'a': 1}) >>> hash(d1) == hash(d2) True >>> d1 == d2 True >>> d3 = {d1: 'value'} # Can use as dictionary key """ def __key(self): def make_hashable(v): if isinstance(v, list): return tuple(make_hashable(x) for x in v) elif isinstance(v, dict): return tuple(sorted((k, make_hashable(val)) for k, val in v.items())) return v return tuple((k, make_hashable(self[k])) for k in sorted(self))
[docs] def __hash__(self): return hash(self.__key())
[docs] def __eq__(self, other): return self.__key() == other.__key()
[docs] class NestedHashabledict(dict): """A hashable dictionary class that enables using dictionaries as dictionary keys. Handles deeply nested dictionaries by recursively converting all nested dicts. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Recursively convert all nested dictionaries to NestedHashabledict for key, value in self.items(): if isinstance(value, dict): self[key] = NestedHashabledict( value ) # Use NestedHashabledict instead of Hashabledict elif isinstance(value, list): self[key] = [ NestedHashabledict(item) if isinstance(item, dict) else item for item in value ]
def __key(self): def make_hashable(v): if isinstance(v, list): return tuple(make_hashable(x) for x in v) elif isinstance(v, dict): return tuple(sorted((k, make_hashable(val)) for k, val in v.items())) return v return tuple((k, make_hashable(v)) for k, v in sorted(self.items())) def __hash__(self): try: return hash(self.__key()) except TypeError as e: # Debug info to help identify unhashable items for k, v in self.items(): try: hash((k, v)) except TypeError: print(f"Unhashable item found - Key: {k}, Value type: {type(v)}") raise e def __eq__(self, other): if not isinstance(other, dict): return False return self.__key() == NestedHashabledict(other).__key()
# Fields that are only used during training setup, not in forward passes # These are excluded when creating static_dict from run_fingerprint _TRAINING_ONLY_FIELDS = frozenset({ "optimisation_settings", # Contains lr, optimizer, etc. "startDateString", # Data loading dates "endDateString", "endTestDateString", "subsidary_pools", # Handled separately "bout_offset", # Training sampling config "freq", # Data frequency string })
[docs] def get_sig_variations(n_assets: int) -> tuple: """ Compute signature variations for arbitrage. Returns all possible (asset_in, asset_out) pairs encoded as a tuple of tuples, where each inner tuple has exactly one +1 (asset out) and one -1 (asset in), with zeros elsewhere. Parameters ---------- n_assets : int Number of assets in the pool. Returns ------- tuple Tuple of tuples representing valid arbitrage directions. Each inner tuple has shape (n_assets,) with values in {-1, 0, 1}. Example ------- >>> get_sig_variations(3) ((1, -1, 0), (1, 0, -1), (-1, 1, 0), (0, 1, -1), (-1, 0, 1), (0, -1, 1)) """ from itertools import product all_sig_variations = np.array(list(product([1, 0, -1], repeat=n_assets))) # Keep only variations with exactly one +1 and one -1 all_sig_variations = all_sig_variations[(all_sig_variations == 1).sum(-1) == 1] all_sig_variations = all_sig_variations[(all_sig_variations == -1).sum(-1) == 1] return tuple(map(tuple, all_sig_variations))
[docs] def create_static_dict( run_fingerprint: dict, bout_length: int, all_sig_variations: list = None, overrides: dict = None, ) -> NestedHashabledict: """Create a static_dict from run_fingerprint for use in forward passes. This simplifies the previous pattern of manually picking ~30 fields from run_fingerprint to create static_dict. Instead, we start with the full run_fingerprint and: 1. Exclude training-only fields 2. Apply necessary transformations (e.g., tokens -> tuple) 3. Add computed fields (bout_length, all_sig_variations) 4. Apply any overrides Parameters ---------- run_fingerprint : dict The full run configuration dictionary bout_length : int Bout length to use (varies between train/test) all_sig_variations : list, optional Pre-computed signature variations for arbitrage overrides : dict, optional Additional key-value pairs to override/add Returns ------- NestedHashabledict Hashable static dictionary for use in JAX forward passes Example ------- >>> static_dict = create_static_dict(run_fingerprint, bout_length=10080) >>> # Instead of manually building: >>> # static_dict = {"chunk_period": rf["chunk_period"], "bout_length": ..., ...} """ # Start with filtered copy static = {k: v for k, v in run_fingerprint.items() if k not in _TRAINING_ONLY_FIELDS} # Apply transformations if "tokens" in static: static["tokens"] = tuple(static["tokens"]) # Add computed fields static["bout_length"] = bout_length # Compute all_sig_variations if not provided but n_assets is available if all_sig_variations is not None: static["all_sig_variations"] = all_sig_variations elif "n_assets" in static or (overrides and "n_assets" in overrides): n_assets = overrides.get("n_assets") if overrides and "n_assets" in overrides else static.get("n_assets") if n_assets is not None: static["all_sig_variations"] = get_sig_variations(n_assets) # Default run_type if not present if "run_type" not in static: static["run_type"] = "normal" # Apply overrides if overrides: static.update(overrides) return NestedHashabledict(static)
[docs] class HashableArrayWrapper(Generic[T]):
[docs] def __init__(self, val: T): self.val = val
def __getattribute__(self, prop): if prop == "val" or prop == "__hash__" or prop == "__eq__": return super(HashableArrayWrapper, self).__getattribute__(prop) return getattr(self.val, prop) def __getitem__(self, key): return self.val[key] def __setitem__(self, key, val): self.val[key] = val def __hash__(self): return hash(self.val.tobytes()) def __eq__(self, other): if isinstance(other, HashableArrayWrapper): return self.__hash__() == other.__hash__() f = getattr(self.val, "__eq__") return f(self, other)
[docs] def get_run_location(run_fingerprint): """Generate a unique run location identifier based on the run fingerprint. This function creates a unique identifier for a simulation run by hashing the run_fingerprint dictionary. The run_fingerprint contains configuration parameters that define the simulation run. Parameters ---------- run_fingerprint : dict A dictionary containing the configuration parameters for the simulation run. This typically includes parameters like start/end dates, tokens, rules, etc. Returns ------- str A string identifier in the format "run_<sha256_hash>" where the hash is generated from the sorted JSON representation of the run_fingerprint. Examples -------- >>> fingerprint = {"startDate": "2023-01-01", "tokens": ["BTC", "ETH"]} >>> get_run_location(fingerprint) 'run_8d147a1f8b8...' """ run_location = "run_" + str( hashlib.sha256( json.dumps(run_fingerprint, sort_keys=True).encode("utf-8"), usedforsecurity=False, ).hexdigest() ) return run_location
[docs] def nan_rollback(grads, params, old_params): """Handles NaN values in gradients by rolling back to previous parameter values. This function checks for NaN values in gradients and reverts the corresponding parameters back to their previous values when NaNs are detected. This helps maintain numerical stability during optimization. Parameters ---------- grads : dict Dictionary containing the current gradients params : dict Dictionary containing the current parameter values old_params : dict Dictionary containing the previous parameter values Returns ------- dict Updated parameters with NaN values rolled back to previous values Examples -------- >>> grads = {"log_k": jnp.array([[1.0, jnp.nan], [3.0, 4.0]])} >>> params = {"log_k": jnp.array([[0.1, 0.2], [0.3, 0.4]])} >>> old_params = {"log_k": jnp.array([[0.05, 0.15], [0.25, 0.35]])} >>> rolled_back = nan_rollback(grads, params, old_params) """ for key in["log_k", "logit_lamb"]: if key in grads: bool_idx = jnp.sum(jnp.isnan(grads[key]), axis=-1, keepdims=True) > 0 params = tree_map( lambda p, old_p: jnp.where(bool_idx, old_p, p), params, old_params ) return params
[docs] @jit def has_nan_grads(grad_tree): """Check whether any leaf in a gradient pytree contains NaN values. JIT-compiled for use inside training loops. Uses ``tree_reduce`` to scan all leaves without materializing intermediate structures. Parameters ---------- grad_tree : pytree JAX pytree of gradient arrays. Returns ------- jnp.ndarray Scalar boolean: True if any gradient leaf contains a NaN. """ return tree_reduce( lambda acc, x: jnp.logical_or(acc, jnp.any(jnp.isnan(x))), grad_tree, initializer=False, )
[docs] def has_nan_params(params): """Check whether any learnable parameter arrays contain NaN values. Skips non-learnable keys (``initial_weights``, ``initial_weights_logits``, ``subsidary_params``) that are not updated by the optimizer. Parameters ---------- params : dict Parameter dict with arrays of shape ``(n_parameter_sets, n_assets)``. Returns ------- bool True if any learnable parameter contains a NaN. """ for key in params: if key not in ["initial_weights", "initial_weights_logits", "subsidary_params"]: if hasattr(params[key], 'shape') and jnp.any(jnp.isnan(params[key])): return True return False
[docs] def nan_param_reinit( params, grads, pool, initial_params, run_fingerprint, n_tokens, n_parameter_sets ): """Reinitialize parameter sets that contain NaN values. During training, parameters can become NaN from bad update steps even when gradients were finite (e.g., large learning rate + steep curvature). This function detects NaN-contaminated parameter sets and replaces them with freshly initialized (noised) parameters via ``pool.init_parameters``, preserving the remaining healthy sets. Parameters ---------- params : dict Current parameter dict with arrays of shape ``(n_parameter_sets, ...)``. grads : dict Current gradient dict (unused directly, but passed for API consistency). pool : BaseTFMMPool Pool instance, used to call ``init_parameters`` for replacement values. initial_params : dict Initial values dict passed to ``pool.init_parameters``. run_fingerprint : dict Run configuration. n_tokens : int Number of assets in the pool. n_parameter_sets : int Number of parallel parameter sets. Returns ------- dict Parameter dict with NaN-contaminated sets replaced by fresh initializations. """ # Check if any param set has NaN params (the actual problem) if has_nan_params(params): new_noised_params = pool.init_parameters( initial_params, run_fingerprint, n_tokens, n_parameter_sets ) # For each parameter set index n_param_sets = len(next(iter(params.values()))) for i in range(n_param_sets): # Check if any key has NaNs for this parameter set has_nans = False for key in params: if key not in ["initial_weights", "initial_weights_logits", "subsidary_params"]: if hasattr(params[key], 'shape') and len(params[key].shape) > 0: if jnp.any(jnp.isnan(params[key][i])): has_nans = True break # If NaNs found, replace all params for this index if has_nans: for key in params: if key not in ["initial_weights", "initial_weights_logits", "subsidary_params"]: if hasattr(params[key], 'shape') and len(params[key].shape) > 0: params[key] = params[key].at[i].set(new_noised_params[key][i]) return params
[docs] def get_unique_tokens(run_fingerprint): """Gets unique tokens from run fingerprint including subsidiary pools. Extracts all tokens from the main pool and subsidiary pools in the run fingerprint, removes duplicates, and returns a sorted list of unique tokens. Parameters ---------- run_fingerprint : dict Dictionary containing run configuration including tokens and subsidiary pools Returns ------- list Sorted list of unique token symbols Examples -------- >>> fingerprint = { ... "tokens": ["BTC", "ETH"], ... "subsidary_pools": [{"tokens": ["ETH", "DAI"]}] ... } >>> get_unique_tokens(fingerprint) ['BTC', 'DAI', 'ETH'] """ all_tokens = [run_fingerprint["tokens"]] + [ cprd["tokens"] for cprd in run_fingerprint["subsidary_pools"] ] all_tokens = [item for sublist in all_tokens for item in sublist] unique_tokens = list(set(all_tokens)) unique_tokens.sort() return unique_tokens
[docs] def split_list(lst, num_splits): """Splits a list into a specified number of roughly equal sublists. Divides a list into num_splits sublists, distributing any remainder elements evenly among the first sublists. Parameters ---------- lst : list The input list to split num_splits : int Number of sublists to create Returns ------- list List of sublists Examples -------- >>> split_list([1,2,3,4,5], 2) [[1,2,3], [4,5]] >>> split_list([1,2,3,4,5,6], 3) [[1,2], [3,4], [5,6]] """ # Calculate the length of each sublist sub_len = len(lst) // num_splits # Determine the number of sublists that should be one element longer num_longer = len(lst) % num_splits # Initialize variables result = [] start_idx = 0 # Iterate over the number of sublists for _ in range(num_splits): # Calculate the end index of the sublist end_idx = start_idx + sub_len # If there are remaining elements to distribute, add one to the sublist length if num_longer > 0: end_idx += 1 num_longer -= 1 # Add the sublist to the result result.append(lst[start_idx:end_idx]) # Update the start index for the next sublist start_idx = end_idx return result
[docs] def invert_permutation(perm): """ Compute the inverse of a permutation. Given a permutation array that maps indices to their new positions, returns the inverse permutation that maps the new positions back to their original indices. Parameters ---------- perm : numpy.ndarray Array representing a permutation of indices Returns ------- numpy.ndarray The inverse permutation array Examples -------- >>> perm = np.array([2,0,1]) >>> invert_permutation(perm) array([1, 2, 0]) """ s = np.zeros(perm.size, perm.dtype) s[perm] = range(perm.size) return s
[docs] def permute_list_of_params(list_of_params, seed=0): """ Randomly permute a list of parameters using a fixed random seed. This function takes a list of parameters and returns a new list with the same elements in a randomly permuted order. The permutation is deterministic based on the provided random seed. Parameters ---------- list_of_params : list The list of parameters to permute seed : int, optional Random seed to use for reproducible permutations (default: 0) Returns ------- list A new list containing the same elements as the input list but in a randomly permuted order Examples -------- >>> params = [1, 2, 3, 4] >>> permute_list_of_params(params, seed=42) [3, 1, 4, 2] >>> permute_list_of_params(params, seed=42) # Same seed gives same permutation [3, 1, 4, 2] """ np.random.seed(seed) # permute idx = np.random.permutation(len(list_of_params)) list_of_params_to_return = [list_of_params[i] for i in idx] return list_of_params_to_return
[docs] def unpermute_list_of_params(list_of_params): """ Restore the original order of a previously permuted list of parameters. This function takes a list that was permuted using permute_list_of_params() and restores it to its original order by applying the inverse permutation with the same random seed. Parameters ---------- list_of_params : list The permuted list of parameters to restore to original order Returns ------- list A new list containing the same elements as the input list but restored to their original order before permutation Examples -------- >>> params = [1, 2, 3, 4] >>> permuted = permute_list_of_params(params) # [3, 1, 4, 2] >>> unpermute_list_of_params(permuted) # Restores original order [1, 2, 3, 4] """ # unpermute idx = np.random.permutation(len(list_of_params)) idx_unpermute = invert_permutation(idx) list_of_params_to_return = [list_of_params[i] for i in idx_unpermute] return list_of_params_to_return
[docs] def get_trades_and_fees( run_fingerprint, raw_trades, fees_df, gas_cost_df, arb_fees_df, lp_supply_df, do_test_period=False ): """ Process trade and fee data for a simulation run. Takes raw trades, fees, gas costs and arbitrage fees and converts them into arrays suitable for simulation. Handles both training and test periods if specified. Parameters ---------- run_fingerprint : dict Dictionary containing run configuration including start/end dates and tokens raw_trades : pd.DataFrame, optional DataFrame containing raw trade data fees_df : pd.DataFrame, optional DataFrame containing fee data gas_cost_df : pd.DataFrame, optional DataFrame containing gas cost data arb_fees_df : pd.DataFrame, optional DataFrame containing arbitrage fee data lp_supply_df : pd.DataFrame, optional DataFrame containing LP supply data do_test_period : bool, optional Whether to process data for a test period after training period (default False) Returns ------- dict Contains processed arrays for trades, fees, gas costs and arb fees for both training and test periods as applicable """ # Process raw trades if provided if raw_trades is not None: train_period_trades = raw_trades_to_trade_array( raw_trades, start_date_string=run_fingerprint["startDateString"], end_date_string=run_fingerprint["endDateString"], tokens=get_unique_tokens(run_fingerprint), ) if do_test_period: test_period_trades = raw_trades_to_trade_array( raw_trades, start_date_string=run_fingerprint["endDateString"], end_date_string=run_fingerprint["endTestDateString"], tokens=get_unique_tokens(run_fingerprint), ) else: train_period_trades = None test_period_trades = None # Process fees, gas costs, and arb fees if provided fees_array = ( raw_fee_like_amounts_to_fee_like_array( fees_df, run_fingerprint["startDateString"], run_fingerprint["endDateString"], names=["fees"], fill_method="ffill", ) if fees_df is not None else None ) if do_test_period: test_fees_array = ( raw_fee_like_amounts_to_fee_like_array( fees_df, run_fingerprint["startDateString"], run_fingerprint["endDateString"], names=["fees"], fill_method="ffill", ) if fees_df is not None else None ) gas_cost_array = ( raw_fee_like_amounts_to_fee_like_array( gas_cost_df, run_fingerprint["startDateString"], run_fingerprint["endDateString"], names=["trade_gas_cost_usd"], fill_method="ffill", ) if gas_cost_df is not None else None ) if do_test_period: test_gas_cost_array = ( raw_fee_like_amounts_to_fee_like_array( gas_cost_df, run_fingerprint["endDateString"], run_fingerprint["endTestDateString"], names=["trade_gas_cost_usd"], fill_method="ffill", ) if gas_cost_df is not None else None ) arb_fees_array = ( raw_fee_like_amounts_to_fee_like_array( arb_fees_df, run_fingerprint["startDateString"], run_fingerprint["endDateString"], names=["arb_fees"], fill_method="ffill", ) if arb_fees_df is not None else None ) if do_test_period: test_arb_fees_array = ( raw_fee_like_amounts_to_fee_like_array( arb_fees_df, run_fingerprint["endDateString"], run_fingerprint["endTestDateString"], names=["arb_fees"], fill_method="ffill", ) if arb_fees_df is not None else None ) lp_supply_array = ( raw_fee_like_amounts_to_fee_like_array( lp_supply_df, run_fingerprint["startDateString"], run_fingerprint["endDateString"], names=["lp_supply"], fill_method="ffill", ) if lp_supply_df is not None else None ) if do_test_period: test_lp_supply_array = ( raw_fee_like_amounts_to_fee_like_array( lp_supply_df, run_fingerprint["endDateString"], run_fingerprint["endTestDateString"], names=["lp_supply"], fill_method="ffill", ) if lp_supply_df is not None else None ) return { "train_period_trades": train_period_trades, "test_period_trades": test_period_trades, "fees_array": fees_array, "gas_cost_array": gas_cost_array, "arb_fees_array": arb_fees_array, "lp_supply_array": lp_supply_array, "test_fees_array": test_fees_array, "test_gas_cost_array": test_gas_cost_array, "test_arb_fees_array": test_arb_fees_array, "test_lp_supply_array": test_lp_supply_array, } else: return { "train_period_trades": train_period_trades, "fees_array": fees_array, "gas_cost_array": gas_cost_array, "arb_fees_array": arb_fees_array, "lp_supply_array": lp_supply_array, }
[docs] def create_daily_unix_array(start_date_str, end_date_str): """ Creates an array of daily Unix timestamps in milliseconds between two dates. Args: start_date_str (str): Start date string in format 'YYYY-MM-DD HH:MM:SS' end_date_str (str): End date string in format 'YYYY-MM-DD HH:MM:SS' Returns: list: Array of Unix timestamps in milliseconds for each day between start and end dates """ end_date = pd.to_datetime(end_date_str) # Create a date range ending the day before the end_date date_range = pd.date_range(start=start_date_str, end=end_date, freq="D") daily_unix_values = date_range.view("int64") // 10**6 return daily_unix_values.tolist()
[docs] def create_time_step(row, unix_values, tokens, index): """ Creates a SimulationResultTimestepDto object for a single time step. Args: row (pd.Series): Row containing prices, reserves and weights data for this timestep unix_values (list): List of Unix timestamps in milliseconds tokens (list): List of token symbols index (int): Index of current timestep Returns: SimulationResultTimestepDto: Object containing timestamp and coin data for this timestep """ timeStep = SimulationResultTimestepDto(unix_values[index], [], 0) for coinIndex, token in enumerate(tokens): coin = LiquidityPoolCoinDto() coin.coinCode = token coin.currentPrice = row["prices"][coinIndex].item() coin.amount = row["reserves"][coinIndex].item() coin.weight = row["weights"][coinIndex].item() coin.marketValue = coin.currentPrice * coin.amount timeStep.coinsHeld.append(coin) return timeStep
[docs] def optimized_output_conversion(simulationRunDto, outputDict, tokens): """ Converts simulation output dictionary to a list of SimulationResultTimestepDto objects. Args: simulationRunDto (SimulationRunDto): Object containing simulation run parameters outputDict (dict): Dictionary containing simulation output data including prices, reserves, and values tokens (list): List of token symbols used in simulation Returns: list: List of SimulationResultTimestepDto objects containing timestep data The function: 1. Creates Unix timestamps for each day between start and end dates 2. Downsamples simulation data from minutes to daily frequency 3. Calculates token weights from reserves, prices and total value 4. Combines data into timestep DTOs with coin holdings and values """ print(simulationRunDto.startDateString) print(simulationRunDto.endDateString) print(tokens) # Create a date range with daily frequency and convert to Unix timestamps in milliseconds unix_values = create_daily_unix_array( simulationRunDto.startDateString, simulationRunDto.endDateString ) # Convert outputDict data to pandas DataFrame for efficient slicing prices_df = pd.DataFrame(outputDict["prices"])[::1440] reserves_df = pd.DataFrame(outputDict["reserves"])[::1440] values_df = pd.DataFrame(outputDict["value"])[::1440] # note that the returned weights are empirical weights, not calculated weights # this is because the calculated weights are not returned in the outputDict as # they are not guaranteed to exist for all possible pool types weights_df = pd.DataFrame( outputDict["reserves"] * outputDict["prices"] / outputDict["value"][:, np.newaxis] )[::1440] print("prices_df: ", len(prices_df)) print("reserves_df: ", len(reserves_df)) print("weights_df: ", len(weights_df)) print("unix_values: ", len(unix_values)) # Combine DataFrames combined_df = pd.concat( [prices_df, reserves_df, weights_df], axis=1, keys=["prices", "reserves", "weights"], ) print(len(unix_values)) print(len(combined_df)) # Check if the length of unix_values matches the number of rows in combined_df if len(unix_values) != len(combined_df): print(len(unix_values)) print(len(combined_df)) raise ValueError( "The length of unix_values does not match the number of rows in combined_df" ) # Ensure index alignment by resetting index combined_df = combined_df.reset_index(drop=True) # Convert DataFrame to list of DTO objects using apply resultTimeSteps = combined_df.apply( lambda row: create_time_step(row, unix_values, tokens, row.name), axis=1 ).tolist() return resultTimeSteps
# ============================================================================= # Memory Probing Utilities # =============================================================================
[docs] def probe_max_n_parameter_sets( run_fingerprint: dict, min_sets: int = 1, max_sets: int = 64, safety_margin: float = 0.9, verbose: bool = True, ) -> dict: """ Probe to find the maximum n_parameter_sets that fits in GPU memory. Uses binary search to find the largest n_parameter_sets value that can complete a forward pass without OOM. Returns a dict with the recommended value and diagnostic info. Parameters ---------- run_fingerprint : dict The run fingerprint configuration. Will be modified temporarily during probing. min_sets : int Minimum n_parameter_sets to try (default 1). max_sets : int Maximum n_parameter_sets to try (default 64). safety_margin : float Fraction of max found to use as recommendation (default 0.9). This provides headroom for gradient computation which uses more memory. verbose : bool Whether to print progress information. Returns ------- dict Keys: ``max_n_parameter_sets`` (int), ``recommended_n_parameter_sets`` (int, with safety margin applied), ``probed_values`` (list), ``success_values`` (list), ``failed_values`` (list). Notes ----- - This function temporarily modifies run_fingerprint during probing. - JAX caches are cleared between attempts. - The forward pass (without gradients) is used for probing, so gradient computation may require ~2x more memory. Hence the safety_margin. """ from copy import deepcopy from jax import clear_caches from jax.tree_util import Partial import jax.numpy as jnp from quantammsim.runners.default_run_fingerprint import run_fingerprint_defaults from quantammsim.core_simulator.param_utils import recursive_default_set from quantammsim.pools.creator import create_pool from quantammsim.utils.data_processing.historic_data_utils import get_data_dict from quantammsim.core_simulator.forward_pass import forward_pass_nograd from jax import jit, vmap # Work with a copy to avoid side effects probe_fingerprint = deepcopy(run_fingerprint) recursive_default_set(probe_fingerprint, run_fingerprint_defaults) probed_values = [] success_values = [] failed_values = [] def try_forward_pass(n_sets: int) -> bool: """Attempt a forward pass with n_sets parameter sets. Returns True if successful.""" probe_fingerprint["optimisation_settings"]["n_parameter_sets"] = n_sets try: # Get tokens and setup unique_tokens = get_unique_tokens(probe_fingerprint) n_tokens = len(unique_tokens) # Load minimal data data_dict = get_data_dict( unique_tokens, probe_fingerprint, data_kind=probe_fingerprint["optimisation_settings"]["training_data_kind"], max_memory_days=probe_fingerprint["max_memory_days"], start_date_string=probe_fingerprint["startDateString"], end_time_string=probe_fingerprint["endDateString"], start_time_test_string=probe_fingerprint["endDateString"], end_time_test_string=probe_fingerprint.get("endTestDateString"), do_test_period=False, ) bout_length_window = data_dict["bout_length"] - probe_fingerprint["bout_offset"] if bout_length_window <= 0: bout_length_window = data_dict["bout_length"] // 2 # Create pool and params rule = probe_fingerprint["rule"] pool = create_pool(rule) learnable_bounds = probe_fingerprint.get("learnable_bounds_settings", {}) initial_params = { "initial_memory_length": probe_fingerprint["initial_memory_length"], "initial_memory_length_delta": probe_fingerprint["initial_memory_length_delta"], "initial_k_per_day": probe_fingerprint["initial_k_per_day"], "initial_weights_logits": probe_fingerprint["initial_weights_logits"], "initial_log_amplitude": probe_fingerprint["initial_log_amplitude"], "initial_raw_width": probe_fingerprint["initial_raw_width"], "initial_raw_exponents": probe_fingerprint["initial_raw_exponents"], "initial_pre_exp_scaling": probe_fingerprint["initial_pre_exp_scaling"], "min_weights_per_asset": learnable_bounds.get("min_weights_per_asset"), "max_weights_per_asset": learnable_bounds.get("max_weights_per_asset"), } params = pool.init_parameters(initial_params, probe_fingerprint, n_tokens, n_sets) params_in_axes_dict = pool.make_vmap_in_axes(params) # Setup static dict using encapsulated helper # all_sig_variations is auto-computed from n_assets static_dict = create_static_dict( probe_fingerprint, bout_length=bout_length_window, overrides={ "n_assets": n_tokens, "training_data_kind": probe_fingerprint["optimisation_settings"]["training_data_kind"], "do_trades": False, }, ) # Create vmapped forward pass partial_forward = Partial( forward_pass_nograd, prices=data_dict["prices"], static_dict=static_dict, pool=pool, ) vmapped_forward = jit( vmap(partial_forward, in_axes=[params_in_axes_dict, None, None]) ) # Run forward pass start_index = (data_dict["start_idx"], 0) _ = vmapped_forward(params, start_index, None) # Force computation to complete jnp.zeros(1).block_until_ready() return True except Exception as e: error_str = str(e).lower() if "resource" in error_str or "memory" in error_str or "oom" in error_str: return False # Re-raise non-memory errors raise # Binary search for max n_parameter_sets low, high = min_sets, max_sets best = min_sets while low <= high: mid = (low + high) // 2 probed_values.append(mid) if verbose: print(f"Probing n_parameter_sets={mid}...", end=" ") # Clear caches before each attempt clear_caches() import gc gc.collect() try: success = try_forward_pass(mid) except Exception as e: if verbose: print(f"Error: {e}") success = False if success: if verbose: print("OK") success_values.append(mid) best = mid low = mid + 1 else: if verbose: print("OOM") failed_values.append(mid) high = mid - 1 # Clear caches after attempt clear_caches() gc.collect() recommended = max(min_sets, int(best * safety_margin)) result = { "max_n_parameter_sets": best, "recommended_n_parameter_sets": recommended, "probed_values": sorted(probed_values), "success_values": sorted(success_values), "failed_values": sorted(failed_values), } if verbose: print(f"\nMemory probe results:") print(f" Max n_parameter_sets: {best}") print(f" Recommended (with {safety_margin:.0%} margin): {recommended}") return result
[docs] def allocate_memory_budget( run_fingerprint: dict, available_memory_gb: float = None, priority: str = "exploration", probe_if_needed: bool = True, max_ensemble_members: int = 1, verbose: bool = True, ) -> dict: """ Allocate memory budget across hyperparameters based on priority. Parameters ---------- run_fingerprint : dict The run fingerprint configuration. available_memory_gb : float, optional Available GPU memory in GB. If None and probe_if_needed=True, will probe to determine capacity. priority : str How to allocate memory budget: - "exploration": Maximize n_parameter_sets (find diverse solutions) - "robustness": Balance n_parameter_sets and n_ensemble_members - "variance_reduction": Maximize batch_size (stable gradients) probe_if_needed : bool Whether to probe memory if available_memory_gb is not provided. max_ensemble_members : int Maximum ensemble members to allocate (default 1 = no ensembling). Set higher (e.g., 4) if you want the "robustness" priority to use ensembles. verbose : bool Whether to print allocation info. Returns ------- dict Recommended settings with keys: ``n_parameter_sets`` (int), ``n_ensemble_members`` (int), ``batch_size`` (int), ``priority_used`` (str), ``probe_result`` (dict or None). """ probe_result = None if available_memory_gb is None and probe_if_needed: # Probe to find capacity probe_result = probe_max_n_parameter_sets( run_fingerprint, verbose=verbose, safety_margin=0.85, # More conservative for allocation ) max_units = probe_result["recommended_n_parameter_sets"] elif available_memory_gb is not None: # Rough estimate: assume ~0.5-2 GB per parameter set depending on config # This is a very rough heuristic max_units = int(available_memory_gb * 4) # ~4 param sets per GB as rough estimate else: # Default conservative estimate max_units = 8 # Allocate based on priority if priority == "exploration": # Maximize exploration with independent param sets n_parameter_sets = max(1, max_units) n_ensemble_members = 1 batch_size = 1 # Small batch, rely on param diversity elif priority == "robustness": # Balance exploration and ensembling (if allowed) if max_ensemble_members > 1: n_parameter_sets = max(1, max_units // 2) n_ensemble_members = min(max_ensemble_members, max(1, max_units // n_parameter_sets // 2)) batch_size = max(1, max_units // (n_parameter_sets * n_ensemble_members)) else: # No ensembling allowed, fall back to exploration-like allocation n_parameter_sets = max(1, max_units) n_ensemble_members = 1 batch_size = 1 elif priority == "variance_reduction": # Fewer param sets, larger batches for stable gradients n_parameter_sets = min(4, max_units) n_ensemble_members = 1 batch_size = max(1, max_units // n_parameter_sets) else: raise ValueError(f"Unknown priority: {priority}. Use 'exploration', 'robustness', or 'variance_reduction'") result = { "n_parameter_sets": n_parameter_sets, "n_ensemble_members": n_ensemble_members, "batch_size": batch_size, "priority_used": priority, "probe_result": probe_result, } if verbose: print(f"\nMemory allocation ({priority} priority):") print(f" n_parameter_sets: {n_parameter_sets}") print(f" n_ensemble_members: {n_ensemble_members}") print(f" batch_size: {batch_size}") if probe_result: print(f" (based on probed max: {probe_result['max_n_parameter_sets']})") return result
[docs] def apply_memory_allocation(run_fingerprint: dict, allocation: dict) -> dict: """ Apply memory allocation results to a run_fingerprint. Parameters ---------- run_fingerprint : dict The run fingerprint to modify (will be modified in place). allocation : dict Result from allocate_memory_budget(). Returns ------- dict The modified run_fingerprint. """ run_fingerprint["optimisation_settings"]["n_parameter_sets"] = allocation["n_parameter_sets"] run_fingerprint["optimisation_settings"]["batch_size"] = allocation["batch_size"] run_fingerprint["n_ensemble_members"] = allocation["n_ensemble_members"] return run_fingerprint
[docs] def auto_configure_memory_params( run_fingerprint: dict, priority: str = "exploration", max_ensemble_members: int = 1, verbose: bool = True, ) -> dict: """ Convenience function: probe memory and apply allocation in one step. Parameters ---------- run_fingerprint : dict The run fingerprint to configure (will be modified in place). priority : str Allocation priority ("exploration", "robustness", "variance_reduction"). max_ensemble_members : int Maximum ensemble members to allocate (default 1 = no ensembling). verbose : bool Whether to print progress info. Returns ------- dict The modified run_fingerprint with optimal memory settings. Example ------- >>> run = {...} # your run_fingerprint >>> auto_configure_memory_params(run, priority="exploration") >>> train_on_historic_data(run) """ allocation = allocate_memory_budget( run_fingerprint, priority=priority, probe_if_needed=True, max_ensemble_members=max_ensemble_members, verbose=verbose, ) return apply_memory_allocation(run_fingerprint, allocation)
# ============================================================================= # Best Params Selection and Tracking # ============================================================================= # Valid selection methods - must match load_manually methods where applicable SELECTION_METHODS = [ "last", # Always return last iteration/trial "best_train", # Best training metric "best_val", # Best validation metric (requires val_fraction > 0) "best_continuous_test", # Best continuous test metric (NOT RECOMMENDED - data leakage) "best_train_min_test", # Best train meeting test threshold ]
[docs] def compute_selection_metric( train_metrics: List[Dict], val_metrics: Optional[List[Dict]] = None, continuous_test_metrics: Optional[List[Dict]] = None, method: str = "best_val", metric: str = "sharpe", min_threshold: float = 0.0, ) -> Tuple[float, int]: """ Compute selection metric value for a single iteration/trial. This is the shared core logic used by both BestParamsTracker (during training) and load_manually (post-training). Returns a value for comparison and the index of the best param set. Parameters ---------- train_metrics : list of dict Training metrics for each param set. Each dict has keys like "sharpe", "returns_over_uniform_hodl", etc. val_metrics : list of dict, optional Validation metrics for each param set. Required if method="best_val". continuous_test_metrics : list of dict, optional Continuous test metrics for each param set. method : str Selection method. One of SELECTION_METHODS. metric : str Which metric to use for comparison (e.g., "sharpe", "returns_over_uniform_hodl"). min_threshold : float Minimum threshold for "best_train_min_test" method. Returns ------- tuple of (float, int) (selection_value, best_param_idx) - value for comparison and index of best param set. Higher selection_value is always better. """ if method not in SELECTION_METHODS: raise ValueError(f"Unknown selection method: {method}. Must be one of {SELECTION_METHODS}") if method == "last": # "Last" always wins - return high value, first param set return float("inf"), 0 elif method == "best_train": if not train_metrics: return -float("inf"), 0 metrics_per_set = np.array([m.get(metric, np.nan) for m in train_metrics]) valid_mask = ~np.isnan(metrics_per_set) if not valid_mask.any(): return -float("inf"), 0 best_idx = int(np.nanargmax(metrics_per_set)) # Use nanmean across param sets for selection value (matches SGD behavior) return float(np.nanmean(metrics_per_set)), best_idx elif method == "best_val": if not val_metrics: raise ValueError("best_val method requires val_metrics (set val_fraction > 0)") metrics_per_set = np.array([m.get(metric, np.nan) for m in val_metrics]) valid_mask = ~np.isnan(metrics_per_set) if not valid_mask.any(): return -float("inf"), 0 best_idx = int(np.nanargmax(metrics_per_set)) # Use nanmean across param sets for selection value return float(np.nanmean(metrics_per_set)), best_idx elif method == "best_continuous_test": # NOT RECOMMENDED - causes data leakage if not continuous_test_metrics: return -float("inf"), 0 metrics_per_set = np.array([m.get(metric, np.nan) for m in continuous_test_metrics]) valid_mask = ~np.isnan(metrics_per_set) if not valid_mask.any(): return -float("inf"), 0 best_idx = int(np.nanargmax(metrics_per_set)) return float(np.nanmean(metrics_per_set)), best_idx elif method == "best_train_min_test": # Best training metric that meets minimum test threshold if not train_metrics: return -float("inf"), 0 train_per_set = np.array([m.get(metric, np.nan) for m in train_metrics]) if continuous_test_metrics: test_per_set = np.array([m.get(metric, np.nan) for m in continuous_test_metrics]) else: # No test metrics - fall back to best_train best_idx = int(np.nanargmax(train_per_set)) return float(np.nanmean(train_per_set)), best_idx best_val = -float("inf") best_idx = 0 for i, (train_v, test_v) in enumerate(zip(train_per_set, test_per_set)): if not np.isnan(test_v) and test_v >= min_threshold: if not np.isnan(train_v) and train_v > best_val: best_val = train_v best_idx = i if best_val == -float("inf"): # No param set met threshold - fall back to best train best_idx = int(np.nanargmax(train_per_set)) return float(np.nanmean(train_per_set)), best_idx return best_val, best_idx else: raise ValueError(f"Unknown selection method: {method}")
[docs] class BestParamsTracker: """ Unified tracking of params across training iterations/trials. Tracks both "last" (most recent iteration) and "best" (by selection method) params along with their associated metrics and continuous outputs. Used by both SGD and Optuna paths to ensure consistent param selection logic. Parameters ---------- selection_method : str Method for selecting best params. One of SELECTION_METHODS. metric : str Which metric to use for selection (e.g., "sharpe", "returns_over_uniform_hodl"). min_threshold : float Minimum threshold for "best_train_min_test" method. Attributes ---------- last_* : Various State from the most recent update() call. best_* : Various State from when selection metric was highest. """
[docs] def __init__( self, selection_method: str = "best_val", metric: str = "sharpe", min_threshold: float = 0.0, ): if selection_method not in SELECTION_METHODS: raise ValueError(f"Unknown selection method: {selection_method}. Must be one of {SELECTION_METHODS}") self.selection_method = selection_method self.metric = metric self.min_threshold = min_threshold # "Last" state - always most recent iteration self.last_iteration = 0 self.last_params = None self.last_param_idx = 0 self.last_train_metrics = None self.last_val_metrics = None self.last_continuous_test_metrics = None self.last_continuous_outputs = None # {"reserves": ..., "weights": ...} # "Best" state - based on selection_method self.best_iteration = 0 self.best_params = None self.best_param_idx = 0 self.best_train_metrics = None self.best_val_metrics = None self.best_continuous_test_metrics = None self.best_continuous_outputs = None self.best_metric_value = -float("inf")
[docs] def update( self, iteration: int, params: Dict, continuous_outputs: Dict, train_metrics_list: List[Dict], val_metrics_list: Optional[List[Dict]] = None, continuous_test_metrics_list: Optional[List[Dict]] = None, ) -> bool: """ Update tracker with current iteration's state. Parameters ---------- iteration : int Current iteration/trial number. params : dict Current parameters (batched over param sets). continuous_outputs : dict Output from continuous forward pass. Must have "reserves" and "weights" with shape (n_param_sets, time_steps, ...). train_metrics_list : list of dict Training metrics for each param set. val_metrics_list : list of dict, optional Validation metrics for each param set. continuous_test_metrics_list : list of dict, optional Continuous test metrics for each param set. Returns ------- bool True if this iteration improved the best metric, False otherwise. """ # Always update "last" state self.last_iteration = iteration self.last_params = deepcopy(params) self.last_train_metrics = train_metrics_list self.last_val_metrics = val_metrics_list self.last_continuous_test_metrics = continuous_test_metrics_list self.last_continuous_outputs = { "reserves": np.array(continuous_outputs["reserves"]), "weights": np.array(continuous_outputs["weights"]), } # Compute selection value and param_idx selection_value, param_idx = compute_selection_metric( train_metrics_list, val_metrics_list, continuous_test_metrics_list, method=self.selection_method, metric=self.metric, min_threshold=self.min_threshold, ) self.last_param_idx = param_idx # Update "best" if improved if selection_value > self.best_metric_value: self.best_metric_value = selection_value self.best_iteration = iteration self.best_param_idx = param_idx self.best_params = deepcopy(params) self.best_train_metrics = train_metrics_list self.best_val_metrics = val_metrics_list self.best_continuous_test_metrics = continuous_test_metrics_list self.best_continuous_outputs = { "reserves": np.array(continuous_outputs["reserves"]), "weights": np.array(continuous_outputs["weights"]), } return True return False
[docs] def select_param_set(self, params_dict: Dict, idx: int, n_param_sets: int) -> Dict: """ Extract single param set from batched params. Parameters ---------- params_dict : dict Batched parameters with shape (n_param_sets, ...) for each key. idx : int Index of param set to extract. n_param_sets : int Total number of param sets. Returns ------- dict Parameters for single param set with shape (...) for each key. """ if n_param_sets == 1: # Already single param set, just squeeze selected = {} for k, v in params_dict.items(): if k == "subsidary_params": selected[k] = v elif hasattr(v, 'shape') and len(v.shape) >= 1 and v.shape[0] == 1: selected[k] = np.squeeze(v, axis=0) if isinstance(v, np.ndarray) else v[0] else: selected[k] = v return selected else: # Select the param set at idx selected = {} for k, v in params_dict.items(): if k == "subsidary_params": selected[k] = v elif hasattr(v, 'shape') and len(v.shape) >= 1 and v.shape[0] == n_param_sets: selected[k] = v[idx] else: selected[k] = v return selected
[docs] def get_results(self, n_param_sets: int, train_bout_length: int) -> Dict: """ Get comprehensive results with both last and best state. Parameters ---------- n_param_sets : int Number of parameter sets (for extracting correct shapes). train_bout_length : int Length of training period. Used to extract final reserves/weights at end of training (not end of test) for warm-starting. Returns ------- dict Comprehensive results including: - last_* fields: State from most recent iteration - best_* fields: State from when selection metric was best - Selection metadata """ # Extract final state at end of training period (for warm-starting next cycle) # continuous_outputs has shape (n_param_sets, total_time_steps, ...) # where total_time_steps = train + val + test # We want the state at train_bout_length (end of training) last_final_reserves = None last_final_weights = None best_final_reserves = None best_final_weights = None if self.last_continuous_outputs is not None: # Index at train_bout_length gives state at END of training period last_final_reserves = self.last_continuous_outputs["reserves"][:, train_bout_length - 1, :] last_final_weights = self.last_continuous_outputs["weights"][:, train_bout_length - 1, :] if self.best_continuous_outputs is not None: best_final_reserves = self.best_continuous_outputs["reserves"][:, train_bout_length - 1, :] best_final_weights = self.best_continuous_outputs["weights"][:, train_bout_length - 1, :] return { # Last iteration results "last_iteration": self.last_iteration, "last_params": self.last_params, "last_param_idx": self.last_param_idx, "last_train_metrics": self.last_train_metrics, "last_val_metrics": self.last_val_metrics, "last_continuous_test_metrics": self.last_continuous_test_metrics, "last_final_reserves": last_final_reserves, "last_final_weights": last_final_weights, # Best iteration results "best_iteration": self.best_iteration, "best_params": self.best_params, "best_param_idx": self.best_param_idx, "best_metric_value": self.best_metric_value, "best_train_metrics": self.best_train_metrics, "best_val_metrics": self.best_val_metrics, "best_continuous_test_metrics": self.best_continuous_test_metrics, "best_final_reserves": best_final_reserves, "best_final_weights": best_final_weights, # Selection info "selection_method": self.selection_method, "selection_metric": self.metric, }