Implementing a new AMM

We have seen how to implement a custom update rule for a QuantAMM pool.

What if you have a totally new type of AMM? To implement a new pool type, you need to create a new class that implements abstract methods from the quantammsim.pools.AbstractPool interface.

Let’s walk through an example implementation, looking at how quantammsim.pools.BalancerPool implements these requirements.

Note that pools do not hold any state, they only have methods. This makes them much easier to make work with JAX, which as a semi-functional language is not well-suited to object-oriented programming. The pool classes thus almost act (very informally speaking) as a namespace for methods.

Core Implementation Requirements

Looking at the quantammsim.pools.AbstractPool interface, we need to implement:

  1. Reserve Calculation Methods

    These methods determine how pool reserves change in response to market conditions:

    • calculate_reserves_zero_fees - Optimized calculation without fees

    • calculate_reserves_with_fees - Handles arbitrage with fees charged by the pool (and also gas costs paid by arbitrageurs and even the fees arbitrageurs may pay on a secondary market to liquidate their positions)

    • calculate_reserves_with_dynamic_inputs - Supports time-varying fees, gas costs, and more (including applying a sequence of [not necessarily arbitrage] trades to the pool).

For full funcionality, all three should be implemented, but for quick testing it is often sufficient to implement only the zero-fees case, which often is substantially faster and simpler than the other two.

  1. Configuration Methods

These methods handle pool setup and parameters:

  • init_base_parameters - Initialize pool configuration

  • make_vmap_in_axes - Configure JAX vectorization

  • is_trainable - Determine if weights can be trained

See quantammsim.pools.AbstractPool for the complete interface specification.

The following sections demonstrate how quantammsim.pools.BalancerPool implements these requirements.

Basic Structure

First, let’s look at the class definition and initialization:

class BalancerPool(AbstractPool):
    def __init__(self):
        super().__init__()

Note the empty __init__ method–pools do not hold any state, they only have methods.

Reserve Calculations

The Balancer pool implements the three methods for reserve calculations:

1. Zero Fees Case

The zero fees implementation is the simplest and most performant:

