Source code for quantammsim.pools.G3M.quantamm.update_rule_estimators.estimators

"""EWMA-based estimator functions for QuantAMM update rules.

Provides high-level routines for computing time-weighted statistics from price
series: EWMA price gradients (``calc_gradients``), momentum sensitivity
factors (``calc_k``), return variances (``calc_return_variances``), and
paired/triple-threat gradient variants. Dispatches to either a sequential
``scan`` (CPU) or convolutional (GPU) backend based on ``DEFAULT_BACKEND``.
"""
# again, this only works on startup!
from jax import config

config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)
# config.update('jax_disable_jit', True)
from jax import default_backend
from jax import local_device_count, devices

DEFAULT_BACKEND = default_backend()
CPU_DEVICE = devices("cpu")[0]
if DEFAULT_BACKEND != "cpu":
    GPU_DEVICE = devices("gpu")[0]
    config.update("jax_platform_name", "gpu")
else:
    GPU_DEVICE = devices("cpu")[0]
    config.update("jax_platform_name", "cpu")

# DEFAULT_BACKEND = 'gpu'
import jax.numpy as jnp
from jax import jit, vmap
from jax import devices, device_put
from jax.tree_util import Partial
from jax.lax import scan, stop_gradient

import numpy as np

from functools import partial

np.seterr(all="raise")
np.seterr(under="print")

from quantammsim.pools.G3M.quantamm.update_rule_estimators.estimator_primitives import (
    make_ewma_kernel,
    make_a_kernel,
    make_cov_kernel,
    _jax_ewma_at_infinity_via_conv_padded,
    _jax_ewma_at_infinity_via_scan,
    _jax_gradients_at_infinity_via_conv_padded,
    _jax_gradients_at_infinity_via_conv_padded_with_alt_ewma,
    _jax_variance_at_infinity_via_conv,
    _jax_variance_at_infinity_via_scan,
    squareplus,
    _jax_gradients_at_infinity_via_scan,
    _jax_gradients_at_infinity_via_scan_with_alt_ewma,
    _jax_gradients_at_infinity_via_scan_with_readout,
)
from quantammsim.core_simulator.param_utils import (
    memory_days_to_lamb,
    jax_memory_days_to_lamb,
    lamb_to_memory_days_clipped,
    calc_lamb,
    calc_lamb_from_index,
)


