"""Native multi-k periodic RHF + RKS on arbitrary Bravais lattices.
This module implements closed-shell multi-k SCF on arbitrary Bravais
lattices and Monkhorst-Pack k-meshes. The per-iteration Fock build
is delegated to :func:`vibeqc.build_periodic_fock_ewald3d_k`
(EWALD_3D gauge: real-space short-range J/K via
``build_fock_2e_ewald3d_blocks`` plus FFT-Poisson long-range J via
``build_j_long_range_periodic``), then Bloch-summed to each k-point.
The SCF loop reuses the multi-k DIIS machinery from the legacy
Ewald driver.
Note on the GDF aux basis: ``aux_basis`` / ``aux_drop_eta`` are
accepted for forward compatibility with a future Lpq(q)-based path,
but the current per-iteration Fock build does **not** materialise a
``Lpq(q)`` tensor — work routes through the EWALD_3D real-space
blocks instead. See commit c74527b for the gauge-consistency reason
the driver moved off Lpq(q) for multi-k. The module's own Γ-only
fast-path (kmesh = (1,1,1)) still delegates to
:func:`vibeqc.run_rhf_periodic_gamma_gdf`, which does use Lpq.
Convention notes
----------------
* No ``exxdiv='ewald'`` Madelung correction is applied by default;
pass ``exxdiv='ewald'`` to enable. The Γ-fastpath mirrors
:func:`vibeqc.run_rhf_periodic_gamma_gdf`'s default (``exxdiv=None``).
* Generic Bravais: the driver uses arbitrary lattice vectors via
the existing ``bloch_sum`` and Bloch-block helpers, with no
cubic-only assumption. 1D / 2D / 3D periodicity all supported.
* Closed-shell RHF and RKS only in this module. UHF / UKS multi-k
is a separate module (pending).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
from ._vibeqc_core import (
BasisSet,
BlochKMesh,
Functional,
GridOptions,
InitialGuess,
LatticeSumOptions,
PeriodicKSOptions,
PeriodicRHFOptions,
PeriodicSystem,
SCFIteration,
bloch_sum,
build_grid,
build_xc_periodic,
compute_kinetic_lattice,
compute_overlap_lattice,
direct_lattice_cells,
nuclear_repulsion_per_cell,
)
from ._vibeqc_core import (
direct_lattice_cells as _direct_cells,
)
from ._vibeqc_core import (
monkhorst_pack as _mp_native,
)
from .aux_basis import (
build_lpq_bloch_native,
default_aux_for,
make_aux_basis_set,
)
from .kpoints import KPoints
from .linear_dependence import scf_preflight_overlap_check
from .periodic_k_density import (
density_matrices_per_k as _density_from_orbitals,
real_space_density_from_per_k_density as _real_space_density_from_per_k_density,
)
from .smearing import (
closed_shell_periodic_occupations as _closed_shell_periodic_occupations,
)
from .smearing import (
hartree_to_kelvin_temperature as _hartree_to_kelvin_temperature,
)
from .options_dump import dump_active_settings
from .periodic_fock_multi_k import build_periodic_fock_ewald3d_k
from .periodic_grid import build_periodic_becke_grid
from .periodic_rhf_gdf import (
PeriodicRHFGDFResult,
_resolve_fock_mixing,
_resolve_level_shift_warmup_cycles,
run_rhf_periodic_gamma_gdf,
)
from .periodic_rhf_multi_k_ewald import (
_canonical_orthogonalizer_complex,
_diag_in_orth_basis,
_g0_block,
_MultiKPulayDIIS,
_reject_unsupported_python_accelerator,
)
from .progress import ProgressLogger, resolve_progress
__all__ = [
"PeriodicKRHFGDFResult",
"PeriodicKRKSGDFResult",
"run_krhf_periodic_gdf",
"run_krks_periodic_gdf",
]
# =====================================================================
# Result types
# =====================================================================
[docs]
@dataclass
class PeriodicKRHFGDFResult:
"""Result of :func:`run_krhf_periodic_gdf`.
All per-k arrays are length-``nkpts`` Python lists holding
complex Hermitian (or real, for Γ-only) matrices in AO basis.
"""
energy: float
e_electronic: float
e_nuclear: float
n_iter: int
converged: bool
mo_energies: List[np.ndarray]
mo_coeffs: List[np.ndarray]
fock: List[np.ndarray]
overlap: List[np.ndarray]
hcore: List[np.ndarray]
density: List[np.ndarray]
kpoints_cart: np.ndarray
kpoint_weights: np.ndarray
scf_trace: List[SCFIteration] = field(default_factory=list)
functional: Optional[str] = None
e_xc: float = 0.0
e_coulomb: float = 0.0
e_hf_exchange: float = 0.0
fock_mixing: float = 0.0
level_shift: float = 0.0
level_shift_warmup_cycles: int = 0
smearing_temperature: float = 0.0
fermi_level: float = 0.0
entropy: float = 0.0
free_energy: float = 0.0
occupations: List[np.ndarray] = field(default_factory=list)
aux_basis_name: str = ""
n_aux: int = 0
backend: str = "native-multi-k-gdf"
@property
def energy_per_cell_ha(self) -> float:
return float(self.energy)
[docs]
@dataclass
class PeriodicKRKSGDFResult(PeriodicKRHFGDFResult):
"""Result of :func:`run_krks_periodic_gdf`."""
# =====================================================================
# Setup helpers
# =====================================================================
def _options_or_default(options, *, is_ks: bool):
if options is not None:
return options
return PeriodicKSOptions() if is_ks else PeriodicRHFOptions()
@dataclass(frozen=True)
class _GammaKMeshInfo:
"""Single-Γ k-mesh metadata for the native Γ fast path."""
kpoints_cart: np.ndarray
weights: np.ndarray
input_n_kpoints: int
def _gamma_kmesh_info(
system: PeriodicSystem,
kmesh: Union[Sequence[int], KPoints, BlochKMesh],
) -> Optional[_GammaKMeshInfo]:
"""Return metadata when ``kmesh`` is exactly a single Γ point.
The KRHF/KRKS public entry points use multi-k-shaped APIs, but for
a single Γ point the answer is delegated to the (fully native and
well-tested) Γ-GDF driver. Anything else falls through into the
real multi-k SCF.
"""
if isinstance(kmesh, KPoints):
kpoints = np.asarray(kmesh.kpoints_cart, dtype=np.float64).reshape(-1, 3)
weights = np.asarray(kmesh.weights, dtype=np.float64).reshape(-1)
elif isinstance(kmesh, BlochKMesh):
kpoints = np.asarray(kmesh.kpoints, dtype=np.float64).reshape(-1, 3)
weights = np.asarray(kmesh.weights, dtype=np.float64).reshape(-1)
else:
mesh = _mesh_tuple_for_system(system, kmesh)
if mesh != (1, 1, 1):
return None
bm = _mp_native(system, [1, 1, 1], [0, 0, 0], False)
kpoints = np.asarray(bm.kpoints, dtype=np.float64).reshape(-1, 3)
weights = np.asarray(bm.weights, dtype=np.float64).reshape(-1)
if kpoints.shape != (1, 3) or weights.shape != (1,):
return None
if not np.allclose(kpoints[0], 0.0, atol=1e-12, rtol=0.0):
return None
if not np.isclose(float(weights[0]), 1.0, atol=1e-12, rtol=0.0):
return None
return _GammaKMeshInfo(
kpoints_cart=kpoints.copy(),
weights=weights.copy(),
input_n_kpoints=1,
)
def _wrap_gamma_gdf_result(
gamma: PeriodicRHFGDFResult,
info: _GammaKMeshInfo,
*,
functional: Optional[str],
result_cls,
):
"""Adapt the native Γ-GDF result to the KRHF/KRKS result shape."""
return result_cls(
energy=float(gamma.energy),
e_electronic=float(gamma.e_electronic),
e_nuclear=float(gamma.e_nuclear),
n_iter=int(gamma.n_iter),
converged=bool(gamma.converged),
mo_energies=[np.asarray(gamma.mo_energies)],
mo_coeffs=[np.asarray(gamma.mo_coeffs)],
fock=[np.asarray(gamma.fock)],
overlap=[np.asarray(gamma.overlap)],
hcore=[np.asarray(gamma.hcore)],
density=[np.asarray(gamma.density)],
kpoints_cart=np.asarray(info.kpoints_cart, dtype=np.float64),
kpoint_weights=np.asarray(info.weights, dtype=np.float64),
scf_trace=list(gamma.scf_trace),
functional=functional or str(getattr(gamma, "functional", "") or "") or None,
e_xc=float(getattr(gamma, "e_xc", 0.0)),
e_coulomb=float(getattr(gamma, "e_coulomb", 0.0)),
e_hf_exchange=float(getattr(gamma, "e_hf_exchange", 0.0)),
fock_mixing=float(getattr(gamma, "fock_mixing", 0.0)),
level_shift=float(getattr(gamma, "level_shift", 0.0)),
level_shift_warmup_cycles=int(getattr(gamma, "level_shift_warmup_cycles", 0)),
smearing_temperature=float(getattr(gamma, "smearing_temperature", 0.0)),
fermi_level=float(getattr(gamma, "fermi_level", 0.0)),
entropy=float(getattr(gamma, "entropy", 0.0)),
free_energy=float(getattr(gamma, "free_energy", gamma.energy)),
occupations=[np.asarray(getattr(gamma, "occupations", np.empty(0)))],
aux_basis_name=str(getattr(gamma, "aux_basis_name", "") or ""),
n_aux=int(getattr(gamma, "n_aux", 0)),
backend="native-gamma-gdf-via-k-gdf",
)
def _mesh_tuple_for_system(
system: PeriodicSystem,
mesh: Sequence[int],
) -> Tuple[int, int, int]:
dim = int(system.dim)
if dim not in (1, 2, 3):
raise ValueError(f"PeriodicSystem.dim must be 1, 2, or 3; got {dim}")
arr = list(mesh)
if len(arr) == dim:
arr = arr + [1] * (3 - dim)
elif len(arr) != 3:
raise ValueError(
f"periodic GDF: kmesh tuple must have length {dim} for "
f"dim={dim} systems or length 3; got {arr!r}"
)
out = tuple(int(x) for x in arr)
if any(x < 1 for x in out):
raise ValueError(f"periodic GDF: kmesh entries must be >= 1; got {arr!r}")
return tuple(out[i] if i < dim else 1 for i in range(3))
def _kmesh_to_kpoints_weights(
system: PeriodicSystem,
kmesh: Union[Sequence[int], KPoints, BlochKMesh],
) -> Tuple[np.ndarray, np.ndarray]:
"""Normalise ``kmesh`` to ``(kpoints_cart, weights)`` arrays.
``kpoints_cart`` is shape ``(n_k, 3)`` in bohr⁻¹.
``weights`` is shape ``(n_k,)`` summing to 1.
"""
if isinstance(kmesh, KPoints):
kpts = np.asarray(kmesh.kpoints_cart, dtype=np.float64).reshape(-1, 3)
w = np.asarray(kmesh.weights, dtype=np.float64).reshape(-1)
elif isinstance(kmesh, BlochKMesh):
kpts = np.asarray(kmesh.kpoints, dtype=np.float64).reshape(-1, 3)
w = np.asarray(kmesh.weights, dtype=np.float64).reshape(-1)
else:
mesh = _mesh_tuple_for_system(system, kmesh)
bm = _mp_native(system, list(mesh), [0, 0, 0], False)
kpts = np.asarray(bm.kpoints, dtype=np.float64).reshape(-1, 3)
w = np.asarray(bm.weights, dtype=np.float64).reshape(-1)
if kpts.shape[0] == 0:
raise ValueError("periodic k-GDF: kmesh has zero k-points")
if not np.isclose(float(w.sum()), 1.0, atol=1e-9):
raise ValueError(
f"periodic k-GDF: kpoint weights must sum to 1; got {float(w.sum()):.6f}"
)
return kpts, w
# =====================================================================
def _occupations_per_k(
eps_per_k: Sequence[np.ndarray],
weights: np.ndarray,
n_elec_per_cell: int,
smearing_T: float,
n_occ_each: int,
) -> Tuple[List[np.ndarray], float, float]:
"""Fermi-Dirac or hard-Aufbau occupations across the k-mesh.
Returns ``(occ_per_k, fermi_level, entropy_per_cell)``.
"""
return _closed_shell_periodic_occupations(
eps_per_k,
weights,
float(n_elec_per_cell),
int(n_occ_each),
float(smearing_T),
)
[docs]
def run_krhf_periodic_gdf(
system: PeriodicSystem,
basis: BasisSet,
kmesh: Union[Sequence[int], KPoints, BlochKMesh] = (1, 1, 1),
options: Optional[Union[PeriodicRHFOptions, PeriodicKSOptions]] = None,
*,
functional: Optional[str] = None,
aux_basis: Optional[str] = None,
aux_drop_eta: float = 0.0,
linear_dep_threshold: float = 1e-7,
gdf_linear_dep_threshold: float = 1e-9,
apply_modrho: bool = True,
fock_mixing: Optional[float] = None,
level_shift_warmup_cycles: Optional[int] = None,
use_compcell: bool = False,
compcell_eta: float = 0.2,
apply_aft_correction: bool = False,
aft_ft_convention: str = "libcint",
aft_precision: float = 1e-10,
rcut_strategy: Optional[object] = None,
rcut_precision: float = 1e-8,
progress: Union[bool, ProgressLogger, None] = None,
verbose: Optional[int] = None,
) -> PeriodicKRHFGDFResult:
"""Run closed-shell periodic HF / KS multi-k SCF via native GDF.
For a single Γ point this delegates to
:func:`vibeqc.run_rhf_periodic_gamma_gdf` (kept in lock-step with
the multi-k path); for any other ``kmesh`` it runs the full
multi-k loop here.
Parameters
----------
system, basis
Periodic system and AO basis.
kmesh
``(n1, n2, n3)`` Monkhorst-Pack mesh, a :class:`KPoints`
instance, or a :class:`BlochKMesh`. Defaults to Γ-only.
options
:class:`PeriodicRHFOptions` (HF) or :class:`PeriodicKSOptions`
(KS).
functional
libxc functional name when running KS; ``None`` means HF.
aux_basis
Auxiliary basis name. Defaults to ``default_aux_for(basis.name)``.
aux_drop_eta
Auxiliary primitive cull threshold passed to
:func:`make_aux_basis_set`.
linear_dep_threshold
Per-k overlap eigenvalue floor for canonical orthogonalisation.
gdf_linear_dep_threshold
Auxiliary metric eigenvalue floor for ``Lpq`` Cholesky-style
fitting (forwarded to :func:`build_lpq_bloch_native`).
apply_modrho
Whether the auxiliary basis is renormalised via
:func:`aux_basis.modrho_renormalise` before fitting (default
on; matches the Γ-only driver).
fock_mixing
Override the resolver-resolved CRYSTAL FMIXING fraction.
level_shift_warmup_cycles
Override the resolver-resolved level-shift warm-up length.
progress, verbose
Live progress logging passthrough.
"""
opts = _options_or_default(options, is_ks=functional is not None)
plog = resolve_progress(progress, verbose=verbose)
# ---------------- Γ fast path ----------------------------------
# The Γ-fast-path delegates to run_rhf_periodic_gamma_gdf which
# short-circuits J/K to Ewald-3D + molecular-limit-K on dim=3
# (it doesn't have a compcell option). For Γ-only compcell SCF,
# users should call ``vibeqc.run_pbc_gdf_rhf`` directly — the
# Γ-only driver that uses Lpq for both J and K. We attempted to
# skip this fast path when use_compcell=True and fall through to
# the multi-k branch with n_k=1, but the multi-k SCF loop's
# per-k weighting doesn't degenerate cleanly to Γ-only
# (gives ~5× incorrect energy on H2). Keeping the fast path means
# ``use_compcell=True`` at ``kmesh=(1,1,1)`` is silently ignored
# — surfaced via a warning so users know to switch drivers.
gamma_info = _gamma_kmesh_info(system, kmesh)
if gamma_info is not None and use_compcell:
plog.info(
" WARNING: use_compcell=True at kmesh=(1,1,1) is currently "
"ignored (the Γ-fastpath delegates to run_rhf_periodic_gamma_gdf "
"which doesn't support compcell). For Γ-only compcell SCF, "
"call vibeqc.run_pbc_gdf_rhf(...) directly. Continuing with "
"the legacy Ewald-3D + molecular-limit-K path."
)
if gamma_info is not None:
gamma = run_rhf_periodic_gamma_gdf(
system,
basis,
opts,
functional=functional,
aux_basis=aux_basis,
aux_drop_eta=aux_drop_eta,
linear_dep_threshold=linear_dep_threshold,
gdf_linear_dep_threshold=gdf_linear_dep_threshold,
apply_modrho=apply_modrho,
fock_mixing=fock_mixing,
level_shift_warmup_cycles=level_shift_warmup_cycles,
progress=plog,
verbose=verbose,
)
result_cls = (
PeriodicKRKSGDFResult if functional is not None else PeriodicKRHFGDFResult
)
return _wrap_gamma_gdf_result(
gamma,
gamma_info,
functional=functional,
result_cls=result_cls,
)
# ---------------- Multi-k branch ------------------------------
func_name = functional or str(getattr(opts, "functional", "") or "")
is_ks = bool(func_name)
func = Functional(func_name, 1) if is_ks else None
alpha = float(func.hf_exchange_fraction) if func is not None else 1.0
fock_mixing_value = _resolve_fock_mixing(opts, fock_mixing)
level_shift = float(getattr(opts, "level_shift", 0.0))
max_iter = int(opts.max_iter)
warmup_cycles = _resolve_level_shift_warmup_cycles(
opts,
level_shift=level_shift,
max_iter=max_iter,
override=level_shift_warmup_cycles,
)
smearing_T = float(getattr(opts, "smearing_temperature", 0.0))
if smearing_T < 0.0:
raise ValueError("run_krhf_periodic_gdf: smearing_temperature must be >= 0")
lat_opts: LatticeSumOptions = opts.lattice_opts
label = f"KRKS {func_name}" if is_ks else "KRHF"
n_elec = system.n_electrons()
if n_elec % 2 != 0:
raise ValueError(
"run_krhf_periodic_gdf: closed-shell RHF/RKS requires "
f"even electron count; got {n_elec}"
)
if system.multiplicity != 1:
raise ValueError(
"run_krhf_periodic_gdf: closed-shell RHF/RKS requires "
f"multiplicity=1; got {system.multiplicity}"
)
n_occ = n_elec // 2
kpoints_cart, weights = _kmesh_to_kpoints_weights(system, kmesh)
n_k = kpoints_cart.shape[0]
# Resolve kmesh to BlochKMesh + lattice cells for Ewald Fock builder.
if isinstance(kmesh, BlochKMesh):
kmesh_bloch = kmesh
elif isinstance(kmesh, KPoints):
kmesh_bloch = kmesh.to_bloch_kmesh()
else:
mesh = _mesh_tuple_for_system(system, kmesh)
kmesh_bloch = _mp_native(system, list(mesh), [0, 0, 0], False)
cells = _direct_cells(system, lat_opts.cutoff_bohr)
aux_name = aux_basis or default_aux_for(basis.name)
plog.banner(f"run_krhf_periodic_gdf {label} kmesh={n_k} k-points")
plog.info(
f"{label} multi-k native GDF / aux={aux_name}, "
f"cutoff={lat_opts.cutoff_bohr:.2f} bohr"
)
plog.info(f"basis: {basis.name} ({basis.nbasis} BFs / {basis.nshells} shells)")
dim = int(system.dim)
active_lengths = [
float(np.linalg.norm(np.asarray(system.lattice, dtype=float)[:, i]))
for i in range(dim)
]
plog.info(
f"periodicity: dim={dim}D, active lengths="
+ ", ".join(f"{x:.3f}" for x in active_lengths)
+ " bohr"
)
n_int_cells = len(direct_lattice_cells(system, lat_opts.cutoff_bohr))
n_nuc_cells = len(direct_lattice_cells(system, lat_opts.nuclear_cutoff_bohr))
plog.info(
"lattice cells: "
f"one-electron/GDF cutoff -> {n_int_cells}, "
f"nuclear cutoff -> {n_nuc_cells}"
)
dump_active_settings(
plog,
[
("PeriodicKSOptions" if is_ks else "PeriodicRHFOptions", opts),
("LatticeSumOptions", lat_opts),
(
"k-GDF kwargs",
{
"functional": func_name or None,
"hf_exchange_fraction": alpha,
"fock_mixing": fock_mixing_value,
"fmixing_percent": 100.0 * fock_mixing_value,
"level_shift": level_shift,
"level_shift_warmup_cycles": warmup_cycles,
"smearing_temperature": smearing_T,
"aux_basis": aux_name,
"aux_drop_eta": float(aux_drop_eta),
"linear_dep_threshold": float(linear_dep_threshold),
"gdf_linear_dep_threshold": float(gdf_linear_dep_threshold),
"apply_modrho": bool(apply_modrho),
"n_kpoints": n_k,
},
),
],
)
# ---- Functional + grid ----------------------------------------
grid = None
if is_ks:
grid_options = getattr(opts, "grid", None)
if grid_options is None:
grid_options = GridOptions()
if bool(getattr(opts, "use_periodic_becke", False)):
grid = build_periodic_becke_grid(
system,
grid_options=grid_options,
image_radius_bohr=float(getattr(opts, "becke_image_radius_bohr", 0.0)),
)
else:
grid = build_grid(system.unit_cell_molecule(), grid_options)
# ---- Real-space one-electron integrals ------------------------
with plog.stage(
"integrals_lattice",
detail=f"S/T/V at cutoff {lat_opts.cutoff_bohr:.2f} bohr",
):
S_lat = compute_overlap_lattice(basis, system, lat_opts)
T_lat = compute_kinetic_lattice(basis, system, lat_opts)
from .periodic_rhf_gdf import _gauge_lat_opts_for_v_ne_and_e_nuc
from .periodic_v_ne import compute_nuclear_lattice_dispatch
# Always use Ewald-3D gauge for V_ne and e_nuc — J/K are
# built via build_periodic_fock_ewald3d_k (Ewald gauge).
gauge_lat_opts = _gauge_lat_opts_for_v_ne_and_e_nuc(lat_opts, system)
V_lat = compute_nuclear_lattice_dispatch(basis, system, gauge_lat_opts)
# ---- Per-k S(k), Hcore(k), canonical orthog X(k) -------------
S_k: List[np.ndarray] = []
Hcore_k: List[np.ndarray] = []
X_k: List[np.ndarray] = []
n_kept_k: List[int] = []
for k_idx in range(n_k):
k_arr = kpoints_cart[k_idx]
Sk = np.asarray(bloch_sum(S_lat, k_arr))
Tk = np.asarray(bloch_sum(T_lat, k_arr))
Vk = np.asarray(bloch_sum(V_lat, k_arr))
Sk = 0.5 * (Sk + Sk.conj().T)
Hk = 0.5 * ((Tk + Vk) + (Tk + Vk).conj().T)
scf_preflight_overlap_check(
Sk,
plog=plog,
label=f"S(k={k_idx}, k_cart={k_arr.round(4).tolist()})",
basis=basis,
)
Xk, n_kept = _canonical_orthogonalizer_complex(
Sk,
linear_dep_threshold,
normalize_diag_first=True,
)
if n_occ > n_kept:
raise RuntimeError(
"run_krhf_periodic_gdf: canonical orthogonalisation at "
f"k = {k_arr} dropped too many directions "
f"(n_occ={n_occ}, n_kept={n_kept}); loosen "
"linear_dep_threshold or pick a less redundant basis."
)
S_k.append(Sk)
Hcore_k.append(Hk)
X_k.append(Xk)
n_kept_k.append(n_kept)
e_nuc = float(nuclear_repulsion_per_cell(system, gauge_lat_opts))
# ---- Aux basis (forward-compat; Fock build itself uses EWALD_3D)
mol = system.unit_cell_molecule()
with plog.stage("aux_basis", detail=aux_name):
aux = make_aux_basis_set(
mol,
aux_name=aux_name,
drop_eta=float(aux_drop_eta),
)
plog.info(f"aux basis: {aux_name} ({aux.nbasis} BFs / {aux.nshells} shells)")
# ---- Initial guess: Hcore diagonalisation per k --------------
plog.info("initial guess: HCORE (per-k Hcore diagonalisation)")
C_k: List[np.ndarray] = []
eps_k: List[np.ndarray] = []
for i in range(n_k):
Ci, ei = _diag_in_orth_basis(Hcore_k[i], X_k[i])
C_k.append(Ci.astype(complex))
eps_k.append(ei)
occ_k, fermi_level, entropy = _occupations_per_k(
eps_k,
weights,
n_elec,
smearing_T,
n_occ,
)
D_k = _density_from_orbitals(C_k, occ_k)
if smearing_T > 0.0:
plog.info(
"smearing: Fermi-Dirac kBT = "
f"{smearing_T:.6g} Ha "
f"({_hartree_to_kelvin_temperature(smearing_T):.1f} K)"
)
# ---- SCF setup -----------------------------------------------
damping = float(opts.damping)
if not (0.0 <= damping < 1.0):
raise ValueError(
f"run_krhf_periodic_gdf: damping must be in [0, 1); got {damping}"
)
if fock_mixing_value != 0.0:
plog.info(
"fock mixing: CRYSTAL FMIXING "
f"{100.0 * fock_mixing_value:.1f}% "
"(previous Fock/KS matrix weight, applied per k)"
)
_reject_unsupported_python_accelerator(opts, "run_krhf_periodic_gdf")
use_diis = bool(opts.use_diis)
diis_start_iter = int(opts.diis_start_iter)
diis = (
_MultiKPulayDIIS(max_subspace=int(opts.diis_subspace_size))
if use_diis
else None
)
if level_shift != 0.0:
if warmup_cycles > 0:
cycle_word = "cycle" if warmup_cycles == 1 else "cycles"
plog.info(
f"level-shift warm-up: {warmup_cycles} {cycle_word} at "
f"{level_shift:.3f} Ha (per k); restart unshifted afterwards"
)
else:
plog.info(
f"level shift: {level_shift:.3f} Ha (per k) "
"applied at each diagonalization"
)
plog.banner(f"SCF ({label} multi-k, native GDF)")
plog.info(" iter energy (Ha) dE ||[F,DS]|| DIIS")
scf_trace: List[SCFIteration] = []
result = PeriodicKRHFGDFResult(
energy=0.0,
e_electronic=0.0,
e_nuclear=float(e_nuc),
n_iter=0,
converged=False,
mo_energies=[e.copy() for e in eps_k],
mo_coeffs=[C.copy() for C in C_k],
fock=[np.empty((0, 0), dtype=complex) for _ in range(n_k)],
overlap=[S.copy() for S in S_k],
hcore=[H.copy() for H in Hcore_k],
density=[D.copy() for D in D_k],
kpoints_cart=kpoints_cart.copy(),
kpoint_weights=weights.copy(),
scf_trace=scf_trace,
functional=func_name or None,
fock_mixing=fock_mixing_value,
level_shift=level_shift,
level_shift_warmup_cycles=warmup_cycles,
smearing_temperature=smearing_T,
fermi_level=float(fermi_level),
entropy=float(entropy),
occupations=[np.asarray(o, dtype=float) for o in occ_k],
aux_basis_name=aux_name,
n_aux=int(aux.nbasis),
)
F_prev_k: Optional[List[np.ndarray]] = None
D_prev_k: List[np.ndarray] = [D.copy() for D in D_k]
E_prev = 0.0
for it in range(1, max_iter + 1):
if warmup_cycles > 0 and it == warmup_cycles + 1:
if diis is not None:
diis = _MultiKPulayDIIS(max_subspace=int(opts.diis_subspace_size))
F_prev_k = None
plog.info("restart: unshifted Fock with fresh DIIS history (per k)")
active_level_shift = (
level_shift
if (level_shift != 0.0 and (warmup_cycles == 0 or it <= warmup_cycles))
else 0.0
)
diis_active = use_diis and it >= diis_start_iter
# Density damping (per k, in AO basis).
if it == 1 or damping == 0.0 or diis_active:
D_used = [D.copy() for D in D_k]
else:
D_used = [
damping * Dp + (1.0 - damping) * Dn for Dp, Dn in zip(D_prev_k, D_k)
]
# ---- J + K via EWALD_3D gauge (2026-05-20). --------------
# Fold the exact density used for this iteration. This is
# essential for smearing and damping: rebuilding from ``C_k`` and
# hard ``n_occ`` would silently feed an integer-Aufbau density to
# the Fock builder while reporting fractional occupations.
D_real = _real_space_density_from_per_k_density(
D_used,
kmesh_bloch,
cells,
)
F_k = build_periodic_fock_ewald3d_k(
basis,
system,
D_real,
omega=0.5,
k_points_cart=[np.asarray(k) for k in kpoints_cart],
Hcore_k=None,
lattice_opts=lat_opts,
exchange_scale=alpha,
)
# XC on gamma-cell density.
V_xc_real = None
E_xc = 0.0
if is_ks:
D_real_gamma = np.asarray(_g0_block(D_real))
D_set = compute_overlap_lattice(basis, system, lat_opts)
n_cells = len(D_set)
zero = np.zeros_like(D_real_gamma)
for g in range(n_cells):
D_set.set_block(g, D_real_gamma if g == 0 else zero)
xc_contrib = build_xc_periodic(
basis,
system,
grid,
func,
D_set,
lat_opts,
)
V_xc_real = np.real(bloch_sum(xc_contrib.V_xc, np.zeros(3)))
V_xc_real = 0.5 * (V_xc_real + V_xc_real.T)
E_xc = float(xc_contrib.e_xc)
# Save F_2e before adding V_xc (for energy decomposition).
F_2e_k = [np.asarray(f).copy() for f in F_k]
for i in range(n_k):
Fi = F_k[i] + Hcore_k[i]
if V_xc_real is not None:
Fi = Fi + V_xc_real.astype(complex)
Fi = 0.5 * (Fi + Fi.conj().T)
F_k[i] = Fi
# ---- Energy decomposition (J-only Fock for E_J, E_K). ------
E_coulomb = 0.0
E_hf_K = 0.0
if alpha != 0.0:
F_J_k = build_periodic_fock_ewald3d_k(
basis,
system,
D_real,
omega=0.5,
k_points_cart=[np.asarray(k) for k in kpoints_cart],
Hcore_k=None,
lattice_opts=lat_opts,
exchange_scale=0.0,
)
E_J_val = 0.0
E_K_val = 0.0
for i in range(n_k):
w = float(weights[i])
J_k = np.asarray(F_J_k[i])
K_k = 2.0 * (J_k - F_2e_k[i]) / alpha
K_k = 0.5 * (K_k + K_k.conj().T)
E_J_val += w * float(np.real(np.trace(D_used[i] @ J_k)))
E_K_val += w * float(np.real(np.trace(D_used[i] @ K_k)))
E_coulomb = 0.5 * E_J_val
E_hf_K = -0.25 * alpha * E_K_val
# ---- Energy (E_elec = Tr[D.Hcore] + 0.5 Tr[D.F_2e]). -------
E_elec = 0.0
for i in range(n_k):
w = float(weights[i])
Di = D_used[i]
Hi = Hcore_k[i]
Fi = F_2e_k[i]
E_elec += w * float(
np.real(np.trace(Di @ Hi)) + 0.5 * np.real(np.trace(Di @ Fi))
)
E_total = E_elec + E_xc + float(e_nuc)
free_energy = E_total - smearing_T * entropy
# ---- Convergence -------------------------------------------
grad_k: List[np.ndarray] = []
grad_norm_sq = 0.0
for i in range(n_k):
FDS = F_k[i] @ D_used[i] @ S_k[i]
err = FDS - FDS.conj().T
grad_k.append(err)
grad_norm_sq += float(weights[i]) * float(np.linalg.norm(err) ** 2)
grad_norm = float(np.sqrt(grad_norm_sq))
dE = free_energy - E_prev
scf_trace.append(
SCFIteration(
iter=it,
energy=float(free_energy),
delta_e=float(dE if it > 1 else 0.0),
grad_norm=float(grad_norm),
diis_subspace=(diis.subspace_size if diis is not None else 0),
)
)
plog.iteration(
it,
energy=float(free_energy),
dE=float(dE if it > 1 else 0.0),
grad=float(grad_norm),
diis=(diis.subspace_size if diis is not None else 0),
)
converged = (
it > 1
and (warmup_cycles == 0 or it > warmup_cycles)
and abs(dE) < float(opts.conv_tol_energy)
and grad_norm < float(opts.conv_tol_grad)
)
# ---- DIIS, FMIXING, level shift ----------------------------
if diis is not None:
F_diis = diis.extrapolate(F_k, grad_k, weights)
if diis_active:
F_k = F_diis
if fock_mixing_value != 0.0:
if F_prev_k is not None:
F_mixed: List[np.ndarray] = []
for i in range(n_k):
Fmix = (1.0 - fock_mixing_value) * F_k[
i
] + fock_mixing_value * F_prev_k[i]
F_mixed.append(0.5 * (Fmix + Fmix.conj().T))
F_k = F_mixed
F_prev_k = [F.copy() for F in F_k]
# Per-k Saunders–Hillier level shift (only at diagonalization).
F_diag = []
for i in range(n_k):
if active_level_shift != 0.0:
Fi_shifted = (
F_k[i]
+ active_level_shift * S_k[i]
- (active_level_shift / 2.0) * (S_k[i] @ D_used[i] @ S_k[i])
)
Fi_shifted = 0.5 * (Fi_shifted + Fi_shifted.conj().T)
F_diag.append(Fi_shifted)
else:
F_diag.append(F_k[i])
# ---- Diagonalise per k + occupations + density -------------
C_new: List[np.ndarray] = []
eps_new: List[np.ndarray] = []
for i in range(n_k):
Ci, ei = _diag_in_orth_basis(F_diag[i], X_k[i])
C_new.append(Ci.astype(complex))
eps_new.append(ei)
occ_k, fermi_level, entropy = _occupations_per_k(
eps_new,
weights,
n_elec,
smearing_T,
n_occ,
)
D_new = _density_from_orbitals(C_new, occ_k)
D_prev_k = D_used
D_k = D_new
C_k = C_new
eps_k = eps_new
E_prev = free_energy
# Update the result placeholder so partial-run callers see
# the last iter's state. We do this every iter (rather than
# only on converge) so a max_iter abort still yields useful
# numbers.
result.energy = E_total
result.e_electronic = E_elec
result.e_xc = E_xc
result.e_coulomb = E_coulomb
result.e_hf_exchange = E_hf_K
result.n_iter = it
result.mo_energies = [e.copy() for e in eps_new]
result.mo_coeffs = [C.copy() for C in C_new]
result.fock = [F.copy() for F in F_k]
result.density = [D.copy() for D in D_k]
result.fermi_level = float(fermi_level)
result.entropy = float(entropy)
result.free_energy = float(free_energy)
result.occupations = [np.asarray(o, dtype=float) for o in occ_k]
if converged:
result.converged = True
plog.converged(
n_iter=result.n_iter,
energy=result.energy,
converged=True,
)
return result
result.converged = False
plog.converged(
n_iter=result.n_iter,
energy=result.energy,
converged=False,
)
return result
[docs]
def run_krks_periodic_gdf(
system: PeriodicSystem,
basis: BasisSet,
kmesh: Union[Sequence[int], KPoints, BlochKMesh] = (1, 1, 1),
options: Optional[PeriodicKSOptions] = None,
*,
functional: Optional[str] = None,
aux_basis: Optional[str] = None,
aux_drop_eta: float = 0.0,
gdf_linear_dep_threshold: float = 1e-9,
apply_modrho: bool = True,
fock_mixing: Optional[float] = None,
level_shift_warmup_cycles: Optional[int] = None,
linear_dep_threshold: float = 1e-7,
progress: Union[bool, ProgressLogger, None] = None,
verbose: Optional[int] = None,
) -> PeriodicKRKSGDFResult:
"""Run closed-shell periodic KS-DFT multi-k SCF via native GDF.
Thin wrapper around :func:`run_krhf_periodic_gdf` that asserts a
functional has been provided. Functional dispatch and exact-
exchange mixing happen inside the shared SCF loop.
"""
opts = _options_or_default(options, is_ks=True)
func = functional or getattr(opts, "functional", None)
if not func:
raise ValueError("run_krks_periodic_gdf requires functional=...")
return run_krhf_periodic_gdf(
system,
basis,
kmesh,
opts,
functional=str(func),
aux_basis=aux_basis,
aux_drop_eta=aux_drop_eta,
gdf_linear_dep_threshold=gdf_linear_dep_threshold,
apply_modrho=apply_modrho,
fock_mixing=fock_mixing,
level_shift_warmup_cycles=level_shift_warmup_cycles,
linear_dep_threshold=linear_dep_threshold,
progress=progress,
verbose=verbose,
)