Core Simulator

Forward Pass

Forward pass simulation pipeline and financial metric calculation.

This module implements the core simulation loop for AMM pool strategies: prices → parameterised weight rule → simulated arbitrage → reserve dynamics → financial metrics.

The forward pass is the innermost computation in the three-level optimization hierarchy: forward pass (per-window) → training loop (gradient descent over windows) → hyperparameter tuner (meta-optimization over training configs). It is JIT-compiled via JAX and fully differentiable, enabling gradient-based optimization of strategy parameters.

Key components:

  • forward_pass / forward_pass_nograd: Entry points that wire pool dynamics to metric calculation. forward_pass propagates gradients; forward_pass_nograd wraps inputs in stop_gradient for evaluation.

  • _calculate_return_value: Dispatch registry mapping ~30 metric names to their implementations, from simple returns to risk-adjusted ratios.

  • Metric helpers (_daily_log_sharpe, _calculate_max_drawdown, etc.): Pure-JAX implementations of financial metrics, designed for differentiability and JIT compatibility.

  • _apply_price_noise: Multiplicative log-normal noise for data augmentation during training.

Notes

All time-series inputs use minute resolution (1 timestep = 1 minute). Duration parameters in metric helpers (e.g., duration=24*60) are in minutes. Annualization assumes 365 calendar days.

The default training metric is daily_log_sharpe (not sharpe). This uses log returns sampled at daily frequency, which is more numerically stable and better aligned with standard financial practice than minute-frequency arithmetic Sharpe.

forward_pass(params, start_index, prices, trades_array=None, fees_array=None, gas_cost_array=None, arb_fees_array=None, pool=None, static_dict={'all_sig_variations': None, 'arb_fees': 0.0, 'arb_frequency': 1, 'bout_length': 1000, 'chunk_period': 60, 'do_trades': False, 'fees': 0.0, 'gas_cost': 0.0, 'initial_pool_value': 1000000.0, 'max_memory_days': 365.0, 'maximum_change': 1.0, 'n_assets': 3, 'return_val': 'reserves', 'rule': 'momentum', 'run_type': 'normal', 'training_data_kind': 'historic', 'use_alt_lamb': False, 'use_pre_exp_scaling': True, 'weight_interpolation_method': 'linear', 'weight_interpolation_period': 60})[source]

Simulates a forward pass of a liquidity pool using specified parameters and market data.

This function models the behavior of a liquidity pool over a given period, considering various factors such as fees, gas costs, and arbitrage fees. It calculates reserves and other metrics based on the provided parameters and market prices.

Parameters:
  • params (dict) – A dictionary containing the parameters for the simulation, such as initial weights and other configuration settings.

  • start_index (array-like) – The starting index for the simulation, used to slice the price data.

  • prices (array-like) – A 2D array of market prices for the assets involved in the simulation.

  • trades_array (array-like, optional) – An array of trades to be considered in the simulation. Defaults to None.

  • fees_array (array-like, optional) – An array of fees to be applied during the simulation. Defaults to None.

  • gas_cost_array (array-like, optional) – An array of gas costs to be considered in the simulation. Defaults to None.

  • arb_fees_array (array-like, optional) – An array of arbitrage fees to be applied during the simulation. Defaults to None.

  • pool (object) – An instance of a pool object that provides methods to calculate reserves based on the inputs. Must be provided.

  • static_dict (dict, optional) – A dictionary of static configuration values for the simulation, such as bout length, number of assets, and return value type. Defaults to a predefined set of values.

Returns:

Depending on the return_val specified in static_dict, the function returns different types of results:

  • ”reserves”: A dictionary containing the reserves over time.

  • ”sharpe”: The Sharpe ratio of the pool returns.

  • ”returns”: The total return over the simulation period.

  • ”returns_over_hodl”: The return over a hold strategy.

  • ”greatest_draw_down”: The greatest drawdown during the simulation.

  • ”alpha”: Not implemented.

  • ”value”: The value of the pool over time.

  • ”reserves_and_values”: A dictionary containing final reserves, final value, value over time, prices, and reserves.

Return type:

dict or float

Raises:

Notes

  • The function is decorated with jax.jit for performance optimization, with static arguments specified for JIT compilation.

  • The function handles different cases for fees and trades, adjusting the calculation method accordingly:

    1. If any of fees_array, gas_cost_array, arb_fees_array, or trades_array is provided, it uses pool.calculate_reserves_with_dynamic_inputs.

    2. If any of fees, gas_cost, or arb_fees in static_dict is a nonzero scalar value, it uses pool.calculate_reserves_with_fees.

    3. If all fees and costs are zero and no trades are provided, it uses pool.calculate_reserves_zero_fees.

  • The function supports different types of return values, allowing for flexible output based on the simulation needs.

  • The arb_frequency in static_dict can alter the frequency of arbitrage operations, affecting the reserves calculation and this size of returned arrays.

Examples

>>> forward_pass(params, start_index, prices, pool=my_pool)
{'reserves': array([...])}
forward_pass_nograd(params, start_index, prices, trades_array=None, fees_array=None, gas_cost_array=None, arb_fees_array=None, pool=None, static_dict={'all_sig_variations': None, 'arb_fees': 0.0, 'arb_frequency': 1, 'bout_length': 1000, 'chunk_period': 60, 'do_trades': False, 'fees': 0.0, 'gas_cost': 0.0, 'initial_pool_value': 1000000.0, 'max_memory_days': 365.0, 'maximum_change': 1.0, 'n_assets': 3, 'return_val': 'reserves', 'rule': 'momentum', 'run_type': 'normal', 'training_data_kind': 'historic', 'use_alt_lamb': False, 'use_pre_exp_scaling': True, 'weight_interpolation_method': 'linear', 'weight_interpolation_period': 60})[source]

Simulates a forward pass of a liquidity pool without gradient tracking using specified parameters and market data.

