Performance Optimization ============================ This guide provides best practices and optimization strategies for achieving maximum performance with **schr** on both CPUs and GPUs. .. contents:: Table of Contents :local: :depth: 2 Hardware Considerations ----------------------- GPU vs CPU ~~~~~~~~~~ **When to use GPU:** * Grid size >= 512 in any dimension * Long simulations (thousands of time steps) * Parameter sweeps or batch processing * 2D/3D systems **When CPU is sufficient:** * 1D systems with N < 4096 * Quick prototyping and debugging * Systems without NVIDIA GPUs * Single-shot simulations Performance Scaling ~~~~~~~~~~~~~~~~~~~ Typical speedups (GPU vs CPU): .. list-table:: :header-rows: 1 :widths: 30 30 40 * - Problem Size - Speedup - Notes * - 1D, N=1024 - 2-5× - Small overhead * - 1D, N=8192 - 10-20× - FFT-dominated * - 2D, 512×512 - 20-50× - Good GPU utilization * - 2D, 1024×1024 - 50-100× - Optimal for GPU * - 3D, 256×256×256 - 100-200× - Best GPU advantage Memory Requirements ~~~~~~~~~~~~~~~~~~~ **RAM/VRAM Usage:** For complex64 arrays (8 bytes per element): .. math:: \text{Memory} = 8N^d \text{ bytes} .. list-table:: :header-rows: 1 :widths: 30 35 35 * - Grid Size - Memory (complex64) - Memory (complex128) * - 1D: 8192 - 64 KB - 128 KB * - 2D: 1024×1024 - 8 MB - 16 MB * - 2D: 2048×2048 - 32 MB - 64 MB * - 3D: 256×256×256 - 128 MB - 256 MB * - 3D: 512×512×512 - 1 GB - 2 GB Add ~2-3× for intermediate arrays during computation. JAX Optimization ---------------- JIT Compilation ~~~~~~~~~~~~~~~ **Just-In-Time (JIT) compilation** is crucial for performance: .. code-block:: python from jax import jit from schr.qm.solvers import SplitStepFourier solver = SplitStepFourier(hamiltonian) # Without JIT (slow) for step in range(1000): psi = solver.step(psi, step * dt, dt) # With JIT (fast) @jit def evolve_step(psi, t, dt): return solver.step(psi, t, dt) for step in range(1000): psi = evolve_step(psi, step * dt, dt) **First call is slow** (compilation), subsequent calls are fast. Static Arguments ~~~~~~~~~~~~~~~~ Mark static arguments in JIT: .. code-block:: python from functools import partial @partial(jit, static_argnums=(2,)) def evolve_with_absorption(psi, t, apply_absorption): psi = solver.step(psi, t, dt) if apply_absorption: psi = psi * absorption_mask return psi Vectorization ~~~~~~~~~~~~~ Use ``vmap`` for batch processing: .. code-block:: python from jax import vmap # Evolve multiple initial conditions in parallel initial_states = jnp.array([psi0_1, psi0_2, psi0_3]) @jit def evolve_single(psi0): psi = psi0 for step in range(num_steps): psi = solver.step(psi, step * dt, dt) return psi # Vectorized version evolve_batch = vmap(evolve_single) final_states = evolve_batch(initial_states) Precision Trade-offs ~~~~~~~~~~~~~~~~~~~~ **float32/complex64 vs float64/complex128:** .. code-block:: python # Default: complex64 (recommended) x, dx = create_grid_1d(-10, 10, 1024, dtype=jnp.float32) psi = psi.astype(jnp.complex64) # High precision: complex128 (2× memory, slower) x, dx = create_grid_1d(-10, 10, 1024, dtype=jnp.float64) psi = psi.astype(jnp.complex128) **Trade-offs:** * **complex64**: 2× faster, 2× less memory, ~7 digits precision * **complex128**: Higher accuracy, needed for very long simulations For most quantum simulations, **complex64 is sufficient**. Memory Management ----------------- GPU Memory ~~~~~~~~~~ **Check GPU memory:** .. code-block:: python import jax # GPU memory info devices = jax.devices('gpu') if devices: device = devices[0] print(f"Device: {device}") # Memory usage varies by GPU **Clear GPU memory:** .. code-block:: python import jax # Clear compiled functions and free GPU memory jax.clear_backends() Out of Memory Solutions ~~~~~~~~~~~~~~~~~~~~~~~ If you encounter ``XLA_ERROR: Out of memory``: 1. **Reduce grid size:** .. code-block:: python # Instead of 2048×2048 X, Y, dx, dy = create_grid_2d(-500, 500, 1024, -500, 500, 1024) 2. **Use float32:** .. code-block:: python dtype = jnp.float32 # Instead of float64 3. **Process in chunks:** .. code-block:: python # Split time evolution into batches for batch in range(num_batches): for step in range(steps_per_batch): psi = solver.step(psi, (batch * steps_per_batch + step) * dt, dt) # Save checkpoint jnp.save(f'checkpoint_{batch}.npy', psi) # Clear cache jax.clear_caches() 4. **Enable memory preallocation:** .. code-block:: bash export XLA_PYTHON_CLIENT_PREALLOCATE=false FFT Optimization ---------------- Grid Size Selection ~~~~~~~~~~~~~~~~~~~ **FFT is fastest for sizes that are powers of 2:** .. code-block:: python # Good choices (powers of 2) good_sizes = [256, 512, 1024, 2048, 4096, 8192] # Avoid (prime factors) avoid_sizes = [1000, 1500, 2000, 3000] **Benchmark different sizes:** .. code-block:: python import time import jax.numpy as jnp for n in [1000, 1024, 2000, 2048]: x = jnp.ones((n, n), dtype=jnp.complex64) # Warmup _ = jnp.fft.fft2(x) # Benchmark start = time.time() for _ in range(100): _ = jnp.fft.fft2(x) elapsed = time.time() - start print(f"n={n}: {elapsed/100*1000:.2f} ms/FFT") In-Place Operations ~~~~~~~~~~~~~~~~~~~ JAX doesn't support in-place operations, but arrays are reused efficiently: .. code-block:: python # This is efficient (array reuse) psi = psi * absorption_mask psi = jnp.fft.fftn(psi) psi = psi * kinetic_operator psi = jnp.fft.ifftn(psi) Parallel Processing ------------------- Multiple GPUs ~~~~~~~~~~~~~ For systems with multiple GPUs: .. code-block:: python import jax from jax import pmap # Distribute across GPUs devices = jax.devices('gpu') print(f"Available GPUs: {len(devices)}") # Parallel map across devices @pmap def evolve_parallel(psi0): psi = psi0 for step in range(num_steps): psi = solver.step(psi, step * dt, dt) return psi # Split initial conditions across GPUs initial_states = jnp.array([psi0_1, psi0_2, psi0_3, psi0_4]) final_states = evolve_parallel(initial_states) CPU Parallelization ~~~~~~~~~~~~~~~~~~~ JAX automatically uses multiple CPU cores: .. code-block:: bash # Control number of threads export XLA_FLAGS="--xla_cpu_multi_thread_eigen=true" export OMP_NUM_THREADS=8