schr.utils package

Submodules

Module contents

Utility functions for schr package.

This module provides helper functions for grid generation, FFT operations, visualization, and I/O operations.

schr.utils.animate_wavefunction(x: Array, psi_frames: list[Array], dt: float, title: str = '', show_real_imag: bool = True, show_probability: bool = True, interval: int = 50, figsize: tuple[float, float] = (12, 5), units: str = 'a.u.', setup_callback: Callable | None = None, update_callback: Callable | None = None) FuncAnimation[source]

Create publication-quality animation of time-evolving wavefunction.

Parameters:
  • x – Spatial grid (1D array)

  • psi_frames – List of wavefunction snapshots (complex)

  • dt – Time step between frames (a.u.)

  • title – Animation title

  • show_real_imag – Show Re($psi$) and Im($psi$)

  • show_probability – Show $|psi|^2$

  • interval – Time between frames in milliseconds

  • figsize – Figure size (width, height) in inches

  • units – Physical units label (default: “a.u.”)

  • setup_callback – Optional function(fig, axes, x, psi0) -> dict Called once to set up custom visualization elements. Should return dict of objects to update in animation.

  • update_callback – Optional function(frame_idx, psi, t, objects) -> None Called each frame to update custom elements. objects is the dict returned by setup_callback.

Returns:

Matplotlib FuncAnimation object

Example

>>> x, dx = create_grid_1d(-10, 10, 512)
>>> psi_frames = []  # List of wavefunctions at different times
>>> anim = animate_wavefunction(x, psi_frames, dt=0.01)
>>> plt.show()
>>> # With custom barrier visualization
>>> def setup(fig, axes, x, psi0):
...     if not isinstance(axes, list):
...         axes = [axes]
...     for ax in axes:
...         barrier = ax.axvspan(8, 12, alpha=0.2, color='red')
...     return {'barrier': barrier}
>>> anim = animate_wavefunction(x, psi_frames, dt=0.01,
...                             setup_callback=setup)
schr.utils.animate_wavefunction_2d(X: Array, Y: Array, psi_frames: list[Array], times: list[float], figsize: tuple[float, float] = (10, 8), plot_type: str = 'probability', cmap: str = 'inferno', units: str = 'a.u.', scale_factor: float = 1.0, scale_units: str | None = None, interval: int = 50, title: str = '', setup_callback: Callable | None = None, update_callback: Callable | None = None) FuncAnimation[source]

Animate 2D wavefunction evolution with optional custom setup/update.

Parameters:
  • X – Coordinate grids (2D arrays)

  • Y – Coordinate grids (2D arrays)

  • psi_frames – List of wavefunction snapshots (complex 2D arrays)

  • times – List of time values for each frame

  • figsize – Figure size (width, height) in inches

  • plot_type – “probability”, “real”, “imag”, or “phase”

  • cmap – Colormap (inferno for probability, RdBu_r for real/imag, twilight for phase)

  • units – Physical units label (default: “a.u.”)

  • scale_factor – Coordinate scaling (e.g., 1000 for nm→μm)

  • scale_units – Scaled units label (e.g., “μm”)

  • interval – Time between frames (ms)

  • title – Animation title

  • setup_callback – Optional function(fig, axes, X, Y, psi0) -> dict Called once to set up custom visualization elements. Should return dict of objects to update in animation.

  • update_callback – Optional function(frame_idx, psi, t, objects) -> None Called each frame to update custom elements. objects is the dict returned by setup_callback.

Returns:

Matplotlib FuncAnimation object

Example

>>> # Simple usage
>>> anim = animate_wavefunction_2d(X, Y, psi_frames, times)
>>> # With custom barrier visualization
>>> def setup(fig, axes, X, Y, psi0):
...     ax = axes[0]
...     barrier_contour = ax.contour(X, Y, V > 0, colors='cyan')
...     return {'barrier': barrier_contour}
>>> anim = animate_wavefunction_2d(X, Y, psi_frames, times,
...                                setup_callback=setup)
schr.utils.clean_old_simulations(keep_recent: int = 10) list[Path][source]

Clean old simulation directories, keeping only the most recent ones.

Parameters:

keep_recent – Number of recent simulations to keep (default: 10).

