Source code for vibeqc.pbc_bipole_uhf

"""BIPOLE-style periodic UHF driver in CRYSTAL's electrostatic gauge.

This is the open-shell counterpart of :mod:`vibeqc.pbc_bipole`. It keeps
the same CRYSTAL-inspired composition:

* ``V_ne`` and ``E_nn`` share one explicit 3D Ewald state.
* The default 3D two-electron build uses ``J_SR(α) + J_LR(α)`` for the
  Hartree operator with that same alpha, plus full-range per-spin
  exchange from the direct real-space builder.
* Energies are evaluated by real-space lattice contractions so the
  first local SAD/Hcore cycle has the same accounting convention as the
  RHF BIPOLE driver.

The full Saunders-Dovesi-Roetti far-pair multipole replacement is still
not wired; this driver is the spin-unrestricted Ewald-J bridge used for
CRYSTAL parity work while the native BIPOLE branch is being completed.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np

from ._vibeqc_core import (
    BasisSet,
    BlochKMesh,
    CoulombMethod,
    EwaldOptions,
    GridOptions,
    InitialGuess,
    LatticeMatrixSet,
    LatticeSumOptions,
    PeriodicRHFOptions,
    PeriodicSystem,
    SCFIteration,
    bloch_sum,
    build_fock_2e_real_space,
    build_jk_2e_real_space,
    compute_kinetic_lattice,
    compute_overlap_lattice,
    direct_lattice_cells,
    ewald_nuclear_repulsion,
    nuclear_repulsion_per_cell,
    real_space_density_from_kpoints_fractional,
)
from .bipole_ext_el_pole import compute_ext_el_spheropole
from .guess import initial_densities_open_shell, initial_density_closed_shell
from .level_shift_schedule import LevelShiftSchedule
from .mom import select_occupied_by_max_overlap as _mom_select
from .oda import compute_oda_lambda as _compute_oda_lambda
from .oda import oda_mix_densities as _oda_mix
from .pbc_bipole import (
    PBCBipoleEnergyComponents,
    _bloch_sum_blocks,
    _cell_key,
    _compute_nuclear_lattice_ewald_reciprocal_ft,
    _crystal_ewald_options,
    _default_bipole_v_ne_grid_options,
    _expand_ibz_kmesh_for_ewald_j,
    _lattice_contract,
    _lattice_contract_blocks,
)
from .periodic_rhf_multi_k_ewald import (
    _canonical_orthogonalizer_complex,
    _damp_lattice_matrix,
    _diag_in_orth_basis,
)
from .periodic_scf_accelerators import (
    DynamicDamping,
    MultiKPeriodicUHFAccelerator,
)
from .periodic_uhf_ewald import _spin_squared
from .periodic_v_ne import compute_nuclear_lattice_dispatch
from .progress import ProgressLogger, resolve_progress
from .scf_divergence import check_scf_divergence
from .smearing._support import reject_unsupported_smearing_temperature
from .symmetry_integrals_reduced import (
    compute_kinetic_lattice_reduced,
    compute_overlap_lattice_reduced,
)

__all__ = [
    "PBCBipoleUHFResult",
    "run_pbc_bipole_uhf",
]


[docs] @dataclass class PBCBipoleUHFResult: """Result of :func:`run_pbc_bipole_uhf`.""" energy: float e_electronic: float e_nuclear: float n_iter: int converged: bool s_squared: float s_squared_ideal: float mo_energies_alpha: List[np.ndarray] mo_coeffs_alpha: List[np.ndarray] fock_alpha: List[np.ndarray] density_alpha: LatticeMatrixSet mo_energies_beta: List[np.ndarray] mo_coeffs_beta: List[np.ndarray] fock_beta: List[np.ndarray] density_beta: LatticeMatrixSet overlap: List[np.ndarray] hcore: List[np.ndarray] scf_trace: List[SCFIteration] = field(default_factory=list) e_ext_el_spheropole: Optional[float] = None ewald_alpha_bohr_inv: Optional[float] = None energy_components: List[PBCBipoleEnergyComponents] = field( default_factory=list, )
@dataclass class _PBCBipoleUHFFockBuild: """Internal Fock bundle for one UHF density pair.""" f2e_alpha_real: LatticeMatrixSet f2e_beta_real: LatticeMatrixSet f_alpha_k_list: List[np.ndarray] f_beta_k_list: List[np.ndarray] e_j_short_range: Optional[float] = None e_j_long_range: Optional[float] = None e_exchange: Optional[float] = None e_j_multipole: Optional[float] = None def _spin_occupations(system: PeriodicSystem) -> Tuple[int, int]: n_elec = int(system.n_electrons()) mult = int(system.multiplicity) if mult < 1: raise ValueError(f"run_pbc_bipole_uhf: multiplicity must be >= 1; got {mult}") if (n_elec + mult - 1) % 2 != 0 or (n_elec - mult + 1) % 2 != 0: raise ValueError( f"run_pbc_bipole_uhf: (n_electrons={n_elec}, " f"multiplicity={mult}) cannot be split into integer alpha/beta." ) n_alpha = (n_elec + mult - 1) // 2 n_beta = (n_elec - mult + 1) // 2 if n_beta < 0: raise ValueError( f"run_pbc_bipole_uhf: multiplicity={mult} is too large for " f"{n_elec} electrons" ) return n_alpha, n_beta def _combine_density_sets( basis: BasisSet, system: PeriodicSystem, lat_opts: LatticeSumOptions, D_alpha: LatticeMatrixSet, D_beta: LatticeMatrixSet, ) -> LatticeMatrixSet: """Return a fresh ``D_alpha + D_beta`` lattice set.""" out = compute_overlap_lattice(basis, system, lat_opts) alpha_blocks = { _cell_key(cell): np.asarray(block, dtype=float) for cell, block in zip(D_alpha.cells, D_alpha.blocks) } beta_blocks = { _cell_key(cell): np.asarray(block, dtype=float) for cell, block in zip(D_beta.cells, D_beta.blocks) } for idx, cell in enumerate(out.cells): key = _cell_key(cell) if key not in alpha_blocks or key not in beta_blocks: raise ValueError( f"_combine_density_sets: missing spin density block for cell {key}" ) out.set_block(idx, alpha_blocks[key] + beta_blocks[key]) return out def _copy_lattice_with_blocks( basis: BasisSet, system: PeriodicSystem, lat_opts: LatticeSumOptions, cells, blocks: Sequence[np.ndarray], ) -> LatticeMatrixSet: """Build a fresh lattice matrix set with the given cell-indexed blocks.""" out = compute_overlap_lattice(basis, system, lat_opts) block_by_cell = { _cell_key(cell): np.asarray(block, dtype=float) for cell, block in zip(cells, blocks) } for idx, cell in enumerate(out.cells): key = _cell_key(cell) if key not in block_by_cell: raise ValueError(f"_copy_lattice_with_blocks: missing block for cell {key}") out.set_block(idx, block_by_cell[key]) return out
[docs] def run_pbc_bipole_uhf( system: PeriodicSystem, basis: BasisSet, kmesh: BlochKMesh, options: Optional[PeriodicRHFOptions] = None, *, linear_dep_threshold: float = 1e-7, canonical_orth_normalize_diag_first: bool = True, level_shift_schedule: Optional[LevelShiftSchedule] = None, use_mom: bool = False, use_oda: bool = False, oda_trust_lambda_max: float = 1.0, use_ewald_j_split: Optional[bool] = None, ewald_omega: Optional[float] = None, ewald_precision: float = 1e-8, v_ne_grid_options: Optional[GridOptions] = None, use_multipole_far_field: Optional[bool] = None, multipole_l_max: int = 2, progress: Union[bool, ProgressLogger, None] = None, verbose: Optional[int] = None, init_alpha: Optional[Sequence[np.ndarray]] = None, init_beta: Optional[Sequence[np.ndarray]] = None, ) -> PBCBipoleUHFResult: """Multi-k open-shell UHF via the CRYSTAL-gauge BIPOLE scaffold.""" opts = options if options is not None else PeriodicRHFOptions() reject_unsupported_smearing_temperature( opts, "run_pbc_bipole_uhf", detail=( "BIPOLE smearing is queued for a later smearing milestone; " "this driver would otherwise run integer spin occupations." ), ) lat_opts: LatticeSumOptions = opts.lattice_opts plog = resolve_progress(progress, verbose=verbose) use_ewald_j_split_auto = use_ewald_j_split is None use_ewald_j_split = ( system.dim == 3 if use_ewald_j_split_auto else bool(use_ewald_j_split) ) lat_opts_2e = LatticeSumOptions() lat_opts_2e.cutoff_bohr = lat_opts.cutoff_bohr lat_opts_2e.nuclear_cutoff_bohr = lat_opts.nuclear_cutoff_bohr lat_opts_2e.coulomb_method = CoulombMethod.DIRECT_TRUNCATED lat_opts_1e = LatticeSumOptions() lat_opts_1e.cutoff_bohr = lat_opts.cutoff_bohr lat_opts_1e.nuclear_cutoff_bohr = lat_opts.nuclear_cutoff_bohr lat_opts_1e.coulomb_method = ( CoulombMethod.EWALD_3D if system.dim == 3 else CoulombMethod.DIRECT_TRUNCATED ) n_elec = int(system.n_electrons()) n_alpha, n_beta = _spin_occupations(system) mult = int(system.multiplicity) _kmesh_ibz = kmesh _ir_mapping = np.asarray(getattr(kmesh, "ir_mapping", []), dtype=int).reshape(-1) k_points = list(_kmesh_ibz.kpoints) weights = np.asarray(_kmesh_ibz.weights, dtype=float) if use_ewald_j_split and _ir_mapping.size > 0: kmesh_full = _expand_ibz_kmesh_for_ewald_j(system, kmesh, plog) k_points_full = list(kmesh_full.kpoints) weights_full = np.asarray(kmesh_full.weights, dtype=float) else: k_points_full = k_points weights_full = weights n_k = len(k_points) if n_k == 0: raise ValueError("kmesh has no k-points") if not np.isclose(weights.sum(), 1.0): raise ValueError(f"kmesh.weights must sum to 1; got {weights.sum():.6f}") if use_ewald_j_split and n_k > 1 and _ir_mapping.size == 0: uniform_w = 1.0 / float(n_k) if not np.allclose(weights, uniform_w, atol=1e-9): raise ValueError( "use_ewald_j_split at multi-k requires uniform full-mesh " "weights or an IBZ-reduced Monkhorst-Pack mesh carrying " "ir_mapping metadata so the driver can expand it. " f"Got non-uniform weights = {weights.tolist()}." ) plog.info( f"PBC BIPOLE UHF (CRYSTAL-gauge) / cutoff {lat_opts.cutoff_bohr:.2f} bohr" ) plog.info(f" n_alpha = {n_alpha}, n_beta = {n_beta}, multiplicity = {mult}") plog.info( f" F^2e (J + K) : " f"{'EWALD_J_SPLIT' if use_ewald_j_split else lat_opts_2e.coulomb_method.name}" f"{' (auto)' if use_ewald_j_split_auto else ''}" ) plog.info( f"k-mesh: {n_k} k-point{'s' if n_k != 1 else ''}, " f"weights sum = {weights.sum():.4f}" ) ewald_options_1e: Optional[EwaldOptions] = None omega_used: Optional[float] = None ewald_cell_volume: Optional[float] = None ewald_k_max: Optional[float] = None if system.dim == 3: from .bipole_ext_el_pole import ( crystal_default_ewald_alpha, crystal_ewald_reciprocal_cutoff, ) V_cell = float( abs( np.linalg.det(np.asarray(system.lattice, dtype=float)), ) ) ewald_cell_volume = V_cell omega_used = ( float(ewald_omega) if ewald_omega is not None else crystal_default_ewald_alpha(V_cell) ) ewald_k_max = crystal_ewald_reciprocal_cutoff(V_cell) ewald_options_1e = _crystal_ewald_options( lat_opts_1e, alpha_bohr_inv=omega_used, tolerance=float(ewald_precision), recip_cutoff_bohr_inv=ewald_k_max, ) plog.info( f" Ewald state: alpha = {omega_used:.6f} bohr^-1, " f"real_cutoff = {lat_opts_1e.nuclear_cutoff_bohr:.2f} bohr, " f"K_max = {ewald_k_max:.2f} bohr^-1, " f"tol = {float(ewald_precision):.0e}" ) with plog.stage( "integrals_lattice", detail=f"S/T/V at cutoff {lat_opts.cutoff_bohr:.2f} bohr", ): _use_sym = False # FIXME: LatticeMatrixSet has no Python constructor if _use_sym: ops = system.symmetry.operations _, S_blocks = compute_overlap_lattice_reduced( basis, system, lat_opts_2e, ops, ) S_lat = LatticeMatrixSet() S_lat.nbf = basis.nbasis S_lat.cells = direct_lattice_cells(system, lat_opts_2e.cutoff_bohr) S_lat.blocks = S_blocks _, T_blocks = compute_kinetic_lattice_reduced( basis, system, lat_opts_2e, ops, ) T_lat = LatticeMatrixSet() T_lat.nbf = basis.nbasis T_lat.cells = direct_lattice_cells(system, lat_opts_2e.cutoff_bohr) T_lat.blocks = T_blocks else: S_lat = compute_overlap_lattice(basis, system, lat_opts_2e) T_lat = compute_kinetic_lattice(basis, system, lat_opts_2e) v_ne_lr_cache = None if ( system.dim == 3 and ewald_options_1e is not None and v_ne_grid_options is None ): V_lat, v_ne_lr_cache = _compute_nuclear_lattice_ewald_reciprocal_ft( basis, system, lat_opts_1e, ewald_options_1e, S_lat, precision=ewald_precision, K_max=ewald_k_max, ) else: v_ne_grid = ( v_ne_grid_options if v_ne_grid_options is not None else (_default_bipole_v_ne_grid_options() if system.dim == 3 else None) ) V_lat = compute_nuclear_lattice_dispatch( basis, system, lat_opts_1e, grid_options=v_ne_grid, ewald_options=ewald_options_1e, ) cells = list(S_lat.cells) plog.info(f"n_cells in lattice sum = {len(cells)}") from .linear_dependence import scf_preflight_overlap_check S_k_list: List[np.ndarray] = [] Hcore_k_list: List[np.ndarray] = [] X_k_list: List[np.ndarray] = [] for k_idx, k in enumerate(k_points): k_arr = np.asarray(k, dtype=float).reshape(3) S_k = np.asarray(bloch_sum(S_lat, k_arr)) T_k = np.asarray(bloch_sum(T_lat, k_arr)) V_k = np.asarray(bloch_sum(V_lat, k_arr)) H_k = T_k + V_k S_k = 0.5 * (S_k + S_k.conj().T) H_k = 0.5 * (H_k + H_k.conj().T) scf_preflight_overlap_check( S_k, plog=plog, label=f"S(k={k_idx}, k_cart={k_arr.round(4).tolist()})", basis=basis, ) X_k, n_kept = _canonical_orthogonalizer_complex( S_k, linear_dep_threshold, normalize_diag_first=canonical_orth_normalize_diag_first, ) if max(n_alpha, n_beta) > n_kept: raise RuntimeError( f"run_pbc_bipole_uhf: canonical orth at k={k_idx} " f"dropped too many directions (n_alpha={n_alpha}, " f"n_beta={n_beta}, n_kept={n_kept})" ) S_k_list.append(S_k) Hcore_k_list.append(H_k) X_k_list.append(X_k) if ewald_options_1e is not None: e_nuc = float(ewald_nuclear_repulsion(system, ewald_options_1e)) else: e_nuc = float(nuclear_repulsion_per_cell(system, lat_opts_1e)) plog.info(f"E_nuc per cell ({lat_opts_1e.coulomb_method.name}) = {e_nuc:+.10f} Ha") C_alpha_per_k: List[np.ndarray] = [] eps_alpha_per_k: List[np.ndarray] = [] C_beta_per_k: List[np.ndarray] = [] eps_beta_per_k: List[np.ndarray] = [] for H_k, X_k in zip(Hcore_k_list, X_k_list): C_a, eps_a = _diag_in_orth_basis(H_k, X_k) C_b, eps_b = _diag_in_orth_basis(H_k, X_k) C_alpha_per_k.append(C_a.astype(complex)) eps_alpha_per_k.append(eps_a) C_beta_per_k.append(C_b.astype(complex)) eps_beta_per_k.append(eps_b) def _spin_density( C_per_k_local: Sequence[np.ndarray], n_occ_each: int, ) -> LatticeMatrixSet: nbf = C_per_k_local[0].shape[1] occ_per_k = [] for _ in range(n_k): occ = np.zeros(nbf, dtype=float) occ[:n_occ_each] = 1.0 occ_per_k.append(occ) return real_space_density_from_kpoints_fractional( C_per_k_local, occ_per_k, kmesh, cells, ) D_alpha_real = _spin_density(C_alpha_per_k, n_alpha) D_beta_real = _spin_density(C_beta_per_k, n_beta) # Caller-supplied warm-start spin densities take precedence over # the SAD/Hcore guess engine. Both init_alpha and init_beta must # be provided together (or both None). Block ordering matches the # canonical ``direct_lattice_cells(kmesh)`` ordering — same # contract as the RHF driver. Used by the NEB driver for # within-image density warm-start (HANDOVER_PERIODIC_NEB.md M4). if (init_alpha is not None) != (init_beta is not None): raise ValueError( "run_pbc_bipole_uhf: init_alpha and init_beta must be " "provided together (both None or both populated)" ) if init_alpha is not None and init_beta is not None: blocks_a = list(init_alpha) blocks_b = list(init_beta) if len(blocks_a) != len(D_alpha_real.cells): raise ValueError( f"run_pbc_bipole_uhf: init_alpha has {len(blocks_a)} " f"blocks; expected {len(D_alpha_real.cells)}" ) if len(blocks_b) != len(D_beta_real.cells): raise ValueError( f"run_pbc_bipole_uhf: init_beta has {len(blocks_b)} " f"blocks; expected {len(D_beta_real.cells)}" ) for g_idx, (ba, bb) in enumerate(zip(blocks_a, blocks_b)): D_alpha_real.set_block(g_idx, np.asarray(ba, dtype=float)) D_beta_real.set_block(g_idx, np.asarray(bb, dtype=float)) plog.info( "initial guess: caller-supplied spin densities (warm-start)" ) initial_density_is_local = True density_from_c_per_k = False else: guess = getattr(opts, "initial_guess", InitialGuess.HCORE) D_guess = None if n_elec % 2 == 0: # CRYSTAL's UHF PATIRR/SAD starts from the total atomic SAD with # zero summed spin density, then the requested spin state is # enforced by the alpha/beta occupations after the first Fock # diagonalisation. Mirror that convention for CYC0 parity. D_total_guess = initial_density_closed_shell( system.unit_cell_molecule(), basis, n_elec // 2, guess, is_periodic=True, ) if D_total_guess is not None: D_guess = ( 0.5 * np.asarray(D_total_guess, dtype=float), 0.5 * np.asarray(D_total_guess, dtype=float), ) if D_guess is None: D_guess = initial_densities_open_shell( system.unit_cell_molecule(), basis, n_alpha, n_beta, guess, is_periodic=True, ) initial_density_is_local = D_guess is not None if D_guess is not None: plog.info(f"initial guess: {guess.name} (g=0 spin densities)") D_a0, D_b0 = D_guess zero_a = np.zeros_like(D_a0, dtype=float) zero_b = np.zeros_like(D_b0, dtype=float) for g_idx in range(len(D_alpha_real.cells)): is_g0 = ( np.asarray(D_alpha_real.cells[g_idx].index, dtype=int) == np.array([0, 0, 0]) ).all() D_alpha_real.set_block(g_idx, D_a0 if is_g0 else zero_a) D_beta_real.set_block(g_idx, D_b0 if is_g0 else zero_b) else: plog.info(f"initial guess: {guess.name} (Hcore-diag per k)") density_from_c_per_k = not initial_density_is_local D_alpha_prev: Optional[LatticeMatrixSet] = None D_beta_prev: Optional[LatticeMatrixSet] = None damping = float(opts.damping) if not (0.0 <= damping < 1.0): raise ValueError(f"run_pbc_bipole_uhf: damping must be in [0,1); got {damping}") damper: Optional[DynamicDamping] = None if bool(getattr(opts, "dynamic_damping", False)): damper = DynamicDamping( initial_alpha=damping, alpha_min=float(getattr(opts, "dynamic_damping_min", 0.0)), alpha_max=float(getattr(opts, "dynamic_damping_max", 0.95)), ) use_diis = bool(opts.use_diis) diis_start_iter = int(opts.diis_start_iter) accel: Optional[MultiKPeriodicUHFAccelerator] = ( MultiKPeriodicUHFAccelerator(opts) if use_diis else None ) level_shift_static = float(getattr(opts, "level_shift", 0.0)) if level_shift_schedule is not None and not isinstance( level_shift_schedule, LevelShiftSchedule, ): raise TypeError( f"level_shift_schedule must be a LevelShiftSchedule or None; " f"got {type(level_shift_schedule).__name__}" ) if level_shift_schedule is not None: plog.info(f"level_shift_schedule: {level_shift_schedule.as_list()}") if use_mom: plog.info("MOM (Maximum Overlap Method): ON") C_prev_occ_alpha_per_k = None C_prev_occ_beta_per_k = None if use_oda and use_diis: raise ValueError( "run_pbc_bipole_uhf: use_oda and use_diis are mutually exclusive" ) if use_oda: if not (0.0 < oda_trust_lambda_max <= 1.0): raise ValueError( f"oda_trust_lambda_max must be in (0, 1]; got {oda_trust_lambda_max}" ) plog.info( f"ODA (Optimal Damping): ON (+1 Fock build/iter, " f"trust lambda_max = {oda_trust_lambda_max})" ) j_lr_cache = v_ne_lr_cache if use_ewald_j_split: if system.dim != 3: raise ValueError( f"use_ewald_j_split requires dim=3 (3D periodic). Got dim={system.dim}." ) from .bipole_fock_ewald import _build_j_long_range_cache assert omega_used is not None cells_r_cart_arr = np.array( [np.asarray(c.r_cart, dtype=float) for c in cells], dtype=float, ) if j_lr_cache is None: j_lr_cache = _build_j_long_range_cache( basis, system, cells_r_cart_arr, omega_used, ewald_precision, K_max=ewald_k_max, ) elif j_lr_cache.ft_per_cell.shape[0] != len(cells): raise RuntimeError( "prebuilt V_ne/J^LR cache has a different cell count " f"({j_lr_cache.ft_per_cell.shape[0]}) from S_lat " f"({len(cells)})" ) plog.info( f" J^LR cache: {j_lr_cache.K_vectors.shape[0]} K-vectors, " f"{j_lr_cache.ft_per_cell.shape[0]} lattice cells" ) def _spin_d_matrices_from_coeffs( coeffs_per_k: Sequence[np.ndarray], n_occ: int, ) -> List[np.ndarray]: out: List[np.ndarray] = [] for C_k_raw in coeffs_per_k: C_k = np.asarray(C_k_raw) C_occ = C_k[:, :n_occ] if n_occ > 0 else C_k[:, :0] out.append(C_occ @ C_occ.conj().T) return out def _build_fock_for_density( D_alpha: LatticeMatrixSet, D_beta: LatticeMatrixSet, *, coeffs_alpha_for_rho: Optional[Sequence[np.ndarray]], coeffs_beta_for_rho: Optional[Sequence[np.ndarray]], ) -> _PBCBipoleUHFFockBuild: D_total = _combine_density_sets( basis, system, lat_opts_2e, D_alpha, D_beta, ) if use_ewald_j_split: F_J_SR_lat = build_fock_2e_real_space( basis, system, lat_opts_2e, D_total, 0.0, float(omega_used), ) jk_alpha = build_jk_2e_real_space( basis, system, lat_opts_2e, D_alpha, 0.0, ) jk_beta = build_jk_2e_real_space( basis, system, lat_opts_2e, D_beta, 0.0, ) K_alpha_blocks = [ np.asarray(block, dtype=float).copy() for block in jk_alpha.K.blocks ] K_beta_blocks = [ np.asarray(block, dtype=float).copy() for block in jk_beta.K.blocks ] rho_hat_for_LR = None if coeffs_alpha_for_rho is not None and coeffs_beta_for_rho is not None: from .bipole_fock_ewald import compute_rho_hat_from_k_density D_a_k = _spin_d_matrices_from_coeffs( coeffs_alpha_for_rho, n_alpha, ) D_b_k = _spin_d_matrices_from_coeffs( coeffs_beta_for_rho, n_beta, ) D_total_k = [Da + Db for Da, Db in zip(D_a_k, D_b_k)] if _ir_mapping.size > 0: D_total_k_full = [D_total_k[int(idx)] for idx in _ir_mapping] rho_hat_for_LR = compute_rho_hat_from_k_density( D_total_k_full, k_points_full, weights_full, j_lr_cache, ) else: rho_hat_for_LR = compute_rho_hat_from_k_density( D_total_k, k_points, weights, j_lr_cache, ) from .bipole_fock_ewald import ( compute_J_long_range_real_space_blocks, ) F_LR_blocks = compute_J_long_range_real_space_blocks( D_total, basis, system, omega_used, precision=ewald_precision, cache=j_lr_cache, rho_hat=rho_hat_for_LR, ) assert ewald_cell_volume is not None j_background_potential = ( -np.pi * float(n_elec) / (float(omega_used) * float(omega_used) * float(ewald_cell_volume)) ) s_blocks = { _cell_key(cell): np.asarray(block, dtype=float) for cell, block in zip(S_lat.cells, S_lat.blocks) } for c, cell in enumerate(F_J_SR_lat.cells): key = _cell_key(cell) F_LR_blocks[c] = F_LR_blocks[c] + j_background_potential * s_blocks[key] j_sr_blocks = [ np.asarray(block, dtype=float).copy() for block in F_J_SR_lat.blocks ] alpha_blocks = [ j_sr + j_lr - k_a for j_sr, j_lr, k_a in zip(j_sr_blocks, F_LR_blocks, K_alpha_blocks) ] beta_blocks = [ j_sr + j_lr - k_b for j_sr, j_lr, k_b in zip(j_sr_blocks, F_LR_blocks, K_beta_blocks) ] f2e_alpha_real = F_J_SR_lat for c, block in enumerate(alpha_blocks): f2e_alpha_real.set_block(c, block) f2e_beta_real = _copy_lattice_with_blocks( basis, system, lat_opts_2e, F_J_SR_lat.cells, beta_blocks, ) # ---- Optional: replace J_SR+J_LR with multipole far-field J -- e_j_multipole: Optional[float] = None if _mp_config.enabled: from .bipole_fock_multipole import apply_multipole_far_field # Build multipole J from total density; K is spin-specific. mp_result = apply_multipole_far_field( D_total, basis, system, lat_opts_2e, _mp_config, J_SR_blocks=j_sr_blocks, K_blocks=None, F_LR_blocks=F_LR_blocks, exchange_scale=0.0, ) for c in range(len(F_J_SR_lat.cells)): f2e_alpha_real.set_block( c, mp_result.f2e_blocks[c] - K_alpha_blocks[c], ) f2e_beta_real.set_block( c, mp_result.f2e_blocks[c] - K_beta_blocks[c], ) e_j_multipole = mp_result.e_j_multipole if mp_result.n_far_cells > 0: plog.info( f" BIPOLE multipole far-field (L_max={_mp_config.L_max}, " f"R={_mp_config.R_bipole:.1f} bohr): " f"{mp_result.n_far_cells}/{mp_result.n_total_cells} cells replaced, " f"E_J_far = {e_j_multipole:+.6f} Ha" ) e_j_short_range = 0.5 * _lattice_contract_blocks( D_total, F_J_SR_lat.cells, j_sr_blocks, operator_name="J_SR", ) e_j_long_range = 0.5 * _lattice_contract_blocks( D_total, F_J_SR_lat.cells, F_LR_blocks, operator_name="J_LR", ) e_exchange = -0.5 * ( _lattice_contract_blocks( D_alpha, F_J_SR_lat.cells, K_alpha_blocks, operator_name="K_alpha", ) + _lattice_contract_blocks( D_beta, F_J_SR_lat.cells, K_beta_blocks, operator_name="K_beta", ) ) else: F_J_lat = build_fock_2e_real_space( basis, system, lat_opts_2e, D_total, 0.0, 0.0, ) jk_alpha = build_jk_2e_real_space( basis, system, lat_opts_2e, D_alpha, 0.0, ) jk_beta = build_jk_2e_real_space( basis, system, lat_opts_2e, D_beta, 0.0, ) j_blocks = [ np.asarray(block, dtype=float).copy() for block in F_J_lat.blocks ] K_alpha_blocks = [ np.asarray(block, dtype=float).copy() for block in jk_alpha.K.blocks ] K_beta_blocks = [ np.asarray(block, dtype=float).copy() for block in jk_beta.K.blocks ] alpha_blocks = [j - k for j, k in zip(j_blocks, K_alpha_blocks)] beta_blocks = [j - k for j, k in zip(j_blocks, K_beta_blocks)] f2e_alpha_real = F_J_lat for c, block in enumerate(alpha_blocks): f2e_alpha_real.set_block(c, block) f2e_beta_real = _copy_lattice_with_blocks( basis, system, lat_opts_2e, F_J_lat.cells, beta_blocks, ) e_j_short_range = 0.5 * _lattice_contract_blocks( D_total, F_J_lat.cells, j_blocks, operator_name="J", ) e_j_long_range = None e_exchange = -0.5 * ( _lattice_contract_blocks( D_alpha, F_J_lat.cells, K_alpha_blocks, operator_name="K_alpha", ) + _lattice_contract_blocks( D_beta, F_J_lat.cells, K_beta_blocks, operator_name="K_beta", ) ) f_alpha_k_list: List[np.ndarray] = [] f_beta_k_list: List[np.ndarray] = [] for k_idx, k in enumerate(k_points): k_arr = np.asarray(k, dtype=float) F_a_2e = _bloch_sum_blocks( f2e_alpha_real.blocks, f2e_alpha_real.cells, k_arr, ) F_b_2e = _bloch_sum_blocks( f2e_beta_real.blocks, f2e_beta_real.cells, k_arr, ) F_a = F_a_2e + np.asarray(Hcore_k_list[k_idx], dtype=complex) F_b = F_b_2e + np.asarray(Hcore_k_list[k_idx], dtype=complex) f_alpha_k_list.append(0.5 * (F_a + F_a.conj().T)) f_beta_k_list.append(0.5 * (F_b + F_b.conj().T)) return _PBCBipoleUHFFockBuild( f2e_alpha_real=f2e_alpha_real, f2e_beta_real=f2e_beta_real, f_alpha_k_list=f_alpha_k_list, f_beta_k_list=f_beta_k_list, e_j_short_range=e_j_short_range, e_j_long_range=e_j_long_range, e_exchange=e_exchange, e_j_multipole=e_j_multipole, ) # ---- Multipole far-field config (resolve once before SCF loop) ----- from .bipole_fock_multipole import ( # noqa: E402 BipoleMultipoleConfig, resolve_multipole_config, ) _mp_config = resolve_multipole_config( system, basis, lat_opts_2e, user_enable=use_multipole_far_field, multipole_l_max=multipole_l_max, ) if _mp_config.enabled: plog.info( f" BIPOLE multipole far-field: ENABLED " f"(L_max={_mp_config.L_max}, R_bipole={_mp_config.R_bipole:.1f} bohr, " f"n_cells={len(_mp_config.cache.cells) if _mp_config.cache else 0})" ) else: plog.info( f" BIPOLE multipole far-field: off " f"(R_bipole={_mp_config.R_bipole:.1f} bohr, " f"cutoff={lat_opts_2e.cutoff_bohr:.1f} bohr)" ) plog.banner("SCF (PBC BIPOLE UHF, direct-space)") plog.info(" iter energy (Ha) dE ||[F,DS]|| DIIS") scf_trace: List[SCFIteration] = [] energy_components: List[PBCBipoleEnergyComponents] = [] E_prev = 0.0 E_elec = 0.0 F_alpha_k_list: List[np.ndarray] = [np.zeros_like(H) for H in Hcore_k_list] F_beta_k_list: List[np.ndarray] = [np.zeros_like(H) for H in Hcore_k_list] converged = False iter_idx = 0 for iter_idx in range(1, int(opts.max_iter) + 1): if damper is not None: damping = damper.alpha diis_active = use_diis and iter_idx >= diis_start_iter if iter_idx > 1 and damping > 0.0 and not diis_active: D_alpha_used = _damp_lattice_matrix( D_alpha_real, D_alpha_prev, damping, ) D_beta_used = _damp_lattice_matrix( D_beta_real, D_beta_prev, damping, ) else: D_alpha_used = D_alpha_real D_beta_used = D_beta_real d_used_is_damped = iter_idx > 1 and damping > 0.0 and not diis_active d_used_from_coeffs = ( density_from_c_per_k and not (initial_density_is_local and iter_idx == 1) and not d_used_is_damped ) fock_build = _build_fock_for_density( D_alpha_used, D_beta_used, coeffs_alpha_for_rho=(C_alpha_per_k if d_used_from_coeffs else None), coeffs_beta_for_rho=(C_beta_per_k if d_used_from_coeffs else None), ) F_alpha_k_list = fock_build.f_alpha_k_list F_beta_k_list = fock_build.f_beta_k_list D_total_used = _combine_density_sets( basis, system, lat_opts_2e, D_alpha_used, D_beta_used, ) E_kin = _lattice_contract(D_total_used, T_lat, operator_name="T") E_ne = _lattice_contract(D_total_used, V_lat, operator_name="V_ne") E_2e = 0.5 * ( _lattice_contract( D_alpha_used, fock_build.f2e_alpha_real, operator_name="F2e_alpha", ) + _lattice_contract( D_beta_used, fock_build.f2e_beta_real, operator_name="F2e_beta", ) ) E_elec = E_kin + E_ne + E_2e grad_norm_sum = 0.0 error_alpha_k_list: List[np.ndarray] = [] error_beta_k_list: List[np.ndarray] = [] D_alpha_k_list: List[np.ndarray] = [] D_beta_k_list: List[np.ndarray] = [] for idx in range(n_k): if initial_density_is_local and iter_idx == 1: k_arr = np.asarray(k_points[idx], dtype=float) D_a_k = _bloch_sum_blocks( D_alpha_used.blocks, D_alpha_used.cells, k_arr, ) D_b_k = _bloch_sum_blocks( D_beta_used.blocks, D_beta_used.cells, k_arr, ) D_a_k = 0.5 * (D_a_k + D_a_k.conj().T) D_b_k = 0.5 * (D_b_k + D_b_k.conj().T) else: C_a = C_alpha_per_k[idx] C_b = C_beta_per_k[idx] C_a_occ = C_a[:, :n_alpha] if n_alpha > 0 else C_a[:, :0] C_b_occ = C_b[:, :n_beta] if n_beta > 0 else C_b[:, :0] D_a_k = C_a_occ @ C_a_occ.conj().T D_b_k = C_b_occ @ C_b_occ.conj().T D_alpha_k_list.append(D_a_k) D_beta_k_list.append(D_b_k) S_k = S_k_list[idx] F_a_k = F_alpha_k_list[idx] F_b_k = F_beta_k_list[idx] FDS_a = F_a_k @ D_a_k @ S_k FDS_b = F_b_k @ D_b_k @ S_k err_a = FDS_a - FDS_a.conj().T err_b = FDS_b - FDS_b.conj().T error_alpha_k_list.append(err_a) error_beta_k_list.append(err_b) grad_norm_sum += float(weights[idx]) * float( np.sqrt(np.linalg.norm(err_a) ** 2 + np.linalg.norm(err_b) ** 2) ) E_total = float(E_elec) + e_nuc # EXT EL-SPHEROPOLE — uses total (alpha+beta) density. D_total_used = _combine_density_sets( basis, system, lat_opts_2e, D_alpha_used, D_beta_used ) E_sphero = compute_ext_el_spheropole(D_total_used, basis, system, lat_opts) E_total += E_sphero dE = E_total - E_prev if iter_idx > 1 else 0.0 check_scf_divergence( "run_pbc_bipole_uhf", iter_idx, E_total, grad_norm_sum, dE, ) diis_sub = accel.subspace_size if accel is not None else 0 scf_trace.append( SCFIteration( iter=iter_idx, energy=float(E_total), delta_e=float(dE), grad_norm=float(grad_norm_sum), diis_subspace=diis_sub, ) ) plog.iteration( iter_idx, energy=float(E_total), dE=float(dE), grad=float(grad_norm_sum), diis=diis_sub, ) energy_components.append( PBCBipoleEnergyComponents( iter=int(iter_idx), e_total=float(E_total), e_electronic=float(E_elec), e_kinetic=float(E_kin), e_nuclear_attraction=float(E_ne), e_two_electron=float(E_2e), e_nuclear_repulsion=float(e_nuc), e_bielet_zone_ee=(None if use_ewald_j_split else float(E_2e)), e_j_short_range=fock_build.e_j_short_range, e_j_long_range=fock_build.e_j_long_range, e_exchange=fock_build.e_exchange, e_ext_el_spheropole=E_sphero, ) ) plog.energy_decomposition( iter_idx, E_kin=float(E_kin), E_ne=float(E_ne), E_2e=float(E_2e), E_elec=float(E_elec), E_nuc=float(e_nuc), ) converged = ( iter_idx > 1 and abs(dE) < float(opts.conv_tol_energy) and grad_norm_sum < float(opts.conv_tol_grad) ) if accel is not None: density_alpha_k_list = [ _bloch_sum_blocks( D_alpha_used.blocks, D_alpha_used.cells, np.asarray(k), ) for k in k_points ] density_beta_k_list = [ _bloch_sum_blocks( D_beta_used.blocks, D_beta_used.cells, np.asarray(k), ) for k in k_points ] F_a_ex, F_b_ex = accel.extrapolate_uhf( F_alpha_k_list, F_beta_k_list, error_alpha_k_list=error_alpha_k_list, error_beta_k_list=error_beta_k_list, density_alpha_k_list=density_alpha_k_list, density_beta_k_list=density_beta_k_list, energy=E_total, mo_coeffs_alpha_k_list=C_alpha_per_k, mo_coeffs_beta_k_list=C_beta_per_k, n_alpha=n_alpha, n_beta=n_beta, weights=list(weights), cells=cells, kpoints=list(k_points), ) if diis_active: F_alpha_k_list = F_a_ex F_beta_k_list = F_b_ex if level_shift_schedule is not None: level_shift_b = level_shift_schedule.at(iter_idx) else: level_shift_b = level_shift_static if level_shift_b != 0.0: F_alpha_for_diag: List[np.ndarray] = [] F_beta_for_diag: List[np.ndarray] = [] for idx in range(n_k): S_k = S_k_list[idx] D_a_k = D_alpha_k_list[idx] D_b_k = D_beta_k_list[idx] F_a_shift = ( F_alpha_k_list[idx] + level_shift_b * S_k - (level_shift_b / 2.0) * (S_k @ D_a_k @ S_k) ) F_b_shift = ( F_beta_k_list[idx] + level_shift_b * S_k - (level_shift_b / 2.0) * (S_k @ D_b_k @ S_k) ) F_alpha_for_diag.append(0.5 * (F_a_shift + F_a_shift.conj().T)) F_beta_for_diag.append(0.5 * (F_b_shift + F_b_shift.conj().T)) else: F_alpha_for_diag = F_alpha_k_list F_beta_for_diag = F_beta_k_list new_C_alpha: List[np.ndarray] = [] new_eps_alpha: List[np.ndarray] = [] new_C_beta: List[np.ndarray] = [] new_eps_beta: List[np.ndarray] = [] for idx in range(n_k): C_a, eps_a = _diag_in_orth_basis( F_alpha_for_diag[idx], X_k_list[idx], ) C_b, eps_b = _diag_in_orth_basis( F_beta_for_diag[idx], X_k_list[idx], ) new_C_alpha.append(C_a) new_eps_alpha.append(eps_a) new_C_beta.append(C_b) new_eps_beta.append(eps_b) # --- MOM reorder (iter >= 2 only) --- if use_mom and C_prev_occ_alpha_per_k is not None: for idx in range(n_k): for spin, (C_k, eps_k, n_occ_spin, C_prev_occ_k) in enumerate( [ ( new_C_alpha[idx], new_eps_alpha[idx], n_alpha, C_prev_occ_alpha_per_k[idx], ), ( new_C_beta[idx], new_eps_beta[idx], n_beta, C_prev_occ_beta_per_k[idx], ), ] ): if n_occ_spin == 0: continue S_k = S_k_list[idx] sel = _mom_select( C_k, S_k, C_prev_occ_k, n_occ_spin, eps_new=eps_k, ) n_kept_idx = C_k.shape[1] virt_mask = np.ones(n_kept_idx, dtype=bool) virt_mask[sel] = False virt_sel = np.where(virt_mask)[0] virt_sel = virt_sel[np.argsort(np.real(eps_k[virt_sel]))] order = np.concatenate([sel, virt_sel]) if spin == 0: new_C_alpha[idx] = C_k[:, order] new_eps_alpha[idx] = eps_k[order] else: new_C_beta[idx] = C_k[:, order] new_eps_beta[idx] = eps_k[order] C_alpha_per_k = new_C_alpha eps_alpha_per_k = new_eps_alpha C_beta_per_k = new_C_beta eps_beta_per_k = new_eps_beta D_alpha_new = _spin_density(C_alpha_per_k, n_alpha) D_beta_new = _spin_density(C_beta_per_k, n_beta) # --- ODA mixing (extra Fock build) --- if use_oda: fock_naive = _build_fock_for_density( D_alpha_new, D_beta_new, coeffs_alpha_for_rho=C_alpha_per_k, coeffs_beta_for_rho=C_beta_per_k, ) oda_step = _compute_oda_lambda( D_alpha_used, D_alpha_new, F_alpha_k_list, fock_naive.f_alpha_k_list, [np.asarray(k) for k in k_points], weights, trust_lambda_max=oda_trust_lambda_max, ) _oda_mix(D_alpha_used, D_alpha_new, oda_step.lam) _oda_mix(D_beta_used, D_beta_new, oda_step.lam) D_alpha_prev = D_alpha_real D_beta_prev = D_beta_real D_alpha_real = D_alpha_used D_beta_real = D_beta_used density_from_c_per_k = oda_step.lam == 1.0 plog.info( f" ODA: lambda = {oda_step.lam:.4f} " f"(g0 = {oda_step.g0:+.3e}, g1 = {oda_step.g1:+.3e})" ) else: D_alpha_prev = D_alpha_used D_beta_prev = D_beta_used D_alpha_real = D_alpha_new D_beta_real = D_beta_new density_from_c_per_k = True # Snapshot for next iter MOM if use_mom: C_prev_occ_alpha_per_k = [ np.asarray(C_alpha_per_k[idx][:, :n_alpha]).copy() if n_alpha > 0 else np.zeros((C_alpha_per_k[idx].shape[0], 0), dtype=complex) for idx in range(n_k) ] C_prev_occ_beta_per_k = [ np.asarray(C_beta_per_k[idx][:, :n_beta]).copy() if n_beta > 0 else np.zeros((C_beta_per_k[idx].shape[0], 0), dtype=complex) for idx in range(n_k) ] if damper is not None: damper.update(E_total) E_prev = E_total if converged: break if n_alpha == 0 or n_beta == 0: s2 = 0.25 * (n_alpha - n_beta) * (n_alpha - n_beta + 2) + n_beta else: k0_idx = 0 for idx, k in enumerate(k_points): if np.allclose(np.asarray(k, dtype=float), 0.0): k0_idx = idx break s2 = _spin_squared( n_alpha, n_beta, np.real(C_alpha_per_k[k0_idx]), np.real(C_beta_per_k[k0_idx]), np.real(S_k_list[k0_idx]), ) plog.converged(n_iter=iter_idx, energy=E_total, converged=converged) # ---- Post-loop: recompute energy on final density for consistency if converged: _fb = _build_fock_for_density( D_alpha_real, D_beta_real, coeffs_alpha_for_rho=C_alpha_per_k, coeffs_beta_for_rho=C_beta_per_k, ) D_tot = _combine_density_sets( basis, system, lat_opts_2e, D_alpha_real, D_beta_real ) E_kin_final = _lattice_contract(D_tot, T_lat, operator_name="T") E_ne_final = _lattice_contract(D_tot, V_lat, operator_name="V_ne") E_2e_final = 0.5 * ( _lattice_contract(D_alpha_real, _fb.f2e_alpha_real, operator_name="F2e") + _lattice_contract(D_beta_real, _fb.f2e_beta_real, operator_name="F2e") ) E_elec = E_kin_final + E_ne_final + E_2e_final E_total = float(E_elec) + e_nuc # Fresh E_total doesn't include spheropole — add it. E_sphero_final = compute_ext_el_spheropole(D_tot, basis, system, lat_opts) E_total += E_sphero_final else: E_sphero_final = energy_components[-1].e_ext_el_spheropole return PBCBipoleUHFResult( energy=float(E_total), e_electronic=float(E_elec), e_nuclear=e_nuc, e_ext_el_spheropole=E_sphero_final, n_iter=iter_idx, converged=converged, s_squared=float(s2), s_squared_ideal=0.25 * (mult - 1) * (mult + 1), mo_energies_alpha=eps_alpha_per_k, mo_coeffs_alpha=C_alpha_per_k, fock_alpha=F_alpha_k_list, density_alpha=D_alpha_real, mo_energies_beta=eps_beta_per_k, mo_coeffs_beta=C_beta_per_k, fock_beta=F_beta_k_list, density_beta=D_beta_real, overlap=S_k_list, hcore=Hcore_k_list, scf_trace=scf_trace, ewald_alpha_bohr_inv=omega_used, energy_components=energy_components, )