Weight Calculation Paths

This guide explains the two weight calculation paths available in quantammsim and when to use each.

Overview

QuantAMM pools compute portfolio weights from price data using estimators (e.g., EWMA for momentum detection). There are two computational approaches:

  1. Vectorized Path - Computes all weights at once using convolution operations

  2. Scan Path - Computes weights sequentially, one time step at a time

Both paths produce numerically equivalent results but have different performance characteristics.

Vectorized Path

The vectorized path uses JAX’s convolution operations to compute estimator outputs (like EWMA values) for all time steps simultaneously. This is the traditional approach and is typically faster for simulation and training.

How it works:

  1. Compute all estimator outputs at once using calculate_rule_outputs()

  2. Apply guardrails and interpolation using calculate_fine_weights()

  3. Slice to the relevant bout period

Advantages:

  • Faster execution (typically 1.5-2x faster than scan)

  • Better GPU utilization through parallelization

  • Well-optimized by XLA compiler

Disadvantages:

  • Higher memory usage (stores all intermediate values)

  • Computation flow differs from production execution

Scan Path

The scan path processes prices sequentially using jax.lax.scan and jax.lax.fori_loop, updating estimator state one step at a time. This mirrors how weights are computed in production (on-chain).

How it works:

  1. Initialize estimator state from first price

  2. Warm up estimator over burn-in period (using fori_loop)

  3. Compute fine weights for bout period (using scan)

Advantages:

  • Matches production/on-chain execution exactly

  • Lower memory footprint (only stores current state)

  • Useful for verifying production behavior

Disadvantages:

  • Slower execution due to sequential processing

  • fori_loop has more overhead than vectorized operations

Selecting a Path

Use the weight_calculation_method parameter in your run fingerprint:

# Automatic selection (default) - uses vectorized if available
run_fingerprint["weight_calculation_method"] = "auto"

# Force vectorized path
run_fingerprint["weight_calculation_method"] = "vectorized"

# Force scan path (matches production)
run_fingerprint["weight_calculation_method"] = "scan"

Pool Support

Most QuantAMM pools support both paths:

Pool

Vectorized

Scan

MomentumPool

Yes

Yes

AntiMomentumPool

Yes

Yes

PowerChannelPool

Yes

Yes

MeanReversionChannelPool

Yes

Yes

DifferenceMomentumPool

Yes

Yes

MinVariancePool

Yes

No

IndexMarketCapPool

Yes

No

You can check pool support programmatically:

from quantammsim.pools import MomentumPool

pool = MomentumPool()
print(pool.supports_vectorized_path())  # True
print(pool.supports_scan_path())        # True

Numerical Equivalence

For pools that support both paths, results are numerically equivalent within floating-point tolerance:

import numpy as np
from quantammsim.runners.jax_runners import do_run_on_historic_data

fingerprint = {
    "rule": "momentum",
    "tokens": ["BTC", "ETH"],
    # ... other settings ...
}
params = {
    "log_k": jnp.array([3.0, 3.0]),
    "logit_lamb": jnp.array([0.0, 0.0]),
    "initial_weights_logits": jnp.array([0.0, 0.0]),
}

# Run with vectorized path
fingerprint["weight_calculation_method"] = "vectorized"
result_vec = do_run_on_historic_data(fingerprint, params)

# Run with scan path
fingerprint["weight_calculation_method"] = "scan"
result_scan = do_run_on_historic_data(fingerprint, params)

# Results match
np.testing.assert_allclose(
    result_vec["final_value"],
    result_scan["final_value"],
    rtol=1e-4
)

Performance Comparison

Typical performance ratios (scan time / vectorized time):

  • Daily chunks (1440 min): ~1.5x slower

  • Hourly chunks (60 min): ~2x slower

  • More assets: Ratio increases slightly

Run the performance tests to measure on your hardware:

pytest tests/performance/test_weight_calculation_timing.py -v -s

Implementation Details

The scan path implementation uses:

  • get_initial_rule_state() - Initialize estimator carry state

  • calculate_rule_output_step() - Single-step estimator update

  • get_initial_guardrail_state() - Initialize weight carry state

  • calculate_coarse_weight_step() - Single-step guardrailed weight

  • calculate_fine_weights_step() - Single-step interpolation block

The burn-in warm-up uses jax.lax.fori_loop which supports dynamic (traced) loop bounds, allowing the warm-up length to vary based on start_index.

Implementing New Pools

When implementing a new pool, you can choose to implement:

  1. Vectorized only - Implement calculate_rule_outputs()

  2. Scan only - Implement calculate_rule_output_step() and get_initial_rule_state()

  3. Both - Implement all methods for maximum flexibility

The base class provides capability detection:

class MyPool(TFMMBasePool):
    def calculate_rule_outputs(self, params, run_fingerprint, prices, ...):
        # Vectorized implementation
        ...

    def calculate_rule_output_step(self, carry, price, params, run_fingerprint):
        # Single-step implementation
        ...

    def get_initial_rule_state(self, initial_price, params, run_fingerprint):
        # Initial carry state
        ...

pool = MyPool()
pool.supports_vectorized_path()  # True (has calculate_rule_outputs)
pool.supports_scan_path()        # True (has calculate_rule_output_step)