Source code for vibeqc.pbc_bipole_rks

"""BIPOLE-style periodic RKS driver in CRYSTAL's electrostatic gauge.

The DFT counterpart of :mod:`vibeqc.pbc_bipole`. Uses the same
CRYSTAL-gauge Ewald J-split F²e build for the Hartree term and
adds the libxc XC potential on the periodic Becke grid. For hybrid
functionals a fraction of HF exchange is retained via the native
``build_jk_2e_real_space`` component builder.

This driver keeps the same multi-k SCF scaffold as the RHF BIPOLE
driver: shared Ewald α across V_ne/E_nn/J^LR, analytic V_ne via
AO-pair Fourier transforms, real-space energy accounting, and
the full suite of convergence accelerators (DIIS, ODA, MOM,
level-shift schedule, damping).
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import List, Optional, Sequence, Union

import numpy as np

from ._vibeqc_core import (
    BasisSet,
    BlochKMesh,
    CoulombMethod,
    EwaldOptions,
    Functional,
    GridOptions,
    InitialGuess,
    LatticeMatrixSet,
    LatticeSumOptions,
    PeriodicKSOptions,
    PeriodicSystem,
    SCFIteration,
    bloch_sum,
    build_fock_2e_real_space,
    build_grid,
    build_jk_2e_real_space,
    build_xc_periodic,
    compute_kinetic_lattice,
    compute_overlap_lattice,
    direct_lattice_cells,
    ewald_nuclear_repulsion,
    nuclear_repulsion_per_cell,
    real_space_density_from_kpoints,
)
from .bipole_ext_el_pole import compute_ext_el_spheropole
from .guess import initial_density_closed_shell
from .level_shift_schedule import LevelShiftSchedule
from .mom import select_occupied_by_max_overlap as _mom_select
from .oda import compute_oda_lambda as _compute_oda_lambda
from .oda import oda_mix_densities as _oda_mix
from .pbc_bipole import (
    PBCBipoleEnergyComponents,
    _bloch_sum_blocks,
    _cell_key,
    _compute_nuclear_lattice_ewald_reciprocal_ft,
    _crystal_ewald_options,
    _default_bipole_v_ne_grid_options,
    _expand_ibz_kmesh_for_ewald_j,
    _lattice_contract,
    _lattice_contract_blocks,
)
from .periodic_grid import build_periodic_becke_grid
from .periodic_rhf_multi_k_ewald import (
    _canonical_orthogonalizer_complex,
    _damp_lattice_matrix,
    _diag_in_orth_basis,
)
from .periodic_scf_accelerators import (
    DynamicDamping,
    MultiKPeriodicSCFAccelerator,
)
from .periodic_v_ne import compute_nuclear_lattice_dispatch
from .progress import ProgressLogger, resolve_progress
from .scf_divergence import check_scf_divergence
from .smearing._support import reject_unsupported_smearing_temperature
from .symmetry_integrals_reduced import (
    compute_kinetic_lattice_reduced,
    compute_overlap_lattice_reduced,
)

__all__ = [
    "PBCBipoleRKSResult",
    "run_pbc_bipole_rks",
]


[docs] @dataclass class PBCBipoleRKSResult: """Result of :func:`run_pbc_bipole_rks`.""" energy: float e_electronic: float e_nuclear: float e_xc: float e_coulomb: float e_hf_exchange: float n_iter: int converged: bool mo_energies: List[np.ndarray] mo_coeffs: List[np.ndarray] fock: List[np.ndarray] overlap: List[np.ndarray] hcore: List[np.ndarray] density: LatticeMatrixSet scf_trace: List[SCFIteration] = field(default_factory=list) e_ext_el_spheropole: Optional[float] = None ewald_alpha_bohr_inv: Optional[float] = None energy_components: List[PBCBipoleEnergyComponents] = field( default_factory=list, ) functional: str = ""
@dataclass class _PBCBipoleRKSFockBuild: """Internal Fock-build bundle for the BIPOLE RKS driver.""" f2e_real: LatticeMatrixSet f_k_list: List[np.ndarray] e_j_short_range: Optional[float] = None e_j_long_range: Optional[float] = None e_exchange: Optional[float] = None e_j_multipole: Optional[float] = None def _density_set_gamma_or_lattice( template: LatticeMatrixSet, D_real: LatticeMatrixSet, ) -> LatticeMatrixSet: """Return a ``LatticeMatrixSet`` suitable for XC evaluation. If ``D_real`` carries per-cell blocks (multi-k), use it directly. Otherwise wrap a Γ-folded matrix in a degenerate set. """ if len(D_real.cells) > 1 and D_real.blocks: return D_real nbf = template.nbf D = ( np.asarray(D_real.blocks[0], dtype=float) if D_real.blocks else np.zeros((nbf, nbf)) ) zero = np.zeros_like(D) for i in range(len(template)): template.set_block(i, D if i == 0 else zero) return template
[docs] def run_pbc_bipole_rks( system: PeriodicSystem, basis: BasisSet, kmesh: BlochKMesh, options=None, *, functional: Optional[str] = None, linear_dep_threshold: float = 1e-7, canonical_orth_normalize_diag_first: bool = True, level_shift_schedule: Optional[LevelShiftSchedule] = None, use_mom: bool = False, use_oda: bool = False, oda_trust_lambda_max: float = 1.0, use_ewald_j_split: Optional[bool] = None, ewald_omega: Optional[float] = None, ewald_precision: float = 1e-8, v_ne_grid_options: Optional[GridOptions] = None, use_multipole_far_field: bool = False, multipole_l_max: int = 2, progress: Union[bool, ProgressLogger, None] = None, verbose: Optional[int] = None, initial_density: Optional[Sequence[np.ndarray]] = None, ) -> PBCBipoleRKSResult: """Multi-k closed-shell RKS via the CRYSTAL-gauge BIPOLE scaffold. Parameters ---------- system, basis, kmesh As in :func:`run_pbc_bipole_rhf`. options :class:`PeriodicKSOptions` or None for PBE defaults. functional XC functional name (overrides ``options.functional`` if given). All other parameters As in :func:`run_pbc_bipole_rhf`. Returns ------- PBCBipoleRKSResult """ from ._vibeqc_core import PeriodicKSOptions as _PKSOpts opts = options if options is not None else _PKSOpts() if functional is not None: opts.functional = str(functional) if not getattr(opts, "functional", None): opts.functional = "pbe" reject_unsupported_smearing_temperature( opts, "run_pbc_bipole_rks", detail=( "BIPOLE smearing is queued for a later smearing milestone; " "this driver would otherwise run integer Aufbau occupations." ), ) func = Functional(opts.functional, 1) # spin-unpolarised alpha_hf = float(func.hf_exchange_fraction) lat_opts: LatticeSumOptions = opts.lattice_opts plog = resolve_progress(progress, verbose=verbose) use_ewald_j_split_auto = use_ewald_j_split is None use_ewald_j_split = ( system.dim == 3 if use_ewald_j_split_auto else bool(use_ewald_j_split) ) lat_opts_2e = LatticeSumOptions() lat_opts_2e.cutoff_bohr = lat_opts.cutoff_bohr lat_opts_2e.nuclear_cutoff_bohr = lat_opts.nuclear_cutoff_bohr lat_opts_2e.coulomb_method = CoulombMethod.DIRECT_TRUNCATED lat_opts_1e = LatticeSumOptions() lat_opts_1e.cutoff_bohr = lat_opts.cutoff_bohr lat_opts_1e.nuclear_cutoff_bohr = lat_opts.nuclear_cutoff_bohr lat_opts_1e.coulomb_method = ( CoulombMethod.EWALD_3D if system.dim == 3 else CoulombMethod.DIRECT_TRUNCATED ) plog.info( f"PBC BIPOLE RKS (CRYSTAL-gauge) / cutoff {lat_opts.cutoff_bohr:.2f} bohr" ) plog.info(f" functional = {opts.functional}, hf_exchange_fraction = {alpha_hf}") plog.info( f" F^2e (J + V_xc{'+ K' if alpha_hf > 0 else ''}) : " f"{'EWALD_J_SPLIT' if use_ewald_j_split else lat_opts_2e.coulomb_method.name}" ) n_elec = system.n_electrons() if n_elec % 2 != 0: raise ValueError( f"run_pbc_bipole_rks: closed-shell RKS requires even electron " f"count; got {n_elec}" ) if system.multiplicity != 1: raise ValueError( f"run_pbc_bipole_rks: requires multiplicity=1; got {system.multiplicity}" ) n_occ = n_elec // 2 _kmesh_ibz = kmesh _ir_mapping = np.asarray(getattr(kmesh, "ir_mapping", []), dtype=int).reshape(-1) k_points = list(_kmesh_ibz.kpoints) weights = np.asarray(_kmesh_ibz.weights, dtype=float) if use_ewald_j_split and _ir_mapping.size > 0: kmesh_full = _expand_ibz_kmesh_for_ewald_j(system, kmesh, plog) k_points_full = list(kmesh_full.kpoints) weights_full = np.asarray(kmesh_full.weights, dtype=float) else: k_points_full = k_points weights_full = weights n_k = len(k_points) if n_k == 0: raise ValueError("kmesh has no k-points") if not np.isclose(weights.sum(), 1.0): raise ValueError(f"kmesh.weights must sum to 1; got {weights.sum():.6f}") plog.info( f"k-mesh: {n_k} k-point{'s' if n_k != 1 else ''}, " f"weights sum = {weights.sum():.4f}" ) # ---- Shared Ewald state (CRYSTAL COMMON/VRSMAD pattern) ---------- ewald_options_1e: Optional[EwaldOptions] = None omega_used: Optional[float] = None ewald_cell_volume: Optional[float] = None ewald_k_max: Optional[float] = None if system.dim == 3: from .bipole_ext_el_pole import ( crystal_default_ewald_alpha, crystal_ewald_reciprocal_cutoff, ) V_cell = float(abs(np.linalg.det(np.asarray(system.lattice, dtype=float)))) ewald_cell_volume = V_cell omega_used = ( float(ewald_omega) if ewald_omega is not None else crystal_default_ewald_alpha(V_cell) ) ewald_k_max = crystal_ewald_reciprocal_cutoff(V_cell) ewald_options_1e = _crystal_ewald_options( lat_opts_1e, alpha_bohr_inv=omega_used, tolerance=float(ewald_precision), recip_cutoff_bohr_inv=ewald_k_max, ) # ---- DFT grid ---------------------------------------------------- if getattr(opts, "use_periodic_becke", False): grid = build_periodic_becke_grid( system, grid_options=opts.grid, image_radius_bohr=float(getattr(opts, "becke_image_radius_bohr", 10.0)), ) else: grid = build_grid(system.unit_cell_molecule(), opts.grid) # ---- Real-space one-electron integrals --------------------------- with plog.stage( "integrals_lattice", detail=f"S/T/V at cutoff {lat_opts.cutoff_bohr:.2f} bohr", ): _use_sym = False # FIXME: LatticeMatrixSet has no Python constructor if _use_sym: ops = system.symmetry.operations _, S_blocks = compute_overlap_lattice_reduced( basis, system, lat_opts_2e, ops, ) S_lat = LatticeMatrixSet() S_lat.nbf = basis.nbasis S_lat.cells = direct_lattice_cells(system, lat_opts_2e.cutoff_bohr) S_lat.blocks = S_blocks _, T_blocks = compute_kinetic_lattice_reduced( basis, system, lat_opts_2e, ops, ) T_lat = LatticeMatrixSet() T_lat.nbf = basis.nbasis T_lat.cells = direct_lattice_cells(system, lat_opts_2e.cutoff_bohr) T_lat.blocks = T_blocks else: S_lat = compute_overlap_lattice(basis, system, lat_opts_2e) T_lat = compute_kinetic_lattice(basis, system, lat_opts_2e) v_ne_lr_cache = None if ( system.dim == 3 and ewald_options_1e is not None and v_ne_grid_options is None ): V_lat, v_ne_lr_cache = _compute_nuclear_lattice_ewald_reciprocal_ft( basis, system, lat_opts_1e, ewald_options_1e, S_lat, precision=ewald_precision, K_max=ewald_k_max, ) else: v_ne_grid = ( v_ne_grid_options if v_ne_grid_options is not None else (_default_bipole_v_ne_grid_options() if system.dim == 3 else None) ) V_lat = compute_nuclear_lattice_dispatch( basis, system, lat_opts_1e, grid_options=v_ne_grid, ewald_options=ewald_options_1e, ) cells = list(S_lat.cells) plog.info(f"n_cells in lattice sum = {len(cells)}") # ---- Per-k S(k), Hcore(k), X(k) --------------------------------- from .linear_dependence import ( check_overlap_matrix, format_linear_dependence_report, raise_if_severe, scf_preflight_overlap_check, ) S_k_list: List[np.ndarray] = [] Hcore_k_list: List[np.ndarray] = [] X_k_list: List[np.ndarray] = [] overlap_reports = [] for k_idx, k in enumerate(k_points): k_arr = np.asarray(k, dtype=float).reshape(3) S_k = np.asarray(bloch_sum(S_lat, k_arr)) T_k = np.asarray(bloch_sum(T_lat, k_arr)) V_k = np.asarray(bloch_sum(V_lat, k_arr)) H_k = T_k + V_k S_k = 0.5 * (S_k + S_k.conj().T) H_k = 0.5 * (H_k + H_k.conj().T) overlap_label = f"S(k={k_idx}, k_cart={k_arr.round(4).tolist()})" if n_k <= 16: report = scf_preflight_overlap_check( S_k, plog=plog, label=overlap_label, basis=basis, ) else: report = check_overlap_matrix( S_k, basis=basis, label=overlap_label, ) if report.severity != "ok": prefix = { "warn": "WARN", "error": "ERROR", "critical": "CRITICAL", }[report.severity] plog.info( f"[{prefix}] overlap [{overlap_label}]: " f"nbf={report.n_basis}, " f"min eig={report.min_eigenvalue:+.2e}, " f"cond={report.condition_number:.2e}" ) plog.write_raw(format_linear_dependence_report(report)) raise_if_severe(report) X_k, n_kept = _canonical_orthogonalizer_complex( S_k, linear_dep_threshold, normalize_diag_first=canonical_orth_normalize_diag_first, ) overlap_reports.append(report) if n_occ > n_kept: raise RuntimeError( f"run_pbc_bipole_rks: canonical orth at k={k_idx} " f"dropped too many directions (n_occ={n_occ}, " f"n_kept={n_kept})" ) S_k_list.append(S_k) Hcore_k_list.append(H_k) X_k_list.append(X_k) # ---- Nuclear repulsion ------------------------------------------- if ewald_options_1e is not None: e_nuc = float(ewald_nuclear_repulsion(system, ewald_options_1e)) else: e_nuc = float(nuclear_repulsion_per_cell(system, lat_opts_1e)) plog.info(f"E_nuc per cell = {e_nuc:+.10f} Ha") # ---- Initial guess ----------------------------------------------- C_per_k: List[np.ndarray] = [] eps_per_k: List[np.ndarray] = [] for H_k, X_k in zip(Hcore_k_list, X_k_list): C_k, eps_k = _diag_in_orth_basis(H_k, X_k) C_per_k.append(C_k.astype(complex)) eps_per_k.append(eps_k) n_occ_per_k = [n_occ] * n_k D_real = real_space_density_from_kpoints( C_per_k, n_occ_per_k, kmesh, cells, ) # Caller-supplied warm-start density takes precedence over the # SAD/Hcore guess engine. Block ordering matches the canonical # ``direct_lattice_cells(kmesh)`` ordering — same contract as the # RHF driver. Used by the NEB driver for within-image density # warm-start (HANDOVER_PERIODIC_NEB.md M4). if initial_density is not None: blocks_in = list(initial_density) if len(blocks_in) != len(D_real.cells): raise ValueError( f"run_pbc_bipole_rks: initial_density has " f"{len(blocks_in)} blocks; expected {len(D_real.cells)}" ) for g_idx, block in enumerate(blocks_in): D_real.set_block(g_idx, np.asarray(block, dtype=float)) plog.info("initial guess: caller-supplied density (warm-start)") initial_density_is_local = True density_from_c_per_k = False else: guess = getattr(opts, "initial_guess", InitialGuess.HCORE) D_engine = initial_density_closed_shell( system.unit_cell_molecule(), basis, n_occ, guess, is_periodic=True, ) if D_engine is not None: plog.info(f"initial guess: {guess.name} (g=0 density from GuessEngine)") for g_idx in range(len(D_real.cells)): if (D_real.cells[g_idx].index == np.array([0, 0, 0])).all(): D_real.set_block(g_idx, D_engine) else: D_real.set_block( g_idx, np.zeros_like(D_engine, dtype=float), ) else: plog.info(f"initial guess: {guess.name} (Hcore-diag per k)") initial_density_is_local = D_engine is not None density_from_c_per_k = not initial_density_is_local D_real_prev: Optional[LatticeMatrixSet] = None # ---- SCF aids ---------------------------------------------------- damping = float(opts.damping) damper: Optional[DynamicDamping] = None if bool(getattr(opts, "dynamic_damping", False)): damper = DynamicDamping( initial_alpha=damping, alpha_min=float(getattr(opts, "dynamic_damping_min", 0.0)), alpha_max=float(getattr(opts, "dynamic_damping_max", 0.95)), ) use_diis = bool(opts.use_diis) diis_start_iter = int(opts.diis_start_iter) accel: Optional[MultiKPeriodicSCFAccelerator] = ( MultiKPeriodicSCFAccelerator(opts) if use_diis else None ) level_shift_static = float(getattr(opts, "level_shift", 0.0)) if level_shift_schedule is not None and not isinstance( level_shift_schedule, LevelShiftSchedule, ): raise TypeError( f"level_shift_schedule must be LevelShiftSchedule; " f"got {type(level_shift_schedule).__name__}" ) if use_mom: plog.info("MOM: ON") C_prev_occ_per_k: Optional[List[np.ndarray]] = None if use_oda and use_diis: raise ValueError("use_oda and use_diis are mutually exclusive") if use_oda: plog.info(f"ODA: ON (trust λ_max = {oda_trust_lambda_max})") # ---- Ewald J-split cache ----------------------------------------- j_lr_cache = v_ne_lr_cache if use_ewald_j_split: if system.dim != 3: raise ValueError("use_ewald_j_split requires dim=3") if n_k > 1: uniform_w = 1.0 / float(n_k) if not np.allclose(weights, uniform_w, atol=1e-9): raise ValueError( "use_ewald_j_split at multi-k requires uniform " "full-mesh weights or IBZ-reduced mesh with ir_mapping" ) from .bipole_fock_ewald import ( _build_j_long_range_cache, compute_J_long_range_real_space_blocks, compute_rho_hat_from_k_density, ) assert omega_used is not None cells_r_cart_arr = np.array( [np.asarray(c.r_cart, dtype=float) for c in cells], dtype=float, ) if j_lr_cache is None: j_lr_cache = _build_j_long_range_cache( basis, system, cells_r_cart_arr, omega_used, ewald_precision, K_max=ewald_k_max, ) def _density_matrices_from_coeffs( coeffs_per_k: Sequence[np.ndarray], ) -> List[np.ndarray]: out: List[np.ndarray] = [] for C_k_raw in coeffs_per_k: C_k_arr = np.asarray(C_k_raw) C_occ_k = C_k_arr[:, :n_occ] out.append(2.0 * (C_occ_k @ C_occ_k.conj().T)) return out def _build_fock_for_density( density: LatticeMatrixSet, *, coeffs_for_rho: Optional[Sequence[np.ndarray]], ) -> _PBCBipoleRKSFockBuild: """Build F²e(g) and F(k) with Ewald J-split + V_xc.""" if use_ewald_j_split: # Short-range J via erfc-screened ERIs F_J_SR_lat = build_fock_2e_real_space( basis, system, lat_opts_2e, density, 0.0, float(omega_used), ) # Full K jk_full_lat = build_jk_2e_real_space( basis, system, lat_opts_2e, density, 0.0, ) K_full_blocks = [ np.asarray(jk_full_lat.K.blocks[c], dtype=float) for c in range(len(jk_full_lat.K.cells)) ] cells_lr = F_J_SR_lat.cells # Long-range J via reciprocal-space sum rho_hat_for_LR = None if coeffs_for_rho is not None: D_k_for_rho = _density_matrices_from_coeffs(coeffs_for_rho) if _ir_mapping.size > 0: D_k_full = [D_k_for_rho[int(idx)] for idx in _ir_mapping] rho_hat_for_LR = compute_rho_hat_from_k_density( D_k_full, k_points_full, weights_full, j_lr_cache, ) else: rho_hat_for_LR = compute_rho_hat_from_k_density( D_k_for_rho, k_points, weights, j_lr_cache, ) F_LR_blocks = compute_J_long_range_real_space_blocks( density, basis, system, omega_used, precision=ewald_precision, cache=j_lr_cache, rho_hat=rho_hat_for_LR, ) # Neutralising background assert ewald_cell_volume is not None j_background_potential = ( -np.pi * float(n_elec) / (float(omega_used) * float(omega_used) * float(ewald_cell_volume)) ) s_blocks = { _cell_key(cell): np.asarray(block, dtype=float) for cell, block in zip(S_lat.cells, S_lat.blocks) } for c, cell in enumerate(cells_lr): key = _cell_key(cell) F_LR_blocks[c] = F_LR_blocks[c] + j_background_potential * s_blocks[key] # Energy components j_sr_blocks = [ np.asarray(F_J_SR_lat.blocks[c], dtype=float) for c in range(len(cells_lr)) ] e_j_short_range = 0.5 * _lattice_contract_blocks( density, cells_lr, j_sr_blocks, operator_name="J_SR", ) e_j_long_range = 0.5 * _lattice_contract_blocks( density, cells_lr, F_LR_blocks, operator_name="J_LR", ) # Assemble F²e(g) = J_SR(g) + J_LR(g) - (α/2)·K(g) e_exchange: Optional[float] = None f2e_direct_blocks = [] for c in range(len(cells_lr)): j_blk = j_sr_blocks[c] + F_LR_blocks[c] if alpha_hf > 0.0: k_blk = K_full_blocks[c] j_blk = j_blk - 0.5 * alpha_hf * k_blk if e_exchange is None: e_exchange = 0.0 d_blk = np.asarray(density.blocks[c], dtype=float) e_exchange = e_exchange - 0.25 * alpha_hf * float( np.sum(d_blk * k_blk) ) f2e_direct_blocks.append(j_blk) f2e_real = F_J_SR_lat for c in range(len(cells_lr)): f2e_real.set_block(c, f2e_direct_blocks[c]) # ---- Optional: replace J_SR+J_LR with multipole far-field J -- e_j_multipole: Optional[float] = None if _mp_config.enabled: from .bipole_fock_multipole import apply_multipole_far_field mp_result = apply_multipole_far_field( density, basis, system, lat_opts_2e, _mp_config, J_SR_blocks=j_sr_blocks, K_blocks=K_full_blocks, F_LR_blocks=F_LR_blocks, exchange_scale=0.5 * alpha_hf, ) for c in range(len(cells_lr)): f2e_real.set_block(c, mp_result.f2e_blocks[c]) e_j_multipole = mp_result.e_j_multipole if mp_result.n_far_cells > 0: plog.info( f" BIPOLE multipole far-field (L_max={_mp_config.L_max}, " f"R={_mp_config.R_bipole:.1f} bohr): " f"{mp_result.n_far_cells}/{mp_result.n_total_cells} cells replaced, " f"E_J_far = {e_j_multipole:+.6f} Ha" ) else: exchange_scale = alpha_hf f2e_real = build_fock_2e_real_space( basis, system, lat_opts_2e, density, exchange_scale, 0.0, ) e_j_short_range = None e_j_long_range = None e_exchange = None # ---- Build XC potential on the DFT grid ---------------------- D_xc_set = _density_set_gamma_or_lattice(S_lat, density) xc_result = build_xc_periodic(basis, system, grid, D_xc_set, func) # ---- Assemble per-k Fock matrices ---------------------------- f_k_list: List[np.ndarray] = [] for k_idx, k in enumerate(k_points): k_arr = np.asarray(k, dtype=float) F2e_k = _bloch_sum_blocks( f2e_real.blocks, f2e_real.cells, k_arr, ) Vxc_k = np.asarray(bloch_sum(xc_result.vxc_lattice, k_arr)) F_k = F2e_k + Vxc_k + np.asarray(Hcore_k_list[k_idx], dtype=complex) F_k = 0.5 * (F_k + F_k.conj().T) f_k_list.append(F_k) return _PBCBipoleRKSFockBuild( f2e_real=f2e_real, f_k_list=f_k_list, e_j_short_range=e_j_short_range, e_j_long_range=e_j_long_range, e_exchange=e_exchange, e_j_multipole=e_j_multipole, ) # ---- Multipole far-field config (resolve once before SCF loop) ----- from .bipole_fock_multipole import ( # noqa: E402 BipoleMultipoleConfig, resolve_multipole_config, ) _mp_config = resolve_multipole_config( system, basis, lat_opts_2e, user_enable=use_multipole_far_field, multipole_l_max=multipole_l_max, ) if _mp_config.enabled: plog.info( f" BIPOLE multipole far-field: ENABLED " f"(L_max={_mp_config.L_max}, R_bipole={_mp_config.R_bipole:.1f} bohr, " f"n_cells={len(_mp_config.cache.cells) if _mp_config.cache else 0})" ) else: plog.info( f" BIPOLE multipole far-field: off " f"(R_bipole={_mp_config.R_bipole:.1f} bohr, " f"cutoff={lat_opts_2e.cutoff_bohr:.1f} bohr)" ) plog.banner("SCF (PBC BIPOLE RKS, direct-space)") plog.info(" iter energy (Ha) dE ||[F,DS]|| DIIS") scf_trace: List[SCFIteration] = [] energy_components: List[PBCBipoleEnergyComponents] = [] E_prev = 0.0 F_k_list: List[np.ndarray] = [np.zeros_like(H) for H in Hcore_k_list] E_elec = 0.0 e_xc = 0.0 converged = False iter_idx = 0 for iter_idx in range(1, int(opts.max_iter) + 1): if damper is not None: damping = damper.alpha diis_active = use_diis and iter_idx >= diis_start_iter D_used = D_real if iter_idx > 1 and damping > 0.0 and not diis_active: D_used = _damp_lattice_matrix(D_real, D_real_prev, damping) d_used_is_damped = iter_idx > 1 and damping > 0.0 and not diis_active d_used_from_coeffs = ( density_from_c_per_k and not (initial_density_is_local and iter_idx == 1) and not d_used_is_damped ) fock_build = _build_fock_for_density( D_used, coeffs_for_rho=(C_per_k if d_used_from_coeffs else None), ) F2e_real = fock_build.f2e_real F_k_list = fock_build.f_k_list # ---- XC energy ----------------------------------------------- D_xc_set = _density_set_gamma_or_lattice(S_lat, D_used) xc_result = build_xc_periodic(basis, system, grid, D_xc_set, func) e_xc = float(xc_result.exc) # ---- Energy -------------------------------------------------- E_kin = _lattice_contract(D_used, T_lat, operator_name="T") E_ne = _lattice_contract(D_used, V_lat, operator_name="V_ne") E_2e = 0.5 * _lattice_contract( D_used, F2e_real, operator_name="F2e", ) E_elec = E_kin + E_ne + E_2e + e_xc grad_norm_sum = 0.0 error_k_list: List[np.ndarray] = [] D_k_list: List[np.ndarray] = [] for idx in range(n_k): if initial_density_is_local and iter_idx == 1: k_arr = np.asarray(k_points[idx], dtype=float) D_k = _bloch_sum_blocks( D_used.blocks, D_used.cells, k_arr, ) D_k = 0.5 * (D_k + D_k.conj().T) else: C_k = C_per_k[idx] C_occ = C_k[:, :n_occ] D_k = 2.0 * (C_occ @ C_occ.conj().T) D_k_list.append(D_k) F_k = F_k_list[idx] w = float(weights[idx]) S_k = S_k_list[idx] FDS = F_k @ D_k @ S_k grad = FDS - FDS.conj().T error_k_list.append(grad) grad_norm_sum += w * float(np.linalg.norm(grad)) E_total = float(E_elec) + e_nuc # EXT EL-SPHEROPOLE E_sphero = compute_ext_el_spheropole(D_used, basis, system, lat_opts) E_total += E_sphero dE = E_total - E_prev if iter_idx > 1 else 0.0 check_scf_divergence( "run_pbc_bipole_rks", iter_idx, E_total, grad_norm_sum, dE, ) scf_trace.append( SCFIteration( iter=iter_idx, energy=float(E_total), delta_e=float(dE if iter_idx > 1 else 0.0), grad_norm=float(grad_norm_sum), diis_subspace=(accel.subspace_size if accel is not None else 0), ) ) plog.iteration( iter_idx, energy=float(E_total), dE=float(dE if iter_idx > 1 else 0.0), grad=float(grad_norm_sum), diis=(accel.subspace_size if accel is not None else 0), ) energy_components.append( PBCBipoleEnergyComponents( iter=int(iter_idx), e_total=float(E_total), e_electronic=float(E_elec), e_kinetic=float(E_kin), e_nuclear_attraction=float(E_ne), e_two_electron=float(E_2e), e_nuclear_repulsion=float(e_nuc), e_j_short_range=fock_build.e_j_short_range, e_j_long_range=fock_build.e_j_long_range, e_exchange=fock_build.e_exchange, e_ext_el_spheropole=E_sphero, ) ) plog.energy_decomposition( iter_idx, E_kin=float(E_kin), E_ne=float(E_ne), E_2e=float(E_2e), E_elec=float(E_elec), E_nuc=float(e_nuc), ) converged = ( iter_idx > 1 and abs(dE) < float(opts.conv_tol_energy) and grad_norm_sum < float(opts.conv_tol_grad) ) # SCF-accelerator extrapolation. Full # {DIIS, KDIIS, EDIIS, EDIIS_DIIS, ADIIS} family wired here; # bridged modes route through the M2e stacked-real-block bridge # (see ``per_k_to_stacked_real_blocks`` in # ``periodic_scf_accelerators.py``). if accel is not None: density_k_list = [ _bloch_sum_blocks(D_used.blocks, D_used.cells, np.asarray(k)) for k in k_points ] F_ex_list = accel.extrapolate_rhf( F_k_list, error_k_list=error_k_list, density_k_list=density_k_list, energy=E_total, mo_coeffs_k_list=C_per_k, n_occ=n_occ, weights=list(weights), cells=cells, kpoints=list(k_points), ) if diis_active: F_k_list = F_ex_list # Level shift if level_shift_schedule is not None: level_shift_b = level_shift_schedule.at(iter_idx) else: level_shift_b = level_shift_static if level_shift_b != 0.0: F_for_diag: List[np.ndarray] = [] for idx in range(n_k): D_k = D_k_list[idx] S_k = S_k_list[idx] F_shift = ( F_k_list[idx] + level_shift_b * S_k - (level_shift_b / 2.0) * (S_k @ D_k @ S_k) ) F_shift = 0.5 * (F_shift + F_shift.conj().T) F_for_diag.append(F_shift) else: F_for_diag = F_k_list # Diagonalise new_C_per_k = [] new_eps_per_k = [] for idx in range(n_k): C_k, eps_k = _diag_in_orth_basis( F_for_diag[idx], X_k_list[idx], ) new_C_per_k.append(C_k) new_eps_per_k.append(eps_k) # MOM if use_mom and C_prev_occ_per_k is not None: for idx in range(n_k): C_k = new_C_per_k[idx] eps_k = new_eps_per_k[idx] S_k = S_k_list[idx] sel = _mom_select( C_k, S_k, C_prev_occ_per_k[idx], n_occ, eps_new=eps_k, ) n_kept_idx = C_k.shape[1] virt_mask = np.ones(n_kept_idx, dtype=bool) virt_mask[sel] = False virt_sel = np.where(virt_mask)[0] virt_sel = virt_sel[np.argsort(np.real(eps_k[virt_sel]))] order = np.concatenate([sel, virt_sel]) new_C_per_k[idx] = C_k[:, order] new_eps_per_k[idx] = eps_k[order] C_per_k = new_C_per_k eps_per_k = new_eps_per_k D_real_new = real_space_density_from_kpoints( C_per_k, [n_occ] * n_k, kmesh, cells, ) # ODA if use_oda: fock_naive = _build_fock_for_density( D_real_new, coeffs_for_rho=C_per_k, ) oda_step = _compute_oda_lambda( D_used, D_real_new, F_k_list, fock_naive.f_k_list, [np.asarray(k) for k in k_points], weights, trust_lambda_max=oda_trust_lambda_max, ) _oda_mix(D_used, D_real_new, oda_step.lam) D_real_prev = D_real D_real = D_used density_from_c_per_k = oda_step.lam == 1.0 plog.info(f" ODA: λ = {oda_step.lam:.4f}") else: D_real_prev = D_used D_real = D_real_new density_from_c_per_k = True if use_mom: C_prev_occ_per_k = [ np.asarray(C_per_k[idx][:, :n_occ]).copy() for idx in range(n_k) ] if damper is not None: damper.update(E_total) E_prev = E_total if converged: break plog.converged(n_iter=iter_idx, energy=E_total, converged=converged) # ---- Post-loop: recompute energy on final density for consistency if converged: _fb = _build_fock_for_density(D_real, coeffs_for_rho=C_per_k) D_xc_set = _density_set_gamma_or_lattice(S_lat, D_real) _xc = build_xc_periodic(basis, system, grid, D_xc_set, func) E_kin_f = _lattice_contract(D_real, T_lat, operator_name="T") E_ne_f = _lattice_contract(D_real, V_lat, operator_name="V_ne") E_2e_f = 0.5 * _lattice_contract(D_real, _fb.f2e_real, operator_name="F2e") E_elec = E_kin_f + E_ne_f + E_2e_f + float(_xc.exc) E_total = float(E_elec) + e_nuc # Fresh E_total doesn't include spheropole — add it. E_sphero_final = compute_ext_el_spheropole(D_real, basis, system, lat_opts) E_total += E_sphero_final else: E_sphero_final = energy_components[-1].e_ext_el_spheropole return PBCBipoleRKSResult( energy=float(E_total), e_electronic=float(E_elec), e_nuclear=e_nuc, e_ext_el_spheropole=E_sphero_final, e_xc=float(e_xc), e_coulomb=float(E_2e), e_hf_exchange=float(fock_build.e_exchange or 0.0), n_iter=iter_idx, converged=converged, mo_energies=eps_per_k, mo_coeffs=C_per_k, fock=F_k_list, overlap=S_k_list, hcore=Hcore_k_list, density=D_real, scf_trace=scf_trace, ewald_alpha_bohr_inv=omega_used, energy_components=energy_components, functional=str(opts.functional), )