JAX Configuration

quantammsim uses JAX for numerical computation. This guide explains how JAX is configured and how to optimize performance.

Default Configuration

quantammsim automatically configures JAX at import time:

from jax import config
config.update("jax_enable_x64", True)  # Enable 64-bit floats

This is required for numerical precision in financial calculations.

Backend Selection

JAX automatically detects available hardware:

from jax import default_backend, devices

# Check current backend
print(default_backend())  # "cpu" or "gpu"

# List available devices
print(devices("cpu"))
print(devices("gpu"))  # If GPU available

Note

TPU Support: quantammsim does not currently support TPUs. The codebase assumes either CPU or GPU backends. If you’re running on a TPU-equipped system, force the CPU or GPU backend using the environment variables described below.

Forcing CPU or GPU

quantammsim detects the backend and configures accordingly:

# In quantammsim modules:
DEFAULT_BACKEND = default_backend()

if DEFAULT_BACKEND != "cpu":
    GPU_DEVICE = devices("gpu")[0]
    config.update("jax_platform_name", "gpu")
else:
    GPU_DEVICE = devices("cpu")[0]
    config.update("jax_platform_name", "cpu")

To force a specific backend before importing quantammsim:

# Force CPU
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

# Then import
import quantammsim

Or to force GPU:

os.environ["JAX_PLATFORM_NAME"] = "gpu"
import quantammsim

Device Placement

Some computations are explicitly placed on CPU for efficiency:

from jax import device_put

# Move data to CPU
cpu_array = device_put(gpu_array, CPU_DEVICE)

This is used for operations where CPU is faster (e.g., scan operations with complex carry states).

Memory Management

JAX pre-allocates GPU memory by default. To change this:

# Before importing JAX
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Or limit memory fraction
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"  # Use 50% of GPU memory

Compilation and Caching

JIT Compilation

Most quantammsim functions are JIT-compiled:

from jax import jit

@jit
def my_function(x):
    return x * 2

The first call compiles the function; subsequent calls are fast.

Cache Directory

JAX caches compiled functions. To set a persistent cache:

import os
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/path/to/cache"

This speeds up startup when running the same code repeatedly.

Debugging

Disable JIT for Debugging

from jax import config
config.update("jax_disable_jit", True)

This runs all operations eagerly, making debugging easier but much slower.

Enable Logging

# Show XLA compilation logs
os.environ["XLA_FLAGS"] = "--xla_dump_to=/tmp/xla_dump"

Full Traceback

JAX simplifies tracebacks by default. For full tracebacks:

os.environ["JAX_TRACEBACK_FILTERING"] = "off"

Performance Tips

  1. Batch Operations

    Use vmap for batched operations instead of Python loops:

    from jax import vmap
    
    # Instead of:
    results = [func(x) for x in batch]
    
    # Use:
    results = vmap(func)(batch)
    
  2. Avoid Python Control Flow

    Use JAX primitives (lax.cond, lax.fori_loop) instead of Python if/for:

    from jax import lax
    
    # Instead of:
    for i in range(n):
        x = update(x)
    
    # Use:
    x = lax.fori_loop(0, n, lambda i, x: update(x), x)
    
  3. Static Arguments

    Mark non-array arguments as static for better compilation:

    from functools import partial
    
    @partial(jit, static_argnums=(1,))
    def func(x, config_dict):
        ...
    
  4. Minimize Host-Device Transfers

    Keep data on the accelerator; avoid frequent transfers to CPU.

Common Issues

NaN Values

If you encounter NaN values, check:

  • Division by zero in your calculations

  • Log of negative numbers

  • Overflow in exponentials

Enable NaN checking:

from jax import config
config.update("jax_debug_nans", True)

Out of Memory

For large simulations:

  1. Reduce batch size

  2. Use gradient checkpointing

  3. Process data in chunks

  4. Limit GPU memory pre-allocation (see Memory Management above)

Slow Compilation

First-time compilation can be slow. Solutions:

  1. Use persistent compilation cache

  2. Reduce function complexity

  3. Use static_argnums for configuration dicts

Environment Variables Summary

Variable

Description

JAX_PLATFORM_NAME

Force backend: “cpu”, “gpu”, “tpu”

XLA_PYTHON_CLIENT_PREALLOCATE

“false” to disable GPU memory pre-allocation

XLA_PYTHON_CLIENT_MEM_FRACTION

GPU memory fraction (0.0-1.0)

JAX_COMPILATION_CACHE_DIR

Path for persistent compilation cache

JAX_TRACEBACK_FILTERING

“off” for full tracebacks

JAX_DISABLE_JIT

“1” to disable JIT compilation