Source code for vibeqc.bands

"""Band structures, k-paths, and density-of-states for periodic systems.

Two related products:

* **Band structure.** Eigenvalues of a one-electron Fock matrix sampled
  along a path of k-points (often through high-symmetry points of the
  Brillouin zone). Plot the resulting :class:`BandStructure` to see
  the dispersion of each band.
* **Density of states (DOS).** Eigenvalues collected over a dense
  Monkhorst–Pack k-mesh, broadened with a Gaussian, projected onto an
  energy grid. Plot the resulting :class:`DensityOfStates` to see how
  the electronic states distribute in energy — band gaps and van Hove
  singularities are obvious by inspection.

Both objects are pure-Python dataclasses so callers can post-process,
serialise, or hand-render. The matching matplotlib plotters live in
:mod:`vibeqc.plot`.

Workflow
--------

The current public entry points take a real-space Fock and overlap
lattice set (``LatticeMatrixSet``) and sample at user-supplied
k-points. For the *non-interacting* (Hcore) limit this is straightforward:
build ``Hcore_lattice = T + V`` from
:func:`vibeqc.compute_kinetic_lattice` + :func:`vibeqc.compute_nuclear_lattice`
and pass it in. Convenience wrappers
:func:`band_structure_hcore` and :func:`density_of_states_hcore`
do this for you.

Bands and DOS computed from a converged interacting Fock matrix require
the user to supply ``F_real`` from their own SCF, since vibe-qc does not
yet persist the real-space converged Fock on the
:class:`PeriodicRHFResult` / :class:`PeriodicKSResult` objects.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import List, Optional, Sequence, Tuple

import numpy as np

from ._vibeqc_core import (
    BasisSet,
    BlochKMesh,
    LatticeMatrixSet,
    LatticeSumOptions,
    PeriodicSystem,
    bloch_sum,
    compute_kinetic_lattice,
    compute_nuclear_lattice,
    compute_overlap_lattice,
    diagonalize_bloch,
    monkhorst_pack,
)


__all__ = [
    "BandStructure",
    "DensityOfStates",
    "KPath",
    "kpath_from_segments",
    "band_structure",
    "band_structure_hcore",
    "density_of_states",
    "density_of_states_hcore",
]


# ---------------------------------------------------------------------------
# Data types
# ---------------------------------------------------------------------------

[docs] @dataclass class KPath: """A discretised k-path — sequence of k-points with cumulative arc length (for the x-axis when plotting) and labelled high-symmetry points (for the x-axis tick marks).""" kpoints_cart: np.ndarray # (N, 3), bohr⁻¹ kpoints_frac: np.ndarray # (N, 3), reciprocal-lattice fractional distances: np.ndarray # (N,) cumulative |Δk_cart| labels: List[Tuple[float, str]] # (distance, label) ticks for plotting @property def n_points(self) -> int: return int(self.kpoints_cart.shape[0])
[docs] @dataclass class BandStructure: """Eigenvalues at every k along a path. ``energies`` shape is ``(n_points, n_bands)`` with bands sorted in ascending energy at each k. Reference energy ``e_fermi`` is the HOMO eigenvalue (over all sampled k-points) when ``n_electrons`` is set, else ``None``. """ kpath: KPath energies: np.ndarray # (n_points, n_bands), Hartree e_fermi: Optional[float] = None # HOMO of the sampled bands, Hartree n_electrons_per_cell: Optional[int] = None @property def n_bands(self) -> int: return int(self.energies.shape[1])
[docs] def shifted_energies(self) -> np.ndarray: """Energies relative to ``e_fermi`` (or zero if not set).""" return self.energies - (self.e_fermi or 0.0)
[docs] @dataclass class DensityOfStates: """Total density of states. ``energies`` is the energy grid (Hartree). ``dos`` is the Gaussian-broadened state count per energy unit per unit cell: dos(ε) = (1 / N_k) · Σ_{k, n} w_k · g(ε − ε_n(k)) with ``g`` a unit-area Gaussian of width ``sigma``. The integral of ``dos`` over all energy equals ``n_bands`` (every state contributes one). """ energies: np.ndarray # (n_e,) Hartree dos: np.ndarray # (n_e,) states / Hartree / cell sigma: float # broadening, Hartree e_fermi: Optional[float] = None
[docs] def shifted_energies(self) -> np.ndarray: return self.energies - (self.e_fermi or 0.0)
# --------------------------------------------------------------------------- # k-path construction # ---------------------------------------------------------------------------
[docs] def kpath_from_segments( system: PeriodicSystem, segments: Sequence[Tuple[Sequence[float], str, Sequence[float], str]], *, points_per_segment: int = 30, ) -> KPath: """Stitch a piecewise-linear k-path through high-symmetry points. ``segments`` is a list of ``(start_frac, start_label, end_frac, end_label)`` tuples. Adjacent segments may share an endpoint — e.g. for a path Γ → X → M → Γ, three segments share their successive endpoints. Returned ``KPath`` carries Cartesian k-points (bohr⁻¹), the matching fractional coordinates, the cumulative arc length along the path (used as the x-coordinate when plotting), and the (distance, label) tick marks for high-symmetry points. ``points_per_segment`` controls the resolution of each leg. """ if not segments: raise ValueError("kpath_from_segments: at least one segment required") B = system.reciprocal_lattice() # 3 × 3, columns = b_1, b_2, b_3 all_frac: List[np.ndarray] = [] labels: List[Tuple[float, str]] = [] distances: List[float] = [] last_cart: Optional[np.ndarray] = None cumulative = 0.0 for s_idx, (k_a, lbl_a, k_b, lbl_b) in enumerate(segments): k_a = np.asarray(k_a, dtype=float).reshape(3) k_b = np.asarray(k_b, dtype=float).reshape(3) # Skip the duplicate point at segment boundaries. i_start = 0 if s_idx == 0 else 1 for i in range(i_start, points_per_segment + 1): t = i / points_per_segment kf = (1.0 - t) * k_a + t * k_b kc = B @ kf if last_cart is not None: cumulative += float(np.linalg.norm(kc - last_cart)) all_frac.append(kf) distances.append(cumulative) last_cart = kc if i == 0: labels.append((cumulative, lbl_a)) elif i == points_per_segment: labels.append((cumulative, lbl_b)) frac_arr = np.asarray(all_frac, dtype=float) cart_arr = (B @ frac_arr.T).T dist_arr = np.asarray(distances, dtype=float) # Coalesce duplicate-position labels at segment joins by collapsing # adjacent entries with the same distance and combining their labels # with "|" — a common convention in band-structure plots. dedup: List[Tuple[float, str]] = [] for d, lbl in labels: if dedup and abs(dedup[-1][0] - d) < 1e-12 and dedup[-1][1] != lbl: dedup[-1] = (d, f"{dedup[-1][1]}|{lbl}") elif not dedup or dedup[-1] != (d, lbl): dedup.append((d, lbl)) return KPath( kpoints_cart=cart_arr, kpoints_frac=frac_arr, distances=dist_arr, labels=dedup, )
# --------------------------------------------------------------------------- # Band structure — sample a Fock matrix at every k along a path # --------------------------------------------------------------------------- def _eigenvalues_at_kpoints( F_terms: Sequence[LatticeMatrixSet], S_real: LatticeMatrixSet, kpoints_cart: np.ndarray, ) -> np.ndarray: """For each k in ``kpoints_cart`` Bloch-sum each term in ``F_terms``, sum them to form ``F(k)``, Bloch-sum ``S(g)``, and diagonalise to obtain band energies. Returned shape: ``(n_kpoints, nbf)``. Taking ``F_terms`` as a list (rather than a single ``LatticeMatrixSet``) avoids materialising ``Hcore = T + V`` as its own lattice matrix set — Bloch summation is linear, so we just add the per-term Bloch sums. """ n_pts = kpoints_cart.shape[0] nbf = S_real.nbf out = np.empty((n_pts, nbf), dtype=float) for i in range(n_pts): k = kpoints_cart[i] Fk = bloch_sum(F_terms[0], k) for term in F_terms[1:]: Fk = Fk + bloch_sum(term, k) Sk = bloch_sum(S_real, k) # Symmetrise tiny imaginary drift before diagonalising. Fk = 0.5 * (Fk + Fk.conj().T) Sk = 0.5 * (Sk + Sk.conj().T) bd = diagonalize_bloch(Fk, Sk) out[i, :] = np.asarray(bd.energies) return out
[docs] def band_structure( F_real: LatticeMatrixSet, S_real: LatticeMatrixSet, kpath: KPath, *, n_electrons_per_cell: Optional[int] = None, ) -> BandStructure: """Sample a real-space Fock matrix along a k-path. The Fock and overlap come from a converged (or non-interacting) SCF in real-space-lattice form. For non-interacting bands use :func:`band_structure_hcore`. If ``n_electrons_per_cell`` is supplied, the highest occupied eigenvalue across all sampled k-points is recorded as ``e_fermi`` so a plotter can shift the reference. For closed-shell systems ``n_electrons_per_cell // 2`` bands are occupied at each k. """ energies = _eigenvalues_at_kpoints([F_real], S_real, kpath.kpoints_cart) e_fermi: Optional[float] = None if n_electrons_per_cell is not None: if n_electrons_per_cell % 2 != 0: # Open-shell band-structure interpretation is more involved # (alpha vs beta channels). For now we only set e_fermi for # closed-shell systems; the array of energies is still # returned so the user can decide what to do. e_fermi = None else: n_occ = n_electrons_per_cell // 2 occ_max = energies[:, :n_occ].max() if n_occ > 0 else None e_fermi = float(occ_max) if occ_max is not None else None return BandStructure( kpath=kpath, energies=energies, e_fermi=e_fermi, n_electrons_per_cell=n_electrons_per_cell, )
[docs] def band_structure_hcore( system: PeriodicSystem, basis: BasisSet, kpath: KPath, *, lattice_opts: Optional[LatticeSumOptions] = None, n_electrons_per_cell: Optional[int] = None, ) -> BandStructure: """Non-interacting (Hcore) band structure: eigenvalues of T + V at every k-point. Useful for system-shape sanity checks before investing in a full SCF. """ opts = lattice_opts if lattice_opts is not None else LatticeSumOptions() S = compute_overlap_lattice(basis, system, opts) T = compute_kinetic_lattice(basis, system, opts) V = compute_nuclear_lattice(basis, system, opts) energies = _eigenvalues_at_kpoints([T, V], S, kpath.kpoints_cart) e_fermi: Optional[float] = None if n_electrons_per_cell is not None and n_electrons_per_cell % 2 == 0: n_occ = n_electrons_per_cell // 2 if n_occ > 0: e_fermi = float(energies[:, :n_occ].max()) return BandStructure( kpath=kpath, energies=energies, e_fermi=e_fermi, n_electrons_per_cell=n_electrons_per_cell, )
# --------------------------------------------------------------------------- # Density of states # --------------------------------------------------------------------------- def _gaussian_dos( energies_per_k: np.ndarray, # (n_k, n_bands) weights: np.ndarray, # (n_k,) k-point weights, sum = 1 energy_grid: np.ndarray, # (n_e,) sigma: float, ) -> np.ndarray: """Sum of unit-area Gaussians centred on every (k, band) eigenvalue, weighted by ``weights[k]``.""" n_e = energy_grid.size dos = np.zeros(n_e, dtype=float) inv_2sigma2 = 0.5 / (sigma * sigma) norm = 1.0 / (sigma * np.sqrt(2.0 * np.pi)) for k, w in enumerate(weights): for eps in energies_per_k[k]: dos += w * norm * np.exp(-((energy_grid - eps) ** 2) * inv_2sigma2) return dos
[docs] def density_of_states( F_real: LatticeMatrixSet, S_real: LatticeMatrixSet, kmesh: BlochKMesh, *, sigma: float = 0.01, energy_grid: Optional[np.ndarray] = None, n_grid: int = 401, pad: float = 5.0, n_electrons_per_cell: Optional[int] = None, ) -> DensityOfStates: """Total DOS computed by Gaussian-broadening every eigenvalue of ``F(k)`` over ``kmesh`` onto an energy grid. ``sigma`` is the Gaussian width in Hartree. ``energy_grid`` defaults to a uniform ``n_grid``-point grid spanning the eigenvalue range extended by ``pad·sigma`` on either side so the broadened tails fit. """ kpoints = np.asarray([np.asarray(k) for k in kmesh.kpoints]) weights = np.asarray(kmesh.weights, dtype=float) if abs(weights.sum() - 1.0) > 1e-8: # Defensive: make this work even with a mesh whose weights weren't # normalised by the caller. weights = weights / weights.sum() energies_per_k = _eigenvalues_at_kpoints([F_real], S_real, kpoints) if energy_grid is None: e_min = energies_per_k.min() - pad * sigma e_max = energies_per_k.max() + pad * sigma energy_grid = np.linspace(e_min, e_max, n_grid) dos = _gaussian_dos(energies_per_k, weights, energy_grid, sigma) e_fermi: Optional[float] = None if n_electrons_per_cell is not None and n_electrons_per_cell % 2 == 0: n_occ = n_electrons_per_cell // 2 if n_occ > 0: e_fermi = float(energies_per_k[:, :n_occ].max()) return DensityOfStates( energies=energy_grid, dos=dos, sigma=sigma, e_fermi=e_fermi, )
[docs] def density_of_states_hcore( system: PeriodicSystem, basis: BasisSet, mesh: Sequence[int], *, sigma: float = 0.01, n_grid: int = 401, pad: float = 5.0, lattice_opts: Optional[LatticeSumOptions] = None, n_electrons_per_cell: Optional[int] = None, ) -> DensityOfStates: """Non-interacting DOS for a system on a Monkhorst–Pack mesh. Convenience for quick band-structure overviews.""" opts = lattice_opts if lattice_opts is not None else LatticeSumOptions() S = compute_overlap_lattice(basis, system, opts) T = compute_kinetic_lattice(basis, system, opts) V = compute_nuclear_lattice(basis, system, opts) km = monkhorst_pack(system, list(mesh)) kpoints = np.asarray([np.asarray(k) for k in km.kpoints]) weights = np.asarray(km.weights, dtype=float) if abs(weights.sum() - 1.0) > 1e-8: weights = weights / weights.sum() energies_per_k = _eigenvalues_at_kpoints([T, V], S, kpoints) e_min = energies_per_k.min() - pad * sigma e_max = energies_per_k.max() + pad * sigma energy_grid = np.linspace(e_min, e_max, n_grid) dos = _gaussian_dos(energies_per_k, weights, energy_grid, sigma) e_fermi: Optional[float] = None if n_electrons_per_cell is not None and n_electrons_per_cell % 2 == 0: n_occ = n_electrons_per_cell // 2 if n_occ > 0: e_fermi = float(energies_per_k[:, :n_occ].max()) return DensityOfStates( energies=energy_grid, dos=dos, sigma=sigma, e_fermi=e_fermi, )