"""Parameter utilities for strategy parameterization, serialization, and loading.
This module handles the full lifecycle of strategy parameters:
- **Initialization**: ``init_params`` / ``init_params_singleton`` create parameter dicts
from human-readable initial values (memory days, k per day) by converting to the
internal reparameterized form (logit_lamb, log_k, etc.).
- **Reparameterization**: Functions like ``calc_lamb``, ``calc_alt_lamb``, ``squareplus``,
and their inverses convert between human-interpretable values and the unconstrained
spaces used for gradient-based optimization.
- **Serialization**: ``NumpyEncoder``, ``dict_of_jnp_to_np``, ``dict_of_jnp_to_list``,
``dict_of_np_to_jnp`` handle conversion between JAX arrays, NumPy arrays, and
JSON-serializable Python types.
- **Loading**: ``load_or_init``, ``load``, ``load_manually``, ``retrieve_best`` load
saved training checkpoints with various selection strategies (best train, best test,
best-train-above-test-threshold, etc.).
- **Grid generation**: ``create_product_of_linspaces``, ``generate_params_combinations``
produce parameter grids for heatmap evaluations.
The key reparameterizations are:
- **lambda (λ)**: EWMA decay factor in [0, 1], stored as ``logit_lamb = log(λ/(1-λ))``.
Converted to/from human-readable ``memory_days`` via cubic-root inversion.
- **k**: Weight update aggressiveness, stored as ``log_k = log2(k / memory_days)``.
This decouples scale from memory length.
- **squareplus**: Smooth, non-negative activation ``(x + sqrt(x² + 4)) / 2``,
an algebraic (non-transcendental) replacement for softplus. Used for exponent params.
Notes
-----
The ``memory_days ↔ lambda`` conversion involves solving a cubic equation analytically.
Both NumPy (``memory_days_to_lamb``) and JAX (``jax_memory_days_to_lamb``) versions
exist; the NumPy version includes safe division guards for zero memory days, while the
JAX version relies on ``jnp.where`` for the zero case.
"""
import os
import json
import hashlib
from copy import deepcopy
from itertools import product
import numpy as np
import jax.numpy as jnp
from jax import jit, lax
from jax import config
from quantammsim.training.hessian_trace import hessian_trace
[docs]
def squareplus(x):
"""Algebraic (non-transcendental) replacement for softplus.
Computes ``(x + sqrt(x² + 4)) / 2``, which maps R → R⁺ smoothly. Unlike softplus
(``log(1 + exp(x))``), squareplus avoids transcendental functions and is thus
cheaper to differentiate through and more JIT-friendly.
Parameters
----------
x : jnp.ndarray or float
Input value(s).
Returns
-------
jnp.ndarray or float
Non-negative output(s), always > 0.
References
----------
Barron, J.T. (2021). "Squareplus: A Softplus-Like Algebraic Rectifier."
arXiv:2112.11687.
See Also
--------
inverse_squareplus : Inverse mapping R⁺ → R.
"""
return lax.mul(0.5, lax.add(x, lax.sqrt(lax.add(lax.square(x), 4.0))))
# again, this only works on startup!
config.update("jax_enable_x64", True)
np.seterr(all="raise")
np.seterr(under="print")
[docs]
def check_run_fingerprint(run_fingerprint):
"""
Check that the run fingerprint is not malformed.
Parameters
----------
run_fingerprint : dict
The run fingerprint to validate.
Returns
-------
None
Raises
------
AssertionError
If weight_interpolation_period is greater than chunk_period.
"""
assert (
run_fingerprint["weight_interpolation_period"]
<= run_fingerprint["chunk_period"]
)
[docs]
def default_set_or_get(dictionary, key, default, augment=True):
"""
Retrieves the value for a given key from a dictionary. If the key does not exist,
it sets the key to a default value and returns the default value.
Parameters
----------
dictionary : dict
The dictionary to search for the key.
key : str
The key to look up in the dictionary.
default : Any
The default value to set and return if the key is not found.
augment : bool, optional
If True, the default value is added to the dictionary if the key is not found.
Default is True.
Returns
-------
Any
The value associated with the key if it exists, otherwise the default value.
"""
value = dictionary.get(key)
if value is None:
if augment:
dictionary[key] = default
return default
return value
[docs]
def default_set(dictionary, key, default):
"""
Sets a default value for a given key in a dictionary if the key does not already exist.
Parameters
----------
dictionary : dict
The dictionary to update.
key : str
The key to check in the dictionary.
default : Any
The default value to set if the key is not present.
Returns
-------
None
"""
value = dictionary.get(key)
if value is None:
dictionary[key] = default
[docs]
def recursive_default_set(target_dict, default_dict):
"""
Recursively sets default values in a target dictionary based on a default dictionary.
Parameters
----------
target_dict : dict
The dictionary to update with default values.
default_dict : dict
The dictionary containing the default values.
Returns
-------
None
"""
for key, value in default_dict.items():
if isinstance(value, dict):
if key not in target_dict:
target_dict[key] = {}
recursive_default_set(target_dict[key], value)
else:
default_set(target_dict, key, value)
[docs]
class NumpyEncoder(json.JSONEncoder):
"""JSON encoder that handles NumPy scalar and array types.
Extends ``json.JSONEncoder`` to serialize ``np.integer`` as ``int``,
``np.floating`` as ``float``, and ``np.ndarray`` as nested lists.
Used when saving training checkpoints and run fingerprints to JSON.
Examples
--------
>>> import json, numpy as np
>>> json.dumps({"val": np.float64(0.5)}, cls=NumpyEncoder)
'{"val": 0.5}'
"""
def default(self, o):
"""
Convert numpy types to Python native types.
Parameters
----------
o : Any
The object to encode.
Returns
-------
Any
The encoded object.
"""
if isinstance(o, np.integer):
return int(o)
elif isinstance(o, np.floating):
return float(o)
elif isinstance(o, np.ndarray):
return o.tolist()
return json.JSONEncoder.default(self, o)
[docs]
def get_run_location(run_fingerprint):
"""
Get the run location based on the run fingerprint.
Parameters
----------
run_fingerprint : dict
The run fingerprint.
Returns
-------
str
The run location.
"""
run_location = "run_" + str(
hashlib.sha256(
json.dumps(run_fingerprint, sort_keys=True).encode("utf-8"),
usedforsecurity=False,
).hexdigest()
)
return run_location
[docs]
def dict_of_jnp_to_np(dictionary):
"""
Convert dictionary values from jax numpy arrays to numpy arrays.
Parameters
----------
dictionary : dict
The dictionary to convert.
Returns
-------
dict
The converted dictionary.
"""
for key in dictionary:
if key != "subsidary_params":
dictionary[key] = np.array(dictionary[key])
return dictionary
[docs]
def dict_of_jnp_to_list(dictionary):
"""
Convert dictionary values from jax numpy arrays to lists.
Parameters
----------
dictionary : dict
The dictionary to convert.
Returns
-------
dict
The converted dictionary.
"""
for key in dictionary:
if key != "subsidary_params":
dictionary[key] = np.array(dictionary[key]).tolist()
return dictionary
[docs]
def dict_of_np_to_jnp(dictionary):
"""
Convert dictionary values from numpy arrays to jax numpy arrays.
Parameters
----------
dictionary : dict
The dictionary to convert.
Returns
-------
dict
The converted dictionary.
"""
for key in dictionary:
if key != "subsidary_params":
dictionary[key] = jnp.array(dictionary[key])
return dictionary
[docs]
@jit
def lamb_to_memory(lamb):
"""Convert EWMA decay factor lambda to the effective memory length (unitless).
The EWMA weighting kernel ``w_t = lambda^t * (1 - lambda)`` has a
characteristic memory scale that grows with lambda. This function
inverts the cubic relationship used in quantammsim's parameterisation:
.. math::
\\text{memory} = 4 \\cdot \\sqrt[3]{\\frac{6 \\lambda}{(1 - \\lambda)^3}}
To convert to days, use :func:`lamb_to_memory_days` which divides by
``2 * chunk_period / 1440``.
Parameters
----------
lamb : float or jnp.ndarray
EWMA decay factor in (0, 1).
Returns
-------
float or jnp.ndarray
Unitless memory scale.
See Also
--------
lamb_to_memory_days : Returns memory in days.
memory_days_to_lamb : Inverse mapping (days -> lambda).
"""
memory = jnp.cbrt(6 * lamb / ((1 - lamb) ** 3.0)) * 4.0
return memory
[docs]
def memory_days_to_lamb(memory_days, chunk_period=60):
"""
Convert memory days to lambda value.
Parameters
----------
memory_days : float
The memory days value.
chunk_period : int, optional
The chunk period. Default is 60.
Returns
-------
float
The lambda value.
"""
scaled_memory_days = (1440.0 * memory_days / (2.0 * chunk_period)) ** 3 / 6.0
smd = scaled_memory_days
smd2 = scaled_memory_days**2
smd3 = scaled_memory_days**3
smd4 = scaled_memory_days**4
numerator_1 = np.cbrt((np.sqrt(3 * (27 * smd4 + 4 * smd3)) - 9 * smd2))
denominator_1 = np.cbrt(2) * 3 ** (2.0 / 3.0) * smd
numerator_2 = np.cbrt((2 / 3))
denominator_2 = numerator_1
# Handle division by zero by checking denominator
safe_div1 = np.divide(
numerator_1,
denominator_1,
out=np.zeros_like(numerator_1),
where=denominator_1 != 0,
)
safe_div2 = np.divide(
np.broadcast_to(numerator_2, denominator_2.shape),
denominator_2,
out=np.zeros_like(np.broadcast_to(numerator_2, denominator_2.shape)),
where=denominator_2 != 0,
)
lamb = np.nan_to_num(safe_div1 - safe_div2 + 1.0, nan=0.0, posinf=1.0, neginf=0.0)
return np.where(memory_days == 0.0, 0.0, lamb)
[docs]
def jax_memory_days_to_lamb(memory_days, chunk_period=60):
"""
Convert memory days to lambda value using JAX operations.
Parameters
----------
memory_days : float
The memory days value.
chunk_period : int, optional
The chunk period. Default is 60.
Returns
-------
float
The lambda value.
"""
scaled_memory_days = (1440.0 * memory_days / (2.0 * chunk_period)) ** 3 / 6.0
smd = scaled_memory_days
smd2 = scaled_memory_days**2
smd3 = scaled_memory_days**3
smd4 = scaled_memory_days**4
numerator_1 = jnp.cbrt((jnp.sqrt(3 * (27 * smd4 + 4 * smd3)) - 9 * smd2))
denominator_1 = jnp.cbrt(2) * 3 ** (2.0 / 3.0) * smd
numerator_2 = jnp.cbrt((2 / 3))
denominator_2 = numerator_1
lamb = numerator_1 / denominator_1 - numerator_2 / denominator_2 + 1.0
return jnp.where(memory_days==0.0, 0.0, lamb)
[docs]
def memory_days_to_logit_lamb(memory_days, chunk_period=60):
"""
Convert memory days to logit lambda value.
Parameters
----------
memory_days : float
The memory days value.
chunk_period : int, optional
The chunk period. Default is 60.
Returns
-------
float
The logit lambda value.
"""
lamb = memory_days_to_lamb(memory_days, chunk_period)
logit_lamb = jnp.log(lamb / (1 - lamb))
return logit_lamb
[docs]
@jit
def lamb_to_memory_days(lamb, chunk_period):
"""Convert EWMA decay factor lambda to effective memory in days.
Applies :func:`lamb_to_memory` then rescales by ``2 * chunk_period / 1440``
to convert from unitless memory to calendar days, accounting for the
observation frequency.
Parameters
----------
lamb : float or jnp.ndarray
EWMA decay factor in (0, 1).
chunk_period : int
Time between observations in minutes (e.g., 1440 for daily, 60 for hourly).
Returns
-------
float or jnp.ndarray
Effective memory in days.
See Also
--------
lamb_to_memory : Unitless version.
memory_days_to_lamb : Inverse mapping.
lamb_to_memory_days_clipped : Clipped version with max_memory_days bound.
"""
memory_days = jnp.cbrt(6 * lamb / ((1 - lamb) ** 3.0)) * 2 * chunk_period / 1440
return memory_days
[docs]
@jit
def logistic_func(x):
"""Standard logistic sigmoid: ``sigma(x) = exp(x) / (1 + exp(x))``.
Maps R -> (0, 1). Used to convert the unconstrained ``logit_lamb``
parameter to the EWMA decay factor ``lambda`` in (0, 1).
Parameters
----------
x : float or jnp.ndarray
Unconstrained input value(s).
Returns
-------
float or jnp.ndarray
Output in (0, 1).
"""
return jnp.exp(x) / (1 + jnp.exp(x))
[docs]
@jit
def jax_logit_lamb_to_lamb(logit_lamb):
"""
Convert logit lambda to lambda value using JAX operations.
Parameters
----------
logit_lamb : float
The logit lambda value.
Returns
-------
float
The lambda value between 0 and 1.
"""
lamb = logistic_func(logit_lamb)
return lamb
[docs]
@jit
def lamb_to_memory_days_clipped(lamb, chunk_period, max_memory_days):
"""
Convert lambda value to memory days, clipped to a maximum value.
Parameters
----------
lamb : float
The lambda value.
chunk_period : int
The chunk period in minutes.
max_memory_days : float
The maximum allowed memory days.
Returns
-------
float
The clipped memory value in days.
"""
memory_days = jnp.clip(
lamb_to_memory_days(lamb, chunk_period), min=0.0, max=max_memory_days
)
return memory_days
[docs]
def calc_lamb(update_rule_parameter_dict):
"""
Calculate the lambda value from the given update rule parameter dictionary.
Parameters
----------
update_rule_parameter_dict : dict
A dictionary containing the update rule parameters.
Must include the key "logit_lamb".
Returns
-------
float
The calculated lambda value.
Raises
------
KeyError
If "logit_lamb" key is not found in update_rule_parameter_dict.
"""
if update_rule_parameter_dict.get("logit_lamb") is not None:
logit_lamb = update_rule_parameter_dict["logit_lamb"]
lamb = logistic_func(logit_lamb)
else:
raise KeyError("logit_lamb key not found in update_rule_parameter_dict")
return lamb
[docs]
def calc_lamb_from_index(update_rule_parameter_dict, logit_lamb_index):
"""
Calculate the lambda value from the given update rule parameter dictionary and index.
Parameters
----------
update_rule_parameter_dict : dict
A dictionary containing the update rule parameters.
Must include the key "logit_lamb".
logit_lamb_index : int
The index of the logit lambda value to calculate.
Returns
-------
float
The calculated lambda value.
Raises
------
KeyError
If "logit_lamb" key is not found in update_rule_parameter_dict.
"""
if update_rule_parameter_dict.get("logit_lamb") is not None:
logit_lamb = update_rule_parameter_dict["logit_lamb"][logit_lamb_index]
lamb = logistic_func(logit_lamb)
else:
raise KeyError("logit_lamb key not found in update_rule_parameter_dict")
return lamb
[docs]
def calc_alt_lamb(update_rule_parameter_dict):
"""
Calculate the alternative lambda value based on the provided update rule parameters.
Parameters
----------
update_rule_parameter_dict : dict
A dictionary containing the update rule parameters.
Must include keys:
- "logit_lamb": The logit lambda value
- "logit_delta_lamb": The logit delta lambda value
Returns
-------
float
The calculated alternative lambda value.
Raises
------
KeyError
If "logit_lamb" or "logit_delta_lamb" is not found in update_rule_parameter_dict.
"""
if update_rule_parameter_dict.get("logit_lamb") is not None:
logit_lamb = update_rule_parameter_dict["logit_lamb"]
else:
raise KeyError("logit_lamb key not found in update_rule_parameter_dict")
if update_rule_parameter_dict.get("logit_delta_lamb") is not None:
logit_delta_lamb = update_rule_parameter_dict["logit_delta_lamb"]
else:
raise KeyError("logit_delta_lamb key not found in update_rule_parameter_dict")
logit_alt_lamb = logit_delta_lamb + logit_lamb
alt_lamb = logistic_func(logit_alt_lamb)
return alt_lamb
[docs]
def inverse_squareplus(y):
"""Inverse of the squareplus activation (JAX version).
Given ``y = squareplus(x)``, recovers ``x = (y² - 1) / y``. Used to convert
from a desired positive parameter value back to the unconstrained raw parameter
for initialization.
Parameters
----------
y : float or jnp.ndarray
Positive input value(s). Must be > 0 (domain of inverse squareplus).
Returns
-------
jnp.ndarray
Unconstrained value(s) that map to ``y`` under squareplus.
See Also
--------
squareplus : Forward mapping R → R⁺.
inverse_squareplus_np : NumPy version for non-JAX contexts.
"""
y = jnp.asarray(y, dtype=jnp.float64)
return lax.div(lax.sub(lax.square(y), 1.0), y)
[docs]
def inverse_squareplus_np(y):
"""Inverse of the squareplus activation (NumPy version).
Identical to ``inverse_squareplus`` but uses NumPy operations, suitable for
use outside JAX-traced contexts (e.g., initialization, post-processing).
Parameters
----------
y : float or np.ndarray
Positive input value(s).
Returns
-------
float or np.ndarray
Unconstrained value(s) that map to ``y`` under squareplus.
See Also
--------
inverse_squareplus : JAX version.
"""
return (y**2 - 1.0) / y
[docs]
def get_raw_value(value):
"""Convert a desired parameter value to raw (log2) space.
Many parameters (k, width, amplitude) use ``2^raw`` reparameterization so that
the raw parameter can take any real value while the effective value is always
positive. This function inverts that: ``raw = log2(value)``.
Parameters
----------
value : float
Desired positive parameter value.
Returns
-------
float
Log2 of the input, for use as the raw parameter.
See Also
--------
get_log_amplitude : Similar but divides by memory_days first.
"""
return np.log2(value)
[docs]
def get_log_amplitude(amplitude, memory_days):
"""Convert desired amplitude to raw log_amplitude parameter.
The effective amplitude is ``2^log_amplitude * memory_days``, so to achieve a
target amplitude: ``log_amplitude = log2(amplitude / memory_days)``.
Parameters
----------
amplitude : float
Desired amplitude value.
memory_days : float
Memory length in days (used to decouple amplitude from memory scale).
Returns
-------
float
Raw log_amplitude parameter value.
"""
return np.log2(amplitude / memory_days)
[docs]
def init_params_singleton(
initial_values_dict, n_tokens, n_subsidary_rules=0, chunk_period=60, log_for_k=True
):
"""Initialize a single parameter set from human-readable initial values.
Converts intuitive values (memory_days, k_per_day, etc.) into the internal
reparameterized form (logit_lamb, log_k, etc.) as 1-D JAX arrays of length
``n_tokens + n_subsidary_rules``.
Parameters
----------
initial_values_dict : dict
Human-readable initial values. Required keys:
- ``'initial_k_per_day'``: Weight update aggressiveness
- ``'initial_memory_length'``: EWMA memory in days
Optional keys:
- ``'initial_memory_length_delta'``: Additional memory for alt lambda
- ``'initial_weights_logits'``: Starting weight logits
- ``'initial_log_amplitude'``: Channel amplitude (log2 scale)
- ``'initial_raw_width'``: Channel width (log2 scale)
- ``'initial_raw_exponents'``: Power exponents (squareplus space)
- ``'initial_pre_exp_scaling'``: Pre-exponential scaling (logit space)
n_tokens : int
Number of assets in the pool.
n_subsidary_rules : int, optional
Number of subsidiary rules (for composite pools). Default is 0.
chunk_period : int, optional
Time between price observations in minutes. Default is 60.
log_for_k : bool, optional
If True, use ``log_k`` parameterization; if False, use linear ``k``.
Default is True.
Returns
-------
dict
Parameter dict with keys: ``'log_k'`` (or ``'k'``), ``'logit_lamb'``,
``'logit_delta_lamb'``, ``'initial_weights_logits'``, ``'log_amplitude'``,
``'raw_width'``, ``'raw_exponents'``, ``'logit_pre_exp_scaling'``,
``'subsidary_params'``. All values are 1-D ``jnp.ndarray``.
See Also
--------
init_params : Multi-set version with noise injection.
"""
n_pool_members = n_tokens + n_subsidary_rules
if log_for_k:
log_k = jnp.array(
[np.log2(initial_values_dict["initial_k_per_day"])] * n_pool_members
)
initial_lamb = memory_days_to_lamb(
initial_values_dict["initial_memory_length"], chunk_period
)
logit_lamb_np = np.log(initial_lamb / (1.0 - initial_lamb))
logit_lamb = jnp.array([logit_lamb_np] * n_pool_members)
# lamb delta is the difference in lamb needed for
# lamb + delta lamb to give a final memory length
# of initial_memory_length + initial_memory_length_delta
if initial_values_dict.get("initial_memory_length_delta") is not None:
initial_lamb_plus_delta_lamb = memory_days_to_lamb(
initial_values_dict["initial_memory_length"]
+ initial_values_dict["initial_memory_length_delta"],
chunk_period,
)
logit_lamb_plus_delta_lamb_np = np.log(
initial_lamb_plus_delta_lamb / (1.0 - initial_lamb_plus_delta_lamb)
)
logit_delta_lamb_np = logit_lamb_plus_delta_lamb_np - logit_lamb_np
logit_delta_lamb = jnp.array([logit_delta_lamb_np] * n_pool_members)
else:
logit_delta_lamb = jnp.array([0.0] * n_pool_members)
if initial_values_dict.get("initial_weights_logits") is not None:
if type(initial_values_dict.get("initial_weights_logits")) not in [
np.array,
jnp.array,
list,
]:
initial_weights_logits = jnp.array(
[initial_values_dict["initial_weights_logits"]] * n_pool_members
)
else:
initial_weights_logits = jnp.array(
initial_values_dict["initial_weights_logits"]
)
else:
initial_weights_logits = jnp.array([0.0] * n_pool_members)
if initial_values_dict.get("initial_log_amplitude") is not None:
log_amplitude = jnp.array(
[initial_values_dict["initial_log_amplitude"]] * n_pool_members
)
else:
log_amplitude = jnp.array([0.0] * n_pool_members)
if initial_values_dict.get("initial_raw_width") is not None:
raw_width = jnp.array([initial_values_dict["initial_raw_width"]] * n_pool_members)
else:
raw_width = jnp.array([0.0] * n_pool_members)
if initial_values_dict.get("initial_raw_exponents") is not None:
raw_exponents = jnp.array(
[initial_values_dict["initial_raw_exponents"]] * n_pool_members
)
else:
raw_exponents = jnp.array([0.0] * n_pool_members)
if initial_values_dict.get("initial_pre_exp_scaling") is not None:
logit_pre_exp_scaling_np = np.log(
initial_values_dict["initial_pre_exp_scaling"]
/ (1.0 - initial_values_dict["initial_pre_exp_scaling"])
)
logit_pre_exp_scaling = jnp.array([[logit_pre_exp_scaling_np] * n_pool_members])
else:
logit_pre_exp_scaling = jnp.array([[0.0] * n_pool_members])
if log_for_k:
params = {
"log_k": log_k,
"logit_lamb": logit_lamb,
"logit_delta_lamb": logit_delta_lamb,
"initial_weights_logits": initial_weights_logits,
"log_amplitude": log_amplitude,
"raw_width": raw_width,
"raw_exponents": raw_exponents,
"logit_pre_exp_scaling": logit_pre_exp_scaling,
"subsidary_params": [],
}
else:
params = {
"k": jnp.array([initial_values_dict["initial_k_per_day"]] * n_pool_members),
"logit_lamb": logit_lamb,
"logit_delta_lamb": logit_delta_lamb,
"initial_weights_logits": initial_weights_logits,
"log_amplitude": log_amplitude,
"raw_width": raw_width,
"raw_exponents": raw_exponents,
"logit_pre_exp_scaling": logit_pre_exp_scaling,
"subsidary_params": [],
}
return params
[docs]
def fill_in_missing_values_from_init_singleton(
params,
initial_values_dict,
n_tokens,
n_subsidary_rules=0,
chunk_period=60,
log_for_k=True,
):
"""
Fill in missing values in parameters from initial values.
Parameters
----------
params : dict
The parameters dictionary to update.
initial_values_dict : dict
The initial values dictionary.
n_tokens : int
The number of tokens.
n_subsidary_rules : int, optional
The number of subsidary rules. Default is 0.
chunk_period : int, optional
The chunk period. Default is 60.
log_for_k : bool, optional
Whether to use log scale for k parameter. Default is True.
Returns
-------
dict
The updated parameters dictionary.
"""
initial_params = init_params_singleton(
initial_values_dict, n_tokens, n_subsidary_rules, chunk_period, log_for_k
)
for key, value in initial_params.items():
if params.get(key) is None:
params[key] = value
return params
[docs]
def init_params(
initial_values_dict,
n_tokens,
n_subsidary_rules=0,
chunk_period=60,
n_parameter_sets=1,
noise="gaussian",
):
"""Initialize multiple parameter sets from human-readable initial values.
Creates ``n_parameter_sets`` copies of the base parameters. When
``n_parameter_sets > 1``, Gaussian noise is added to all rows except
the first (which remains at the exact initial values). This is the
legacy ensemble initialization method; for more control, see
``EnsembleAveragingHook``.
Parameters
----------
initial_values_dict : dict
Human-readable initial values (same format as ``init_params_singleton``).
n_tokens : int
Number of assets in the pool.
n_subsidary_rules : int, optional
Number of subsidiary rules. Default is 0.
chunk_period : int, optional
Time between price observations in minutes. Default is 60.
n_parameter_sets : int, optional
Number of parameter sets (ensemble members). Default is 1.
noise : str, optional
Noise type for diversification. Only ``'gaussian'`` is supported.
Default is ``'gaussian'``.
Returns
-------
dict
Parameter dict with 2-D arrays of shape ``(n_parameter_sets, n_pool_members)``
for each parameter key.
See Also
--------
init_params_singleton : Single parameter set initialization.
"""
n_pool_members = n_tokens + n_subsidary_rules
log_k = np.array(
[[np.log2(initial_values_dict["initial_k_per_day"])] * n_pool_members]
* n_parameter_sets
)
initial_lamb = memory_days_to_lamb(
initial_values_dict["initial_memory_length"], chunk_period
)
logit_lamb_np = np.log(initial_lamb / (1.0 - initial_lamb))
logit_lamb = np.array([[logit_lamb_np] * n_pool_members] * n_parameter_sets)
# lamb delta is the difference in lamb needed for
# lamb + delta lamb to give a final memory length
# of initial_memory_length + initial_memory_length_delta
initial_lamb_plus_delta_lamb = memory_days_to_lamb(
initial_values_dict["initial_memory_length"]
+ initial_values_dict["initial_memory_length_delta"],
chunk_period,
)
logit_lamb_plus_delta_lamb_np = np.log(
initial_lamb_plus_delta_lamb / (1.0 - initial_lamb_plus_delta_lamb)
)
logit_delta_lamb_np = logit_lamb_plus_delta_lamb_np - logit_lamb_np
logit_delta_lamb = np.array(
[[logit_delta_lamb_np] * n_pool_members] * n_parameter_sets
)
if type(initial_values_dict["initial_weights_logits"]) not in [
np.array,
jnp.array,
list,
]:
initial_weights_logits = np.array(
[[initial_values_dict["initial_weights_logits"]] * n_pool_members]
* n_parameter_sets
)
else:
initial_weights_logits = np.array(
[initial_values_dict["initial_weights_logits"]] * n_parameter_sets
)
log_amplitude = np.array(
[[initial_values_dict["initial_log_amplitude"]] * n_pool_members]
* n_parameter_sets
)
raw_width = np.array(
[[initial_values_dict["initial_raw_width"]] * n_pool_members] * n_parameter_sets
)
raw_exponents = np.array(
[[initial_values_dict["initial_raw_exponents"]] * n_pool_members]
* n_parameter_sets
)
logit_pre_exp_scaling_np = np.log(
initial_values_dict["initial_pre_exp_scaling"]
/ (1.0 - initial_values_dict["initial_pre_exp_scaling"])
)
logit_pre_exp_scaling = np.array(
[[logit_pre_exp_scaling_np] * n_pool_members] * n_parameter_sets
)
params = {
"log_k": log_k,
"logit_lamb": logit_lamb,
"logit_delta_lamb": logit_delta_lamb,
"initial_weights_logits": initial_weights_logits,
"log_amplitude": log_amplitude,
"raw_width": raw_width,
"raw_exponents": raw_exponents,
"logit_pre_exp_scaling": logit_pre_exp_scaling,
"subsidary_params": [],
}
if n_parameter_sets > 1:
if noise == "gaussian":
for key, value in params.items():
if key != "subsidary_params":
# Leave first row of each jax parameter unaltered, add
# gaussian noise to subsequent rows.
value[1:] = value[1:] + np.random.randn(*value[1:].shape)
for key, value in params.items():
if key != "subsidary_params":
params[key] = jnp.array(value)
return params
[docs]
def fill_in_missing_values_from_init(
params,
initial_values_dict,
n_tokens,
n_subsidary_rules=0,
chunk_period=60,
n_parameter_sets=1,
):
"""
Fill in missing values in parameters from initial values.
Parameters
----------
params : dict
The parameters dictionary to update.
initial_values_dict : dict
The initial values dictionary.
n_tokens : int
The number of tokens.
n_subsidary_rules : int, optional
The number of subsidary rules. Default is 0.
chunk_period : int, optional
The chunk period. Default is 60.
n_parameter_sets : int, optional
The number of parameter sets. Default is 1.
Returns
-------
dict
The updated parameters dictionary.
"""
initial_params = init_params(
initial_values_dict,
n_tokens,
n_subsidary_rules,
chunk_period,
n_parameter_sets=n_parameter_sets,
)
for key, value in initial_params.items():
if params.get(key) is None:
params[key] = value
return params
[docs]
def calc_hessian_from_loaded_params(params, partial_fixed_training_step):
"""
Calculate the Hessian matrix from the loaded parameters.
Parameters
----------
params : dict
A dictionary of parameters.
partial_fixed_training_step : callable
A function representing the partial fixed training step.
Returns
-------
numpy.ndarray
The Hessian matrix calculated from the loaded parameters.
"""
params_local = deepcopy(params)
if params_local.get("step") is not None:
params_local.pop("step")
if params_local.get("test_objective") is not None:
params_local.pop("test_objective")
if params_local.get("train_objective") is not None:
params_local.pop("train_objective")
return np.array(
hessian_trace(
dict_of_np_to_jnp(params_local), partial_fixed_training_step
).copy()
)
[docs]
def load_result_array(run_location, key="objective", recalc_hess=False):
"""
Load simulation results from a JSON file and return run fingerprint and results array.
Parameters
----------
run_location : str
Path to the JSON results file.
key : str, optional
Which value to extract from results. Default is "objective".
recalc_hess : bool, optional
Whether to recalculate Hessian trace values. Default is False.
Returns
-------
tuple
A tuple containing:
run_fingerprint : dict
Configuration details and metadata for the simulation run
values : list
Array of values extracted from results based on specified key
"""
if os.path.isfile(run_location):
with open(run_location, encoding='utf-8') as json_file:
params = json.load(json_file)
params = json.loads(params)
if recalc_hess is True:
if "hessian_trace" not in params[0].keys():
for i, param in enumerate(params):
params[i]["hessian_trace"] = calc_hessian_from_loaded_params(
params[i]
)
print(
i,
"/",
len(params),
" ",
i / len(params),
"htr: ",
params[i]["hessian_trace"],
)
return params[0], [p[key] for p in params[1:]]
def _extract_objective_values(objectives, metric_key="returns_over_uniform_hodl"):
"""Extract numeric values from objectives that may be dicts or numbers.
Handles both old format (list of numbers) and new format (list of metric dicts).
Parameters
----------
objectives : list
List of objectives, where each objective is either:
- A list of numbers (old format)
- A list of dicts with metric keys (new format)
metric_key : str
The key to extract from dict objectives. Defaults to "return".
For continuous_test_metrics, use keys like "continuous_test_return".
Returns
-------
np.ndarray
2D array of numeric objective values
"""
if not objectives:
return np.array([[float("-inf")]])
result = []
for obj_list in objectives:
if isinstance(obj_list, (list, tuple)) and len(obj_list) > 0:
if isinstance(obj_list[0], dict):
# New format: list of metric dicts
result.append([d.get(metric_key, float("-inf")) for d in obj_list])
else:
# Old format: list of numbers
result.append(list(obj_list))
else:
# Single value or empty
if isinstance(obj_list, dict):
result.append([obj_list.get(metric_key, float("-inf"))])
elif obj_list is not None:
try:
result.append([float(obj_list)])
except (TypeError, ValueError):
result.append([float("-inf")])
else:
result.append([float("-inf")])
return np.array(result)
def _is_new_format(params):
"""Check if params use the new format with metric dicts.
Returns True if train_objective contains dicts, False if it contains numbers.
"""
if len(params) < 2:
return False
train_obj = params[1].get("train_objective")
if isinstance(train_obj, (list, tuple)) and len(train_obj) > 0:
return isinstance(train_obj[0], dict)
return isinstance(train_obj, dict)
[docs]
def get_objective_scalar(obj, metric_key="returns_over_uniform_hodl"):
"""Extract a scalar value from an objective that may be a dict or number.
Use this when you have a single objective value (after retrieve_best has
indexed into the parameter sets) and need a float.
Parameters
----------
obj : float, int, or dict
The objective value - either a scalar (old format) or a dict of metrics (new format)
metric_key : str
The key to extract from dict objectives. Defaults to "returns_over_uniform_hodl".
Returns
-------
float
The scalar objective value
Examples
--------
>>> get_objective_scalar(0.1) # old format
0.1
>>> get_objective_scalar({"return": 0.1, "sharpe": 0.5}) # new format
0.1
"""
if isinstance(obj, dict):
return float(obj.get(metric_key, float("-inf")))
try:
return float(obj)
except (TypeError, ValueError):
return float("-inf")
def _get_test_objectives(params, use_continuous=True, metric_key="returns_over_uniform_hodl"):
"""Get test objectives, preferring continuous_test_metrics if available.
Parameters
----------
params : list
List of parameter dicts (including fingerprint at index 0)
use_continuous : bool
If True and continuous_test_metrics exists, use it instead of test_objective
metric_key : str
The metric key to extract (e.g., "return", "sharpe")
Returns
-------
list
Raw objectives list from params
str
The actual metric key to use (may be prefixed with "continuous_test_")
"""
# Check if continuous_test_metrics is available
if use_continuous and len(params) > 1:
first_param = params[1]
if "continuous_test_metrics" in first_param and first_param["continuous_test_metrics"]:
# Use continuous test metrics - keys are prefixed with "continuous_test_"
continuous_key = f"continuous_test_{metric_key}"
return [p.get("continuous_test_metrics", []) for p in params[1:]], continuous_key
# Fall back to test_objective
return [p["test_objective"] for p in params[1:]], metric_key
[docs]
def load_manually(
run_location,
load_method="last",
recalc_hess=False,
min_test=0.0,
return_as_iterables=False,
metric_key="returns_over_uniform_hodl",
use_continuous_test=True,
):
"""Load and process parameter sets from a JSON results file with custom loading methods.
Parameters
----------
run_location : str
Path to the JSON results file.
load_method : str, optional
Method for selecting parameter sets. One of:
'last' - Returns the last parameter set
'best_objective' - Returns set with highest overall objective
'best_train_objective' - Returns set with highest training objective
'best_test_objective' - Returns set with highest test objective
'best_train_min_test_objective' - Returns set with highest training objective
that meets minimum test threshold.
Defaults to 'last'.
recalc_hess : bool, optional
Whether to recalculate Hessian trace values. Defaults to False.
min_test : float, optional
Minimum test objective threshold for 'best_train_min_test_objective' method.
Defaults to 0.0.
metric_key : str, optional
For new format files with metric dicts, specifies which metric to use.
Options include: "return", "sharpe", "jax_sharpe", "returns_over_hodl",
"returns_over_uniform_hodl", "annualised_returns", "calmar", "sterling", "ulcer".
Ignored for old format files with simple numeric objectives.
Defaults to "returns_over_uniform_hodl".
use_continuous_test : bool, optional
If True and continuous_test_metrics is available, use it instead of
test_objective for test-related load methods. Defaults to True.
Returns
-------
tuple
Two-element tuple containing:
- dict: Loaded parameters
- int: The index of the selected parameter set
"""
if os.path.isfile(run_location):
with open(run_location, encoding="utf-8") as json_file:
params = json.load(json_file)
params = json.loads(params)
# Check if params length exceeds 1.5x the number of iterations
if len(params) > 1.5 * params[0]["optimisation_settings"]["n_iterations"]:
# Find last index where step == 0
last_step_zero_idx = -1
for i in range(len(params) - 1, 0, -1):
if params[i].get("step", -1) == 0:
last_step_zero_idx = i
break
# Keep only 0th row and rows from last step==0 onwards
if last_step_zero_idx != -1:
params = [params[0]] + params[last_step_zero_idx:]
if recalc_hess is True:
if "hessian_trace" not in params[0].keys():
for i in range(len(params)):
params[i]["hessian_trace"] = calc_hessian_from_loaded_params(
params[i]
)
dumped = json.dumps(params, cls=NumpyEncoder)
with open(run_location, "w", encoding="utf-8") as json_file:
json.dump(dumped, json_file)
# Helper to extract a single numeric value from an objective (handles old/new format)
def _get_objective_value(obj, key=metric_key):
if isinstance(obj, dict):
return obj.get(key, float("-inf"))
try:
return float(obj)
except (TypeError, ValueError):
return float("-inf")
if load_method == "last":
index = -1
context = None
elif load_method == "best_objective":
objectives = [p["objective"] for p in params[1:]]
index = np.argmax(np.nanmax(objectives, axis=1)) + 1
context = np.nanargmax(np.nanmax(objectives, axis=0))
elif load_method == "best_train_objective":
raw_objectives = [p["train_objective"] for p in params[1:]]
objectives = _extract_objective_values(raw_objectives, metric_key)
index = np.argmax(np.nanmax(objectives, axis=1)) + 1
context = np.nanargmax(np.nanmax(objectives, axis=0))
elif load_method == "best_train_objective_for_each_parameter_set":
raw_objectives = [p["train_objective"] for p in params[1:]]
objectives = _extract_objective_values(raw_objectives, metric_key)
index = (np.nanargmax(objectives, axis=0) + 1).tolist()
context = np.arange(len(objectives[0])).tolist()
elif load_method == "best_test_objective":
raw_objectives, actual_key = _get_test_objectives(params, use_continuous_test, metric_key)
objectives = _extract_objective_values(raw_objectives, actual_key)
index = np.argmax(np.nanmax(objectives, axis=1)) + 1
context = np.nanargmax(np.nanmax(objectives, axis=0))
elif load_method == "best_objective_of_last":
objectives = [params[-1]["objective"]]
index = -1
context = np.nanargmax(np.nanmax(objectives))
elif load_method == "best_train_objective_of_last":
raw_objectives = [params[-1]["train_objective"]]
objectives = _extract_objective_values(raw_objectives, metric_key)
index = -1
context = np.nanargmax(np.nanmax(objectives))
elif load_method == "best_test_objective_of_last":
raw_objectives, actual_key = _get_test_objectives(
[params[0], params[-1]], use_continuous_test, metric_key
)
objectives = _extract_objective_values(raw_objectives, actual_key)
index = -1
context = np.nanargmax(np.nanmax(objectives))
elif load_method == "best_train_min_test_objective":
# Get test objectives (prefer continuous if available)
raw_test_objs, test_key = _get_test_objectives(params, use_continuous_test, metric_key)
# Filter params where best test objective meets threshold
objectives = []
for idx, p in enumerate(params[1:]):
test_vals = _extract_objective_values([raw_test_objs[idx]], test_key)[0]
if np.nanmax(test_vals) >= min_test:
objectives.append(p)
train_objective_max = float("-inf")
if len(objectives) == 0:
objectives = params[1:]
best_objective = objectives[0]
set_with_best_test_index = 0
num_param_sets = len(_extract_objective_values(
[params[1]["train_objective"]], metric_key
)[0])
for p in objectives:
train_vals = _extract_objective_values([p["train_objective"]], metric_key)[0]
p_idx = params[1:].index(p) if p in params[1:] else 0
test_vals = _extract_objective_values([raw_test_objs[p_idx]], test_key)[0]
for i in range(num_param_sets):
test_val = test_vals[i] if i < len(test_vals) else float("-inf")
train_val = train_vals[i] if i < len(train_vals) else float("-inf")
if test_val >= min_test and train_val >= train_objective_max:
best_objective = p
set_with_best_test_index = i
train_objective_max = train_val
if return_as_iterables:
return [best_objective], [set_with_best_test_index]
else:
return best_objective, set_with_best_test_index
elif load_method == "best_test_min_train_objective":
# Get test objectives (prefer continuous if available)
raw_test_objs, test_key = _get_test_objectives(params, use_continuous_test, metric_key)
# Filter params where best train objective meets threshold
objectives = []
for p in params[1:]:
train_vals = _extract_objective_values([p["train_objective"]], metric_key)[0]
if np.nanmax(train_vals) >= min_test:
objectives.append(p)
test_objective_max = float("-inf")
if len(objectives) == 0:
objectives = params[1:]
best_objective = objectives[0]
set_with_best_test_index = 0
num_param_sets = len(_extract_objective_values(
[params[1]["test_objective"]], metric_key
)[0])
for p in objectives:
train_vals = _extract_objective_values([p["train_objective"]], metric_key)[0]
p_idx = params[1:].index(p) if p in params[1:] else 0
test_vals = _extract_objective_values([raw_test_objs[p_idx]], test_key)[0]
for i in range(num_param_sets):
train_val = train_vals[i] if i < len(train_vals) else float("-inf")
test_val = test_vals[i] if i < len(test_vals) else float("-inf")
if train_val >= min_test and test_val >= test_objective_max:
best_objective = p
set_with_best_test_index = i
test_objective_max = test_val
if return_as_iterables:
return [best_objective], [set_with_best_test_index]
else:
return best_objective, set_with_best_test_index
return best_objective, set_with_best_test_index
else:
raise NotImplementedError
if return_as_iterables:
if "for_each_parameter_set" not in load_method:
return [params[index]], [context]
else:
return [params[i] for i in index], context
else:
return params[index], context
[docs]
def retrieve_best(data_location, load_method, re_calc_hess, min_alt_obj=0.0, return_as_iterables=False):
"""Retrieve the best parameters from a training run.
Loads parameters using the specified method and extracts the best
parameter set based on the context (index of best performing parameters).
Removes training metadata (step, hessian_trace, etc.) from the returned params.
Parameters
----------
data_location : str
Path to the directory containing saved training results.
load_method : str
Method for loading parameters. Options include:
- 'last': Load the most recent checkpoint
- 'best_train_objective': Load checkpoint with best training objective
- 'best_test_objective': Load checkpoint with best test objective
re_calc_hess : bool
Whether to recalculate hessian information when loading.
min_alt_obj : float, optional
Minimum alternative objective threshold. Defaults to 0.0.
return_as_iterables : bool, optional
If True, returns lists of all loaded params and steps.
If False, returns only the first (best) params and step.
Defaults to False.
Returns
-------
params : dict or list of dict
Best parameter dictionary (or list if return_as_iterables=True).
Training metadata fields are removed.
steps : int or list of int
Training step(s) at which the parameters were saved.
"""
params, contexts = load_manually(data_location, load_method, re_calc_hess, min_alt_obj, return_as_iterables=True)
steps = []
params_list = []
for param, context in zip(params, contexts):
steps.append(param["step"])
params_list.append(param.copy())
params_list[-1].pop("step")
params_list[-1].pop("hessian_trace")
params_list[-1].pop("local_learning_rate")
params_list[-1].pop("iterations_since_improvement")
for key in params_list[-1].keys():
if key != "subsidary_params":
params_list[-1][key] = params_list[-1][key][context]
if return_as_iterables:
return params_list, steps
else:
return params_list[0], steps[0]
[docs]
def load_or_init(
run_fingerprint,
initial_values_dict,
n_tokens,
n_subsidary_rules,
recalc_hess=False,
chunk_period=60,
force_init=False,
load_method="last",
n_parameter_sets=1,
results_dir="./results/",
partial_fixed_training_step=None,
):
"""
Load or initialize parameters for the AMM simulator.
Parameters
----------
run_fingerprint : str
The fingerprint of the run.
initial_values_dict : dict
The initial values dictionary.
n_tokens : int
The number of tokens.
n_subsidary_rules : int
The number of subsidiary rules.
recalc_hess : bool, optional
Whether to recalculate the Hessian. Default is False.
chunk_period : int, optional
The chunk period. Default is 60.
force_init : bool, optional
Whether to force initialization. Default is False.
load_method : str, optional
The method to use for loading. Default is "last".
n_parameter_sets : int, optional
The number of parameter sets. Default is 1.
results_dir : str, optional
The directory for results. Default is "./results/".
partial_fixed_training_step : callable, optional
The partial fixed training step. Default is None.
Returns
-------
tuple
A tuple containing:
params : dict
The loaded or initialized parameters
loaded : bool
Whether the parameters were loaded (True) or initialized (False)
"""
run_location = results_dir + get_run_location(run_fingerprint) + ".json"
if force_init:
print("force init")
params = init_params(
initial_values_dict,
n_tokens,
n_subsidary_rules,
chunk_period,
n_parameter_sets=n_parameter_sets,
)
loaded = False
elif os.path.isfile(run_location):
print("Loading from: ", run_location)
print("found file")
with open(run_location, encoding='utf-8') as json_file:
params = json.load(json_file)
# if params:
# calc()
params = json.loads(params)
print("params")
print(params)
if recalc_hess is True:
if "hessian_trace" not in params[0].keys():
for i in range(len(params)):
params[i]["hessian_trace"] = calc_hessian_from_loaded_params(
params[i], partial_fixed_training_step
)
print(
i,
"/",
len(params),
" ",
i / len(params),
"htr: ",
params[i]["hessian_trace"],
)
dumped = json.dumps(params, cls=NumpyEncoder)
with open(run_location, "w", encoding='utf-8') as json_file:
json.dump(dumped, json_file, indent=4)
if isinstance(params, list):
params = [
fill_in_missing_values_from_init(
p,
initial_values_dict,
n_tokens,
n_subsidary_rules,
chunk_period,
n_parameter_sets=n_parameter_sets,
)
for p in params
]
else:
params = fill_in_missing_values_from_init(
params,
initial_values_dict,
n_tokens,
n_subsidary_rules,
chunk_period,
n_parameter_sets=n_parameter_sets,
)
if load_method == "last":
index = -1
elif load_method == "best_objective":
objectives = [p["objective"] for p in params[1:]]
index = np.argmax(np.max(objectives, axis=0)) + 1
else:
raise NotImplementedError
params = dict_of_np_to_jnp(params[index])
params["subsidary_params"] = [
dict_of_np_to_jnp(sp) for sp in params["subsidary_params"]
]
loaded = True
else:
# if n_parameter_sets == 1:
# params = init_params_singleton(
# initial_values_dict,
# n_tokens,
# n_subsidary_rules,
# chunk_period)
# else:
params = init_params(
initial_values_dict,
n_tokens,
n_subsidary_rules,
chunk_period,
n_parameter_sets=n_parameter_sets,
)
loaded = False
return params, loaded
[docs]
def load(
run_location,
initial_values_dict,
n_tokens,
n_subsidary_rules,
chunk_period=60,
load_method="last",
n_parameter_sets=1,
):
"""
Load parameters from a file and fill in missing values based on initial values.
Parameters
----------
run_location : str
The location of the file containing the parameters.
initial_values_dict : dict
A dictionary of initial values.
n_tokens : int
The number of tokens.
n_subsidary_rules : int
The number of subsidiary rules.
chunk_period : int, optional
The chunk period. Default is 60.
load_method : {'last', 'best_objective', 'best_train_objective'}, optional
The method to use for loading parameters. Default is 'last'.
n_parameter_sets : int, optional
The number of parameter sets. Default is 1.
Returns
-------
tuple
A tuple containing:
params : dict
The loaded parameters
context : int or None
The context index for the loaded parameters
Raises
------
FileNotFoundError
If the run_location file does not exist.
NotImplementedError
If an unsupported load_method is specified.
"""
if os.path.isfile(run_location):
with open(run_location, encoding='utf-8') as json_file:
params = json.load(json_file)
params = json.loads(params)
if isinstance(params, list):
params = [
fill_in_missing_values_from_init(
p,
initial_values_dict,
n_tokens,
n_subsidary_rules,
chunk_period,
n_parameter_sets=n_parameter_sets,
)
for p in params
]
else:
params = fill_in_missing_values_from_init(
params,
initial_values_dict,
n_tokens,
n_subsidary_rules,
chunk_period,
n_parameter_sets=n_parameter_sets,
)
if load_method == "last":
index = -1
context = None
elif load_method == "best_objective":
objectives = [p["objective"] for p in params[1:]]
index = np.argmax(np.nanmax(objectives, axis=1)) + 1
context = np.argmax(np.nanmax(objectives, axis=0))
elif load_method == "best_train_objective":
objectives = [p["train_objective"] for p in params[1:]]
index = np.argmax(np.nanmax(objectives, axis=1)) + 1
context = np.argmax(np.nanmax(objectives, axis=0))
else:
raise NotImplementedError
params = dict_of_np_to_jnp(params[index])
params["subsidary_params"] = [
dict_of_np_to_jnp(sp) for sp in params["subsidary_params"]
]
else:
raise FileNotFoundError(f"File not found: {run_location}")
return params, context
[docs]
def make_composite_run_params(
composite_params,
list_of_subsidary_pool_run_dicts,
initial_values_dict,
n_parameter_sets,
):
"""
Create composite run parameters for the AMM simulator.
Parameters
----------
composite_params : dict
The composite parameters for the AMM simulator.
list_of_subsidary_pool_run_dicts : list
A list of dictionaries containing the parameters for each subsidiary pool run.
initial_values_dict : dict
The initial values dictionary for the AMM simulator.
n_parameter_sets : int
The number of parameter sets.
Returns
-------
dict
The composite run parameters for the AMM simulator.
"""
params = deepcopy(composite_params)
params["subsidary_params"] = []
for sub in list_of_subsidary_pool_run_dicts:
local_initial_values_dict = deepcopy(initial_values_dict)
local_initial_values_dict["initial_memory_length"] = sub[
"initial_memory_length"
]
local_initial_values_dict["initial_k_per_day"] = sub["initial_k_per_day"]
local_n_tokens = len(sub["tokens"])
params["subsidary_params"].append(
init_params(
local_initial_values_dict,
local_n_tokens,
n_parameter_sets=n_parameter_sets,
)
)
return params
[docs]
def create_product_of_linspaces(
params, keys_ranges, num_points_per_key, inverse_funcs=None
):
"""
Create a product of linspaces for chosen keys in the params dict.
Parameters
----------
params : dict
The dictionary containing initial parameter values.
keys_ranges : dict
The dictionary containing high and low values for each key.
num_points_per_key : dict
The dictionary containing the number of points for each key.
inverse_funcs : dict, optional
A dictionary of inverse functions for each key.
Returns
-------
list
A list of dictionaries with all combinations of linspace values for the chosen keys.
"""
# Create linspaces for each key
linspaces = {}
for key, (low, high) in keys_ranges.items():
num_points = num_points_per_key.get(
key, 10
) # Default to 10 points if not specified
linspace = np.linspace(low, high, num_points)
if inverse_funcs and key in inverse_funcs:
linspace = inverse_funcs[key](linspace)
linspaces[key] = linspace
# Create the product of linspaces
linspace_product = list(product(*linspaces.values()))
# Create a list of dictionaries with all combinations of linspace values
param_combinations = []
for values in linspace_product:
param_combination = params.copy()
for i, key in enumerate(keys_ranges.keys()):
param_combination[key] = values[i]
param_combinations.append(param_combination)
return param_combinations
[docs]
def create_product_of_arrays(params, keys_arrays):
"""
Create a product of arrays for chosen keys in the params dict.
Parameters
----------
params : dict
The dictionary containing initial parameter values.
keys_arrays : dict
The dictionary containing the points for each key.
Returns
-------
list
A list of dictionaries with all combinations of linspace values for the chosen keys.
"""
# Create the product of linspaces
key_product = list(product(*keys_arrays.values()))
# Create a list of dictionaries with all combinations of linspace values
param_combinations = []
for values in key_product:
param_combination = params.copy()
for i, key in enumerate(keys_arrays.keys()):
param_combination[key] = values[i]
param_combinations.append(param_combination)
return param_combinations
[docs]
def generate_run_fingerprint_combinations(
run_fingerprint,
keys_ranges=None,
num_points_per_key=None,
inverse_funcs=None,
):
"""
Generate run fingerprint combinations with specified ranges and scaling.
Parameters
----------
run_fingerprint : dict
The base run fingerprint.
keys_ranges : dict, optional
The dictionary containing high and low values for each key.
Defaults to logarithmic ranges for 'arb_frequency'.
num_points_per_key : dict, optional
The dictionary containing the number of points for each key.
Defaults to 10 points for each key.
inverse_funcs : dict, optional
A dictionary of inverse functions for each key.
Defaults to logarithmic scaling for 'arb_frequency'.
Returns
-------
list
A list of dictionaries with all combinations of run fingerprint values.
"""
# Default keys ranges
if keys_ranges is None:
keys_ranges = {
# "arb_fees": (np.log(0.001), np.log(0.1)), # Example logarithmic range
"arb_frequency": (np.log(1), np.log(100)), # Example logarithmic range
}
# Default number of points for each key
if num_points_per_key is None:
num_points_per_key = {key: 10 for key in keys_ranges}
# # Default inverse functions for logarithmic scaling
# if inverse_funcs is None:
# inverse_funcs = {key: log_scale for key in keys_ranges}
# Generate run fingerprint combinations
run_fingerprint_combinations = create_product_of_linspaces(
run_fingerprint.copy(), keys_ranges, num_points_per_key, inverse_funcs
)
return run_fingerprint_combinations
[docs]
def make_log_range_with_zero(x):
"""
Compute the exponential of a given value, with a special case for zero.
Parameters
----------
x : float
The input value for which the exponential is to be computed.
Returns
-------
float
The exponential of the input value `x`, or zero if `x` is zero.
"""
if x == 0:
return 0
else:
return np.exp(x)
[docs]
def combine_param_combinations(param_combinations, n_parameter_sets):
"""
Combine single-row jnp arrays in param_combinations into multi-row jnp arrays.
Parameters
----------
param_combinations : list
List of dictionaries with single-row jnp arrays.
n_parameter_sets : int
Number of parameter sets to combine into each dictionary.
Returns
-------
list
List of dictionaries with multi-row jnp arrays.
"""
def combine_subsidary_params(subsidary_params_list):
combined_subsidary_params = []
for i in range(0, len(subsidary_params_list), n_parameter_sets):
batch = subsidary_params_list[i : i + n_parameter_sets]
combined_params = {}
if len(batch[0]) > 0:
for key in batch[0].keys():
combined_params[key] = jnp.stack([params[key] for params in batch])
combined_subsidary_params.append(combined_params)
return combined_subsidary_params
combined_params_list = []
for i in range(0, len(param_combinations), n_parameter_sets):
batch = param_combinations[i : i + n_parameter_sets]
combined_params = {}
for key in batch[0].keys():
if key == "subsidary_params":
combined_params[key] = combine_subsidary_params(
[params[key] for params in batch]
)
else:
combined_params[key] = jnp.stack([params[key] for params in batch])
combined_params_list.append(combined_params)
return combined_params_list
[docs]
def split_param_combinations(param_combinations):
"""
Split multi-row jnp arrays in param_combinations into single-row jnp arrays.
Parameters
----------
param_combinations : list
List of dictionaries with multi-row jnp arrays.
Returns
-------
list
List of dictionaries with single-row jnp arrays.
"""
def split_subsidary_params(subsidary_params_dict):
split_subsidary_params = []
keys = [k for k in subsidary_params_dict.keys()]
for i in range(len(subsidary_params_dict[keys[0]])):
split_params = {k: subsidary_params_dict[k][i] for k in keys}
split_subsidary_params.append(split_params)
return split_subsidary_params
split_params_list = []
for dict_ in param_combinations:
keys = [k for k in dict_.keys()]
for i in range(len(dict_[keys[0]])):
split_dict = {}
for key in keys:
if key == "subsidary_params" or key == "rule_outputs_dict":
split_dict[key] = split_subsidary_params(dict_[key])
else:
split_dict[key] = dict_[key][i]
split_params_list.append(split_dict)
return split_params_list
[docs]
def make_vmap_in_axes_dict(
input_dict, in_axes, keys_to_recur_on, keys_with_no_vamp=[], n_repeats_of_recurred=0
):
"""Create a ``vmap`` in_axes specification dict matching a parameter dict structure.
Constructs the nested dict/list structure that ``jax.vmap`` expects for its
``in_axes`` argument when vectorizing over a dict of parameters. Handles
recursive structure for subsidiary parameters.
Parameters
----------
input_dict : dict
Parameter dictionary whose structure to mirror.
in_axes : int
Axis to vectorize over (typically 0 for the parameter-set dimension).
keys_to_recur_on : list of str
Keys (e.g., ``'subsidary_params'``) that contain nested parameter dicts
requiring recursive axis specification.
keys_with_no_vamp : list of str, optional
Keys that should not be vectorized (axis set to None). Default is ``[]``.
n_repeats_of_recurred : int, optional
Number of subsidiary parameter dicts. Default is 0.
Returns
-------
dict
Nested dict matching the structure of ``input_dict`` with integer axes
or None for each leaf.
"""
in_axes_dict = dict()
for key, _ in input_dict.items():
in_axes_dict[key] = in_axes
for key in keys_to_recur_on:
in_axes_dict[key] = [
make_vmap_in_axes_dict(
input_dict,
in_axes,
[],
keys_with_no_vamp=["subsidary_params"],
n_repeats_of_recurred=0,
)
] * n_repeats_of_recurred
for key in keys_with_no_vamp:
in_axes_dict[key] = None
return in_axes_dict
[docs]
def generate_params_combinations(
initial_values_dict,
n_tokens,
n_subsidary_rules,
chunk_period,
n_parameter_sets,
k_per_day_range,
memory_days_range,
num_points_k_per_day=10,
num_points_memory_days=10,
):
"""
Generate parameter combinations with linearly-spaced values of k_per_day and memory_days.
Args:
initial_values_dict (dict): The initial values dictionary.
n_tokens (int): The number of tokens.
n_subsidary_rules (int): The number of subsidary rules.
chunk_period (int): The chunk period.
n_parameter_sets (int): The number of parameter sets.
k_per_day_range (tuple): The range (low, high) for k_per_day.
memory_days_range (tuple): The range (low, high) for memory_days.
num_points_k_per_day (int, optional): The number of points for k_per_day linspace. Defaults to 10.
num_points_memory_days (int, optional): The number of points for memory_days linspace. Defaults to 10.
Returns:
list: A list of dictionaries with all combinations of parameter values.
"""
# Initialize base params
# base_params = init_params_singleton(
# initial_values_dict, n_tokens, n_subsidary_rules, chunk_period
# )
# Define keys ranges for linspace generation
keys_ranges = {
"initial_k_per_day": k_per_day_range,
"initial_memory_length": memory_days_range,
}
# Define number of points for each key
num_points_per_key = {
"initial_k_per_day": num_points_k_per_day,
"initial_memory_length": num_points_memory_days,
}
# Generate param combinations
initial_values_dict_combinations = create_product_of_linspaces(
initial_values_dict.copy(), keys_ranges, num_points_per_key
)
# Fill in missing values from initial values
filled_param_combinations = [
fill_in_missing_values_from_init_singleton(
{},
i_v_d,
n_tokens,
n_subsidary_rules,
chunk_period,
n_parameter_sets,
)
for i_v_d in initial_values_dict_combinations
]
return filled_param_combinations, initial_values_dict_combinations
[docs]
def process_initial_values(
initial_values_dict, key, n_assets, n_parameter_sets, force_scalar=False
):
"""Extract and broadcast a parameter value to the correct shape.
Handles flexible input formats: scalar (broadcast to all assets and sets),
per-asset vector (broadcast across sets), or full matrix. Used by the
schema-aware initialization path.
Parameters
----------
initial_values_dict : dict
Dictionary containing initial parameter values.
key : str
Parameter name to extract.
n_assets : int
Number of assets (columns).
n_parameter_sets : int
Number of parameter sets / ensemble members (rows).
force_scalar : bool, optional
If True, treat value as a scalar even if it's array-like, producing
shape ``(n_parameter_sets,)`` instead of ``(n_parameter_sets, n_assets)``.
Returns
-------
np.ndarray
Array of shape ``(n_parameter_sets, n_assets)`` or ``(n_parameter_sets,)``
if ``force_scalar=True``.
Raises
------
ValueError
If ``key`` is not in ``initial_values_dict`` or has incompatible shape.
"""
if key in initial_values_dict:
initial_value = initial_values_dict[key]
if isinstance(initial_value, (np.ndarray, jnp.ndarray, list)):
initial_value = np.array(initial_value)
if force_scalar:
return np.array([initial_value] * n_parameter_sets)
elif initial_value.size == n_assets:
return np.array([initial_value] * n_parameter_sets)
elif initial_value.size == 1:
return np.array([[initial_value] * n_assets] * n_parameter_sets)
elif initial_value.shape == (n_parameter_sets, n_assets):
return initial_value
else:
raise ValueError(
f"{key} must be a singleton or a vector of length n_assets or a matrix of shape (n_parameter_sets, n_assets)"
)
else:
if force_scalar:
return np.array([initial_value] * n_parameter_sets)
else:
return np.array([[initial_value] * n_assets] * n_parameter_sets)
else:
raise ValueError(f"initial_values_dict must contain {key}")
def _to_float64_list(value):
"""Convert JAX/numpy array to list of float64."""
if isinstance(value, (jnp.ndarray, np.ndarray)):
return [float(x) for x in np.array(value).flatten()]
elif isinstance(value, (list, tuple)):
return [float(x) for x in value]
else:
return [float(value)]
def _to_bd18_string_list(values):
"""Convert list of floats to list of 18 fixed point integer strings.
Uses string manipulation to avoid overflow from multiplication by 1e18.
Formats each value with 18 decimal places, then removes the decimal point
and strips leading zeros.
"""
result = []
for x in values:
# Format with 18 decimal places, then remove decimal point
formatted = f"{x:.18f}"
# Split on decimal point
if '.' in formatted:
int_part, frac_part = formatted.split('.')
# Pad fractional part to exactly 18 digits if needed
frac_part = frac_part.ljust(18, '0')[:18]
combined = int_part + frac_part
else:
# No decimal point, just append 18 zeros
combined = formatted + '0' * 18
# Strip leading zeros, but keep at least one digit (handle zero case)
stripped = combined.lstrip('0')
result.append(stripped if stripped else '0')
return result
[docs]
def convert_parameter_values(params, run_fingerprint, max_memory_days=None):
"""Convert raw (reparameterized) parameters to human-readable and on-chain formats.
Applies the inverse reparameterizations (logit → lambda → memory_days, log2 → k,
squareplus → exponents, etc.) and produces both float64 values and BD18 fixed-point
string representations suitable for on-chain deployment.
Parameters
----------
params : dict
Raw parameter dictionary (e.g., ``'logit_lamb'``, ``'log_k'``, ``'raw_exponents'``).
run_fingerprint : dict
Run configuration, must include ``'chunk_period'``.
max_memory_days : float, optional
Maximum memory days for lambda clipping. If None, uses
``run_fingerprint['max_memory_days']`` (default 365).
Returns
-------
dict
``{'values': {...}, 'strings': {...}}`` where each inner dict maps
human-readable parameter names (``'lamb'``, ``'k'``, ``'exponents'``,
``'width'``, ``'amplitude'``, ``'pre_exp_scaling'``) to lists.
``'values'`` contains float64 lists; ``'strings'`` contains BD18
(18-decimal fixed-point integer) string representations.
Notes
-----
BD18 format multiplies the float value by 10^18 and represents as an integer
string, matching the Solidity ``uint256`` representation used by the on-chain
QuantAMM contracts. The conversion uses string manipulation to avoid float64
overflow from direct multiplication by 1e18.
"""
result = {"values": {}, "strings": {}}
memory_days = None # Keep track of computed memory_days for reuse
if max_memory_days is None:
max_memory_days = run_fingerprint.get("max_memory_days", 365)
if "logit_lamb" in params:
memory_days = lamb_to_memory_days_clipped(
calc_lamb(params),
chunk_period=run_fingerprint["chunk_period"],
max_memory_days=max_memory_days,
)
lamb = calc_lamb(params)
lamb_list = _to_float64_list(lamb)
result["values"]["lamb"] = lamb_list
result["strings"]["lamb"] = _to_bd18_string_list(lamb_list)
if "log_k" in params:
k = 2 ** params["log_k"] * memory_days
k_list = _to_float64_list(k)
result["values"]["k"] = k_list
result["strings"]["k"] = _to_bd18_string_list(k_list)
elif "k" in params:
k = params["k"] * memory_days
k_list = _to_float64_list(k)
result["values"]["k"] = k_list
result["strings"]["k"] = _to_bd18_string_list(k_list)
if "raw_exponents" in params:
exponents = squareplus(params["raw_exponents"])
exponents_list = _to_float64_list(exponents)
result["values"]["exponents"] = exponents_list
result["strings"]["exponents"] = _to_bd18_string_list(exponents_list)
if "raw_width" in params:
width = 2 ** params["raw_width"]
width_list = _to_float64_list(width)
result["values"]["width"] = width_list
result["strings"]["width"] = _to_bd18_string_list(width_list)
if "log_amplitude" in params:
# Recompute memory_days if not already computed
if memory_days is None:
memory_days = lamb_to_memory_days_clipped(
calc_lamb(params),
chunk_period=run_fingerprint["chunk_period"],
max_memory_days=max_memory_days,
)
amplitude = (2 ** params["log_amplitude"]) * memory_days
amplitude_list = _to_float64_list(amplitude)
result["values"]["amplitude"] = amplitude_list
result["strings"]["amplitude"] = _to_bd18_string_list(amplitude_list)
if "logit_pre_exp_scaling" in params:
pre_exp_scaling = jnp.exp(params["logit_pre_exp_scaling"]) / (
1 + jnp.exp(params["logit_pre_exp_scaling"])
)
pes_list = _to_float64_list(pre_exp_scaling)
result["values"]["pre_exp_scaling"] = pes_list
result["strings"]["pre_exp_scaling"] = _to_bd18_string_list(pes_list)
if "raw_pre_exp_scaling" in params:
pre_exp_scaling = 2 ** params["raw_pre_exp_scaling"]
pes_list = _to_float64_list(pre_exp_scaling)
result["values"]["pre_exp_scaling"] = pes_list
result["strings"]["pre_exp_scaling"] = _to_bd18_string_list(pes_list)
return result
# print("-" * 80)
# print("final readouts")
# for readout in result["readouts"]:
# print(
# f"{readout}: { jnp.array_str(result['readouts'][readout][-1], precision=16, suppress_small=False)}"
# )
# print("-" * 80)
# print("final weights")
# print(
# f"{jnp.array_str(result['weights'][-1], precision=16, suppress_small=False)}"
# )
# print("-" * 80)
# print("final prices")
# print(
# f"{jnp.array_str(result['prices'][-1], precision=16, suppress_small=False)}"
# )
# print("=" * 80)