Performance Optimization

This guide provides best practices and optimization strategies for achieving maximum performance with schr on both CPUs and GPUs.

Hardware Considerations

GPU vs CPU

When to use GPU:

  • Grid size >= 512 in any dimension

  • Long simulations (thousands of time steps)

  • Parameter sweeps or batch processing

  • 2D/3D systems

When CPU is sufficient:

  • 1D systems with N < 4096

  • Quick prototyping and debugging

  • Systems without NVIDIA GPUs

  • Single-shot simulations

Performance Scaling

Typical speedups (GPU vs CPU):

Problem Size

Speedup

Notes

1D, N=1024

2-5×

Small overhead

1D, N=8192

10-20×

FFT-dominated

2D, 512×512

20-50×

Good GPU utilization

2D, 1024×1024

50-100×

Optimal for GPU

3D, 256×256×256

100-200×

Best GPU advantage

Memory Requirements

RAM/VRAM Usage:

For complex64 arrays (8 bytes per element):

\[\text{Memory} = 8N^d \text{ bytes}\]

Grid Size

Memory (complex64)

Memory (complex128)

1D: 8192

64 KB

128 KB

2D: 1024×1024

8 MB

16 MB

2D: 2048×2048

32 MB

64 MB

3D: 256×256×256

128 MB

256 MB

3D: 512×512×512

1 GB

2 GB

Add ~2-3× for intermediate arrays during computation.

JAX Optimization

JIT Compilation

Just-In-Time (JIT) compilation is crucial for performance:

from jax import jit
from schr.qm.solvers import SplitStepFourier

solver = SplitStepFourier(hamiltonian)

# Without JIT (slow)
for step in range(1000):
    psi = solver.step(psi, step * dt, dt)

# With JIT (fast)
@jit
def evolve_step(psi, t, dt):
    return solver.step(psi, t, dt)

for step in range(1000):
    psi = evolve_step(psi, step * dt, dt)

First call is slow (compilation), subsequent calls are fast.

Static Arguments

Mark static arguments in JIT:

from functools import partial

@partial(jit, static_argnums=(2,))
def evolve_with_absorption(psi, t, apply_absorption):
    psi = solver.step(psi, t, dt)
    if apply_absorption:
        psi = psi * absorption_mask
    return psi

Vectorization

Use vmap for batch processing:

from jax import vmap

# Evolve multiple initial conditions in parallel
initial_states = jnp.array([psi0_1, psi0_2, psi0_3])

@jit
def evolve_single(psi0):
    psi = psi0
    for step in range(num_steps):
        psi = solver.step(psi, step * dt, dt)
    return psi

# Vectorized version
evolve_batch = vmap(evolve_single)
final_states = evolve_batch(initial_states)

Precision Trade-offs

float32/complex64 vs float64/complex128:

# Default: complex64 (recommended)
x, dx = create_grid_1d(-10, 10, 1024, dtype=jnp.float32)
psi = psi.astype(jnp.complex64)

# High precision: complex128 (2× memory, slower)
x, dx = create_grid_1d(-10, 10, 1024, dtype=jnp.float64)
psi = psi.astype(jnp.complex128)

Trade-offs:

  • complex64: 2× faster, 2× less memory, ~7 digits precision

  • complex128: Higher accuracy, needed for very long simulations

For most quantum simulations, complex64 is sufficient.

Memory Management

GPU Memory

Check GPU memory:

import jax

# GPU memory info
devices = jax.devices('gpu')
if devices:
    device = devices[0]
    print(f"Device: {device}")
    # Memory usage varies by GPU

Clear GPU memory:

import jax

# Clear compiled functions and free GPU memory
jax.clear_backends()

Out of Memory Solutions

If you encounter XLA_ERROR: Out of memory:

  1. Reduce grid size:

    # Instead of 2048×2048
    X, Y, dx, dy = create_grid_2d(-500, 500, 1024, -500, 500, 1024)
    
  2. Use float32:

    dtype = jnp.float32  # Instead of float64
    
  3. Process in chunks:

    # Split time evolution into batches
    for batch in range(num_batches):
        for step in range(steps_per_batch):
            psi = solver.step(psi, (batch * steps_per_batch + step) * dt, dt)
    
        # Save checkpoint
        jnp.save(f'checkpoint_{batch}.npy', psi)
    
        # Clear cache
        jax.clear_caches()
    
  4. Enable memory preallocation:

    export XLA_PYTHON_CLIENT_PREALLOCATE=false
    

FFT Optimization

Grid Size Selection

FFT is fastest for sizes that are powers of 2:

# Good choices (powers of 2)
good_sizes = [256, 512, 1024, 2048, 4096, 8192]

# Avoid (prime factors)
avoid_sizes = [1000, 1500, 2000, 3000]

Benchmark different sizes:

import time
import jax.numpy as jnp

for n in [1000, 1024, 2000, 2048]:
    x = jnp.ones((n, n), dtype=jnp.complex64)

    # Warmup
    _ = jnp.fft.fft2(x)

    # Benchmark
    start = time.time()
    for _ in range(100):
        _ = jnp.fft.fft2(x)
    elapsed = time.time() - start

    print(f"n={n}: {elapsed/100*1000:.2f} ms/FFT")

In-Place Operations

JAX doesn’t support in-place operations, but arrays are reused efficiently:

# This is efficient (array reuse)
psi = psi * absorption_mask
psi = jnp.fft.fftn(psi)
psi = psi * kinetic_operator
psi = jnp.fft.ifftn(psi)

Parallel Processing

Multiple GPUs

For systems with multiple GPUs:

import jax
from jax import pmap

# Distribute across GPUs
devices = jax.devices('gpu')
print(f"Available GPUs: {len(devices)}")

# Parallel map across devices
@pmap
def evolve_parallel(psi0):
    psi = psi0
    for step in range(num_steps):
        psi = solver.step(psi, step * dt, dt)
    return psi

# Split initial conditions across GPUs
initial_states = jnp.array([psi0_1, psi0_2, psi0_3, psi0_4])
final_states = evolve_parallel(initial_states)

CPU Parallelization

JAX automatically uses multiple CPU cores:

# Control number of threads
export XLA_FLAGS="--xla_cpu_multi_thread_eigen=true"
export OMP_NUM_THREADS=8