schr.utils.visualization module¶
Professional visualization tools for quantum simulations.
- schr.utils.visualization.animate_wavefunction(x: Array, psi_frames: list[Array], dt: float, title: str = '', show_real_imag: bool = True, show_probability: bool = True, interval: int = 50, figsize: tuple[float, float] = (12, 5), units: str = 'a.u.', setup_callback: Callable | None = None, update_callback: Callable | None = None) FuncAnimation[source]¶
Create publication-quality animation of time-evolving wavefunction.
- Parameters:
x – Spatial grid (1D array)
psi_frames – List of wavefunction snapshots (complex)
dt – Time step between frames (a.u.)
title – Animation title
show_real_imag – Show Re($psi$) and Im($psi$)
show_probability – Show $|psi|^2$
interval – Time between frames in milliseconds
figsize – Figure size (width, height) in inches
units – Physical units label (default: “a.u.”)
setup_callback – Optional function(fig, axes, x, psi0) -> dict Called once to set up custom visualization elements. Should return dict of objects to update in animation.
update_callback – Optional function(frame_idx, psi, t, objects) -> None Called each frame to update custom elements. objects is the dict returned by setup_callback.
- Returns:
Matplotlib FuncAnimation object
Example
>>> x, dx = create_grid_1d(-10, 10, 512) >>> psi_frames = [] # List of wavefunctions at different times >>> anim = animate_wavefunction(x, psi_frames, dt=0.01) >>> plt.show()
>>> # With custom barrier visualization >>> def setup(fig, axes, x, psi0): ... if not isinstance(axes, list): ... axes = [axes] ... for ax in axes: ... barrier = ax.axvspan(8, 12, alpha=0.2, color='red') ... return {'barrier': barrier} >>> anim = animate_wavefunction(x, psi_frames, dt=0.01, ... setup_callback=setup)
- schr.utils.visualization.animate_wavefunction_2d(X: Array, Y: Array, psi_frames: list[Array], times: list[float], figsize: tuple[float, float] = (10, 8), plot_type: str = 'probability', cmap: str = 'inferno', units: str = 'a.u.', scale_factor: float = 1.0, scale_units: str | None = None, interval: int = 50, title: str = '', setup_callback: Callable | None = None, update_callback: Callable | None = None) FuncAnimation[source]¶
Animate 2D wavefunction evolution with optional custom setup/update.
- Parameters:
X – Coordinate grids (2D arrays)
Y – Coordinate grids (2D arrays)
psi_frames – List of wavefunction snapshots (complex 2D arrays)
times – List of time values for each frame
figsize – Figure size (width, height) in inches
plot_type – “probability”, “real”, “imag”, or “phase”
cmap – Colormap (inferno for probability, RdBu_r for real/imag, twilight for phase)
units – Physical units label (default: “a.u.”)
scale_factor – Coordinate scaling (e.g., 1000 for nm→μm)
scale_units – Scaled units label (e.g., “μm”)
interval – Time between frames (ms)
title – Animation title
setup_callback – Optional function(fig, axes, X, Y, psi0) -> dict Called once to set up custom visualization elements. Should return dict of objects to update in animation.
update_callback – Optional function(frame_idx, psi, t, objects) -> None Called each frame to update custom elements. objects is the dict returned by setup_callback.
- Returns:
Matplotlib FuncAnimation object
Example
>>> # Simple usage >>> anim = animate_wavefunction_2d(X, Y, psi_frames, times)
>>> # With custom barrier visualization >>> def setup(fig, axes, X, Y, psi0): ... ax = axes[0] ... barrier_contour = ax.contour(X, Y, V > 0, colors='cyan') ... return {'barrier': barrier_contour} >>> anim = animate_wavefunction_2d(X, Y, psi_frames, times, ... setup_callback=setup)
- schr.utils.visualization.plot_wavefunction(x: Array, psi: Array, title: str = '', show_real_imag: bool = True, show_probability: bool = True, figsize: tuple[float, float] = (12, 4), units: str = 'a.u.') Figure[source]¶
Plot 1D wavefunction with publication quality.
- Parameters:
x – Spatial grid (1D array)
psi – Wavefunction values (complex)
title – Plot title
show_real_imag – Show Re($psi$) and Im($psi$)
show_probability – Show $|psi|^2$
figsize – Figure size (width, height) in inches
units – Physical units label (default: “a.u.”)
- Returns:
Matplotlib Figure object
- schr.utils.visualization.plot_wavefunction_2d(X: Array, Y: Array, psi: Array, title: str = '', plot_type: str = 'probability', figsize: tuple[float, float] = (8, 7), cmap: str = 'viridis', units: str = 'a.u.', vmin: float | None = None, vmax: float | None = None) Figure[source]¶
Plot 2D wavefunction with publication quality.
- Parameters:
X – X-coordinate grid (2D array)
Y – Y-coordinate grid (2D array)
psi – Wavefunction values (complex 2D array)
title – Plot title
plot_type – “probability”, “real”, “imag”, or “phase”
figsize – Figure size (width, height) in inches
cmap – Colormap name (viridis, plasma, inferno, twilight)
units – Physical units label (default: “a.u.”)
vmin – Minimum value for colormap (auto if None)
vmax – Maximum value for colormap (auto if None)
- Returns:
Matplotlib Figure object
Example
>>> X, Y, dx, dy = create_grid_2d(-5, 5, 256, -5, 5, 256) >>> psi = jnp.exp(-(X**2 + Y**2) / 2) >>> fig = plot_wavefunction_2d(X, Y, psi) >>> plt.show()
- schr.utils.visualization.plot_wavefunction_2d_comprehensive(X: Array, Y: Array, psi: Array, title: str = '2D Wavefunction', figsize: tuple[float, float] = (16, 4), units: str = 'a.u.') Figure[source]¶
Educational 2D wavefunction visualization showing all components.
Creates a 4-panel visualization of a 2D wavefunction: 1. Probability density $|psi|^2$ 2. Real part Re$(psi)$ 3. Imaginary part Im$(psi)$ 4. Phase arg$(psi)$
- Parameters:
X – X-coordinate grid (2D array)
Y – Y-coordinate grid (2D array)
psi – Wavefunction values (complex 2D array)
title – Overall plot title
figsize – Figure size (width, height) in inches
units – Physical units label
- Returns:
Matplotlib Figure object
Example
>>> X, Y, dx, dy = create_grid_2d(-5, 5, 128, -5, 5, 128) >>> psi = jnp.exp(-(X**2 + Y**2) / 2) >>> fig = plot_wavefunction_2d_comprehensive(X, Y, psi) >>> plt.show()
- schr.utils.visualization.plot_wavefunction_3d(x: Array, psi: Array, title: str = 'Quantum Wavefunction', figsize: tuple[float, float] = (14, 5), units: str = 'a.u.', elev: float = 20, azim: float = -60) Figure[source]¶
Plot 1D wavefunction in 3D showing real, imaginary, and probability.
Visualizes $psi(x) = text{Re}(psi) + icdottext{Im}(psi)$ as a 3D curve, making the complex nature of the wavefunction intuitive for beginners.
- Parameters:
x – Spatial grid (1D array)
psi – Wavefunction values (complex)
title – Plot title
figsize – Figure size (width, height) in inches
units – Physical units label
elev – Elevation angle for 3D view (degrees)
azim – Azimuth angle for 3D view (degrees)
- Returns:
Matplotlib Figure object with 3D visualization
Example
>>> x = jnp.linspace(-5, 5, 200) >>> psi = jnp.exp(-x**2/2) * jnp.exp(1j * 2 * x) >>> fig = plot_wavefunction_3d(x, psi) >>> plt.show()
- schr.utils.visualization.plot_wavefunction_components(x: Array, psi: Array, title: str = 'Anatomy of a Wavefunction', figsize: tuple[float, float] = (14, 8), units: str = 'a.u.') Figure[source]¶
Educational plot showing all aspects of a wavefunction for beginners.
Creates a comprehensive 4-panel visualization: 1. Real and imaginary parts 2. Magnitude $|psi|$ 3. Probability density $|psi|^2$ 4. Phase arg$(psi)$
- Parameters:
x – Spatial grid (1D array)
psi – Wavefunction values (complex)
title – Overall plot title
figsize – Figure size (width, height) in inches
units – Physical units label
- Returns:
Matplotlib Figure object
Example
>>> x = jnp.linspace(-5, 5, 200) >>> psi = jnp.exp(-x**2/2) * jnp.exp(1j * 2 * x) >>> fig = plot_wavefunction_components(x, psi) >>> plt.show()