Returns:

List of paths that were removed.

Example

>>> removed = clean_old_simulations(keep_recent=5)
>>> print(f"Removed {len(removed)} old simulations")
Removed 3 old simulations
schr.utils.complex_absorbing_potential(coords: ~jax.Array | tuple[~jax.Array, ...], domain: tuple[float, float] | tuple[tuple[float, float], ...], width: float | tuple[float, ...], strength: float = 1.0, order: int = 2, dtype: ~numpy.dtype = <class 'jax.numpy.complex64'>) Array[source]

Create complex absorbing potential (CAP).

Generates an imaginary potential -i*η*f(x) that acts as an absorber when added to the Hamiltonian. The negative imaginary part causes exponential decay of the wavefunction amplitude.

The CAP method is more sophisticated than simple masks as it properly absorbs incoming waves while minimizing reflections.

Parameters:
  • coords – Grid coordinates (see polynomial_absorbing_mask).

  • domain – Domain boundaries (see polynomial_absorbing_mask).

  • width – CAP layer width.

  • strength – CAP strength parameter η. Typical: 0.1-2.0.

  • order – Polynomial order for CAP profile. Typical: 2-4.

  • dtype – JAX complex dtype for the potential.

Returns:

Complex absorbing potential (purely imaginary, negative values). Returns 0 in interior and -i*η*f(x) in absorption layer.

Example

>>> # Use in Hamiltonian: V_total = V_physical + CAP
>>> X, Y = jnp.meshgrid(jnp.linspace(-5, 5, 256), jnp.linspace(-3, 3, 128))
>>> cap = complex_absorbing_potential((X, Y), ((-5, 5), (-3, 3)), 1.0, strength=0.5)
>>> V_total = V_physical + cap
schr.utils.create_absorbing_boundary(coords: Array | tuple[Array, ...], domain: tuple[float, float] | tuple[tuple[float, float], ...], width: float | tuple[float, ...], method: Literal['polynomial', 'exponential', 'cap', 'mask_potential'] = 'polynomial', **kwargs) Array[source]

Create absorbing boundary using specified method.

Convenience function that dispatches to the appropriate absorbing boundary implementation.

Parameters:
  • coords – Grid coordinates.

  • domain – Domain boundaries.

  • width – Absorption layer width.

  • method – Absorbing boundary method: - “polynomial”: Polynomial mask (cos^n profile) - “exponential”: Exponential mask - “cap”: Complex absorbing potential - “mask_potential”: Real absorbing potential

  • **kwargs – Additional method-specific parameters: - order: Polynomial order (polynomial, cap, mask_potential) - strength: Absorption strength (exponential, cap, mask_potential) - dtype: Data type

Returns:

Absorbing boundary (mask or potential depending on method).

Example

>>> # Polynomial mask (default)
>>> X, Y = jnp.meshgrid(jnp.linspace(-5, 5, 256), jnp.linspace(-3, 3, 128))
>>> mask = create_absorbing_boundary((X, Y), ((-5, 5), (-3, 3)), 1.0)
>>>
>>> # Complex absorbing potential
>>> cap = create_absorbing_boundary(
...     (X, Y), ((-5, 5), (-3, 3)), 1.0,
...     method="cap", strength=0.5
... )
schr.utils.create_absorption_mask(X: Array, Y: Array, x_range: tuple[float, float], y_range: tuple[float, float], width: float = 1000.0, strength: float = 1.0, order: int = 4) Array[source]

Legacy interface for 2D polynomial absorbing mask.

Provided for backward compatibility with existing code.

Parameters:
  • X – 2D coordinate grids.

  • Y – 2D coordinate grids.

  • x_range – (x_min, x_max) tuple.

  • y_range – (y_min, y_max) tuple.

  • width – Absorption region width.

  • strength – Unused (kept for compatibility).

  • order – Polynomial order.

Returns:

2D absorption mask.

Example

>>> X, Y = jnp.meshgrid(jnp.linspace(-5, 5, 256), jnp.linspace(-3, 3, 128))
>>> mask = create_absorption_mask(X, Y, (-5, 5), (-3, 3), width=1.0)
schr.utils.create_grid_1d(x_min: float, x_max: float, nx: int, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) tuple[Array, float][source]

