Core Simulator
Forward Pass
Forward pass simulation pipeline and financial metric calculation.
This module implements the core simulation loop for AMM pool strategies: prices → parameterised weight rule → simulated arbitrage → reserve dynamics → financial metrics.
The forward pass is the innermost computation in the three-level optimization hierarchy: forward pass (per-window) → training loop (gradient descent over windows) → hyperparameter tuner (meta-optimization over training configs). It is JIT-compiled via JAX and fully differentiable, enabling gradient-based optimization of strategy parameters.
Key components:
forward_pass/forward_pass_nograd: Entry points that wire pool dynamics to metric calculation.forward_passpropagates gradients;forward_pass_nogradwraps inputs instop_gradientfor evaluation._calculate_return_value: Dispatch registry mapping ~30 metric names to their implementations, from simple returns to risk-adjusted ratios.Metric helpers (
_daily_log_sharpe,_calculate_max_drawdown, etc.): Pure-JAX implementations of financial metrics, designed for differentiability and JIT compatibility._apply_price_noise: Multiplicative log-normal noise for data augmentation during training.
Notes
All time-series inputs use minute resolution (1 timestep = 1 minute). Duration parameters
in metric helpers (e.g., duration=24*60) are in minutes. Annualization assumes 365
calendar days.
The default training metric is daily_log_sharpe (not sharpe). This uses log returns
sampled at daily frequency, which is more numerically stable and better aligned with
standard financial practice than minute-frequency arithmetic Sharpe.
- forward_pass(params, start_index, prices, trades_array=None, fees_array=None, gas_cost_array=None, arb_fees_array=None, pool=None, static_dict={'all_sig_variations': None, 'arb_fees': 0.0, 'arb_frequency': 1, 'bout_length': 1000, 'chunk_period': 60, 'do_trades': False, 'fees': 0.0, 'gas_cost': 0.0, 'initial_pool_value': 1000000.0, 'max_memory_days': 365.0, 'maximum_change': 1.0, 'n_assets': 3, 'return_val': 'reserves', 'rule': 'momentum', 'run_type': 'normal', 'training_data_kind': 'historic', 'use_alt_lamb': False, 'use_pre_exp_scaling': True, 'weight_interpolation_method': 'linear', 'weight_interpolation_period': 60})[source]
Simulates a forward pass of a liquidity pool using specified parameters and market data.
This function models the behavior of a liquidity pool over a given period, considering various factors such as fees, gas costs, and arbitrage fees. It calculates reserves and other metrics based on the provided parameters and market prices.
- Parameters:
params (dict) – A dictionary containing the parameters for the simulation, such as initial weights and other configuration settings.
start_index (array-like) – The starting index for the simulation, used to slice the price data.
prices (array-like) – A 2D array of market prices for the assets involved in the simulation.
trades_array (array-like, optional) – An array of trades to be considered in the simulation. Defaults to None.
fees_array (array-like, optional) – An array of fees to be applied during the simulation. Defaults to None.
gas_cost_array (array-like, optional) – An array of gas costs to be considered in the simulation. Defaults to None.
arb_fees_array (array-like, optional) – An array of arbitrage fees to be applied during the simulation. Defaults to None.
pool (object) – An instance of a pool object that provides methods to calculate reserves based on the inputs. Must be provided.
static_dict (dict, optional) – A dictionary of static configuration values for the simulation, such as bout length, number of assets, and return value type. Defaults to a predefined set of values.
- Returns:
Depending on the return_val specified in static_dict, the function returns different types of results:
”reserves”: A dictionary containing the reserves over time.
”sharpe”: The Sharpe ratio of the pool returns.
”returns”: The total return over the simulation period.
”returns_over_hodl”: The return over a hold strategy.
”greatest_draw_down”: The greatest drawdown during the simulation.
”alpha”: Not implemented.
”value”: The value of the pool over time.
”reserves_and_values”: A dictionary containing final reserves, final value, value over time, prices, and reserves.
- Return type:
- Raises:
ValueError – If the pool is not provided.
NotImplementedError – If the return_val is set to “alpha” or any other unsupported value.
Notes
The function is decorated with jax.jit for performance optimization, with static arguments specified for JIT compilation.
The function handles different cases for fees and trades, adjusting the calculation method accordingly:
If any of fees_array, gas_cost_array, arb_fees_array, or trades_array is provided, it uses pool.calculate_reserves_with_dynamic_inputs.
If any of fees, gas_cost, or arb_fees in static_dict is a nonzero scalar value, it uses pool.calculate_reserves_with_fees.
If all fees and costs are zero and no trades are provided, it uses pool.calculate_reserves_zero_fees.
The function supports different types of return values, allowing for flexible output based on the simulation needs.
The arb_frequency in static_dict can alter the frequency of arbitrage operations, affecting the reserves calculation and this size of returned arrays.
Examples
>>> forward_pass(params, start_index, prices, pool=my_pool) {'reserves': array([...])}
- forward_pass_nograd(params, start_index, prices, trades_array=None, fees_array=None, gas_cost_array=None, arb_fees_array=None, pool=None, static_dict={'all_sig_variations': None, 'arb_fees': 0.0, 'arb_frequency': 1, 'bout_length': 1000, 'chunk_period': 60, 'do_trades': False, 'fees': 0.0, 'gas_cost': 0.0, 'initial_pool_value': 1000000.0, 'max_memory_days': 365.0, 'maximum_change': 1.0, 'n_assets': 3, 'return_val': 'reserves', 'rule': 'momentum', 'run_type': 'normal', 'training_data_kind': 'historic', 'use_alt_lamb': False, 'use_pre_exp_scaling': True, 'weight_interpolation_method': 'linear', 'weight_interpolation_period': 60})[source]
Simulates a forward pass of a liquidity pool without gradient tracking using specified parameters and market data.
This function models the behavior of a liquidity pool over a given period, similar to forward_pass, but ensures that no gradients are tracked for the input parameters and data. It is useful for scenarios where gradient computation is not required, such as evaluation or inference.
- Parameters:
params (dict) – A dictionary containing the parameters for the simulation, such as initial weights and other configuration settings.
start_index (array-like) – The starting index for the simulation, used to slice the price data.
prices (array-like) – A 2D array of market prices for the assets involved in the simulation.
trades_array (array-like, optional) – An array of trades to be considered in the simulation. Defaults to None.
fees_array (array-like, optional) – An array of fees to be applied during the simulation. Defaults to None.
gas_cost_array (array-like, optional) – An array of gas costs to be considered in the simulation. Defaults to None.
arb_fees_array (array-like, optional) – An array of arbitrage fees to be applied during the simulation. Defaults to None.
pool (object) – An instance of a pool object that provides methods to calculate reserves based on the inputs. Must be provided.
static_dict (dict, optional) – A dictionary of static configuration values for the simulation, such as bout length, number of assets, and return value type. Defaults to a predefined set of values.
- Returns:
Depending on the return_val specified in static_dict, the function returns different types of results:
”reserves”: A dictionary containing the reserves over time.
”sharpe”: The Sharpe ratio of the pool returns.
”returns”: The total return over the simulation period.
”returns_over_hodl”: The return over a hold strategy.
”greatest_draw_down”: The greatest drawdown during the simulation.
”alpha”: Not implemented.
”value”: The value of the pool over time.
”reserves_and_values”: A dictionary containing final reserves, final value, value over time, prices, and reserves.
- Return type:
- Raises:
ValueError – If the pool is not provided.
NotImplementedError – If the return_val is set to “alpha” or any other unsupported value.
Notes
The function is decorated with jax.jit for performance optimization, with static arguments specified for JIT compilation.
The function handles different cases for fees and trades, adjusting the calculation method accordingly:
If any of fees_array, gas_cost_array, arb_fees_array, or trades_array is provided, it uses pool.calculate_reserves_with_dynamic_inputs.
If any of fees, gas_cost, or arb_fees in static_dict is a nonzero scalar value, it uses pool.calculate_reserves_with_fees.
If all fees and costs are zero and no trades are provided, it uses pool.calculate_reserves_zero_fees.
The function supports different types of return values, allowing for flexible output based on the simulation needs.
The arb_frequency in static_dict can alter the frequency of arbitrage operations, affecting the reserves calculation and this size of returned arrays.
- The function uses jax.lax.stop_gradient to ensure that no gradients are tracked
for the input parameters and data.
Examples
>>> forward_pass_nograd(params, start_index, prices, pool=my_pool) {'reserves': array([...])}
Parameter Utilities
Parameter utilities for strategy parameterization, serialization, and loading.
This module handles the full lifecycle of strategy parameters:
Initialization:
init_params/init_params_singletoncreate parameter dicts from human-readable initial values (memory days, k per day) by converting to the internal reparameterized form (logit_lamb, log_k, etc.).Reparameterization: Functions like
calc_lamb,calc_alt_lamb,squareplus, and their inverses convert between human-interpretable values and the unconstrained spaces used for gradient-based optimization.Serialization:
NumpyEncoder,dict_of_jnp_to_np,dict_of_jnp_to_list,dict_of_np_to_jnphandle conversion between JAX arrays, NumPy arrays, and JSON-serializable Python types.Loading:
load_or_init,load,load_manually,retrieve_bestload saved training checkpoints with various selection strategies (best train, best test, best-train-above-test-threshold, etc.).Grid generation:
create_product_of_linspaces,generate_params_combinationsproduce parameter grids for heatmap evaluations.
The key reparameterizations are:
lambda (λ): EWMA decay factor in [0, 1], stored as
logit_lamb = log(λ/(1-λ)). Converted to/from human-readablememory_daysvia cubic-root inversion.k: Weight update aggressiveness, stored as
log_k = log2(k / memory_days). This decouples scale from memory length.squareplus: Smooth, non-negative activation
(x + sqrt(x² + 4)) / 2, an algebraic (non-transcendental) replacement for softplus. Used for exponent params.
Notes
The memory_days ↔ lambda conversion involves solving a cubic equation analytically.
Both NumPy (memory_days_to_lamb) and JAX (jax_memory_days_to_lamb) versions
exist; the NumPy version includes safe division guards for zero memory days, while the
JAX version relies on jnp.where for the zero case.
- squareplus(x)[source]
Algebraic (non-transcendental) replacement for softplus.
Computes
(x + sqrt(x² + 4)) / 2, which maps R → R⁺ smoothly. Unlike softplus (log(1 + exp(x))), squareplus avoids transcendental functions and is thus cheaper to differentiate through and more JIT-friendly.- Parameters:
x (jnp.ndarray or float) – Input value(s).
- Returns:
Non-negative output(s), always > 0.
- Return type:
jnp.ndarray or float
References
Barron, J.T. (2021). “Squareplus: A Softplus-Like Algebraic Rectifier.” arXiv:2112.11687.
See also
inverse_squareplusInverse mapping R⁺ → R.
- check_run_fingerprint(run_fingerprint)[source]
Check that the run fingerprint is not malformed.
- Parameters:
run_fingerprint (dict) – The run fingerprint to validate.
- Return type:
None
- Raises:
AssertionError – If weight_interpolation_period is greater than chunk_period.
- default_set_or_get(dictionary, key, default, augment=True)[source]
Retrieves the value for a given key from a dictionary. If the key does not exist, it sets the key to a default value and returns the default value.
- Parameters:
dictionary (dict) – The dictionary to search for the key.
key (str) – The key to look up in the dictionary.
default (Any) – The default value to set and return if the key is not found.
augment (bool, optional) – If True, the default value is added to the dictionary if the key is not found. Default is True.
- Returns:
The value associated with the key if it exists, otherwise the default value.
- Return type:
Any
- default_set(dictionary, key, default)[source]
Sets a default value for a given key in a dictionary if the key does not already exist.
- recursive_default_set(target_dict, default_dict)[source]
Recursively sets default values in a target dictionary based on a default dictionary.
- class NumpyEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)[source]
Bases:
JSONEncoderJSON encoder that handles NumPy scalar and array types.
Extends
json.JSONEncoderto serializenp.integerasint,np.floatingasfloat, andnp.ndarrayas nested lists. Used when saving training checkpoints and run fingerprints to JSON.Examples
>>> import json, numpy as np >>> json.dumps({"val": np.float64(0.5)}, cls=NumpyEncoder) '{"val": 0.5}'
- dict_of_jnp_to_np(dictionary)[source]
Convert dictionary values from jax numpy arrays to numpy arrays.
- dict_of_np_to_jnp(dictionary)[source]
Convert dictionary values from numpy arrays to jax numpy arrays.
- lamb_to_memory(lamb)[source]
Convert EWMA decay factor lambda to the effective memory length (unitless).
The EWMA weighting kernel
w_t = lambda^t * (1 - lambda)has a characteristic memory scale that grows with lambda. This function inverts the cubic relationship used in quantammsim’s parameterisation:\[\text{memory} = 4 \cdot \sqrt[3]{\frac{6 \lambda}{(1 - \lambda)^3}}\]To convert to days, use
lamb_to_memory_days()which divides by2 * chunk_period / 1440.- Parameters:
lamb (float or jnp.ndarray) – EWMA decay factor in (0, 1).
- Returns:
Unitless memory scale.
- Return type:
float or jnp.ndarray
See also
lamb_to_memory_daysReturns memory in days.
memory_days_to_lambInverse mapping (days -> lambda).
- jax_memory_days_to_lamb(memory_days, chunk_period=60)[source]
Convert memory days to lambda value using JAX operations.
- memory_days_to_logit_lamb(memory_days, chunk_period=60)[source]
Convert memory days to logit lambda value.
- lamb_to_memory_days(lamb, chunk_period)[source]
Convert EWMA decay factor lambda to effective memory in days.
Applies
lamb_to_memory()then rescales by2 * chunk_period / 1440to convert from unitless memory to calendar days, accounting for the observation frequency.- Parameters:
- Returns:
Effective memory in days.
- Return type:
float or jnp.ndarray
See also
lamb_to_memoryUnitless version.
memory_days_to_lambInverse mapping.
lamb_to_memory_days_clippedClipped version with max_memory_days bound.
- logistic_func(x)[source]
Standard logistic sigmoid:
sigma(x) = exp(x) / (1 + exp(x)).Maps R -> (0, 1). Used to convert the unconstrained
logit_lambparameter to the EWMA decay factorlambdain (0, 1).
- jax_logit_lamb_to_lamb(logit_lamb)[source]
Convert logit lambda to lambda value using JAX operations.
- lamb_to_memory_days_clipped(lamb, chunk_period, max_memory_days)[source]
Convert lambda value to memory days, clipped to a maximum value.
- calc_lamb(update_rule_parameter_dict)[source]
Calculate the lambda value from the given update rule parameter dictionary.
- calc_lamb_from_index(update_rule_parameter_dict, logit_lamb_index)[source]
Calculate the lambda value from the given update rule parameter dictionary and index.
- Parameters:
- Returns:
The calculated lambda value.
- Return type:
- Raises:
KeyError – If “logit_lamb” key is not found in update_rule_parameter_dict.
- calc_alt_lamb(update_rule_parameter_dict)[source]
Calculate the alternative lambda value based on the provided update rule parameters.
- Parameters:
update_rule_parameter_dict (dict) – A dictionary containing the update rule parameters. Must include keys: - “logit_lamb”: The logit lambda value - “logit_delta_lamb”: The logit delta lambda value
- Returns:
The calculated alternative lambda value.
- Return type:
- Raises:
KeyError – If “logit_lamb” or “logit_delta_lamb” is not found in update_rule_parameter_dict.
- inverse_squareplus(y)[source]
Inverse of the squareplus activation (JAX version).
Given
y = squareplus(x), recoversx = (y² - 1) / y. Used to convert from a desired positive parameter value back to the unconstrained raw parameter for initialization.- Parameters:
y (float or jnp.ndarray) – Positive input value(s). Must be > 0 (domain of inverse squareplus).
- Returns:
Unconstrained value(s) that map to
yunder squareplus.- Return type:
jnp.ndarray
See also
squareplusForward mapping R → R⁺.
inverse_squareplus_npNumPy version for non-JAX contexts.
- inverse_squareplus_np(y)[source]
Inverse of the squareplus activation (NumPy version).
Identical to
inverse_squareplusbut uses NumPy operations, suitable for use outside JAX-traced contexts (e.g., initialization, post-processing).- Parameters:
y (float or np.ndarray) – Positive input value(s).
- Returns:
Unconstrained value(s) that map to
yunder squareplus.- Return type:
float or np.ndarray
See also
inverse_squareplusJAX version.
- get_raw_value(value)[source]
Convert a desired parameter value to raw (log2) space.
Many parameters (k, width, amplitude) use
2^rawreparameterization so that the raw parameter can take any real value while the effective value is always positive. This function inverts that:raw = log2(value).- Parameters:
value (float) – Desired positive parameter value.
- Returns:
Log2 of the input, for use as the raw parameter.
- Return type:
See also
get_log_amplitudeSimilar but divides by memory_days first.
- get_log_amplitude(amplitude, memory_days)[source]
Convert desired amplitude to raw log_amplitude parameter.
The effective amplitude is
2^log_amplitude * memory_days, so to achieve a target amplitude:log_amplitude = log2(amplitude / memory_days).
- init_params_singleton(initial_values_dict, n_tokens, n_subsidary_rules=0, chunk_period=60, log_for_k=True)[source]
Initialize a single parameter set from human-readable initial values.
Converts intuitive values (memory_days, k_per_day, etc.) into the internal reparameterized form (logit_lamb, log_k, etc.) as 1-D JAX arrays of length
n_tokens + n_subsidary_rules.- Parameters:
initial_values_dict (dict) – Human-readable initial values. Required keys: -
'initial_k_per_day': Weight update aggressiveness -'initial_memory_length': EWMA memory in days Optional keys: -'initial_memory_length_delta': Additional memory for alt lambda -'initial_weights_logits': Starting weight logits -'initial_log_amplitude': Channel amplitude (log2 scale) -'initial_raw_width': Channel width (log2 scale) -'initial_raw_exponents': Power exponents (squareplus space) -'initial_pre_exp_scaling': Pre-exponential scaling (logit space)n_tokens (int) – Number of assets in the pool.
n_subsidary_rules (int, optional) – Number of subsidiary rules (for composite pools). Default is 0.
chunk_period (int, optional) – Time between price observations in minutes. Default is 60.
log_for_k (bool, optional) – If True, use
log_kparameterization; if False, use lineark. Default is True.
- Returns:
Parameter dict with keys:
'log_k'(or'k'),'logit_lamb','logit_delta_lamb','initial_weights_logits','log_amplitude','raw_width','raw_exponents','logit_pre_exp_scaling','subsidary_params'. All values are 1-Djnp.ndarray.- Return type:
See also
init_paramsMulti-set version with noise injection.
- fill_in_missing_values_from_init_singleton(params, initial_values_dict, n_tokens, n_subsidary_rules=0, chunk_period=60, log_for_k=True)[source]
Fill in missing values in parameters from initial values.
- Parameters:
params (dict) – The parameters dictionary to update.
initial_values_dict (dict) – The initial values dictionary.
n_tokens (int) – The number of tokens.
n_subsidary_rules (int, optional) – The number of subsidary rules. Default is 0.
chunk_period (int, optional) – The chunk period. Default is 60.
log_for_k (bool, optional) – Whether to use log scale for k parameter. Default is True.
- Returns:
The updated parameters dictionary.
- Return type:
- init_params(initial_values_dict, n_tokens, n_subsidary_rules=0, chunk_period=60, n_parameter_sets=1, noise='gaussian')[source]
Initialize multiple parameter sets from human-readable initial values.
Creates
n_parameter_setscopies of the base parameters. Whenn_parameter_sets > 1, Gaussian noise is added to all rows except the first (which remains at the exact initial values). This is the legacy ensemble initialization method; for more control, seeEnsembleAveragingHook.- Parameters:
initial_values_dict (dict) – Human-readable initial values (same format as
init_params_singleton).n_tokens (int) – Number of assets in the pool.
n_subsidary_rules (int, optional) – Number of subsidiary rules. Default is 0.
chunk_period (int, optional) – Time between price observations in minutes. Default is 60.
n_parameter_sets (int, optional) – Number of parameter sets (ensemble members). Default is 1.
noise (str, optional) – Noise type for diversification. Only
'gaussian'is supported. Default is'gaussian'.
- Returns:
Parameter dict with 2-D arrays of shape
(n_parameter_sets, n_pool_members)for each parameter key.- Return type:
See also
init_params_singletonSingle parameter set initialization.
- fill_in_missing_values_from_init(params, initial_values_dict, n_tokens, n_subsidary_rules=0, chunk_period=60, n_parameter_sets=1)[source]
Fill in missing values in parameters from initial values.
- Parameters:
params (dict) – The parameters dictionary to update.
initial_values_dict (dict) – The initial values dictionary.
n_tokens (int) – The number of tokens.
n_subsidary_rules (int, optional) – The number of subsidary rules. Default is 0.
chunk_period (int, optional) – The chunk period. Default is 60.
n_parameter_sets (int, optional) – The number of parameter sets. Default is 1.
- Returns:
The updated parameters dictionary.
- Return type:
- calc_hessian_from_loaded_params(params, partial_fixed_training_step)[source]
Calculate the Hessian matrix from the loaded parameters.
- Parameters:
params (dict) – A dictionary of parameters.
partial_fixed_training_step (callable) – A function representing the partial fixed training step.
- Returns:
The Hessian matrix calculated from the loaded parameters.
- Return type:
- load_result_array(run_location, key='objective', recalc_hess=False)[source]
Load simulation results from a JSON file and return run fingerprint and results array.
- Parameters:
- Returns:
- A tuple containing:
- run_fingerprintdict
Configuration details and metadata for the simulation run
- valueslist
Array of values extracted from results based on specified key
- Return type:
- get_objective_scalar(obj, metric_key='returns_over_uniform_hodl')[source]
Extract a scalar value from an objective that may be a dict or number.
Use this when you have a single objective value (after retrieve_best has indexed into the parameter sets) and need a float.
- Parameters:
- Returns:
The scalar objective value
- Return type:
Examples
>>> get_objective_scalar(0.1) # old format 0.1 >>> get_objective_scalar({"return": 0.1, "sharpe": 0.5}) # new format 0.1
- load_manually(run_location, load_method='last', recalc_hess=False, min_test=0.0, return_as_iterables=False, metric_key='returns_over_uniform_hodl', use_continuous_test=True)[source]
Load and process parameter sets from a JSON results file with custom loading methods.
- Parameters:
run_location (str) – Path to the JSON results file.
load_method (str, optional) – Method for selecting parameter sets. One of: ‘last’ - Returns the last parameter set ‘best_objective’ - Returns set with highest overall objective ‘best_train_objective’ - Returns set with highest training objective ‘best_test_objective’ - Returns set with highest test objective ‘best_train_min_test_objective’ - Returns set with highest training objective that meets minimum test threshold. Defaults to ‘last’.
recalc_hess (bool, optional) – Whether to recalculate Hessian trace values. Defaults to False.
min_test (float, optional) – Minimum test objective threshold for ‘best_train_min_test_objective’ method. Defaults to 0.0.
metric_key (str, optional) – For new format files with metric dicts, specifies which metric to use. Options include: “return”, “sharpe”, “jax_sharpe”, “returns_over_hodl”, “returns_over_uniform_hodl”, “annualised_returns”, “calmar”, “sterling”, “ulcer”. Ignored for old format files with simple numeric objectives. Defaults to “returns_over_uniform_hodl”.
use_continuous_test (bool, optional) – If True and continuous_test_metrics is available, use it instead of test_objective for test-related load methods. Defaults to True.
- Returns:
Two-element tuple containing: - dict: Loaded parameters - int: The index of the selected parameter set
- Return type:
- retrieve_best(data_location, load_method, re_calc_hess, min_alt_obj=0.0, return_as_iterables=False)[source]
Retrieve the best parameters from a training run.
Loads parameters using the specified method and extracts the best parameter set based on the context (index of best performing parameters). Removes training metadata (step, hessian_trace, etc.) from the returned params.
- Parameters:
data_location (str) – Path to the directory containing saved training results.
load_method (str) – Method for loading parameters. Options include: - ‘last’: Load the most recent checkpoint - ‘best_train_objective’: Load checkpoint with best training objective - ‘best_test_objective’: Load checkpoint with best test objective
re_calc_hess (bool) – Whether to recalculate hessian information when loading.
min_alt_obj (float, optional) – Minimum alternative objective threshold. Defaults to 0.0.
return_as_iterables (bool, optional) – If True, returns lists of all loaded params and steps. If False, returns only the first (best) params and step. Defaults to False.
- Returns:
params (dict or list of dict) – Best parameter dictionary (or list if return_as_iterables=True). Training metadata fields are removed.
steps (int or list of int) – Training step(s) at which the parameters were saved.
- load_or_init(run_fingerprint, initial_values_dict, n_tokens, n_subsidary_rules, recalc_hess=False, chunk_period=60, force_init=False, load_method='last', n_parameter_sets=1, results_dir='./results/', partial_fixed_training_step=None)[source]
Load or initialize parameters for the AMM simulator.
- Parameters:
run_fingerprint (str) – The fingerprint of the run.
initial_values_dict (dict) – The initial values dictionary.
n_tokens (int) – The number of tokens.
n_subsidary_rules (int) – The number of subsidiary rules.
recalc_hess (bool, optional) – Whether to recalculate the Hessian. Default is False.
chunk_period (int, optional) – The chunk period. Default is 60.
force_init (bool, optional) – Whether to force initialization. Default is False.
load_method (str, optional) – The method to use for loading. Default is “last”.
n_parameter_sets (int, optional) – The number of parameter sets. Default is 1.
results_dir (str, optional) – The directory for results. Default is “./results/”.
partial_fixed_training_step (callable, optional) – The partial fixed training step. Default is None.
- Returns:
- A tuple containing:
- paramsdict
The loaded or initialized parameters
- loadedbool
Whether the parameters were loaded (True) or initialized (False)
- Return type:
- load(run_location, initial_values_dict, n_tokens, n_subsidary_rules, chunk_period=60, load_method='last', n_parameter_sets=1)[source]
Load parameters from a file and fill in missing values based on initial values.
- Parameters:
run_location (str) – The location of the file containing the parameters.
initial_values_dict (dict) – A dictionary of initial values.
n_tokens (int) – The number of tokens.
n_subsidary_rules (int) – The number of subsidiary rules.
chunk_period (int, optional) – The chunk period. Default is 60.
load_method ({'last', 'best_objective', 'best_train_objective'}, optional) – The method to use for loading parameters. Default is ‘last’.
n_parameter_sets (int, optional) – The number of parameter sets. Default is 1.
- Returns:
- A tuple containing:
- paramsdict
The loaded parameters
- contextint or None
The context index for the loaded parameters
- Return type:
- Raises:
FileNotFoundError – If the run_location file does not exist.
NotImplementedError – If an unsupported load_method is specified.
- make_composite_run_params(composite_params, list_of_subsidary_pool_run_dicts, initial_values_dict, n_parameter_sets)[source]
Create composite run parameters for the AMM simulator.
- Parameters:
composite_params (dict) – The composite parameters for the AMM simulator.
list_of_subsidary_pool_run_dicts (list) – A list of dictionaries containing the parameters for each subsidiary pool run.
initial_values_dict (dict) – The initial values dictionary for the AMM simulator.
n_parameter_sets (int) – The number of parameter sets.
- Returns:
The composite run parameters for the AMM simulator.
- Return type:
- create_product_of_linspaces(params, keys_ranges, num_points_per_key, inverse_funcs=None)[source]
Create a product of linspaces for chosen keys in the params dict.
- Parameters:
params (dict) – The dictionary containing initial parameter values.
keys_ranges (dict) – The dictionary containing high and low values for each key.
num_points_per_key (dict) – The dictionary containing the number of points for each key.
inverse_funcs (dict, optional) – A dictionary of inverse functions for each key.
- Returns:
A list of dictionaries with all combinations of linspace values for the chosen keys.
- Return type:
- create_product_of_arrays(params, keys_arrays)[source]
Create a product of arrays for chosen keys in the params dict.
- generate_run_fingerprint_combinations(run_fingerprint, keys_ranges=None, num_points_per_key=None, inverse_funcs=None)[source]
Generate run fingerprint combinations with specified ranges and scaling.
- Parameters:
run_fingerprint (dict) – The base run fingerprint.
keys_ranges (dict, optional) – The dictionary containing high and low values for each key. Defaults to logarithmic ranges for ‘arb_frequency’.
num_points_per_key (dict, optional) – The dictionary containing the number of points for each key. Defaults to 10 points for each key.
inverse_funcs (dict, optional) – A dictionary of inverse functions for each key. Defaults to logarithmic scaling for ‘arb_frequency’.
- Returns:
A list of dictionaries with all combinations of run fingerprint values.
- Return type:
- make_log_range_with_zero(x)[source]
Compute the exponential of a given value, with a special case for zero.
- combine_param_combinations(param_combinations, n_parameter_sets)[source]
Combine single-row jnp arrays in param_combinations into multi-row jnp arrays.
- split_param_combinations(param_combinations)[source]
Split multi-row jnp arrays in param_combinations into single-row jnp arrays.
- make_vmap_in_axes_dict(input_dict, in_axes, keys_to_recur_on, keys_with_no_vamp=[], n_repeats_of_recurred=0)[source]
Create a
vmapin_axes specification dict matching a parameter dict structure.Constructs the nested dict/list structure that
jax.vmapexpects for itsin_axesargument when vectorizing over a dict of parameters. Handles recursive structure for subsidiary parameters.- Parameters:
input_dict (dict) – Parameter dictionary whose structure to mirror.
in_axes (int) – Axis to vectorize over (typically 0 for the parameter-set dimension).
keys_to_recur_on (list of str) – Keys (e.g.,
'subsidary_params') that contain nested parameter dicts requiring recursive axis specification.keys_with_no_vamp (list of str, optional) – Keys that should not be vectorized (axis set to None). Default is
[].n_repeats_of_recurred (int, optional) – Number of subsidiary parameter dicts. Default is 0.
- Returns:
Nested dict matching the structure of
input_dictwith integer axes or None for each leaf.- Return type:
- generate_params_combinations(initial_values_dict, n_tokens, n_subsidary_rules, chunk_period, n_parameter_sets, k_per_day_range, memory_days_range, num_points_k_per_day=10, num_points_memory_days=10)[source]
Generate parameter combinations with linearly-spaced values of k_per_day and memory_days.
- Parameters:
initial_values_dict (dict) – The initial values dictionary.
n_tokens (int) – The number of tokens.
n_subsidary_rules (int) – The number of subsidary rules.
chunk_period (int) – The chunk period.
n_parameter_sets (int) – The number of parameter sets.
k_per_day_range (tuple) – The range (low, high) for k_per_day.
memory_days_range (tuple) – The range (low, high) for memory_days.
num_points_k_per_day (int, optional) – The number of points for k_per_day linspace. Defaults to 10.
num_points_memory_days (int, optional) – The number of points for memory_days linspace. Defaults to 10.
- Returns:
A list of dictionaries with all combinations of parameter values.
- Return type:
- process_initial_values(initial_values_dict, key, n_assets, n_parameter_sets, force_scalar=False)[source]
Extract and broadcast a parameter value to the correct shape.
Handles flexible input formats: scalar (broadcast to all assets and sets), per-asset vector (broadcast across sets), or full matrix. Used by the schema-aware initialization path.
- Parameters:
initial_values_dict (dict) – Dictionary containing initial parameter values.
key (str) – Parameter name to extract.
n_assets (int) – Number of assets (columns).
n_parameter_sets (int) – Number of parameter sets / ensemble members (rows).
force_scalar (bool, optional) – If True, treat value as a scalar even if it’s array-like, producing shape
(n_parameter_sets,)instead of(n_parameter_sets, n_assets).
- Returns:
Array of shape
(n_parameter_sets, n_assets)or(n_parameter_sets,)ifforce_scalar=True.- Return type:
np.ndarray
- Raises:
ValueError – If
keyis not ininitial_values_dictor has incompatible shape.
- convert_parameter_values(params, run_fingerprint, max_memory_days=None)[source]
Convert raw (reparameterized) parameters to human-readable and on-chain formats.
Applies the inverse reparameterizations (logit → lambda → memory_days, log2 → k, squareplus → exponents, etc.) and produces both float64 values and BD18 fixed-point string representations suitable for on-chain deployment.
- Parameters:
params (dict) – Raw parameter dictionary (e.g.,
'logit_lamb','log_k','raw_exponents').run_fingerprint (dict) – Run configuration, must include
'chunk_period'.max_memory_days (float, optional) – Maximum memory days for lambda clipping. If None, uses
run_fingerprint['max_memory_days'](default 365).
- Returns:
{'values': {...}, 'strings': {...}}where each inner dict maps human-readable parameter names ('lamb','k','exponents','width','amplitude','pre_exp_scaling') to lists.'values'contains float64 lists;'strings'contains BD18 (18-decimal fixed-point integer) string representations.- Return type:
Notes
BD18 format multiplies the float value by 10^18 and represents as an integer string, matching the Solidity
uint256representation used by the on-chain QuantAMM contracts. The conversion uses string manipulation to avoid float64 overflow from direct multiplication by 1e18.
Result Exporter
- get_run_location(run_fingerprint)[source]
Generates a unique identifier string based on the provided run fingerprint.
The function takes a dictionary representing the run fingerprint, converts it to a JSON string, and then computes its SHA-256 hash. The resulting hash is used to create a unique identifier string with a “run_” prefix.
- append_json(new_data, filename)[source]
Append new data to a JSON file.
This function reads the existing data from a JSON file, appends the new data to it, and then writes the updated data back to the file.
- Parameters:
- Raises:
FileNotFoundError – If the specified file does not exist.
json.JSONDecodeError – If the file contains invalid JSON.
- append_list_json(new_data, filename)[source]
Append new data to a JSON file.
This function reads the existing data from a JSON file, appends the new data to it, and then writes the updated data back to the file.
- Parameters:
- Raises:
FileNotFoundError – If the specified file does not exist.
json.JSONDecodeError – If the file contains invalid JSON.
- save_multi_params(run_fingerprint, params, test_objective, train_objective, objective, local_learning_rate, iterations_since_improvement, steps, continuous_test_metrics=None, validation_metrics=None, sorted_tokens=True)[source]
Save multiple parameter sets along with their associated metrics to a JSON file.
- Parameters:
run_fingerprint (dict) – Dictionary containing run configuration details used to generate unique run location
params (list) – List of parameter dictionaries to save
test_objective (list) – List of objective values/metrics on test set for each parameter set
train_objective (list) – List of objective values/metrics on training set for each parameter set
objective (list) – List of overall objective values for each parameter set
local_learning_rate (list) – List of learning rates used for each parameter set
iterations_since_improvement (list) – List tracking iterations without improvement for each parameter set
steps (list) – List of step counts for each parameter set
continuous_test_metrics (list, optional) – List of continuous test metrics for each parameter set
validation_metrics (list, optional) – List of validation metrics for each parameter set (when using val_fraction > 0)
sorted_tokens (bool, optional) – Whether tokens are sorted alphabetically, by default True
Notes
Saves the data to a JSON file at
./results/run_<sha256_hash>.jsonwhere the hash is generated from the run_fingerprint using SHA-256. If file exists, appends new parameter sets to existing data Converts JAX arrays to numpy arrays before saving
- save_optuna_results_sgd_format(run_fingerprint, study, n_assets, sorted_tokens=True)[source]
Save optuna study results in the same format as SGD training results.
This allows optuna-optimized parameters to be loaded and analyzed with the same tools used for SGD-trained parameters.
- Parameters:
Notes
Saves to
./results/run_<sha256_hash>.jsonin the same format as save_multi_params, allowing unified result analysis.
- save_params(run_fingerprint, params, step, test_objective, train_objective, objective, hess, local_learning_rate, iterations_since_improvement, sorted_tokens=True)[source]
Save optimization parameters and results to a JSON file.
- Parameters:
run_fingerprint (dict) – Dictionary containing run configuration details
params (dict) – Dictionary of optimization parameters
step (int) – Current optimization step count
test_objective (float) – Objective function value on test data
train_objective (float) – Objective function value on training data
objective (float) – Overall objective function value
hess (float) – Trace of the Hessian matrix
local_learning_rate (float) – Current learning rate
iterations_since_improvement (int) – Number of iterations without improvement
sorted_tokens (bool, optional) – Whether tokens are sorted alphabetically, by default True
Notes
Saves the data to a JSON file at
./results/run_<sha256_hash>.jsonwhere the hash is generated from the run_fingerprint using SHA-256. If file exists, appends new parameter set to existing data Converts JAX arrays to numpy arrays before saving
Windowing Utilities
- get_indices(start_index, bout_length, len_prices, key, optimisation_settings)[source]
Get indices for sampling data windows during training.
- Parameters:
start_index (int) – Starting index position in the data
bout_length (int) – Length of each training window/bout
len_prices (int) – Total length of the price data
key (jax.random.PRNGKey) – JAX random number generator key
optimisation_settings (dict) – Dictionary containing optimization settings with keys: - batch_size: Number of windows to sample - training_data_kind: Type of training data (‘historic’ or ‘mc’) - sample_method: Method for sampling windows (‘exponential’ or ‘uniform’) - max_mc_version: Maximum MC version number (only used if training_data_kind=’mc’)
- Returns:
- start_indexesjnp.ndarray
Array of sampled starting indices, shape (batch_size, 2) for historic data or (batch_size, 3) for MC data
- keyjax.random.PRNGKey
Updated random number generator key
- Return type:
- raw_trades_to_trade_array(raw_trades, start_date_string, end_date_string, tokens)[source]
Convert raw trade data to a structured trade array.
This function takes raw trade data and converts it into a pandas DataFrame with a continuous range of Unix timestamps. Each row in the DataFrame represents a minute, and trades are mapped to their corresponding timestamps.
- Parameters:
- Returns:
A numpy array with columns ‘token in’, ‘token out’, and ‘amount in’. The index is a continuous range of Unix timestamps from start_unix to end_unix at minute intervals. Timestamps without trades are filled with zeros.
- Return type:
numpy array
- raw_fee_like_amounts_to_fee_like_array(raw_inputs, start_date_string, end_date_string, names, fill_method='base')[source]
Convert raw fee-like data to a structured fee-like array.
Takes raw fee-like data (fees, gas costs, arb fees) and converts it into a pandas DataFrame with a continuous range of Unix timestamps. Each row represents a minute, with trades mapped to their corresponding timestamps.
- Parameters:
raw_inputs (pandas.DataFrame) – Raw fee-like data, where each row contains unix_timestamp and the fee-like amount with given column name
start_time (str) – The start date time in format “%Y-%m-%d %H:%M:%S”
end_time (str) – The end date time in format “%Y-%m-%d %H:%M:%S”
names (list of str) – The names of columns in raw_inputs of the fee-like amount
fill_method (str) – The method to fill in missing values. Options: - ‘base’: fills rows which have no values with 0 - ‘ffill’: fills with the last non-zero value
- Returns:
Array giving the fee-like values over time. The index is a continuous range of Unix timestamps from start_unix to end_unix at minute intervals. Timestamps without values are filled with zeros.
- Return type:
- filter_coarse_weights_by_data_indices(coarse_weights, data_dict)[source]
Slice coarse (chunk-period) weights to match the time range in
data_dict.Used when pre-computed coarse weights are loaded from a previous run or from on-chain data, and need to be aligned with the current training/evaluation window.
- Parameters:
- Returns:
Shallow copy of
coarse_weightswith'weights'sliced to the matching time range.- Return type:
- filter_reserves_by_data_indices(reserves, unix_values, data_dict)[source]
Slice a reserves array to match the time range in
data_dict.Looks up the unix timestamps at
data_dict["start_idx"]anddata_dict["end_idx"] - 1inunix_values, then returns the corresponding slice ofreserves.- Parameters:
reserves (np.ndarray) – Full reserves array, shape
(T, n_assets).unix_values (np.ndarray) – Unix timestamps corresponding to each row of
reserves.data_dict (dict) – Must contain
unix_values,start_idx, andend_idx.
- Returns:
Sliced reserves matching the data_dict time range.
- Return type:
np.ndarray
- filter_reserves_by_given_timestamp(reserves, unix_values, timestamp)[source]
Extract the reserves row at a specific unix timestamp.
- Parameters:
reserves (np.ndarray) – Full reserves array, shape
(T, n_assets).unix_values (np.ndarray) – Unix timestamps corresponding to each row.
timestamp (int) – Unix timestamp (milliseconds) to look up.
- Returns:
Reserves at the given timestamp, shape
(n_assets,).- Return type:
np.ndarray
- Raises:
IndexError – If
timestampis not found inunix_values.