"""Hamiltonian implementations for quantum mechanics."""
from collections.abc import Callable
import jax.numpy as jnp
from jax import Array
from schr.core.base import Hamiltonian
from schr.utils.fft import momentum_operator
[docs]
class ParticleInPotential(Hamiltonian):
r"""Hamiltonian for particle in arbitrary potential.
.. math::
\hat{H} = -\frac{\hbar^2}{2m}\nabla^2 + V(\mathbf{r}, t)
Attributes:
potential: Callable V(r, t) returning potential energy.
mass: Particle mass (a.u., default: 1.0 = electron mass).
hbar: Reduced Planck constant (a.u., default: 1.0).
dx: Grid spacing (a.u.).
kinetic_operator: Momentum space kinetic energy :math:`\hbar^2k^2/(2m)`.
"""
def __init__(
self,
potential: Callable[[Array, float], Array],
grid_shape: tuple,
dx: float,
mass: float = 1.0,
hbar: float = 1.0,
dtype: jnp.dtype = jnp.complex64,
):
"""Initialize Hamiltonian.
Args:
potential: Function V(r, t) returning potential energy (a.u.).
grid_shape: Shape (nx,) for 1D, (ny, nx) for 2D, (nz, ny, nx) for 3D.
dx: Grid spacing (a.u., uniform).
mass: Particle mass (a.u., default: 1.0 = electron mass).
hbar: Reduced Planck constant (a.u., default: 1.0).
dtype: JAX dtype (default: complex64).
"""
super().__init__(dtype=dtype)
self.potential = potential
self.mass = mass
self.hbar = hbar
self.dx = dx
self.grid_shape = grid_shape
k_components = momentum_operator(grid_shape, dx, hbar=hbar)
k_squared = sum(k**2 for k in k_components)
self.kinetic_operator = (hbar**2 / (2 * mass)) * k_squared
[docs]
def apply(self, psi: Array, t: float) -> Array:
r"""Apply :math:`\hat{H}|\psi\rangle = (\hat{T} + \hat{V})|\psi\rangle` using split-operator method.
Args:
psi: Wavefunction.
t: Time (a.u.).
Returns:
:math:`\hat{H}|\psi\rangle`.
"""
psi_k = jnp.fft.fftn(psi)
t_psi_k = self.kinetic_operator * psi_k
t_psi = jnp.fft.ifftn(t_psi_k)
v_psi = self.potential(psi, t) * psi
return t_psi + v_psi
[docs]
def energy(self, psi: Array, t: float) -> float:
r"""Compute energy expectation :math:`\langle\psi|\hat{H}|\psi\rangle`.
Args:
psi: Normalized wavefunction.
t: Time (a.u.).
Returns:
Energy expectation value (a.u.).
"""
h_psi = self.apply(psi, t)
ndim = psi.ndim
dv = self.dx**ndim
return jnp.real(jnp.sum(jnp.conj(psi) * h_psi) * dv)
[docs]
def kinetic_energy(self, psi: Array) -> float:
r"""Compute kinetic energy expectation :math:`\langle\psi|\hat{T}|\psi\rangle`.
Args:
psi: Normalized wavefunction.
Returns:
Kinetic energy (a.u.).
"""
psi_k = jnp.fft.fftn(psi)
t_psi_k = self.kinetic_operator * psi_k
t_psi = jnp.fft.ifftn(t_psi_k)
ndim = psi.ndim
dv = self.dx**ndim
return jnp.real(jnp.sum(jnp.conj(psi) * t_psi) * dv)
[docs]
def potential_energy(self, psi: Array, t: float) -> float:
r"""Compute potential energy expectation :math:`\langle\psi|\hat{V}|\psi\rangle`.
Args:
psi: Normalized wavefunction.
t: Time (a.u.).
Returns:
Potential energy (a.u.).
"""
v_psi = self.potential(psi, t) * psi
ndim = psi.ndim
dv = self.dx**ndim
return jnp.real(jnp.sum(jnp.conj(psi) * v_psi) * dv)
[docs]
class FreeParticle(ParticleInPotential):
r"""Free particle Hamiltonian (:math:`V = 0`).
.. math::
\hat{H} = -\frac{\hbar^2}{2m}\nabla^2
"""
def __init__(
self,
grid_shape: tuple,
dx: float,
mass: float = 1.0,
hbar: float = 1.0,
dtype: jnp.dtype = jnp.complex64,
):
"""Initialize free particle Hamiltonian.
Args:
grid_shape: Shape of spatial grid.
dx: Grid spacing (a.u.).
mass: Particle mass (a.u., default: 1.0 = electron mass).
hbar: Reduced Planck constant (a.u., default: 1.0).
dtype: JAX dtype (default: complex64).
"""
def zero_potential(r: Array, t: float) -> Array:
return jnp.zeros_like(r, dtype=jnp.float32)
super().__init__(
potential=zero_potential,
grid_shape=grid_shape,
dx=dx,
mass=mass,
hbar=hbar,
dtype=dtype,
)
[docs]
class HarmonicOscillator(ParticleInPotential):
r"""Quantum harmonic oscillator Hamiltonian.
.. math::
\hat{H} = -\frac{\hbar^2}{2m}\nabla^2 + \frac{1}{2}m\omega^2r^2
"""
def __init__(
self,
omega: float,
grid_shape: tuple,
dx: float,
grid_coords: Array,
mass: float = 1.0,
hbar: float = 1.0,
dtype: jnp.dtype = jnp.complex64,
):
"""Initialize harmonic oscillator Hamiltonian.
Args:
omega: Angular frequency (a.u.).
grid_shape: Shape of spatial grid.
dx: Grid spacing (a.u.).
grid_coords: Spatial coordinates (x for 1D, (X, Y) for 2D, etc.).
mass: Particle mass (a.u., default: 1.0 = electron mass).
hbar: Reduced Planck constant (a.u., default: 1.0).
dtype: JAX dtype (default: complex64).
"""
self.omega = omega
self.grid_coords = grid_coords
def harmonic_potential(r: Array, t: float) -> Array:
if isinstance(self.grid_coords, tuple):
r_squared = sum(coord**2 for coord in self.grid_coords)
else:
r_squared = self.grid_coords**2
return 0.5 * mass * omega**2 * r_squared
super().__init__(
potential=harmonic_potential,
grid_shape=grid_shape,
dx=dx,
mass=mass,
hbar=hbar,
dtype=dtype,
)