"""Phase 15c-3b: multi-k open-shell UKS SCF driver with composed
EWALD_3D Coulomb dispatch.
Open-shell DFT counterpart of:
* :func:`run_uhf_periodic_multi_k_ewald3d` — same per-spin
DIIS / damping / orthogonalisation structure.
* :func:`run_uks_periodic_gamma_ewald3d` — Γ-only UKS Ewald
(15c-3a).
* :func:`run_rks_periodic_multi_k_ewald3d` — multi-k closed-shell
KS Ewald (15c-2).
Per SCF iteration:
F_α(k) = H_core(k) + Bloch_k[J(D_α + D_β, ω) − α·K(D_α)]
+ Bloch_k[V_xc^α(g)]
F_β(k) = H_core(k) + Bloch_k[J(D_α + D_β, ω) − α·K(D_β)]
+ Bloch_k[V_xc^β(g)]
with α = func.hf_exchange_fraction. For pure DFT (α = 0) the K
builds are skipped entirely; for hybrids the per-spin K is computed
via ``build_fock_2e_real_space(..., exchange_scale=1, omega=0)`` and
extracted as ``K(D_σ) = 2·(J_full(D_σ) − F_full(D_σ))`` (mirrors the
UHF multi-k pattern), then scaled by α.
Density flow. Multi-k carries proper LatticeMatrixSets
``D_α_real``, ``D_β_real`` (one-particle, no factor of 2). The
periodic UKS XC kernel :func:`build_xc_periodic_uks` consumes them
directly and returns LatticeMatrixSets ``V_xc^α(g)``, ``V_xc^β(g)``
which are Bloch-summed per k.
Energy formula (mirrors molecular UKS + the multi-k UHF Ewald
convention):
E_elec = E_xc + Σ_k w_k · ½ Re tr[(D_α(k) + D_β(k))·H_core(k)]
+ Σ_k w_k · ½ Re tr[D_α(k)·F_α^{HF}(k)]
+ Σ_k w_k · ½ Re tr[D_β(k)·F_β^{HF}(k)]
where ``F_σ^{HF}(k) = Bloch_k[J − α·K_σ]`` is the Hartree-plus-HF
piece of F_σ (V_xc reported through E_xc rather than a trace).
Scope.
* Multi-k open-shell, ``multiplicity ≥ 1``, integer α/β occupations.
* Pure DFT, hybrid, and HF (α = 1, equivalent to UHF Ewald).
* Per-spin Pulay DIIS with k-weighted Frobenius inner product.
* Saunders-Hillier level shift via ``options.level_shift``.
* Periodic Becke partition selectable via
``options.use_periodic_becke``.
* ⟨S²⟩ diagnostic on the Γ-block (or first) k-point — same shortcut
the multi-k UHF Ewald driver uses.
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
from ._vibeqc_core import (
BasisSet,
BlochKMesh,
CoulombMethod,
Functional,
InitialGuess,
LatticeMatrixSet,
LatticeSumOptions,
PeriodicKSOptions,
PeriodicSystem,
SCFIteration,
bloch_sum,
build_fock_2e_real_space,
build_grid,
build_xc_periodic_uks,
compute_kinetic_lattice,
compute_nuclear_lattice,
compute_overlap_lattice,
nuclear_repulsion_per_cell,
real_space_density_from_kpoints_fractional,
)
from .ewald_j import auto_grid
from .guess import initial_densities_open_shell
from .madelung import (
madelung_energy_correction_for_lat as _madelung_energy_correction_for_lat,
)
from .periodic_density import build_j_long_range_periodic
from .periodic_grid import build_periodic_becke_grid
from .periodic_rhf_multi_k_ewald import (
_canonical_orthogonalizer_complex,
_damp_lattice_matrix,
_diag_in_orth_basis,
_g0_block,
_MultiKPulayDIIS,
)
from .periodic_uhf_ewald import _spin_squared
from .progress import ProgressLogger, resolve_progress
from .scf_divergence import check_scf_divergence
__all__ = [
"PeriodicUKSMultiKEwaldResult",
"run_uks_periodic_multi_k_ewald3d",
]
[docs]
@dataclass
class PeriodicUKSMultiKEwaldResult:
"""Result of :func:`run_uks_periodic_multi_k_ewald3d`."""
energy: float
e_electronic: float
e_nuclear: float
e_xc: float
e_coulomb: float
e_hf_exchange: float
n_iter: int
converged: bool
s_squared: float
s_squared_ideal: float
# α
mo_energies_alpha: List[np.ndarray]
mo_coeffs_alpha: List[np.ndarray]
fock_alpha: List[np.ndarray]
density_alpha: LatticeMatrixSet
# β
mo_energies_beta: List[np.ndarray]
mo_coeffs_beta: List[np.ndarray]
fock_beta: List[np.ndarray]
density_beta: LatticeMatrixSet
overlap: List[np.ndarray]
hcore: List[np.ndarray]
functional: str = ""
scf_trace: List[SCFIteration] = field(default_factory=list)
omega: float = 0.0
grid_shape: Tuple[int, int, int] = (0, 0, 0)
def _build_uks_fock_2e_blocks_ewald3d(
basis: BasisSet,
system: PeriodicSystem,
D_alpha_real: LatticeMatrixSet,
D_beta_real: LatticeMatrixSet,
omega: float,
alpha: float,
lat_opts: LatticeSumOptions,
grid_shape_t: Tuple[int, int, int],
origin: Optional[Sequence[float]],
spacing_bohr: float,
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""Per-cell F^{2e,σ}(g) = J(D_total, ω)(g) − α·K(D_σ)(g) blocks.
Mirrors :func:`vibeqc.periodic_uhf_multi_k_ewald._build_uhf_fock_blocks_ewald3d`
with the per-spin K scaled by ``alpha`` (1.0 = HF; 0.0 = pure
DFT, K builds skipped). For pure DFT we save four real-space 2e
builds per Fock evaluation."""
n_cells = len(D_alpha_real.cells)
# D_total = D_α + D_β as a LatticeMatrixSet.
D_total_real = compute_overlap_lattice(basis, system, lat_opts)
for g in range(n_cells):
D_total_real.set_block(
g,
np.asarray(D_alpha_real.blocks[g], dtype=float)
+ np.asarray(D_beta_real.blocks[g], dtype=float),
)
# J(D_total, ω) — Ewald split (always needed).
J_SR_total_lms = build_fock_2e_real_space(
basis,
system,
lat_opts,
D_total_real,
0.0,
float(omega),
)
J_LR_total_blocks = build_j_long_range_periodic(
basis,
system,
D_total_real,
omega=float(omega),
grid_shape=grid_shape_t,
origin=origin,
spacing_bohr=spacing_bohr,
output_cells=list(range(n_cells)),
)
if alpha == 0.0:
F_alpha_blocks: List[np.ndarray] = []
F_beta_blocks: List[np.ndarray] = []
for g in range(n_cells):
J_total = np.asarray(J_SR_total_lms.blocks[g], dtype=float) + np.asarray(
J_LR_total_blocks[g], dtype=float
)
F_alpha_blocks.append(J_total)
F_beta_blocks.append(J_total)
return F_alpha_blocks, F_beta_blocks
# Hybrid path: per-spin K via the standard J_full / F_full pair.
J_full_alpha = build_fock_2e_real_space(
basis,
system,
lat_opts,
D_alpha_real,
0.0,
0.0,
)
F_full_alpha = build_fock_2e_real_space(
basis,
system,
lat_opts,
D_alpha_real,
1.0,
0.0,
)
J_full_beta = build_fock_2e_real_space(
basis,
system,
lat_opts,
D_beta_real,
0.0,
0.0,
)
F_full_beta = build_fock_2e_real_space(
basis,
system,
lat_opts,
D_beta_real,
1.0,
0.0,
)
F_alpha_blocks = []
F_beta_blocks = []
for g in range(n_cells):
J_total = np.asarray(J_SR_total_lms.blocks[g], dtype=float) + np.asarray(
J_LR_total_blocks[g], dtype=float
)
# K(D_σ) = 2 · (J_full(D_σ) − F_full(D_σ)).
J_full_a = np.asarray(J_full_alpha.blocks[g], dtype=float)
F_full_a = np.asarray(F_full_alpha.blocks[g], dtype=float)
K_a = 2.0 * (J_full_a - F_full_a)
J_full_b = np.asarray(J_full_beta.blocks[g], dtype=float)
F_full_b = np.asarray(F_full_beta.blocks[g], dtype=float)
K_b = 2.0 * (J_full_b - F_full_b)
F_alpha_blocks.append(J_total - alpha * K_a)
F_beta_blocks.append(J_total - alpha * K_b)
return F_alpha_blocks, F_beta_blocks
def _bloch_sum_blocks(
blocks: Sequence[np.ndarray],
cells,
k_cart: np.ndarray,
) -> np.ndarray:
k = np.asarray(k_cart, dtype=float).reshape(3)
F_k = np.zeros_like(blocks[0], dtype=complex)
for g_idx, block in enumerate(blocks):
R_g = np.asarray(cells[g_idx].r_cart, dtype=float)
phase = np.exp(1j * float(np.dot(k, R_g)))
F_k = F_k + phase * block
return F_k
def _bloch_sum_lms_at_k(
lms: LatticeMatrixSet,
k_cart: np.ndarray,
) -> np.ndarray:
return np.asarray(bloch_sum(lms, np.asarray(k_cart, dtype=float).reshape(3)))
def run_uks_periodic_multi_k_ewald3d(
system: PeriodicSystem,
basis: BasisSet,
kmesh: BlochKMesh,
options: Optional[PeriodicKSOptions] = None,
*,
omega: float = 0.0,
grid_shape: Optional[Union[Tuple[int, int, int], int]] = None,
origin: Optional[Sequence[float]] = None,
spacing_bohr: float = 0.3,
linear_dep_threshold: float = 1e-7,
canonical_orth_normalize_diag_first: bool = True,
auto_optimize_truncation: bool = True,
progress: Union[bool, ProgressLogger, None] = None,
verbose: Optional[int] = None,
) -> PeriodicUKSMultiKEwaldResult:
"""Multi-k open-shell periodic Kohn-Sham SCF with EWALD_3D Coulomb.
Parameters
----------
system, basis, kmesh
Periodic system, AO basis, k-mesh.
options
Optional :class:`PeriodicKSOptions`. Reads ``functional``,
``grid``, ``use_periodic_becke``, ``becke_image_radius_bohr``,
``level_shift``, ``damping``, ``max_iter``, ``conv_tol_*``,
``diis_*``, ``initial_guess``, ``lattice_opts``.
``smearing_temperature`` is not implemented for UKS yet; a
positive value raises ``NotImplementedError`` so metallic UKS
inputs are never run with silently integer occupations.
Returns
-------
:class:`PeriodicUKSMultiKEwaldResult`.
"""
opts = options if options is not None else PeriodicKSOptions()
smearing_T = float(getattr(opts, "smearing_temperature", 0.0))
if smearing_T < 0.0:
raise ValueError(
"run_uks_periodic_multi_k_ewald3d: smearing_temperature must be >= 0"
)
if smearing_T > 0.0:
raise NotImplementedError(
"run_uks_periodic_multi_k_ewald3d: metallic UKS smearing "
"is not implemented yet. Closed-shell periodic RHF/RKS "
"multi-k Ewald supports Fermi-Dirac smearing; UKS needs "
"spin-resolved fractional occupations and chemical "
"potential handling."
)
lat_opts: LatticeSumOptions = opts.lattice_opts
plog = resolve_progress(progress, verbose=verbose)
# ω must match the nuclear Ewald α (auto-selected from
# nuclear_cutoff_bohr in the C++ ewald engine) so the jellium
# background terms cancel exactly. Mirrors the override block in
# the sibling Ewald drivers (commit 49f8ae91 / 433d3543). The
# driver kwarg ``omega`` is retained for signature parity but is
# overridden here; users override via ``opts.ewald_omega``.
_ewald_tol = getattr(opts, "ewald_tolerance", 1e-12)
_cutoff = getattr(opts, "ewald_cutoff_bohr", lat_opts.nuclear_cutoff_bohr)
if omega <= 0.0:
_user_omega = getattr(opts, "ewald_omega", None)
if _user_omega is not None and float(_user_omega) > 0.0:
omega = float(_user_omega)
else:
from .bipole_ext_el_pole import crystal_default_ewald_alpha
V_cell = float(abs(np.linalg.det(np.asarray(system.lattice, dtype=float))))
omega = crystal_default_ewald_alpha(V_cell)
lat = np.asarray(system.lattice, dtype=float)
if grid_shape is None:
grid_shape_t = auto_grid(lat, spacing_bohr)
elif isinstance(grid_shape, int):
grid_shape_t = (grid_shape, grid_shape, grid_shape)
else:
grid_shape_t = tuple(int(x) for x in grid_shape)
plog.info(
f"UKS multi-k EWALD_3D / functional={opts.functional!r}, "
f"omega = {float(omega):.3f}, "
f"FFT grid {grid_shape_t[0]}x{grid_shape_t[1]}x{grid_shape_t[2]}"
)
plog.info(f"basis: {basis.name} ({basis.nbasis} BFs / {basis.nshells} shells)")
from .options_dump import dump_active_settings
dump_active_settings(
plog,
[
("PeriodicKSOptions", opts),
("LatticeSumOptions", lat_opts),
(
"Driver kwargs",
{
"omega": float(omega),
"grid_shape": grid_shape_t,
"origin": origin,
"spacing_bohr": float(spacing_bohr),
"linear_dep_threshold": float(linear_dep_threshold),
"canonical_orth_normalize_diag_first": canonical_orth_normalize_diag_first,
"auto_optimize_truncation": auto_optimize_truncation,
},
),
],
)
if plog.level >= 5:
from .scf_log import format_basis_summary
plog.write_raw(format_basis_summary(basis))
n_elec = int(system.n_electrons())
mult = int(system.multiplicity)
if mult < 1:
raise ValueError(
f"run_uks_periodic_multi_k_ewald3d: multiplicity must be ≥ 1, got {mult}"
)
if (n_elec + mult - 1) % 2 != 0 or (n_elec - mult + 1) % 2 != 0:
raise ValueError(
f"run_uks_periodic_multi_k_ewald3d: (n_electrons={n_elec}, "
f"multiplicity={mult}) cannot be split into integer α/β."
)
n_alpha = (n_elec + mult - 1) // 2
n_beta = (n_elec - mult + 1) // 2
# ---- Functional + DFT grid ------------------------------------------
func = Functional(opts.functional, 2) # spin-polarized
alpha = float(func.hf_exchange_fraction)
if opts.use_periodic_becke:
grid = build_periodic_becke_grid(
system,
grid_options=opts.grid,
image_radius_bohr=float(opts.becke_image_radius_bohr),
)
else:
grid = build_grid(system.unit_cell_molecule(), opts.grid)
k_points = list(kmesh.kpoints)
weights = np.asarray(kmesh.weights, dtype=float)
n_k = len(k_points)
if n_k == 0:
raise ValueError("kmesh has no k-points")
if not np.isclose(weights.sum(), 1.0):
raise ValueError(f"kmesh.weights must sum to 1; got {weights.sum():.6f}")
plog.info(
f"k-mesh: {n_k} k-point{'s' if n_k != 1 else ''}, "
f"weights sum = {weights.sum():.4f}; "
f"n_alpha = {n_alpha}, n_beta = {n_beta}"
)
# ---- Auto-optimise lattice truncation (default ON) -------------------
if auto_optimize_truncation and lat_opts.coulomb_method == CoulombMethod.EWALD_3D:
from .eigs_preflight import (
format_truncation_optimization_report,
optimize_truncation,
)
k_arr = [np.asarray(k, dtype=float) for k in k_points]
opt_rep = optimize_truncation(
system,
basis,
lattice_opts=lat_opts,
k_points_cart=k_arr,
)
if (
opt_rep.n_evaluations > 1
or opt_rep.optimized_lattice_opts.cutoff_bohr != lat_opts.cutoff_bohr
):
plog.write_raw(format_truncation_optimization_report(opt_rep))
if not opt_rep.converged:
plog.warning("auto_optimize_truncation did not converge.")
lat_opts = opt_rep.optimized_lattice_opts
with plog.stage(
"integrals_lattice", detail=f"S/T/V at cutoff {lat_opts.cutoff_bohr:.2f} bohr"
):
S_lat = compute_overlap_lattice(basis, system, lat_opts)
T_lat = compute_kinetic_lattice(basis, system, lat_opts)
from .periodic_v_ne import compute_nuclear_lattice_dispatch
V_lat = compute_nuclear_lattice_dispatch(basis, system, lat_opts)
cells = list(S_lat.cells)
n_cells = len(cells)
S_k_list: List[np.ndarray] = []
Hcore_k_list: List[np.ndarray] = []
X_k_list: List[np.ndarray] = []
# Per-k linear-dependence preflight; see periodic_rhf_multi_k_ewald
# for the rationale (Searle et al., ARCHER eCSE04-16, 2017).
from .linear_dependence import scf_preflight_overlap_check
for k_idx, k in enumerate(k_points):
k_arr = np.asarray(k, dtype=float).reshape(3)
S_k = np.asarray(bloch_sum(S_lat, k_arr))
T_k = np.asarray(bloch_sum(T_lat, k_arr))
V_k = np.asarray(bloch_sum(V_lat, k_arr))
H_k = T_k + V_k
S_k = 0.5 * (S_k + S_k.conj().T)
H_k = 0.5 * (H_k + H_k.conj().T)
scf_preflight_overlap_check(
S_k,
plog=plog,
label=f"S(k={k_idx}, k_cart={k_arr.round(4).tolist()})",
basis=basis,
)
X_k, n_kept = _canonical_orthogonalizer_complex(
S_k,
linear_dep_threshold,
normalize_diag_first=canonical_orth_normalize_diag_first,
)
if max(n_alpha, n_beta) > n_kept:
raise RuntimeError(
f"run_uks_periodic_multi_k_ewald3d: orth dropped too many "
f"directions (n_α={n_alpha}, n_β={n_beta}, "
f"n_kept={n_kept}) at k = {k_arr}"
)
S_k_list.append(S_k)
Hcore_k_list.append(H_k)
X_k_list.append(X_k)
e_nuc = float(nuclear_repulsion_per_cell(system, lat_opts))
# ---- Initial guess --------------------------------------------------
C_alpha_per_k: List[np.ndarray] = []
eps_alpha_per_k: List[np.ndarray] = []
C_beta_per_k: List[np.ndarray] = []
eps_beta_per_k: List[np.ndarray] = []
for H_k, X_k in zip(Hcore_k_list, X_k_list):
C_a, eps_a = _diag_in_orth_basis(H_k, X_k)
C_b, eps_b = _diag_in_orth_basis(H_k, X_k)
C_alpha_per_k.append(C_a.astype(complex))
eps_alpha_per_k.append(eps_a)
C_beta_per_k.append(C_b.astype(complex))
eps_beta_per_k.append(eps_b)
def _spin_density(C_per_k_local, n_occ_each):
"""One-particle (no factor 2) real-space spin density via the
C++ fractional-occupation builder. Mirrors the UHF multi-k
Ewald driver convention."""
nbf = C_per_k_local[0].shape[1]
occ_per_k = []
for _ in range(n_k):
occ = np.zeros(nbf, dtype=float)
occ[:n_occ_each] = 1.0
occ_per_k.append(occ)
return real_space_density_from_kpoints_fractional(
C_per_k_local,
occ_per_k,
kmesh,
cells,
)
D_alpha_real = _spin_density(C_alpha_per_k, n_alpha)
D_beta_real = _spin_density(C_beta_per_k, n_beta)
# Density-mode guesses: overwrite per-spin densities at g=0 with
# the engine output (proportional spin split — matches molecular
# UHF behaviour). Closed-shell-like cases (n_α == n_β) reduce to
# the even split that this driver previously used inline.
guess = getattr(opts, "initial_guess", InitialGuess.HCORE)
split = initial_densities_open_shell(
system.unit_cell_molecule(),
basis,
n_alpha,
n_beta,
guess,
is_periodic=True,
)
if split is not None:
plog.info(f"initial guess: {guess.name} (spin-split density via GuessEngine)")
D_a_sad, D_b_sad = split
zero_block = np.zeros_like(D_a_sad, dtype=float)
for g_idx in range(len(D_alpha_real.cells)):
is_g0 = (D_alpha_real.cells[g_idx].index == np.array([0, 0, 0])).all()
D_alpha_real.set_block(g_idx, D_a_sad if is_g0 else zero_block)
D_beta_real.set_block(g_idx, D_b_sad if is_g0 else zero_block)
else:
plog.info(
f"initial guess: {guess.name} "
"(Hcore-diagonalise at each k, spin-degenerate)"
)
D_alpha_prev: Optional[LatticeMatrixSet] = None
D_beta_prev: Optional[LatticeMatrixSet] = None
damping = float(opts.damping)
if not (0.0 <= damping < 1.0):
raise ValueError(
f"run_uks_periodic_multi_k_ewald3d: damping must be in "
f"[0, 1); got {damping}"
)
use_diis = bool(opts.use_diis)
diis_start_iter = int(opts.diis_start_iter)
diis_alpha = (
_MultiKPulayDIIS(max_subspace=int(opts.diis_subspace_size))
if use_diis
else None
)
diis_beta = (
_MultiKPulayDIIS(max_subspace=int(opts.diis_subspace_size))
if use_diis
else None
)
level_shift = float(getattr(opts, "level_shift", 0.0))
# Phase C1c — quadratic SCF fallback (per-spin per-k Newton step).
quadratic_fallback_iter = int(getattr(opts, "quadratic_fallback_iter", 0))
quadratic_fallback_shift = float(getattr(opts, "quadratic_fallback_shift", 0.1))
quadratic_fallback_max_step = float(
getattr(opts, "quadratic_fallback_max_step", 0.1)
)
# ---- SCF loop -------------------------------------------------------
scf_trace: List[SCFIteration] = []
E_prev = 0.0
F_alpha_k_list: List[np.ndarray] = [np.zeros_like(H) for H in Hcore_k_list]
F_beta_k_list: List[np.ndarray] = [np.zeros_like(H) for H in Hcore_k_list]
F_HF_alpha_k_list: List[np.ndarray] = list(F_alpha_k_list)
F_HF_beta_k_list: List[np.ndarray] = list(F_beta_k_list)
E_xc = 0.0
E_coulomb_per_cell = 0.0
E_hf_K_per_cell = 0.0
plog.banner(f"SCF (UKS multi-k {opts.functional!r}, EWALD_3D)")
plog.info(" iter energy (Ha) dE ||[F,DS]|| DIIS")
converged = False
iter_idx = 0
for iter_idx in range(1, int(opts.max_iter) + 1):
diis_active = use_diis and iter_idx >= diis_start_iter
if iter_idx > 1 and damping > 0.0 and not diis_active:
D_alpha_used = _damp_lattice_matrix(
D_alpha_real,
D_alpha_prev,
damping,
)
D_beta_used = _damp_lattice_matrix(
D_beta_real,
D_beta_prev,
damping,
)
else:
D_alpha_used = D_alpha_real
D_beta_used = D_beta_real
# Per-spin 2e Fock blocks F^{2e,σ}(g) = J(D_total) − α·K(D_σ).
F_HF_alpha_blocks, F_HF_beta_blocks = _build_uks_fock_2e_blocks_ewald3d(
basis,
system,
D_alpha_used,
D_beta_used,
omega,
alpha,
lat_opts,
grid_shape_t,
origin,
spacing_bohr,
)
# Periodic UKS XC: V_xc^σ(g) lattice + scalar E_xc.
xc = build_xc_periodic_uks(
basis,
system,
grid,
func,
D_alpha_used,
D_beta_used,
lat_opts,
)
E_xc = float(xc.e_xc)
# Bloch-sum F^{2e,σ}(g) and V_xc^σ(g) at every k, add Hcore(k).
F_alpha_k_list = []
F_beta_k_list = []
F_HF_alpha_k_list = []
F_HF_beta_k_list = []
for k_idx, k_cart in enumerate(k_points):
k_arr = np.asarray(k_cart)
F_HF_a_k = _bloch_sum_blocks(F_HF_alpha_blocks, cells, k_arr)
F_HF_b_k = _bloch_sum_blocks(F_HF_beta_blocks, cells, k_arr)
V_xc_a_k = _bloch_sum_lms_at_k(xc.V_alpha, k_arr)
V_xc_b_k = _bloch_sum_lms_at_k(xc.V_beta, k_arr)
F_a = Hcore_k_list[k_idx] + F_HF_a_k + V_xc_a_k
F_b = Hcore_k_list[k_idx] + F_HF_b_k + V_xc_b_k
F_a = 0.5 * (F_a + F_a.conj().T)
F_b = 0.5 * (F_b + F_b.conj().T)
F_alpha_k_list.append(F_a)
F_beta_k_list.append(F_b)
F_HF_alpha_k_list.append(F_HF_a_k)
F_HF_beta_k_list.append(F_HF_b_k)
# Energy + per-k errors.
# E_elec = E_xc + Σ_k w_k [½ Re tr((D_α + D_β)·H_k)
# + ½ Re tr(D_α(k)·F_HF_α(k))
# + ½ Re tr(D_β(k)·F_HF_β(k))]
E_core_trace = 0.0
E_HF_alpha_trace = 0.0
E_HF_beta_trace = 0.0
grad_norm_sum = 0.0
error_alpha_k_list: List[np.ndarray] = []
error_beta_k_list: List[np.ndarray] = []
for idx in range(n_k):
C_a = C_alpha_per_k[idx]
C_b = C_beta_per_k[idx]
C_a_occ = C_a[:, :n_alpha] if n_alpha > 0 else C_a[:, :0]
C_b_occ = C_b[:, :n_beta] if n_beta > 0 else C_b[:, :0]
D_a_k = C_a_occ @ C_a_occ.conj().T
D_b_k = C_b_occ @ C_b_occ.conj().T
H_k = Hcore_k_list[idx]
F_a_k = F_alpha_k_list[idx]
F_b_k = F_beta_k_list[idx]
F_HF_a_k = F_HF_alpha_k_list[idx]
F_HF_b_k = F_HF_beta_k_list[idx]
w = float(weights[idx])
# NOTE: prefactor is 1.0 (not ½) because the per-spin
# contribution below uses F_HF_σ (Hartree + scaled-K only,
# no Hcore). Compare the multi-k UHF Ewald driver, which
# uses ½ on Hcore *and* uses the full F (Hcore included)
# inside the per-spin terms — equivalent total.
E_core_trace += w * np.real(np.trace((D_a_k + D_b_k) @ H_k))
E_HF_alpha_trace += w * 0.5 * np.real(np.trace(D_a_k @ F_HF_a_k))
E_HF_beta_trace += w * 0.5 * np.real(np.trace(D_b_k @ F_HF_b_k))
S_k = S_k_list[idx]
FDS_a = F_a_k @ D_a_k @ S_k
FDS_b = F_b_k @ D_b_k @ S_k
err_a = FDS_a - FDS_a.conj().T
err_b = FDS_b - FDS_b.conj().T
error_alpha_k_list.append(err_a)
error_beta_k_list.append(err_b)
grad_norm_sum += w * float(
np.sqrt(np.linalg.norm(err_a) ** 2 + np.linalg.norm(err_b) ** 2)
)
E_elec = (
E_xc
+ float(E_core_trace)
+ float(E_HF_alpha_trace)
+ float(E_HF_beta_trace)
)
# Madelung-leak correction (v0.6.1). For UKS, total density
# is D_α + D_β at the unit cell.
_D_g0 = np.asarray(_g0_block(D_alpha_real)) + np.asarray(_g0_block(D_beta_real))
_S_g0 = np.asarray(_g0_block(S_lat))
E_madelung_fix = _madelung_energy_correction_for_lat(
_D_g0, _S_g0, system, lat_opts
)
E_total = E_elec + e_nuc + E_madelung_fix
dE = E_total - E_prev
# Divergence detection (v0.6.2).
check_scf_divergence(
"run_uks_periodic_multi_k_ewald3d",
iter_idx,
E_total,
grad_norm_sum,
dE,
)
diis_sub = 0
if diis_alpha is not None:
diis_sub = max(diis_sub, diis_alpha.subspace_size)
if diis_beta is not None:
diis_sub = max(diis_sub, diis_beta.subspace_size)
scf_trace.append(
SCFIteration(
iter=iter_idx,
energy=float(E_total),
delta_e=float(dE if iter_idx > 1 else 0.0),
grad_norm=float(grad_norm_sum),
diis_subspace=diis_sub,
)
)
plog.iteration(
iter_idx,
energy=float(E_total),
dE=float(dE if iter_idx > 1 else 0.0),
grad=float(grad_norm_sum),
diis=diis_sub,
)
converged = (
iter_idx > 1
and abs(dE) < float(opts.conv_tol_energy)
and grad_norm_sum < float(opts.conv_tol_grad)
)
# Phase C1c gate.
in_quadratic_phase = (
quadratic_fallback_iter > 0 and iter_idx > quadratic_fallback_iter
)
new_C_alpha: List[np.ndarray] = []
new_eps_alpha: List[np.ndarray] = []
new_C_beta: List[np.ndarray] = []
new_eps_beta: List[np.ndarray] = []
if in_quadratic_phase:
from .quadratic_scf import quadratic_step
for idx in range(n_k):
C_a, eps_a = quadratic_step(
F_alpha_k_list[idx],
C_alpha_per_k[idx],
eps_alpha_per_k[idx],
n_alpha,
shift=quadratic_fallback_shift,
max_step=quadratic_fallback_max_step,
)
C_b, eps_b = quadratic_step(
F_beta_k_list[idx],
C_beta_per_k[idx],
eps_beta_per_k[idx],
n_beta,
shift=quadratic_fallback_shift,
max_step=quadratic_fallback_max_step,
)
new_C_alpha.append(C_a)
new_eps_alpha.append(eps_a)
new_C_beta.append(C_b)
new_eps_beta.append(eps_b)
else:
# DIIS extrapolation per spin.
if diis_alpha is not None and diis_beta is not None:
F_a_ex = diis_alpha.extrapolate(
F_alpha_k_list,
error_alpha_k_list,
weights,
)
F_b_ex = diis_beta.extrapolate(
F_beta_k_list,
error_beta_k_list,
weights,
)
if diis_active:
F_alpha_k_list = F_a_ex
F_beta_k_list = F_b_ex
# Saunders-Hillier level shift per spin per k.
if level_shift != 0.0:
for idx in range(n_k):
S_k = S_k_list[idx]
C_a = C_alpha_per_k[idx]
C_b = C_beta_per_k[idx]
C_a_occ = C_a[:, :n_alpha] if n_alpha > 0 else C_a[:, :0]
C_b_occ = C_b[:, :n_beta] if n_beta > 0 else C_b[:, :0]
D_a_k = C_a_occ @ C_a_occ.conj().T
D_b_k = C_b_occ @ C_b_occ.conj().T
F_alpha_k_list[idx] = (
F_alpha_k_list[idx]
+ level_shift * S_k
- level_shift * (S_k @ D_a_k @ S_k)
)
F_beta_k_list[idx] = (
F_beta_k_list[idx]
+ level_shift * S_k
- level_shift * (S_k @ D_b_k @ S_k)
)
F_alpha_k_list[idx] = 0.5 * (
F_alpha_k_list[idx] + F_alpha_k_list[idx].conj().T
)
F_beta_k_list[idx] = 0.5 * (
F_beta_k_list[idx] + F_beta_k_list[idx].conj().T
)
# Diagonalize per spin per k.
for idx in range(n_k):
C_a, eps_a = _diag_in_orth_basis(
F_alpha_k_list[idx],
X_k_list[idx],
)
C_b, eps_b = _diag_in_orth_basis(
F_beta_k_list[idx],
X_k_list[idx],
)
new_C_alpha.append(C_a)
new_eps_alpha.append(eps_a)
new_C_beta.append(C_b)
new_eps_beta.append(eps_b)
C_alpha_per_k = new_C_alpha
eps_alpha_per_k = new_eps_alpha
C_beta_per_k = new_C_beta
eps_beta_per_k = new_eps_beta
D_alpha_new = _spin_density(C_alpha_per_k, n_alpha)
D_beta_new = _spin_density(C_beta_per_k, n_beta)
D_alpha_prev = D_alpha_used
D_beta_prev = D_beta_used
D_alpha_real = D_alpha_new
D_beta_real = D_beta_new
E_prev = E_total
if converged:
break
# ---- Final pass on converged D's ------------------------------------
if converged:
F_HF_alpha_blocks, F_HF_beta_blocks = _build_uks_fock_2e_blocks_ewald3d(
basis,
system,
D_alpha_real,
D_beta_real,
omega,
alpha,
lat_opts,
grid_shape_t,
origin,
spacing_bohr,
)
# J-only per-spin pair for reporting.
if alpha != 0.0:
J_only_alpha_blocks, J_only_beta_blocks = _build_uks_fock_2e_blocks_ewald3d(
basis,
system,
D_alpha_real,
D_beta_real,
omega,
0.0,
lat_opts,
grid_shape_t,
origin,
spacing_bohr,
)
else:
J_only_alpha_blocks = F_HF_alpha_blocks
J_only_beta_blocks = F_HF_beta_blocks
xc = build_xc_periodic_uks(
basis,
system,
grid,
func,
D_alpha_real,
D_beta_real,
lat_opts,
)
E_xc = float(xc.e_xc)
F_alpha_k_list = []
F_beta_k_list = []
F_HF_alpha_k_list = []
F_HF_beta_k_list = []
J_only_alpha_k_list: List[np.ndarray] = []
J_only_beta_k_list: List[np.ndarray] = []
for k_idx, k_cart in enumerate(k_points):
k_arr = np.asarray(k_cart)
F_HF_a_k = _bloch_sum_blocks(F_HF_alpha_blocks, cells, k_arr)
F_HF_b_k = _bloch_sum_blocks(F_HF_beta_blocks, cells, k_arr)
V_xc_a_k = _bloch_sum_lms_at_k(xc.V_alpha, k_arr)
V_xc_b_k = _bloch_sum_lms_at_k(xc.V_beta, k_arr)
F_alpha_k_list.append(
0.5
* (
(Hcore_k_list[k_idx] + F_HF_a_k + V_xc_a_k)
+ (Hcore_k_list[k_idx] + F_HF_a_k + V_xc_a_k).conj().T
)
)
F_beta_k_list.append(
0.5
* (
(Hcore_k_list[k_idx] + F_HF_b_k + V_xc_b_k)
+ (Hcore_k_list[k_idx] + F_HF_b_k + V_xc_b_k).conj().T
)
)
F_HF_alpha_k_list.append(F_HF_a_k)
F_HF_beta_k_list.append(F_HF_b_k)
J_only_alpha_k_list.append(
_bloch_sum_blocks(J_only_alpha_blocks, cells, k_arr)
)
J_only_beta_k_list.append(
_bloch_sum_blocks(J_only_beta_blocks, cells, k_arr)
)
final_C_alpha: List[np.ndarray] = []
final_C_beta: List[np.ndarray] = []
final_eps_alpha: List[np.ndarray] = []
final_eps_beta: List[np.ndarray] = []
E_core_trace = 0.0
E_HF_alpha_trace = 0.0
E_HF_beta_trace = 0.0
E_J_alpha_trace = 0.0
E_J_beta_trace = 0.0
for idx in range(n_k):
C_a, eps_a = _diag_in_orth_basis(
F_alpha_k_list[idx],
X_k_list[idx],
)
C_b, eps_b = _diag_in_orth_basis(
F_beta_k_list[idx],
X_k_list[idx],
)
final_C_alpha.append(C_a)
final_C_beta.append(C_b)
final_eps_alpha.append(eps_a)
final_eps_beta.append(eps_b)
C_a_occ = C_a[:, :n_alpha] if n_alpha > 0 else C_a[:, :0]
C_b_occ = C_b[:, :n_beta] if n_beta > 0 else C_b[:, :0]
D_a_k = C_a_occ @ C_a_occ.conj().T
D_b_k = C_b_occ @ C_b_occ.conj().T
w = float(weights[idx])
E_core_trace += w * np.real(np.trace((D_a_k + D_b_k) @ Hcore_k_list[idx]))
E_HF_alpha_trace += (
w * 0.5 * np.real(np.trace(D_a_k @ F_HF_alpha_k_list[idx]))
)
E_HF_beta_trace += (
w * 0.5 * np.real(np.trace(D_b_k @ F_HF_beta_k_list[idx]))
)
E_J_alpha_trace += (
w * 0.5 * np.real(np.trace(D_a_k @ J_only_alpha_k_list[idx]))
)
E_J_beta_trace += (
w * 0.5 * np.real(np.trace(D_b_k @ J_only_beta_k_list[idx]))
)
C_alpha_per_k = final_C_alpha
C_beta_per_k = final_C_beta
eps_alpha_per_k = final_eps_alpha
eps_beta_per_k = final_eps_beta
E_elec = (
E_xc
+ float(E_core_trace)
+ float(E_HF_alpha_trace)
+ float(E_HF_beta_trace)
)
# Madelung-leak correction (v0.6.1).
_D_g0_f = np.asarray(_g0_block(D_alpha_real)) + np.asarray(
_g0_block(D_beta_real)
)
_S_g0_f = np.asarray(_g0_block(S_lat))
E_madelung_fix = _madelung_energy_correction_for_lat(
_D_g0_f, _S_g0_f, system, lat_opts
)
E_total = float(E_elec) + e_nuc + E_madelung_fix
E_coulomb_per_cell = float(E_J_alpha_trace + E_J_beta_trace)
# tr(D·F_HF) = tr(D·J) − α·tr(D·K) (with the ½ prefactor inside
# E_HF_*_trace), so HF_total - J_total = -α · ½ tr(D·K).
E_hf_K_per_cell = float(
(E_HF_alpha_trace - E_J_alpha_trace) + (E_HF_beta_trace - E_J_beta_trace)
)
# ⟨S²⟩ from the Γ-block (or first) k-point — same shortcut as
# multi-k UHF Ewald.
if n_alpha == 0 or n_beta == 0:
s2 = 0.25 * (n_alpha - n_beta) * (n_alpha - n_beta + 2) + n_beta
else:
k0_idx = 0
for i, k in enumerate(k_points):
if np.allclose(np.asarray(k), 0.0):
k0_idx = i
break
S_real = np.real(S_k_list[k0_idx])
s2 = _spin_squared(
n_alpha,
n_beta,
np.real(C_alpha_per_k[k0_idx]),
np.real(C_beta_per_k[k0_idx]),
S_real,
)
plog.converged(n_iter=iter_idx, energy=E_total, converged=converged)
return PeriodicUKSMultiKEwaldResult(
energy=E_total,
e_electronic=float(E_elec),
e_nuclear=e_nuc,
e_xc=float(E_xc),
e_coulomb=float(E_coulomb_per_cell),
e_hf_exchange=float(E_hf_K_per_cell),
n_iter=iter_idx,
converged=converged,
s_squared=float(s2),
s_squared_ideal=0.25 * (mult - 1) * (mult + 1),
mo_energies_alpha=eps_alpha_per_k,
mo_coeffs_alpha=C_alpha_per_k,
fock_alpha=F_alpha_k_list,
density_alpha=D_alpha_real,
mo_energies_beta=eps_beta_per_k,
mo_coeffs_beta=C_beta_per_k,
fock_beta=F_beta_k_list,
density_beta=D_beta_real,
overlap=S_k_list,
hcore=Hcore_k_list,
functional=str(opts.functional),
scf_trace=scf_trace,
omega=float(omega),
grid_shape=grid_shape_t,
)