This function models the behavior of a liquidity pool over a given period, similar to forward_pass, but ensures that no gradients are tracked for the input parameters and data. It is useful for scenarios where gradient computation is not required, such as evaluation or inference.

Parameters:
  • params (dict) – A dictionary containing the parameters for the simulation, such as initial weights and other configuration settings.

  • start_index (array-like) – The starting index for the simulation, used to slice the price data.

  • prices (array-like) – A 2D array of market prices for the assets involved in the simulation.

  • trades_array (array-like, optional) – An array of trades to be considered in the simulation. Defaults to None.

  • fees_array (array-like, optional) – An array of fees to be applied during the simulation. Defaults to None.

  • gas_cost_array (array-like, optional) – An array of gas costs to be considered in the simulation. Defaults to None.

  • arb_fees_array (array-like, optional) – An array of arbitrage fees to be applied during the simulation. Defaults to None.

  • pool (object) – An instance of a pool object that provides methods to calculate reserves based on the inputs. Must be provided.

  • static_dict (dict, optional) – A dictionary of static configuration values for the simulation, such as bout length, number of assets, and return value type. Defaults to a predefined set of values.

Returns:

Depending on the return_val specified in static_dict, the function returns different types of results:

  • ”reserves”: A dictionary containing the reserves over time.

  • ”sharpe”: The Sharpe ratio of the pool returns.

  • ”returns”: The total return over the simulation period.

  • ”returns_over_hodl”: The return over a hold strategy.

  • ”greatest_draw_down”: The greatest drawdown during the simulation.

  • ”alpha”: Not implemented.

  • ”value”: The value of the pool over time.

  • ”reserves_and_values”: A dictionary containing final reserves, final value, value over time, prices, and reserves.

Return type:

dict or float

Raises:

Notes

  • The function is decorated with jax.jit for performance optimization, with static arguments specified for JIT compilation.

  • The function handles different cases for fees and trades, adjusting the calculation method accordingly:

    1. If any of fees_array, gas_cost_array, arb_fees_array, or trades_array is provided, it uses pool.calculate_reserves_with_dynamic_inputs.

    2. If any of fees, gas_cost, or arb_fees in static_dict is a nonzero scalar value, it uses pool.calculate_reserves_with_fees.

    3. If all fees and costs are zero and no trades are provided, it uses pool.calculate_reserves_zero_fees.

  • The function supports different types of return values, allowing for flexible output based on the simulation needs.

  • The arb_frequency in static_dict can alter the frequency of arbitrage operations, affecting the reserves calculation and this size of returned arrays.

  • The function uses jax.lax.stop_gradient to ensure that no gradients are tracked

    for the input parameters and data.

Examples

>>> forward_pass_nograd(params, start_index, prices, pool=my_pool)
{'reserves': array([...])}

Parameter Utilities

Parameter utilities for strategy parameterization, serialization, and loading.

This module handles the full lifecycle of strategy parameters:

  • Initialization: init_params / init_params_singleton create parameter dicts from human-readable initial values (memory days, k per day) by converting to the internal reparameterized form (logit_lamb, log_k, etc.).

  • Reparameterization: Functions like calc_lamb, calc_alt_lamb, squareplus, and their inverses convert between human-interpretable values and the unconstrained spaces used for gradient-based optimization.

  • Serialization: NumpyEncoder, dict_of_jnp_to_np, dict_of_jnp_to_list, dict_of_np_to_jnp handle conversion between JAX arrays, NumPy arrays, and JSON-serializable Python types.

  • Loading: load_or_init, load, load_manually, retrieve_best load saved training checkpoints with various selection strategies (best train, best test, best-train-above-test-threshold, etc.).

  • Grid generation: create_product_of_linspaces, generate_params_combinations produce parameter grids for heatmap evaluations.

The key reparameterizations are:

  • lambda (λ): EWMA decay factor in [0, 1], stored as logit_lamb = log(λ/(1-λ)). Converted to/from human-readable memory_days via cubic-root inversion.

  • k: Weight update aggressiveness, stored as log_k = log2(k / memory_days). This decouples scale from memory length.

  • squareplus: Smooth, non-negative activation (x + sqrt(x² + 4)) / 2, an algebraic (non-transcendental) replacement for softplus. Used for exponent params.

Notes

The memory_days lambda conversion involves solving a cubic equation analytically. Both NumPy (memory_days_to_lamb) and JAX (jax_memory_days_to_lamb) versions exist; the NumPy version includes safe division guards for zero memory days, while the JAX version relies on jnp.where for the zero case.

squareplus(x)[source]

Algebraic (non-transcendental) replacement for softplus.

Computes (x + sqrt(x² + 4)) / 2, which maps R → R⁺ smoothly. Unlike softplus (log(1 + exp(x))), squareplus avoids transcendental functions and is thus cheaper to differentiate through and more JIT-friendly.

Parameters:

x (jnp.ndarray or float) – Input value(s).

Returns:

Non-negative output(s), always > 0.

Return type:

jnp.ndarray or float

References

Barron, J.T. (2021). “Squareplus: A Softplus-Like Algebraic Rectifier.” arXiv:2112.11687.

See also

inverse_squareplus

Inverse mapping R⁺ → R.

check_run_fingerprint(run_fingerprint)[source]

Check that the run fingerprint is not malformed.

Parameters:

run_fingerprint (dict) – The run fingerprint to validate.

Return type:

None

Raises:

AssertionError – If weight_interpolation_period is greater than chunk_period.

default_set_or_get(dictionary, key, default, augment=True)[source]

Retrieves the value for a given key from a dictionary. If the key does not exist, it sets the key to a default value and returns the default value.

Parameters:
  • dictionary (dict) – The dictionary to search for the key.

  • key (str) – The key to look up in the dictionary.

  • default (Any) – The default value to set and return if the key is not found.

  • augment (bool, optional) – If True, the default value is added to the dictionary if the key is not found. Default is True.

Returns:

The value associated with the key if it exists, otherwise the default value.

Return type:

Any

default_set(dictionary, key, default)[source]

