"""BIPOLE-style periodic RKS driver in CRYSTAL's electrostatic gauge.
The DFT counterpart of :mod:`vibeqc.pbc_bipole`. Uses the same
CRYSTAL-gauge Ewald J-split F²e build for the Hartree term and
adds the libxc XC potential on the periodic Becke grid. For hybrid
functionals a fraction of HF exchange is retained via the native
``build_jk_2e_real_space`` component builder.
This driver keeps the same multi-k SCF scaffold as the RHF BIPOLE
driver: shared Ewald α across V_ne/E_nn/J^LR, analytic V_ne via
AO-pair Fourier transforms, real-space energy accounting, and
the full suite of convergence accelerators (DIIS, ODA, MOM,
level-shift schedule, damping).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional, Sequence, Union
import numpy as np
from ._vibeqc_core import (
BasisSet,
BlochKMesh,
CoulombMethod,
EwaldOptions,
Functional,
GridOptions,
InitialGuess,
LatticeMatrixSet,
LatticeSumOptions,
PeriodicKSOptions,
PeriodicSystem,
SCFIteration,
bloch_sum,
build_fock_2e_real_space,
build_grid,
build_jk_2e_real_space,
build_xc_periodic,
compute_kinetic_lattice,
compute_overlap_lattice,
direct_lattice_cells,
ewald_nuclear_repulsion,
nuclear_repulsion_per_cell,
real_space_density_from_kpoints,
)
from .bipole_ext_el_pole import compute_ext_el_spheropole
from .guess import initial_density_closed_shell
from .level_shift_schedule import LevelShiftSchedule
from .mom import select_occupied_by_max_overlap as _mom_select
from .oda import compute_oda_lambda as _compute_oda_lambda
from .oda import oda_mix_densities as _oda_mix
from .pbc_bipole import (
PBCBipoleEnergyComponents,
_bloch_sum_blocks,
_cell_key,
_compute_nuclear_lattice_ewald_reciprocal_ft,
_crystal_ewald_options,
_default_bipole_v_ne_grid_options,
_expand_ibz_kmesh_for_ewald_j,
_lattice_contract,
_lattice_contract_blocks,
)
from .periodic_grid import build_periodic_becke_grid
from .periodic_rhf_multi_k_ewald import (
_canonical_orthogonalizer_complex,
_damp_lattice_matrix,
_diag_in_orth_basis,
)
from .periodic_scf_accelerators import (
DynamicDamping,
MultiKPeriodicSCFAccelerator,
)
from .periodic_v_ne import compute_nuclear_lattice_dispatch
from .progress import ProgressLogger, resolve_progress
from .scf_divergence import check_scf_divergence
from .smearing._support import reject_unsupported_smearing_temperature
from .symmetry_integrals_reduced import (
compute_kinetic_lattice_reduced,
compute_overlap_lattice_reduced,
)
__all__ = [
"PBCBipoleRKSResult",
"run_pbc_bipole_rks",
]
[docs]
@dataclass
class PBCBipoleRKSResult:
"""Result of :func:`run_pbc_bipole_rks`."""
energy: float
e_electronic: float
e_nuclear: float
e_xc: float
e_coulomb: float
e_hf_exchange: float
n_iter: int
converged: bool
mo_energies: List[np.ndarray]
mo_coeffs: List[np.ndarray]
fock: List[np.ndarray]
overlap: List[np.ndarray]
hcore: List[np.ndarray]
density: LatticeMatrixSet
scf_trace: List[SCFIteration] = field(default_factory=list)
e_ext_el_spheropole: Optional[float] = None
ewald_alpha_bohr_inv: Optional[float] = None
energy_components: List[PBCBipoleEnergyComponents] = field(
default_factory=list,
)
functional: str = ""
@dataclass
class _PBCBipoleRKSFockBuild:
"""Internal Fock-build bundle for the BIPOLE RKS driver."""
f2e_real: LatticeMatrixSet
f_k_list: List[np.ndarray]
e_j_short_range: Optional[float] = None
e_j_long_range: Optional[float] = None
e_exchange: Optional[float] = None
e_j_multipole: Optional[float] = None
def _density_set_gamma_or_lattice(
template: LatticeMatrixSet,
D_real: LatticeMatrixSet,
) -> LatticeMatrixSet:
"""Return a ``LatticeMatrixSet`` suitable for XC evaluation.
If ``D_real`` carries per-cell blocks (multi-k), use it directly.
Otherwise wrap a Γ-folded matrix in a degenerate set.
"""
if len(D_real.cells) > 1 and D_real.blocks:
return D_real
nbf = template.nbf
D = (
np.asarray(D_real.blocks[0], dtype=float)
if D_real.blocks
else np.zeros((nbf, nbf))
)
zero = np.zeros_like(D)
for i in range(len(template)):
template.set_block(i, D if i == 0 else zero)
return template
[docs]
def run_pbc_bipole_rks(
system: PeriodicSystem,
basis: BasisSet,
kmesh: BlochKMesh,
options=None,
*,
functional: Optional[str] = None,
linear_dep_threshold: float = 1e-7,
canonical_orth_normalize_diag_first: bool = True,
level_shift_schedule: Optional[LevelShiftSchedule] = None,
use_mom: bool = False,
use_oda: bool = False,
oda_trust_lambda_max: float = 1.0,
use_ewald_j_split: Optional[bool] = None,
ewald_omega: Optional[float] = None,
ewald_precision: float = 1e-8,
v_ne_grid_options: Optional[GridOptions] = None,
use_multipole_far_field: bool = False,
multipole_l_max: int = 2,
progress: Union[bool, ProgressLogger, None] = None,
verbose: Optional[int] = None,
initial_density: Optional[Sequence[np.ndarray]] = None,
) -> PBCBipoleRKSResult:
"""Multi-k closed-shell RKS via the CRYSTAL-gauge BIPOLE scaffold.
Parameters
----------
system, basis, kmesh
As in :func:`run_pbc_bipole_rhf`.
options
:class:`PeriodicKSOptions` or None for PBE defaults.
functional
XC functional name (overrides ``options.functional`` if given).
All other parameters
As in :func:`run_pbc_bipole_rhf`.
Returns
-------
PBCBipoleRKSResult
"""
from ._vibeqc_core import PeriodicKSOptions as _PKSOpts
opts = options if options is not None else _PKSOpts()
if functional is not None:
opts.functional = str(functional)
if not getattr(opts, "functional", None):
opts.functional = "pbe"
reject_unsupported_smearing_temperature(
opts,
"run_pbc_bipole_rks",
detail=(
"BIPOLE smearing is queued for a later smearing milestone; "
"this driver would otherwise run integer Aufbau occupations."
),
)
func = Functional(opts.functional, 1) # spin-unpolarised
alpha_hf = float(func.hf_exchange_fraction)
lat_opts: LatticeSumOptions = opts.lattice_opts
plog = resolve_progress(progress, verbose=verbose)
use_ewald_j_split_auto = use_ewald_j_split is None
use_ewald_j_split = (
system.dim == 3 if use_ewald_j_split_auto else bool(use_ewald_j_split)
)
lat_opts_2e = LatticeSumOptions()
lat_opts_2e.cutoff_bohr = lat_opts.cutoff_bohr
lat_opts_2e.nuclear_cutoff_bohr = lat_opts.nuclear_cutoff_bohr
lat_opts_2e.coulomb_method = CoulombMethod.DIRECT_TRUNCATED
lat_opts_1e = LatticeSumOptions()
lat_opts_1e.cutoff_bohr = lat_opts.cutoff_bohr
lat_opts_1e.nuclear_cutoff_bohr = lat_opts.nuclear_cutoff_bohr
lat_opts_1e.coulomb_method = (
CoulombMethod.EWALD_3D if system.dim == 3 else CoulombMethod.DIRECT_TRUNCATED
)
plog.info(
f"PBC BIPOLE RKS (CRYSTAL-gauge) / cutoff {lat_opts.cutoff_bohr:.2f} bohr"
)
plog.info(f" functional = {opts.functional}, hf_exchange_fraction = {alpha_hf}")
plog.info(
f" F^2e (J + V_xc{'+ K' if alpha_hf > 0 else ''}) : "
f"{'EWALD_J_SPLIT' if use_ewald_j_split else lat_opts_2e.coulomb_method.name}"
)
n_elec = system.n_electrons()
if n_elec % 2 != 0:
raise ValueError(
f"run_pbc_bipole_rks: closed-shell RKS requires even electron "
f"count; got {n_elec}"
)
if system.multiplicity != 1:
raise ValueError(
f"run_pbc_bipole_rks: requires multiplicity=1; got {system.multiplicity}"
)
n_occ = n_elec // 2
_kmesh_ibz = kmesh
_ir_mapping = np.asarray(getattr(kmesh, "ir_mapping", []), dtype=int).reshape(-1)
k_points = list(_kmesh_ibz.kpoints)
weights = np.asarray(_kmesh_ibz.weights, dtype=float)
if use_ewald_j_split and _ir_mapping.size > 0:
kmesh_full = _expand_ibz_kmesh_for_ewald_j(system, kmesh, plog)
k_points_full = list(kmesh_full.kpoints)
weights_full = np.asarray(kmesh_full.weights, dtype=float)
else:
k_points_full = k_points
weights_full = weights
n_k = len(k_points)
if n_k == 0:
raise ValueError("kmesh has no k-points")
if not np.isclose(weights.sum(), 1.0):
raise ValueError(f"kmesh.weights must sum to 1; got {weights.sum():.6f}")
plog.info(
f"k-mesh: {n_k} k-point{'s' if n_k != 1 else ''}, "
f"weights sum = {weights.sum():.4f}"
)
# ---- Shared Ewald state (CRYSTAL COMMON/VRSMAD pattern) ----------
ewald_options_1e: Optional[EwaldOptions] = None
omega_used: Optional[float] = None
ewald_cell_volume: Optional[float] = None
ewald_k_max: Optional[float] = None
if system.dim == 3:
from .bipole_ext_el_pole import (
crystal_default_ewald_alpha,
crystal_ewald_reciprocal_cutoff,
)
V_cell = float(abs(np.linalg.det(np.asarray(system.lattice, dtype=float))))
ewald_cell_volume = V_cell
omega_used = (
float(ewald_omega)
if ewald_omega is not None
else crystal_default_ewald_alpha(V_cell)
)
ewald_k_max = crystal_ewald_reciprocal_cutoff(V_cell)
ewald_options_1e = _crystal_ewald_options(
lat_opts_1e,
alpha_bohr_inv=omega_used,
tolerance=float(ewald_precision),
recip_cutoff_bohr_inv=ewald_k_max,
)
# ---- DFT grid ----------------------------------------------------
if getattr(opts, "use_periodic_becke", False):
grid = build_periodic_becke_grid(
system,
grid_options=opts.grid,
image_radius_bohr=float(getattr(opts, "becke_image_radius_bohr", 10.0)),
)
else:
grid = build_grid(system.unit_cell_molecule(), opts.grid)
# ---- Real-space one-electron integrals ---------------------------
with plog.stage(
"integrals_lattice",
detail=f"S/T/V at cutoff {lat_opts.cutoff_bohr:.2f} bohr",
):
_use_sym = False # FIXME: LatticeMatrixSet has no Python constructor
if _use_sym:
ops = system.symmetry.operations
_, S_blocks = compute_overlap_lattice_reduced(
basis,
system,
lat_opts_2e,
ops,
)
S_lat = LatticeMatrixSet()
S_lat.nbf = basis.nbasis
S_lat.cells = direct_lattice_cells(system, lat_opts_2e.cutoff_bohr)
S_lat.blocks = S_blocks
_, T_blocks = compute_kinetic_lattice_reduced(
basis,
system,
lat_opts_2e,
ops,
)
T_lat = LatticeMatrixSet()
T_lat.nbf = basis.nbasis
T_lat.cells = direct_lattice_cells(system, lat_opts_2e.cutoff_bohr)
T_lat.blocks = T_blocks
else:
S_lat = compute_overlap_lattice(basis, system, lat_opts_2e)
T_lat = compute_kinetic_lattice(basis, system, lat_opts_2e)
v_ne_lr_cache = None
if (
system.dim == 3
and ewald_options_1e is not None
and v_ne_grid_options is None
):
V_lat, v_ne_lr_cache = _compute_nuclear_lattice_ewald_reciprocal_ft(
basis,
system,
lat_opts_1e,
ewald_options_1e,
S_lat,
precision=ewald_precision,
K_max=ewald_k_max,
)
else:
v_ne_grid = (
v_ne_grid_options
if v_ne_grid_options is not None
else (_default_bipole_v_ne_grid_options() if system.dim == 3 else None)
)
V_lat = compute_nuclear_lattice_dispatch(
basis,
system,
lat_opts_1e,
grid_options=v_ne_grid,
ewald_options=ewald_options_1e,
)
cells = list(S_lat.cells)
plog.info(f"n_cells in lattice sum = {len(cells)}")
# ---- Per-k S(k), Hcore(k), X(k) ---------------------------------
from .linear_dependence import (
check_overlap_matrix,
format_linear_dependence_report,
raise_if_severe,
scf_preflight_overlap_check,
)
S_k_list: List[np.ndarray] = []
Hcore_k_list: List[np.ndarray] = []
X_k_list: List[np.ndarray] = []
overlap_reports = []
for k_idx, k in enumerate(k_points):
k_arr = np.asarray(k, dtype=float).reshape(3)
S_k = np.asarray(bloch_sum(S_lat, k_arr))
T_k = np.asarray(bloch_sum(T_lat, k_arr))
V_k = np.asarray(bloch_sum(V_lat, k_arr))
H_k = T_k + V_k
S_k = 0.5 * (S_k + S_k.conj().T)
H_k = 0.5 * (H_k + H_k.conj().T)
overlap_label = f"S(k={k_idx}, k_cart={k_arr.round(4).tolist()})"
if n_k <= 16:
report = scf_preflight_overlap_check(
S_k,
plog=plog,
label=overlap_label,
basis=basis,
)
else:
report = check_overlap_matrix(
S_k,
basis=basis,
label=overlap_label,
)
if report.severity != "ok":
prefix = {
"warn": "WARN",
"error": "ERROR",
"critical": "CRITICAL",
}[report.severity]
plog.info(
f"[{prefix}] overlap [{overlap_label}]: "
f"nbf={report.n_basis}, "
f"min eig={report.min_eigenvalue:+.2e}, "
f"cond={report.condition_number:.2e}"
)
plog.write_raw(format_linear_dependence_report(report))
raise_if_severe(report)
X_k, n_kept = _canonical_orthogonalizer_complex(
S_k,
linear_dep_threshold,
normalize_diag_first=canonical_orth_normalize_diag_first,
)
overlap_reports.append(report)
if n_occ > n_kept:
raise RuntimeError(
f"run_pbc_bipole_rks: canonical orth at k={k_idx} "
f"dropped too many directions (n_occ={n_occ}, "
f"n_kept={n_kept})"
)
S_k_list.append(S_k)
Hcore_k_list.append(H_k)
X_k_list.append(X_k)
# ---- Nuclear repulsion -------------------------------------------
if ewald_options_1e is not None:
e_nuc = float(ewald_nuclear_repulsion(system, ewald_options_1e))
else:
e_nuc = float(nuclear_repulsion_per_cell(system, lat_opts_1e))
plog.info(f"E_nuc per cell = {e_nuc:+.10f} Ha")
# ---- Initial guess -----------------------------------------------
C_per_k: List[np.ndarray] = []
eps_per_k: List[np.ndarray] = []
for H_k, X_k in zip(Hcore_k_list, X_k_list):
C_k, eps_k = _diag_in_orth_basis(H_k, X_k)
C_per_k.append(C_k.astype(complex))
eps_per_k.append(eps_k)
n_occ_per_k = [n_occ] * n_k
D_real = real_space_density_from_kpoints(
C_per_k,
n_occ_per_k,
kmesh,
cells,
)
# Caller-supplied warm-start density takes precedence over the
# SAD/Hcore guess engine. Block ordering matches the canonical
# ``direct_lattice_cells(kmesh)`` ordering — same contract as the
# RHF driver. Used by the NEB driver for within-image density
# warm-start (HANDOVER_PERIODIC_NEB.md M4).
if initial_density is not None:
blocks_in = list(initial_density)
if len(blocks_in) != len(D_real.cells):
raise ValueError(
f"run_pbc_bipole_rks: initial_density has "
f"{len(blocks_in)} blocks; expected {len(D_real.cells)}"
)
for g_idx, block in enumerate(blocks_in):
D_real.set_block(g_idx, np.asarray(block, dtype=float))
plog.info("initial guess: caller-supplied density (warm-start)")
initial_density_is_local = True
density_from_c_per_k = False
else:
guess = getattr(opts, "initial_guess", InitialGuess.HCORE)
D_engine = initial_density_closed_shell(
system.unit_cell_molecule(),
basis,
n_occ,
guess,
is_periodic=True,
)
if D_engine is not None:
plog.info(f"initial guess: {guess.name} (g=0 density from GuessEngine)")
for g_idx in range(len(D_real.cells)):
if (D_real.cells[g_idx].index == np.array([0, 0, 0])).all():
D_real.set_block(g_idx, D_engine)
else:
D_real.set_block(
g_idx,
np.zeros_like(D_engine, dtype=float),
)
else:
plog.info(f"initial guess: {guess.name} (Hcore-diag per k)")
initial_density_is_local = D_engine is not None
density_from_c_per_k = not initial_density_is_local
D_real_prev: Optional[LatticeMatrixSet] = None
# ---- SCF aids ----------------------------------------------------
damping = float(opts.damping)
damper: Optional[DynamicDamping] = None
if bool(getattr(opts, "dynamic_damping", False)):
damper = DynamicDamping(
initial_alpha=damping,
alpha_min=float(getattr(opts, "dynamic_damping_min", 0.0)),
alpha_max=float(getattr(opts, "dynamic_damping_max", 0.95)),
)
use_diis = bool(opts.use_diis)
diis_start_iter = int(opts.diis_start_iter)
accel: Optional[MultiKPeriodicSCFAccelerator] = (
MultiKPeriodicSCFAccelerator(opts) if use_diis else None
)
level_shift_static = float(getattr(opts, "level_shift", 0.0))
if level_shift_schedule is not None and not isinstance(
level_shift_schedule,
LevelShiftSchedule,
):
raise TypeError(
f"level_shift_schedule must be LevelShiftSchedule; "
f"got {type(level_shift_schedule).__name__}"
)
if use_mom:
plog.info("MOM: ON")
C_prev_occ_per_k: Optional[List[np.ndarray]] = None
if use_oda and use_diis:
raise ValueError("use_oda and use_diis are mutually exclusive")
if use_oda:
plog.info(f"ODA: ON (trust λ_max = {oda_trust_lambda_max})")
# ---- Ewald J-split cache -----------------------------------------
j_lr_cache = v_ne_lr_cache
if use_ewald_j_split:
if system.dim != 3:
raise ValueError("use_ewald_j_split requires dim=3")
if n_k > 1:
uniform_w = 1.0 / float(n_k)
if not np.allclose(weights, uniform_w, atol=1e-9):
raise ValueError(
"use_ewald_j_split at multi-k requires uniform "
"full-mesh weights or IBZ-reduced mesh with ir_mapping"
)
from .bipole_fock_ewald import (
_build_j_long_range_cache,
compute_J_long_range_real_space_blocks,
compute_rho_hat_from_k_density,
)
assert omega_used is not None
cells_r_cart_arr = np.array(
[np.asarray(c.r_cart, dtype=float) for c in cells],
dtype=float,
)
if j_lr_cache is None:
j_lr_cache = _build_j_long_range_cache(
basis,
system,
cells_r_cart_arr,
omega_used,
ewald_precision,
K_max=ewald_k_max,
)
def _density_matrices_from_coeffs(
coeffs_per_k: Sequence[np.ndarray],
) -> List[np.ndarray]:
out: List[np.ndarray] = []
for C_k_raw in coeffs_per_k:
C_k_arr = np.asarray(C_k_raw)
C_occ_k = C_k_arr[:, :n_occ]
out.append(2.0 * (C_occ_k @ C_occ_k.conj().T))
return out
def _build_fock_for_density(
density: LatticeMatrixSet,
*,
coeffs_for_rho: Optional[Sequence[np.ndarray]],
) -> _PBCBipoleRKSFockBuild:
"""Build F²e(g) and F(k) with Ewald J-split + V_xc."""
if use_ewald_j_split:
# Short-range J via erfc-screened ERIs
F_J_SR_lat = build_fock_2e_real_space(
basis,
system,
lat_opts_2e,
density,
0.0,
float(omega_used),
)
# Full K
jk_full_lat = build_jk_2e_real_space(
basis,
system,
lat_opts_2e,
density,
0.0,
)
K_full_blocks = [
np.asarray(jk_full_lat.K.blocks[c], dtype=float)
for c in range(len(jk_full_lat.K.cells))
]
cells_lr = F_J_SR_lat.cells
# Long-range J via reciprocal-space sum
rho_hat_for_LR = None
if coeffs_for_rho is not None:
D_k_for_rho = _density_matrices_from_coeffs(coeffs_for_rho)
if _ir_mapping.size > 0:
D_k_full = [D_k_for_rho[int(idx)] for idx in _ir_mapping]
rho_hat_for_LR = compute_rho_hat_from_k_density(
D_k_full,
k_points_full,
weights_full,
j_lr_cache,
)
else:
rho_hat_for_LR = compute_rho_hat_from_k_density(
D_k_for_rho,
k_points,
weights,
j_lr_cache,
)
F_LR_blocks = compute_J_long_range_real_space_blocks(
density,
basis,
system,
omega_used,
precision=ewald_precision,
cache=j_lr_cache,
rho_hat=rho_hat_for_LR,
)
# Neutralising background
assert ewald_cell_volume is not None
j_background_potential = (
-np.pi
* float(n_elec)
/ (float(omega_used) * float(omega_used) * float(ewald_cell_volume))
)
s_blocks = {
_cell_key(cell): np.asarray(block, dtype=float)
for cell, block in zip(S_lat.cells, S_lat.blocks)
}
for c, cell in enumerate(cells_lr):
key = _cell_key(cell)
F_LR_blocks[c] = F_LR_blocks[c] + j_background_potential * s_blocks[key]
# Energy components
j_sr_blocks = [
np.asarray(F_J_SR_lat.blocks[c], dtype=float)
for c in range(len(cells_lr))
]
e_j_short_range = 0.5 * _lattice_contract_blocks(
density,
cells_lr,
j_sr_blocks,
operator_name="J_SR",
)
e_j_long_range = 0.5 * _lattice_contract_blocks(
density,
cells_lr,
F_LR_blocks,
operator_name="J_LR",
)
# Assemble F²e(g) = J_SR(g) + J_LR(g) - (α/2)·K(g)
e_exchange: Optional[float] = None
f2e_direct_blocks = []
for c in range(len(cells_lr)):
j_blk = j_sr_blocks[c] + F_LR_blocks[c]
if alpha_hf > 0.0:
k_blk = K_full_blocks[c]
j_blk = j_blk - 0.5 * alpha_hf * k_blk
if e_exchange is None:
e_exchange = 0.0
d_blk = np.asarray(density.blocks[c], dtype=float)
e_exchange = e_exchange - 0.25 * alpha_hf * float(
np.sum(d_blk * k_blk)
)
f2e_direct_blocks.append(j_blk)
f2e_real = F_J_SR_lat
for c in range(len(cells_lr)):
f2e_real.set_block(c, f2e_direct_blocks[c])
# ---- Optional: replace J_SR+J_LR with multipole far-field J --
e_j_multipole: Optional[float] = None
if _mp_config.enabled:
from .bipole_fock_multipole import apply_multipole_far_field
mp_result = apply_multipole_far_field(
density,
basis,
system,
lat_opts_2e,
_mp_config,
J_SR_blocks=j_sr_blocks,
K_blocks=K_full_blocks,
F_LR_blocks=F_LR_blocks,
exchange_scale=0.5 * alpha_hf,
)
for c in range(len(cells_lr)):
f2e_real.set_block(c, mp_result.f2e_blocks[c])
e_j_multipole = mp_result.e_j_multipole
if mp_result.n_far_cells > 0:
plog.info(
f" BIPOLE multipole far-field (L_max={_mp_config.L_max}, "
f"R={_mp_config.R_bipole:.1f} bohr): "
f"{mp_result.n_far_cells}/{mp_result.n_total_cells} cells replaced, "
f"E_J_far = {e_j_multipole:+.6f} Ha"
)
else:
exchange_scale = alpha_hf
f2e_real = build_fock_2e_real_space(
basis,
system,
lat_opts_2e,
density,
exchange_scale,
0.0,
)
e_j_short_range = None
e_j_long_range = None
e_exchange = None
# ---- Build XC potential on the DFT grid ----------------------
D_xc_set = _density_set_gamma_or_lattice(S_lat, density)
xc_result = build_xc_periodic(basis, system, grid, D_xc_set, func)
# ---- Assemble per-k Fock matrices ----------------------------
f_k_list: List[np.ndarray] = []
for k_idx, k in enumerate(k_points):
k_arr = np.asarray(k, dtype=float)
F2e_k = _bloch_sum_blocks(
f2e_real.blocks,
f2e_real.cells,
k_arr,
)
Vxc_k = np.asarray(bloch_sum(xc_result.vxc_lattice, k_arr))
F_k = F2e_k + Vxc_k + np.asarray(Hcore_k_list[k_idx], dtype=complex)
F_k = 0.5 * (F_k + F_k.conj().T)
f_k_list.append(F_k)
return _PBCBipoleRKSFockBuild(
f2e_real=f2e_real,
f_k_list=f_k_list,
e_j_short_range=e_j_short_range,
e_j_long_range=e_j_long_range,
e_exchange=e_exchange,
e_j_multipole=e_j_multipole,
)
# ---- Multipole far-field config (resolve once before SCF loop) -----
from .bipole_fock_multipole import ( # noqa: E402
BipoleMultipoleConfig,
resolve_multipole_config,
)
_mp_config = resolve_multipole_config(
system,
basis,
lat_opts_2e,
user_enable=use_multipole_far_field,
multipole_l_max=multipole_l_max,
)
if _mp_config.enabled:
plog.info(
f" BIPOLE multipole far-field: ENABLED "
f"(L_max={_mp_config.L_max}, R_bipole={_mp_config.R_bipole:.1f} bohr, "
f"n_cells={len(_mp_config.cache.cells) if _mp_config.cache else 0})"
)
else:
plog.info(
f" BIPOLE multipole far-field: off "
f"(R_bipole={_mp_config.R_bipole:.1f} bohr, "
f"cutoff={lat_opts_2e.cutoff_bohr:.1f} bohr)"
)
plog.banner("SCF (PBC BIPOLE RKS, direct-space)")
plog.info(" iter energy (Ha) dE ||[F,DS]|| DIIS")
scf_trace: List[SCFIteration] = []
energy_components: List[PBCBipoleEnergyComponents] = []
E_prev = 0.0
F_k_list: List[np.ndarray] = [np.zeros_like(H) for H in Hcore_k_list]
E_elec = 0.0
e_xc = 0.0
converged = False
iter_idx = 0
for iter_idx in range(1, int(opts.max_iter) + 1):
if damper is not None:
damping = damper.alpha
diis_active = use_diis and iter_idx >= diis_start_iter
D_used = D_real
if iter_idx > 1 and damping > 0.0 and not diis_active:
D_used = _damp_lattice_matrix(D_real, D_real_prev, damping)
d_used_is_damped = iter_idx > 1 and damping > 0.0 and not diis_active
d_used_from_coeffs = (
density_from_c_per_k
and not (initial_density_is_local and iter_idx == 1)
and not d_used_is_damped
)
fock_build = _build_fock_for_density(
D_used,
coeffs_for_rho=(C_per_k if d_used_from_coeffs else None),
)
F2e_real = fock_build.f2e_real
F_k_list = fock_build.f_k_list
# ---- XC energy -----------------------------------------------
D_xc_set = _density_set_gamma_or_lattice(S_lat, D_used)
xc_result = build_xc_periodic(basis, system, grid, D_xc_set, func)
e_xc = float(xc_result.exc)
# ---- Energy --------------------------------------------------
E_kin = _lattice_contract(D_used, T_lat, operator_name="T")
E_ne = _lattice_contract(D_used, V_lat, operator_name="V_ne")
E_2e = 0.5 * _lattice_contract(
D_used,
F2e_real,
operator_name="F2e",
)
E_elec = E_kin + E_ne + E_2e + e_xc
grad_norm_sum = 0.0
error_k_list: List[np.ndarray] = []
D_k_list: List[np.ndarray] = []
for idx in range(n_k):
if initial_density_is_local and iter_idx == 1:
k_arr = np.asarray(k_points[idx], dtype=float)
D_k = _bloch_sum_blocks(
D_used.blocks,
D_used.cells,
k_arr,
)
D_k = 0.5 * (D_k + D_k.conj().T)
else:
C_k = C_per_k[idx]
C_occ = C_k[:, :n_occ]
D_k = 2.0 * (C_occ @ C_occ.conj().T)
D_k_list.append(D_k)
F_k = F_k_list[idx]
w = float(weights[idx])
S_k = S_k_list[idx]
FDS = F_k @ D_k @ S_k
grad = FDS - FDS.conj().T
error_k_list.append(grad)
grad_norm_sum += w * float(np.linalg.norm(grad))
E_total = float(E_elec) + e_nuc
# EXT EL-SPHEROPOLE
E_sphero = compute_ext_el_spheropole(D_used, basis, system, lat_opts)
E_total += E_sphero
dE = E_total - E_prev if iter_idx > 1 else 0.0
check_scf_divergence(
"run_pbc_bipole_rks",
iter_idx,
E_total,
grad_norm_sum,
dE,
)
scf_trace.append(
SCFIteration(
iter=iter_idx,
energy=float(E_total),
delta_e=float(dE if iter_idx > 1 else 0.0),
grad_norm=float(grad_norm_sum),
diis_subspace=(accel.subspace_size if accel is not None else 0),
)
)
plog.iteration(
iter_idx,
energy=float(E_total),
dE=float(dE if iter_idx > 1 else 0.0),
grad=float(grad_norm_sum),
diis=(accel.subspace_size if accel is not None else 0),
)
energy_components.append(
PBCBipoleEnergyComponents(
iter=int(iter_idx),
e_total=float(E_total),
e_electronic=float(E_elec),
e_kinetic=float(E_kin),
e_nuclear_attraction=float(E_ne),
e_two_electron=float(E_2e),
e_nuclear_repulsion=float(e_nuc),
e_j_short_range=fock_build.e_j_short_range,
e_j_long_range=fock_build.e_j_long_range,
e_exchange=fock_build.e_exchange,
e_ext_el_spheropole=E_sphero,
)
)
plog.energy_decomposition(
iter_idx,
E_kin=float(E_kin),
E_ne=float(E_ne),
E_2e=float(E_2e),
E_elec=float(E_elec),
E_nuc=float(e_nuc),
)
converged = (
iter_idx > 1
and abs(dE) < float(opts.conv_tol_energy)
and grad_norm_sum < float(opts.conv_tol_grad)
)
# SCF-accelerator extrapolation. Full
# {DIIS, KDIIS, EDIIS, EDIIS_DIIS, ADIIS} family wired here;
# bridged modes route through the M2e stacked-real-block bridge
# (see ``per_k_to_stacked_real_blocks`` in
# ``periodic_scf_accelerators.py``).
if accel is not None:
density_k_list = [
_bloch_sum_blocks(D_used.blocks, D_used.cells, np.asarray(k))
for k in k_points
]
F_ex_list = accel.extrapolate_rhf(
F_k_list,
error_k_list=error_k_list,
density_k_list=density_k_list,
energy=E_total,
mo_coeffs_k_list=C_per_k,
n_occ=n_occ,
weights=list(weights),
cells=cells,
kpoints=list(k_points),
)
if diis_active:
F_k_list = F_ex_list
# Level shift
if level_shift_schedule is not None:
level_shift_b = level_shift_schedule.at(iter_idx)
else:
level_shift_b = level_shift_static
if level_shift_b != 0.0:
F_for_diag: List[np.ndarray] = []
for idx in range(n_k):
D_k = D_k_list[idx]
S_k = S_k_list[idx]
F_shift = (
F_k_list[idx]
+ level_shift_b * S_k
- (level_shift_b / 2.0) * (S_k @ D_k @ S_k)
)
F_shift = 0.5 * (F_shift + F_shift.conj().T)
F_for_diag.append(F_shift)
else:
F_for_diag = F_k_list
# Diagonalise
new_C_per_k = []
new_eps_per_k = []
for idx in range(n_k):
C_k, eps_k = _diag_in_orth_basis(
F_for_diag[idx],
X_k_list[idx],
)
new_C_per_k.append(C_k)
new_eps_per_k.append(eps_k)
# MOM
if use_mom and C_prev_occ_per_k is not None:
for idx in range(n_k):
C_k = new_C_per_k[idx]
eps_k = new_eps_per_k[idx]
S_k = S_k_list[idx]
sel = _mom_select(
C_k,
S_k,
C_prev_occ_per_k[idx],
n_occ,
eps_new=eps_k,
)
n_kept_idx = C_k.shape[1]
virt_mask = np.ones(n_kept_idx, dtype=bool)
virt_mask[sel] = False
virt_sel = np.where(virt_mask)[0]
virt_sel = virt_sel[np.argsort(np.real(eps_k[virt_sel]))]
order = np.concatenate([sel, virt_sel])
new_C_per_k[idx] = C_k[:, order]
new_eps_per_k[idx] = eps_k[order]
C_per_k = new_C_per_k
eps_per_k = new_eps_per_k
D_real_new = real_space_density_from_kpoints(
C_per_k,
[n_occ] * n_k,
kmesh,
cells,
)
# ODA
if use_oda:
fock_naive = _build_fock_for_density(
D_real_new,
coeffs_for_rho=C_per_k,
)
oda_step = _compute_oda_lambda(
D_used,
D_real_new,
F_k_list,
fock_naive.f_k_list,
[np.asarray(k) for k in k_points],
weights,
trust_lambda_max=oda_trust_lambda_max,
)
_oda_mix(D_used, D_real_new, oda_step.lam)
D_real_prev = D_real
D_real = D_used
density_from_c_per_k = oda_step.lam == 1.0
plog.info(f" ODA: λ = {oda_step.lam:.4f}")
else:
D_real_prev = D_used
D_real = D_real_new
density_from_c_per_k = True
if use_mom:
C_prev_occ_per_k = [
np.asarray(C_per_k[idx][:, :n_occ]).copy() for idx in range(n_k)
]
if damper is not None:
damper.update(E_total)
E_prev = E_total
if converged:
break
plog.converged(n_iter=iter_idx, energy=E_total, converged=converged)
# ---- Post-loop: recompute energy on final density for consistency
if converged:
_fb = _build_fock_for_density(D_real, coeffs_for_rho=C_per_k)
D_xc_set = _density_set_gamma_or_lattice(S_lat, D_real)
_xc = build_xc_periodic(basis, system, grid, D_xc_set, func)
E_kin_f = _lattice_contract(D_real, T_lat, operator_name="T")
E_ne_f = _lattice_contract(D_real, V_lat, operator_name="V_ne")
E_2e_f = 0.5 * _lattice_contract(D_real, _fb.f2e_real, operator_name="F2e")
E_elec = E_kin_f + E_ne_f + E_2e_f + float(_xc.exc)
E_total = float(E_elec) + e_nuc
# Fresh E_total doesn't include spheropole — add it.
E_sphero_final = compute_ext_el_spheropole(D_real, basis, system, lat_opts)
E_total += E_sphero_final
else:
E_sphero_final = energy_components[-1].e_ext_el_spheropole
return PBCBipoleRKSResult(
energy=float(E_total),
e_electronic=float(E_elec),
e_nuclear=e_nuc,
e_ext_el_spheropole=E_sphero_final,
e_xc=float(e_xc),
e_coulomb=float(E_2e),
e_hf_exchange=float(fock_build.e_exchange or 0.0),
n_iter=iter_idx,
converged=converged,
mo_energies=eps_per_k,
mo_coeffs=C_per_k,
fock=F_k_list,
overlap=S_k_list,
hcore=Hcore_k_list,
density=D_real,
scf_trace=scf_trace,
ewald_alpha_bohr_inv=omega_used,
energy_components=energy_components,
functional=str(opts.functional),
)