Performance Optimization¶
This guide provides best practices and optimization strategies for achieving maximum performance with schr on both CPUs and GPUs.
Hardware Considerations¶
GPU vs CPU¶
When to use GPU:
Grid size >= 512 in any dimension
Long simulations (thousands of time steps)
Parameter sweeps or batch processing
2D/3D systems
When CPU is sufficient:
1D systems with N < 4096
Quick prototyping and debugging
Systems without NVIDIA GPUs
Single-shot simulations
Performance Scaling¶
Typical speedups (GPU vs CPU):
Problem Size |
Speedup |
Notes |
|---|---|---|
1D, N=1024 |
2-5× |
Small overhead |
1D, N=8192 |
10-20× |
FFT-dominated |
2D, 512×512 |
20-50× |
Good GPU utilization |
2D, 1024×1024 |
50-100× |
Optimal for GPU |
3D, 256×256×256 |
100-200× |
Best GPU advantage |
Memory Requirements¶
RAM/VRAM Usage:
For complex64 arrays (8 bytes per element):
Grid Size |
Memory (complex64) |
Memory (complex128) |
|---|---|---|
1D: 8192 |
64 KB |
128 KB |
2D: 1024×1024 |
8 MB |
16 MB |
2D: 2048×2048 |
32 MB |
64 MB |
3D: 256×256×256 |
128 MB |
256 MB |
3D: 512×512×512 |
1 GB |
2 GB |
Add ~2-3× for intermediate arrays during computation.
JAX Optimization¶
JIT Compilation¶
Just-In-Time (JIT) compilation is crucial for performance:
from jax import jit
from schr.qm.solvers import SplitStepFourier
solver = SplitStepFourier(hamiltonian)
# Without JIT (slow)
for step in range(1000):
psi = solver.step(psi, step * dt, dt)
# With JIT (fast)
@jit
def evolve_step(psi, t, dt):
return solver.step(psi, t, dt)
for step in range(1000):
psi = evolve_step(psi, step * dt, dt)
First call is slow (compilation), subsequent calls are fast.
Static Arguments¶
Mark static arguments in JIT:
from functools import partial
@partial(jit, static_argnums=(2,))
def evolve_with_absorption(psi, t, apply_absorption):
psi = solver.step(psi, t, dt)
if apply_absorption:
psi = psi * absorption_mask
return psi
Vectorization¶
Use vmap for batch processing:
from jax import vmap
# Evolve multiple initial conditions in parallel
initial_states = jnp.array([psi0_1, psi0_2, psi0_3])
@jit
def evolve_single(psi0):
psi = psi0
for step in range(num_steps):
psi = solver.step(psi, step * dt, dt)
return psi
# Vectorized version
evolve_batch = vmap(evolve_single)
final_states = evolve_batch(initial_states)
Precision Trade-offs¶
float32/complex64 vs float64/complex128:
# Default: complex64 (recommended)
x, dx = create_grid_1d(-10, 10, 1024, dtype=jnp.float32)
psi = psi.astype(jnp.complex64)
# High precision: complex128 (2× memory, slower)
x, dx = create_grid_1d(-10, 10, 1024, dtype=jnp.float64)
psi = psi.astype(jnp.complex128)
Trade-offs:
complex64: 2× faster, 2× less memory, ~7 digits precision
complex128: Higher accuracy, needed for very long simulations
For most quantum simulations, complex64 is sufficient.
Memory Management¶
GPU Memory¶
Check GPU memory:
import jax
# GPU memory info
devices = jax.devices('gpu')
if devices:
device = devices[0]
print(f"Device: {device}")
# Memory usage varies by GPU
Clear GPU memory:
import jax
# Clear compiled functions and free GPU memory
jax.clear_backends()
Out of Memory Solutions¶
If you encounter XLA_ERROR: Out of memory:
Reduce grid size:
# Instead of 2048×2048 X, Y, dx, dy = create_grid_2d(-500, 500, 1024, -500, 500, 1024)
Use float32:
dtype = jnp.float32 # Instead of float64
Process in chunks:
# Split time evolution into batches for batch in range(num_batches): for step in range(steps_per_batch): psi = solver.step(psi, (batch * steps_per_batch + step) * dt, dt) # Save checkpoint jnp.save(f'checkpoint_{batch}.npy', psi) # Clear cache jax.clear_caches()
Enable memory preallocation:
export XLA_PYTHON_CLIENT_PREALLOCATE=false
FFT Optimization¶
Grid Size Selection¶
FFT is fastest for sizes that are powers of 2:
# Good choices (powers of 2)
good_sizes = [256, 512, 1024, 2048, 4096, 8192]
# Avoid (prime factors)
avoid_sizes = [1000, 1500, 2000, 3000]
Benchmark different sizes:
import time
import jax.numpy as jnp
for n in [1000, 1024, 2000, 2048]:
x = jnp.ones((n, n), dtype=jnp.complex64)
# Warmup
_ = jnp.fft.fft2(x)
# Benchmark
start = time.time()
for _ in range(100):
_ = jnp.fft.fft2(x)
elapsed = time.time() - start
print(f"n={n}: {elapsed/100*1000:.2f} ms/FFT")
In-Place Operations¶
JAX doesn’t support in-place operations, but arrays are reused efficiently:
# This is efficient (array reuse)
psi = psi * absorption_mask
psi = jnp.fft.fftn(psi)
psi = psi * kinetic_operator
psi = jnp.fft.ifftn(psi)
Parallel Processing¶
Multiple GPUs¶
For systems with multiple GPUs:
import jax
from jax import pmap
# Distribute across GPUs
devices = jax.devices('gpu')
print(f"Available GPUs: {len(devices)}")
# Parallel map across devices
@pmap
def evolve_parallel(psi0):
psi = psi0
for step in range(num_steps):
psi = solver.step(psi, step * dt, dt)
return psi
# Split initial conditions across GPUs
initial_states = jnp.array([psi0_1, psi0_2, psi0_3, psi0_4])
final_states = evolve_parallel(initial_states)
CPU Parallelization¶
JAX automatically uses multiple CPU cores:
# Control number of threads
export XLA_FLAGS="--xla_cpu_multi_thread_eigen=true"
export OMP_NUM_THREADS=8