Sets a default value for a given key in a dictionary if the key does not already exist.

Parameters:
  • dictionary (dict) – The dictionary to update.

  • key (str) – The key to check in the dictionary.

  • default (Any) – The default value to set if the key is not present.

Return type:

None

recursive_default_set(target_dict, default_dict)[source]

Recursively sets default values in a target dictionary based on a default dictionary.

Parameters:
  • target_dict (dict) – The dictionary to update with default values.

  • default_dict (dict) – The dictionary containing the default values.

Return type:

None

class NumpyEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)[source]

Bases: JSONEncoder

JSON encoder that handles NumPy scalar and array types.

Extends json.JSONEncoder to serialize np.integer as int, np.floating as float, and np.ndarray as nested lists. Used when saving training checkpoints and run fingerprints to JSON.

Examples

>>> import json, numpy as np
>>> json.dumps({"val": np.float64(0.5)}, cls=NumpyEncoder)
'{"val": 0.5}'
get_run_location(run_fingerprint)[source]

Get the run location based on the run fingerprint.

Parameters:

run_fingerprint (dict) – The run fingerprint.

Returns:

The run location.

Return type:

str

dict_of_jnp_to_np(dictionary)[source]

Convert dictionary values from jax numpy arrays to numpy arrays.

Parameters:

dictionary (dict) – The dictionary to convert.

Returns:

The converted dictionary.

Return type:

dict

dict_of_jnp_to_list(dictionary)[source]

Convert dictionary values from jax numpy arrays to lists.

Parameters:

dictionary (dict) – The dictionary to convert.

Returns:

The converted dictionary.

Return type:

dict

dict_of_np_to_jnp(dictionary)[source]

Convert dictionary values from numpy arrays to jax numpy arrays.

Parameters:

dictionary (dict) – The dictionary to convert.

Returns:

The converted dictionary.

Return type:

dict

lamb_to_memory(lamb)[source]

Convert EWMA decay factor lambda to the effective memory length (unitless).

The EWMA weighting kernel w_t = lambda^t * (1 - lambda) has a characteristic memory scale that grows with lambda. This function inverts the cubic relationship used in quantammsim’s parameterisation:

\[\text{memory} = 4 \cdot \sqrt[3]{\frac{6 \lambda}{(1 - \lambda)^3}}\]

To convert to days, use lamb_to_memory_days() which divides by 2 * chunk_period / 1440.

Parameters:

lamb (float or jnp.ndarray) – EWMA decay factor in (0, 1).

Returns:

Unitless memory scale.

Return type:

float or jnp.ndarray

See also

lamb_to_memory_days

Returns memory in days.

memory_days_to_lamb

Inverse mapping (days -> lambda).

memory_days_to_lamb(memory_days, chunk_period=60)[source]

Convert memory days to lambda value.

Parameters:
  • memory_days (float) – The memory days value.

  • chunk_period (int, optional) – The chunk period. Default is 60.

Returns:

The lambda value.

Return type:

float

jax_memory_days_to_lamb(memory_days, chunk_period=60)[source]

Convert memory days to lambda value using JAX operations.

Parameters:
  • memory_days (float) – The memory days value.

  • chunk_period (int, optional) – The chunk period. Default is 60.

Returns:

The lambda value.

Return type:

float

memory_days_to_logit_lamb(memory_days, chunk_period=60)[source]

Convert memory days to logit lambda value.

Parameters:
  • memory_days (float) – The memory days value.

  • chunk_period (int, optional) – The chunk period. Default is 60.

Returns:

The logit lambda value.

Return type:

float

lamb_to_memory_days(lamb, chunk_period)[source]

Convert EWMA decay factor lambda to effective memory in days.

Applies lamb_to_memory() then rescales by 2 * chunk_period / 1440 to convert from unitless memory to calendar days, accounting for the observation frequency.

Parameters:
  • lamb (float or jnp.ndarray) – EWMA decay factor in (0, 1).

  • chunk_period (int) – Time between observations in minutes (e.g., 1440 for daily, 60 for hourly).

Returns:

Effective memory in days.

Return type:

float or jnp.ndarray

See also

lamb_to_memory

Unitless version.

memory_days_to_lamb

Inverse mapping.

lamb_to_memory_days_clipped

Clipped version with max_memory_days bound.

logistic_func(x)[source]

Standard logistic sigmoid: sigma(x) = exp(x) / (1 + exp(x)).

Maps R -> (0, 1). Used to convert the unconstrained logit_lamb parameter to the EWMA decay factor lambda in (0, 1).

Parameters:

x (float or jnp.ndarray) – Unconstrained input value(s).

Returns:

Output in (0, 1).

Return type:

float or jnp.ndarray

jax_logit_lamb_to_lamb(logit_lamb)[source]

Convert logit lambda to lambda value using JAX operations.

Parameters:

logit_lamb (float) – The logit lambda value.

Returns:

The lambda value between 0 and 1.

Return type:

float

lamb_to_memory_days_clipped(lamb, chunk_period, max_memory_days)[source]

Convert lambda value to memory days, clipped to a maximum value.

Parameters:
  • lamb (float) – The lambda value.

  • chunk_period (int) – The chunk period in minutes.

  • max_memory_days (float) – The maximum allowed memory days.

Returns:

The clipped memory value in days.

Return type:

float

calc_lamb(update_rule_parameter_dict)[source]

Calculate the lambda value from the given update rule parameter dictionary.

Parameters:

update_rule_parameter_dict (dict) – A dictionary containing the update rule parameters. Must include the key “logit_lamb”.

Returns:

The calculated lambda value.

Return type:

float

Raises:

KeyError – If “logit_lamb” key is not found in update_rule_parameter_dict.

calc_lamb_from_index(update_rule_parameter_dict, logit_lamb_index)[source]

Calculate the lambda value from the given update rule parameter dictionary and index.

Parameters:
  • update_rule_parameter_dict (dict) – A dictionary containing the update rule parameters. Must include the key “logit_lamb”.

  • logit_lamb_index (int) – The index of the logit lambda value to calculate.