Create 1D spatial grid.

Parameters:
  • x_min – Minimum x coordinate (a.u.).

  • x_max – Maximum x coordinate (a.u.).

  • nx – Number of grid points.

  • dtype – JAX dtype (default: float32).

Returns:

Grid points and spacing (a.u.).

Return type:

Tuple (x, dx)

Raises:

ValueError – If nx < 2 or x_max <= x_min.

schr.utils.create_grid_2d(x_min: float, x_max: float, nx: int, y_min: float, y_max: float, ny: int, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) tuple[Array, Array, float, float][source]

Create 2D spatial grid.

Parameters:
  • x_min – Minimum x coordinate (a.u.).

  • x_max – Maximum x coordinate (a.u.).

  • nx – Number of x grid points.

  • y_min – Minimum y coordinate (a.u.).

  • y_max – Maximum y coordinate (a.u.).

  • ny – Number of y grid points.

  • dtype – JAX dtype (default: float32).

Returns:

2D meshgrid arrays (shape: ny × nx) and spacings (a.u.).

Return type:

Tuple (X, Y, dx, dy)

Raises:

ValueError – If grid parameters are invalid.

schr.utils.create_grid_3d(x_min: float, x_max: float, nx: int, y_min: float, y_max: float, ny: int, z_min: float, z_max: float, nz: int, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) tuple[Array, Array, Array, float, float, float][source]

Create 3D spatial grid.

Parameters:
  • x_min – Minimum x coordinate (a.u.).

  • x_max – Maximum x coordinate (a.u.).

  • nx – Number of x grid points.

  • y_min – Minimum y coordinate (a.u.).

  • y_max – Maximum y coordinate (a.u.).

  • ny – Number of y grid points.

  • z_min – Minimum z coordinate (a.u.).

  • z_max – Maximum z coordinate (a.u.).

  • nz – Number of z grid points.

  • dtype – JAX dtype (default: float32).

Returns:

3D meshgrid arrays (shape: nz × ny × nx) and spacings (a.u.).

Return type:

Tuple (X, Y, Z, dx, dy, dz)

Raises:

ValueError – If grid parameters are invalid.

schr.utils.exponential_absorbing_mask(coords: ~jax.Array | tuple[~jax.Array, ...], domain: tuple[float, float] | tuple[tuple[float, float], ...], width: float | tuple[float, ...], strength: float = 5.0, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) Array[source]

Create exponential absorbing boundary mask.

Uses exponential decay profile: exp(-strength * (1 - t)^2) where t is the normalized distance from boundary. Provides very strong absorption near edges.

Parameters:
  • coords – Grid coordinates (see polynomial_absorbing_mask).

  • domain – Domain boundaries (see polynomial_absorbing_mask).

  • width – Absorption layer width.

  • strength – Absorption strength parameter (higher = stronger). Typical: 3-10.

  • dtype – JAX dtype for the mask.

Returns:

Absorption mask (0 to 1).

Example

>>> X, Y = jnp.meshgrid(jnp.linspace(-5, 5, 256), jnp.linspace(-3, 3, 128))
>>> mask = exponential_absorbing_mask((X, Y), ((-5, 5), (-3, 3)), 1.0, strength=7.0)
schr.utils.fft_derivative(f: Array, dx: float, axis: int = -1, order: int = 1) Array[source]

Compute spectral derivative using FFT.

Spectral accuracy with periodic boundaries. Derivatives in Fourier space:

  • 1st order: multiply by \(ik\)

  • 2nd order: multiply by \(-k^2\)

Parameters:
  • f – Function values on uniform grid.

  • dx – Grid spacing (a.u.).

  • axis – Axis for derivative (default: -1 = last axis).

  • order – Derivative order (1 or 2).

Returns:

\(\partial^n f/\partial x^n\).

Raises:

ValueError – If order not in {1, 2}.

schr.utils.fftfreq_grid(n: int, d: float, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) Array[source]

Create frequency grid for FFT operations.

Parameters:
  • n – Number of grid points.

  • d – Grid spacing (a.u.).

  • dtype – JAX dtype (default: float32).

Returns:

Frequency array (rad/a.u.).

