JAX Configuration
quantammsim uses JAX for numerical computation. This guide explains how JAX is configured and how to optimize performance.
Default Configuration
quantammsim automatically configures JAX at import time:
from jax import config
config.update("jax_enable_x64", True) # Enable 64-bit floats
This is required for numerical precision in financial calculations.
Backend Selection
JAX automatically detects available hardware:
from jax import default_backend, devices
# Check current backend
print(default_backend()) # "cpu" or "gpu"
# List available devices
print(devices("cpu"))
print(devices("gpu")) # If GPU available
Note
TPU Support: quantammsim does not currently support TPUs. The codebase assumes either CPU or GPU backends. If you’re running on a TPU-equipped system, force the CPU or GPU backend using the environment variables described below.
Forcing CPU or GPU
quantammsim detects the backend and configures accordingly:
# In quantammsim modules:
DEFAULT_BACKEND = default_backend()
if DEFAULT_BACKEND != "cpu":
GPU_DEVICE = devices("gpu")[0]
config.update("jax_platform_name", "gpu")
else:
GPU_DEVICE = devices("cpu")[0]
config.update("jax_platform_name", "cpu")
To force a specific backend before importing quantammsim:
# Force CPU
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
# Then import
import quantammsim
Or to force GPU:
os.environ["JAX_PLATFORM_NAME"] = "gpu"
import quantammsim
Device Placement
Some computations are explicitly placed on CPU for efficiency:
from jax import device_put
# Move data to CPU
cpu_array = device_put(gpu_array, CPU_DEVICE)
This is used for operations where CPU is faster (e.g., scan operations with complex carry states).
Memory Management
JAX pre-allocates GPU memory by default. To change this:
# Before importing JAX
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
# Or limit memory fraction
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" # Use 50% of GPU memory
Compilation and Caching
JIT Compilation
Most quantammsim functions are JIT-compiled:
from jax import jit
@jit
def my_function(x):
return x * 2
The first call compiles the function; subsequent calls are fast.
Cache Directory
JAX caches compiled functions. To set a persistent cache:
import os
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/path/to/cache"
This speeds up startup when running the same code repeatedly.
Debugging
Disable JIT for Debugging
from jax import config
config.update("jax_disable_jit", True)
This runs all operations eagerly, making debugging easier but much slower.
Enable Logging
# Show XLA compilation logs
os.environ["XLA_FLAGS"] = "--xla_dump_to=/tmp/xla_dump"
Full Traceback
JAX simplifies tracebacks by default. For full tracebacks:
os.environ["JAX_TRACEBACK_FILTERING"] = "off"
Performance Tips
Batch Operations
Use
vmapfor batched operations instead of Python loops:from jax import vmap # Instead of: results = [func(x) for x in batch] # Use: results = vmap(func)(batch)
Avoid Python Control Flow
Use JAX primitives (
lax.cond,lax.fori_loop) instead of Pythonif/for:from jax import lax # Instead of: for i in range(n): x = update(x) # Use: x = lax.fori_loop(0, n, lambda i, x: update(x), x)
Static Arguments
Mark non-array arguments as static for better compilation:
from functools import partial @partial(jit, static_argnums=(1,)) def func(x, config_dict): ...
Minimize Host-Device Transfers
Keep data on the accelerator; avoid frequent transfers to CPU.
Common Issues
NaN Values
If you encounter NaN values, check:
Division by zero in your calculations
Log of negative numbers
Overflow in exponentials
Enable NaN checking:
from jax import config
config.update("jax_debug_nans", True)
Out of Memory
For large simulations:
Reduce batch size
Use gradient checkpointing
Process data in chunks
Limit GPU memory pre-allocation (see Memory Management above)
Slow Compilation
First-time compilation can be slow. Solutions:
Use persistent compilation cache
Reduce function complexity
Use
static_argnumsfor configuration dicts
Environment Variables Summary
Variable |
Description |
|---|---|
|
Force backend: “cpu”, “gpu”, “tpu” |
|
“false” to disable GPU memory pre-allocation |
|
GPU memory fraction (0.0-1.0) |
|
Path for persistent compilation cache |
|
“off” for full tracebacks |
|
“1” to disable JIT compilation |