"""BIPOLE-style periodic RHF driver in CRYSTAL's electrostatic gauge.
CRYSTAL's 3D periodic HF energy uses one shared Ewald state for the
point-charge tail terms and a separate screened real-space machinery for
the AO two-electron build. This driver mirrors that composition:
* ``V_ne`` and ``E_nn`` use ``EWALD_3D`` with one explicit
``EwaldOptions`` object, matching CRYSTAL's ``COMMON/VRSMAD`` pattern.
The default 3D ``V_ne`` path evaluates the smooth reciprocal piece
analytically with shifted AO-pair Fourier transforms.
* The optional ``use_ewald_j_split`` path builds
``J = J_SR(ω) + J_LR(ω)`` with the same α used by ``V_ne`` / ``E_nn``.
``J_LR`` is represented as real-space blocks for CRYSTAL-style
``TOTENY`` energy contractions, and includes the electron-electron
neutralising-background Fock potential ``-π N_e /(α² V) · S(g)``.
* Exchange remains the full direct-space ``K`` from
``build_fock_2e_real_space``; no Madelung K shift is applied.
* Energies are always evaluated as real-space lattice contractions,
``Σ_g tr[D(g)H(g)] + ½Σ_g tr[D(g)F²e(g)]``, not from a Γ-folded
operator.
V_ne gauge placement
--------------------
CRYSTAL and vibe-qc use the same four-component Ewald decomposition
(real-space erfc, reciprocal-space K≠0 sum, self-energy, jellium
background), but place the G=0 correction differently:
* **CRYSTAL**: the jellium background ``−π Q_n²/(2 β² V)`` is added
to ``PAR(18)`` (nuclear repulsion ``E_nn``). ``V_ne`` includes only
the K≠0 reciprocal sum; the G=0 term is handled implicitly through
the total-energy cancellation.
* **vibe-qc**: the V_ne operator receives an explicit background
``+π Q_n/(α² V) · S(g)``, and E_nn receives the standard
``−π Q_n²/(2 α² V)`` jellium term. For a neutral cell these
cancel exactly in E_total. Per-component diagnostics (E_ne, E_nuc)
therefore differ from CRYSTAL's ENECYCLE output by the background
magnitude (~16 Ha for MgO/STO-3G), but the total energy is
invariant.
This is still an algorithmic re-implementation, not a CRYSTAL wrapper,
and no external QC program is imported at runtime. The remaining parity
gap to CRYSTAL's native BIPOLE code is the full Saunders-Dovesi-Roetti
multipole-far-pair branch: CRYSTAL replaces far direct ERIs with
truncated multipole expansions and prints the corresponding EXT
EL-POLE / EXT EL-SPHEROPOLE decomposition. The Ewald-J path here
captures the same gauge and reproduces CYC0 tightly, but dense-k final
parity should still be checked against the CRYSTAL diagnostics before
promoting this driver beyond research-preview status.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
from ._vibeqc_core import (
BasisSet,
BlochKMesh,
CoulombMethod,
EwaldOptions,
GridOptions,
InitialGuess,
LatticeMatrixSet,
LatticeSumOptions,
PeriodicSystem,
SCFIteration,
bloch_sum,
build_fock_2e_real_space,
build_jk_2e_real_space,
compute_kinetic_lattice,
compute_nuclear_erfc_lattice,
compute_overlap_lattice,
direct_lattice_cells,
ewald_nuclear_repulsion,
nuclear_repulsion_per_cell,
real_space_density_from_kpoints,
)
from ._vibeqc_core import (
monkhorst_pack as _native_monkhorst_pack,
)
from .bipole_ext_el_pole import compute_ext_el_spheropole
from .guess import initial_density_closed_shell
from .level_shift_schedule import LevelShiftSchedule
from .mom import select_occupied_by_max_overlap as _mom_select
from .oda import compute_oda_lambda as _compute_oda_lambda
from .oda import oda_mix_densities as _oda_mix
from .periodic_rhf_multi_k_ewald import (
_canonical_orthogonalizer_complex,
_damp_lattice_matrix,
_diag_in_orth_basis,
)
from .periodic_scf_accelerators import (
DynamicDamping,
MultiKPeriodicSCFAccelerator,
)
from .periodic_v_ne import compute_nuclear_lattice_dispatch
from .progress import ProgressLogger, resolve_progress
from .scf_divergence import check_scf_divergence
from .smearing._support import reject_unsupported_smearing_temperature
from .symmetry_integrals_reduced import (
compute_kinetic_lattice_reduced,
compute_overlap_lattice_reduced,
)
__all__ = [
"PBCBipoleEnergyComponents",
"PBCBipoleRHFResult",
"run_pbc_bipole_rhf",
]
@dataclass
class PBCBipoleEnergyComponents:
"""Per-iteration energy components in CRYSTAL ``ENECYCLE`` terms."""
iter: int
e_total: float
e_electronic: float
e_kinetic: float
e_nuclear_attraction: float
e_two_electron: float
e_nuclear_repulsion: float
e_bielet_zone_ee: Optional[float] = None
e_ext_el_pole: Optional[float] = None
e_ext_el_spheropole: Optional[float] = None
e_j_short_range: Optional[float] = None
e_j_long_range: Optional[float] = None
e_exchange: Optional[float] = None
e_j_multipole: Optional[float] = None
[docs]
@dataclass
class PBCBipoleRHFResult:
"""Result of :func:`run_pbc_bipole_rhf`.
Per-cell ``energy`` / ``e_electronic`` / ``e_nuclear`` and per-k
matrices (``mo_energies``, ``mo_coeffs``, ``fock``, ``overlap``,
``hcore``) alongside the converged real-space ``density``. For 3D
BIPOLE runs, ``ewald_alpha_bohr_inv`` records the single alpha used
by V_ne / E_nn / optional J_LR.
"""
energy: float
e_electronic: float
e_nuclear: float
n_iter: int
converged: bool
mo_energies: List[np.ndarray]
mo_coeffs: List[np.ndarray]
fock: List[np.ndarray]
overlap: List[np.ndarray]
hcore: List[np.ndarray]
density: LatticeMatrixSet
# Fields with defaults must come after all non-default fields
# (Python 3.14 dataclass enforcement).
e_ext_el_spheropole: Optional[float] = None
scf_trace: List[SCFIteration] = field(default_factory=list)
ewald_alpha_bohr_inv: Optional[float] = None
energy_components: List[PBCBipoleEnergyComponents] = field(
default_factory=list,
)
@dataclass
class _PBCBipoleFockBuild:
"""Internal Fock-build bundle for one density in the BIPOLE driver."""
f2e_real: LatticeMatrixSet
f_k_list: List[np.ndarray]
e_j_short_range: Optional[float] = None
e_j_long_range: Optional[float] = None
e_exchange: Optional[float] = None
e_j_multipole: Optional[float] = None
def _bloch_sum_blocks(
blocks: Sequence[np.ndarray],
cells,
k_cart: np.ndarray,
) -> np.ndarray:
"""F(k) = Σ_g exp(+i k·R_g) F(g). Real → complex result."""
k = np.asarray(k_cart, dtype=float).reshape(3)
F_k = np.zeros_like(blocks[0], dtype=complex)
for g_idx, block in enumerate(blocks):
R_g = np.asarray(cells[g_idx].r_cart, dtype=float)
phase = np.exp(1j * float(np.dot(k, R_g)))
F_k = F_k + phase * np.asarray(block, dtype=float)
return F_k
def _cell_key(cell) -> Tuple[int, int, int]:
return tuple(int(x) for x in np.asarray(cell.index, dtype=int).reshape(3))
def _lattice_contract(
density: LatticeMatrixSet,
operator: LatticeMatrixSet,
*,
operator_name: str,
) -> float:
return _lattice_contract_blocks(
density,
operator.cells,
operator.blocks,
operator_name=operator_name,
)
def _lattice_contract_blocks(
density: LatticeMatrixSet,
operator_cells,
operator_blocks,
*,
operator_name: str,
) -> float:
"""Real-space periodic trace ``Σ_g Σ_μν D_μν(g) M_μν(g)``.
CRYSTAL's BIPOLE energy path contracts the current real-space
density against real-space operator blocks, not against a Γ-folded
Bloch sum. The elementwise form mirrors ``cpp/src/periodic_scf.cpp``
and avoids accidentally adding cross-cell one-electron terms when
the initial SAD density lives only at ``g = 0``.
"""
op_blocks = {
_cell_key(cell): np.asarray(block)
for cell, block in zip(operator_cells, operator_blocks)
}
total = 0.0
for cell, d_block_raw in zip(density.cells, density.blocks):
key = _cell_key(cell)
if key not in op_blocks:
raise ValueError(
f"_lattice_contract: {operator_name} is missing cell {key}"
)
d_block = np.asarray(d_block_raw)
m_block = op_blocks[key]
if d_block.shape != m_block.shape:
raise ValueError(
f"_lattice_contract: shape mismatch for {operator_name} "
f"at cell {key}: D{d_block.shape} vs M{m_block.shape}"
)
total += float(np.real(np.sum(d_block * m_block)))
return total
def _crystal_ewald_options(
lat_opts: LatticeSumOptions,
*,
alpha_bohr_inv: Optional[float],
tolerance: float,
recip_cutoff_bohr_inv: Optional[float] = None,
) -> EwaldOptions:
"""Build the single Ewald state used by BIPOLE V_ne / E_nn / J_LR.
CRYSTAL's ``COMMON/VRSMAD`` pattern is important less because the
final Ewald sum depends on alpha (it should not, in the complete
limit) and more because finite cutoffs / quadrature do. This helper
makes that shared state explicit on the Python side.
When ``recip_cutoff_bohr_inv`` is provided (positive), the C++
nuclear Ewald will use this K_max instead of auto-computing it
from α and tolerance. This guarantees that nuclear and electronic
Ewald sums use the same reciprocal lattice envelope — essential
for G=0 cancellation at finite cutoffs.
"""
opts = EwaldOptions()
opts.real_cutoff_bohr = lat_opts.nuclear_cutoff_bohr
opts.tolerance = tolerance
if alpha_bohr_inv is not None:
opts.alpha = float(alpha_bohr_inv)
if recip_cutoff_bohr_inv is not None and recip_cutoff_bohr_inv > 0.0:
opts.recip_cutoff_bohr_inv = float(recip_cutoff_bohr_inv)
return opts
def _expand_ibz_kmesh_for_ewald_j(
system: PeriodicSystem,
kmesh: BlochKMesh,
plog: ProgressLogger,
) -> BlochKMesh:
"""Expand an IBZ-reduced MP mesh to the full mesh for Ewald-J.
The current Ewald-J long-range density transform needs the full
Bloch-summed AO-pair FT over the whole Monkhorst-Pack mesh. A
symmetry-reduced ``BlochKMesh`` is fine as user input as long as it
carries ``ir_mapping`` metadata; we reconstruct the corresponding
full mesh here so the rest of the SCF loop sees uniform weights.
"""
ir_mapping = np.asarray(
getattr(kmesh, "ir_mapping", []),
dtype=int,
).reshape(-1)
if ir_mapping.size == 0:
return kmesh
mesh = tuple(int(x) for x in getattr(kmesh, "mesh", (1, 1, 1)))
shift = tuple(int(x) for x in getattr(kmesh, "is_shift", (0, 0, 0)))
full_n = int(np.prod(mesh))
current_n = len(list(kmesh.kpoints))
if full_n <= current_n:
return kmesh
expanded = _native_monkhorst_pack(
system,
list(mesh),
list(shift),
False,
)
plog.info(
" k-mesh symmetry expansion: "
f"{current_n} IBZ point{'s' if current_n != 1 else ''} "
f"-> {len(list(expanded.kpoints))} full MP points for Ewald-J "
f"(mesh={mesh}, shift={shift})"
)
return expanded
def _default_bipole_v_ne_grid_options() -> GridOptions:
"""Fallback grid for the smooth long-range Ewald V_ne matrix.
CRYSTAL evaluates the one-electron Ewald potential analytically. The
BIPOLE driver now does the same by default via AO-pair Fourier
transforms. If a caller passes ``v_ne_grid_options`` explicitly, the
driver falls back to the older quadrature path; use a tighter Lebedev
grid there than the generic XC default so off-diagonal Hcore elements
remain stable for CRYSTAL cycle-parity diagnostics.
"""
opts = GridOptions()
opts.n_radial = 99
opts.angular = "lebedev"
opts.lebedev_order = 41
opts.angular_pruning = "none"
opts.partition = "becke"
return opts
def _compute_nuclear_lattice_ewald_reciprocal_ft(
basis: BasisSet,
system: PeriodicSystem,
lat_opts: LatticeSumOptions,
ewald_options: EwaldOptions,
S_lat: LatticeMatrixSet,
*,
cache=None,
precision: float = 1e-8,
K_max: Optional[float] = None,
):
"""Analytic 3D Ewald V_ne blocks via AO-pair Fourier transforms.
The generic ``compute_nuclear_lattice_ewald`` path evaluates the
smooth long-range potential on a molecular quadrature grid. That is
fine for total energies, but BIPOLE CRYSTAL parity is sensitive to
off-diagonal Hcore elements after the first Fock diagonalisation.
Here we evaluate the reciprocal-space part analytically:
``V_lr(g) = -Σ_G kernel(G) ρ_nuc(G) FT_g(G)^* + π Q_n/(α²V) S(g)``.
The short-range erfc piece remains libint-analytic.
Parameters
----------
K_max : float, optional
Explicit reciprocal-space cutoff in bohr⁻¹ for the V_ne
reciprocal-sum cache. When provided together with the shared
Ewald α, guarantees that V_ne and J_LR use the same envelope.
"""
if system.dim != 3:
raise ValueError("analytic Ewald V_ne requires dim=3")
alpha = float(ewald_options.alpha)
if alpha <= 0.0:
raise ValueError("analytic Ewald V_ne requires explicit alpha")
V_short = compute_nuclear_erfc_lattice(
basis,
system,
alpha,
lat_opts,
)
if len(V_short.cells) != len(S_lat.cells):
raise RuntimeError("analytic Ewald V_ne: V_short and S cell lists differ")
if cache is None:
from .bipole_fock_ewald import _build_j_long_range_cache
cells_r_cart_arr = np.array(
[np.asarray(c.r_cart, dtype=float) for c in S_lat.cells],
dtype=float,
)
cache = _build_j_long_range_cache(
basis,
system,
cells_r_cart_arr,
alpha,
precision,
K_max=K_max,
)
atom_pos = np.array(
[[float(x) for x in atom.xyz] for atom in system.unit_cell],
dtype=float,
)
atom_z = np.array(
[float(atom.Z) for atom in system.unit_cell],
dtype=float,
)
phases = np.exp(-1j * (atom_pos @ cache.K_vectors.T))
rho_nuc = atom_z @ phases
weighted = cache.kernel * rho_nuc
a = np.asarray(system.lattice, dtype=float)
V_cell = float(abs(np.linalg.det(a)))
q_nuc = float(atom_z.sum())
background = np.pi * q_nuc / (alpha * alpha * V_cell)
for c, cell in enumerate(V_short.cells):
if _cell_key(cell) != _cell_key(S_lat.cells[c]):
raise RuntimeError(
"analytic Ewald V_ne: cell ordering differs between V_short and S"
)
v_lr = -np.einsum(
"k,mnk->mn",
weighted,
cache.ft_per_cell[c].conj(),
)
block = (
np.asarray(V_short.blocks[c], dtype=float)
+ np.real(v_lr)
+ background * np.asarray(S_lat.blocks[c], dtype=float)
)
V_short.set_block(c, block)
return V_short, cache
[docs]
def run_pbc_bipole_rhf(
system: PeriodicSystem,
basis: BasisSet,
kmesh: BlochKMesh,
options=None,
*,
linear_dep_threshold: float = 1e-7,
canonical_orth_normalize_diag_first: bool = True,
level_shift_schedule: Optional["LevelShiftSchedule"] = None,
use_mom: bool = False,
use_oda: bool = False,
oda_trust_lambda_max: float = 1.0,
use_ewald_j_split: Optional[bool] = None,
ewald_omega: Optional[float] = None,
ewald_precision: float = 1e-8,
v_ne_grid_options: Optional[GridOptions] = None,
use_multipole_diag: bool = False,
use_multipole_far_field: bool = False,
multipole_l_max: int = 2,
progress: Union[bool, ProgressLogger, None] = None,
verbose: Optional[int] = None,
initial_density: Optional[Sequence[np.ndarray]] = None,
) -> PBCBipoleRHFResult:
"""Multi-k closed-shell RHF via the CRYSTAL-gauge BIPOLE scaffold.
Algorithm (matches CRYSTAL BIELET):
1. Real-space one-electron integrals S(g), T(g), V_ne(g) at
``opts.lattice_opts.cutoff_bohr``. For 3D systems V_ne uses
the same Ewald α as E_nn.
2. Bloch-sum to S(k), Hcore(k) per k-point; canonical-orth X(k).
3. Initial guess via ``opts.initial_guess`` (default SAD).
4. SCF iter:
a. Build F^{2e}(g). With ``use_ewald_j_split=True`` this is
``J_SR(g;ω) + J_LR(g;ω) + V_bg·S(g) - ½K_full(g)``.
With the flag off, use the legacy direct-only
``build_fock_2e_real_space`` scaffold.
b. Bloch-sum F^{2e}(g) → F(k); add Hcore(k).
c. Energy: E_elec = Σ_g tr[D(g)Hcore(g)]
+ ½Σ_g tr[D(g)F²e(g)] in real-space block form
(CRYSTAL/TOTENY convention).
d. Optional DIIS extrapolation of F(k) via [F,DS] errors.
e. Optional LEVSHIFT shift on F(k).
f. Diagonalise F(k) → C(k), ε(k).
g. Optional MOM reorder of occupied subspace.
h. Rebuild D_real via real_space_density_from_kpoints.
i. Optional ODA mixing on density.
5. E_total = E_elec + E_nuc.
``use_ewald_j_split`` defaults to ``None``. In that mode the
driver automatically uses the CRYSTAL-gauge Ewald-J split for 3D
systems and keeps the old direct-only path for dim < 3 diagnostic
runs. Pass ``False`` explicitly only when you want the legacy
direct-only F²e scaffold for debugging.
For 3D systems the default ``V_ne`` implementation is analytic:
erfc-screened nuclear attraction from libint plus a reciprocal-space
AO-pair Fourier-transform sum. Passing ``v_ne_grid_options`` opts
into the older grid-quadrature long-range ``V_ne`` path for
diagnostics.
"""
from ._vibeqc_core import PeriodicRHFOptions
opts = options if options is not None else PeriodicRHFOptions()
reject_unsupported_smearing_temperature(
opts,
"run_pbc_bipole_rhf",
detail=(
"BIPOLE smearing is queued for a later smearing milestone; "
"this driver would otherwise run integer Aufbau occupations."
),
)
lat_opts: LatticeSumOptions = opts.lattice_opts
plog = resolve_progress(progress, verbose=verbose)
use_ewald_j_split_auto = use_ewald_j_split is None
use_ewald_j_split = (
system.dim == 3 if use_ewald_j_split_auto else bool(use_ewald_j_split)
)
# CRYSTAL-style gauge separation (per the EWALD_3D / BIPOLE audit):
# V_ne and E_nn use Ewald with one shared alpha. F^{2e} uses the
# direct lattice cell list for J_SR/K; the optional J_LR reciprocal
# sum consumes the same alpha as the one-electron Ewald state.
lat_opts_2e = LatticeSumOptions()
lat_opts_2e.cutoff_bohr = lat_opts.cutoff_bohr
lat_opts_2e.nuclear_cutoff_bohr = lat_opts.nuclear_cutoff_bohr
lat_opts_2e.coulomb_method = CoulombMethod.DIRECT_TRUNCATED
lat_opts_1e = LatticeSumOptions()
lat_opts_1e.cutoff_bohr = lat_opts.cutoff_bohr
lat_opts_1e.nuclear_cutoff_bohr = lat_opts.nuclear_cutoff_bohr
if system.dim == 3:
lat_opts_1e.coulomb_method = CoulombMethod.EWALD_3D
else:
# 1D / 2D: no Ewald_3D path; fall back to direct truncation.
lat_opts_1e.coulomb_method = CoulombMethod.DIRECT_TRUNCATED
plog.info(f"PBC BIPOLE (CRYSTAL-gauge) / cutoff {lat_opts.cutoff_bohr:.2f} bohr")
plog.info(
f" V_ne + E_nn : {lat_opts_1e.coulomb_method.name}"
f" (Ewald gauge for point-charge tails)"
)
plog.info(
f" F^2e (J + K) : "
f"{'EWALD_J_SPLIT' if use_ewald_j_split else lat_opts_2e.coulomb_method.name}"
f"{' (auto)' if use_ewald_j_split_auto else ''}"
f" (direct J_SR/K cell list"
f"{' + reciprocal J_LR' if use_ewald_j_split else ''})"
)
plog.info(f"basis: {basis.name} ({basis.nbasis} BFs / {basis.nshells} shells)")
# Closed-shell sanity.
n_elec = system.n_electrons()
if n_elec % 2 != 0:
raise ValueError(
f"run_pbc_bipole_rhf: closed-shell RHF requires even electron "
f"count; got {n_elec}"
)
if system.multiplicity != 1:
raise ValueError(
f"run_pbc_bipole_rhf: requires multiplicity=1; got {system.multiplicity}"
)
n_occ = n_elec // 2
_kmesh_ibz = kmesh
_ir_mapping = np.asarray(getattr(kmesh, "ir_mapping", []), dtype=int).reshape(-1)
k_points = list(_kmesh_ibz.kpoints)
weights = np.asarray(_kmesh_ibz.weights, dtype=float)
if use_ewald_j_split and _ir_mapping.size > 0:
kmesh_full = _expand_ibz_kmesh_for_ewald_j(system, kmesh, plog)
k_points_full = list(kmesh_full.kpoints)
weights_full = np.asarray(kmesh_full.weights, dtype=float)
else:
k_points_full = k_points
weights_full = weights
n_k = len(k_points)
if n_k == 0:
raise ValueError("kmesh has no k-points")
if not np.isclose(weights.sum(), 1.0):
raise ValueError(f"kmesh.weights must sum to 1; got {weights.sum():.6f}")
plog.info(
f"k-mesh: {n_k} k-point{'s' if n_k != 1 else ''}, "
f"weights sum = {weights.sum():.4f}"
)
# CRYSTAL-style shared Ewald state for all point-charge-tail terms.
# V_ne, E_nn, and the optional reciprocal J^LR build must consume the
# same alpha AND the same K_max — because finite-cutoff G=0
# cancellation requires matched reciprocal envelopes.
ewald_options_1e: Optional[EwaldOptions] = None
omega_used: Optional[float] = None
ewald_cell_volume: Optional[float] = None
ewald_k_max: Optional[float] = None
if system.dim == 3:
from .bipole_ext_el_pole import (
crystal_default_ewald_alpha,
crystal_ewald_reciprocal_cutoff,
)
V_cell = float(
abs(
np.linalg.det(np.asarray(system.lattice, dtype=float)),
)
)
ewald_cell_volume = V_cell
omega_used = (
float(ewald_omega)
if ewald_omega is not None
else crystal_default_ewald_alpha(V_cell)
)
ewald_k_max = crystal_ewald_reciprocal_cutoff(V_cell)
ewald_options_1e = _crystal_ewald_options(
lat_opts_1e,
alpha_bohr_inv=omega_used,
tolerance=float(ewald_precision),
recip_cutoff_bohr_inv=ewald_k_max,
)
plog.info(
f" Ewald state: α = {omega_used:.6f} bohr⁻¹, "
f"cutoff_real = {lat_opts_1e.nuclear_cutoff_bohr:.2f} bohr, "
f"K_max = {ewald_k_max:.2f} bohr⁻¹, "
f"tol = {float(ewald_precision):.0e}"
)
# ---- Real-space one-electron integrals -------------------------------
# S, T use cell-list-only cutoff (lat_opts_2e — they're independent
# of coulomb_method). V_ne uses lat_opts_1e so the EWALD_3D path is
# taken on 3D systems (CRYSTAL-equivalent gauge).
with plog.stage(
"integrals_lattice",
detail=f"S/T/V at cutoff {lat_opts.cutoff_bohr:.2f} bohr",
):
_use_sym = False # FIXME: LatticeMatrixSet has no Python constructor
if _use_sym:
ops = system.symmetry.operations
plog.info(
f"S/T integrals: symmetry-reduced path "
f"(SG {system.symmetry.international_symbol}, "
f"{system.symmetry.order} ops)"
)
_, S_blocks = compute_overlap_lattice_reduced(
basis,
system,
lat_opts_2e,
ops,
)
S_lat = LatticeMatrixSet()
S_lat.nbf = basis.nbasis
S_lat.cells = direct_lattice_cells(system, lat_opts_2e.cutoff_bohr)
S_lat.blocks = S_blocks
_, T_blocks = compute_kinetic_lattice_reduced(
basis,
system,
lat_opts_2e,
ops,
)
T_lat = LatticeMatrixSet()
T_lat.nbf = basis.nbasis
T_lat.cells = direct_lattice_cells(system, lat_opts_2e.cutoff_bohr)
T_lat.blocks = T_blocks
else:
S_lat = compute_overlap_lattice(basis, system, lat_opts_2e)
T_lat = compute_kinetic_lattice(basis, system, lat_opts_2e)
v_ne_lr_cache = None
if (
system.dim == 3
and ewald_options_1e is not None
and v_ne_grid_options is None
):
plog.info(
" V_ne Ewald long range: analytic AO-pair FT (shared with J^LR cache)"
)
V_lat, v_ne_lr_cache = _compute_nuclear_lattice_ewald_reciprocal_ft(
basis,
system,
lat_opts_1e,
ewald_options_1e,
S_lat,
precision=ewald_precision,
K_max=ewald_k_max,
)
else:
v_ne_grid = (
v_ne_grid_options
if v_ne_grid_options is not None
else (_default_bipole_v_ne_grid_options() if system.dim == 3 else None)
)
V_lat = compute_nuclear_lattice_dispatch(
basis,
system,
lat_opts_1e,
grid_options=v_ne_grid,
ewald_options=ewald_options_1e,
)
cells = list(S_lat.cells)
plog.info(f"n_cells in lattice sum = {len(cells)}")
# Per-k S(k), Hcore(k), orthogonaliser X(k).
from .linear_dependence import (
check_overlap_matrix,
format_linear_dependence_report,
raise_if_severe,
scf_preflight_overlap_check,
)
S_k_list: List[np.ndarray] = []
T_k_list: List[np.ndarray] = []
V_ne_k_list: List[np.ndarray] = []
Hcore_k_list: List[np.ndarray] = []
X_k_list: List[np.ndarray] = []
overlap_reports = []
for k_idx, k in enumerate(k_points):
k_arr = np.asarray(k, dtype=float).reshape(3)
S_k = np.asarray(bloch_sum(S_lat, k_arr))
T_k = np.asarray(bloch_sum(T_lat, k_arr))
V_k = np.asarray(bloch_sum(V_lat, k_arr))
T_k = 0.5 * (T_k + T_k.conj().T)
V_k = 0.5 * (V_k + V_k.conj().T)
H_k = T_k + V_k
S_k = 0.5 * (S_k + S_k.conj().T)
H_k = 0.5 * (H_k + H_k.conj().T)
overlap_label = f"S(k={k_idx}, k_cart={k_arr.round(4).tolist()})"
if n_k <= 16:
report = scf_preflight_overlap_check(
S_k,
plog=plog,
label=overlap_label,
basis=basis,
)
else:
report = check_overlap_matrix(
S_k,
basis=basis,
label=overlap_label,
)
if report.severity != "ok":
prefix = {
"warn": "WARN",
"error": "ERROR",
"critical": "CRITICAL",
}[report.severity]
cond_str = (
f"{report.condition_number:.2e}"
if np.isfinite(report.condition_number)
else "+inf"
)
plog.info(
f"[{prefix}] overlap [{overlap_label}]: "
f"nbf={report.n_basis}, "
f"min eig={report.min_eigenvalue:+.2e}, "
f"cond={cond_str}, severity={report.severity}"
)
plog.write_raw(format_linear_dependence_report(report))
raise_if_severe(report)
X_k, n_kept = _canonical_orthogonalizer_complex(
S_k,
linear_dep_threshold,
normalize_diag_first=canonical_orth_normalize_diag_first,
)
overlap_reports.append(report)
if n_occ > n_kept:
raise RuntimeError(
f"run_pbc_bipole_rhf: canonical orth at k={k_idx} "
f"dropped too many directions (n_occ={n_occ}, n_kept={n_kept})"
)
S_k_list.append(S_k)
T_k_list.append(T_k)
V_ne_k_list.append(V_k)
Hcore_k_list.append(H_k)
X_k_list.append(X_k)
if n_k > 16:
severity_rank = {"ok": 0, "warn": 1, "error": 2, "critical": 3}
worst = max(
overlap_reports,
key=lambda r: severity_rank.get(r.severity, -1),
)
min_s = min(float(r.min_eigenvalue) for r in overlap_reports)
max_cond = max(float(r.condition_number) for r in overlap_reports)
cond_str = f"{max_cond:.2e}" if np.isfinite(max_cond) else "+inf"
plog.info(
f"overlap [k-mesh summary]: n_k={n_k}, nbf={basis.nbasis}, "
f"min eig={min_s:+.2e}, max cond={cond_str}, "
f"severity={worst.severity}"
)
# ---- Nuclear repulsion per cell --------------------------------------
if ewald_options_1e is not None:
e_nuc = float(ewald_nuclear_repulsion(system, ewald_options_1e))
else:
e_nuc = float(nuclear_repulsion_per_cell(system, lat_opts_1e))
plog.info(f"E_nuc per cell ({lat_opts_1e.coulomb_method.name}) = {e_nuc:+.10f} Ha")
# ---- Initial guess ---------------------------------------------------
C_per_k: List[np.ndarray] = []
eps_per_k: List[np.ndarray] = []
for H_k, X_k in zip(Hcore_k_list, X_k_list):
C_k, eps_k = _diag_in_orth_basis(H_k, X_k)
C_per_k.append(C_k.astype(complex))
eps_per_k.append(eps_k)
n_occ_per_k = [n_occ] * n_k
D_real = real_space_density_from_kpoints(
C_per_k,
n_occ_per_k,
kmesh,
cells,
)
# Caller-supplied warm-start density takes precedence over both the
# SAD/Hcore guess engine and the Hcore-diag fallback. The caller is
# responsible for matching ``initial_density`` blocks against the
# canonical ``direct_lattice_cells(kmesh)`` ordering (which is what
# the SCF's ``D_real`` uses). Used by the NEB driver for within-
# image density warm-start across outer iterations + within FD-
# gradient displaced SCFs (HANDOVER_PERIODIC_NEB.md M4 periodic
# follow-up).
if initial_density is not None:
blocks_in = list(initial_density)
if len(blocks_in) != len(D_real.cells):
raise ValueError(
f"run_pbc_bipole_rhf: initial_density has {len(blocks_in)} "
f"blocks; expected {len(D_real.cells)} (one per cell in "
f"direct_lattice_cells(kmesh))"
)
for g_idx, block in enumerate(blocks_in):
D_real.set_block(g_idx, np.asarray(block, dtype=float))
plog.info("initial guess: caller-supplied density (warm-start)")
initial_density_is_local = True
density_from_c_per_k = False
else:
# SAD override (place SAD density at g=0; zeros elsewhere).
guess = getattr(opts, "initial_guess", InitialGuess.HCORE)
D_engine = initial_density_closed_shell(
system.unit_cell_molecule(),
basis,
n_occ,
guess,
is_periodic=True,
)
if D_engine is not None:
plog.info(f"initial guess: {guess.name} (g=0 density from GuessEngine)")
for g_idx in range(len(D_real.cells)):
if (D_real.cells[g_idx].index == np.array([0, 0, 0])).all():
D_real.set_block(g_idx, D_engine)
else:
D_real.set_block(g_idx, np.zeros_like(D_engine, dtype=float))
else:
plog.info(f"initial guess: {guess.name} (Hcore-diag per k)")
initial_density_is_local = D_engine is not None
density_from_c_per_k = not initial_density_is_local
D_real_prev: Optional[LatticeMatrixSet] = None
# ---- SCF aids: damping, accelerator family, LEVSHIFT, MOM, ODA ------
damping = float(opts.damping)
if not (0.0 <= damping < 1.0):
raise ValueError(f"run_pbc_bipole_rhf: damping must be in [0,1); got {damping}")
damper: Optional[DynamicDamping] = None
if bool(getattr(opts, "dynamic_damping", False)):
damper = DynamicDamping(
initial_alpha=damping,
alpha_min=float(getattr(opts, "dynamic_damping_min", 0.0)),
alpha_max=float(getattr(opts, "dynamic_damping_max", 0.95)),
)
use_diis = bool(opts.use_diis)
diis_start_iter = int(opts.diis_start_iter)
accel: Optional[MultiKPeriodicSCFAccelerator] = (
MultiKPeriodicSCFAccelerator(opts) if use_diis else None
)
level_shift_static = float(getattr(opts, "level_shift", 0.0))
if level_shift_schedule is not None and not isinstance(
level_shift_schedule,
LevelShiftSchedule,
):
raise TypeError(
f"level_shift_schedule must be a LevelShiftSchedule or None; "
f"got {type(level_shift_schedule).__name__}"
)
if level_shift_schedule is not None:
plog.info(f"level_shift_schedule: {level_shift_schedule.as_list()}")
if use_mom:
plog.info("MOM (Maximum Overlap Method): ON")
C_prev_occ_per_k: Optional[List[np.ndarray]] = None
if use_oda and use_diis:
raise ValueError(
"run_pbc_bipole_rhf: use_oda and use_diis are mutually exclusive"
)
if use_oda:
if not (0.0 < oda_trust_lambda_max <= 1.0):
raise ValueError(
f"oda_trust_lambda_max must be in (0, 1]; got {oda_trust_lambda_max}"
)
plog.info(
f"ODA (Optimal Damping): ON (+1 Fock build/iter, "
f"trust λ_max = {oda_trust_lambda_max})"
)
# ---- Optional: Ewald J-split F^2e build (Phase 5 of BIPOLE branch) ---
j_lr_cache = v_ne_lr_cache
if use_ewald_j_split:
# CRYSTAL-equivalent gauge: V_ne + E_nn use Ewald, F^2e uses
# J^SR(direct erfc-screened) + J^LR(analytic reciprocal-sum) − ½K.
# Single shared α between V_ne, E_nn, and J_LR (CRYSTAL's
# COMMON/VRSMAD/ pattern).
#
# Multi-k J^LR uses Bloch-summed shifted-ν AO-pair FTs and a
# k-space ρ̂(K). The operator is materialised as real-space
# blocks below so both diagonalisation and TOTENY-style energy
# accounting see the same long-range J.
if system.dim != 3:
raise ValueError(
f"use_ewald_j_split requires dim=3 (3D periodic). Got dim={system.dim}."
)
if n_k > 1 and _ir_mapping.size == 0:
# Non-uniform weights without ir_mapping: can't expand.
uniform_w = 1.0 / float(n_k)
if not np.allclose(weights, uniform_w, atol=1e-9):
raise ValueError(
"use_ewald_j_split at multi-k requires uniform full-mesh "
"weights or an IBZ-reduced Monkhorst-Pack mesh carrying "
"ir_mapping metadata so the driver can expand it. "
f"Got non-uniform weights = {weights.tolist()}."
)
from .bipole_fock_ewald import (
_build_j_long_range_cache,
compute_J_long_range_real_space_blocks,
compute_rho_hat_from_k_density,
)
assert omega_used is not None
plog.info(
f"Ewald J-split F^2e: ON (CRYSTAL-equivalent gauge); "
f"ω = {omega_used:.4f} bohr⁻¹, precision = {ewald_precision:.0e}"
)
# Pre-build the shifted-ν FT cache once — invariant across SCF
# iters + k-points within an iter. Γ-only needs the same cache
# for real-space energy blocks even though the Fock can be built
# from the k=0 folded matrix.
cells_r_cart_arr = np.array(
[np.asarray(c.r_cart, dtype=float) for c in cells],
dtype=float,
)
if j_lr_cache is None:
j_lr_cache = _build_j_long_range_cache(
basis,
system,
cells_r_cart_arr,
omega_used,
ewald_precision,
K_max=ewald_k_max,
)
elif j_lr_cache.ft_per_cell.shape[0] != len(cells):
raise RuntimeError(
"prebuilt V_ne/J^LR cache has a different cell count "
f"({j_lr_cache.ft_per_cell.shape[0]}) from S_lat "
f"({len(cells)})"
)
plog.info(
f" J^LR cache: {j_lr_cache.K_vectors.shape[0]} K-vectors, "
f"{j_lr_cache.ft_per_cell.shape[0]} lattice cells"
)
def _density_matrices_from_coeffs(
coeffs_per_k: Sequence[np.ndarray],
) -> List[np.ndarray]:
out: List[np.ndarray] = []
for C_k_raw in coeffs_per_k:
C_k_arr = np.asarray(C_k_raw)
C_occ_k = C_k_arr[:, :n_occ]
out.append(2.0 * (C_occ_k @ C_occ_k.conj().T))
return out
def _build_fock_for_density(
density: LatticeMatrixSet,
*,
coeffs_for_rho: Optional[Sequence[np.ndarray]],
) -> _PBCBipoleFockBuild:
"""Build F²e(g) and F(k) for one real-space density.
``coeffs_for_rho`` is supplied only when the density is exactly
represented by the current per-k orbitals. Damped / ODA-mixed
densities are real-space objects with no exact orbital
representation, so their J^LR density transform must use the
real-space blocks instead of stale C(k) coefficients.
"""
if use_ewald_j_split:
F_J_SR_lat = build_fock_2e_real_space(
basis,
system,
lat_opts_2e,
density,
0.0, # exchange_scale = 0 → pure J
float(omega_used), # omega > 0 → erfc kernel
)
jk_full_lat = build_jk_2e_real_space(
basis,
system,
lat_opts_2e,
density,
0.0,
)
F_K_full_lat = jk_full_lat.K
cells_lr = F_J_SR_lat.cells
F2e_direct_blocks = []
K_full_blocks = []
for c in range(len(cells_lr)):
j_sr = np.asarray(F_J_SR_lat.blocks[c], dtype=float)
k_full = np.asarray(F_K_full_lat.blocks[c], dtype=float)
K_full_blocks.append(k_full)
F2e_direct_blocks.append(j_sr - 0.5 * k_full)
rho_hat_for_LR = None
if coeffs_for_rho is not None:
D_k_for_rho = _density_matrices_from_coeffs(coeffs_for_rho)
# If using IBZ mesh, expand D(k) to full mesh for rho_hat
if _ir_mapping.size > 0:
D_k_full = [D_k_for_rho[int(idx)] for idx in _ir_mapping]
rho_hat_for_LR = compute_rho_hat_from_k_density(
D_k_full,
k_points_full,
weights_full,
j_lr_cache,
)
else:
rho_hat_for_LR = compute_rho_hat_from_k_density(
D_k_for_rho,
k_points,
weights,
j_lr_cache,
)
F_LR_blocks = compute_J_long_range_real_space_blocks(
density,
basis,
system,
omega_used,
precision=ewald_precision,
cache=j_lr_cache,
rho_hat=rho_hat_for_LR,
)
assert ewald_cell_volume is not None
j_background_potential = (
-np.pi
* float(n_elec)
/ (float(omega_used) * float(omega_used) * float(ewald_cell_volume))
)
s_blocks = {
_cell_key(cell): np.asarray(block, dtype=float)
for cell, block in zip(S_lat.cells, S_lat.blocks)
}
for c, cell in enumerate(cells_lr):
key = _cell_key(cell)
F_LR_blocks[c] = F_LR_blocks[c] + j_background_potential * s_blocks[key]
e_j_short_range = 0.5 * _lattice_contract(
density,
F_J_SR_lat,
operator_name="J_SR",
)
e_j_long_range = 0.5 * _lattice_contract_blocks(
density,
cells_lr,
F_LR_blocks,
operator_name="J_LR",
)
e_exchange = -0.25 * _lattice_contract_blocks(
density,
cells_lr,
K_full_blocks,
operator_name="K_full",
)
f2e_real = F_J_SR_lat
for c in range(len(cells_lr)):
f2e_real.set_block(
c,
F2e_direct_blocks[c] + F_LR_blocks[c],
)
# ---- Optional: replace J_SR+J_LR with multipole far-field J --
e_j_multipole: Optional[float] = None
if _mp_config.enabled:
from .bipole_fock_multipole import apply_multipole_far_field
mp_result = apply_multipole_far_field(
density,
basis,
system,
lat_opts_2e,
_mp_config,
J_SR_blocks=[
np.asarray(F_J_SR_lat.blocks[c], dtype=float)
for c in range(len(cells_lr))
],
K_blocks=K_full_blocks,
F_LR_blocks=F_LR_blocks,
exchange_scale=0.5,
)
for c in range(len(cells_lr)):
f2e_real.set_block(c, mp_result.f2e_blocks[c])
e_j_multipole = mp_result.e_j_multipole
if mp_result.n_far_cells > 0:
plog.info(
f" BIPOLE multipole far-field (L_max={_mp_config.L_max}, "
f"R={_mp_config.R_bipole:.1f} bohr): "
f"{mp_result.n_far_cells}/{mp_result.n_total_cells} cells replaced, "
f"E_J_far = {e_j_multipole:+.6f} Ha"
)
else:
f2e_real = build_fock_2e_real_space(
basis,
system,
lat_opts_2e,
density,
1.0,
0.0,
)
e_j_short_range = None
e_j_long_range = None
e_exchange = None
e_j_multipole = None
f_k_list: List[np.ndarray] = []
for k_idx, k in enumerate(k_points):
k_arr = np.asarray(k, dtype=float)
F2e_k = _bloch_sum_blocks(
f2e_real.blocks,
f2e_real.cells,
k_arr,
)
F_k = F2e_k + np.asarray(Hcore_k_list[k_idx], dtype=complex)
F_k = 0.5 * (F_k + F_k.conj().T)
f_k_list.append(F_k)
return _PBCBipoleFockBuild(
f2e_real=f2e_real,
f_k_list=f_k_list,
e_j_short_range=e_j_short_range,
e_j_long_range=e_j_long_range,
e_exchange=e_exchange,
e_j_multipole=e_j_multipole,
)
# ---- Multipole far-field config (resolve once before SCF loop) -----
from .bipole_fock_multipole import ( # noqa: E402
BipoleMultipoleConfig,
resolve_multipole_config,
)
_mp_config = resolve_multipole_config(
system,
basis,
lat_opts_2e,
user_enable=use_multipole_far_field,
multipole_l_max=multipole_l_max,
)
if _mp_config.enabled:
plog.info(
f" BIPOLE multipole far-field: ENABLED "
f"(L_max={_mp_config.L_max}, R_bipole={_mp_config.R_bipole:.1f} bohr, "
f"n_cells={len(_mp_config.cache.cells) if _mp_config.cache else 0})"
)
else:
plog.info(
f" BIPOLE multipole far-field: off "
f"(R_bipole={_mp_config.R_bipole:.1f} bohr, "
f"cutoff={lat_opts_2e.cutoff_bohr:.1f} bohr)"
)
# ---- SCF loop --------------------------------------------------------
plog.banner("SCF (PBC BIPOLE, direct-space)")
plog.info(" iter energy (Ha) dE ||[F,DS]|| DIIS")
scf_trace: List[SCFIteration] = []
energy_components: List[PBCBipoleEnergyComponents] = []
E_prev = 0.0
F_k_list: List[np.ndarray] = [np.zeros_like(H) for H in Hcore_k_list]
E_elec = 0.0
converged = False
iter_idx = 0
for iter_idx in range(1, int(opts.max_iter) + 1):
if damper is not None:
damping = damper.alpha
diis_active = use_diis and iter_idx >= diis_start_iter
E_j_short_range: Optional[float] = None
E_j_long_range: Optional[float] = None
E_exchange: Optional[float] = None
E_j_multipole: Optional[float] = None
# Damping (skip when DIIS active).
D_used = D_real
if iter_idx > 1 and damping > 0.0 and not diis_active:
D_used = _damp_lattice_matrix(D_real, D_real_prev, damping)
# --- F^{2e}(g) build.
# Use the k-space ρ̂(K) route only when the real-space density
# is exactly represented by C_per_k. Local SAD, fixed damping,
# and ODA-mixed densities are real-space densities; for those,
# J^LR must be built from the actual density blocks to avoid
# using stale orbitals in the reciprocal-space piece.
d_used_is_damped = iter_idx > 1 and damping > 0.0 and not diis_active
d_used_from_coeffs = (
density_from_c_per_k
and not (initial_density_is_local and iter_idx == 1)
and not d_used_is_damped
)
fock_build = _build_fock_for_density(
D_used,
coeffs_for_rho=(C_per_k if d_used_from_coeffs else None),
)
F2e_real = fock_build.f2e_real
F_k_list = fock_build.f_k_list
E_j_short_range = fock_build.e_j_short_range
E_j_long_range = fock_build.e_j_long_range
E_exchange = fock_build.e_exchange
E_j_multipole = fock_build.e_j_multipole
# --- Per-cell electronic energy + [F,DS] error vectors.
#
# CRYSTAL's energy path contracts the real-space density against
# real-space operator blocks: E = Σ_g D(g)H(g) + ½Σ_g D(g)F²e(g).
# This is essential at CYC0, where SAD is localised at g=0 and
# Γ-folding T/V would incorrectly add cross-cell one-electron
# blocks. k-space D(k) is still needed for error vectors,
# level-shift projection, and the J^LR split path.
E_kin = _lattice_contract(D_used, T_lat, operator_name="T")
E_ne = _lattice_contract(D_used, V_lat, operator_name="V_ne")
E_2e = 0.5 * _lattice_contract(
D_used,
F2e_real,
operator_name="F2e",
)
E_elec = E_kin + E_ne + E_2e
grad_norm_sum = 0.0
error_k_list: List[np.ndarray] = []
D_k_list: List[np.ndarray] = []
for idx in range(n_k):
if initial_density_is_local and iter_idx == 1:
# SAD/PATOM-style local guesses are stored explicitly as
# D(g=0)=D_atom_sum and D(g≠0)=0. Their Bloch sum is the
# same D at every k; using the Hcore-diag C(k) seed here
# would make the energy/error vector inconsistent with
# the Fock matrix that was just built from SAD.
k_arr = np.asarray(k_points[idx], dtype=float)
D_k = _bloch_sum_blocks(D_used.blocks, D_used.cells, k_arr)
D_k = 0.5 * (D_k + D_k.conj().T)
else:
# Multi-k (or legacy): D_k from previous iter's C.
C_k = C_per_k[idx]
C_occ = C_k[:, :n_occ]
D_k = 2.0 * (C_occ @ C_occ.conj().T)
D_k_list.append(D_k)
H_k = Hcore_k_list[idx]
F_k = F_k_list[idx]
w = float(weights[idx])
S_k = S_k_list[idx]
FDS = F_k @ D_k @ S_k
grad = FDS - FDS.conj().T
error_k_list.append(grad)
grad_norm_sum += w * float(np.linalg.norm(grad))
E_total = float(E_elec) + e_nuc
# EXT EL-SPHEROPOLE — CRYSTAL's K=0 Ewald reciprocal-space
# limit term, added to energy only (not the Fock matrix).
E_sphero = compute_ext_el_spheropole(D_used, basis, system, lat_opts)
E_total += E_sphero
dE = E_total - E_prev if iter_idx > 1 else 0.0
check_scf_divergence(
"run_pbc_bipole_rhf",
iter_idx,
E_total,
grad_norm_sum,
dE,
)
scf_trace.append(
SCFIteration(
iter=iter_idx,
energy=float(E_total),
delta_e=float(dE if iter_idx > 1 else 0.0),
grad_norm=float(grad_norm_sum),
diis_subspace=(accel.subspace_size if accel is not None else 0),
)
)
plog.iteration(
iter_idx,
energy=float(E_total),
dE=float(dE if iter_idx > 1 else 0.0),
grad=float(grad_norm_sum),
diis=(accel.subspace_size if accel is not None else 0),
)
energy_components.append(
PBCBipoleEnergyComponents(
iter=int(iter_idx),
e_total=float(E_total),
e_electronic=float(E_elec),
e_kinetic=float(E_kin),
e_nuclear_attraction=float(E_ne),
e_two_electron=float(E_2e),
e_nuclear_repulsion=float(e_nuc),
e_bielet_zone_ee=(None if use_ewald_j_split else float(E_2e)),
e_ext_el_spheropole=E_sphero,
e_j_short_range=E_j_short_range,
e_j_long_range=E_j_long_range,
e_exchange=E_exchange,
e_j_multipole=E_j_multipole,
)
)
plog.energy_decomposition(
iter_idx,
E_kin=float(E_kin),
E_ne=float(E_ne),
E_2e=float(E_2e),
E_elec=float(E_elec),
E_nuc=float(e_nuc),
)
# ---- Multipole far-field diagnostics (if enabled) -----------
if use_multipole_diag and system.dim == 3:
from .bipole_fock_multipole import (
build_j_far_field_multipole,
estimate_bipole_radius,
)
try:
R_bipole = estimate_bipole_radius(
system,
basis,
L_max=multipole_l_max,
)
far_j = build_j_far_field_multipole(
D_used,
basis,
system,
lat_opts_2e,
L_max=multipole_l_max,
R_bipole=R_bipole,
cache=_mp_config.cache if _mp_config.enabled else None,
)
plog.info(
f" BIPOLE far-field (L_max={multipole_l_max}, "
f"R_bipole={R_bipole:.1f} bohr): "
f"E_J_far = {far_j.e_j_far:+.6f} Ha, "
f"n_pairs = {far_j.n_cell_pairs}"
)
except Exception as exc:
plog.info(
f" BIPOLE far-field diagnostic failed: {type(exc).__name__}: {exc}"
)
converged = (
iter_idx > 1
and abs(dE) < float(opts.conv_tol_energy)
and grad_norm_sum < float(opts.conv_tol_grad)
)
# --- SCF-accelerator extrapolation. The full
# {DIIS, KDIIS, EDIIS, EDIIS_DIIS, ADIIS} family + dynamic_damping
# is wired on the multi-k BIPOLE path: DIIS / KDIIS run natively
# per-k (Pulay / orbital-rotation-gradient designs from M2c);
# EDIIS / ADIIS / EDIIS_DIIS bridge through the stacked-real-block
# representation landed in M2e (see
# ``per_k_to_stacked_real_blocks`` in
# ``periodic_scf_accelerators.py``).
if accel is not None:
density_k_list = [
_bloch_sum_blocks(D_used.blocks, D_used.cells, np.asarray(k))
for k in k_points
]
F_ex_list = accel.extrapolate_rhf(
F_k_list,
error_k_list=error_k_list,
density_k_list=density_k_list,
energy=E_total,
mo_coeffs_k_list=C_per_k,
n_occ=n_occ,
weights=list(weights),
cells=cells,
kpoints=list(k_points),
)
if diis_active:
F_k_list = F_ex_list
# --- LEVSHIFT (per-iter schedule or static)
if level_shift_schedule is not None:
level_shift_b = level_shift_schedule.at(iter_idx)
else:
level_shift_b = level_shift_static
if level_shift_b != 0.0:
F_for_diag: List[np.ndarray] = []
for idx in range(n_k):
D_k = D_k_list[idx]
S_k = S_k_list[idx]
F_shift = (
F_k_list[idx]
+ level_shift_b * S_k
- (level_shift_b / 2.0) * (S_k @ D_k @ S_k)
)
F_shift = 0.5 * (F_shift + F_shift.conj().T)
F_for_diag.append(F_shift)
else:
F_for_diag = F_k_list
# --- Diagonalise F(k) → new C(k), ε(k)
new_C_per_k = []
new_eps_per_k = []
for idx in range(n_k):
C_k, eps_k = _diag_in_orth_basis(F_for_diag[idx], X_k_list[idx])
new_C_per_k.append(C_k)
new_eps_per_k.append(eps_k)
# --- MOM reorder (iter ≥ 2 only; falls through to Aufbau at iter 1)
if use_mom and C_prev_occ_per_k is not None:
for idx in range(n_k):
C_k = new_C_per_k[idx]
eps_k = new_eps_per_k[idx]
S_k = S_k_list[idx]
sel = _mom_select(
C_k,
S_k,
C_prev_occ_per_k[idx],
n_occ,
eps_new=eps_k,
)
n_kept_idx = C_k.shape[1]
virt_mask = np.ones(n_kept_idx, dtype=bool)
virt_mask[sel] = False
virt_sel = np.where(virt_mask)[0]
virt_sel = virt_sel[np.argsort(np.real(eps_k[virt_sel]))]
order = np.concatenate([sel, virt_sel])
new_C_per_k[idx] = C_k[:, order]
new_eps_per_k[idx] = eps_k[order]
C_per_k = new_C_per_k
eps_per_k = new_eps_per_k
# --- Rebuild D_real. The initial SAD guess is localised at
# g=0 by construction, but after the first diagonalisation a
# Γ-only Bloch density is periodic over every retained cell
# block (phase = 1), matching CRYSTAL's PDIG_IR convention.
D_real_new = real_space_density_from_kpoints(
C_per_k,
[n_occ] * n_k,
kmesh,
cells,
)
# --- ODA mixing (extra Fock build)
if use_oda:
fock_naive = _build_fock_for_density(
D_real_new,
coeffs_for_rho=C_per_k,
)
oda_step = _compute_oda_lambda(
D_used,
D_real_new,
F_k_list,
fock_naive.f_k_list,
[np.asarray(k) for k in k_points],
weights,
trust_lambda_max=oda_trust_lambda_max,
)
_oda_mix(D_used, D_real_new, oda_step.lam)
D_real_prev = D_real
D_real = D_used
density_from_c_per_k = oda_step.lam == 1.0
plog.info(
f" ODA: λ = {oda_step.lam:.4f} "
f"(g0 = {oda_step.g0:+.3e}, g1 = {oda_step.g1:+.3e})"
)
else:
D_real_prev = D_used
D_real = D_real_new
density_from_c_per_k = True
# Snapshot for next iter's MOM
if use_mom:
C_prev_occ_per_k = [
np.asarray(C_per_k[idx][:, :n_occ]).copy() for idx in range(n_k)
]
if damper is not None:
damper.update(E_total)
E_prev = E_total
if converged:
break
plog.converged(n_iter=iter_idx, energy=E_total, converged=converged)
# ---- Post-loop: recompute energy on final density for consistency
if converged:
_fb = _build_fock_for_density(D_real, coeffs_for_rho=C_per_k)
E_kin_final = _lattice_contract(D_real, T_lat, operator_name="T")
E_ne_final = _lattice_contract(D_real, V_lat, operator_name="V_ne")
E_2e_final = 0.5 * _lattice_contract(
D_real,
_fb.f2e_real,
operator_name="F2e",
)
E_elec = E_kin_final + E_ne_final + E_2e_final
E_total = float(E_elec) + e_nuc
# Fresh E_total doesn't include spheropole — add it.
E_sphero_final = compute_ext_el_spheropole(D_real, basis, system, lat_opts)
E_total += E_sphero_final
else:
# Non-converged: E_total already includes spheropole from the
# last SCF iteration. Store it for the result.
E_sphero_final = energy_components[-1].e_ext_el_spheropole
return PBCBipoleRHFResult(
energy=float(E_total),
e_electronic=float(E_elec),
e_nuclear=e_nuc,
n_iter=iter_idx,
converged=converged,
mo_energies=eps_per_k,
mo_coeffs=C_per_k,
fock=F_k_list,
overlap=S_k_list,
hcore=Hcore_k_list,
density=D_real,
e_ext_el_spheropole=E_sphero_final,
scf_trace=scf_trace,
ewald_alpha_bohr_inv=omega_used,
energy_components=energy_components,
)