"""Minimum-variance portfolio pool for QuantAMM.
Allocates weights inversely proportional to each asset's EWMA return variance
(diagonal-covariance minimum-variance portfolio). The strategy outputs weights
directly rather than weight changes, producing a risk-parity-like allocation
that tilts toward lower-volatility assets.
Key parameters: ``logit_lamb`` (EWMA decay for variance estimation).
"""
# again, this only works on startup!
from jax import config
config.update("jax_enable_x64", True)
from jax import default_backend
from jax import local_device_count, devices
DEFAULT_BACKEND = default_backend()
CPU_DEVICE = devices("cpu")[0]
if DEFAULT_BACKEND != "cpu":
GPU_DEVICE = devices("gpu")[0]
config.update("jax_platform_name", "gpu")
else:
GPU_DEVICE = devices("cpu")[0]
config.update("jax_platform_name", "cpu")
import jax.numpy as jnp
from jax import jit, vmap
from jax import devices, device_put
from jax import tree_util
from jax.lax import stop_gradient, dynamic_slice
from quantammsim.pools.G3M.quantamm.TFMM_base_pool import TFMMBasePool
from quantammsim.core_simulator.param_utils import (
memory_days_to_lamb,
lamb_to_memory_days_clipped,
calc_lamb,
)
from quantammsim.pools.G3M.quantamm.update_rule_estimators.estimators import (
calc_return_variances,
)
from typing import Dict, Any, Optional
from functools import partial
from abc import abstractmethod
import numpy as np
# import the fine weight output function which has pre-set argument rule_outputs_are_weights
# as this is True for min variance pools --- the strategy outputs weights themselves, not changes
from quantammsim.pools.G3M.quantamm.weight_calculations.fine_weights import (
calc_fine_weight_output_from_weights,
)
@jit
def _jax_min_variance_weights(variances):
diag_precisions = 1.0 / variances
reshape_sum = jnp.sum(diag_precisions, axis=-1, keepdims=True)
precision_based_weights = diag_precisions / reshape_sum
return precision_based_weights
[docs]
class MinVariancePool(TFMMBasePool):
"""
A class for min variance strategies run as TFMM (Temporal Function Market Making) liquidity pools,
extending the TFMMBasePool class.
This class implements a min variance strategy for asset allocation within a TFMM framework.
It uses price data to generate min variance weights.
Parameters
----------
None
Methods
-------
calculate_rule_outputs(params, run_fingerprint, prices, additional_oracle_input)
Calculate the raw weight outputs based on min variance calculations.
calculate_fine_weights(rule_output, initial_weights, run_fingerprint, params)
Refine the raw weight outputs to produce final weights.
calculate_weights(params, run_fingerprint, prices, additional_oracle_input)
Orchestrate the weight calculation process.
Notes
-----
The class provides methods to calculate raw weight outputs based on min variance signals and refine them
into final asset weights, taking into account various parameters and constraints defined in the pool setup.
"""
[docs]
def __init__(self):
"""
Initialize a new MinVariancePool instance.
"""
super().__init__()
[docs]
@partial(jit, static_argnums=(2))
def calculate_rule_outputs(
self,
params: Dict[str, Any],
run_fingerprint: Dict[str, Any],
prices: jnp.ndarray,
additional_oracle_input: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""
Calculate the raw weight outputs based on minimum variance optimization.
This method computes target weights that minimize portfolio variance. It processes
the input prices to calculate return variances, which are then used to determine
inverse-variance weighted allocations.
Parameters
----------
params : Dict[str, Any]
A dictionary of strategy parameters including lambda values.
run_fingerprint : Dict[str, Any]
A dictionary containing run-specific settings including chunk_period
and max_memory_days.
prices : jnp.ndarray
An array of asset prices over time.
additional_oracle_input : Optional[jnp.ndarray], optional
Additional input data, if any. Not used in this implementation.
Returns
-------
jnp.ndarray
Raw weight outputs representing the target minimum variance weights.
Notes
-----
The method performs the following steps:
1. Extracts chunkwise price values from the input prices.
2. Calculates return variances using an EWMA estimator.
3. Computes inverse-variance weights via the min variance formula.
Unlike momentum-based rules, these outputs represent target weights directly,
not weight changes to be applied incrementally.
"""
chunkwise_price_values = prices[:: run_fingerprint["chunk_period"]]
variances = calc_return_variances(params, chunkwise_price_values, run_fingerprint["chunk_period"], run_fingerprint["max_memory_days"], cap_lamb=True)
rule_outputs = _jax_min_variance_weights(variances)
return rule_outputs
[docs]
@partial(jit, static_argnums=(3))
def calculate_fine_weights(
self,
rule_output: jnp.ndarray,
initial_weights: jnp.ndarray,
run_fingerprint: Dict[str, Any],
params: Dict[str, Any],
) -> jnp.ndarray:
"""
Refine raw weight outputs to produce final weights for the momentum pool.
This method takes the raw weight outputs calculated from momentum signals and refines
them into final asset weights. It applies various constraints and adjustments defined
in the pool parameters and run fingerprint.
Parameters
----------
rule_output : jnp.ndarray
Raw weight changes or outputs from momentum calculations.
initial_weights : jnp.ndarray
Initial weights of assets in the pool.
run_fingerprint : Dict[str, Any]
Dictionary containing run-specific parameters and settings.
params : Dict[str, Any]
Pool parameters.
Returns
-------
jnp.ndarray
Refined weights for each asset in the pool over the specified time period.
Notes
-----
Uses the `calc_fine_weight_output_from_weights` function to perform the actual
refinement. The implementation of this function should handle details such as weight
interpolation, maximum change limits, and ensuring weights sum to 1.
"""
return calc_fine_weight_output_from_weights(
rule_output, initial_weights, run_fingerprint, params
)
[docs]
def init_base_parameters(
self,
initial_values_dict: Dict[str, Any],
run_fingerprint: Dict[str, Any],
n_assets: int,
n_parameter_sets: int = 1,
noise: str = "gaussian",
) -> Dict[str, Any]:
"""
Initialize parameters for the momentum pool.
This method sets up the initial parameters for the momentum pool strategy, including
weights, memory length (lambda), and the momentum factor (k).
Parameters
----------
initial_values_dict : Dict[str, Any]
Dictionary containing initial values for various parameters.
run_fingerprint : Dict[str, Any]
Dictionary containing run-specific settings and parameters.
n_assets : int
The number of assets in the pool.
n_parameter_sets : int, optional
The number of parameter sets to initialize, by default 1.
noise : str, optional
The type of noise to apply during initialization, by default "gaussian".
Returns
-------
Dict[str, jnp.array]
Dictionary containing the initialized parameters for the momentum pool.
Raises
------
ValueError
If required initial values are missing or in an incorrect format.
Notes
-----
This method handles the initialization of parameters for initial weights, lambda
(memory length parameter), and k (momentum factor) for each asset and parameter set.
It processes the initial values to ensure they are in the correct format and applies
any necessary transformations (e.g., logit transformations for lambda).
"""
np.random.seed(0)
# We need to initialise the weights for each parameter set
# If a vector is provided in the inital values dict, we use
# that, if only a singleton array is provided we expand it
# to n_assets and use that vlaue for all assets.
def process_initial_values(
initial_values_dict, key, n_assets, n_parameter_sets
):
if key in initial_values_dict:
initial_value = initial_values_dict[key]
if isinstance(initial_value, (np.ndarray, jnp.ndarray, list)):
initial_value = np.array(initial_value)
if initial_value.size == n_assets:
return np.array([initial_value] * n_parameter_sets)
elif initial_value.size == 1:
return np.array([[initial_value] * n_assets] * n_parameter_sets)
elif initial_value.shape == (n_parameter_sets, n_assets):
return initial_value
else:
raise ValueError(
f"{key} must be a singleton or a vector of length n_assets or a matrix of shape (n_parameter_sets, n_assets)"
)
else:
return np.array([[initial_value] * n_assets] * n_parameter_sets)
else:
raise ValueError(f"initial_values_dict must contain {key}")
initial_weights_logits = process_initial_values(
initial_values_dict, "initial_weights_logits", n_assets, n_parameter_sets
)
initial_lamb = memory_days_to_lamb(
initial_values_dict["initial_memory_length"],
run_fingerprint["chunk_period"],
)
logit_lamb_np = np.log(initial_lamb / (1.0 - initial_lamb))
logit_lamb = np.array([[logit_lamb_np] * n_assets] * n_parameter_sets)
# lamb delta is the difference in lamb needed for
# lamb + delta lamb to give a final memory length
# of initial_memory_length + initial_memory_length_delta
initial_lamb_plus_delta_lamb = memory_days_to_lamb(
initial_values_dict["initial_memory_length"]
+ initial_values_dict["initial_memory_length_delta"],
run_fingerprint["chunk_period"],
)
logit_lamb_plus_delta_lamb_np = np.log(
initial_lamb_plus_delta_lamb / (1.0 - initial_lamb_plus_delta_lamb)
)
logit_delta_lamb_np = logit_lamb_plus_delta_lamb_np - logit_lamb_np
logit_delta_lamb = np.array(
[[logit_delta_lamb_np] * n_assets] * n_parameter_sets
)
memory_days_1 = process_initial_values(
initial_values_dict, "initial_memory_length", n_assets, n_parameter_sets
)
memory_days_delta = process_initial_values(
initial_values_dict, "initial_memory_length_delta", n_assets, n_parameter_sets
)
memory_days_2 = memory_days_1 + memory_days_delta
params = {
# "logit_lamb": logit_lamb,
# "logit_delta_lamb": logit_delta_lamb,
"memory_days_1": memory_days_1,
"memory_days_2": memory_days_2,
"initial_weights_logits": initial_weights_logits,
"subsidary_params": [],
}
params = self.add_noise(params, noise, n_parameter_sets)
return params
[docs]
@classmethod
def process_parameters(cls, update_rule_parameters, run_fingerprint):
"""Process Min Variance pool parameters from web interface input."""
result = {}
# Find memory_days parameter
for urp in update_rule_parameters:
if urp.name == "memory_days":
memory_days = []
for tokenValue in urp.value:
memory_days.append(tokenValue)
if len(memory_days) != len(run_fingerprint["tokens"]):
memory_days = [memory_days[0]] * len(run_fingerprint["tokens"])
memory_days = np.array(memory_days)
# Set both memory_days parameters to the same value
result["memory_days_1"] = memory_days # for variance calculation
result["memory_days_2"] = memory_days # for weight smoothing
break
# Process any remaining parameters in default way
for urp in update_rule_parameters:
if urp.name != "memory_days": # skip memory_days as already processed
value = []
for tokenValue in urp.value:
value.append(tokenValue)
if len(value) != n_assets:
value = [value[0]] * n_assets
result[urp.name] = np.array(value)
return result
tree_util.register_pytree_node(
MinVariancePool, MinVariancePool._tree_flatten, MinVariancePool._tree_unflatten
)