Training Pipeline
This guide explains how to train QuantAMM pool strategies using quantammsim’s gradient-based optimization.
Overview
The training pipeline uses JAX for automatic differentiation to optimize strategy parameters. The main entry point is train_on_historic_data:
from quantammsim.runners.jax_runners import train_on_historic_data
train_on_historic_data(
run_fingerprint=run_fingerprint,
iterations_per_print=10,
verbose=True
)
Basic Training Setup
1. Configure the Run Fingerprint
run_fingerprint = {
# Data settings
"startDateString": "2023-01-01 00:00:00",
"endDateString": "2023-12-01 00:00:00",
"endTestDateString": "2024-01-01 00:00:00", # Optional test period
"tokens": ["BTC", "ETH"],
# Strategy
"rule": "momentum",
# Pool settings
"initial_pool_value": 1000000.0,
"fees": 0.003, # 30 bps
"do_arb": True,
# Optimization settings
"optimisation_settings": {
"method": "gradient_descent",
"optimiser": "adam",
"base_lr": 0.1,
"n_iterations": 1000,
"batch_size": 8,
"n_parameter_sets": 4,
},
# Objective (daily_log_sharpe is the default and recommended)
"return_val": "daily_log_sharpe",
}
2. Run Training
from quantammsim.runners.jax_runners import train_on_historic_data
train_on_historic_data(
run_fingerprint=run_fingerprint,
verbose=True
)
Optimizer Configuration
Available Optimizers
# Adam (recommended)
run_fingerprint["optimisation_settings"]["optimiser"] = "adam"
# AdamW (with weight decay)
run_fingerprint["optimisation_settings"]["optimiser"] = "adamw"
# SGD
run_fingerprint["optimisation_settings"]["optimiser"] = "sgd"
Learning Rate Schedules
Constant Learning Rate:
run_fingerprint["optimisation_settings"].update({
"lr_schedule_type": "constant",
"base_lr": 0.1,
})
Warmup with Cosine Decay:
run_fingerprint["optimisation_settings"].update({
"lr_schedule_type": "warmup_cosine",
"base_lr": 0.1,
"warmup_steps": 100,
"min_lr": 1e-6,
})
Plateau-Based Decay:
run_fingerprint["optimisation_settings"].update({
"use_plateau_decay": True,
"decay_lr_plateau": 100, # Iterations without improvement
"decay_lr_ratio": 0.8, # Multiply LR by this factor
})
Gradient Clipping
run_fingerprint["optimisation_settings"].update({
"use_gradient_clipping": True,
"clip_norm": 10.0,
})
Batch Training
The training uses batched gradient computation for efficiency:
run_fingerprint["optimisation_settings"].update({
"batch_size": 8, # Number of time periods per batch
"n_parameter_sets": 4, # Parallel parameter sets to optimize
"sample_method": "uniform", # How to sample training periods
})
The bout_offset parameter controls training data variety:
# Train on different starting points within this offset window
run_fingerprint["bout_offset"] = 24 * 60 * 7 # 1 week in minutes
Initial Parameters
Set starting values for optimization:
run_fingerprint.update({
"initial_memory_length": 10.0, # EWMA memory in days
"initial_k_per_day": 20, # Trading intensity
"initial_weights_logits": 1.0, # Starting weight distribution
"initial_log_amplitude": 0.0, # Signal amplitude
"initial_raw_width": 0.0, # Channel width
"initial_raw_exponents": 0.0, # Power exponents
"initial_pre_exp_scaling": 0.5, # Pre-exponential scaling
})
Training Objectives
Set the objective function:
# Daily log-return Sharpe (default, recommended)
run_fingerprint["return_val"] = "daily_log_sharpe"
# Annualised Sharpe ratio
run_fingerprint["return_val"] = "sharpe"
# Maximize total return
run_fingerprint["return_val"] = "returns"
# Maximize return over holding initial portfolio
run_fingerprint["return_val"] = "returns_over_hodl"
# Maximize Sortino ratio
run_fingerprint["return_val"] = "sortino"
Advanced: Hessian-Based Training
For second-order optimization (experimental):
run_fingerprint["optimisation_settings"]["train_on_hessian_trace"] = True
This uses Hessian trace information but is more computationally expensive.
Backpropagation Module
The training pipeline uses the quantammsim.training.backpropagation module internally. Key functions:
Objective Factories:
batched_objective_factory- Creates batched loss functionbatched_objective_with_hessian_factory- Includes Hessian computation
Update Factories:
update_factory- Basic gradient updateupdate_factory_with_optax- Uses Optax optimizers (Adam, SGD, etc.)update_with_hessian_factory_with_optax- Hessian-aware updates
Optimizer Creation:
from quantammsim.training.backpropagation import create_optimizer_chain
# Creates an Optax optimizer chain based on run_fingerprint settings
optimizer = create_optimizer_chain(run_fingerprint)
Straight-Through Estimators
For improved gradient flow through clipping operations:
run_fingerprint.update({
"ste_max_change": True, # STE for weight change clipping
"ste_min_max_weight": True, # STE for min/max weight bounds
})
These allow gradients to flow through otherwise non-differentiable clipping operations.
Monitoring Training
Training progress is printed at intervals:
train_on_historic_data(
run_fingerprint=run_fingerprint,
iterations_per_print=10, # Print every 10 iterations
verbose=True
)
Output includes:
Current iteration
Training objective value
Learning rate
Best parameters found
Saving and Loading
Training state is automatically saved. To resume:
train_on_historic_data(
run_fingerprint=run_fingerprint,
run_location="path/to/saved/run",
force_init=False # Don't reinitialize, load existing state
)
Example: Complete Training Script
from quantammsim.runners.jax_runners import train_on_historic_data
run_fingerprint = {
# Data
"startDateString": "2023-01-01 00:00:00",
"endDateString": "2023-10-01 00:00:00",
"endTestDateString": "2024-01-01 00:00:00",
"tokens": ["BTC", "ETH", "SOL"],
# Strategy
"rule": "momentum",
"initial_pool_value": 1000000.0,
"fees": 0.003,
"do_arb": True,
"arb_quality": 1.0,
# Weight calculation
"chunk_period": 1440, # Daily
"maximum_change": 0.001,
# Initial params
"initial_memory_length": 10.0,
"initial_k_per_day": 20,
# Optimization
"optimisation_settings": {
"method": "gradient_descent",
"optimiser": "adam",
"base_lr": 0.1,
"n_iterations": 500,
"batch_size": 8,
"n_parameter_sets": 4,
"lr_schedule_type": "warmup_cosine",
"warmup_steps": 50,
"min_lr": 1e-5,
"use_gradient_clipping": True,
"clip_norm": 10.0,
},
# Objective
"return_val": "daily_log_sharpe",
}
train_on_historic_data(
run_fingerprint=run_fingerprint,
iterations_per_print=10,
verbose=True
)
See Also
Run Fingerprints — Complete run fingerprint reference
Metrics Reference — Available training objectives
Robustness Features — Regularisation techniques (early stopping, SWA, price noise)
Training Pools — Optuna hyperparameter optimization
Walk-Forward Analysis — Walk-forward validation for overfitting detection
Ensemble Training — Ensemble training for implicit regularisation
Hyperparameter Tuning — Meta-optimization of training hyperparameters
train_on_historic_data()— API reference