Returns:

The calculated lambda value.

Return type:

float

Raises:

KeyError – If “logit_lamb” key is not found in update_rule_parameter_dict.

calc_alt_lamb(update_rule_parameter_dict)[source]

Calculate the alternative lambda value based on the provided update rule parameters.

Parameters:

update_rule_parameter_dict (dict) – A dictionary containing the update rule parameters. Must include keys: - “logit_lamb”: The logit lambda value - “logit_delta_lamb”: The logit delta lambda value

Returns:

The calculated alternative lambda value.

Return type:

float

Raises:

KeyError – If “logit_lamb” or “logit_delta_lamb” is not found in update_rule_parameter_dict.

inverse_squareplus(y)[source]

Inverse of the squareplus activation (JAX version).

Given y = squareplus(x), recovers x = (y² - 1) / y. Used to convert from a desired positive parameter value back to the unconstrained raw parameter for initialization.

Parameters:

y (float or jnp.ndarray) – Positive input value(s). Must be > 0 (domain of inverse squareplus).

Returns:

Unconstrained value(s) that map to y under squareplus.

Return type:

jnp.ndarray

See also

squareplus

Forward mapping R → R⁺.

inverse_squareplus_np

NumPy version for non-JAX contexts.

inverse_squareplus_np(y)[source]

Inverse of the squareplus activation (NumPy version).

Identical to inverse_squareplus but uses NumPy operations, suitable for use outside JAX-traced contexts (e.g., initialization, post-processing).

Parameters:

y (float or np.ndarray) – Positive input value(s).

Returns:

Unconstrained value(s) that map to y under squareplus.

Return type:

float or np.ndarray

See also

inverse_squareplus

JAX version.

get_raw_value(value)[source]

Convert a desired parameter value to raw (log2) space.

Many parameters (k, width, amplitude) use 2^raw reparameterization so that the raw parameter can take any real value while the effective value is always positive. This function inverts that: raw = log2(value).

Parameters:

value (float) – Desired positive parameter value.

Returns:

Log2 of the input, for use as the raw parameter.

Return type:

float

See also

get_log_amplitude

Similar but divides by memory_days first.

get_log_amplitude(amplitude, memory_days)[source]

Convert desired amplitude to raw log_amplitude parameter.

The effective amplitude is 2^log_amplitude * memory_days, so to achieve a target amplitude: log_amplitude = log2(amplitude / memory_days).

Parameters:
  • amplitude (float) – Desired amplitude value.

  • memory_days (float) – Memory length in days (used to decouple amplitude from memory scale).

Returns:

Raw log_amplitude parameter value.

Return type:

float

init_params_singleton(initial_values_dict, n_tokens, n_subsidary_rules=0, chunk_period=60, log_for_k=True)[source]

Initialize a single parameter set from human-readable initial values.

Converts intuitive values (memory_days, k_per_day, etc.) into the internal reparameterized form (logit_lamb, log_k, etc.) as 1-D JAX arrays of length n_tokens + n_subsidary_rules.

Parameters:
  • initial_values_dict (dict) – Human-readable initial values. Required keys: - 'initial_k_per_day': Weight update aggressiveness - 'initial_memory_length': EWMA memory in days Optional keys: - 'initial_memory_length_delta': Additional memory for alt lambda - 'initial_weights_logits': Starting weight logits - 'initial_log_amplitude': Channel amplitude (log2 scale) - 'initial_raw_width': Channel width (log2 scale) - 'initial_raw_exponents': Power exponents (squareplus space) - 'initial_pre_exp_scaling': Pre-exponential scaling (logit space)

  • n_tokens (int) – Number of assets in the pool.

  • n_subsidary_rules (int, optional) – Number of subsidiary rules (for composite pools). Default is 0.

  • chunk_period (int, optional) – Time between price observations in minutes. Default is 60.

  • log_for_k (bool, optional) – If True, use log_k parameterization; if False, use linear k. Default is True.

Returns:

Parameter dict with keys: 'log_k' (or 'k'), 'logit_lamb', 'logit_delta_lamb', 'initial_weights_logits', 'log_amplitude', 'raw_width', 'raw_exponents', 'logit_pre_exp_scaling', 'subsidary_params'. All values are 1-D jnp.ndarray.

Return type:

dict

See also

init_params

Multi-set version with noise injection.

fill_in_missing_values_from_init_singleton(params, initial_values_dict, n_tokens, n_subsidary_rules=0, chunk_period=60, log_for_k=True)[source]

Fill in missing values in parameters from initial values.

Parameters:
  • params (dict) – The parameters dictionary to update.

  • initial_values_dict (dict) – The initial values dictionary.

  • n_tokens (int) – The number of tokens.

  • n_subsidary_rules (int, optional) – The number of subsidary rules. Default is 0.

  • chunk_period (int, optional) – The chunk period. Default is 60.

  • log_for_k (bool, optional) – Whether to use log scale for k parameter. Default is True.

Returns:

The updated parameters dictionary.

Return type:

dict

init_params(initial_values_dict, n_tokens, n_subsidary_rules=0, chunk_period=60, n_parameter_sets=1, noise='gaussian')[source]

Initialize multiple parameter sets from human-readable initial values.

Creates n_parameter_sets copies of the base parameters. When n_parameter_sets > 1, Gaussian noise is added to all rows except the first (which remains at the exact initial values). This is the legacy ensemble initialization method; for more control, see EnsembleAveragingHook.

Parameters:
  • initial_values_dict (dict) – Human-readable initial values (same format as init_params_singleton).

  • n_tokens (int) – Number of assets in the pool.

  • n_subsidary_rules (int, optional) – Number of subsidiary rules. Default is 0.

  • chunk_period (int, optional) – Time between price observations in minutes. Default is 60.

  • n_parameter_sets (int, optional) – Number of parameter sets (ensemble members). Default is 1.

  • noise (str, optional) – Noise type for diversification. Only 'gaussian' is supported. Default is 'gaussian'.

