"""Core training and simulation runners for quantammsim.
This module provides the two primary entry points for using quantammsim:
:func:`train_on_historic_data`
Optimise strategy parameters on historical price data using either
gradient descent (Adam/AdamW/SGD via Optax) or gradient-free search
(Optuna). Supports ensemble training, early stopping with validation
holdout, warm-starting from previous walk-forward cycles, checkpointing
for Rademacher complexity analysis, and Stochastic Weight Averaging.
:func:`do_run_on_historic_data`
Execute a single forward pass (simulation) with fixed parameters and
return the full results dict. Used for post-training evaluation,
walk-forward OOS testing, and visualisation. Supports injecting
real trade data, time-varying fees/gas costs, and LP supply changes.
Both functions accept a ``run_fingerprint`` dict as their primary
configuration. See :doc:`/user_guide/run_fingerprints` for the complete
reference of available settings.
"""
import numpy as np
from copy import deepcopy
from tqdm import tqdm
import math
import gc
import os
import optuna
from jax.tree_util import Partial
from jax import jit, vmap, random
from jax import clear_caches
from jax.tree_util import tree_map
from quantammsim.utils.data_processing.historic_data_utils import (
get_data_dict,
)
from quantammsim.core_simulator.forward_pass import (
forward_pass,
forward_pass_nograd,
_calculate_return_value,
)
from quantammsim.core_simulator.windowing_utils import get_indices, filter_coarse_weights_by_data_indices
from quantammsim.training.backpropagation import (
update_from_partial_training_step_factory,
update_from_partial_training_step_factory_with_optax,
create_opt_state_in_axes_dict,
create_optimizer_chain,
)
from quantammsim.core_simulator.param_utils import (
recursive_default_set,
check_run_fingerprint,
memory_days_to_logit_lamb,
retrieve_best,
process_initial_values,
get_run_location,
)
from quantammsim.core_simulator.result_exporter import (
save_multi_params,
save_optuna_results_sgd_format,
)
from quantammsim.runners.jax_runner_utils import (
nan_param_reinit,
has_nan_grads,
Hashabledict,
NestedHashabledict,
HashableArrayWrapper,
get_trades_and_fees,
get_unique_tokens,
OptunaManager,
generate_evaluation_points,
create_trial_params,
create_static_dict,
get_sig_variations,
BestParamsTracker,
SELECTION_METHODS,
)
from quantammsim.pools.creator import create_pool
from quantammsim.runners.default_run_fingerprint import run_fingerprint_defaults
from quantammsim.utils.post_train_analysis import (
calculate_period_metrics,
calculate_continuous_test_metrics,
)
import jax.numpy as jnp
[docs]
def train_on_historic_data(
run_fingerprint,
root=None,
iterations_per_print=1,
force_init=False,
price_data=None,
verbose=True,
run_location=None,
return_training_metadata=False,
warm_start_params=None,
warm_start_weights=None,
):
"""Optimise strategy parameters on historical price data.
This is the primary training entry point for quantammsim. It loads (or
accepts) price data, constructs the JAX computation graph, and runs either
gradient-based (Adam/AdamW/SGD) or gradient-free (Optuna) optimisation
according to ``run_fingerprint["optimisation_settings"]["method"]``.
Parameters
----------
run_fingerprint : dict
Master configuration dict. Key fields consumed here:
- ``tokens``, ``startDateString``, ``endDateString``,
``endTestDateString`` — data selection
- ``rule`` — pool/strategy type (e.g. ``"momentum"``,
``"mean_reversion_channel"``)
- ``return_val`` — objective metric (default ``"daily_log_sharpe"``)
- ``optimisation_settings.method`` — ``"gradient_descent"`` or
``"optuna"``
- ``optimisation_settings.optimiser`` — ``"adam"``, ``"adamw"``,
or ``"sgd"``
- ``optimisation_settings.n_iterations`` — training epochs
- ``optimisation_settings.val_fraction`` — fraction of training
window held out for early-stopping validation (0 = disabled)
- ``optimisation_settings.use_swa`` — enable Stochastic Weight
Averaging
- ``optimisation_settings.track_checkpoints`` — save periodic
parameter snapshots for Rademacher complexity analysis
See :doc:`/user_guide/run_fingerprints` for the full reference.
root : str, optional
Root directory for data files and saved results.
iterations_per_print : int, optional
Print training progress every *N* iterations (default 1).
force_init : bool, optional
If True, ignore cached results and re-initialise parameters.
price_data : array-like or DataFrame, optional
Pre-loaded price data. When None, data is loaded from parquet
files based on ``run_fingerprint`` date/token settings.
verbose : bool, optional
Print detailed progress information (default True).
run_location : str, optional
Path to a previously-saved run to resume from. When None, a new
run is initialised (or auto-detected from the fingerprint hash).
return_training_metadata : bool, optional
If True, return ``(params, metadata)`` where *metadata* contains
``epochs_trained``, ``final_objective``, and ``checkpoint_returns``
(a ``(n_checkpoints, T-1)`` array for Rademacher complexity, or
None if checkpointing was disabled).
warm_start_params : dict, optional
Strategy parameters from a previous walk-forward cycle. Each
value is expanded to ``(n_parameter_sets, ...)`` shape with
added Gaussian noise (scale controlled by
``optimisation_settings.noise_scale``).
warm_start_weights : array-like, optional
Final portfolio weights from a previous cycle, shape
``(n_assets,)``. The pool starts with a fresh
``initial_pool_value`` distributed according to these weights.
Returns
-------
dict or tuple or list or None
- **Gradient descent**, ``return_training_metadata=False``:
best params dict.
- **Gradient descent**, ``return_training_metadata=True``:
``(params, metadata)`` tuple.
- **Optuna**: list of best trials, or None if none completed.
"""
recursive_default_set(run_fingerprint, run_fingerprint_defaults)
check_run_fingerprint(run_fingerprint)
if verbose:
print("Run Fingerprint: ", run_fingerprint)
rule = run_fingerprint["rule"]
chunk_period = run_fingerprint["chunk_period"]
weight_interpolation_period = run_fingerprint["weight_interpolation_period"]
use_alt_lamb = run_fingerprint["use_alt_lamb"]
use_pre_exp_scaling = run_fingerprint["use_pre_exp_scaling"]
fees = run_fingerprint["fees"]
arb_fees = run_fingerprint["arb_fees"]
gas_cost = run_fingerprint["gas_cost"]
n_parameter_sets = run_fingerprint["optimisation_settings"]["n_parameter_sets"]
weight_interpolation_method = run_fingerprint["weight_interpolation_method"]
arb_frequency = run_fingerprint["arb_frequency"]
random_key = random.key(
run_fingerprint["optimisation_settings"]["initial_random_key"]
)
learnable_bounds = run_fingerprint.get("learnable_bounds_settings", {})
initial_params = {
"initial_memory_length": run_fingerprint["initial_memory_length"],
"initial_memory_length_delta": run_fingerprint["initial_memory_length_delta"],
"initial_k_per_day": run_fingerprint["initial_k_per_day"],
"initial_weights_logits": run_fingerprint["initial_weights_logits"],
"initial_log_amplitude": run_fingerprint["initial_log_amplitude"],
"initial_raw_width": run_fingerprint["initial_raw_width"],
"initial_raw_exponents": run_fingerprint["initial_raw_exponents"],
"initial_pre_exp_scaling": run_fingerprint["initial_pre_exp_scaling"],
"min_weights_per_asset": learnable_bounds.get("min_weights_per_asset"),
"max_weights_per_asset": learnable_bounds.get("max_weights_per_asset"),
}
unique_tokens = get_unique_tokens(run_fingerprint)
n_tokens = len(unique_tokens)
n_assets = n_tokens
all_sig_variations = get_sig_variations(n_assets)
np.random.seed(0)
max_memory_days = run_fingerprint["max_memory_days"]
if price_data is None and verbose:
print(f"[Data] Loading {run_fingerprint['optimisation_settings']['training_data_kind']} data...")
data_dict = get_data_dict(
unique_tokens,
run_fingerprint,
data_kind=run_fingerprint["optimisation_settings"]["training_data_kind"],
root=root,
max_memory_days=max_memory_days,
start_date_string=run_fingerprint["startDateString"],
end_time_string=run_fingerprint["endDateString"],
start_time_test_string=run_fingerprint["endDateString"],
end_time_test_string=run_fingerprint["endTestDateString"],
max_mc_version=run_fingerprint["optimisation_settings"]["max_mc_version"],
price_data=price_data,
do_test_period=True,
)
max_memory_days = data_dict["max_memory_days"]
# Validation holdout setup
# If val_fraction > 0, carve out validation window from end of training
val_fraction = run_fingerprint["optimisation_settings"].get("val_fraction", 0.0)
# Validate val_fraction
if val_fraction < 0 or val_fraction >= 1.0:
raise ValueError(
f"val_fraction must be in [0, 1), got {val_fraction}. "
f"Use 0 for no validation holdout, or a value like 0.2 for 20% validation."
)
if val_fraction > 0:
# Store original bout_length for reference (used for continuous forward pass and test slicing)
original_bout_length = data_dict["bout_length"]
# Calculate validation and effective training lengths
val_length = int(original_bout_length * val_fraction)
effective_train_length = original_bout_length - val_length
# Ensure validation window is meaningful (at least 1 day of data for minute frequency)
min_val_length = run_fingerprint.get("chunk_period", 1440) # Default 1 day
if val_length < min_val_length:
raise ValueError(
f"val_fraction={val_fraction} results in val_length={val_length} steps, "
f"which is less than minimum {min_val_length} steps (1 chunk_period). "
f"Increase val_fraction or use a longer training period."
)
# Override data_dict["bout_length"] to be the effective training length
# This ensures training sampling and forward passes use the correct (reduced) length
data_dict["bout_length"] = effective_train_length
val_start_idx = data_dict["start_idx"] + effective_train_length
# Ensure we have room for random sampling in the training region
bout_length_window = effective_train_length - run_fingerprint["bout_offset"]
if bout_length_window <= 0:
raise ValueError(
f"val_fraction={val_fraction} is too large. "
f"effective_train_length ({effective_train_length}) must be > bout_offset ({run_fingerprint['bout_offset']}). "
f"Either reduce val_fraction or increase bout_length or reduce bout_offset."
)
if verbose:
# Convert steps to days for readability (assuming minute data)
steps_per_day = 1440
print(f"[Setup] Validation holdout: {val_fraction*100:.0f}%")
print(f" Train: {effective_train_length:,} steps (~{effective_train_length/steps_per_day:.1f} days)")
print(f" Val: {val_length:,} steps (~{val_length/steps_per_day:.1f} days)")
print(f" Test: {data_dict.get('bout_length_test', 0):,} steps (~{data_dict.get('bout_length_test', 0)/steps_per_day:.1f} days)")
else:
# No validation holdout - use full training window
# Early stopping will use test data (not recommended but backwards compatible)
original_bout_length = data_dict["bout_length"] # No difference when no validation
bout_length_window = data_dict["bout_length"] - run_fingerprint["bout_offset"]
val_length = 0
val_start_idx = None
assert bout_length_window > 0
# Determine the end index for sampling (must not overlap with validation)
if val_fraction > 0:
# Sampling must stay within effective training region
sampling_end_idx = val_start_idx
else:
# No validation - use original behavior
sampling_end_idx = data_dict["end_idx"]
if run_location is None:
run_location = './results/' + get_run_location(run_fingerprint) + ".json"
# Check for cached results (skip if force_init=True)
if not force_init and os.path.isfile(run_location):
if verbose:
print(f"[Cache] Loading cached results from: {run_location}")
params, step = retrieve_best(run_location, "best_train_objective", False, None)
loaded = True
else:
if force_init and os.path.isfile(run_location) and verbose:
print(f"[Cache] force_init=True, ignoring cached file")
loaded = False
# Create pool
pool = create_pool(rule)
# pool must be trainable
assert pool.is_trainable(), "The selected pool must be trainable for this operation"
if not loaded:
# Check if we should warm-start from previous cycle params
if warm_start_params is not None:
# Use warm_start_params as initialization for strategy parameters
# (lamb, k, etc.). Pool starts with fresh initial_pool_value but
# distributed according to warm_start_weights if provided.
params = {}
for key, value in warm_start_params.items():
if key == "subsidary_params":
params[key] = value if value is not None else []
continue
# Skip initial_reserves - we compute fresh reserves below
if key == "initial_reserves":
continue
if hasattr(value, 'copy'):
params[key] = jnp.array(value.copy())
else:
params[key] = jnp.array(value) if not isinstance(value, (list, type(None))) else value
# Ensure params have correct shape for n_parameter_sets
# warm_start_params are single param set (shape: (n_assets,) or scalar)
# need to expand to (n_parameter_sets, ...) format
# Step 1: Stack to (n_parameter_sets, ...) shape
for key, value in list(params.items()):
if key == "subsidary_params" or value is None:
continue
# Convert to array if not already (handles scalars from optuna make_scalar=True)
arr_value = np.array(value)
if arr_value.ndim == 0:
# Scalar: expand to (n_parameter_sets, 1)
params[key] = np.stack([arr_value.reshape(1)] * n_parameter_sets, axis=0)
else:
# Array: expand to (n_parameter_sets, ...)
params[key] = np.stack([arr_value] * n_parameter_sets, axis=0)
# Step 2: Add noise using existing pool method (reuse single source of truth)
noise_scale = run_fingerprint["optimisation_settings"].get("noise_scale", 0.1)
params = pool.add_noise(params, "gaussian", n_parameter_sets, noise_scale)
# Initialize reserves with fresh initial_pool_value
# If warm_start_weights provided, distribute according to those weights
# Otherwise use equal weights
initial_pool_value = run_fingerprint["initial_pool_value"]
start_prices = data_dict["prices"][data_dict["start_idx"]]
n_assets_local = len(start_prices)
if warm_start_weights is not None:
# Validate warm_start_weights before using
weights = jnp.array(warm_start_weights)
weights_sum = jnp.sum(weights)
if jnp.any(jnp.isnan(weights)):
if verbose:
print("[Warm-start] Warning: weights contain NaN, using equal weights")
warm_start_weights = None
elif weights_sum <= 0:
if verbose:
print("[Warm-start] Warning: weights sum <= 0, using equal weights")
warm_start_weights = None
if warm_start_weights is not None:
# Use previous cycle's ending weights to distribute fresh pool value
weights = jnp.array(warm_start_weights)
# Normalize weights to sum to 1 (safety check)
weights = weights / (jnp.sum(weights) + 1e-10)
# Compute reserves: value_per_asset = weight * total_value, reserves = value / price
value_per_asset = weights * initial_pool_value
fresh_reserves = value_per_asset / start_prices
if verbose:
weights_str = ", ".join([f"{w:.2%}" for w in np.array(weights)])
print(f"[Warm-start] Using previous params + weights [{weights_str}]")
else:
# Equal weight initial reserves
value_per_asset = initial_pool_value / n_assets_local
fresh_reserves = value_per_asset / start_prices
if verbose:
print(f"[Warm-start] Using previous params with equal weights")
params["initial_reserves"] = jnp.stack([fresh_reserves] * n_parameter_sets, axis=0)
offset = 0
else:
parameter_init_method = run_fingerprint["optimisation_settings"].get(
"parameter_init_method", "gaussian"
)
params = pool.init_parameters(
initial_params, run_fingerprint, n_tokens, n_parameter_sets,
noise=parameter_init_method,
)
offset = 0
else:
offset = step + 1
if verbose:
print(f"[Cache] Resuming from step {offset}")
for key in ["step", "test_objective", "train_objective", "hessian_trace", "local_learning_rate", "iterations_since_improvement", "objective", "continuous_test_metrics", "validation_metrics"]:
if key in params:
params.pop(key)
if run_fingerprint["optimisation_settings"]["method"] == "optuna":
n_parameter_sets = 1
for key, value in params.items():
params[key] = process_initial_values(
params, key, n_assets, n_parameter_sets, force_scalar=True
)
params["subsidary_params"] = []
# noise_scale controls initialization diversity for param sets 1+
# Default 0.1 maintains backward compatibility
noise_scale = run_fingerprint["optimisation_settings"].get("noise_scale", 0.1)
params = pool.add_noise(params, "gaussian", n_parameter_sets, noise_scale=noise_scale)
params_in_axes_dict = pool.make_vmap_in_axes(params)
# Create static dict using helper - overrides for training-specific values
base_static_dict = create_static_dict(
run_fingerprint,
bout_length=bout_length_window,
all_sig_variations=all_sig_variations,
overrides={
"n_assets": n_assets,
"training_data_kind": run_fingerprint["optimisation_settings"]["training_data_kind"],
"do_trades": False,
},
)
partial_training_step = Partial(
forward_pass,
prices=data_dict["prices"],
static_dict=Hashabledict(base_static_dict),
pool=pool,
)
partial_forward_pass_nograd_batch = Partial(
forward_pass_nograd,
prices=data_dict["prices"],
static_dict=Hashabledict(base_static_dict),
pool=pool,
)
# Note: Validation and test metrics are now computed by slicing from the continuous
# forward pass (which covers train + validation + test) rather than running separate
# passes. This ensures metrics reflect continuous simulation state.
returns_train_static_dict = base_static_dict.copy()
returns_train_static_dict["return_val"] = "returns"
returns_train_static_dict["bout_length"] = data_dict["bout_length"]
partial_forward_pass_nograd_batch_returns_train = Partial(
forward_pass_nograd,
static_dict=Hashabledict(returns_train_static_dict),
pool=pool,
)
# Create continuous forward pass that covers train + validation + test period
# Use original_bout_length to include validation period when val_fraction > 0
continuous_static_dict = base_static_dict.copy()
continuous_static_dict["return_val"] = "reserves_and_values"
continuous_static_dict["bout_length"] = original_bout_length + data_dict["bout_length_test"]
partial_forward_pass_nograd_batch_continuous = Partial(
forward_pass_nograd,
static_dict=Hashabledict(continuous_static_dict),
pool=pool,
)
nograd_in_axes = [params_in_axes_dict, None, None]
partial_forward_pass_nograd_returns_train = jit(
vmap(
partial_forward_pass_nograd_batch_returns_train,
in_axes=nograd_in_axes,
)
)
partial_forward_pass_nograd_continuous = jit(
vmap(
partial_forward_pass_nograd_batch_continuous,
in_axes=nograd_in_axes,
)
)
partial_fixed_training_step = Partial(
partial_training_step, start_index=(data_dict["start_idx"], 0)
)
local_learning_rate = run_fingerprint["optimisation_settings"]["base_lr"]
iterations_since_improvement = 0
max_iterations_with_no_improvement = run_fingerprint["optimisation_settings"][
"decay_lr_plateau"
]
decay_lr_ratio = run_fingerprint["optimisation_settings"]["decay_lr_ratio"]
min_lr = run_fingerprint["optimisation_settings"]["min_lr"]
# Early stopping settings
# If val_fraction > 0, early stopping uses validation metrics (recommended)
# If val_fraction == 0, early stopping uses test metrics (data leakage - not recommended)
use_early_stopping = run_fingerprint["optimisation_settings"].get("early_stopping", False)
early_stopping_patience = run_fingerprint["optimisation_settings"].get("early_stopping_patience", 200)
# This metric is used for TWO purposes:
# 1. Early stopping: determines when to stop training (if use_early_stopping=True)
# 2. Param selection: determines which params to return (if val_fraction > 0)
# The name "early_stopping_metric" is historical - it's really a "selection_metric"
selection_metric = run_fingerprint["optimisation_settings"].get("early_stopping_metric", "sharpe")
# Validate selection metric
# All metrics are normalized so higher = better (see forward_pass.py _calculate_* functions)
# These must match keys returned by calculate_period_metrics in post_train_analysis.py
valid_metrics = [
"sharpe", "daily_log_sharpe", "return", "returns_over_hodl",
"returns_over_uniform_hodl", "calmar", "sterling", "ulcer",
]
if (use_early_stopping or val_fraction > 0) and selection_metric not in valid_metrics:
raise ValueError(
f"early_stopping_metric '{selection_metric}' is not valid. "
f"Must be one of: {valid_metrics}"
)
metric_direction = 1 # All metrics: higher = better
# Early stopping state (only used when use_early_stopping=True)
# Early stopping only controls WHEN to stop, not WHAT params to return.
# Final param selection is handled by BestParamsTracker.
best_early_stopping_metric = float("inf") if metric_direction == -1 else -float("inf")
iterations_since_early_stopping_improvement = 0
use_validation_for_early_stopping = val_fraction > 0
warned_about_nan = False # Track if we've already warned about NaN metrics
# Initialize BestParamsTracker for unified param selection
# Selection method depends on whether validation is enabled
tracker_selection_method = "best_val" if val_fraction > 0 else "best_train"
params_tracker = BestParamsTracker(
selection_method=tracker_selection_method,
metric=selection_metric,
min_threshold=0.0,
)
# SWA settings
use_swa = run_fingerprint["optimisation_settings"].get("use_swa", False)
swa_start_frac = run_fingerprint["optimisation_settings"].get("swa_start_frac", 0.75)
swa_freq = run_fingerprint["optimisation_settings"].get("swa_freq", 10)
swa_params_list = [] # Will collect parameters for averaging
n_iterations = run_fingerprint["optimisation_settings"]["n_iterations"]
# Checkpoint tracking for Rademacher complexity
track_checkpoints = run_fingerprint["optimisation_settings"].get("track_checkpoints", False)
checkpoint_interval = run_fingerprint["optimisation_settings"].get("checkpoint_interval", 10)
checkpoint_returns_list = [] # Will collect returns at each checkpoint for Rademacher
# Warn about SWA + validation conflict
if use_swa and val_fraction > 0:
import warnings
warnings.warn(
"Both SWA and validation holdout are enabled. "
"Validation-based param selection will take precedence over SWA. "
"To use SWA, set val_fraction=0.",
UserWarning
)
if run_fingerprint["optimisation_settings"]["method"] == "gradient_descent":
if run_fingerprint["optimisation_settings"]["optimiser"] in ["adam", "adamw"]:
import optax
# Create Adam optimizer with the specified learning rate
optimizer = create_optimizer_chain(run_fingerprint)
# Initialize optimizer state for each parameter set
# For multiple parameter sets, each needs its own optimizer state
# if n_parameter_sets > 1:
# Use vmap to vectorize optimizer initialization over parameter sets
init_optimizer = lambda params: optimizer.init(params)
batched_init = vmap(init_optimizer, in_axes=[params_in_axes_dict])
opt_state = batched_init(params)
# else:
# opt_state = optimizer.init(params)
opt_state_in_axes_dict = create_opt_state_in_axes_dict(opt_state)
# Use optax-based update function
update_batch = update_from_partial_training_step_factory_with_optax(
partial_training_step,
optimizer,
run_fingerprint["optimisation_settings"]["train_on_hessian_trace"],
partial_fixed_training_step,
)
update = jit(
vmap(
update_batch,
in_axes=[params_in_axes_dict, None, None, opt_state_in_axes_dict],
)
)
elif run_fingerprint["optimisation_settings"]["optimiser"] == "sgd":
update_batch = update_from_partial_training_step_factory(
partial_training_step,
run_fingerprint["optimisation_settings"]["train_on_hessian_trace"],
partial_fixed_training_step,
)
update = jit(
vmap(
update_batch,
in_axes=[params_in_axes_dict, None, None],
)
)
elif run_fingerprint["optimisation_settings"]["optimiser"] != "sgd":
raise NotImplementedError
paramSteps = []
trainingSteps = []
continuousTestSteps = []
validationSteps = [] # Collect validation metrics when val_fraction > 0
objectiveSteps = []
learningRateSteps = []
interationsSinceImprovementSteps = []
stepSteps = []
train_prices = data_dict["prices"][data_dict["start_idx"]:data_dict["start_idx"] + data_dict["bout_length"]]
continuous_prices = data_dict["prices"][data_dict["start_idx"]:data_dict["start_idx"] + original_bout_length + data_dict["bout_length_test"]]
val_prices = data_dict["prices"][
data_dict["start_idx"] + data_dict["bout_length"]:
data_dict["start_idx"] + original_bout_length
]
for i in range(run_fingerprint["optimisation_settings"]["n_iterations"] + 1):
step = i + offset
start_indexes, random_key = get_indices(
start_index=data_dict["start_idx"],
bout_length=bout_length_window,
len_prices=sampling_end_idx, # Limited to not overlap with validation
key=random_key,
optimisation_settings=run_fingerprint["optimisation_settings"],
)
if run_fingerprint["optimisation_settings"]["optimiser"] in [
"adam",
"adamw",
]:
# Adam update with state maintenance
params, objective_value, old_params, grads, opt_state = update(
params, start_indexes, local_learning_rate, opt_state
)
else:
# Regular SGD update
params, objective_value, old_params, grads = update(
params, start_indexes, local_learning_rate
)
params = nan_param_reinit(
params,
grads,
pool,
initial_params,
run_fingerprint,
n_tokens,
n_parameter_sets,
)
# Run continuous forward pass covering train + test period
# This is vmapped over parameter sets, so outputs have shape:
# - value: (n_parameter_sets, time_steps)
# - reserves: (n_parameter_sets, time_steps, n_assets)
continuous_outputs = partial_forward_pass_nograd_continuous(
params,
(data_dict["start_idx"], 0),
data_dict["prices"],
)
# Process each parameter set individually
# (metric functions expect single parameter set, not batched)
train_metrics_list = []
continuous_test_metrics_list = []
for param_idx in range(n_parameter_sets):
# Extract outputs for this parameter set
# After indexing: value (time_steps,), reserves (time_steps, n_assets)
param_value = continuous_outputs["value"][param_idx]
param_reserves = continuous_outputs["reserves"][param_idx]
# Slice train period (uses effective_train_length when val_fraction > 0)
train_dict = {
"value": param_value[:data_dict["bout_length"]],
"reserves": param_reserves[:data_dict["bout_length"]],
}
# Create continuous dict for test metrics
# continuous_test_metrics computes metrics on test slice from continuous simulation
param_continuous_dict = {
"value": param_value,
"reserves": param_reserves,
}
# Calculate metrics
train_metrics = calculate_period_metrics(train_dict, train_prices)
continuous_test_metrics = calculate_continuous_test_metrics(
param_continuous_dict,
original_bout_length, # Use original length as train/test boundary
data_dict["bout_length_test"],
continuous_prices
)
train_metrics_list.append(train_metrics)
continuous_test_metrics_list.append(continuous_test_metrics)
# Compute validation metrics if val_fraction > 0 (for early stopping and saving)
if val_fraction > 0:
val_metrics_list = []
for param_idx in range(n_parameter_sets):
# Validation period: from effective_train_length to original_bout_length
val_dict = {
"value": continuous_outputs["value"][param_idx, data_dict["bout_length"]:original_bout_length],
"reserves": continuous_outputs["reserves"][param_idx, data_dict["bout_length"]:original_bout_length, :],
}
val_metrics = calculate_period_metrics(val_dict, val_prices)
val_metrics_list.append(val_metrics)
# Collect validation metrics for saving
validationSteps.append(val_metrics_list)
# Compute current_val_metric for early stopping
val_metrics_per_set = np.array([
t.get(selection_metric, np.nan) for t in val_metrics_list
])
current_val_metric = np.nanmean(val_metrics_per_set)
else:
val_metrics_list = None
current_val_metric = None
# Update BestParamsTracker - handles both best_train and best_val selection
tracker_improved = params_tracker.update(
iteration=step,
params=params,
continuous_outputs=continuous_outputs,
train_metrics_list=train_metrics_list,
val_metrics_list=val_metrics_list,
continuous_test_metrics_list=continuous_test_metrics_list,
)
# Track iterations since improvement for learning rate decay
# This uses the tracker's improvement signal
if tracker_improved:
iterations_since_improvement = 0
else:
iterations_since_improvement += 1
if iterations_since_improvement > max_iterations_with_no_improvement:
local_learning_rate = local_learning_rate * decay_lr_ratio
iterations_since_improvement = 0
if local_learning_rate < min_lr:
local_learning_rate = min_lr
# Save step data for checkpointing
paramSteps.append(deepcopy(params))
trainingSteps.append(train_metrics_list)
continuousTestSteps.append(continuous_test_metrics_list)
objectiveSteps.append(np.array(objective_value.copy()))
learningRateSteps.append(deepcopy(local_learning_rate))
interationsSinceImprovementSteps.append(iterations_since_improvement)
stepSteps.append(step)
# Early stopping based on validation or test metrics
# Note: Early stopping only controls WHEN to stop training.
# Final param selection is handled by params_tracker.
if use_early_stopping:
if use_validation_for_early_stopping and val_metrics_list:
# Reuse current_val_metric computed above (same value)
current_early_stopping_metric = current_val_metric
metric_source = "validation"
# Warn on first occurrence of NaN (not just iteration 0)
if np.isnan(current_early_stopping_metric) and not warned_about_nan:
import warnings
warnings.warn(
f"Validation {selection_metric} is NaN at iteration {i}. "
f"Early stopping may not work correctly. "
f"Check that validation period has sufficient data.",
UserWarning
)
warned_about_nan = True
elif continuous_test_metrics_list:
# Fallback to continuous test metrics (not recommended - causes data leakage)
# Note: When using test metrics for early stopping, param SELECTION still uses
# training-best (since val_fraction=0). This is intentional - we don't want to
# select params based on test performance, only use it as a stopping heuristic.
# Use nanmean to ignore NaN param sets
current_early_stopping_metric = np.nanmean([
t.get(selection_metric, np.nan) for t in continuous_test_metrics_list
])
metric_source = "continuous_test"
else:
current_early_stopping_metric = -float("inf")
metric_source = "none"
# Track early stopping metric for patience countdown
metric_improved = (current_early_stopping_metric * metric_direction) > (best_early_stopping_metric * metric_direction)
if metric_improved:
best_early_stopping_metric = current_early_stopping_metric
iterations_since_early_stopping_improvement = 0
else:
iterations_since_early_stopping_improvement += 1
if iterations_since_early_stopping_improvement >= early_stopping_patience:
if verbose:
print(f"\n[Early stopping] No {metric_source} {selection_metric} improvement for {early_stopping_patience} iterations")
print(f" Stopped at iteration {step}, best {selection_metric}={best_early_stopping_metric:+.4f}")
# Just break - param selection happens at the end using params_tracker
break
# SWA: collect parameters after swa_start_frac of training
if use_swa and i >= int(n_iterations * swa_start_frac) and i % swa_freq == 0:
swa_params_list.append(deepcopy(params))
# Checkpoint tracking for Rademacher complexity
# Save DAILY EXCESS returns (vs uniform HODL) at checkpoint intervals
# Daily aggregation gives more meaningful Rademacher values
if track_checkpoints and i % checkpoint_interval == 0:
# Extract values and prices from the training period
# continuous_outputs["value"] has shape (n_parameter_sets, time_steps)
train_values = continuous_outputs["value"][:, :data_dict["bout_length"]]
train_prices = data_dict["prices"][data_dict["start_idx"]:data_dict["start_idx"] + data_dict["bout_length"]]
# Compute uniform HODL benchmark (equal weight, no rebalancing)
# Price ratio for each asset: p_t / p_0
price_ratios = train_prices / (train_prices[0:1] + 1e-10) # (T, n_assets)
# Uniform HODL value = initial_value * mean(price_ratios across assets)
uniform_hodl_value = price_ratios.mean(axis=-1) # (T,)
# Compute log returns for model and benchmark
# Shape: (n_parameter_sets, bout_length - 1)
model_log_returns = jnp.diff(jnp.log(train_values + 1e-10), axis=-1)
hodl_log_returns = jnp.diff(jnp.log(uniform_hodl_value + 1e-10)) # (bout_length - 1,)
# Excess returns = model returns - benchmark returns
excess_returns = model_log_returns - hodl_log_returns[None, :]
# Take mean across parameter sets (they're independent runs)
# Shape: (bout_length - 1,)
checkpoint_excess_returns = np.array(excess_returns.mean(axis=0))
# Aggregate to daily resolution (1440 minutes per day)
# This gives more meaningful Rademacher values
minutes_per_day = 1440
n_full_days = len(checkpoint_excess_returns) // minutes_per_day
if n_full_days > 0:
# Sum minute returns to get daily returns (log returns are additive)
daily_excess = checkpoint_excess_returns[:n_full_days * minutes_per_day]
daily_excess = daily_excess.reshape(n_full_days, minutes_per_day).sum(axis=1)
# Only save if no NaN values (training didn't explode)
if not np.isnan(daily_excess).any():
checkpoint_returns_list.append(daily_excess)
if step % iterations_per_print == 0:
if verbose:
# Format metrics for display
obj_val = float(np.mean(objective_value)) if hasattr(objective_value, '__len__') else float(objective_value)
print(f"\n[Iter {step}] objective={obj_val:.4f}")
# Training metrics (in-sample)
if train_metrics_list:
train_sharpes = [t.get("sharpe", np.nan) for t in train_metrics_list]
train_rohs = [t.get("returns_over_uniform_hodl", np.nan) for t in train_metrics_list]
print(f" Train (IS): sharpe={np.nanmean(train_sharpes):+.4f} ret_over_hodl={np.nanmean(train_rohs):+.4f}")
# Validation metrics (if using validation holdout)
if val_fraction > 0 and val_metrics_list:
val_sharpe = np.nanmean([t.get("sharpe", np.nan) for t in val_metrics_list])
val_roh = np.nanmean([t.get("returns_over_uniform_hodl", np.nan) for t in val_metrics_list])
print(f" Val: sharpe={val_sharpe:+.4f} ret_over_hodl={val_roh:+.4f}")
if use_early_stopping:
print(f" Early stop: {selection_metric}={current_early_stopping_metric:+.4f} "
f"(best={best_early_stopping_metric:+.4f}, wait={iterations_since_early_stopping_improvement}/{early_stopping_patience})")
# Continuous test metrics (out-of-sample, from continuous forward pass)
if continuous_test_metrics_list:
test_sharpes = [t.get("sharpe", np.nan) for t in continuous_test_metrics_list]
test_rohs = [t.get("returns_over_uniform_hodl", np.nan) for t in continuous_test_metrics_list]
print(f" Test (OOS): sharpe={np.nanmean(test_sharpes):+.4f} ret_over_hodl={np.nanmean(test_rohs):+.4f}")
save_multi_params(
deepcopy(run_fingerprint),
paramSteps,
continuousTestSteps, # Used as test_objective for backward compat
trainingSteps,
objectiveSteps,
learningRateSteps,
interationsSinceImprovementSteps,
stepSteps,
continuousTestSteps,
validation_metrics=validationSteps if validationSteps else None,
sorted_tokens=True,
)
paramSteps = []
trainingSteps = []
continuousTestSteps = []
validationSteps = []
objectiveSteps = []
learningRateSteps = []
interationsSinceImprovementSteps = []
stepSteps = []
# Get results from tracker (includes both last and best state)
tracker_results = params_tracker.get_results(n_parameter_sets, original_bout_length)
if verbose:
obj_val = float(np.mean(objective_value)) if hasattr(objective_value, '__len__') else float(objective_value)
print(f"\n{'='*60}")
print(f"TRAINING COMPLETE - {i + 1} iterations")
print(f"{'='*60}")
print(f"Final objective: {obj_val:.4f}")
print(f"Selection: method={tracker_results['selection_method']}, metric={tracker_results['selection_metric']}")
# Build training metadata for analysis and evaluation
# Includes both "last" (final iteration) and "best" (by selection method) results
training_metadata = {
"method": "gradient_descent",
"epochs_trained": i + 1, # Actual iterations completed
"final_objective": float(np.array(objective_value).mean()),
# Last iteration metrics (for all param sets)
"last_train_metrics": tracker_results["last_train_metrics"],
"last_continuous_test_metrics": tracker_results["last_continuous_test_metrics"],
"last_val_metrics": tracker_results["last_val_metrics"],
"last_param_idx": tracker_results["last_param_idx"],
"last_final_reserves": tracker_results["last_final_reserves"][tracker_results["last_param_idx"]] if tracker_results["last_final_reserves"] is not None else None,
"last_final_weights": tracker_results["last_final_weights"][tracker_results["last_param_idx"]] if tracker_results["last_final_weights"] is not None else None,
# Best iteration metrics (by selection method)
"best_train_metrics": tracker_results["best_train_metrics"],
"best_continuous_test_metrics": tracker_results["best_continuous_test_metrics"],
"best_val_metrics": tracker_results["best_val_metrics"],
"best_param_idx": tracker_results["best_param_idx"],
"best_iteration": tracker_results["best_iteration"],
"best_metric_value": tracker_results["best_metric_value"],
"best_final_reserves": tracker_results["best_final_reserves"][tracker_results["best_param_idx"]] if tracker_results["best_final_reserves"] is not None else None,
"best_final_weights": tracker_results["best_final_weights"][tracker_results["best_param_idx"]] if tracker_results["best_final_weights"] is not None else None,
# Selection info
"selection_method": tracker_results["selection_method"],
"selection_metric": tracker_results["selection_metric"],
# Legacy field names (for backward compatibility)
# TODO: Deprecate these in favor of best_* fields
"final_train_metrics": tracker_results["best_train_metrics"],
"final_continuous_test_metrics": tracker_results["best_continuous_test_metrics"],
"final_weights": tracker_results["best_final_weights"][tracker_results["best_param_idx"]] if tracker_results["best_final_weights"] is not None else None,
"final_reserves": tracker_results["best_final_reserves"][tracker_results["best_param_idx"]] if tracker_results["best_final_reserves"] is not None else None,
# Provenance
"run_location": run_location,
"run_fingerprint": deepcopy(run_fingerprint),
}
if track_checkpoints and checkpoint_returns_list:
training_metadata["checkpoint_returns"] = np.stack(checkpoint_returns_list, axis=0)
else:
training_metadata["checkpoint_returns"] = None
# SWA: Stochastic Weight Averaging (only if no validation data)
# SWA averages params across TIME (different training iterations), not across param sets.
# After SWA averaging, we still have n_parameter_sets param sets - we then select the
# best one based on the tracker's best_param_idx.
if use_swa and len(swa_params_list) > 0 and val_fraction == 0:
if verbose:
print(f"Applying SWA: averaging {len(swa_params_list)} parameter snapshots across time")
swa_params = {}
for key in swa_params_list[0].keys():
if key == "subsidary_params":
swa_params[key] = swa_params_list[-1][key]
else:
stacked = jnp.stack([p[key] for p in swa_params_list], axis=0)
swa_params[key] = jnp.mean(stacked, axis=0)
# Select param set using tracker's best_param_idx
selected_params = params_tracker.select_param_set(swa_params, tracker_results["best_param_idx"], n_parameter_sets)
if return_training_metadata:
return selected_params, training_metadata
return selected_params
# Return best params from tracker
best_params = tracker_results["best_params"]
best_idx = tracker_results["best_param_idx"]
if verbose:
# Print best iteration results
print(f"\nBest iteration: {tracker_results['best_iteration']} (param_set={best_idx})")
print(f" Selection {tracker_results['selection_metric']}: {tracker_results['best_metric_value']:+.4f}")
# Best train metrics
if tracker_results["best_train_metrics"]:
best_train = tracker_results["best_train_metrics"][best_idx]
print(f" Train (IS): sharpe={best_train.get('sharpe', np.nan):+.4f} "
f"ret_over_hodl={best_train.get('returns_over_uniform_hodl', np.nan):+.4f}")
# Best validation metrics (if used)
if tracker_results["best_val_metrics"] and tracker_results["best_val_metrics"][best_idx]:
best_val = tracker_results["best_val_metrics"][best_idx]
print(f" Val: sharpe={best_val.get('sharpe', np.nan):+.4f} "
f"ret_over_hodl={best_val.get('returns_over_uniform_hodl', np.nan):+.4f}")
# Best continuous test metrics (OOS)
if tracker_results["best_continuous_test_metrics"]:
best_test = tracker_results["best_continuous_test_metrics"][best_idx]
print(f" Test (OOS): sharpe={best_test.get('sharpe', np.nan):+.4f} "
f"ret_over_hodl={best_test.get('returns_over_uniform_hodl', np.nan):+.4f}")
# Compare with last iteration if different
if tracker_results["best_iteration"] != i:
print(f"\nLast iteration: {i}")
if tracker_results["last_train_metrics"]:
last_train = tracker_results["last_train_metrics"][tracker_results["last_param_idx"]]
print(f" Train (IS): sharpe={last_train.get('sharpe', np.nan):+.4f} "
f"ret_over_hodl={last_train.get('returns_over_uniform_hodl', np.nan):+.4f}")
if tracker_results["last_continuous_test_metrics"]:
last_test = tracker_results["last_continuous_test_metrics"][tracker_results["last_param_idx"]]
print(f" Test (OOS): sharpe={last_test.get('sharpe', np.nan):+.4f} "
f"ret_over_hodl={last_test.get('returns_over_uniform_hodl', np.nan):+.4f}")
print(f"{'='*60}")
selected_params = params_tracker.select_param_set(best_params, best_idx, n_parameter_sets)
if return_training_metadata:
return selected_params, training_metadata
return selected_params
elif run_fingerprint["optimisation_settings"]["method"] == "optuna":
n_evaluation_points = 20
min_spacing = data_dict["bout_length"] // 2 # E
run_fingerprint["optimisation_settings"]["n_parameter_sets"] = 1
# assert run_fingerprint["optimisation_settings"]["n_parameter_sets"] == 1, \
# "Optuna only supports single parameter sets"
# Generate and store evaluation points
if "evaluation_starts" not in run_fingerprint:
evaluation_starts = generate_evaluation_points(
data_dict["start_idx"],
data_dict["end_idx"],
bout_length_window,
n_evaluation_points,
min_spacing,
run_fingerprint["optimisation_settings"]["initial_random_key"],
)
run_fingerprint["evaluation_starts"] = [int(e) for e in evaluation_starts]
else:
evaluation_starts = run_fingerprint["evaluation_starts"]
reserves_values_train_static_dict = base_static_dict.copy()
reserves_values_train_static_dict["return_val"] = "reserves_and_values"
reserves_values_train_static_dict["bout_length"] = data_dict["bout_length"]
partial_forward_pass_nograd_batch_reserves_values_train = jit(
Partial(
forward_pass_nograd,
static_dict=Hashabledict(reserves_values_train_static_dict),
pool=pool,
)
)
reserves_values_test_static_dict = base_static_dict.copy()
reserves_values_test_static_dict["return_val"] = "reserves_and_values"
reserves_values_test_static_dict["bout_length"] = data_dict["bout_length_test"]
partial_forward_pass_nograd_batch_reserves_values_test = jit(
Partial(
forward_pass_nograd,
static_dict=Hashabledict(reserves_values_test_static_dict),
pool=pool,
)
)
# Continuous forward pass covering train + test for proper continuous metrics
continuous_optuna_static_dict = base_static_dict.copy()
continuous_optuna_static_dict["return_val"] = "reserves_and_values"
continuous_optuna_static_dict["bout_length"] = original_bout_length + data_dict["bout_length_test"]
partial_forward_pass_continuous_optuna = jit(
Partial(
forward_pass_nograd,
static_dict=Hashabledict(continuous_optuna_static_dict),
pool=pool,
)
)
# Initialize Optuna manager
optuna_manager = OptunaManager(run_fingerprint)
optuna_manager.setup_study(
multi_objective=run_fingerprint["optimisation_settings"]["optuna_settings"][
"multi_objective"
]
)
run_fingerprint["optimisation_settings"]["optuna_settings"]["parameter_config"][
"logit_lamb"
] = {
"low": float(
memory_days_to_logit_lamb(
0.5, chunk_period=base_static_dict["chunk_period"]
)
),
"high": float(
memory_days_to_logit_lamb(
base_static_dict["max_memory_days"],
chunk_period=base_static_dict["chunk_period"],
)
),
"log_scale": False,
}
# Get optuna-specific settings
optuna_settings = run_fingerprint["optimisation_settings"]["optuna_settings"]
expand_around = optuna_settings.get("expand_around", True)
overfitting_penalty = optuna_settings.get("overfitting_penalty", 0.0)
# Create objective with parameter configuration and validation
def objective(trial):
try:
param_config = run_fingerprint["optimisation_settings"][
"optuna_settings"
]["parameter_config"]
if run_fingerprint["optimisation_settings"]["optuna_settings"][
"make_scalar"
]:
# Set scalar=True for all parameter configurations
for param_key in param_config:
param_config[param_key]["scalar"] = True
trial_params = create_trial_params(
trial, param_config, params, run_fingerprint, n_assets, expand_around=expand_around
)
# Training evaluation
train_outputs = partial_forward_pass_nograd_batch_reserves_values_train(
trial_params,
(data_dict["start_idx"], 0),
data_dict["prices"],
)
# Calculate objectives for each evaluation point through slicing
train_objectives = []
for start_offset in evaluation_starts:
# Calculate relative indices for slicing
start_idx = start_offset - data_dict["start_idx"]
end_idx = start_idx + data_dict["bout_length"]
# Slice the relevant portions of the full trajectory
train_value = _calculate_return_value(
run_fingerprint["return_val"],
train_outputs["reserves"][start_idx:end_idx],
data_dict["prices"][start_idx:end_idx],
train_outputs["value"][start_idx:end_idx],
initial_reserves=train_outputs["reserves"][start_idx],
)
train_objectives.append(train_value)
mean_train_value = jnp.sum(jnp.array(train_objectives)) / len(train_objectives)
train_value = _calculate_return_value(
run_fingerprint["return_val"],
train_outputs["reserves"],
train_outputs["prices"],
train_outputs["value"],
initial_reserves=train_outputs["reserves"][0],
)
train_sharpe = _calculate_return_value(
"sharpe",
train_outputs["reserves"],
train_outputs["prices"],
train_outputs["value"],
)
train_return = (
train_outputs["value"][-1] / train_outputs["value"][0] - 1.0
)
train_returns_over_hodl = _calculate_return_value(
"returns_over_hodl",
train_outputs["reserves"],
train_outputs["prices"],
train_outputs["value"],
initial_reserves=train_outputs["reserves"][0],
)
train_returns_over_uniform_hodl = _calculate_return_value(
"returns_over_uniform_hodl",
train_outputs["reserves"],
train_outputs["prices"],
train_outputs["value"],
initial_reserves=train_outputs["reserves"][0],
)
# Test period evaluation using continuous forward pass
# This ensures test metrics reflect continuous simulation from training
continuous_outputs = partial_forward_pass_continuous_optuna(
trial_params,
(data_dict["start_idx"], 0),
data_dict["prices"],
)
# Calculate continuous test metrics first (always needed)
continuous_prices = data_dict["prices"][
data_dict["start_idx"]:data_dict["start_idx"] + original_bout_length + data_dict["bout_length_test"]
]
continuous_test_dict = {
"value": continuous_outputs["value"],
"reserves": continuous_outputs["reserves"],
}
continuous_test_metrics = calculate_continuous_test_metrics(
continuous_test_dict,
original_bout_length,
data_dict["bout_length_test"],
continuous_prices,
)
# Calculate validation metrics
train_length = data_dict["bout_length"]
if val_fraction > 0:
# Validation period exists between train and test
validation_reserves = continuous_outputs["reserves"][train_length:original_bout_length]
validation_value_arr = continuous_outputs["value"][train_length:original_bout_length]
validation_prices = continuous_outputs["prices"][train_length:original_bout_length]
validation_value = _calculate_return_value(
run_fingerprint["return_val"],
validation_reserves,
validation_prices,
validation_value_arr,
initial_reserves=validation_reserves[0],
)
validation_sharpe = _calculate_return_value(
"sharpe",
validation_reserves,
validation_prices,
validation_value_arr,
)
validation_return = (
validation_value_arr[-1] / validation_value_arr[0]
- 1.0
)
validation_returns_over_hodl = _calculate_return_value(
"returns_over_hodl",
validation_reserves,
validation_prices,
validation_value_arr,
initial_reserves=validation_reserves[0],
)
validation_returns_over_uniform_hodl = _calculate_return_value(
"returns_over_uniform_hodl",
validation_reserves,
validation_prices,
validation_value_arr,
initial_reserves=validation_reserves[0],
)
else:
# No validation period - use continuous test metrics
validation_value = continuous_test_metrics.get(run_fingerprint["return_val"], continuous_test_metrics["sharpe"])
validation_sharpe = continuous_test_metrics["sharpe"]
validation_return = continuous_test_metrics["return"]
validation_returns_over_hodl = continuous_test_metrics["returns_over_hodl"]
validation_returns_over_uniform_hodl = continuous_test_metrics["returns_over_uniform_hodl"]
# Log both training and validation metrics
# optuna_manager.logger.info(f"Trial {trial.number}:")
optuna_manager.logger.info(
f"Training {trial.number}, Return over HODL: {train_returns_over_hodl}"
)
optuna_manager.logger.info(
f"Training {trial.number}, Return: {train_return}"
)
optuna_manager.logger.info(
f"Training {trial.number}, Sharpe: {train_sharpe}"
)
optuna_manager.logger.info(
f"Training {trial.number}, {run_fingerprint['return_val']}: {train_value}"
)
optuna_manager.logger.info(
f"Validation {trial.number}, Return over HODL: {validation_returns_over_hodl}"
)
optuna_manager.logger.info(
f"Validation {trial.number}, Return: {validation_return}"
)
optuna_manager.logger.info(
f"Validation {trial.number}, Sharpe: {validation_sharpe}"
)
optuna_manager.logger.info(
f"Validation {trial.number}, {run_fingerprint['return_val']}: {validation_value}"
)
for i, value in enumerate(train_objectives):
optuna_manager.logger.info(
f"Training {trial.number}, Evaluation point {i}: {value}"
)
optuna_manager.logger.info(
f"Training {trial.number}, Mean value: {mean_train_value}"
)
# Store validation value as a trial attribute
trial.set_user_attr("validation_value", validation_value)
trial.set_user_attr(
"validation_returns_over_hodl", validation_returns_over_hodl
)
trial.set_user_attr("validation_returns_over_uniform_hodl", validation_returns_over_uniform_hodl)
trial.set_user_attr("validation_sharpe", validation_sharpe)
trial.set_user_attr("validation_return", validation_return)
trial.set_user_attr("train_value", train_value)
trial.set_user_attr("train_returns_over_hodl", train_returns_over_hodl)
trial.set_user_attr("train_returns_over_uniform_hodl", train_returns_over_uniform_hodl)
trial.set_user_attr("train_sharpe", train_sharpe)
trial.set_user_attr("train_return", train_return)
trial.set_user_attr("train_objectives", train_objectives)
trial.set_user_attr("mean_train_value", mean_train_value)
# Store continuous test metrics (same ones as train/val)
trial.set_user_attr("continuous_test_sharpe", continuous_test_metrics["sharpe"])
trial.set_user_attr("continuous_test_return", continuous_test_metrics["return"])
trial.set_user_attr("continuous_test_returns_over_hodl", continuous_test_metrics["returns_over_hodl"])
trial.set_user_attr("continuous_test_returns_over_uniform_hodl", continuous_test_metrics["returns_over_uniform_hodl"])
if run_fingerprint["optimisation_settings"]["optuna_settings"][
"multi_objective"
]:
return (
np.mean(train_objectives), # mean_return
np.min(train_objectives), # worst_case
-np.std(train_objectives), # stability
)
else:
# Apply overfitting penalty if configured
# Penalty is proportional to (train - validation) gap when train > validation
if overfitting_penalty > 0:
train_val_gap = float(mean_train_value) - float(validation_value)
if train_val_gap > 0: # Only penalize if training better than validation
penalty = overfitting_penalty * train_val_gap
penalized_value = float(mean_train_value) - penalty
trial.set_user_attr("overfitting_penalty_applied", float(penalty))
return penalized_value
return mean_train_value # Optimize on training value
except Exception as e:
import traceback
optuna_manager.logger.error(f"Trial {trial.number} failed: {str(e)}")
optuna_manager.logger.error(f"Full traceback:\n{traceback.format_exc()}")
raise e
# Run optimization
optuna_manager.optimize(objective)
# Check if any trials completed successfully
completed_trials = [
t for t in optuna_manager.study.trials
if t.state == optuna.trial.TrialState.COMPLETE
]
# Save results in SGD-compatible format for unified downstream analysis
if completed_trials:
sgd_format_path = save_optuna_results_sgd_format(
run_fingerprint=run_fingerprint,
study=optuna_manager.study,
n_assets=n_assets,
sorted_tokens=True,
)
if verbose:
print(f"Saved SGD-compatible results to: {sgd_format_path}")
if verbose:
n_total = len(optuna_manager.study.trials)
n_completed = len(completed_trials)
n_pruned = len([t for t in optuna_manager.study.trials if t.state == optuna.trial.TrialState.PRUNED])
n_failed = n_total - n_completed - n_pruned
print(f"\n{'='*60}")
print(f"OPTUNA OPTIMIZATION COMPLETE")
print(f"{'='*60}")
print(f"Trials: {n_completed} completed, {n_pruned} pruned, {n_failed} failed (of {n_total} total)")
if not completed_trials:
print("\nWARNING: No trials completed successfully!")
elif run_fingerprint["optimisation_settings"]["optuna_settings"]["multi_objective"]:
print(f"\nPareto front ({len(optuna_manager.study.best_trials)} trials):")
for i, trial in enumerate(optuna_manager.study.best_trials[:5]): # Show top 5
train_val = trial.values[0] if trial.values else 0
test_val = trial.user_attrs.get('validation_value', 0)
print(f" [{i+1}] Train={train_val:+.4f} Test={test_val:+.4f} (trial #{trial.number})")
if len(optuna_manager.study.best_trials) > 5:
print(f" ... and {len(optuna_manager.study.best_trials) - 5} more")
else:
best = optuna_manager.study.best_trial
train_sharpe = best.user_attrs.get('train_sharpe', best.value)
test_sharpe = best.user_attrs.get('validation_value', 0)
train_roh = best.user_attrs.get('train_returns_over_hodl', 0)
print(f"\nBest trial: #{best.number}")
print(f" Train (IS): sharpe={train_sharpe:+.4f} ret_over_hodl={train_roh:+.4f}")
print(f" Test (OOS): sharpe={test_sharpe:+.4f}")
print(f"{'='*60}")
if completed_trials:
# Convert best trial params to dict format like gradient descent returns
from quantammsim.core_simulator.result_exporter import _optuna_params_to_arrays
best_trial = optuna_manager.study.best_trial
last_trial = completed_trials[-1] # Most recent trial
best_params = _optuna_params_to_arrays(best_trial.params, n_assets)
best_params["subsidary_params"] = []
if "initial_weights_logits" not in best_params:
best_params["initial_weights_logits"] = jnp.zeros(n_assets)
last_params = _optuna_params_to_arrays(last_trial.params, n_assets)
last_params["subsidary_params"] = []
if "initial_weights_logits" not in last_params:
last_params["initial_weights_logits"] = jnp.zeros(n_assets)
if return_training_metadata:
# Run continuous forward passes for both best and last trials
best_continuous_outputs = partial_forward_pass_continuous_optuna(
best_params,
(data_dict["start_idx"], 0),
data_dict["prices"],
)
last_continuous_outputs = partial_forward_pass_continuous_optuna(
last_params,
(data_dict["start_idx"], 0),
data_dict["prices"],
)
# Extract final state at end of TRAINING period (for warm-starting)
# Use bout_length - 1 to get state at end of training
train_length = data_dict["bout_length"]
best_final_reserves = np.array(best_continuous_outputs["reserves"][train_length - 1])
best_final_weights = np.array(best_continuous_outputs["weights"][train_length - 1])
last_final_reserves = np.array(last_continuous_outputs["reserves"][train_length - 1])
last_final_weights = np.array(last_continuous_outputs["weights"][train_length - 1])
# Build train metrics for best trial
best_train_metrics = {
"sharpe": float(best_trial.user_attrs.get("train_sharpe", 0)),
"returns": float(best_trial.user_attrs.get("train_return", 0)),
"returns_over_hodl": float(best_trial.user_attrs.get("train_returns_over_hodl", 0)),
"returns_over_uniform_hodl": float(best_trial.user_attrs.get("train_returns_over_uniform_hodl", 0)),
run_fingerprint["return_val"]: float(best_trial.user_attrs.get("train_value", 0)),
}
# Build train metrics for last trial
last_train_metrics = {
"sharpe": float(last_trial.user_attrs.get("train_sharpe", 0)),
"returns": float(last_trial.user_attrs.get("train_return", 0)),
"returns_over_hodl": float(last_trial.user_attrs.get("train_returns_over_hodl", 0)),
"returns_over_uniform_hodl": float(last_trial.user_attrs.get("train_returns_over_uniform_hodl", 0)),
run_fingerprint["return_val"]: float(last_trial.user_attrs.get("train_value", 0)),
}
# Compute continuous_test_metrics for best trial
continuous_prices = data_dict["prices"][
data_dict["start_idx"]:data_dict["start_idx"] + original_bout_length + data_dict["bout_length_test"]
]
best_continuous_dict = {
"value": best_continuous_outputs["value"],
"reserves": best_continuous_outputs["reserves"],
}
best_continuous_test_metrics = calculate_continuous_test_metrics(
best_continuous_dict,
original_bout_length,
data_dict["bout_length_test"],
continuous_prices
)
# Compute continuous_test_metrics for last trial
last_continuous_dict = {
"value": last_continuous_outputs["value"],
"reserves": last_continuous_outputs["reserves"],
}
last_continuous_test_metrics = calculate_continuous_test_metrics(
last_continuous_dict,
original_bout_length,
data_dict["bout_length_test"],
continuous_prices
)
# Return unified metadata matching gradient_descent format
metadata = {
"method": "optuna",
"epochs_trained": len(completed_trials),
# Last trial metrics
"last_train_metrics": [last_train_metrics],
"last_continuous_test_metrics": [last_continuous_test_metrics],
"last_val_metrics": None, # Optuna doesn't have validation holdout
"last_param_idx": 0,
"last_final_reserves": last_final_reserves,
"last_final_weights": last_final_weights,
# Best trial metrics
"best_train_metrics": [best_train_metrics],
"best_continuous_test_metrics": [best_continuous_test_metrics],
"best_val_metrics": None, # Optuna doesn't have validation holdout
"best_param_idx": 0,
"best_iteration": best_trial.number,
"best_metric_value": float(best_trial.value) if best_trial.value is not None else 0.0,
"best_final_reserves": best_final_reserves,
"best_final_weights": best_final_weights,
# Selection info
"selection_method": "best_train", # Optuna optimizes on training objective
"selection_metric": run_fingerprint["return_val"],
# Legacy fields (for backward compat)
"final_train_metrics": [best_train_metrics],
"final_continuous_test_metrics": [best_continuous_test_metrics],
"final_objective": float(best_trial.value) if best_trial.value is not None else 0.0,
"final_weights": best_final_weights,
"final_reserves": best_final_reserves,
# Provenance
"run_location": run_location,
"run_fingerprint": deepcopy(run_fingerprint),
"checkpoint_returns": None,
# Optuna-specific extras
"n_trials": len(completed_trials),
"best_value": float(best_trial.value) if best_trial.value is not None else None,
}
if verbose:
# Print continuous test metrics (computed from actual forward pass)
print(f"\nContinuous test metrics (from forward pass):")
print(f" Best trial #{best_trial.number}:")
print(f" Train (IS): sharpe={best_train_metrics.get('sharpe', 0):+.4f} "
f"ret_over_hodl={best_train_metrics.get('returns_over_hodl', 0):+.4f}")
print(f" Test (OOS): sharpe={best_continuous_test_metrics.get('sharpe', 0):+.4f} "
f"ret_over_hodl={best_continuous_test_metrics.get('returns_over_uniform_hodl', 0):+.4f}")
if best_trial.number != last_trial.number:
print(f" Last trial #{last_trial.number}:")
print(f" Train (IS): sharpe={last_train_metrics.get('sharpe', 0):+.4f} "
f"ret_over_hodl={last_train_metrics.get('returns_over_hodl', 0):+.4f}")
print(f" Test (OOS): sharpe={last_continuous_test_metrics.get('sharpe', 0):+.4f} "
f"ret_over_hodl={last_continuous_test_metrics.get('returns_over_uniform_hodl', 0):+.4f}")
return best_params, metadata
return best_params
else:
if return_training_metadata:
return None, {
"method": "optuna",
"n_trials": 0,
"error": "No trials completed",
"epochs_trained": 0,
# Last trial metrics (none available)
"last_train_metrics": None,
"last_continuous_test_metrics": None,
"last_final_reserves": None,
"last_final_weights": None,
# Best trial metrics (none available)
"best_train_metrics": None,
"best_continuous_test_metrics": None,
"best_final_reserves": None,
"best_final_weights": None,
# Selection info
"selection_method": "best_train",
"selection_metric": run_fingerprint.get("return_val", "sharpe"),
"best_param_idx": 0,
# Legacy fields (for backward compat)
"final_objective": float("-inf"),
"final_train_metrics": None,
"final_continuous_test_metrics": None,
"final_reserves": None,
"final_weights": None,
# Provenance
"run_location": run_location,
"run_fingerprint": deepcopy(run_fingerprint),
"checkpoint_returns": None,
}
return None
else:
raise NotImplementedError
[docs]
def do_run_on_historic_data(
run_fingerprint,
params={},
root=None,
price_data=None,
verbose=False,
raw_trades=None,
fees=None,
gas_cost=None,
arb_fees=None,
fees_df=None,
gas_cost_df=None,
arb_fees_df=None,
lp_supply_df=None,
do_test_period=False,
low_data_mode=False,
preslice_burnin=True,
):
"""Execute a forward-pass simulation with fixed parameters.
Runs the full simulation pipeline — price loading, weight calculation,
arbitrage, and metric computation — using pre-trained (or manually
specified) strategy parameters. This is the primary entry point for
post-training evaluation, walk-forward OOS testing, and visualisation.
Parameters
----------
run_fingerprint : dict
Master configuration dict (same structure as
:func:`train_on_historic_data`).
params : dict or list of dict
Strategy parameters. A single dict runs one simulation; a list
of dicts runs multiple parameter sets in parallel via ``vmap``.
root : str, optional
Root directory for data files.
price_data : array-like or DataFrame, optional
Pre-loaded price data. When None, loaded from parquet files.
verbose : bool, optional
Print progress information (default False).
raw_trades : DataFrame, optional
Real trade data to inject. Columns: unix timestamp (minute),
token_in, token_out, amount_in.
fees : float, optional
Swap fee override (e.g. 0.003 for 30 bps).
gas_cost : float, optional
Gas cost override per transaction.
arb_fees : float, optional
Arbitrageur fee override.
fees_df : DataFrame, optional
Time-varying swap fees (columns: unix, fee).
gas_cost_df : DataFrame, optional
Time-varying gas costs (columns: unix, gas_cost).
arb_fees_df : DataFrame, optional
Time-varying arb fees (columns: unix, arb_fee).
lp_supply_df : DataFrame, optional
Time-varying LP supply changes.
do_test_period : bool, optional
If True, also run the OOS test period defined by
``endDateString`` to ``endTestDateString`` (default False).
low_data_mode : bool, optional
If True, drop raw price arrays from the output dict to reduce
memory usage (default False).
preslice_burnin : bool, optional
If True, pre-slice data to ``max_memory_days`` of burn-in plus
the simulation period (default True). Set False to load all
available history.
Returns
-------
dict or tuple[dict, dict]
When ``do_test_period=False``: a single results dict with keys
including ``values``, ``reserves``, ``weights``,
``coarse_weights``, ``objective``, and per-asset breakdowns.
When ``do_test_period=True``: ``(train_results, test_results)``.
For multiple parameter sets, each value in the dict is a list
(one entry per parameter set).
"""
# Set default values for run_fingerprint and its optimisation_settings
recursive_default_set(run_fingerprint, run_fingerprint_defaults)
# Extract various settings from run_fingerprint
chunk_period = run_fingerprint["chunk_period"]
weight_interpolation_period = run_fingerprint["weight_interpolation_period"]
use_alt_lamb = run_fingerprint["use_alt_lamb"]
use_pre_exp_scaling = run_fingerprint["use_pre_exp_scaling"]
weight_interpolation_method = run_fingerprint["weight_interpolation_method"]
arb_frequency = run_fingerprint["arb_frequency"]
rule = run_fingerprint["rule"]
# Create a list of unique tokens
unique_tokens = get_unique_tokens(run_fingerprint)
n_tokens = len(run_fingerprint["tokens"])
n_assets = n_tokens
# Generate all possible signature variations
all_sig_variations = get_sig_variations(n_assets)
max_memory_days = run_fingerprint["max_memory_days"]
np.random.seed(0)
dynamic_inputs_dict = get_trades_and_fees(
run_fingerprint,
raw_trades,
fees_df,
gas_cost_df,
arb_fees_df,
lp_supply_df,
do_test_period=do_test_period,
)
# Load price data if not provided
if price_data is None:
if verbose:
print("loading data")
data_dict = get_data_dict(
unique_tokens,
run_fingerprint,
data_kind=run_fingerprint["optimisation_settings"]["training_data_kind"],
root=root,
max_memory_days=max_memory_days,
start_date_string=run_fingerprint["startDateString"],
end_time_string=run_fingerprint["endDateString"],
start_time_test_string=run_fingerprint["endDateString"],
end_time_test_string=run_fingerprint["endTestDateString"],
max_mc_version=run_fingerprint["optimisation_settings"]["max_mc_version"],
price_data=price_data,
do_test_period=do_test_period,
preslice_burnin=preslice_burnin,
)
max_memory_days = data_dict["max_memory_days"]
if verbose:
print("max_memory_days: ", max_memory_days)
if run_fingerprint["optimisation_settings"]["training_data_kind"] == "mc":
# TODO: Handle MC data for post-training analysis
raise NotImplementedError
# create pool
pool = create_pool(rule)
# Create static dict using helper - with run-specific overrides
base_static_dict = create_static_dict(
run_fingerprint,
bout_length=data_dict["bout_length"],
all_sig_variations=all_sig_variations,
overrides={
"n_assets": n_assets,
"training_data_kind": run_fingerprint["optimisation_settings"]["training_data_kind"],
# Override fees if provided as function args
"fees": fees if fees is not None else run_fingerprint["fees"],
"arb_fees": arb_fees if arb_fees is not None else run_fingerprint["arb_fees"],
"gas_cost": gas_cost if gas_cost is not None else run_fingerprint["gas_cost"],
"do_trades": False if raw_trades is None else run_fingerprint["do_trades"],
# Include date strings for run-time use
"startDateString": run_fingerprint["startDateString"],
"endDateString": run_fingerprint["endDateString"],
"endTestDateString": run_fingerprint["endTestDateString"],
},
)
# Create static dictionaries for training and testing
reserves_values_train_static_dict = base_static_dict.copy()
reserves_values_train_static_dict["return_val"] = "reserves_and_values"
reserves_values_train_static_dict["bout_length"] = data_dict["bout_length"]
partial_forward_pass_nograd_batch_reserves_values_train = jit(
Partial(
forward_pass_nograd,
static_dict=Hashabledict(reserves_values_train_static_dict),
pool=pool,
)
)
if do_test_period:
reserves_values_test_static_dict = base_static_dict.copy()
reserves_values_test_static_dict["return_val"] = "reserves_and_values"
reserves_values_test_static_dict["bout_length"] = data_dict["bout_length_test"]
partial_forward_pass_nograd_batch_reserves_values_test = jit(
Partial(
forward_pass_nograd,
static_dict=Hashabledict(reserves_values_test_static_dict),
pool=pool,
)
)
# Ensure params is a list
if isinstance(params, dict):
params = [params]
total_params = len(params)
update_every = max(
math.floor(total_params / 10), 1
) # Update every 10% of the way through the number of param sets
output_dicts = []
if do_test_period:
output_dicts_test = []
# Process each set of parameters
for i in range(total_params):
param = params[i]
if i % update_every == 0:
if verbose:
tqdm.write(f"Processed {i+1} out of {total_params} parameters.")
# Run forward pass for training data
output_dict = partial_forward_pass_nograd_batch_reserves_values_train(
param,
(data_dict["start_idx"], 0),
data_dict["prices"],
dynamic_inputs_dict["train_period_trades"],
dynamic_inputs_dict["fees_array"],
dynamic_inputs_dict["gas_cost_array"],
dynamic_inputs_dict["arb_fees_array"],
)
if low_data_mode:
output_dict["final_prices"] = output_dict["prices"][-1]
output_dict["initial_reserves"] = output_dict["reserves"][0]
output_dict["initial_prices"] = output_dict["prices"][0]
del output_dict["prices"]
del output_dict["reserves"]
del output_dict["value"]
output_dicts.append(output_dict)
# Run forward pass for test data if required
if do_test_period:
output_dict_test = partial_forward_pass_nograd_batch_reserves_values_test(
param,
(data_dict["start_idx_test"], 0),
data_dict["prices"],
dynamic_inputs_dict["test_period_trades"],
dynamic_inputs_dict["test_fees_array"],
dynamic_inputs_dict["test_gas_cost_array"],
dynamic_inputs_dict["test_arb_fees_array"],
)
if low_data_mode:
output_dict_test["final_prices"] = output_dict_test["prices"][-1]
output_dict_test["initial_reserves"] = output_dict_test["reserves"][0]
output_dict_test["initial_prices"] = output_dict_test["prices"][0]
del output_dict_test["prices"]
del output_dict_test["reserves"]
del output_dict_test["value"]
output_dicts_test.append(output_dict_test)
# out = partial_forward_pass_nograd_batch(
# params[0],
# (data_dict["start_idx"], 0),
# )
# raise Exception("stop")
# If only one set of parameters, return as single dict instead of list
if len(output_dicts) == 1:
output_dicts = output_dicts[0]
output_dicts["data_dict"] = data_dict
if do_test_period:
output_dicts_test = output_dicts_test[0]
# Return results
gc.collect()
gc.collect()
# Clear any cached JAX computations to free memory
clear_caches()
if do_test_period:
return output_dicts, output_dicts_test
else:
return output_dicts
[docs]
def do_run_on_historic_data_with_provided_coarse_weights(
run_fingerprint,
coarse_weights,
params={},
root=None,
price_data=None,
verbose=False,
raw_trades=None,
fees=None,
gas_cost=None,
arb_fees=None,
fees_df=None,
gas_cost_df=None,
arb_fees_df=None,
lp_supply_df=None,
do_test_period=False,
low_data_mode=False,
):
"""Execute a simulation using pre-computed coarse weights.
Like :func:`do_run_on_historic_data`, but bypasses the weight-calculation
step entirely. The caller provides ``coarse_weights`` directly, and
this function performs only fine-weight interpolation, arbitrage
simulation, and metric computation.
This is useful for replaying a trained strategy with externally-computed
or manually-specified weight trajectories, or for separating the weight
computation from the simulation for profiling or debugging.
Parameters
----------
run_fingerprint : dict
Master configuration dict.
coarse_weights : jnp.ndarray
Pre-computed coarse weights, shape ``(n_coarse_steps, n_assets)``.
params : dict or list of dict, optional
Strategy parameters (used only for ``initial_reserves`` and any
subsidiary parameters, not for weight computation).
root : str, optional
Root directory for data files.
price_data : array-like or DataFrame, optional
Pre-loaded price data.
verbose : bool, optional
Print progress (default False).
raw_trades : DataFrame, optional
Real trade data to inject.
fees : float, optional
Swap fee override.
gas_cost : float, optional
Gas cost override.
arb_fees : float, optional
Arbitrageur fee override.
fees_df : DataFrame, optional
Time-varying swap fees.
gas_cost_df : DataFrame, optional
Time-varying gas costs.
arb_fees_df : DataFrame, optional
Time-varying arb fees.
lp_supply_df : DataFrame, optional
Time-varying LP supply changes.
do_test_period : bool, optional
Run OOS test period (default False).
low_data_mode : bool, optional
Drop raw arrays from output to save memory (default False).
Returns
-------
dict or tuple[dict, dict]
Same structure as :func:`do_run_on_historic_data`.
"""
from quantammsim.pools.G3M.quantamm.weight_calculations.fine_weights import (
_jax_calc_coarse_weights,
_jax_fine_weights_from_actual_starts_and_diffs,
)
from quantammsim.pools.G3M.quantamm.quantamm_reserves import (
_jax_calc_quantAMM_reserves_with_dynamic_inputs,
)
# Set default values for run_fingerprint and its optimisation_settings
recursive_default_set(run_fingerprint, run_fingerprint_defaults)
# Extract various settings from run_fingerprint
chunk_period = run_fingerprint["chunk_period"]
weight_interpolation_period = run_fingerprint["weight_interpolation_period"]
use_alt_lamb = run_fingerprint["use_alt_lamb"]
use_pre_exp_scaling = run_fingerprint["use_pre_exp_scaling"]
weight_interpolation_method = run_fingerprint["weight_interpolation_method"]
arb_frequency = run_fingerprint["arb_frequency"]
rule = run_fingerprint["rule"]
# Create a list of unique tokens
unique_tokens = get_unique_tokens(run_fingerprint)
n_tokens = len(run_fingerprint["tokens"])
n_assets = n_tokens
# Generate all possible signature variations
all_sig_variations = get_sig_variations(n_assets)
max_memory_days = run_fingerprint["max_memory_days"]
np.random.seed(0)
dynamic_inputs_dict = get_trades_and_fees(
run_fingerprint,
raw_trades,
fees_df,
gas_cost_df,
arb_fees_df,
lp_supply_df,
do_test_period=do_test_period,
)
# Load price data if not provided
if price_data is None:
if verbose:
print("loading data")
data_dict = get_data_dict(
unique_tokens,
run_fingerprint,
data_kind=run_fingerprint["optimisation_settings"]["training_data_kind"],
root=root,
max_memory_days=max_memory_days,
start_date_string=run_fingerprint["startDateString"],
end_time_string=run_fingerprint["endDateString"],
start_time_test_string=run_fingerprint["endDateString"],
end_time_test_string=run_fingerprint["endTestDateString"],
max_mc_version=run_fingerprint["optimisation_settings"]["max_mc_version"],
price_data=price_data,
do_test_period=do_test_period,
)
max_memory_days = data_dict["max_memory_days"]
if verbose:
print("max_memory_days: ", max_memory_days)
if run_fingerprint["optimisation_settings"]["training_data_kind"] == "mc":
# TODO: Handle MC data for post-training analysis
raise NotImplementedError
# create pool
pool = create_pool(rule)
# Create static dict using helper - with run-specific overrides
base_static_dict = create_static_dict(
run_fingerprint,
bout_length=data_dict["bout_length"],
all_sig_variations=all_sig_variations,
overrides={
"n_assets": n_assets,
"training_data_kind": run_fingerprint["optimisation_settings"]["training_data_kind"],
# Override fees if provided as function args
"fees": fees if fees is not None else run_fingerprint["fees"],
"arb_fees": arb_fees if arb_fees is not None else run_fingerprint["arb_fees"],
"gas_cost": gas_cost if gas_cost is not None else run_fingerprint["gas_cost"],
"do_trades": False if raw_trades is None else run_fingerprint["do_trades"],
# Include date strings for run-time use
"startDateString": run_fingerprint["startDateString"],
"endDateString": run_fingerprint["endDateString"],
"endTestDateString": run_fingerprint["endTestDateString"],
},
)
# Create static dictionaries for training and testing
static_dict = base_static_dict.copy()
static_dict["return_val"] = "reserves_and_values"
static_dict["bout_length"] = data_dict["bout_length"]
training_data_kind = static_dict["training_data_kind"]
minimum_weight = static_dict.get("minimum_weight")
n_assets = static_dict["n_assets"]
return_val = static_dict["return_val"]
bout_length = static_dict["bout_length"]
# filter coarse weights using the start and end indices
coarse_weights = filter_coarse_weights_by_data_indices(coarse_weights, data_dict)
# take coarse weights and convert to array of fine weights
initial_weights = coarse_weights["weights"][0]
# Repeat the last row of coarse weights
coarse_weights_padded = jnp.vstack(
[coarse_weights["weights"], coarse_weights["weights"][-1]]
)
coarse_weight_changes = jnp.diff(coarse_weights_padded, axis=0)
actual_starts_cpu, scaled_diffs_cpu, target_weights_cpu = _jax_calc_coarse_weights(
coarse_weight_changes,
initial_weights,
minimum_weight,
params,
run_fingerprint["max_memory_days"],
chunk_period,
chunk_period,
1.0,
False,
)
weights = _jax_fine_weights_from_actual_starts_and_diffs(
actual_starts_cpu,
scaled_diffs_cpu,
initial_weights,
interpol_num=chunk_period + 1,
num=chunk_period + 1,
maximum_change=1.0,
method="linear",
)
# undo padding
weights = weights[: (-1 * chunk_period + 1)]
# Check that weights[::chunk_period] matches coarse_weights["weights"]
# Get weights at coarse timesteps
coarse_timestep_weights = weights[::chunk_period]
weights = weights[:-1]
# Compare with original coarse weights
weights_match = jnp.allclose(
coarse_timestep_weights, coarse_weights["weights"], rtol=1e-10
)
start_index = data_dict["start_idx"]
end_index = data_dict["end_idx"] - 1
local_prices = data_dict["prices"][start_index:end_index]
local_unix_values = data_dict["unix_values"][start_index:end_index]
# reserves = pool.calculate_reserves_with_fees(
# params,
# NestedHashabledict(static_dict),
# data_dict["prices"],
# start_index=None,
# local_prices=HashableArrayWrapper(local_prices),
# weights=HashableArrayWrapper(weights),
# initial_reserves=HashableArrayWrapper(params["initial_reserves"]),
# )
fees_array = dynamic_inputs_dict.get("fees_array")
arb_thresh_array = dynamic_inputs_dict.get("gas_cost_array")
arb_fees_array = dynamic_inputs_dict.get("arb_fees_array")
trade_array = dynamic_inputs_dict.get("trades")
lp_supply_array = dynamic_inputs_dict.get("lp_supply_array")
if fees_array is None:
fees_array = jnp.array([static_dict["fees"]])
if arb_thresh_array is None:
arb_thresh_array = jnp.array([static_dict["gas_cost"]])
if arb_fees_array is None:
arb_fees_array = jnp.array([static_dict["arb_fees"]])
# initial_pool_value = run_fingerprint["initial_pool_value"]
# initial_value_per_token = arb_acted_upon_weights[0] * initial_pool_value
# initial_reserves = initial_value_per_token / arb_acted_upon_local_prices[0]
initial_reserves = params["initial_reserves"]
# any of fees_array, arb_thresh_array, arb_fees_array, trade_array, and lp_supply_array
# can be singletons, in which case we repeat them for the length of the bout.
# Determine the maximum leading dimension
max_len = bout_length - 1
if run_fingerprint["arb_frequency"] != 1:
max_len = max_len // run_fingerprint["arb_frequency"]
fees_array = fees_array[:max_len]
arb_thresh_array = arb_thresh_array[:max_len]
arb_thresh_array = arb_thresh_array * 0.0
arb_fees_array = arb_fees_array[:max_len]
if lp_supply_array is not None:
lp_supply_array = lp_supply_array[:max_len]
if trade_array is not None:
trade_array = trade_array[:max_len]
# Broadcast input arrays to match the maximum leading dimension.
# If they are singletons, this will just repeat them for the length of the bout.
# If they are arrays of length bout_length, this will cause no change.
fees_array_broadcast = jnp.broadcast_to(
fees_array, (max_len,) + fees_array.shape[1:]
)
arb_thresh_array_broadcast = jnp.broadcast_to(
arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:]
)
arb_fees_array_broadcast = jnp.broadcast_to(
arb_fees_array, (max_len,) + arb_fees_array.shape[1:]
)
# if lp_supply_array is not provided, we set it to a constant of 1.0
if lp_supply_array is None:
lp_supply_array = jnp.array(1.0)
lp_supply_array_broadcast = jnp.broadcast_to(
lp_supply_array, (max_len,) + lp_supply_array.shape[1:]
)
# if we are doing trades, the trades array must be of the same length as the other arrays
if run_fingerprint["do_trades"]:
assert trade_array.shape[0] == max_len
reserves = _jax_calc_quantAMM_reserves_with_dynamic_inputs(
initial_reserves,
weights,
local_prices,
fees_array_broadcast,
arb_thresh_array_broadcast,
arb_fees_array_broadcast,
jnp.array(static_dict["all_sig_variations"]),
None,
run_fingerprint["do_trades"],
run_fingerprint["do_arb"],
run_fingerprint["noise_trader_ratio"],
lp_supply_array_broadcast,
)
value_over_time = jnp.sum(jnp.multiply(reserves, local_prices), axis=-1)
return_dict = {
"final_reserves": reserves[-1],
"final_value": (reserves[-1] * local_prices[-1]).sum(),
"value": value_over_time,
"prices": local_prices,
"reserves": reserves,
"weights": weights,
"coarse_weight_changes": coarse_weight_changes,
"data_dict": data_dict,
"unix_values": local_unix_values,
}
return return_dict