"""BIPOLE-style periodic UKS driver in CRYSTAL's electrostatic gauge.
The open-shell DFT counterpart of :mod:`vibeqc.pbc_bipole_rks`.
Uses the same CRYSTAL-gauge Ewald J-split F²e build for the Hartree
term and adds the spin-polarised libxc XC potential on the periodic
Becke grid. For hybrid functionals a fraction of HF exchange is
retained per spin channel via the native ``build_jk_2e_real_space``.
This driver keeps the same multi-k SCF scaffold as the UHF BIPOLE
driver: shared Ewald α across V_ne/E_nn/J^LR, analytic V_ne via
AO-pair Fourier transforms, real-space energy accounting, and
the full suite of convergence accelerators (DIIS, ODA, MOM,
level-shift schedule, damping).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional, Sequence, Union
import numpy as np
from ._vibeqc_core import (
BasisSet,
BlochKMesh,
CoulombMethod,
EwaldOptions,
Functional,
GridOptions,
InitialGuess,
LatticeMatrixSet,
LatticeSumOptions,
PeriodicKSOptions,
PeriodicSystem,
SCFIteration,
bloch_sum,
build_fock_2e_real_space,
build_grid,
build_jk_2e_real_space,
build_xc_periodic,
compute_kinetic_lattice,
compute_overlap_lattice,
direct_lattice_cells,
ewald_nuclear_repulsion,
nuclear_repulsion_per_cell,
real_space_density_from_kpoints_fractional,
)
from .bipole_ext_el_pole import compute_ext_el_spheropole
from .guess import initial_densities_open_shell, initial_density_closed_shell
from .level_shift_schedule import LevelShiftSchedule
from .mom import select_occupied_by_max_overlap as _mom_select
from .oda import compute_oda_lambda as _compute_oda_lambda
from .oda import oda_mix_densities as _oda_mix
from .pbc_bipole import (
PBCBipoleEnergyComponents,
_bloch_sum_blocks,
_cell_key,
_compute_nuclear_lattice_ewald_reciprocal_ft,
_crystal_ewald_options,
_default_bipole_v_ne_grid_options,
_expand_ibz_kmesh_for_ewald_j,
_lattice_contract,
_lattice_contract_blocks,
)
from .pbc_bipole_uhf import (
_combine_density_sets,
_copy_lattice_with_blocks,
_spin_occupations,
)
from .periodic_grid import build_periodic_becke_grid
from .periodic_rhf_multi_k_ewald import (
_canonical_orthogonalizer_complex,
_damp_lattice_matrix,
_diag_in_orth_basis,
)
from .periodic_scf_accelerators import (
DynamicDamping,
MultiKPeriodicUHFAccelerator,
)
from .periodic_uhf_ewald import _spin_squared
from .periodic_v_ne import compute_nuclear_lattice_dispatch
from .progress import ProgressLogger, resolve_progress
from .scf_divergence import check_scf_divergence
from .smearing._support import reject_unsupported_smearing_temperature
from .symmetry_integrals_reduced import (
compute_kinetic_lattice_reduced,
compute_overlap_lattice_reduced,
)
__all__ = [
"PBCBipoleUKSResult",
"run_pbc_bipole_uks",
]
[docs]
@dataclass
class PBCBipoleUKSResult:
"""Result of :func:`run_pbc_bipole_uks`."""
energy: float
e_electronic: float
e_nuclear: float
e_xc: float
e_coulomb: float
e_hf_exchange: float
n_iter: int
converged: bool
s_squared: float
s_squared_ideal: float
mo_energies_alpha: List[np.ndarray]
mo_coeffs_alpha: List[np.ndarray]
fock_alpha: List[np.ndarray]
density_alpha: LatticeMatrixSet
mo_energies_beta: List[np.ndarray]
mo_coeffs_beta: List[np.ndarray]
fock_beta: List[np.ndarray]
density_beta: LatticeMatrixSet
overlap: List[np.ndarray]
hcore: List[np.ndarray]
scf_trace: List[SCFIteration] = field(default_factory=list)
e_ext_el_spheropole: Optional[float] = None
ewald_alpha_bohr_inv: Optional[float] = None
energy_components: List[PBCBipoleEnergyComponents] = field(
default_factory=list,
)
functional: str = ""
@dataclass
class _PBCBipoleUKSFockBuild:
"""Internal Fock bundle for one UKS density pair."""
f2e_alpha_real: LatticeMatrixSet
f2e_beta_real: LatticeMatrixSet
f_alpha_k_list: List[np.ndarray]
f_beta_k_list: List[np.ndarray]
e_j_short_range: Optional[float] = None
e_j_long_range: Optional[float] = None
e_exchange: Optional[float] = None
e_j_multipole: Optional[float] = None
def _density_set_gamma_or_lattice(template, D_real):
if len(D_real.cells) > 1 and D_real.blocks:
return D_real
nbf = template.nbf
D = (
np.asarray(D_real.blocks[0], dtype=float)
if D_real.blocks
else np.zeros((nbf, nbf))
)
zero = np.zeros_like(D)
for i in range(len(template)):
template.set_block(i, D if i == 0 else zero)
return template
[docs]
def run_pbc_bipole_uks(
system: PeriodicSystem,
basis: BasisSet,
kmesh: BlochKMesh,
options=None,
*,
functional: Optional[str] = None,
linear_dep_threshold: float = 1e-7,
canonical_orth_normalize_diag_first: bool = True,
level_shift_schedule: Optional[LevelShiftSchedule] = None,
use_mom: bool = False,
use_oda: bool = False,
oda_trust_lambda_max: float = 1.0,
use_ewald_j_split: Optional[bool] = None,
ewald_omega: Optional[float] = None,
ewald_precision: float = 1e-8,
v_ne_grid_options: Optional[GridOptions] = None,
use_multipole_far_field: Optional[bool] = None,
multipole_l_max: int = 2,
progress: Union[bool, ProgressLogger, None] = None,
verbose: Optional[int] = None,
init_alpha: Optional[Sequence[np.ndarray]] = None,
init_beta: Optional[Sequence[np.ndarray]] = None,
) -> PBCBipoleUKSResult:
"""Multi-k open-shell UKS via the CRYSTAL-gauge BIPOLE scaffold."""
from ._vibeqc_core import PeriodicKSOptions as _PKSOpts
opts = options if options is not None else _PKSOpts()
if functional is not None:
opts.functional = str(functional)
if not getattr(opts, "functional", None):
opts.functional = "pbe"
reject_unsupported_smearing_temperature(
opts,
"run_pbc_bipole_uks",
detail=(
"BIPOLE smearing is queued for a later smearing milestone; "
"this driver would otherwise run integer spin occupations."
),
)
func = Functional(opts.functional, 2) # spin-polarised
alpha_hf = float(func.hf_exchange_fraction)
lat_opts: LatticeSumOptions = opts.lattice_opts
plog = resolve_progress(progress, verbose=verbose)
use_ewald_j_split_auto = use_ewald_j_split is None
use_ewald_j_split = (
system.dim == 3 if use_ewald_j_split_auto else bool(use_ewald_j_split)
)
lat_opts_2e = LatticeSumOptions()
lat_opts_2e.cutoff_bohr = lat_opts.cutoff_bohr
lat_opts_2e.nuclear_cutoff_bohr = lat_opts.nuclear_cutoff_bohr
lat_opts_2e.coulomb_method = CoulombMethod.DIRECT_TRUNCATED
lat_opts_1e = LatticeSumOptions()
lat_opts_1e.cutoff_bohr = lat_opts.cutoff_bohr
lat_opts_1e.nuclear_cutoff_bohr = lat_opts.nuclear_cutoff_bohr
lat_opts_1e.coulomb_method = (
CoulombMethod.EWALD_3D if system.dim == 3 else CoulombMethod.DIRECT_TRUNCATED
)
n_elec = int(system.n_electrons())
n_alpha, n_beta = _spin_occupations(system)
mult = int(system.multiplicity)
plog.info(
f"PBC BIPOLE UKS (CRYSTAL-gauge) / cutoff {lat_opts.cutoff_bohr:.2f} bohr"
)
plog.info(f" functional = {opts.functional}, hf_exchange_fraction = {alpha_hf}")
plog.info(f" n_alpha = {n_alpha}, n_beta = {n_beta}, multiplicity = {mult}")
_kmesh_ibz = kmesh
_ir_mapping = np.asarray(getattr(kmesh, "ir_mapping", []), dtype=int).reshape(-1)
k_points = list(_kmesh_ibz.kpoints)
weights = np.asarray(_kmesh_ibz.weights, dtype=float)
if use_ewald_j_split and _ir_mapping.size > 0:
kmesh_full = _expand_ibz_kmesh_for_ewald_j(system, kmesh, plog)
k_points_full = list(kmesh_full.kpoints)
weights_full = np.asarray(kmesh_full.weights, dtype=float)
else:
k_points_full = k_points
weights_full = weights
n_k = len(k_points)
if n_k == 0:
raise ValueError("kmesh has no k-points")
if not np.isclose(weights.sum(), 1.0):
raise ValueError(f"kmesh.weights must sum to 1; got {weights.sum():.6f}")
if use_ewald_j_split and n_k > 1 and _ir_mapping.size == 0:
uniform_w = 1.0 / float(n_k)
if not np.allclose(weights, uniform_w, atol=1e-9):
raise ValueError(
"use_ewald_j_split at multi-k requires uniform weights or IBZ mesh"
)
plog.info(f"k-mesh: {n_k} k-points, weights sum = {weights.sum():.4f}")
# Shared Ewald state
ewald_options_1e = None
omega_used = None
ewald_cell_volume = None
ewald_k_max: Optional[float] = None
if system.dim == 3:
from .bipole_ext_el_pole import (
crystal_default_ewald_alpha,
crystal_ewald_reciprocal_cutoff,
)
V_cell = float(abs(np.linalg.det(np.asarray(system.lattice, dtype=float))))
ewald_cell_volume = V_cell
omega_used = (
float(ewald_omega)
if ewald_omega is not None
else crystal_default_ewald_alpha(V_cell)
)
ewald_k_max = crystal_ewald_reciprocal_cutoff(V_cell)
ewald_options_1e = _crystal_ewald_options(
lat_opts_1e,
alpha_bohr_inv=omega_used,
tolerance=float(ewald_precision),
recip_cutoff_bohr_inv=ewald_k_max,
)
# DFT grid
if getattr(opts, "use_periodic_becke", False):
grid = build_periodic_becke_grid(
system,
grid_options=opts.grid,
image_radius_bohr=float(getattr(opts, "becke_image_radius_bohr", 10.0)),
)
else:
grid = build_grid(system.unit_cell_molecule(), opts.grid)
# One-electron integrals
with plog.stage(
"integrals_lattice", detail=f"S/T/V at cutoff {lat_opts.cutoff_bohr:.2f} bohr"
):
_use_sym = False # FIXME: LatticeMatrixSet has no Python constructor
if _use_sym:
ops = system.symmetry.operations
_, S_blocks = compute_overlap_lattice_reduced(
basis,
system,
lat_opts_2e,
ops,
)
S_lat = LatticeMatrixSet()
S_lat.nbf = basis.nbasis
S_lat.cells = direct_lattice_cells(system, lat_opts_2e.cutoff_bohr)
S_lat.blocks = S_blocks
_, T_blocks = compute_kinetic_lattice_reduced(
basis,
system,
lat_opts_2e,
ops,
)
T_lat = LatticeMatrixSet()
T_lat.nbf = basis.nbasis
T_lat.cells = direct_lattice_cells(system, lat_opts_2e.cutoff_bohr)
T_lat.blocks = T_blocks
else:
S_lat = compute_overlap_lattice(basis, system, lat_opts_2e)
T_lat = compute_kinetic_lattice(basis, system, lat_opts_2e)
v_ne_lr_cache = None
if (
system.dim == 3
and ewald_options_1e is not None
and v_ne_grid_options is None
):
V_lat, v_ne_lr_cache = _compute_nuclear_lattice_ewald_reciprocal_ft(
basis,
system,
lat_opts_1e,
ewald_options_1e,
S_lat,
precision=ewald_precision,
K_max=ewald_k_max,
)
else:
v_ne_grid = (
v_ne_grid_options
if v_ne_grid_options is not None
else (_default_bipole_v_ne_grid_options() if system.dim == 3 else None)
)
V_lat = compute_nuclear_lattice_dispatch(
basis,
system,
lat_opts_1e,
grid_options=v_ne_grid,
ewald_options=ewald_options_1e,
)
cells = list(S_lat.cells)
# Per-k matrices
from .linear_dependence import scf_preflight_overlap_check
S_k_list = []
Hcore_k_list = []
X_k_list = []
for k_idx, k in enumerate(k_points):
k_arr = np.asarray(k, dtype=float).reshape(3)
S_k = np.asarray(bloch_sum(S_lat, k_arr))
T_k = np.asarray(bloch_sum(T_lat, k_arr))
V_k = np.asarray(bloch_sum(V_lat, k_arr))
H_k = T_k + V_k
S_k = 0.5 * (S_k + S_k.conj().T)
H_k = 0.5 * (H_k + H_k.conj().T)
scf_preflight_overlap_check(S_k, plog=plog, label=f"S(k={k_idx})", basis=basis)
X_k, n_kept = _canonical_orthogonalizer_complex(
S_k,
linear_dep_threshold,
normalize_diag_first=canonical_orth_normalize_diag_first,
)
if max(n_alpha, n_beta) > n_kept:
raise RuntimeError(
f"run_pbc_bipole_uks: canonical orth at k={k_idx} dropped too many directions"
)
S_k_list.append(S_k)
Hcore_k_list.append(H_k)
X_k_list.append(X_k)
# Nuclear repulsion
e_nuc = (
float(ewald_nuclear_repulsion(system, ewald_options_1e))
if ewald_options_1e is not None
else float(nuclear_repulsion_per_cell(system, lat_opts_1e))
)
# Initial guess
C_alpha_per_k = []
eps_alpha_per_k = []
C_beta_per_k = []
eps_beta_per_k = []
for H_k, X_k in zip(Hcore_k_list, X_k_list):
C_a, eps_a = _diag_in_orth_basis(H_k, X_k)
C_b, eps_b = _diag_in_orth_basis(H_k, X_k)
C_alpha_per_k.append(C_a.astype(complex))
eps_alpha_per_k.append(eps_a)
C_beta_per_k.append(C_b.astype(complex))
eps_beta_per_k.append(eps_b)
def _spin_density(C_per_k_local, n_occ_each):
nbf = C_per_k_local[0].shape[1]
occ_per_k = []
for _ in range(n_k):
occ = np.zeros(nbf, dtype=float)
occ[:n_occ_each] = 1.0
occ_per_k.append(occ)
return real_space_density_from_kpoints_fractional(
C_per_k_local, occ_per_k, kmesh, cells
)
D_alpha_real = _spin_density(C_alpha_per_k, n_alpha)
D_beta_real = _spin_density(C_beta_per_k, n_beta)
# Caller-supplied warm-start spin densities take precedence over
# the SAD/Hcore guess engine — see run_pbc_bipole_uhf for the
# same contract. Used by the NEB driver for within-image density
# warm-start (HANDOVER_PERIODIC_NEB.md M4).
if (init_alpha is not None) != (init_beta is not None):
raise ValueError(
"run_pbc_bipole_uks: init_alpha and init_beta must be "
"provided together (both None or both populated)"
)
if init_alpha is not None and init_beta is not None:
blocks_a = list(init_alpha)
blocks_b = list(init_beta)
if len(blocks_a) != len(D_alpha_real.cells):
raise ValueError(
f"run_pbc_bipole_uks: init_alpha has {len(blocks_a)} "
f"blocks; expected {len(D_alpha_real.cells)}"
)
if len(blocks_b) != len(D_beta_real.cells):
raise ValueError(
f"run_pbc_bipole_uks: init_beta has {len(blocks_b)} "
f"blocks; expected {len(D_beta_real.cells)}"
)
for g_idx, (ba, bb) in enumerate(zip(blocks_a, blocks_b)):
D_alpha_real.set_block(g_idx, np.asarray(ba, dtype=float))
D_beta_real.set_block(g_idx, np.asarray(bb, dtype=float))
initial_density_is_local = True
density_from_c_per_k = False
else:
guess = getattr(opts, "initial_guess", InitialGuess.HCORE)
D_guess = None
if n_elec % 2 == 0:
D_total_guess = initial_density_closed_shell(
system.unit_cell_molecule(), basis, n_elec // 2, guess, is_periodic=True
)
if D_total_guess is not None:
D_guess = (
0.5 * np.asarray(D_total_guess, dtype=float),
0.5 * np.asarray(D_total_guess, dtype=float),
)
if D_guess is None:
D_guess = initial_densities_open_shell(
system.unit_cell_molecule(), basis, n_alpha, n_beta, guess, is_periodic=True
)
initial_density_is_local = D_guess is not None
if D_guess is not None:
D_a0, D_b0 = D_guess
zero_a = np.zeros_like(D_a0, dtype=float)
zero_b = np.zeros_like(D_b0, dtype=float)
for g_idx in range(len(D_alpha_real.cells)):
is_g0 = (
np.asarray(D_alpha_real.cells[g_idx].index, dtype=int)
== np.array([0, 0, 0])
).all()
D_alpha_real.set_block(g_idx, D_a0 if is_g0 else zero_a)
D_beta_real.set_block(g_idx, D_b0 if is_g0 else zero_b)
density_from_c_per_k = not initial_density_is_local
D_alpha_prev = None
D_beta_prev = None
# SCF aids
damping = float(opts.damping)
damper: Optional[DynamicDamping] = None
if bool(getattr(opts, "dynamic_damping", False)):
damper = DynamicDamping(
initial_alpha=damping,
alpha_min=float(getattr(opts, "dynamic_damping_min", 0.0)),
alpha_max=float(getattr(opts, "dynamic_damping_max", 0.95)),
)
use_diis = bool(opts.use_diis)
diis_start_iter = int(opts.diis_start_iter)
accel: Optional[MultiKPeriodicUHFAccelerator] = (
MultiKPeriodicUHFAccelerator(opts) if use_diis else None
)
level_shift_static = float(getattr(opts, "level_shift", 0.0))
if level_shift_schedule is not None and not isinstance(
level_shift_schedule, LevelShiftSchedule
):
raise TypeError(f"level_shift_schedule must be LevelShiftSchedule")
if use_mom:
plog.info("MOM: ON")
C_prev_occ_alpha_per_k = None
C_prev_occ_beta_per_k = None
if use_oda and use_diis:
raise ValueError("use_oda and use_diis are mutually exclusive")
if use_oda:
plog.info(f"ODA: ON (trust lambda_max = {oda_trust_lambda_max})")
# Ewald J-split cache
j_lr_cache = v_ne_lr_cache
if use_ewald_j_split:
if system.dim != 3:
raise ValueError(f"use_ewald_j_split requires dim=3")
from .bipole_fock_ewald import _build_j_long_range_cache
assert omega_used is not None
cells_r_cart_arr = np.array(
[np.asarray(c.r_cart, dtype=float) for c in cells], dtype=float
)
if j_lr_cache is None:
j_lr_cache = _build_j_long_range_cache(
basis,
system,
cells_r_cart_arr,
omega_used,
ewald_precision,
K_max=ewald_k_max,
)
def _spin_d_matrices_from_coeffs(coeffs_per_k, n_occ):
out = []
for C_k_raw in coeffs_per_k:
C_k = np.asarray(C_k_raw)
C_occ = C_k[:, :n_occ] if n_occ > 0 else C_k[:, :0]
out.append(C_occ @ C_occ.conj().T)
return out
def _build_fock_for_density(
D_alpha, D_beta, *, coeffs_alpha_for_rho, coeffs_beta_for_rho
):
D_total = _combine_density_sets(basis, system, lat_opts_2e, D_alpha, D_beta)
if use_ewald_j_split:
F_J_SR_lat = build_fock_2e_real_space(
basis, system, lat_opts_2e, D_total, 0.0, float(omega_used)
)
jk_alpha = build_jk_2e_real_space(basis, system, lat_opts_2e, D_alpha, 0.0)
jk_beta = build_jk_2e_real_space(basis, system, lat_opts_2e, D_beta, 0.0)
K_alpha_blocks = [
np.asarray(block, dtype=float).copy() for block in jk_alpha.K.blocks
]
K_beta_blocks = [
np.asarray(block, dtype=float).copy() for block in jk_beta.K.blocks
]
rho_hat_for_LR = None
if coeffs_alpha_for_rho is not None and coeffs_beta_for_rho is not None:
from .bipole_fock_ewald import compute_rho_hat_from_k_density
D_a_k = _spin_d_matrices_from_coeffs(coeffs_alpha_for_rho, n_alpha)
D_b_k = _spin_d_matrices_from_coeffs(coeffs_beta_for_rho, n_beta)
D_total_k = [Da + Db for Da, Db in zip(D_a_k, D_b_k)]
if _ir_mapping.size > 0:
D_full = [D_total_k[int(idx)] for idx in _ir_mapping]
rho_hat_for_LR = compute_rho_hat_from_k_density(
D_full, k_points_full, weights_full, j_lr_cache
)
else:
rho_hat_for_LR = compute_rho_hat_from_k_density(
D_total_k, k_points, weights, j_lr_cache
)
from .bipole_fock_ewald import compute_J_long_range_real_space_blocks
F_LR_blocks = compute_J_long_range_real_space_blocks(
D_total,
basis,
system,
omega_used,
precision=ewald_precision,
cache=j_lr_cache,
rho_hat=rho_hat_for_LR,
)
assert ewald_cell_volume is not None
j_background_potential = (
-np.pi
* float(n_elec)
/ (float(omega_used) * float(omega_used) * float(ewald_cell_volume))
)
s_blocks = {
_cell_key(cell): np.asarray(block, dtype=float)
for cell, block in zip(S_lat.cells, S_lat.blocks)
}
for c, cell in enumerate(F_J_SR_lat.cells):
F_LR_blocks[c] = (
F_LR_blocks[c] + j_background_potential * s_blocks[_cell_key(cell)]
)
j_sr_blocks = [
np.asarray(block, dtype=float).copy() for block in F_J_SR_lat.blocks
]
e_j_short_range = 0.5 * _lattice_contract_blocks(
D_total, F_J_SR_lat.cells, j_sr_blocks, operator_name="J_SR"
)
e_j_long_range = 0.5 * _lattice_contract_blocks(
D_total, F_J_SR_lat.cells, F_LR_blocks, operator_name="J_LR"
)
e_exchange = None
alpha_blocks = []
beta_blocks = []
for c in range(len(F_J_SR_lat.cells)):
j_blk = j_sr_blocks[c] + F_LR_blocks[c]
k_a = K_alpha_blocks[c]
k_b = K_beta_blocks[c]
if alpha_hf > 0.0:
alpha_blocks.append(j_blk - 0.5 * alpha_hf * k_a)
beta_blocks.append(j_blk - 0.5 * alpha_hf * k_b)
if e_exchange is None:
e_exchange = 0.0
d_a = np.asarray(D_alpha.blocks[c], dtype=float)
d_b = np.asarray(D_beta.blocks[c], dtype=float)
e_exchange = e_exchange - 0.5 * alpha_hf * float(
np.sum(d_a * k_a) + np.sum(d_b * k_b)
)
else:
alpha_blocks.append(j_blk)
beta_blocks.append(j_blk)
f2e_alpha_real = F_J_SR_lat
for c, block in enumerate(alpha_blocks):
f2e_alpha_real.set_block(c, block)
f2e_beta_real = _copy_lattice_with_blocks(
basis, system, lat_opts_2e, F_J_SR_lat.cells, beta_blocks
)
# ---- Optional: replace J_SR+J_LR with multipole far-field J --
e_j_multipole: Optional[float] = None
if _mp_config.enabled:
from .bipole_fock_multipole import apply_multipole_far_field
# Build multipole J from total density; K is spin-specific.
mp_result = apply_multipole_far_field(
D_total,
basis,
system,
lat_opts_2e,
_mp_config,
J_SR_blocks=j_sr_blocks,
K_blocks=None,
F_LR_blocks=F_LR_blocks,
exchange_scale=0.0,
)
for c in range(len(F_J_SR_lat.cells)):
f2e_alpha_real.set_block(
c,
mp_result.f2e_blocks[c] - K_alpha_blocks[c],
)
f2e_beta_real.set_block(
c,
mp_result.f2e_blocks[c] - K_beta_blocks[c],
)
e_j_multipole = mp_result.e_j_multipole
if mp_result.n_far_cells > 0:
plog.info(
f" BIPOLE multipole far-field (L_max={_mp_config.L_max}, "
f"R={_mp_config.R_bipole:.1f} bohr): "
f"{mp_result.n_far_cells}/{mp_result.n_total_cells} cells replaced, "
f"E_J_far = {e_j_multipole:+.6f} Ha"
)
else:
F_J_lat = build_fock_2e_real_space(
basis, system, lat_opts_2e, D_total, 0.0, 0.0
)
jk_alpha = build_jk_2e_real_space(basis, system, lat_opts_2e, D_alpha, 0.0)
jk_beta = build_jk_2e_real_space(basis, system, lat_opts_2e, D_beta, 0.0)
j_blocks = [
np.asarray(block, dtype=float).copy() for block in F_J_lat.blocks
]
K_alpha_blocks = [
np.asarray(block, dtype=float).copy() for block in jk_alpha.K.blocks
]
K_beta_blocks = [
np.asarray(block, dtype=float).copy() for block in jk_beta.K.blocks
]
if alpha_hf > 0.0:
alpha_blocks = [
j - 0.5 * alpha_hf * k for j, k in zip(j_blocks, K_alpha_blocks)
]
beta_blocks = [
j - 0.5 * alpha_hf * k for j, k in zip(j_blocks, K_beta_blocks)
]
else:
alpha_blocks = list(j_blocks)
beta_blocks = list(j_blocks)
f2e_alpha_real = F_J_lat
for c, block in enumerate(alpha_blocks):
f2e_alpha_real.set_block(c, block)
f2e_beta_real = _copy_lattice_with_blocks(
basis, system, lat_opts_2e, F_J_lat.cells, beta_blocks
)
e_j_short_range = 0.5 * _lattice_contract_blocks(
D_total, F_J_lat.cells, j_blocks, operator_name="J"
)
e_j_long_range = None
e_exchange = None
if alpha_hf > 0.0:
e_exchange = (
-0.5
* alpha_hf
* (
_lattice_contract_blocks(
D_alpha,
F_J_lat.cells,
K_alpha_blocks,
operator_name="K_alpha",
)
+ _lattice_contract_blocks(
D_beta, F_J_lat.cells, K_beta_blocks, operator_name="K_beta"
)
)
)
# XC potential
D_xc_alpha = _density_set_gamma_or_lattice(S_lat, D_alpha)
D_xc_beta = _density_set_gamma_or_lattice(S_lat, D_beta)
xc_result = build_xc_periodic(
basis, system, grid, D_xc_alpha, func, density_beta=D_xc_beta
)
# Per-k Fock
f_alpha_k_list = []
f_beta_k_list = []
for k_idx, k in enumerate(k_points):
k_arr = np.asarray(k, dtype=float)
F_a_2e = _bloch_sum_blocks(
f2e_alpha_real.blocks, f2e_alpha_real.cells, k_arr
)
F_b_2e = _bloch_sum_blocks(f2e_beta_real.blocks, f2e_beta_real.cells, k_arr)
Vxc_a_k = np.asarray(bloch_sum(xc_result.vxc_alpha_lattice, k_arr))
Vxc_b_k = np.asarray(bloch_sum(xc_result.vxc_beta_lattice, k_arr))
F_a = F_a_2e + Vxc_a_k + np.asarray(Hcore_k_list[k_idx], dtype=complex)
F_b = F_b_2e + Vxc_b_k + np.asarray(Hcore_k_list[k_idx], dtype=complex)
f_alpha_k_list.append(0.5 * (F_a + F_a.conj().T))
f_beta_k_list.append(0.5 * (F_b + F_b.conj().T))
return _PBCBipoleUKSFockBuild(
f2e_alpha_real=f2e_alpha_real,
f2e_beta_real=f2e_beta_real,
f_alpha_k_list=f_alpha_k_list,
f_beta_k_list=f_beta_k_list,
e_j_short_range=e_j_short_range,
e_j_long_range=e_j_long_range,
e_exchange=e_exchange,
e_j_multipole=e_j_multipole,
)
# ---- Multipole far-field config (resolve once before SCF loop) -----
from .bipole_fock_multipole import ( # noqa: E402
BipoleMultipoleConfig,
resolve_multipole_config,
)
_mp_config = resolve_multipole_config(
system,
basis,
lat_opts_2e,
user_enable=use_multipole_far_field,
multipole_l_max=multipole_l_max,
)
if _mp_config.enabled:
plog.info(
f" BIPOLE multipole far-field: ENABLED "
f"(L_max={_mp_config.L_max}, R_bipole={_mp_config.R_bipole:.1f} bohr, "
f"n_cells={len(_mp_config.cache.cells) if _mp_config.cache else 0})"
)
else:
plog.info(
f" BIPOLE multipole far-field: off "
f"(R_bipole={_mp_config.R_bipole:.1f} bohr, "
f"cutoff={lat_opts_2e.cutoff_bohr:.1f} bohr)"
)
plog.banner("SCF (PBC BIPOLE UKS, direct-space)")
plog.info(" iter energy (Ha) dE ||[F,DS]|| DIIS")
scf_trace = []
energy_components = []
E_prev = 0.0
E_elec = 0.0
e_xc = 0.0
F_alpha_k_list = [np.zeros_like(H) for H in Hcore_k_list]
F_beta_k_list = [np.zeros_like(H) for H in Hcore_k_list]
converged = False
iter_idx = 0
for iter_idx in range(1, int(opts.max_iter) + 1):
if damper is not None:
damping = damper.alpha
diis_active = use_diis and iter_idx >= diis_start_iter
if iter_idx > 1 and damping > 0.0 and not diis_active:
D_alpha_used = _damp_lattice_matrix(D_alpha_real, D_alpha_prev, damping)
D_beta_used = _damp_lattice_matrix(D_beta_real, D_beta_prev, damping)
else:
D_alpha_used = D_alpha_real
D_beta_used = D_beta_real
d_used_is_damped = iter_idx > 1 and damping > 0.0 and not diis_active
d_used_from_coeffs = (
density_from_c_per_k
and not (initial_density_is_local and iter_idx == 1)
and not d_used_is_damped
)
fock_build = _build_fock_for_density(
D_alpha_used,
D_beta_used,
coeffs_alpha_for_rho=(C_alpha_per_k if d_used_from_coeffs else None),
coeffs_beta_for_rho=(C_beta_per_k if d_used_from_coeffs else None),
)
F_alpha_k_list = fock_build.f_alpha_k_list
F_beta_k_list = fock_build.f_beta_k_list
# XC energy
D_xc_alpha = _density_set_gamma_or_lattice(S_lat, D_alpha_used)
D_xc_beta = _density_set_gamma_or_lattice(S_lat, D_beta_used)
xc_result = build_xc_periodic(
basis, system, grid, D_xc_alpha, func, density_beta=D_xc_beta
)
e_xc = float(xc_result.exc)
# Energy
D_total_used = _combine_density_sets(
basis, system, lat_opts_2e, D_alpha_used, D_beta_used
)
E_kin = _lattice_contract(D_total_used, T_lat, operator_name="T")
E_ne = _lattice_contract(D_total_used, V_lat, operator_name="V_ne")
E_2e = 0.5 * (
_lattice_contract(
D_alpha_used, fock_build.f2e_alpha_real, operator_name="F2e_alpha"
)
+ _lattice_contract(
D_beta_used, fock_build.f2e_beta_real, operator_name="F2e_beta"
)
)
E_elec = E_kin + E_ne + E_2e + e_xc
# Gradient
grad_norm_sum = 0.0
error_alpha_k_list = []
error_beta_k_list = []
D_alpha_k_list = []
D_beta_k_list = []
for idx in range(n_k):
if initial_density_is_local and iter_idx == 1:
k_arr = np.asarray(k_points[idx], dtype=float)
D_a_k = _bloch_sum_blocks(
D_alpha_used.blocks, D_alpha_used.cells, k_arr
)
D_b_k = _bloch_sum_blocks(D_beta_used.blocks, D_beta_used.cells, k_arr)
D_a_k = 0.5 * (D_a_k + D_a_k.conj().T)
D_b_k = 0.5 * (D_b_k + D_b_k.conj().T)
else:
C_a = C_alpha_per_k[idx]
C_b = C_beta_per_k[idx]
C_a_occ = C_a[:, :n_alpha] if n_alpha > 0 else C_a[:, :0]
C_b_occ = C_b[:, :n_beta] if n_beta > 0 else C_b[:, :0]
D_a_k = C_a_occ @ C_a_occ.conj().T
D_b_k = C_b_occ @ C_b_occ.conj().T
D_alpha_k_list.append(D_a_k)
D_beta_k_list.append(D_b_k)
S_k = S_k_list[idx]
F_a_k = F_alpha_k_list[idx]
F_b_k = F_beta_k_list[idx]
err_a = F_a_k @ D_a_k @ S_k - (F_a_k @ D_a_k @ S_k).conj().T
err_b = F_b_k @ D_b_k @ S_k - (F_b_k @ D_b_k @ S_k).conj().T
error_alpha_k_list.append(err_a)
error_beta_k_list.append(err_b)
grad_norm_sum += float(weights[idx]) * float(
np.sqrt(np.linalg.norm(err_a) ** 2 + np.linalg.norm(err_b) ** 2)
)
E_total = float(E_elec) + e_nuc
# EXT EL-SPHEROPOLE — uses total (alpha+beta) density.
D_total_used = _combine_density_sets(
basis, system, lat_opts_2e, D_alpha_used, D_beta_used
)
E_sphero = compute_ext_el_spheropole(D_total_used, basis, system, lat_opts)
E_total += E_sphero
dE = E_total - E_prev if iter_idx > 1 else 0.0
check_scf_divergence("run_pbc_bipole_uks", iter_idx, E_total, grad_norm_sum, dE)
diis_sub = max(
accel.subspace_size if accel is not None else 0,
accel.subspace_size if accel is not None else 0,
)
scf_trace.append(
SCFIteration(
iter=iter_idx,
energy=float(E_total),
delta_e=float(dE),
grad_norm=float(grad_norm_sum),
diis_subspace=diis_sub,
)
)
plog.iteration(
iter_idx,
energy=float(E_total),
dE=float(dE),
grad=float(grad_norm_sum),
diis=diis_sub,
)
energy_components.append(
PBCBipoleEnergyComponents(
iter=int(iter_idx),
e_total=float(E_total),
e_electronic=float(E_elec),
e_kinetic=float(E_kin),
e_nuclear_attraction=float(E_ne),
e_two_electron=float(E_2e),
e_nuclear_repulsion=float(e_nuc),
e_bielet_zone_ee=(None if use_ewald_j_split else float(E_2e)),
e_j_short_range=fock_build.e_j_short_range,
e_j_long_range=fock_build.e_j_long_range,
e_exchange=fock_build.e_exchange,
)
)
plog.energy_decomposition(
iter_idx,
E_kin=float(E_kin),
E_ne=float(E_ne),
E_2e=float(E_2e),
E_elec=float(E_elec),
E_nuc=float(e_nuc),
)
converged = (
iter_idx > 1
and abs(dE) < float(opts.conv_tol_energy)
and grad_norm_sum < float(opts.conv_tol_grad)
)
# SCF-accelerator extrapolation.
if accel is not None:
density_alpha_k_list = [
_bloch_sum_blocks(
D_alpha_used.blocks,
D_alpha_used.cells,
np.asarray(k),
)
for k in k_points
]
density_beta_k_list = [
_bloch_sum_blocks(
D_beta_used.blocks,
D_beta_used.cells,
np.asarray(k),
)
for k in k_points
]
F_a_ex, F_b_ex = accel.extrapolate_uhf(
F_alpha_k_list,
F_beta_k_list,
error_alpha_k_list=error_alpha_k_list,
error_beta_k_list=error_beta_k_list,
density_alpha_k_list=density_alpha_k_list,
density_beta_k_list=density_beta_k_list,
energy=E_total,
mo_coeffs_alpha_k_list=C_alpha_per_k,
mo_coeffs_beta_k_list=C_beta_per_k,
n_alpha=n_alpha,
n_beta=n_beta,
weights=list(weights),
cells=cells,
kpoints=list(k_points),
)
if diis_active:
F_alpha_k_list = F_a_ex
F_beta_k_list = F_b_ex
# Level shift
if level_shift_schedule is not None:
level_shift_b = level_shift_schedule.at(iter_idx)
else:
level_shift_b = level_shift_static
if level_shift_b != 0.0:
F_alpha_for_diag = []
F_beta_for_diag = []
for idx in range(n_k):
S_k = S_k_list[idx]
D_a_k = D_alpha_k_list[idx]
D_b_k = D_beta_k_list[idx]
F_a_shift = (
F_alpha_k_list[idx]
+ level_shift_b * S_k
- (level_shift_b / 2.0) * (S_k @ D_a_k @ S_k)
)
F_b_shift = (
F_beta_k_list[idx]
+ level_shift_b * S_k
- (level_shift_b / 2.0) * (S_k @ D_b_k @ S_k)
)
F_alpha_for_diag.append(0.5 * (F_a_shift + F_a_shift.conj().T))
F_beta_for_diag.append(0.5 * (F_b_shift + F_b_shift.conj().T))
else:
F_alpha_for_diag = F_alpha_k_list
F_beta_for_diag = F_beta_k_list
# Diagonalise
new_C_alpha = []
new_eps_alpha = []
new_C_beta = []
new_eps_beta = []
for idx in range(n_k):
C_a, eps_a = _diag_in_orth_basis(F_alpha_for_diag[idx], X_k_list[idx])
C_b, eps_b = _diag_in_orth_basis(F_beta_for_diag[idx], X_k_list[idx])
new_C_alpha.append(C_a)
new_eps_alpha.append(eps_a)
new_C_beta.append(C_b)
new_eps_beta.append(eps_b)
# MOM
if use_mom and C_prev_occ_alpha_per_k is not None:
for idx in range(n_k):
for spin, (C_k, eps_k, n_occ_spin, C_prev_occ_k) in enumerate(
[
(
new_C_alpha[idx],
new_eps_alpha[idx],
n_alpha,
C_prev_occ_alpha_per_k[idx],
),
(
new_C_beta[idx],
new_eps_beta[idx],
n_beta,
C_prev_occ_beta_per_k[idx],
),
]
):
if n_occ_spin == 0:
continue
S_k = S_k_list[idx]
sel = _mom_select(C_k, S_k, C_prev_occ_k, n_occ_spin, eps_new=eps_k)
n_kept_idx = C_k.shape[1]
virt_mask = np.ones(n_kept_idx, dtype=bool)
virt_mask[sel] = False
virt_sel = np.where(virt_mask)[0]
virt_sel = virt_sel[np.argsort(np.real(eps_k[virt_sel]))]
order = np.concatenate([sel, virt_sel])
if spin == 0:
new_C_alpha[idx] = C_k[:, order]
new_eps_alpha[idx] = eps_k[order]
else:
new_C_beta[idx] = C_k[:, order]
new_eps_beta[idx] = eps_k[order]
C_alpha_per_k = new_C_alpha
eps_alpha_per_k = new_eps_alpha
C_beta_per_k = new_C_beta
eps_beta_per_k = new_eps_beta
D_alpha_new = _spin_density(C_alpha_per_k, n_alpha)
D_beta_new = _spin_density(C_beta_per_k, n_beta)
# ODA
if use_oda:
fock_naive = _build_fock_for_density(
D_alpha_new,
D_beta_new,
coeffs_alpha_for_rho=C_alpha_per_k,
coeffs_beta_for_rho=C_beta_per_k,
)
oda_step = _compute_oda_lambda(
D_alpha_used,
D_alpha_new,
F_alpha_k_list,
fock_naive.f_alpha_k_list,
[np.asarray(k) for k in k_points],
weights,
trust_lambda_max=oda_trust_lambda_max,
)
_oda_mix(D_alpha_used, D_alpha_new, oda_step.lam)
_oda_mix(D_beta_used, D_beta_new, oda_step.lam)
D_alpha_prev = D_alpha_real
D_beta_prev = D_beta_real
D_alpha_real = D_alpha_used
D_beta_real = D_beta_used
density_from_c_per_k = oda_step.lam == 1.0
plog.info(
f" ODA: lambda = {oda_step.lam:.4f} (g0 = {oda_step.g0:+.3e}, g1 = {oda_step.g1:+.3e})"
)
else:
D_alpha_prev = D_alpha_used
D_beta_prev = D_beta_used
D_alpha_real = D_alpha_new
D_beta_real = D_beta_new
density_from_c_per_k = True
if use_mom:
C_prev_occ_alpha_per_k = [
np.asarray(C_alpha_per_k[idx][:, :n_alpha]).copy()
if n_alpha > 0
else np.zeros((C_alpha_per_k[idx].shape[0], 0), dtype=complex)
for idx in range(n_k)
]
C_prev_occ_beta_per_k = [
np.asarray(C_beta_per_k[idx][:, :n_beta]).copy()
if n_beta > 0
else np.zeros((C_beta_per_k[idx].shape[0], 0), dtype=complex)
for idx in range(n_k)
]
if damper is not None:
damper.update(E_total)
E_prev = E_total
if converged:
break
# S^2
if n_alpha == 0 or n_beta == 0:
s2 = 0.25 * (n_alpha - n_beta) * (n_alpha - n_beta + 2) + n_beta
else:
k0_idx = next(
idx
for idx, k in enumerate(k_points)
if np.allclose(np.asarray(k, dtype=float), 0.0)
)
s2 = _spin_squared(
n_alpha,
n_beta,
np.real(C_alpha_per_k[k0_idx]),
np.real(C_beta_per_k[k0_idx]),
np.real(S_k_list[k0_idx]),
)
plog.converged(n_iter=iter_idx, energy=E_total, converged=converged)
# ---- Post-loop: recompute energy on final density for consistency
if converged:
_fb = _build_fock_for_density(
D_alpha_real,
D_beta_real,
coeffs_alpha_for_rho=C_alpha_per_k,
coeffs_beta_for_rho=C_beta_per_k,
)
D_xc_a = _density_set_gamma_or_lattice(S_lat, D_alpha_real)
D_xc_b = _density_set_gamma_or_lattice(S_lat, D_beta_real)
_xc = build_xc_periodic(basis, system, grid, D_xc_a, func, density_beta=D_xc_b)
D_tot = _combine_density_sets(
basis, system, lat_opts_2e, D_alpha_real, D_beta_real
)
E_kin_f = _lattice_contract(D_tot, T_lat, operator_name="T")
E_ne_f = _lattice_contract(D_tot, V_lat, operator_name="V_ne")
E_2e_f = 0.5 * (
_lattice_contract(D_alpha_real, _fb.f2e_alpha_real, operator_name="F2e")
+ _lattice_contract(D_beta_real, _fb.f2e_beta_real, operator_name="F2e")
)
E_elec = E_kin_f + E_ne_f + E_2e_f + float(_xc.exc)
E_total = float(E_elec) + e_nuc
# Fresh E_total doesn't include spheropole — add it.
E_sphero_final = compute_ext_el_spheropole(D_tot, basis, system, lat_opts)
E_total += E_sphero_final
else:
E_sphero_final = None
return PBCBipoleUKSResult(
energy=float(E_total),
e_electronic=float(E_elec),
e_nuclear=e_nuc,
e_ext_el_spheropole=E_sphero_final,
e_xc=float(e_xc),
e_coulomb=float(E_2e),
e_hf_exchange=float(fock_build.e_exchange or 0.0),
n_iter=iter_idx,
converged=converged,
s_squared=float(s2),
s_squared_ideal=0.25 * (mult - 1) * (mult + 1),
mo_energies_alpha=eps_alpha_per_k,
mo_coeffs_alpha=C_alpha_per_k,
fock_alpha=F_alpha_k_list,
density_alpha=D_alpha_real,
mo_energies_beta=eps_beta_per_k,
mo_coeffs_beta=C_beta_per_k,
fock_beta=F_beta_k_list,
density_beta=D_beta_real,
overlap=S_k_list,
hcore=Hcore_k_list,
scf_trace=scf_trace,
ewald_alpha_bohr_inv=omega_used,
energy_components=energy_components,
functional=str(opts.functional),
)