Returns:

Parameter dict with 2-D arrays of shape (n_parameter_sets, n_pool_members) for each parameter key.

Return type:

dict

See also

init_params_singleton

Single parameter set initialization.

fill_in_missing_values_from_init(params, initial_values_dict, n_tokens, n_subsidary_rules=0, chunk_period=60, n_parameter_sets=1)[source]

Fill in missing values in parameters from initial values.

Parameters:
  • params (dict) – The parameters dictionary to update.

  • initial_values_dict (dict) – The initial values dictionary.

  • n_tokens (int) – The number of tokens.

  • n_subsidary_rules (int, optional) – The number of subsidary rules. Default is 0.

  • chunk_period (int, optional) – The chunk period. Default is 60.

  • n_parameter_sets (int, optional) – The number of parameter sets. Default is 1.

Returns:

The updated parameters dictionary.

Return type:

dict

calc_hessian_from_loaded_params(params, partial_fixed_training_step)[source]

Calculate the Hessian matrix from the loaded parameters.

Parameters:
  • params (dict) – A dictionary of parameters.

  • partial_fixed_training_step (callable) – A function representing the partial fixed training step.

Returns:

The Hessian matrix calculated from the loaded parameters.

Return type:

numpy.ndarray

load_result_array(run_location, key='objective', recalc_hess=False)[source]

Load simulation results from a JSON file and return run fingerprint and results array.

Parameters:
  • run_location (str) – Path to the JSON results file.

  • key (str, optional) – Which value to extract from results. Default is “objective”.

  • recalc_hess (bool, optional) – Whether to recalculate Hessian trace values. Default is False.

Returns:

A tuple containing:
run_fingerprintdict

Configuration details and metadata for the simulation run

valueslist

Array of values extracted from results based on specified key

Return type:

tuple

get_objective_scalar(obj, metric_key='returns_over_uniform_hodl')[source]

Extract a scalar value from an objective that may be a dict or number.

Use this when you have a single objective value (after retrieve_best has indexed into the parameter sets) and need a float.

Parameters:
  • obj (float, int, or dict) – The objective value - either a scalar (old format) or a dict of metrics (new format)

  • metric_key (str) – The key to extract from dict objectives. Defaults to “returns_over_uniform_hodl”.

Returns:

The scalar objective value

Return type:

float

Examples

>>> get_objective_scalar(0.1)  # old format
0.1
>>> get_objective_scalar({"return": 0.1, "sharpe": 0.5})  # new format
0.1
load_manually(run_location, load_method='last', recalc_hess=False, min_test=0.0, return_as_iterables=False, metric_key='returns_over_uniform_hodl', use_continuous_test=True)[source]

Load and process parameter sets from a JSON results file with custom loading methods.

Parameters:
  • run_location (str) – Path to the JSON results file.

  • load_method (str, optional) – Method for selecting parameter sets. One of: ‘last’ - Returns the last parameter set ‘best_objective’ - Returns set with highest overall objective ‘best_train_objective’ - Returns set with highest training objective ‘best_test_objective’ - Returns set with highest test objective ‘best_train_min_test_objective’ - Returns set with highest training objective that meets minimum test threshold. Defaults to ‘last’.

  • recalc_hess (bool, optional) – Whether to recalculate Hessian trace values. Defaults to False.

  • min_test (float, optional) – Minimum test objective threshold for ‘best_train_min_test_objective’ method. Defaults to 0.0.

  • metric_key (str, optional) – For new format files with metric dicts, specifies which metric to use. Options include: “return”, “sharpe”, “jax_sharpe”, “returns_over_hodl”, “returns_over_uniform_hodl”, “annualised_returns”, “calmar”, “sterling”, “ulcer”. Ignored for old format files with simple numeric objectives. Defaults to “returns_over_uniform_hodl”.

  • use_continuous_test (bool, optional) – If True and continuous_test_metrics is available, use it instead of test_objective for test-related load methods. Defaults to True.

Returns:

Two-element tuple containing: - dict: Loaded parameters - int: The index of the selected parameter set

Return type:

tuple

retrieve_best(data_location, load_method, re_calc_hess, min_alt_obj=0.0, return_as_iterables=False)[source]

Retrieve the best parameters from a training run.

Loads parameters using the specified method and extracts the best parameter set based on the context (index of best performing parameters). Removes training metadata (step, hessian_trace, etc.) from the returned params.

Parameters:
  • data_location (str) – Path to the directory containing saved training results.

  • load_method (str) – Method for loading parameters. Options include: - ‘last’: Load the most recent checkpoint - ‘best_train_objective’: Load checkpoint with best training objective - ‘best_test_objective’: Load checkpoint with best test objective

  • re_calc_hess (bool) – Whether to recalculate hessian information when loading.

  • min_alt_obj (float, optional) – Minimum alternative objective threshold. Defaults to 0.0.

  • return_as_iterables (bool, optional) – If True, returns lists of all loaded params and steps. If False, returns only the first (best) params and step. Defaults to False.

Returns:

  • params (dict or list of dict) – Best parameter dictionary (or list if return_as_iterables=True). Training metadata fields are removed.

  • steps (int or list of int) – Training step(s) at which the parameters were saved.

load_or_init(run_fingerprint, initial_values_dict, n_tokens, n_subsidary_rules, recalc_hess=False, chunk_period=60, force_init=False, load_method='last', n_parameter_sets=1, results_dir='./results/', partial_fixed_training_step=None)[source]

Load or initialize parameters for the AMM simulator.

Parameters:
  • run_fingerprint (str) – The fingerprint of the run.

  • initial_values_dict (dict) – The initial values dictionary.

  • n_tokens (int) – The number of tokens.

  • n_subsidary_rules (int) – The number of subsidiary rules.

  • recalc_hess (bool, optional) – Whether to recalculate the Hessian. Default is False.

  • chunk_period (int, optional) – The chunk period. Default is 60.

  • force_init (bool, optional) – Whether to force initialization. Default is False.

  • load_method (str, optional) – The method to use for loading. Default is “last”.

  • n_parameter_sets (int, optional) – The number of parameter sets. Default is 1.

  • results_dir (str, optional) – The directory for results. Default is “./results/”.

  • partial_fixed_training_step (callable, optional) – The partial fixed training step. Default is None.

