schr.utils.visualization module

Professional visualization tools for quantum simulations.

schr.utils.visualization.animate_wavefunction(x: Array, psi_frames: list[Array], dt: float, title: str = '', show_real_imag: bool = True, show_probability: bool = True, interval: int = 50, figsize: tuple[float, float] = (12, 5), units: str = 'a.u.', setup_callback: Callable | None = None, update_callback: Callable | None = None) FuncAnimation[source]

Create publication-quality animation of time-evolving wavefunction.

Parameters:
  • x – Spatial grid (1D array)

  • psi_frames – List of wavefunction snapshots (complex)

  • dt – Time step between frames (a.u.)

  • title – Animation title

  • show_real_imag – Show Re($psi$) and Im($psi$)

  • show_probability – Show $|psi|^2$

  • interval – Time between frames in milliseconds

  • figsize – Figure size (width, height) in inches

  • units – Physical units label (default: “a.u.”)

  • setup_callback – Optional function(fig, axes, x, psi0) -> dict Called once to set up custom visualization elements. Should return dict of objects to update in animation.

  • update_callback – Optional function(frame_idx, psi, t, objects) -> None Called each frame to update custom elements. objects is the dict returned by setup_callback.

Returns:

Matplotlib FuncAnimation object

Example

>>> x, dx = create_grid_1d(-10, 10, 512)
>>> psi_frames = []  # List of wavefunctions at different times
>>> anim = animate_wavefunction(x, psi_frames, dt=0.01)
>>> plt.show()
>>> # With custom barrier visualization
>>> def setup(fig, axes, x, psi0):
...     if not isinstance(axes, list):
...         axes = [axes]
...     for ax in axes:
...         barrier = ax.axvspan(8, 12, alpha=0.2, color='red')
...     return {'barrier': barrier}
>>> anim = animate_wavefunction(x, psi_frames, dt=0.01,
...                             setup_callback=setup)
schr.utils.visualization.animate_wavefunction_2d(X: Array, Y: Array, psi_frames: list[Array], times: list[float], figsize: tuple[float, float] = (10, 8), plot_type: str = 'probability', cmap: str = 'inferno', units: str = 'a.u.', scale_factor: float = 1.0, scale_units: str | None = None, interval: int = 50, title: str = '', setup_callback: Callable | None = None, update_callback: Callable | None = None) FuncAnimation[source]

Animate 2D wavefunction evolution with optional custom setup/update.

Parameters:
  • X – Coordinate grids (2D arrays)

  • Y – Coordinate grids (2D arrays)

  • psi_frames – List of wavefunction snapshots (complex 2D arrays)

  • times – List of time values for each frame

  • figsize – Figure size (width, height) in inches

  • plot_type – “probability”, “real”, “imag”, or “phase”

  • cmap – Colormap (inferno for probability, RdBu_r for real/imag, twilight for phase)

  • units – Physical units label (default: “a.u.”)

  • scale_factor – Coordinate scaling (e.g., 1000 for nm→μm)

  • scale_units – Scaled units label (e.g., “μm”)

  • interval – Time between frames (ms)

  • title – Animation title

  • setup_callback – Optional function(fig, axes, X, Y, psi0) -> dict Called once to set up custom visualization elements. Should return dict of objects to update in animation.

  • update_callback – Optional function(frame_idx, psi, t, objects) -> None Called each frame to update custom elements. objects is the dict returned by setup_callback.

Returns:

Matplotlib FuncAnimation object

Example

>>> # Simple usage
>>> anim = animate_wavefunction_2d(X, Y, psi_frames, times)
>>> # With custom barrier visualization
>>> def setup(fig, axes, X, Y, psi0):
...     ax = axes[0]
...     barrier_contour = ax.contour(X, Y, V > 0, colors='cyan')
...     return {'barrier': barrier_contour}
>>> anim = animate_wavefunction_2d(X, Y, psi_frames, times,
...                                setup_callback=setup)
schr.utils.visualization.plot_wavefunction(x: Array, psi: Array, title: str = '', show_real_imag: bool = True, show_probability: bool = True, figsize: tuple[float, float] = (12, 4), units: str = 'a.u.') Figure[source]

Plot 1D wavefunction with publication quality.

Parameters:
  • x – Spatial grid (1D array)

  • psi – Wavefunction values (complex)

  • title – Plot title

  • show_real_imag – Show Re($psi$) and Im($psi$)

  • show_probability – Show $|psi|^2$

  • figsize – Figure size (width, height) in inches

  • units – Physical units label (default: “a.u.”)