schr.utils.get_cache_dir(create: bool = True) Path[source]

Get directory for cached computations.

Parameters:

create – If True, create the directory if it doesn’t exist (default: True).

Returns:

Path to ~/.schr/cache directory.

Example

>>> cache_dir = get_cache_dir()
>>> print(cache_dir)
/Users/username/.schr/cache
schr.utils.get_data_dir(create: bool = True) Path[source]

Get directory for data files (potentials, initial conditions, etc.).

Parameters:

create – If True, create the directory if it doesn’t exist (default: True).

Returns:

Path to ~/.schr/data directory.

Example

>>> data_dir = get_data_dir()
>>> print(data_dir)
/Users/username/.schr/data
schr.utils.get_output_dir(simulation_name: str, create: bool = True) Path[source]

Get output directory for a simulation.

Creates a subdirectory in ~/.schr for the simulation results.

Parameters:
  • simulation_name – Name of the simulation (e.g., “double_slit”, “tunneling”).

  • create – If True, create the directory if it doesn’t exist (default: True).

Returns:

Path to the simulation output directory.

Example

>>> output_dir = get_output_dir("double_slit")
>>> print(output_dir)
/Users/username/.schr/double_slit
schr.utils.get_output_path(simulation_name: str, filename: str, timestamp: bool = True, create_dir: bool = True) Path[source]

Get full path for an output file.

Parameters:
  • simulation_name – Name of the simulation.

  • filename – Name of the output file (e.g., “animation.mp4”, “final_state.npy”).

  • timestamp – If True, append timestamp to filename suffix (default: True).

  • create_dir – If True, create the directory if needed (default: True).

Returns:

Full path to the output file.

Example

>>> path = get_output_path("double_slit", "animation.mp4")
>>> print(path)
/Users/username/.schr/double_slit/animation_20251104_153022.mp4
>>> path = get_output_path("double_slit", "animation.mp4", timestamp=False)
>>> print(path)
/Users/username/.schr/double_slit/animation.mp4
schr.utils.get_schr_home() Path[source]

Get the Schr home directory (~/.schr).

Creates the directory if it doesn’t exist.

Returns:

Path to ~/.schr directory.

Example

>>> home = get_schr_home()
>>> print(home)
/Users/username/.schr
schr.utils.get_simulation_info(simulation_dir: Path) dict[source]

Get information about a simulation directory.

Parameters:

simulation_dir – Path to the simulation directory.

Returns:

Dictionary with simulation information.

Example

>>> info = get_simulation_info(Path("~/.schr/double_slit"))
>>> print(info)
{
    'name': 'double_slit',
    'size_mb': 125.3,
    'num_files': 126,
    'created': '2025-11-04 15:30:22',
    'modified': '2025-11-04 15:35:45'
}
schr.utils.list_simulations() list[str][source]

List all simulation directories in ~/.schr.

Returns:

List of simulation directory names.

Example

>>> simulations = list_simulations()
>>> for sim in simulations:
...     print(sim)
double_slit
tunneling
quantum_vortex
schr.utils.mask_absorbing_potential(coords: ~jax.Array | tuple[~jax.Array, ...], domain: tuple[float, float] | tuple[tuple[float, float], ...], width: float | tuple[float, ...], strength: float = 100.0, order: int = 2, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) Array[source]

Create real-valued absorbing potential (mask potential).

Similar to CAP but uses a large positive real potential instead of imaginary. This creates a “soft wall” that exponentially suppresses the wavefunction. Simpler than CAP but may cause more reflections.

Parameters:
  • coords – Grid coordinates.

  • domain – Domain boundaries.

  • width – Absorption layer width.

  • strength – Potential height. Very large values (50-1000) recommended.

  • order – Polynomial order for potential profile.

  • dtype – JAX dtype for the potential.

Returns:

Real absorbing potential (positive values in absorption layer).

Example

>>> x = jnp.linspace(-10, 10, 512)
>>> V_absorb = mask_absorbing_potential(x, (-10, 10), width=2.0, strength=500.0)
>>> V_total = V_physical + V_absorb
schr.utils.momentum_operator(shape: tuple[int, ...], dx: float, hbar: float = 1.0, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) tuple[Array, ...][source]