Returns:

A tuple containing:
paramsdict

The loaded or initialized parameters

loadedbool

Whether the parameters were loaded (True) or initialized (False)

Return type:

tuple

load(run_location, initial_values_dict, n_tokens, n_subsidary_rules, chunk_period=60, load_method='last', n_parameter_sets=1)[source]

Load parameters from a file and fill in missing values based on initial values.

Parameters:
  • run_location (str) – The location of the file containing the parameters.

  • initial_values_dict (dict) – A dictionary of initial values.

  • n_tokens (int) – The number of tokens.

  • n_subsidary_rules (int) – The number of subsidiary rules.

  • chunk_period (int, optional) – The chunk period. Default is 60.

  • load_method ({'last', 'best_objective', 'best_train_objective'}, optional) – The method to use for loading parameters. Default is ‘last’.

  • n_parameter_sets (int, optional) – The number of parameter sets. Default is 1.

Returns:

A tuple containing:
paramsdict

The loaded parameters

contextint or None

The context index for the loaded parameters

Return type:

tuple

Raises:
make_composite_run_params(composite_params, list_of_subsidary_pool_run_dicts, initial_values_dict, n_parameter_sets)[source]

Create composite run parameters for the AMM simulator.

Parameters:
  • composite_params (dict) – The composite parameters for the AMM simulator.

  • list_of_subsidary_pool_run_dicts (list) – A list of dictionaries containing the parameters for each subsidiary pool run.

  • initial_values_dict (dict) – The initial values dictionary for the AMM simulator.

  • n_parameter_sets (int) – The number of parameter sets.

Returns:

The composite run parameters for the AMM simulator.

Return type:

dict

create_product_of_linspaces(params, keys_ranges, num_points_per_key, inverse_funcs=None)[source]

Create a product of linspaces for chosen keys in the params dict.

Parameters:
  • params (dict) – The dictionary containing initial parameter values.

  • keys_ranges (dict) – The dictionary containing high and low values for each key.

  • num_points_per_key (dict) – The dictionary containing the number of points for each key.

  • inverse_funcs (dict, optional) – A dictionary of inverse functions for each key.

Returns:

A list of dictionaries with all combinations of linspace values for the chosen keys.

Return type:

list

create_product_of_arrays(params, keys_arrays)[source]

Create a product of arrays for chosen keys in the params dict.

Parameters:
  • params (dict) – The dictionary containing initial parameter values.

  • keys_arrays (dict) – The dictionary containing the points for each key.

Returns:

A list of dictionaries with all combinations of linspace values for the chosen keys.

Return type:

list

generate_run_fingerprint_combinations(run_fingerprint, keys_ranges=None, num_points_per_key=None, inverse_funcs=None)[source]

Generate run fingerprint combinations with specified ranges and scaling.

Parameters:
  • run_fingerprint (dict) – The base run fingerprint.

  • keys_ranges (dict, optional) – The dictionary containing high and low values for each key. Defaults to logarithmic ranges for ‘arb_frequency’.

  • num_points_per_key (dict, optional) – The dictionary containing the number of points for each key. Defaults to 10 points for each key.

  • inverse_funcs (dict, optional) – A dictionary of inverse functions for each key. Defaults to logarithmic scaling for ‘arb_frequency’.

Returns:

A list of dictionaries with all combinations of run fingerprint values.

Return type:

list

make_log_range_with_zero(x)[source]

Compute the exponential of a given value, with a special case for zero.

Parameters:

x (float) – The input value for which the exponential is to be computed.

Returns:

The exponential of the input value x, or zero if x is zero.

Return type:

float

combine_param_combinations(param_combinations, n_parameter_sets)[source]

Combine single-row jnp arrays in param_combinations into multi-row jnp arrays.

Parameters:
  • param_combinations (list) – List of dictionaries with single-row jnp arrays.

  • n_parameter_sets (int) – Number of parameter sets to combine into each dictionary.

Returns:

List of dictionaries with multi-row jnp arrays.

Return type:

list

split_param_combinations(param_combinations)[source]

Split multi-row jnp arrays in param_combinations into single-row jnp arrays.

Parameters:

param_combinations (list) – List of dictionaries with multi-row jnp arrays.

Returns:

List of dictionaries with single-row jnp arrays.

Return type:

list

make_vmap_in_axes_dict(input_dict, in_axes, keys_to_recur_on, keys_with_no_vamp=[], n_repeats_of_recurred=0)[source]

Create a vmap in_axes specification dict matching a parameter dict structure.

Constructs the nested dict/list structure that jax.vmap expects for its in_axes argument when vectorizing over a dict of parameters. Handles recursive structure for subsidiary parameters.

Parameters:
  • input_dict (dict) – Parameter dictionary whose structure to mirror.

  • in_axes (int) – Axis to vectorize over (typically 0 for the parameter-set dimension).

  • keys_to_recur_on (list of str) – Keys (e.g., 'subsidary_params') that contain nested parameter dicts requiring recursive axis specification.

  • keys_with_no_vamp (list of str, optional) – Keys that should not be vectorized (axis set to None). Default is [].

  • n_repeats_of_recurred (int, optional) – Number of subsidiary parameter dicts. Default is 0.

Returns:

Nested dict matching the structure of input_dict with integer axes or None for each leaf.

Return type:

dict

generate_params_combinations(initial_values_dict, n_tokens, n_subsidary_rules, chunk_period, n_parameter_sets, k_per_day_range, memory_days_range, num_points_k_per_day=10, num_points_memory_days=10)[source]

Generate parameter combinations with linearly-spaced values of k_per_day and memory_days.

