Source code for quantammsim.core_simulator.windowing_utils

import numpy as np
import pandas as pd

# again, this only works on startup!
from jax import config

config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax import random


[docs] def get_indices( start_index, bout_length, len_prices, key, optimisation_settings, ): """ Get indices for sampling data windows during training. Parameters ---------- start_index : int Starting index position in the data bout_length : int Length of each training window/bout len_prices : int Total length of the price data key : jax.random.PRNGKey JAX random number generator key optimisation_settings : dict Dictionary containing optimization settings with keys: - batch_size: Number of windows to sample - training_data_kind: Type of training data ('historic' or 'mc') - sample_method: Method for sampling windows ('exponential' or 'uniform') - max_mc_version: Maximum MC version number (only used if training_data_kind='mc') Returns ------- tuple - start_indexes : jnp.ndarray Array of sampled starting indices, shape (batch_size, 2) for historic data or (batch_size, 3) for MC data - key : jax.random.PRNGKey Updated random number generator key """ # first split indices key, subkey = random.split(key) # we want to sample starting indices for training # simplest thing is simply anything before the bout is up batch_size = optimisation_settings["batch_size"] training_data_kind = optimisation_settings["training_data_kind"] sample_method = optimisation_settings["sample_method"] if training_data_kind == "historic": return_shape = (batch_size, 2) sample_shape = (batch_size,) elif training_data_kind == "mc": max_mc_version = optimisation_settings["max_mc_version"] # MC versions are indexed from zero return_shape = (batch_size, 3) sample_shape = (batch_size,) start_indexes = jnp.zeros(return_shape, dtype=jnp.int64) range_ = int(len_prices - bout_length - start_index) if sample_method == "exponential": probs = 5 * jnp.arange(range_) / range_ probs = jnp.exp(probs) probs = probs / jnp.sum(probs) # start_indexes[:, 0] = start_index + np.random.choice( # range_, size=batch_size, replace=False, p=probs # ) start_indexes = start_indexes.at[:, 0].set( start_index + random.choice(subkey, range_, shape=sample_shape, replace=False, p=probs) ) elif sample_method == "uniform": start_indexes = start_indexes.at[:, 0].set( start_index + random.choice(subkey, range_, shape=sample_shape, replace=False) ) elif sample_method == "stratified": # Use float boundaries so the last segment extends to the full range seg_boundaries = jnp.linspace(0, range_, batch_size + 1).astype(jnp.int64) seg_starts = seg_boundaries[:-1] seg_sizes = seg_boundaries[1:] - seg_starts offsets = random.randint(subkey, shape=sample_shape, minval=0, maxval=jnp.maximum(seg_sizes, 1)) start_indexes = start_indexes.at[:, 0].set( start_index + seg_starts + offsets ) else: raise NotImplementedError if training_data_kind == "mc": key, subkey = random.split(key) start_indexes = start_indexes.at[:, -1].set( random.choice(subkey, max_mc_version + 1, shape=sample_shape, replace=True) ) return start_indexes, key
[docs] def raw_trades_to_trade_array(raw_trades, start_date_string, end_date_string, tokens): """ Convert raw trade data to a structured trade array. This function takes raw trade data and converts it into a pandas DataFrame with a continuous range of Unix timestamps. Each row in the DataFrame represents a minute, and trades are mapped to their corresponding timestamps. Parameters ---------- raw_trades : pandas df Raw trades, where each trade is a row containing unix_timestamp, token_in (str), token_out (str), amount_in). start_time : str The start date time in format "%Y-%m-%d %H:%M:%S". end_time : str The end date time in format "%Y-%m-%d %H:%M:%S". tokens : list of str The tokens of the run Returns ------- numpy array: A numpy array with columns 'token in', 'token out', and 'amount in'. The index is a continuous range of Unix timestamps from start_unix to end_unix at minute intervals. Timestamps without trades are filled with zeros. """ # Create a DataFrame with a continuous range of Unix timestamps full_index = ( pd.date_range( start=pd.to_datetime(start_date_string, format="%Y-%m-%d %H:%M:%S"), end=pd.to_datetime(end_date_string, format="%Y-%m-%d %H:%M:%S"), freq="T", ).astype(int) // 10**6 ) full_index_df = pd.DataFrame( index=full_index, columns=["token_in", "token_out", "amount_in"], data=0 ) # Create a dictionary to map token strings to their numerical indexes token_to_index = {token: index for index, token in enumerate(tokens)} for index, row in raw_trades.iterrows(): unix_timestamp = row["unix"] token_in = row["token_in"] token_out = row["token_out"] amount_in = row["amount_in"] if unix_timestamp in full_index_df.index: token_in_index = token_to_index.get(token_in, 0) # Use 0 for unknown tokens token_out_index = token_to_index.get( token_out, 0 ) # Use 0 for unknown tokens full_index_df.loc[unix_timestamp] = [ token_in_index, token_out_index, amount_in, ] return np.array(full_index_df)[:-1]
[docs] def raw_fee_like_amounts_to_fee_like_array( raw_inputs, start_date_string, end_date_string, names, fill_method="base" ): """ Convert raw fee-like data to a structured fee-like array. Takes raw fee-like data (fees, gas costs, arb fees) and converts it into a pandas DataFrame with a continuous range of Unix timestamps. Each row represents a minute, with trades mapped to their corresponding timestamps. Parameters ---------- raw_inputs : pandas.DataFrame Raw fee-like data, where each row contains unix_timestamp and the fee-like amount with given column name start_time : str The start date time in format "%Y-%m-%d %H:%M:%S" end_time : str The end date time in format "%Y-%m-%d %H:%M:%S" names : list of str The names of columns in raw_inputs of the fee-like amount fill_method : str The method to fill in missing values. Options: - 'base': fills rows which have no values with 0 - 'ffill': fills with the last non-zero value Returns ------- numpy.ndarray Array giving the fee-like values over time. The index is a continuous range of Unix timestamps from start_unix to end_unix at minute intervals. Timestamps without values are filled with zeros. """ # Create a DataFrame with a continuous range of Unix timestamps full_index = ( pd.date_range( start=pd.to_datetime(start_date_string, format="%Y-%m-%d %H:%M:%S"), end=pd.to_datetime(end_date_string, format="%Y-%m-%d %H:%M:%S"), freq="min", ).astype(int) // 10**6 )[:-1] full_index_df = pd.DataFrame( index=full_index, columns=names, data=0, dtype=np.float64 ) # Map raw data to the full index DataFrame for index, row in raw_inputs.iterrows(): unix_timestamp = int(row["unix"]) if unix_timestamp in full_index_df.index: for name in names: full_index_df.loc[unix_timestamp, name] = float(row[name]) # Apply fill method if fill_method == "ffill": try: # Validate required columns exist if 'unix' not in raw_inputs.columns: raise KeyError("raw_inputs must contain 'unix' column") for name in names: if name not in raw_inputs.columns: raise KeyError(f"raw_inputs missing required column: {name}") # Convert start_date_string to unix timestamp start_unix = pd.to_datetime(start_date_string, format="%Y-%m-%d %H:%M:%S").value // 10**6 # Ensure unix values are valid valid_unix = pd.to_numeric(raw_inputs['unix'], errors='coerce') valid_mask = valid_unix.notna() for name in names: initial_value = None if valid_mask.any(): # Try to get the last value before our start date previous_values = raw_inputs[ valid_mask & (valid_unix < start_unix) ] if not previous_values.empty: try: initial_value = pd.to_numeric(previous_values[name].iloc[-1]) except (ValueError, TypeError): initial_value = None if initial_value is None or pd.isna(initial_value): # Try to get first value in our date range in_range_values = raw_inputs[ valid_mask & (valid_unix >= start_unix) ] if not in_range_values.empty: try: initial_value = pd.to_numeric(in_range_values[name].iloc[0]) except (ValueError, TypeError): initial_value = None if initial_value is not None and pd.notna(initial_value): # this more complex logic is because of how we have started with prior-to-start values # filled in, and then we want to ffill the rest # Fill initial values full_index_df[name] = full_index_df[name].mask( full_index_df[name] == 0, initial_value ) # Use ffill() full_index_df[name] = full_index_df[name].where( full_index_df[name] != 0 ).ffill() except (ValueError, KeyError, TypeError) as e: print(f"Warning: Error during ffill processing: {str(e)}") # On any error, return the original zero-filled DataFrame pass # If fill_method is 'base', we don't need to do anything as zeros are already in place elif fill_method == "base": pass else: raise NotImplementedError if len(names) == 1: return np.array(full_index_df, dtype=np.float64)[:,0] else: return np.array(full_index_df, dtype=np.float64)
[docs] def filter_coarse_weights_by_data_indices(coarse_weights, data_dict): """Slice coarse (chunk-period) weights to match the time range in ``data_dict``. Used when pre-computed coarse weights are loaded from a previous run or from on-chain data, and need to be aligned with the current training/evaluation window. Parameters ---------- coarse_weights : dict Must contain ``'unix_values'`` (timestamps) and ``'weights'`` array of shape ``(T_coarse, n_assets)``. data_dict : dict Must contain ``'unix_values'``, ``'start_idx'``, and ``'end_idx'``. Returns ------- dict Shallow copy of ``coarse_weights`` with ``'weights'`` sliced to the matching time range. """ weights_start_index = np.where( coarse_weights["unix_values"] == data_dict["unix_values"][data_dict["start_idx"]] )[0][0] weights_end_index = np.where( coarse_weights["unix_values"] == data_dict["unix_values"][data_dict["end_idx"] - 1] )[0][0] filtered_weights = coarse_weights.copy() filtered_weights["weights"] = filtered_weights["weights"][ weights_start_index : (weights_end_index + 1) ] return filtered_weights
[docs] def filter_reserves_by_data_indices(reserves, unix_values, data_dict): """Slice a reserves array to match the time range in ``data_dict``. Looks up the unix timestamps at ``data_dict["start_idx"]`` and ``data_dict["end_idx"] - 1`` in ``unix_values``, then returns the corresponding slice of ``reserves``. Parameters ---------- reserves : np.ndarray Full reserves array, shape ``(T, n_assets)``. unix_values : np.ndarray Unix timestamps corresponding to each row of ``reserves``. data_dict : dict Must contain ``unix_values``, ``start_idx``, and ``end_idx``. Returns ------- np.ndarray Sliced reserves matching the data_dict time range. """ reserves_start_index = np.where( unix_values == data_dict["unix_values"][data_dict["start_idx"]] )[0][0] reserves_end_index = np.where( unix_values == data_dict["unix_values"][data_dict["end_idx"] - 1] )[0][0] filtered_reserves = reserves.copy() filtered_reserves = filtered_reserves[ reserves_start_index : (reserves_end_index + 1) ] return filtered_reserves
[docs] def filter_reserves_by_given_timestamp(reserves, unix_values, timestamp): """Extract the reserves row at a specific unix timestamp. Parameters ---------- reserves : np.ndarray Full reserves array, shape ``(T, n_assets)``. unix_values : np.ndarray Unix timestamps corresponding to each row. timestamp : int Unix timestamp (milliseconds) to look up. Returns ------- np.ndarray Reserves at the given timestamp, shape ``(n_assets,)``. Raises ------ IndexError If ``timestamp`` is not found in ``unix_values``. """ reserves_index = np.where( unix_values == timestamp )[0][0] return reserves[reserves_index].copy()