Create momentum operators in Fourier space.

For kinetic energy: \(\hat{T} = \frac{\hbar^2k^2}{2m}\).

Parameters:
  • shape – Grid shape (nx,) for 1D, (ny, nx) for 2D, (nz, ny, nx) for 3D.

  • dx – Grid spacing (a.u., uniform).

  • hbar – Reduced Planck constant (a.u., default: 1.0).

  • dtype – JAX dtype (default: float32).

Returns:

  • 1D: (kx,) of shape (nx,)

  • 2D: (kx, ky) of shapes (ny, nx) each

  • 3D: (kx, ky, kz) of shapes (nz, ny, nx) each

Return type:

Tuple of momentum operators (one per dimension)

Raises:

ValueError – If ndim > 3.

schr.utils.plot_wavefunction(x: Array, psi: Array, title: str = '', show_real_imag: bool = True, show_probability: bool = True, figsize: tuple[float, float] = (12, 4), units: str = 'a.u.') Figure[source]

Plot 1D wavefunction with publication quality.

Parameters:
  • x – Spatial grid (1D array)

  • psi – Wavefunction values (complex)

  • title – Plot title

  • show_real_imag – Show Re($psi$) and Im($psi$)

  • show_probability – Show $|psi|^2$

  • figsize – Figure size (width, height) in inches

  • units – Physical units label (default: “a.u.”)

Returns:

Matplotlib Figure object

schr.utils.plot_wavefunction_2d(X: Array, Y: Array, psi: Array, title: str = '', plot_type: str = 'probability', figsize: tuple[float, float] = (8, 7), cmap: str = 'viridis', units: str = 'a.u.', vmin: float | None = None, vmax: float | None = None) Figure[source]

Plot 2D wavefunction with publication quality.

Parameters:
  • X – X-coordinate grid (2D array)

  • Y – Y-coordinate grid (2D array)

  • psi – Wavefunction values (complex 2D array)

  • title – Plot title

  • plot_type – “probability”, “real”, “imag”, or “phase”

  • figsize – Figure size (width, height) in inches

  • cmap – Colormap name (viridis, plasma, inferno, twilight)

  • units – Physical units label (default: “a.u.”)

  • vmin – Minimum value for colormap (auto if None)

  • vmax – Maximum value for colormap (auto if None)

Returns:

Matplotlib Figure object

Example

>>> X, Y, dx, dy = create_grid_2d(-5, 5, 256, -5, 5, 256)
>>> psi = jnp.exp(-(X**2 + Y**2) / 2)
>>> fig = plot_wavefunction_2d(X, Y, psi)
>>> plt.show()
schr.utils.polynomial_absorbing_mask(coords: ~jax.Array | tuple[~jax.Array, ...], domain: tuple[float, float] | tuple[tuple[float, float], ...], width: float | tuple[float, ...], order: int = 4, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) Array[source]

Create polynomial absorbing boundary mask using \(\cos^n\) profile.

Uses a smooth polynomial profile based on \(\cos(\pi t/2)^n\) where \(t\) is the normalized distance from the boundary. Higher orders provide stronger absorption.

Parameters:
  • coords – Grid coordinates. Single 1D array or tuple of arrays for 2D/3D.

  • domain – Domain boundaries. (min, max) for 1D or tuple of tuples for 2D/3D.

  • width – Absorption layer width (uniform float or per-dimension tuple).

  • order – Polynomial order (higher = stronger absorption). Typical: 2-8.

  • dtype – JAX dtype for the mask.

Returns:

Absorption mask (multiplicative factor from 0 to 1). Returns 1.0 in interior and smoothly decreases to 0 near boundaries.

Raises:

ValueError – If dimensions are inconsistent.

Example

>>> x = jnp.linspace(-10, 10, 512)
>>> mask = polynomial_absorbing_mask(x, (-10, 10), width=2.0, order=4)
>>>
>>> X, Y = jnp.meshgrid(jnp.linspace(-5, 5, 256), jnp.linspace(-3, 3, 128))
>>> mask = polynomial_absorbing_mask(
...     (X, Y),
...     ((-5, 5), (-3, 3)),
...     width=1.0,
...     order=6
... )