Parameters:
  • initial_values_dict (dict) – The initial values dictionary.

  • n_tokens (int) – The number of tokens.

  • n_subsidary_rules (int) – The number of subsidary rules.

  • chunk_period (int) – The chunk period.

  • n_parameter_sets (int) – The number of parameter sets.

  • k_per_day_range (tuple) – The range (low, high) for k_per_day.

  • memory_days_range (tuple) – The range (low, high) for memory_days.

  • num_points_k_per_day (int, optional) – The number of points for k_per_day linspace. Defaults to 10.

  • num_points_memory_days (int, optional) – The number of points for memory_days linspace. Defaults to 10.

Returns:

A list of dictionaries with all combinations of parameter values.

Return type:

list

process_initial_values(initial_values_dict, key, n_assets, n_parameter_sets, force_scalar=False)[source]

Extract and broadcast a parameter value to the correct shape.

Handles flexible input formats: scalar (broadcast to all assets and sets), per-asset vector (broadcast across sets), or full matrix. Used by the schema-aware initialization path.

Parameters:
  • initial_values_dict (dict) – Dictionary containing initial parameter values.

  • key (str) – Parameter name to extract.

  • n_assets (int) – Number of assets (columns).

  • n_parameter_sets (int) – Number of parameter sets / ensemble members (rows).

  • force_scalar (bool, optional) – If True, treat value as a scalar even if it’s array-like, producing shape (n_parameter_sets,) instead of (n_parameter_sets, n_assets).

Returns:

Array of shape (n_parameter_sets, n_assets) or (n_parameter_sets,) if force_scalar=True.

Return type:

np.ndarray

Raises:

ValueError – If key is not in initial_values_dict or has incompatible shape.

convert_parameter_values(params, run_fingerprint, max_memory_days=None)[source]

Convert raw (reparameterized) parameters to human-readable and on-chain formats.

Applies the inverse reparameterizations (logit → lambda → memory_days, log2 → k, squareplus → exponents, etc.) and produces both float64 values and BD18 fixed-point string representations suitable for on-chain deployment.

Parameters:
  • params (dict) – Raw parameter dictionary (e.g., 'logit_lamb', 'log_k', 'raw_exponents').

  • run_fingerprint (dict) – Run configuration, must include 'chunk_period'.

  • max_memory_days (float, optional) – Maximum memory days for lambda clipping. If None, uses run_fingerprint['max_memory_days'] (default 365).

Returns:

{'values': {...}, 'strings': {...}} where each inner dict maps human-readable parameter names ('lamb', 'k', 'exponents', 'width', 'amplitude', 'pre_exp_scaling') to lists. 'values' contains float64 lists; 'strings' contains BD18 (18-decimal fixed-point integer) string representations.

Return type:

dict

Notes

BD18 format multiplies the float value by 10^18 and represents as an integer string, matching the Solidity uint256 representation used by the on-chain QuantAMM contracts. The conversion uses string manipulation to avoid float64 overflow from direct multiplication by 1e18.

Result Exporter

get_run_location(run_fingerprint)[source]

Generates a unique identifier string based on the provided run fingerprint.

The function takes a dictionary representing the run fingerprint, converts it to a JSON string, and then computes its SHA-256 hash. The resulting hash is used to create a unique identifier string with a “run_” prefix.

Parameters:

run_fingerprint (dict) – A dictionary representing the run fingerprint.

Returns:

A unique identifier string formatted as “run_” followed by a SHA-256 hash

Return type:

str

append_json(new_data, filename)[source]

Append new data to a JSON file.

This function reads the existing data from a JSON file, appends the new data to it, and then writes the updated data back to the file.

Parameters:
  • new_data (dict) – The new data to be appended to the JSON file.

  • filename (str) – The path to the JSON file.

Raises:
append_list_json(new_data, filename)[source]

Append new data to a JSON file.

This function reads the existing data from a JSON file, appends the new data to it, and then writes the updated data back to the file.

Parameters:
  • new_data (list) – The new data to be appended to the JSON file.

  • filename (str) – The path to the JSON file.

Raises:
save_multi_params(run_fingerprint, params, test_objective, train_objective, objective, local_learning_rate, iterations_since_improvement, steps, continuous_test_metrics=None, validation_metrics=None, sorted_tokens=True)[source]

Save multiple parameter sets along with their associated metrics to a JSON file.

Parameters:
  • run_fingerprint (dict) – Dictionary containing run configuration details used to generate unique run location

  • params (list) – List of parameter dictionaries to save

  • test_objective (list) – List of objective values/metrics on test set for each parameter set

  • train_objective (list) – List of objective values/metrics on training set for each parameter set

  • objective (list) – List of overall objective values for each parameter set

  • local_learning_rate (list) – List of learning rates used for each parameter set

  • iterations_since_improvement (list) – List tracking iterations without improvement for each parameter set

  • steps (list) – List of step counts for each parameter set

  • continuous_test_metrics (list, optional) – List of continuous test metrics for each parameter set

  • validation_metrics (list, optional) – List of validation metrics for each parameter set (when using val_fraction > 0)

  • sorted_tokens (bool, optional) – Whether tokens are sorted alphabetically, by default True

Notes

Saves the data to a JSON file at ./results/run_<sha256_hash>.json where the hash is generated from the run_fingerprint using SHA-256. If file exists, appends new parameter sets to existing data Converts JAX arrays to numpy arrays before saving

save_optuna_results_sgd_format(run_fingerprint, study, n_assets, sorted_tokens=True)[source]

Save optuna study results in the same format as SGD training results.

This allows optuna-optimized parameters to be loaded and analyzed with the same tools used for SGD-trained parameters.

Parameters:
  • run_fingerprint (dict) – Dictionary containing run configuration details

  • study (optuna.Study) – Completed optuna study object

  • n_assets (int) – Number of assets in the pool (needed to reconstruct array params)

  • sorted_tokens (bool, optional) – Whether tokens are sorted alphabetically, by default True