Returns:

Matplotlib Figure object

schr.utils.visualization.plot_wavefunction_2d(X: Array, Y: Array, psi: Array, title: str = '', plot_type: str = 'probability', figsize: tuple[float, float] = (8, 7), cmap: str = 'viridis', units: str = 'a.u.', vmin: float | None = None, vmax: float | None = None) Figure[source]

Plot 2D wavefunction with publication quality.

Parameters:
  • X – X-coordinate grid (2D array)

  • Y – Y-coordinate grid (2D array)

  • psi – Wavefunction values (complex 2D array)

  • title – Plot title

  • plot_type – “probability”, “real”, “imag”, or “phase”

  • figsize – Figure size (width, height) in inches

  • cmap – Colormap name (viridis, plasma, inferno, twilight)

  • units – Physical units label (default: “a.u.”)

  • vmin – Minimum value for colormap (auto if None)

  • vmax – Maximum value for colormap (auto if None)

Returns:

Matplotlib Figure object

Example

>>> X, Y, dx, dy = create_grid_2d(-5, 5, 256, -5, 5, 256)
>>> psi = jnp.exp(-(X**2 + Y**2) / 2)
>>> fig = plot_wavefunction_2d(X, Y, psi)
>>> plt.show()
schr.utils.visualization.plot_wavefunction_2d_comprehensive(X: Array, Y: Array, psi: Array, title: str = '2D Wavefunction', figsize: tuple[float, float] = (16, 4), units: str = 'a.u.') Figure[source]

Educational 2D wavefunction visualization showing all components.

Creates a 4-panel visualization of a 2D wavefunction: 1. Probability density $|psi|^2$ 2. Real part Re$(psi)$ 3. Imaginary part Im$(psi)$ 4. Phase arg$(psi)$

Parameters:
  • X – X-coordinate grid (2D array)

  • Y – Y-coordinate grid (2D array)

  • psi – Wavefunction values (complex 2D array)

  • title – Overall plot title

  • figsize – Figure size (width, height) in inches

  • units – Physical units label

Returns:

Matplotlib Figure object

Example

>>> X, Y, dx, dy = create_grid_2d(-5, 5, 128, -5, 5, 128)
>>> psi = jnp.exp(-(X**2 + Y**2) / 2)
>>> fig = plot_wavefunction_2d_comprehensive(X, Y, psi)
>>> plt.show()
schr.utils.visualization.plot_wavefunction_3d(x: Array, psi: Array, title: str = 'Quantum Wavefunction', figsize: tuple[float, float] = (14, 5), units: str = 'a.u.', elev: float = 20, azim: float = -60) Figure[source]

Plot 1D wavefunction in 3D showing real, imaginary, and probability.

Visualizes $psi(x) = text{Re}(psi) + icdottext{Im}(psi)$ as a 3D curve, making the complex nature of the wavefunction intuitive for beginners.

Parameters:
  • x – Spatial grid (1D array)

  • psi – Wavefunction values (complex)

  • title – Plot title

  • figsize – Figure size (width, height) in inches

  • units – Physical units label

  • elev – Elevation angle for 3D view (degrees)

  • azim – Azimuth angle for 3D view (degrees)

Returns:

Matplotlib Figure object with 3D visualization

Example

>>> x = jnp.linspace(-5, 5, 200)
>>> psi = jnp.exp(-x**2/2) * jnp.exp(1j * 2 * x)
>>> fig = plot_wavefunction_3d(x, psi)
>>> plt.show()
schr.utils.visualization.plot_wavefunction_components(x: Array, psi: Array, title: str = 'Anatomy of a Wavefunction', figsize: tuple[float, float] = (14, 8), units: str = 'a.u.') Figure[source]

Educational plot showing all aspects of a wavefunction for beginners.

Creates a comprehensive 4-panel visualization: 1. Real and imaginary parts 2. Magnitude $|psi|$ 3. Probability density $|psi|^2$ 4. Phase arg$(psi)$

Parameters:
  • x – Spatial grid (1D array)

  • psi – Wavefunction values (complex)

  • title – Overall plot title

  • figsize – Figure size (width, height) in inches

  • units – Physical units label

Returns:

Matplotlib Figure object

Example

>>> x = jnp.linspace(-5, 5, 200)
>>> psi = jnp.exp(-x**2/2) * jnp.exp(1j * 2 * x)
>>> fig = plot_wavefunction_components(x, psi)
>>> plt.show()