def calculate_reserves_zero_fees(
    self,
    params: Dict[str, Any],
    run_fingerprint: Dict[str, Any],
    prices: jnp.ndarray,
    start_index: jnp.ndarray,
    additional_oracle_input: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
    """
    Calculate reserves assuming zero fees and perfect arbitrage.

    Uses JAX-accelerated function _jax_calc_balancer_reserve_ratios for efficient
    computation in the theoretical zero-fee case. Simpler than TFMM implementation
    due to constant weights.

    Implementation Notes:
    ---------------------
    1. Uses dynamic_slice for price window
    2. Applies constant weights from calculate_initial_weights
    3. Computes reserve ratios directly
    4. Uses cumprod for reserve calculation
    5. Handles no-arbitrage case via broadcasting

    Parameters
    ----------
    params : Dict[str, Any]
        Pool parameters containing initial_weights_logits or initial_weights
    run_fingerprint : Dict[str, Any]
        Simulation parameters
    prices : jnp.ndarray
        Price history array
    start_index : jnp.ndarray
        Starting index for the calculation window
    additional_oracle_input : Optional[jnp.ndarray]
        Not used in BalancerPool, kept for interface compatibility

    Returns
    -------
    jnp.ndarray
        Calculated reserves over time
    """

    # Get constant weights
    weights = self.calculate_weights(params)

    # Extract relevant price window
    local_prices = dynamic_slice(prices, start_index, (bout_length - 1, n_assets))

    # Calculate initial reserves
    initial_value_per_token = weights * run_fingerprint["initial_pool_value"]
    initial_reserves = initial_value_per_token / local_prices[0]

    if run_fingerprint["do_arb"]:
        # Calculate reserve ratios using vectorized operation
        reserve_ratios = _jax_calc_balancer_reserve_ratios(
            local_prices[:-1], weights, local_prices[1:]
        )
        # Compute reserves through cumulative products
        reserves = jnp.vstack([
            initial_reserves,
            initial_reserves * jnp.cumprod(reserve_ratios, axis=0)
        ])
    else:
        reserves = jnp.broadcast_to(initial_reserves, local_prices.shape)

    return reserves

Slicing the price window

While it might be natural to consider passing in a price array that corresponds exactly the time period covered by the simulation, it can actually be neater for some use cases to pass in a price array that is longer than the simulation period, and then slice the price array to the relevant period within these functions.

This is particularly useful for pools that have dynamic properties that change over time, such as time-varying fees or dynamic weights, as these features very often will depend on earlier prices than those of just the simulation period.

So in the calculate_reserves_zero_fees function, we see that we pass in a start_index parameter, which is used to slice the price array to the relevant period. The length of the price array is given by bout_length, which is a parameter of the run_fingerprint dictionary.

For a base Balancer pool with constant weights, however, we have no dynamic properties (the weights are constant, the fees are fixed at zero here). This means that we could happily pass in a price array that is just the length of the simulation period. But the dynamic slicing of the completed price array is the structure required by the quantammsim.pools.AbstractPool interface, and is the structure that enables time varying properties.

Arbitrage control

The run_fingerprint dictionary contains a do_arb parameter, which controls whether arbitrage is performed on the pool. If arbitrage is not enabled, this function simply returns the initial reserves without any further calculation. In practice, we would set do_arb to True, as this is the only way to get a realistic simulation of the pool. If one is performing a simulation, however, where a trade sequence is provided, it may be useful to set do_arb to False, as this will allow one to see the effect of trades on the pool without the additional complexity of arbitrage. See below the discussion of the calculate_reserves_with_dynamic_inputs function for more details. The do_arb key is set to True by default.

Understanding _jax_calc_balancer_reserve_ratios

Deriving the actual reserve calculations for a particular pool type can be a bit of a dark art. For Balancer pools with fixed weights the core calculation of how reserves change in response to price movements is handled by _jax_calc_balancer_reserve_ratios.

Here we will take a brief foray into the mathematics of the Balancer pool, and how give a gloss on where the logic in _jax_calc_balancer_reserve_ratios comes from. Other pools will have different reserve calculations, but the general approach is the same: derive the reserve calculations from the pool’s trading function by considering how arbitrageurs will act given pool state and external market prices.

The derivations tend to rely on two key ideas:

  1. Invariant Preservation

For a Balancer pool containg \(N\) assets, with weights \(w_1, w_2, ..., w_N\), (where \(w_i\) sum to 1 and are in the range [0, 1]), and reserves \(R_1, R_2, ..., R_N\), the trading function is

\[k = \prod_i^N R_i^{w_i}\]

in the zero fees case. And the value \(k\) of the trading function is invariant under allowed operations on the pool.

  1. Price Matching and Equilibrium

After arbitrage, in the zero fees case, the pool’s marginal prices exactly match the external market prices. The pool’s quoted price for a marginal trade of the \(i\)th asset is proportional to \(\frac{w_i}{R_i}\). So we have that, after arbitrage,

\[\frac{\frac{w_i}{R_i}}{\frac{w_j}{R_j}} = \frac{p_i}{p_j},\]

where \(p_k\) is the price of asset \(k\) on the external market in a particular numeraire.

Combining these ideas, we can derive the reserve ratio formula for a Balancer pool with constant weights,

\[\frac{R_i(t')}{R_i(t)} = \frac{p_i(t)}{p_i(t')} \prod_{j=1}^N \left(\frac{p_j(t')}{p_j(t)}\right)^{w_j}.\]

The full derivation is available in the the Temporal Function Market Making litepaper, Appendix A1.

Note

We have subtly used, under the hood, the result that geometric mean market maker pools hold minimum value when their quoted marginal prices are equal to the external market price. Proving that result is beyond the scope of this tutorial, but it is a well-known result in the AMM literature, and can be derived using the method of Lagrange multipliers.

Note

For different pools and/or when handling the presence of fees and other time varying properties of pools (e.g. that arbitrageurs might have fixed costs and other constraints) the reserve derivations and resulting calculations will be different. The general approach is the same: derive the reserve calculations from the pool’s trading function by considering how arbitrageurs will act given pool state and external market prices.

Now let’s focus on the implementation, _jax_calc_balancer_reserve_ratios():

@jit
def _jax_calc_balancer_reserve_ratios(prev_prices, weights, prices):
    """Calculate reserve ratio changes for constant-weight Balancer pools.

    Parameters
    ----------
    prev_prices : jnp.ndarray
        Previous asset prices
    weights : jnp.ndarray
        Pool weights (must sum to 1)
    prices : jnp.ndarray
        New asset prices

    Returns
    -------
    jnp.ndarray
        Ratio of new reserves to old reserves for each asset
    """
    # Calculate price ratios p'/p for each asset
    price_ratios = prices / prev_prices

    # Calculate the product term ∏(p'/p)^w
    price_product_ratio = jnp.prod(price_ratios**weights)

    # Calculate final reserve ratios
    reserve_ratios = price_product_ratio / price_ratios
    return reserve_ratios
This implementation is:
  • Fully vectorized for parallel computation, computing this for all assets and time steps simultaneously (as we have obtained the ratio between reserves at different times and the result only depends on the weights and the prices, not the prior reserves):

  • JIT-compiled for performance, via the @jit decorator

  • Numerically stable through use of ratios rather than absolute values

  • Handles arbitrary numbers of assets

With no fees arbitrageurs will always trade to exactly match external market prices. With fees, we need more complex calculations to account for the exact structure of the arbitrage trade, as well as other factors like gas costs, as we will see below.

2. With Fees Case

The implementation with fees requires more complex arbitrage calculations:

@partial(jit, static_argnums=2)
def calculate_reserves_with_fees(
    self,
    params: Dict[str, Any],
    run_fingerprint: Dict[str, Any],
    prices: jnp.ndarray,
    start_index: jnp.ndarray,
    additional_oracle_input: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
    weights = self.calculate_weights(params)
    bout_length = run_fingerprint["bout_length"]
    n_assets = run_fingerprint["n_assets"]
    local_prices = dynamic_slice(prices, start_index, (bout_length - 1, n_assets))

    if run_fingerprint["arb_frequency"] != 1:
        arb_acted_upon_local_prices = local_prices[
            :: run_fingerprint["arb_frequency"]
        ]
    else:
        arb_acted_upon_local_prices = local_prices

    # calculate initial reserves
    initial_pool_value = run_fingerprint["initial_pool_value"]
    initial_value_per_token = weights * initial_pool_value
    initial_reserves = initial_value_per_token / local_prices[0]

    if run_fingerprint["do_arb"]:
        reserves = _jax_calc_balancer_reserves_with_fees_using_precalcs(
            initial_reserves,
            weights,
            arb_acted_upon_local_prices,
            fees=run_fingerprint["fees"],
            arb_thresh=run_fingerprint["gas_cost"],
            arb_fees=run_fingerprint["arb_fees"],
            all_sig_variations=jnp.array(run_fingerprint["all_sig_variations"]),
        )
    else:
        reserves = jnp.broadcast_to(
            initial_reserves, arb_acted_upon_local_prices.shape
        )

    return reserves

This implementation has a similar structure to the zero-fees case, but with the addition of the fees, arb_thresh, and arb_fees parameters. These parameters are used to account for the exact structure of the arbitrage trade, as well as other factors like gas costs. For a deep dive into this part of the codebase, see _jax_calc_balancer_reserves_with_fees_using_precalcs(). The underlying mathematics is provided in this paper by the team on optimal arbitrage trades in G3Ms in the presence of fees.

3. Dynamic Inputs Case

For time-varying parameters:

@partial(jit, static_argnums=2)
def calculate_reserves_with_dynamic_inputs(
    self,
    params: Dict[str, Any],
    run_fingerprint: Dict[str, Any],
    prices: jnp.ndarray,
    start_index: jnp.ndarray,
    fees_array: jnp.ndarray,
    arb_thresh_array: jnp.ndarray,
    arb_fees_array: jnp.ndarray,
    trade_array: jnp.ndarray,
    additional_oracle_input: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
    bout_length = run_fingerprint["bout_length"]
    n_assets = run_fingerprint["n_assets"]

    local_prices = dynamic_slice(prices, start_index, (bout_length - 1, n_assets))
    weights = self.calculate_weights(params)

    if run_fingerprint["arb_frequency"] != 1:
        arb_acted_upon_local_prices = local_prices[
            :: run_fingerprint["arb_frequency"]
        ]
    else:
        arb_acted_upon_local_prices = local_prices

    initial_pool_value = run_fingerprint["initial_pool_value"]
    initial_value_per_token = weights * initial_pool_value
    initial_reserves = initial_value_per_token / arb_acted_upon_local_prices[0]

    # any of fees_array, arb_thresh_array, arb_fees_array, trade_array
    # can be singletons, in which case we repeat them for the length of the bout

    # Determine the maximum leading dimension
    max_len = bout_length - 1
    if run_fingerprint["arb_frequency"] != 1:
        max_len = max_len // run_fingerprint["arb_frequency"]
    # Broadcast input arrays to match the maximum leading dimension.
    # If they are singletons, this will just repeat them for the length of the bout.
    # If they are arrays of length bout_length, this will cause no change.
    fees_array_broadcast = jnp.broadcast_to(
        fees_array, (max_len,) + fees_array.shape[1:]
    )
    arb_thresh_array_broadcast = jnp.broadcast_to(
        arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:]
    )
    arb_fees_array_broadcast = jnp.broadcast_to(
        arb_fees_array, (max_len,) + arb_fees_array.shape[1:]
    )
    # if we are doing trades, the trades array must be of the same length as the other arrays
    if run_fingerprint["do_trades"]:
        assert trade_array.shape[0] == max_len
    reserves = _jax_calc_balancer_reserves_with_dynamic_inputs(
        initial_reserves,
        weights,
        arb_acted_upon_local_prices,
        fees_array_broadcast,
        arb_thresh_array_broadcast,
        arb_fees_array_broadcast,
        jnp.array(run_fingerprint["all_sig_variations"]),
        trade_array,
        run_fingerprint["do_trades"],
        run_fingerprint["do_arb"],
    )
    return reserves

This method is more complex still, with the addition of the fees_array, arb_thresh_array, arb_fees_array, and trade_array parameters. The function _jax_calc_balancer_reserves_with_dynamic_inputs() is doing the heavy lifting here. It implements the same core logic as the fees case above, but also contains the logic for time varying fees, arbitrage thresholds, arbitrage fees, and so on, and enabling “exact out given in” trades to be done from the trade array input.

Helper Methods

Finally, we implement required helper methods:

def init_base_parameters(
    self,
    initial_values_dict: Dict[str, Any],
    run_fingerprint: Dict[str, Any],
    n_assets: int,
    n_parameter_sets: int = 1,
    noise: str = "gaussian",
) -> Dict[str, Any]:
    np.random.seed(0)

    # We need to initialise the weights for each parameter set
    # If a vector is provided in the inital values dict, we use
    # that, if only a singleton array is provided we expand it
    # to n_assets and use that vlaue for all assets.
    def process_initial_values(
        initial_values_dict, key, n_assets, n_parameter_sets
    ):
        if key in initial_values_dict:
            initial_value = initial_values_dict[key]
            if isinstance(initial_value, (np.ndarray, jnp.ndarray, list)):
                initial_value = np.array(initial_value)
                if initial_value.size == n_assets:
                    return np.array([initial_value] * n_parameter_sets)
                elif initial_value.size == 1:
                    return np.array([[initial_value] * n_assets] * n_parameter_sets)
                elif initial_value.shape == (n_parameter_sets, n_assets):
                    return initial_value
                else:
                    raise ValueError(
                        f"{key} must be a singleton or a vector of length n_assets"
                         +  "or a matrix of shape (n_parameter_sets, n_assets)"
                    )
            else:
                return np.array([[initial_value] * n_assets] * n_parameter_sets)
        else:
            raise ValueError(f"initial_values_dict must contain {key}")

    initial_weights_logits = process_initial_values(
        initial_values_dict, "initial_weights_logits", n_assets, n_parameter_sets
    )
    params = {
        "initial_weights_logits": initial_weights_logits,
    }
    params = self.add_noise(params, noise, n_parameter_sets)
    return params

def is_trainable(self):
    """Balancer pools have fixed weights and are not trainable."""
    return False

Note

JAX enables very efficient vmapping over the parameters of a pool, and out the box this is enabled via the method make_vmap_in_axes() provided in the base class. If the pool has a particularly complex structure in its parameters, however, (e.g. dicts of dicts where different levels of the hierachy have different vmap axes, for example) it may be necessary to implement a custom method to enable vmapping over that pool’s params dict.

Note

After a pool class is created, it should be registered with JAX as a pytree. For the Balancer pool class, the call looks like this:

jax.tree_util.register_pytree_node(
    BalancerPool, BalancerPool._tree_flatten, BalancerPool._tree_unflatten
)

This can be put directly below the class definition. The methods _tree_flatten() and _tree_unflatten() are provided in the base class. For custom pools that maintain internal state (breaking the standard design pattern for pool classes to be stateless) these methods would perhaps need to be overridden.

Note

If you want to go further and access your pool via the frontend and if your pool has parameters that have both human-readable-but-contrained and hard-to-interpret-but-unconstrained representations, we recommend that you implement _process_specific_parameters() that takes the human-readable parameterisation and converts it into the underlying parameterisation. See the implementation of _process_specific_parameters() for an example of this. See Constrained vs unconstrained parameters for more details on the use of constrained vs unconstrained parameters.

This implementation demonstrates how to create a pool type with:

  • Efficient JAX-accelerated calculations

  • Support for fees and arbitrage

  • Proper handling of dynamic parameters

  • Clear separation of zero-fee and fee-based calculations