"""Mean-reversion channel pool for QuantAMM.
Implements a channel-based mean-reversion strategy where weight updates are
driven by a Gaussian-enveloped, power-law-exponentiated price gradient signal.
Inside the channel (small price deviations) the strategy reverts; outside
(large deviations) the response saturates smoothly.
Key parameters: ``width`` (channel half-width), ``amplitude`` (reversion
strength), ``exponents`` (per-asset power-law shaping), ``logit_lamb`` (EWMA
memory length).
"""
# again, this only works on startup!
from jax import config
config.update("jax_enable_x64", True)
from jax import default_backend
from jax import local_device_count, devices
DEFAULT_BACKEND = default_backend()
CPU_DEVICE = devices("cpu")[0]
if DEFAULT_BACKEND != "cpu":
GPU_DEVICE = devices("gpu")[0]
config.update("jax_platform_name", "gpu")
else:
GPU_DEVICE = devices("cpu")[0]
config.update("jax_platform_name", "cpu")
import jax.numpy as jnp
from jax import jit, vmap
from jax import devices, device_put
from jax import tree_util
from jax.lax import stop_gradient, dynamic_slice
from quantammsim.pools.G3M.quantamm.momentum_pool import (
MomentumPool,
_jax_momentum_weight_update,
)
from quantammsim.core_simulator.param_utils import (
memory_days_to_lamb,
lamb_to_memory_days_clipped,
calc_lamb,
inverse_squareplus_np,
get_raw_value,
get_log_amplitude,
jax_memory_days_to_lamb,
)
from quantammsim.pools.G3M.quantamm.update_rule_estimators.estimators import (
calc_gradients,
calc_k,
squareplus,
)
from quantammsim.pools.G3M.quantamm.update_rule_estimators.estimator_primitives import (
_jax_gradient_scan_function,
)
from quantammsim.core_simulator.param_schema import ParamSpec, OptunaRange
from typing import Dict, Any, Optional
from functools import partial
import numpy as np
# import the fine weight output function which has pre-set argument rule_outputs_are_themselves_weights
# as this is False for momentum pools --- the strategy outputs weight _changes_
from quantammsim.pools.G3M.quantamm.weight_calculations.fine_weights import (
calc_fine_weight_output_from_weight_changes,
)
@jit
def _jax_mean_reversion_channel_weight_update(
price_gradient,
k,
width,
amplitude,
exponents,
inverse_scaling=0.5415,
pre_exp_scaling=0.5,
):
"""
Calculate weight updates using mean reversion channel strategy.
Parameters
----------
price_gradient : jnp.ndarray
Array of price gradients for each asset.
k : float or jnp.ndarray
Scaling factor for weight updates.
width : float or jnp.ndarray
Width parameter for the mean reversion channel.
amplitude : float or jnp.ndarray
Amplitude of the mean reversion effect.
exponents : jnp.ndarray
Exponents for the trend following portion.
inverse_scaling : float, optional
Scaling factor for the channel portion, by default 0.5415.
pre_exp_scaling : float, optional
Scaling factor applied before exponentiation, by default 0.5.
Returns
-------
jnp.ndarray
Array of weight updates for each asset.
Notes
-----
Combines a mean reversion channel component with a trend following component:
1. Channel portion uses a Gaussian envelope and cubic function
2. Trend portion uses power law scaling outside the channel
"""
envelope = jnp.exp(-(price_gradient**2) / (2 * width**2))
scaled_price_gradient = jnp.pi * price_gradient / (3 * width)
channel_portion = (
-amplitude
* envelope
* (scaled_price_gradient - (scaled_price_gradient**3) / 6)
/ inverse_scaling
)
trend_portion = (
(1 - envelope)
* jnp.sign(price_gradient)
* jnp.power(jnp.abs(price_gradient / (2.0 * pre_exp_scaling)), exponents)
)
signal = channel_portion + trend_portion
offset_constants = -(k * signal).sum(axis=-1, keepdims=True) / (jnp.sum(k))
weight_updates = k * (signal + offset_constants)
return weight_updates
[docs]
class MeanReversionChannelPool(MomentumPool):
"""
A class for mean reversion channel strategies run as TFMM liquidity pools.
This class implements a "mean reversion channel" strategy for asset allocation within a TFMM framework.
It uses price data to generate mean reversion channel signals, which are then translated into weight adjustments.
Parameters
----------
None
Methods
-------
calculate_rule_outputs(params, run_fingerprint, prices, additional_oracle_input)
Calculate the raw weight outputs based on mean reversion channel signals.
Notes
-----
The MeanReversionChannelPool implements a mean-reversion-based channel following strategy for asset allocation within a TFMM framework.
It uses price data to generate mean-reversion signals, which are then translated into weight adjustments.
The class provides methods to calculate raw weight outputs based on these signals and refine them
into final asset weights, taking into account various parameters and constraints defined in the pool setup.
"""
# Pool-owned parameter schema for MeanReversionChannel
# Uses sp_* (squareplus-transformed) params
#
# Internal param mappings:
# sp_k: squareplus(sp_k) = k -> k_per_day
# logit_lamb: logit(lamb) -> memory_length
# sp_exponents: squareplus(sp_exponents) = exponents, typically 1-4
# sp_pre_exp_scaling: squareplus(sp_pre_exp_scaling) = scaling
# sp_amplitude: squareplus(sp_amplitude) = amplitude
# sp_width: squareplus(sp_width) = width (channel width)
PARAM_SCHEMA = {
# sp_k: squareplus transformed
"sp_k": ParamSpec(
initial=19.5, # squareplus(19.5) ≈ 20
optuna=OptunaRange(low=-1.0, high=100.0, log_scale=False, scalar=False),
description="Squareplus-space k factor",
),
"logit_lamb": ParamSpec(
initial=4.0,
optuna=OptunaRange(low=-4.0, high=8.0, log_scale=False, scalar=False),
description="Logit of decay parameter lambda (memory length)",
),
"logit_delta_lamb": ParamSpec(
initial=0.0,
optuna=OptunaRange(low=-5.0, high=5.0, log_scale=False, scalar=False),
description="Delta in logit space for alternative lambda",
),
# Power channel parameters (squareplus transformed)
"sp_exponents": ParamSpec(
initial=0.0, # squareplus(0) ≈ 1.0
optuna=OptunaRange(low=-2.0, high=4.0, log_scale=False, scalar=False),
description="Squareplus-space exponents (gives 0.3-5)",
),
"sp_pre_exp_scaling": ParamSpec(
initial=-1.0, # squareplus(-1) ≈ 0.38
optuna=OptunaRange(low=-3.0, high=2.0, log_scale=False, scalar=False),
description="Squareplus-space pre-exp scaling (gives 0.09-2.4)",
),
# Mean reversion channel specific (squareplus transformed)
"sp_amplitude": ParamSpec(
initial=0.0, # squareplus(0) ≈ 1.0
optuna=OptunaRange(low=-3.0, high=4.0, log_scale=False, scalar=False),
description="Squareplus-space amplitude (gives 0.09-5)",
),
"sp_width": ParamSpec(
initial=0.0, # squareplus(0) ≈ 1.0
optuna=OptunaRange(low=-3.0, high=3.0, log_scale=False, scalar=False),
description="Squareplus-space channel width (gives 0.09-3.3)",
),
"initial_weights_logits": ParamSpec(
initial=1.0,
optuna=OptunaRange(low=-10, high=10, log_scale=False, scalar=False),
description="Logit-space initial portfolio weights",
trainable=False,
),
}
[docs]
@classmethod
def get_param_schema(cls) -> dict:
"""Get the full parameter schema for MeanReversionChannelPool."""
return cls.PARAM_SCHEMA
[docs]
def __init__(self):
"""
Initialize a new MeanReversionChannelPool instance.
Parameters
----------
None
"""
super().__init__()
[docs]
@partial(jit, static_argnums=(2))
def calculate_rule_outputs(
self,
params: Dict[str, Any],
run_fingerprint: Dict[str, Any],
prices: jnp.ndarray,
additional_oracle_input: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""
Calculate the raw weight outputs based on mean reversion channel signals.
This method computes the raw weight adjustments for the mean reversion channel strategy. It processes
the input prices to calculate gradients, which are then used to determine weight updates.
Parameters
----------
params : Dict[str, Any]
A dictionary of strategy parameters.
run_fingerprint : Dict[str, Any]
A dictionary containing run-specific settings.
prices : jnp.ndarray
An array of asset prices over time.
additional_oracle_input : Optional[jnp.ndarray], optional
Additional input data, if any.
Returns
-------
jnp.ndarray
Raw weight outputs representing the suggested weight adjustments.
Notes
-----
The method performs the following steps:
1. Calculates the memory days based on the lambda parameter.
2. Computes the 'k' factor which scales the weight updates.
3. Extracts chunkwise price values from the input prices.
4. Calculates price gradients using the calc_gradients function.
5. Applies the mean reversion channel weight update formula to get raw weight outputs.
The raw weight outputs are not the final weights, but rather the changes
to be applied to the previous weights. These will be refined in subsequent steps.
"""
use_pre_exp_scaling = run_fingerprint["use_pre_exp_scaling"]
# pre_exp_scaling: prefer sp_ (squareplus), fall back to logit_ (sigmoid), then raw_ (2^x)
if use_pre_exp_scaling and params.get("sp_pre_exp_scaling") is not None:
pre_exp_scaling = squareplus(params.get("sp_pre_exp_scaling"))
elif use_pre_exp_scaling and params.get("logit_pre_exp_scaling") is not None:
logit_pre_exp_scaling = params.get("logit_pre_exp_scaling")
pre_exp_scaling = jnp.exp(logit_pre_exp_scaling) / (
1 + jnp.exp(logit_pre_exp_scaling)
)
elif use_pre_exp_scaling and params.get("raw_pre_exp_scaling") is not None:
pre_exp_scaling = 2 ** params.get("raw_pre_exp_scaling")
else:
pre_exp_scaling = 0.5
memory_days = lamb_to_memory_days_clipped(
calc_lamb(params),
run_fingerprint["chunk_period"],
run_fingerprint["max_memory_days"],
)
# k: prefer sp_k (squareplus), fall back to log_k (2^x)
if params.get("sp_k") is not None:
k = squareplus(params.get("sp_k")) * memory_days
else:
k = calc_k(params, memory_days)
chunkwise_price_values = prices[:: run_fingerprint["chunk_period"]]
gradients = calc_gradients(
params,
chunkwise_price_values,
run_fingerprint["chunk_period"],
run_fingerprint["max_memory_days"],
run_fingerprint["use_alt_lamb"],
cap_lamb=True,
)
# exponents: prefer sp_exponents, fall back to raw_exponents (both use squareplus)
if params.get("sp_exponents") is not None:
exponents = squareplus(params.get("sp_exponents"))
else:
exponents = squareplus(params.get("raw_exponents"))
# amplitude: prefer sp_amplitude (squareplus), fall back to log_amplitude (2^x)
if params.get("sp_amplitude") is not None:
amplitude = squareplus(params.get("sp_amplitude")) * memory_days
else:
amplitude = (2 ** params.get("log_amplitude")) * memory_days
# width: prefer sp_width (squareplus), fall back to raw_width (2^x)
if params.get("sp_width") is not None:
width = squareplus(params.get("sp_width"))
else:
width = 2 ** params.get("raw_width")
rule_outputs = _jax_mean_reversion_channel_weight_update(
gradients,
k,
width,
amplitude,
exponents,
pre_exp_scaling=pre_exp_scaling,
)
return rule_outputs
[docs]
def calculate_rule_output_step(
self,
carry: Dict[str, jnp.ndarray],
price: jnp.ndarray,
params: Dict[str, Any],
run_fingerprint: Dict[str, Any],
) -> tuple:
"""
Calculate a single step of mean reversion channel weight update.
This mirrors the production implementation where we:
1. Update the gradient estimator state (ewma, running_a)
2. Compute the gradient from the updated state
3. Apply the mean reversion channel weight update formula
Parameters
----------
carry : Dict[str, jnp.ndarray]
Current state with 'ewma' and 'running_a'
price : jnp.ndarray
Current price observation (shape: n_assets,)
params : Dict[str, Any]
Pool parameters (logit_lamb, sp_k, sp_amplitude, sp_width, sp_exponents, etc.)
run_fingerprint : Dict[str, Any]
Simulation settings (chunk_period, max_memory_days, use_pre_exp_scaling, etc.)
Returns
-------
tuple
(new_carry, rule_output)
"""
# Compute lambda with max_memory_days capping
lamb = calc_lamb(params)
max_lamb = jax_memory_days_to_lamb(
run_fingerprint["max_memory_days"], run_fingerprint["chunk_period"]
)
lamb = jnp.clip(lamb, min=0.0, max=max_lamb)
# Get estimator constants (inherited from MomentumPool)
G_inf, saturated_b = self._get_estimator_constants(lamb)
# Use the estimator primitive for gradient calculation
carry_list = [carry["ewma"], carry["running_a"]]
new_carry_list, gradient = _jax_gradient_scan_function(
carry_list, price, G_inf, lamb, saturated_b
)
# Compute memory days and k for weight update
memory_days = lamb_to_memory_days_clipped(
lamb, run_fingerprint["chunk_period"], run_fingerprint["max_memory_days"]
)
# k: prefer sp_k (squareplus), fall back to log_k (2^x)
if params.get("sp_k") is not None:
k = squareplus(params.get("sp_k")) * memory_days
else:
k = calc_k(params, memory_days)
# pre_exp_scaling: prefer sp_ (squareplus), fall back to logit_ (sigmoid), then raw_ (2^x)
use_pre_exp_scaling = run_fingerprint["use_pre_exp_scaling"]
if use_pre_exp_scaling and params.get("sp_pre_exp_scaling") is not None:
pre_exp_scaling = squareplus(params.get("sp_pre_exp_scaling"))
elif use_pre_exp_scaling and params.get("logit_pre_exp_scaling") is not None:
logit_pre_exp_scaling = params.get("logit_pre_exp_scaling")
pre_exp_scaling = jnp.exp(logit_pre_exp_scaling) / (
1 + jnp.exp(logit_pre_exp_scaling)
)
elif use_pre_exp_scaling and params.get("raw_pre_exp_scaling") is not None:
pre_exp_scaling = 2 ** params.get("raw_pre_exp_scaling")
else:
pre_exp_scaling = 0.5
# exponents: prefer sp_exponents, fall back to raw_exponents (both use squareplus)
if params.get("sp_exponents") is not None:
exponents = squareplus(params.get("sp_exponents"))
else:
exponents = squareplus(params.get("raw_exponents"))
# amplitude: prefer sp_amplitude (squareplus), fall back to log_amplitude (2^x)
if params.get("sp_amplitude") is not None:
amplitude = squareplus(params.get("sp_amplitude")) * memory_days
else:
amplitude = (2 ** params.get("log_amplitude")) * memory_days
# width: prefer sp_width (squareplus), fall back to raw_width (2^x)
if params.get("sp_width") is not None:
width = squareplus(params.get("sp_width"))
else:
width = 2 ** params.get("raw_width")
# Apply mean reversion channel weight update
rule_output = _jax_mean_reversion_channel_weight_update(
gradient,
k,
width,
amplitude,
exponents,
pre_exp_scaling=pre_exp_scaling,
)
new_carry = {
"ewma": new_carry_list[0],
"running_a": new_carry_list[1],
}
return new_carry, rule_output
[docs]
def init_base_parameters(
self,
initial_values_dict: Dict[str, Any],
run_fingerprint: Dict[str, Any],
n_assets: int,
n_parameter_sets: int = 1,
noise: str = "gaussian",
) -> Dict[str, Any]:
"""
Initialize parameters for a mean reversion channel pool.
This method sets up the initial parameters for the mean reversion channel pool strategy, including
weights, memory length (lambda), the update aggressiveness (k) and the exponents.
Parameters
----------
initial_values_dict : Dict[str, Any]
Dictionary containing initial values for various parameters.
run_fingerprint : Dict[str, Any]
Dictionary containing run-specific settings and parameters.
n_assets : int
The number of assets in the pool.
n_parameter_sets : int, optional
The number of parameter sets to initialize, by default 1.
noise : str, optional
The type of noise to apply during initialization, by default "gaussian".
Returns
-------
Dict[str, jnp.array]
Dictionary containing the initialized parameters for the momentum pool.
Raises
------
ValueError
If required initial values are missing or in an incorrect format.
Notes
-----
This method handles the initialization of parameters for initial weights, lambda
(memory length parameter), the update agressiveness (k), the exponents and the width for each asset and parameter set.
It processes the initial values to ensure they are in the correct format and applies
any necessary transformations (e.g., logit transformations for lambda).
"""
# np.random.seed(0)
# We need to initialise the weights for each parameter set
# If a vector is provided in the inital values dict, we use
# that, if only a singleton array is provided we expand it
# to n_assets and use that vlaue for all assets.
def process_initial_values(
initial_values_dict, key, n_assets, n_parameter_sets, force_scalar=False
):
if key in initial_values_dict:
initial_value = initial_values_dict[key]
if isinstance(initial_value, (np.ndarray, jnp.ndarray, list)):
initial_value = np.array(initial_value)
if force_scalar:
return np.array([initial_value] * n_parameter_sets)
elif initial_value.size == n_assets:
return np.array([initial_value] * n_parameter_sets)
elif initial_value.size == 1:
return np.array([[initial_value] * n_assets] * n_parameter_sets)
elif initial_value.shape == (n_parameter_sets, n_assets):
return initial_value
else:
raise ValueError(
f"{key} must be a singleton or a vector of length n_assets or a matrix of shape (n_parameter_sets, n_assets)"
)
else:
if force_scalar:
return np.array([[initial_value]] * n_parameter_sets)
else:
return np.array([[initial_value] * n_assets] * n_parameter_sets)
else:
raise ValueError(f"initial_values_dict must contain {key}")
initial_weights_logits = process_initial_values(
initial_values_dict, "initial_weights_logits", n_assets, n_parameter_sets, force_scalar=False
)
# sp_k: use inverse_squareplus to get param that squareplus maps to initial_k_per_day
sp_k = inverse_squareplus_np(
process_initial_values(
initial_values_dict, "initial_k_per_day", n_assets, n_parameter_sets, force_scalar=run_fingerprint["optimisation_settings"]["force_scalar"]
)
)
initial_lamb = memory_days_to_lamb(
initial_values_dict["initial_memory_length"],
run_fingerprint["chunk_period"],
)
logit_lamb_np = np.log(initial_lamb / (1.0 - initial_lamb))
if run_fingerprint["optimisation_settings"]["force_scalar"]:
logit_lamb = np.array([[logit_lamb_np]] * n_parameter_sets)
else:
logit_lamb = np.array([[logit_lamb_np] * n_assets] * n_parameter_sets)
# lamb delta is the difference in lamb needed for
# lamb + delta lamb to give a final memory length
# of initial_memory_length + initial_memory_length_delta
initial_lamb_plus_delta_lamb = memory_days_to_lamb(
initial_values_dict["initial_memory_length"]
+ initial_values_dict["initial_memory_length_delta"],
run_fingerprint["chunk_period"],
)
logit_lamb_plus_delta_lamb_np = np.log(
initial_lamb_plus_delta_lamb / (1.0 - initial_lamb_plus_delta_lamb)
)
logit_delta_lamb_np = logit_lamb_plus_delta_lamb_np - logit_lamb_np
if run_fingerprint["optimisation_settings"]["force_scalar"]:
logit_delta_lamb = np.array([[logit_delta_lamb_np]] * n_parameter_sets)
else:
logit_delta_lamb = np.array(
[[logit_delta_lamb_np] * n_assets] * n_parameter_sets
)
# sp_pre_exp_scaling: use inverse_squareplus to get param that squareplus maps to initial_pre_exp_scaling
sp_pre_exp_scaling_np = inverse_squareplus_np(
initial_values_dict["initial_pre_exp_scaling"]
)
if run_fingerprint["optimisation_settings"]["force_scalar"]:
sp_pre_exp_scaling = np.array([[sp_pre_exp_scaling_np]] * n_parameter_sets)
else:
sp_pre_exp_scaling = np.array(
[[sp_pre_exp_scaling_np] * n_assets] * n_parameter_sets
)
# sp_amplitude: use inverse_squareplus to get param that squareplus maps to 2^initial_log_amplitude
# (maintaining same initial effective amplitude value)
if run_fingerprint["optimisation_settings"]["force_scalar"]:
sp_amplitude = np.array([[inverse_squareplus_np(2 ** initial_values_dict["initial_log_amplitude"])]] * n_parameter_sets)
else:
sp_amplitude = np.array(
[[inverse_squareplus_np(2 ** initial_values_dict["initial_log_amplitude"])] * n_assets]
* n_parameter_sets
)
# sp_width: use inverse_squareplus to get param that squareplus maps to 2^initial_raw_width
# (maintaining same initial effective width value)
if run_fingerprint["optimisation_settings"]["force_scalar"]:
sp_width = np.array([[inverse_squareplus_np(2 ** initial_values_dict["initial_raw_width"])]] * n_parameter_sets)
else:
sp_width = np.array(
[[inverse_squareplus_np(2 ** initial_values_dict["initial_raw_width"])] * n_assets] * n_parameter_sets
)
# sp_exponents: the initial_raw_exponents value is already in the right form for squareplus
if run_fingerprint["optimisation_settings"]["force_scalar"]:
sp_exponents = np.array([[initial_values_dict["initial_raw_exponents"]]] * n_parameter_sets)
else:
sp_exponents = np.array(
[[initial_values_dict["initial_raw_exponents"]] * n_assets]
* n_parameter_sets
)
params = {
"sp_k": sp_k,
"logit_lamb": logit_lamb,
"logit_delta_lamb": logit_delta_lamb,
"initial_weights_logits": initial_weights_logits,
"sp_amplitude": sp_amplitude,
"sp_width": sp_width,
"sp_exponents": sp_exponents,
"sp_pre_exp_scaling": sp_pre_exp_scaling,
"subsidary_params": [],
}
params = self.add_noise(params, noise, n_parameter_sets)
return params
@classmethod
def _process_specific_parameters(cls, update_rule_parameters, run_fingerprint):
"""Process mean reversion channel specific parameters."""
result = {}
amplitude_values = None
memory_days = None
# Get memory_days value for amplitude calculation
for urp in update_rule_parameters:
if urp.name == "memory_days":
memory_days = urp.value
break
# Process specific parameters
for urp in update_rule_parameters:
if urp.name == "amplitude":
amplitude_values = urp.value
elif urp.name == "exponent":
# Use inverse_squareplus to get sp_exponents param
sp_exponents = [float(inverse_squareplus_np(val)) for val in urp.value]
if len(sp_exponents) != len(run_fingerprint["tokens"]):
sp_exponents = [sp_exponents[0]] * len(run_fingerprint["tokens"])
result["sp_exponents"] = np.array(sp_exponents)
elif urp.name == "width":
# Use inverse_squareplus to get sp_width param
sp_width = [float(inverse_squareplus_np(val)) for val in urp.value]
if len(sp_width) != len(run_fingerprint["tokens"]):
sp_width = [sp_width[0]] * len(run_fingerprint["tokens"])
result["sp_width"] = np.array(sp_width)
elif urp.name == "pre_exp_scaling":
# Use inverse_squareplus to get sp_pre_exp_scaling param
sp_pre_exp_scaling = [float(inverse_squareplus_np(val)) for val in urp.value]
if len(sp_pre_exp_scaling) != len(run_fingerprint["tokens"]):
sp_pre_exp_scaling = [sp_pre_exp_scaling[0]] * len(run_fingerprint["tokens"])
result["sp_pre_exp_scaling"] = np.array(sp_pre_exp_scaling)
# Process amplitude last - use inverse_squareplus to get sp_amplitude param
if amplitude_values is not None:
if memory_days is None:
raise ValueError("memory_days parameter is required for amplitude calculation")
# amplitude_values are the actual amplitude values (before dividing by memory_days)
# sp_amplitude should be inverse_squareplus of (amplitude / memory_days) since
# effective amplitude = squareplus(sp_amplitude) * memory_days
sp_amplitude = [
float(inverse_squareplus_np(float(amp) / float(mem)))
for amp, mem in zip(amplitude_values, memory_days)
]
if len(sp_amplitude) != len(run_fingerprint["tokens"]):
sp_amplitude = [sp_amplitude[0]] * len(run_fingerprint["tokens"])
result["sp_amplitude"] = np.array(sp_amplitude)
return result
tree_util.register_pytree_node(
MeanReversionChannelPool,
MeanReversionChannelPool._tree_flatten,
MeanReversionChannelPool._tree_unflatten,
)