"""CoW AMM (Function-Maximising AMM) pool implementation.
Implements the CoW Protocol's FM-AMM, which replaces the continuous
arbitrage of traditional AMMs with batch-auction-based rebalancing.
The pool computes reserves under both perfect and single-arbitrageur
models, blended by an ``arb_quality`` parameter, and supports dynamic
fees, external trades, and noise traders.
"""
from typing import Dict, Any, Optional
# again, this only works on startup!
from jax import config
from jax import default_backend
from jax import devices, tree_util
config.update("jax_enable_x64", True)
DEFAULT_BACKEND = default_backend()
CPU_DEVICE = devices("cpu")[0]
if DEFAULT_BACKEND != "cpu":
GPU_DEVICE = devices("gpu")[0]
config.update("jax_platform_name", "gpu")
else:
GPU_DEVICE = devices("cpu")[0]
config.update("jax_platform_name", "cpu")
import jax.numpy as jnp
from jax import jit, vmap
from jax import device_put
from jax import tree_util
from jax.lax import stop_gradient, dynamic_slice
from typing import Dict, Any, Optional, Callable
from functools import partial
import numpy as np
from quantammsim.pools.base_pool import AbstractPool
from quantammsim.pools.FM_AMM.cow_reserves import (
_jax_calc_cowamm_reserve_ratio_vmapped,
_jax_calc_cowamm_reserves_with_fees,
_jax_calc_cowamm_reserves_one_arb_zero_fees,
_jax_calc_cowamm_reserves_one_arb_with_fees,
_jax_calc_cowamm_reserves_with_dynamic_inputs,
)
[docs]
class CowPool(AbstractPool):
"""
A class representing a CowPool, which is a type of automated market maker (AMM) pool
with specific characteristics.
Methods
-------
__init__():
Initializes the CowPool instance.
calculate_reserves_with_fees(params, run_fingerprint, prices,
start_index, additional_oracle_input=None) -> jnp.ndarray:
Calculates the reserves of the pool considering fees.
calculate_reserves_zero_fees(params, run_fingerprint, prices,
start_index, additional_oracle_input=None) -> jnp.ndarray:
Calculates the reserves of the pool without considering fees.
calculate_reserves_with_dynamic_inputs(params, run_fingerprint, prices,
start_index, fees_array, arb_thresh_array, arb_fees_array, trade_array,
additional_oracle_input=None) -> jnp.ndarray:
Calculates the reserves of the pool with dynamic inputs for fees,
arbitrage thresholds, arbitrage fees, and trades.
init_base_parameters(initial_values_dict, run_fingerprint, n_assets,
n_parameter_sets=1, noise="gaussian") -> Dict[str, Any]:
Initializes the base parameters for the pool. Cow pools have no parameters.
calculate_initial_weights(params, *args, **kwargs) -> jnp.ndarray:
Calculates the weights for the assets in the pool. For CowPool,
the weights are always [0.5, 0.5].
make_vmap_in_axes(params, n_repeats_of_recurred=0):
Creates the vmap in_axes for the parameters.
is_trainable() -> bool:
Indicates whether the pool is trainable. For CowPool, this is always False.
"""
[docs]
def __init__(self):
super().__init__()
[docs]
@partial(jit, static_argnums=(2))
def calculate_reserves_with_fees(
self,
params: Dict[str, Any],
run_fingerprint: Dict[str, Any],
prices: jnp.ndarray,
start_index: jnp.ndarray,
additional_oracle_input: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
weights = self.calculate_initial_weights(params)
bout_length = run_fingerprint["bout_length"]
n_assets = run_fingerprint["n_assets"]
local_prices = dynamic_slice(prices, start_index, (bout_length - 1, n_assets))
if run_fingerprint["arb_frequency"] != 1:
arb_acted_upon_local_prices = local_prices[:: run_fingerprint["arb_frequency"]]
else:
arb_acted_upon_local_prices = local_prices
# calculate initial reserves
initial_pool_value = run_fingerprint["initial_pool_value"]
initial_value_per_token = weights * initial_pool_value
initial_reserves = initial_value_per_token / local_prices[0]
if run_fingerprint["do_arb"]:
reserves_with_perfect_arbs = _jax_calc_cowamm_reserves_with_fees(
initial_reserves,
arb_acted_upon_local_prices,
weight=weights[0],
fees=run_fingerprint["fees"],
arb_thresh=run_fingerprint["gas_cost"],
arb_fees=run_fingerprint["arb_fees"],
noise_trader_ratio=run_fingerprint["noise_trader_ratio"],
)
# now we need to calculate the reserves with imperfect arbs
# we do this by taking the perfect arbs and then applying a small
# amount of noise to the weights
reserves_with_one_arb = _jax_calc_cowamm_reserves_one_arb_with_fees(
initial_reserves,
arb_acted_upon_local_prices,
weight=weights[0],
fees=run_fingerprint["fees"],
arb_thresh=run_fingerprint["gas_cost"],
arb_fees=run_fingerprint["arb_fees"],
noise_trader_ratio=run_fingerprint["noise_trader_ratio"],
)
reserves = (
run_fingerprint["arb_quality"] * reserves_with_perfect_arbs
+ (1.0 - run_fingerprint["arb_quality"]) * reserves_with_one_arb
)
else:
reserves = jnp.broadcast_to(initial_reserves, local_prices.shape)
return reserves
[docs]
@partial(jit, static_argnums=(2))
def calculate_reserves_zero_fees(
self,
params: Dict[str, Any],
run_fingerprint: Dict[str, Any],
prices: jnp.ndarray,
start_index: jnp.ndarray,
additional_oracle_input: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
weights = self.calculate_initial_weights(params)
bout_length = run_fingerprint["bout_length"]
n_assets = run_fingerprint["n_assets"]
local_prices = dynamic_slice(prices, start_index, (bout_length - 1, n_assets))
if run_fingerprint["arb_frequency"] != 1:
arb_acted_upon_local_prices = local_prices[
:: run_fingerprint["arb_frequency"]
]
else:
arb_acted_upon_local_prices = local_prices
# calculate initial reserves
initial_pool_value = run_fingerprint["initial_pool_value"]
initial_value_per_token = weights * initial_pool_value
initial_reserves = initial_value_per_token / local_prices[0]
if run_fingerprint["do_arb"]:
reserves_with_perfect_arbs = (
_jax_calc_cowamm_reserves_with_fees(
initial_reserves,
arb_acted_upon_local_prices,
weight=weights[0],
fees=0.0,
arb_thresh=run_fingerprint["gas_cost"],
arb_fees=run_fingerprint["arb_fees"],
)
)
reserves_with_one_arb = _jax_calc_cowamm_reserves_one_arb_zero_fees(
initial_reserves,
arb_acted_upon_local_prices,
weight=weights[0],
arb_thresh=run_fingerprint["gas_cost"],
arb_fees=run_fingerprint["arb_fees"],
)
reserves = (
run_fingerprint["arb_quality"] * reserves_with_perfect_arbs
+ (1.0 - run_fingerprint["arb_quality"]) * reserves_with_one_arb
)
else:
reserves = jnp.broadcast_to(initial_reserves, local_prices.shape)
return reserves
[docs]
def init_base_parameters(
self,
initial_values_dict: Dict[str, Any],
run_fingerprint: Dict[str, Any],
n_assets: int,
n_parameter_sets: int = 1,
noise: str = "gaussian",
) -> Dict[str, Any]:
np.random.seed(0)
# We need to initialise the weights for each parameter set
# If a vector is provided in the inital values dict, we use
# that, if only a singleton array is provided we expand it
# to n_assets and use that vlaue for all assets.
def process_initial_values(
initial_values_dict, key, n_assets, n_parameter_sets
):
if key in initial_values_dict:
initial_value = initial_values_dict[key]
if isinstance(initial_value, (np.ndarray, jnp.ndarray, list)):
initial_value = np.array(initial_value)
if initial_value.size == n_assets:
return np.array([initial_value] * n_parameter_sets)
elif initial_value.size == 1:
return np.array([[initial_value] * n_assets] * n_parameter_sets)
elif initial_value.shape == (n_parameter_sets, n_assets):
return initial_value
else:
raise ValueError(
f"{key} must be a singleton or a vector of length n_assets or a matrix of shape (n_parameter_sets, n_assets)"
)
else:
return np.array([[initial_value] * n_assets] * n_parameter_sets)
else:
raise ValueError(f"initial_values_dict must contain {key}")
initial_weights_logits = process_initial_values(
initial_values_dict, "initial_weights_logits", n_assets, n_parameter_sets
)
params = {
"initial_weights_logits": initial_weights_logits,
}
params = self.add_noise(params, noise, n_parameter_sets)
return params
[docs]
def is_trainable(self):
return False
tree_util.register_pytree_node(
CowPool,
CowPool._tree_flatten,
CowPool._tree_unflatten,
)