[docs] def calc_gradients( update_rule_parameter_dict, chunkwise_price_values, chunk_period, max_memory_days, use_alt_lamb, cap_lamb=True, safety_margin_multiplier=5.0, ): """Calculate time-weighted price gradients for TFMM strategy implementation. Computes gradients of price movements using exponentially weighted moving averages (EWMA), with support for both CPU and GPU acceleration via JAX. Implements the gradient calculation described in the TFMM litepaper, with additional features for alternative memory lengths. Parameters ---------- update_rule_parameter_dict : dict Dictionary containing strategy parameters including: - 'logit_lamb': Controls the base memory length - 'logit_delta_lamb': Optional, controls alternative memory length if use_alt_lamb=True chunkwise_price_values : ndarray Array of shape (time_steps, n_assets) containing price values for each asset over time chunk_period : float Time period between chunks in minutes max_memory_days : float Maximum allowed memory length in days, used to cap lambda parameters use_alt_lamb : bool Whether to use an alternative lambda parameter for part of the calculation. Enables different parts of update rules to act over different memory lengths cap_lamb : bool, optional Whether to apply maximum memory day restriction to lambda parameters. Defaults to True safety_margin_multiplier : float, optional Multiplier for padding length in GPU/conv path. Higher values ensure EWMA convergence but use more memory. Theoretical minimum is ~1.9x for 99.9% convergence. Defaults to 5.0 Returns ------- ndarray Array of shape (time_steps-1, n_assets) containing calculated gradients. For each asset, represents the time-weighted rate of change of prices. Notes ----- The function implements two calculation paths: 1. GPU path: Uses convolution operations for efficient parallel computation - Pads input data to handle initialization - Creates specialized kernels for EWMA calculation - Leverages GPU parallelism for efficient computation 2. CPU path: Uses scan operations for sequential computation - More memory efficient - Uses a scan operation, so is fundamentally sequential The gradient calculation follows the methodology described in the TFMM litepaper, using exponential weighting to estimate price trends while avoiding look-ahead bias. """ lamb = calc_lamb(update_rule_parameter_dict) max_lamb = jax_memory_days_to_lamb(max_memory_days, chunk_period) # Apply max_memory_days restriction to lamb and alt_lamb if cap_lamb: capped_lamb = jnp.clip(lamb, min=0.0, max=max_lamb) lamb = capped_lamb safety_margin_max_memory_days = max_memory_days * safety_margin_multiplier # we can use alt lamb / alt memory days to allow different parts of # update rules to act over different memory lengths if use_alt_lamb: if update_rule_parameter_dict.get("logit_delta_lamb") is not None: logit_delta_lamb = update_rule_parameter_dict["logit_delta_lamb"] logit_alt_lamb = logit_delta_lamb + update_rule_parameter_dict["logit_lamb"] alt_lamb = jnp.exp(logit_alt_lamb) / (1 + jnp.exp(logit_alt_lamb)) if cap_lamb: capped_alt_lamb = jnp.clip(alt_lamb, min=0.0, max=max_lamb) alt_lamb = capped_alt_lamb else: raise Exception alt_memory_days = ( jnp.cbrt(6 * alt_lamb / ((1 - alt_lamb) ** 3)) * 2 * chunk_period / 1440 ) alt_memory_days = jnp.clip(alt_memory_days, min=0.0, max=max_memory_days) else: capped_alt_lamb = lamb alt_lamb = lamb if DEFAULT_BACKEND != "cpu": lamb = jnp.broadcast_to( lamb, update_rule_parameter_dict["initial_weights_logits"].shape ) ewma_kernel = make_ewma_kernel( lamb, safety_margin_max_memory_days, chunk_period ) a_kernel = make_a_kernel(lamb, safety_margin_max_memory_days, chunk_period) padded_chunkwise_price_values = jnp.vstack( [ jnp.ones( ( int(safety_margin_max_memory_days * 1440 / chunk_period), chunkwise_price_values.shape[1], ) ) * chunkwise_price_values[0], chunkwise_price_values, ] ) ewma_padded = _jax_ewma_at_infinity_via_conv_padded( padded_chunkwise_price_values, ewma_kernel ) saturated_b = lamb / ((1 - lamb) ** 3) if use_alt_lamb: alt_ewma_kernel = make_ewma_kernel( alt_lamb, safety_margin_max_memory_days, chunk_period ) alt_ewma_padded = _jax_ewma_at_infinity_via_conv_padded( padded_chunkwise_price_values, alt_ewma_kernel ) gradients = _jax_gradients_at_infinity_via_conv_padded_with_alt_ewma( padded_chunkwise_price_values, ewma_padded, alt_ewma_padded, a_kernel, saturated_b, ) else: gradients = _jax_gradients_at_infinity_via_conv_padded( padded_chunkwise_price_values, ewma_padded, a_kernel, saturated_b ) else: if use_alt_lamb: gradients = _jax_gradients_at_infinity_via_scan_with_alt_ewma( chunkwise_price_values, lamb, alt_lamb )[1:] else: gradients = _jax_gradients_at_infinity_via_scan( chunkwise_price_values, lamb )[1:] return gradients
def calc_triple_threat_gradients( update_rule_parameter_dict, logit_lamb_index, chunkwise_price_values, chunk_period, max_memory_days, cap_lamb=True, ): """Calculate time-weighted price gradients for TFMM strategy implementation. Computes gradients of price movements using exponentially weighted moving averages (EWMA), with support for both CPU and GPU acceleration via JAX. Implements the gradient calculation described in the TFMM litepaper, with additional features for alternative memory lengths. Parameters ---------- update_rule_parameter_dict : dict Dictionary containing strategy parameters including: - 'logit_lamb': Controls the base memory length - 'logit_delta_lamb': Optional, controls alternative memory length if use_alt_lamb=True chunkwise_price_values : ndarray Array of shape (time_steps, n_assets) containing price values for each asset over time chunk_period : float Time period between chunks in minutes max_memory_days : float Maximum allowed memory length in days, used to cap lambda parameters cap_lamb : bool, optional Whether to apply maximum memory day restriction to lambda parameters. Defaults to True Returns ------- ndarray Array of shape (time_steps-1, n_assets) containing calculated gradients. For each asset, represents the time-weighted rate of change of prices. Notes ----- The function implements two calculation paths: 1. GPU path: Uses convolution operations for efficient parallel computation - Pads input data to handle initialization - Creates specialized kernels for EWMA calculation - Leverages GPU parallelism for efficient computation 2. CPU path: Uses scan operations for sequential computation - More memory efficient - Uses a scan operation, so is fundamentally sequential The gradient calculation follows the methodology described in the TFMM litepaper, using exponential weighting to estimate price trends while avoiding look-ahead bias. """ lamb = calc_lamb_from_index(update_rule_parameter_dict, logit_lamb_index) max_lamb = jax_memory_days_to_lamb(max_memory_days, chunk_period) # Apply max_memory_days restriction to lamb and alt_lamb if cap_lamb: capped_lamb = jnp.clip(lamb, min=0.0, max=max_lamb) lamb = capped_lamb safety_margin_max_memory_days = max_memory_days * 5.0 # we can use alt lamb / alt memory days to allow different parts of # update rules to act over different memory lengths if update_rule_parameter_dict.get("logit_lamb_for_ewma") is not None: logit_alt_lamb = update_rule_parameter_dict.get("logit_lamb_for_ewma") alt_lamb = jnp.exp(logit_alt_lamb) / (1 + jnp.exp(logit_alt_lamb)) if cap_lamb: capped_alt_lamb = jnp.clip(alt_lamb, min=0.0, max=max_lamb) alt_lamb = capped_alt_lamb else: raise Exception alt_memory_days = ( jnp.cbrt(6 * alt_lamb / ((1 - alt_lamb) ** 3)) * 2 * chunk_period / 1440 ) alt_memory_days = jnp.clip(alt_memory_days, min=0.0, max=max_memory_days) if DEFAULT_BACKEND != "cpu": lamb = jnp.broadcast_to( lamb, update_rule_parameter_dict["initial_weights_logits"].shape ) ewma_kernel = make_ewma_kernel( lamb, safety_margin_max_memory_days, chunk_period ) a_kernel = make_a_kernel(lamb, safety_margin_max_memory_days, chunk_period) padded_chunkwise_price_values = jnp.vstack( [ jnp.ones( ( int(safety_margin_max_memory_days * 1440 / chunk_period), chunkwise_price_values.shape[1], ) ) * chunkwise_price_values[0], chunkwise_price_values, ] ) ewma_padded = _jax_ewma_at_infinity_via_conv_padded( padded_chunkwise_price_values, ewma_kernel ) saturated_b = lamb / ((1 - lamb) ** 3) alt_ewma_kernel = make_ewma_kernel( alt_lamb, safety_margin_max_memory_days, chunk_period ) alt_ewma_padded = _jax_ewma_at_infinity_via_conv_padded( padded_chunkwise_price_values, alt_ewma_kernel ) gradients = _jax_gradients_at_infinity_via_conv_padded_with_alt_ewma( padded_chunkwise_price_values, ewma_padded, alt_ewma_padded, a_kernel, saturated_b, ) else: gradients = _jax_gradients_at_infinity_via_scan_with_alt_ewma( chunkwise_price_values, lamb, alt_lamb )[1:] return gradients def calc_triple_threat_gradients( update_rule_parameter_dict, logit_lamb_index, chunkwise_price_values, chunk_period, max_memory_days, cap_lamb=True, ): """Calculate time-weighted price gradients for TFMM strategy implementation. Computes gradients of price movements using exponentially weighted moving averages (EWMA), with support for both CPU and GPU acceleration via JAX. Implements the gradient calculation described in the TFMM litepaper, with additional features for alternative memory lengths. Parameters ---------- update_rule_parameter_dict : dict Dictionary containing strategy parameters including: - 'logit_lamb': Controls the base memory length - 'logit_delta_lamb': Optional, controls alternative memory length if use_alt_lamb=True chunkwise_price_values : ndarray Array of shape (time_steps, n_assets) containing price values for each asset over time chunk_period : float Time period between chunks in minutes max_memory_days : float Maximum allowed memory length in days, used to cap lambda parameters cap_lamb : bool, optional Whether to apply maximum memory day restriction to lambda parameters. Defaults to True Returns ------- ndarray Array of shape (time_steps-1, n_assets) containing calculated gradients. For each asset, represents the time-weighted rate of change of prices. Notes ----- The function implements two calculation paths: 1. GPU path: Uses convolution operations for efficient parallel computation - Pads input data to handle initialization - Creates specialized kernels for EWMA calculation - Leverages GPU parallelism for efficient computation 2. CPU path: Uses scan operations for sequential computation - More memory efficient - Uses a scan operation, so is fundamentally sequential The gradient calculation follows the methodology described in the TFMM litepaper, using exponential weighting to estimate price trends while avoiding look-ahead bias. """ lamb = calc_lamb_from_index(update_rule_parameter_dict, logit_lamb_index) max_lamb = jax_memory_days_to_lamb(max_memory_days, chunk_period) # Apply max_memory_days restriction to lamb and alt_lamb if cap_lamb: capped_lamb = jnp.clip(lamb, min=0.0, max=max_lamb) lamb = capped_lamb safety_margin_max_memory_days = max_memory_days * 5.0 # we can use alt lamb / alt memory days to allow different parts of # update rules to act over different memory lengths if update_rule_parameter_dict.get("logit_lamb_for_ewma") is not None: logit_alt_lamb = update_rule_parameter_dict.get("logit_lamb_for_ewma") alt_lamb = jnp.exp(logit_alt_lamb) / (1 + jnp.exp(logit_alt_lamb)) if cap_lamb: capped_alt_lamb = jnp.clip(alt_lamb, min=0.0, max=max_lamb) alt_lamb = capped_alt_lamb else: raise Exception alt_memory_days = ( jnp.cbrt(6 * alt_lamb / ((1 - alt_lamb) ** 3)) * 2 * chunk_period / 1440 ) alt_memory_days = jnp.clip(alt_memory_days, min=0.0, max=max_memory_days) if DEFAULT_BACKEND != "cpu": lamb = jnp.broadcast_to( lamb, update_rule_parameter_dict["initial_weights_logits"].shape ) ewma_kernel = make_ewma_kernel( lamb, safety_margin_max_memory_days, chunk_period ) a_kernel = make_a_kernel(lamb, safety_margin_max_memory_days, chunk_period) padded_chunkwise_price_values = jnp.vstack( [ jnp.ones( ( int(safety_margin_max_memory_days * 1440 / chunk_period), chunkwise_price_values.shape[1], ) ) * chunkwise_price_values[0], chunkwise_price_values, ] ) ewma_padded = _jax_ewma_at_infinity_via_conv_padded( padded_chunkwise_price_values, ewma_kernel ) saturated_b = lamb / ((1 - lamb) ** 3) alt_ewma_kernel = make_ewma_kernel( alt_lamb, safety_margin_max_memory_days, chunk_period ) alt_ewma_padded = _jax_ewma_at_infinity_via_conv_padded( padded_chunkwise_price_values, alt_ewma_kernel ) gradients = _jax_gradients_at_infinity_via_conv_padded_with_alt_ewma( padded_chunkwise_price_values, ewma_padded, alt_ewma_padded, a_kernel, saturated_b, ) else: gradients = _jax_gradients_at_infinity_via_scan_with_alt_ewma( chunkwise_price_values, lamb, alt_lamb )[1:] return gradients def calc_gradients_with_readout( update_rule_parameter_dict, chunkwise_price_values, chunk_period, max_memory_days, use_alt_lamb, cap_lamb=True, ): """Calculate time-weighted price gradients for TFMM strategy implementation giving intermediate values. Computes gradients of price movements using exponentially weighted moving averages (EWMA), outputting intermediate values. It implements the gradient calculation described in the TFMM litepaper, with additional features for alternative memory lengths. Parameters ---------- update_rule_parameter_dict : dict Dictionary containing strategy parameters including: - 'logit_lamb': Controls the base memory length - 'logit_delta_lamb': Optional, controls alternative memory length if use_alt_lamb=True chunkwise_price_values : ndarray Array of shape (time_steps, n_assets) containing price values for each asset over time chunk_period : float Time period between chunks in minutes max_memory_days : float Maximum allowed memory length in days, used to cap lambda parameters use_alt_lamb : bool Whether to use an alternative lambda parameter for part of the calculation. Enables different parts of update rules to act over different memory lengths cap_lamb : bool, optional Whether to apply maximum memory day restriction to lambda parameters. Defaults to True Returns ------- dict gradients and intermediate values, each an array of shape (time_steps-1, n_assets) containing calculated values. Notes ----- The gradient calculation follows the methodology described in the TFMM litepaper, using exponential weighting to estimate price trends while avoiding look-ahead bias. """ lamb = calc_lamb(update_rule_parameter_dict) max_lamb = jax_memory_days_to_lamb(max_memory_days, chunk_period) # Apply max_memory_days restriction to lamb and alt_lamb if cap_lamb: capped_lamb = jnp.clip(lamb, min=0.0, max=max_lamb) lamb = capped_lamb safety_margin_max_memory_days = max_memory_days * 5.0 gradients_dict = _jax_gradients_at_infinity_via_scan_with_readout( chunkwise_price_values, lamb ) # do not include the first element of the gradients, to match the output of the conv path gradients_dict["gradients"] = gradients_dict["gradients"][1:] # gradients dict here also contains ewma and running a values return gradients_dict def calc_ewma_padded( update_rule_parameter_dict, chunkwise_price_values, chunk_period, max_memory_days, cap_lamb=True, ): """Calculate padded exponentially weighted moving average (EWMA) of prices. Computes EWMA using convolution with padding to handle initialization. The padding extends the price series backward using the first price value, allowing the EWMA to converge before the actual data begins. Parameters ---------- update_rule_parameter_dict : dict Dictionary containing 'logit_lamb' parameter controlling memory length. chunkwise_price_values : ndarray Array of shape (time_steps, n_assets) containing price values. chunk_period : float Time period between chunks in minutes. max_memory_days : float Maximum allowed memory length in days, used to cap lambda. cap_lamb : bool, optional Whether to apply max_memory_days restriction. Defaults to True. Returns ------- ndarray Padded EWMA array. Note: includes padding prefix, so length is greater than input length. See Also -------- calc_alt_ewma_padded : Similar but uses alternative lambda parameter. calc_gradients : Uses EWMA internally for gradient calculation. """ lamb = calc_lamb(update_rule_parameter_dict) max_lamb = memory_days_to_lamb(max_memory_days, chunk_period) # Apply max_memory_days restriction to lamb and alt_lamb # og_lamb = lamb.copy() if cap_lamb: capped_lamb = jnp.clip(lamb, min=0.0, max=max_lamb) lamb = capped_lamb safety_margin_max_memory_days = max_memory_days * 5.0 ewma_kernel = make_ewma_kernel(lamb, safety_margin_max_memory_days, chunk_period) padded_chunkwise_price_values = jnp.vstack( [ jnp.ones( ( int(safety_margin_max_memory_days * 1440 / chunk_period), chunkwise_price_values.shape[1], ) ) * chunkwise_price_values[0], chunkwise_price_values, ] ) ewma_padded = _jax_ewma_at_infinity_via_conv_padded( padded_chunkwise_price_values, ewma_kernel ) return ewma_padded def calc_alt_ewma_padded( update_rule_parameter_dict, chunkwise_price_values, chunk_period, max_memory_days, cap_lamb=True, ): """Calculate padded EWMA using an alternative (secondary) lambda parameter. Similar to calc_ewma_padded but uses a different memory length derived from 'logit_delta_lamb'. This allows update rules to use two different time scales for different components of the calculation. Parameters ---------- update_rule_parameter_dict : dict Dictionary containing: - 'logit_lamb': Base lambda parameter - 'logit_delta_lamb': Delta to add to logit_lamb for alternative lambda chunkwise_price_values : ndarray Array of shape (time_steps, n_assets) containing price values. chunk_period : float Time period between chunks in minutes. max_memory_days : float Maximum allowed memory length in days, used to cap lambda. cap_lamb : bool, optional Whether to apply max_memory_days restriction. Defaults to True. Returns ------- ndarray Padded EWMA array using alternative lambda. Note: includes padding prefix, so length is greater than input length. Raises ------ Exception If 'logit_delta_lamb' is not present in update_rule_parameter_dict. See Also -------- calc_ewma_padded : Uses primary lambda parameter. """ lamb = calc_lamb(update_rule_parameter_dict) max_lamb = memory_days_to_lamb(max_memory_days, chunk_period) # Apply max_memory_days restriction to lamb and alt_lamb # og_lamb = lamb.copy() if cap_lamb: capped_lamb = jnp.clip(lamb, min=0.0, max=max_lamb) lamb = capped_lamb safety_margin_max_memory_days = max_memory_days * 5.0 # we can use alt lamb / alt memory days to allow different parts of if update_rule_parameter_dict.get("logit_delta_lamb") is not None: logit_delta_lamb = update_rule_parameter_dict["logit_delta_lamb"] logit_alt_lamb = logit_delta_lamb + update_rule_parameter_dict["logit_lamb"] alt_lamb = jnp.exp(logit_alt_lamb) / (1 + jnp.exp(logit_alt_lamb)) # og_alt_lamb = alt_lamb.copy() if cap_lamb: capped_alt_lamb = jnp.clip(alt_lamb, min=0.0, max=max_lamb) alt_lamb = capped_alt_lamb else: raise Exception alt_memory_days = ( jnp.cbrt(6 * alt_lamb / ((1 - alt_lamb) ** 3)) * 2 * chunk_period / 1440 ) alt_memory_days = jnp.clip(alt_memory_days, min=0.0, max=max_memory_days) alt_ewma_kernel = make_ewma_kernel( alt_lamb, safety_margin_max_memory_days, chunk_period ) padded_chunkwise_price_values = jnp.vstack( [ jnp.ones( ( int(safety_margin_max_memory_days * 1440 / chunk_period), chunkwise_price_values.shape[1], ) ) * chunkwise_price_values[0], chunkwise_price_values, ] ) alt_ewma_padded = _jax_ewma_at_infinity_via_conv_padded( padded_chunkwise_price_values, alt_ewma_kernel ) return alt_ewma_padded
[docs] def calc_ewma_pair( memory_days_1, memory_days_2, chunkwise_price_values, chunk_period, max_memory_days, cap_lamb=True, ): """Calculate two exponentially weighted moving averages (EWMAs) with different memory lengths. Core estimator for Difference Momentum strategies, implementing the MACD-like comparison: .. math:: 1 - \\frac{E_2(\\mathbf{p}(t))}{E_1(\\mathbf{p}(t))} where E₁ and E₂ are EWMAs with memory lengths m₁ and m₂ respectively. Typically m₂ > m₁ for trend comparison. This function outputs the EWMAs for the two memory lengths used in the above formula. It supports both CPU and GPU paths, with GPU being more efficient for larger datasets. Parameters ---------- memory_days_1 : float Memory length for first EWMA in days. Typically shorter period (m₁), making it more responsive to recent price changes. memory_days_2 : float Memory length for second EWMA in days. Typically longer period (m₂), providing baseline for trend comparison. chunkwise_price_values : jnp.ndarray Price/oracle values of shape (time, assets). Can be any oracle value, not just prices (see Notes). chunk_period : float Time period between chunks in minutes. Used to convert between memory_days and λ parameters. max_memory_days : float Maximum allowed memory length in days. Prevents numerical instability from extremely long memory periods. cap_lamb : bool, optional Whether to cap λ values to prevent numerical issues, by default True. Recommended for numerical stability. Returns ------- tuple[jnp.ndarray, jnp.ndarray] Two EWMA arrays of shape (time - 1, assets) (each same shape as calc_gradients). Notes ----- Implementation details: 1. Converts memory_days to λ values using: λ = memory_days_to_lamb(memory_days, chunk_period) 2. GPU acceleration: - Uses convolution for parallel computation - Adds 5 * max_memory_days padding for initialization - Returns padded arrays for consistent calculations 3. CPU computation: - Uses sequential scan operations - More memory efficient for smaller datasets The function ensures: - Non-negative memory lengths - Numerical stability through λ capping See Also -------- DifferenceMomentumPool : Primary user of this function calc_gradients : Alternative trend estimation approach """ # Ensure non-negative memory days memory_days_1 = jnp.maximum(memory_days_1, 0.0) memory_days_2 = jnp.maximum(memory_days_2, 0.0) # Convert to lambda values lamb_1 = jax_memory_days_to_lamb(memory_days_1, chunk_period) lamb_2 = jax_memory_days_to_lamb(memory_days_2, chunk_period) if cap_lamb: max_lamb = jax_memory_days_to_lamb(max_memory_days, chunk_period) lamb_1 = jnp.clip(lamb_1, min=0.0, max=max_lamb) lamb_2 = jnp.clip(lamb_2, min=0.0, max=max_lamb) # Ensure input is 2D chunkwise_price_values = jnp.atleast_2d(chunkwise_price_values) if DEFAULT_BACKEND != "cpu": safety_margin_max_memory_days = max_memory_days * 5.0 # Create padding exactly as in calc_ewma_padded padded_chunkwise_price_values = jnp.vstack( [ jnp.ones( ( int(safety_margin_max_memory_days * 1440 / chunk_period), chunkwise_price_values.shape[1], ) ) * chunkwise_price_values[0], chunkwise_price_values, ] ) ewma_kernel_1 = make_ewma_kernel(lamb_1, safety_margin_max_memory_days, chunk_period) ewma_kernel_2 = make_ewma_kernel(lamb_2, safety_margin_max_memory_days, chunk_period) # Calculate EWMAs using same function as calc_ewma_padded ewma_1 = _jax_ewma_at_infinity_via_conv_padded( padded_chunkwise_price_values, ewma_kernel_1 ) ewma_2 = _jax_ewma_at_infinity_via_conv_padded( padded_chunkwise_price_values, ewma_kernel_2 ) ewma_1 = ewma_1[-(len(chunkwise_price_values) - 1):] ewma_2 = ewma_2[-(len(chunkwise_price_values) - 1):] else: # CPU path - no padding needed ewma_1 = _jax_ewma_at_infinity_via_scan(chunkwise_price_values, lamb_1) ewma_2 = _jax_ewma_at_infinity_via_scan(chunkwise_price_values, lamb_2) return ewma_1, ewma_2
[docs] def calc_return_variances( update_rule_parameter_dict, chunkwise_price_values, chunk_period, max_memory_days, cap_lamb, ): """Calculate time-weighted return variances for TFMM strategy implementation. Computes the variance of asset returns using exponentially weighted moving averages (EWMA), with support for both CPU and GPU acceleration via JAX. Essential for minimum variance portfolio calculations and related risk estimation. Parameters ---------- update_rule_parameter_dict : dict Dictionary containing strategy parameters, supporting two parameterizations: 1. Direct memory specification: - 'memory_days_1': Direct specification of memory length in days 2. Logit parameterization: - 'logit_lamb': Controls memory length through logit transform chunkwise_price_values : ndarray Array of shape (time_steps, n_assets) containing price values for each asset over time chunk_period : float Time period between chunks in minutes max_memory_days : float Maximum allowed memory length in days, used to cap lambda parameters cap_lamb : bool Whether to apply maximum memory day restriction to lambda parameters Returns ------- ndarray Array of shape (time_steps-1, n_assets) containing calculated variances. For each asset, represents the time-weighted variance of returns. Notes ----- The function implements two calculation paths: 1. GPU path: Uses convolution operations for efficient parallel computation - Pads input data to handle initialization - Creates specialized kernels for variance calculation - Combines EWMA and covariance kernels for efficient computation 2. CPU path: Uses scan operations for sequential computation - More memory efficient for smaller datasets - Direct implementation of variance formula The variance calculation follows the methodology described in the TFMM litepaper, using exponential weighting to estimate return variances while avoiding look-ahead bias. See Also -------- MinVariancePool : Primary user of this function calc_gradients : Calculates price gradients using similar EWMA methodology calc_ewma_pair : Calculates EWMA pairs for difference momentum strategies """ # Determine which parameterization is being used if "memory_days_1" in update_rule_parameter_dict: # Direct memory_days parameterization memory_days = update_rule_parameter_dict["memory_days_1"] lamb = jax_memory_days_to_lamb(memory_days, chunk_period) else: # Original logit_lamb parameterization lamb = calc_lamb(update_rule_parameter_dict) if cap_lamb: max_lamb = memory_days_to_lamb(max_memory_days, chunk_period) lamb = jnp.clip(lamb, min=0.0, max=max_lamb) returns = jnp.diff(chunkwise_price_values, axis=0) / chunkwise_price_values[:-1] if DEFAULT_BACKEND != "cpu": safety_margin_max_memory_days = max_memory_days * 5.0 cov_kernel = make_cov_kernel(lamb, safety_margin_max_memory_days, chunk_period) ewma_kernel = make_ewma_kernel( lamb, safety_margin_max_memory_days, chunk_period ) padded_returns = jnp.vstack( [ jnp.ones( ( int(safety_margin_max_memory_days * 1440 / chunk_period), returns.shape[1], ) ) * returns[0], returns, ] ) ewma_returns_padded = _jax_ewma_at_infinity_via_conv_padded( padded_returns, ewma_kernel ) variances = _jax_variance_at_infinity_via_conv( padded_returns, ewma_returns_padded[1:], cov_kernel, lamb ) variances = variances[-(len(chunkwise_price_values) - 1) :] else: variances = _jax_variance_at_infinity_via_scan(returns, lamb) return variances
def calc_return_precision_based_weights( update_rule_parameter_dict, chunkwise_price_values, chunk_period, max_memory_days, cap_lamb, ): """Calculate precision-based (inverse variance) portfolio weights. Computes weights proportional to the inverse of return variances, giving higher weight to assets with lower volatility. This is a simplified minimum variance approach. Parameters ---------- update_rule_parameter_dict : dict Dictionary containing lambda parameters for variance calculation. chunkwise_price_values : ndarray Array of shape (time_steps, n_assets) containing price values. chunk_period : float Time period between chunks in minutes. max_memory_days : float Maximum allowed memory length in days. cap_lamb : bool Whether to apply max_memory_days restriction to lambda. Returns ------- ndarray Array of shape (time_steps-1, n_assets) containing normalized precision-based weights that sum to 1 at each timestep. See Also -------- calc_return_variances : Calculates the variances used here. MinVariancePool : Uses similar inverse-variance weighting logic. """ variances = calc_return_variances( update_rule_parameter_dict, chunkwise_price_values, chunk_period, max_memory_days, cap_lamb, ) # ewma_padded = calc_ewma_padded(update_rule_parameter_dict, chunkwise_price_values, chunk_period, max_memory_days, cap_lamb) # ewma = ewma_padded[-(len(chunkwise_price_values) - 1):] # variances = _jax_variance_at_infinity_via_conv(chunkwise_price_values, ewma, cov_kernel, lamb) precision_based_weights = 1.0 / variances precision_based_weights = precision_based_weights / precision_based_weights.sum( axis=-1, keepdims=True ) return precision_based_weights # cov_kernel = make_cov_kernel(lamb, safety_margin_max_memory_days, chunk_period) # variances = _jax_variance_at_infinity_via_conv( # chunkwise_price_values, ewma, cov_kernel, lamb # ) # diag_precisions = diag_numba(precisions) # reshape_sum = np.reshape(np.sum(diag_precisions, axis=-1), (n - 1, 1)) # precision_based_weights = 1.0 / variances # precision_based_weights = precision_based_weights / precision_based_weights.sum( # axis=-1, keepdims=True # ) def calc_k(update_rule_parameter_dict, memory_days): """Calculate the 'k' parameter controlling update aggressiveness. The 'k' parameter scales weight updates in momentum-based strategies. Higher k values lead to more aggressive weight changes in response to price movements. Supports three parameterization modes. Parameters ---------- update_rule_parameter_dict : dict Dictionary containing one of: - 'log_k': Log2-scale k, computed as (2^log_k) * memory_days - 'absolute_k': Direct k value, used as-is - 'k': Linear k, computed as k * memory_days memory_days : float Memory length in days, used to scale k in log_k and k modes. Returns ------- float or ndarray The computed k value(s). Shape matches input parameter shape. Notes ----- The log_k parameterization is preferred for optimization as it allows k to vary over several orders of magnitude with bounded parameter values. """ if update_rule_parameter_dict.get("log_k") is not None: log_k = update_rule_parameter_dict.get("log_k") k = (2**log_k) * memory_days elif update_rule_parameter_dict.get("absolute_k") is not None: k = update_rule_parameter_dict.get("absolute_k") else: k = update_rule_parameter_dict.get("k") * memory_days return k