Notes

Saves to ./results/run_<sha256_hash>.json in the same format as save_multi_params, allowing unified result analysis.

save_params(run_fingerprint, params, step, test_objective, train_objective, objective, hess, local_learning_rate, iterations_since_improvement, sorted_tokens=True)[source]

Save optimization parameters and results to a JSON file.

Parameters:
  • run_fingerprint (dict) – Dictionary containing run configuration details

  • params (dict) – Dictionary of optimization parameters

  • step (int) – Current optimization step count

  • test_objective (float) – Objective function value on test data

  • train_objective (float) – Objective function value on training data

  • objective (float) – Overall objective function value

  • hess (float) – Trace of the Hessian matrix

  • local_learning_rate (float) – Current learning rate

  • iterations_since_improvement (int) – Number of iterations without improvement

  • sorted_tokens (bool, optional) – Whether tokens are sorted alphabetically, by default True

Notes

Saves the data to a JSON file at ./results/run_<sha256_hash>.json where the hash is generated from the run_fingerprint using SHA-256. If file exists, appends new parameter set to existing data Converts JAX arrays to numpy arrays before saving

Windowing Utilities

get_indices(start_index, bout_length, len_prices, key, optimisation_settings)[source]

Get indices for sampling data windows during training.

Parameters:
  • start_index (int) – Starting index position in the data

  • bout_length (int) – Length of each training window/bout

  • len_prices (int) – Total length of the price data

  • key (jax.random.PRNGKey) – JAX random number generator key

  • optimisation_settings (dict) – Dictionary containing optimization settings with keys: - batch_size: Number of windows to sample - training_data_kind: Type of training data (‘historic’ or ‘mc’) - sample_method: Method for sampling windows (‘exponential’ or ‘uniform’) - max_mc_version: Maximum MC version number (only used if training_data_kind=’mc’)

Returns:

  • start_indexesjnp.ndarray

    Array of sampled starting indices, shape (batch_size, 2) for historic data or (batch_size, 3) for MC data

  • keyjax.random.PRNGKey

    Updated random number generator key

Return type:

tuple

raw_trades_to_trade_array(raw_trades, start_date_string, end_date_string, tokens)[source]

Convert raw trade data to a structured trade array.

This function takes raw trade data and converts it into a pandas DataFrame with a continuous range of Unix timestamps. Each row in the DataFrame represents a minute, and trades are mapped to their corresponding timestamps.

Parameters:
  • raw_trades (pandas df) – Raw trades, where each trade is a row containing unix_timestamp, token_in (str), token_out (str), amount_in).

  • start_time (str) – The start date time in format “%Y-%m-%d %H:%M:%S”.

  • end_time (str) – The end date time in format “%Y-%m-%d %H:%M:%S”.

  • tokens (list of str) – The tokens of the run

Returns:

A numpy array with columns ‘token in’, ‘token out’, and ‘amount in’. The index is a continuous range of Unix timestamps from start_unix to end_unix at minute intervals. Timestamps without trades are filled with zeros.

Return type:

numpy array

raw_fee_like_amounts_to_fee_like_array(raw_inputs, start_date_string, end_date_string, names, fill_method='base')[source]

Convert raw fee-like data to a structured fee-like array.

Takes raw fee-like data (fees, gas costs, arb fees) and converts it into a pandas DataFrame with a continuous range of Unix timestamps. Each row represents a minute, with trades mapped to their corresponding timestamps.

Parameters:
  • raw_inputs (pandas.DataFrame) – Raw fee-like data, where each row contains unix_timestamp and the fee-like amount with given column name

  • start_time (str) – The start date time in format “%Y-%m-%d %H:%M:%S”

  • end_time (str) – The end date time in format “%Y-%m-%d %H:%M:%S”

  • names (list of str) – The names of columns in raw_inputs of the fee-like amount

  • fill_method (str) – The method to fill in missing values. Options: - ‘base’: fills rows which have no values with 0 - ‘ffill’: fills with the last non-zero value

Returns:

Array giving the fee-like values over time. The index is a continuous range of Unix timestamps from start_unix to end_unix at minute intervals. Timestamps without values are filled with zeros.

Return type:

numpy.ndarray

filter_coarse_weights_by_data_indices(coarse_weights, data_dict)[source]

Slice coarse (chunk-period) weights to match the time range in data_dict.

Used when pre-computed coarse weights are loaded from a previous run or from on-chain data, and need to be aligned with the current training/evaluation window.

Parameters:
  • coarse_weights (dict) – Must contain 'unix_values' (timestamps) and 'weights' array of shape (T_coarse, n_assets).

  • data_dict (dict) – Must contain 'unix_values', 'start_idx', and 'end_idx'.

Returns:

Shallow copy of coarse_weights with 'weights' sliced to the matching time range.

Return type:

dict

filter_reserves_by_data_indices(reserves, unix_values, data_dict)[source]

Slice a reserves array to match the time range in data_dict.

Looks up the unix timestamps at data_dict["start_idx"] and data_dict["end_idx"] - 1 in unix_values, then returns the corresponding slice of reserves.

Parameters:
  • reserves (np.ndarray) – Full reserves array, shape (T, n_assets).

  • unix_values (np.ndarray) – Unix timestamps corresponding to each row of reserves.

  • data_dict (dict) – Must contain unix_values, start_idx, and end_idx.

Returns:

Sliced reserves matching the data_dict time range.

Return type:

np.ndarray

filter_reserves_by_given_timestamp(reserves, unix_values, timestamp)[source]

Extract the reserves row at a specific unix timestamp.

Parameters:
  • reserves (np.ndarray) – Full reserves array, shape (T, n_assets).

  • unix_values (np.ndarray) – Unix timestamps corresponding to each row.

  • timestamp (int) – Unix timestamp (milliseconds) to look up.

Returns:

Reserves at the given timestamp, shape (n_assets,).

Return type:

np.ndarray

Raises:

IndexError – If timestamp is not found in unix_values.