Weight Calculation Paths
This guide explains the two weight calculation paths available in quantammsim and when to use each.
Overview
QuantAMM pools compute portfolio weights from price data using estimators (e.g., EWMA for momentum detection). There are two computational approaches:
Vectorized Path - Computes all weights at once using convolution operations
Scan Path - Computes weights sequentially, one time step at a time
Both paths produce numerically equivalent results but have different performance characteristics.
Vectorized Path
The vectorized path uses JAX’s convolution operations to compute estimator outputs (like EWMA values) for all time steps simultaneously. This is the traditional approach and is typically faster for simulation and training.
How it works:
Compute all estimator outputs at once using
calculate_rule_outputs()Apply guardrails and interpolation using
calculate_fine_weights()Slice to the relevant bout period
Advantages:
Faster execution (typically 1.5-2x faster than scan)
Better GPU utilization through parallelization
Well-optimized by XLA compiler
Disadvantages:
Higher memory usage (stores all intermediate values)
Computation flow differs from production execution
Scan Path
The scan path processes prices sequentially using jax.lax.scan and jax.lax.fori_loop, updating estimator state one step at a time. This mirrors how weights are computed in production (on-chain).
How it works:
Initialize estimator state from first price
Warm up estimator over burn-in period (using
fori_loop)Compute fine weights for bout period (using
scan)
Advantages:
Matches production/on-chain execution exactly
Lower memory footprint (only stores current state)
Useful for verifying production behavior
Disadvantages:
Slower execution due to sequential processing
fori_loophas more overhead than vectorized operations
Selecting a Path
Use the weight_calculation_method parameter in your run fingerprint:
# Automatic selection (default) - uses vectorized if available
run_fingerprint["weight_calculation_method"] = "auto"
# Force vectorized path
run_fingerprint["weight_calculation_method"] = "vectorized"
# Force scan path (matches production)
run_fingerprint["weight_calculation_method"] = "scan"
Pool Support
Most QuantAMM pools support both paths:
Pool |
Vectorized |
Scan |
|---|---|---|
MomentumPool |
Yes |
Yes |
AntiMomentumPool |
Yes |
Yes |
PowerChannelPool |
Yes |
Yes |
MeanReversionChannelPool |
Yes |
Yes |
DifferenceMomentumPool |
Yes |
Yes |
MinVariancePool |
Yes |
No |
IndexMarketCapPool |
Yes |
No |
You can check pool support programmatically:
from quantammsim.pools import MomentumPool
pool = MomentumPool()
print(pool.supports_vectorized_path()) # True
print(pool.supports_scan_path()) # True
Numerical Equivalence
For pools that support both paths, results are numerically equivalent within floating-point tolerance:
import numpy as np
from quantammsim.runners.jax_runners import do_run_on_historic_data
fingerprint = {
"rule": "momentum",
"tokens": ["BTC", "ETH"],
# ... other settings ...
}
params = {
"log_k": jnp.array([3.0, 3.0]),
"logit_lamb": jnp.array([0.0, 0.0]),
"initial_weights_logits": jnp.array([0.0, 0.0]),
}
# Run with vectorized path
fingerprint["weight_calculation_method"] = "vectorized"
result_vec = do_run_on_historic_data(fingerprint, params)
# Run with scan path
fingerprint["weight_calculation_method"] = "scan"
result_scan = do_run_on_historic_data(fingerprint, params)
# Results match
np.testing.assert_allclose(
result_vec["final_value"],
result_scan["final_value"],
rtol=1e-4
)
Performance Comparison
Typical performance ratios (scan time / vectorized time):
Daily chunks (1440 min): ~1.5x slower
Hourly chunks (60 min): ~2x slower
More assets: Ratio increases slightly
Run the performance tests to measure on your hardware:
pytest tests/performance/test_weight_calculation_timing.py -v -s
Implementation Details
The scan path implementation uses:
get_initial_rule_state()- Initialize estimator carry statecalculate_rule_output_step()- Single-step estimator updateget_initial_guardrail_state()- Initialize weight carry statecalculate_coarse_weight_step()- Single-step guardrailed weightcalculate_fine_weights_step()- Single-step interpolation block
The burn-in warm-up uses jax.lax.fori_loop which supports dynamic (traced) loop bounds, allowing the warm-up length to vary based on start_index.
Implementing New Pools
When implementing a new pool, you can choose to implement:
Vectorized only - Implement
calculate_rule_outputs()Scan only - Implement
calculate_rule_output_step()andget_initial_rule_state()Both - Implement all methods for maximum flexibility
The base class provides capability detection:
class MyPool(TFMMBasePool):
def calculate_rule_outputs(self, params, run_fingerprint, prices, ...):
# Vectorized implementation
...
def calculate_rule_output_step(self, carry, price, params, run_fingerprint):
# Single-step implementation
...
def get_initial_rule_state(self, initial_price, params, run_fingerprint):
# Initial carry state
...
pool = MyPool()
pool.supports_vectorized_path() # True (has calculate_rule_outputs)
pool.supports_scan_path() # True (has calculate_rule_output_step)