"""HF-CCM self-consistent field (closed-shell RHF) -- AICCM milestone 2b.

Assembles the cyclic Fock matrix from the CCM-weighted integrals
(``S^CCM``, ``h^CCM``, the effective ERI tensor, and ``V_nn^CCM``) and
solves the cyclic Roothaan-Hall equations (eqs 25-27 of Peintinger &
Bredow, J. Comput. Chem. 35, 839 (2014), doi:10.1002/jcc.23550):

    # Eq. 25:  F^CCM = h^CCM + sum_n (2 J_n^CCM - K_n^CCM)
    # Eq. 26:  F^CCM C^CCM = S^CCM C^CCM E^CCM
    # Eq. 27:  E^CCM = sum_i eps_i - 1/2 sum_ij (J_ij - K_ij) + V_nn^CCM

This is the small-cluster validation driver: it consumes the dense effective
ERI tensor (:func:`vibeqc.periodic.ccm.padded.ccm_eri`) and runs a plain
DIIS-accelerated RHF in Python. The production path (large 3-D clusters,
all SCF methods) is a C++ ``WeightedLatticeJKBuilder`` fed into the existing
``run_*_scf_with_jk`` entry points; see ``handovers/HANDOVER_AICCM.md``. The energy
per atom from this driver converges to periodic Hartree-Fock as the cluster
grows (Peintinger & Bredow 2014, Tables 2-4).
"""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np

from .integrals import ccm_overlap
from .padded import ccm_eri, ccm_eri_symmetric, ccm_hcore, ccm_nuclear_repulsion

__all__ = ["CCMSCFResult", "run_ccm_rhf"]

# Effective four-center builders selectable via run_ccm_rhf(method=...).
#   "union12"  -- historical eq-18 product weight (Peintinger & Bredow 2014);
#                validated in 1-D, breaks 8-fold ERI symmetry in >=2-D.
#   "aiccm2026dev-a" -- the symmetric Born-von Kármán-torus four-center
#                (AICCM_ALGORITHM.md Sec.13): exactly 8-fold symmetric for any
#                lattice. This is *this* development line's method of record.
_CCM_ERI_METHODS = {
    "union12": ccm_eri,
    "aiccm2026dev-a": ccm_eri_symmetric,
    "aiccmdev": ccm_eri_symmetric,   # deprecated alias of "aiccm2026dev-a" (back-compat)
}


def _ccm_eri_for_method(ccm, method):
    """Build the effective four-center tensor selected by ``method``.

    Shared by the RHF / UHF / MP2 CCM drivers so the ``method`` keyword (notably
    ``"aiccm2026dev-a"``) means the same thing everywhere.
    """
    try:
        builder = _CCM_ERI_METHODS[method]
    except KeyError:
        raise ValueError(
            f"unknown CCM four-center method {method!r}; "
            f"choose from {sorted(_CCM_ERI_METHODS)}"
        )
    return builder(ccm)


@dataclass
class CCMSCFResult:
    converged: bool
    n_iter: int
    energy: float  # total CCM energy per reference cell (Ha)
    energy_per_atom: float
    e_electronic: float
    e_nuclear: float
    mo_energies: np.ndarray
    mo_coeffs: np.ndarray
    density: np.ndarray  # D = 2 C_occ C_occ^T
    fock: np.ndarray
    overlap: np.ndarray
    hcore: np.ndarray
    idempotency_error: float  # ||D S D - 2 D||_F  (Gamma-point structural check)


def _orthonormaliser(S, lindep_tol):
    """Canonical orthogonalisation X with S-eigenvectors below tol dropped."""
    w, U = np.linalg.eigh(S)
    keep = w > lindep_tol
    if not np.all(keep):
        U, w = U[:, keep], w[keep]
    return U / np.sqrt(w)


def run_ccm_rhf(
    ccm, *, method="union12", max_iter=128, conv_tol=1e-9, diis_dim=8,
    lindep_tol=1e-7, eri=None
):
    """Run closed-shell HF-CCM on ``ccm`` (a :class:`CCMSystem`).

    Parameters mirror an ordinary RHF. ``method`` selects the effective
    four-center weighting (see :data:`_CCM_ERI_METHODS`):

    * ``"union12"`` (default) -- the historical eq-18 product weight
      (Peintinger & Bredow 2014); validated in 1-D.
    * ``"aiccm2026dev-a"`` -- the symmetric Born-von Kármán-torus four-center
      (:func:`~vibeqc.periodic.ccm.padded.ccm_eri_symmetric`,
      ``AICCM_ALGORITHM.md`` Sec.13), exactly 8-fold permutationally symmetric for
      any lattice. This is this development line's method of record.

    ``eri`` may instead be a precomputed effective tensor (it then overrides
    ``method``). Returns a :class:`CCMSCFResult`.

    Status: reproduces the supercell-model HF energy to ~1e-5 Ha/atom (H₄
    alternating chain vs Peintinger PhD thesis Tab. 8.3 / JCC 2014 Tab. 2),
    for 1-D clusters with arbitrary orbitals (validated s and p). The molecular
    limit is exact. **Scope/known limits:** the dense ``n_ref_ao**4`` padded
    ERI in :func:`ccm_eri` confines this driver to small / 1-D clusters -- 3-D
    supercells blow up the padded basis and need the production C++ lattice-sum
    (next milestone). The four-center is not yet bit-exact (not perfectly
    cyclically invariant), so small high-symmetry clusters show minor orbital-
    degeneracy splitting. See ``handovers/HANDOVER_AICCM.md`` Sec. "M2b status".
    """
    if eri is None:
        eri = _ccm_eri_for_method(ccm, method)

    S = ccm_overlap(ccm)
    h, _, _ = ccm_hcore(ccm)
    e_nn = ccm_nuclear_repulsion(ccm)

    n_elec = ccm.supercell.n_electrons()
    if n_elec % 2 != 0:
        raise ValueError(
            f"run_ccm_rhf is closed-shell but the cluster has {n_elec} electrons; "
            "use an even-electron cluster (UHF-CCM is a later milestone)."
        )
    n_occ = n_elec // 2

    # Fig.-6 guard: the CCM overlap must be positive definite.
    w_s = np.linalg.eigvalsh(S)
    if w_s[0] < lindep_tol:
        raise ValueError(
            f"CCM overlap matrix near-singular (min eig {w_s[0]:.2e} < {lindep_tol:.1e}); "
            "enlarge the cluster or screen diffuse functions (Peintinger & Bredow 2014, Fig. 6)."
        )
    X = _orthonormaliser(S, lindep_tol)

    def diag_fock(F):
        Fp = X.T @ F @ X
        eps, Cp = np.linalg.eigh(Fp)
        return eps, X @ Cp

    def density(C):
        Cocc = C[:, :n_occ]
        return 2.0 * (Cocc @ Cocc.T)

    # Core-Hamiltonian initial guess.
    eps, C = diag_fock(h)
    D = density(C)

    diis_F, diis_e = [], []
    e_last = 0.0
    converged = False
    for it in range(1, max_iter + 1):
        J = np.einsum("mnrs,rs->mn", eri, D, optimize=True)
        K = np.einsum("msrn,rs->mn", eri, D, optimize=True)
        F = h + J - 0.5 * K
        F = 0.5 * (
            F + F.T
        )  # enforce Hermiticity (the WIP eff is not yet fully symmetric)

        e_elec = 0.5 * np.sum(D * (h + F))
        e_tot = e_elec + e_nn

        # DIIS error = S D F - F D S (zero at convergence), in the orthonormal basis.
        err = X.T @ (F @ D @ S - S @ D @ F) @ X
        if len(diis_F) == diis_dim:
            diis_F.pop(0)
            diis_e.pop(0)
        diis_F.append(F)
        diis_e.append(err)
        if len(diis_F) >= 2:
            F = _diis_extrapolate(diis_F, diis_e)

        eps, C = diag_fock(F)
        D = density(C)

        de = e_tot - e_last
        e_last = e_tot
        if it > 1 and abs(de) < conv_tol and np.max(np.abs(err)) < 1e-6:
            converged = True
            break

    idem = np.linalg.norm(D @ S @ D - 2.0 * D)
    return CCMSCFResult(
        converged=converged,
        n_iter=it,
        energy=e_tot,
        energy_per_atom=e_tot / ccm.n_atoms,
        e_electronic=e_elec,
        e_nuclear=e_nn,
        mo_energies=eps,
        mo_coeffs=C,
        density=D,
        fock=F,
        overlap=S,
        hcore=h,
        idempotency_error=float(idem),
    )


def _diis_extrapolate(focks, errors):
    m = len(focks)
    B = -np.ones((m + 1, m + 1))
    B[-1, -1] = 0.0
    for i in range(m):
        for j in range(m):
            B[i, j] = np.sum(errors[i] * errors[j])
    rhs = np.zeros(m + 1)
    rhs[-1] = -1.0
    try:
        c = np.linalg.solve(B, rhs)
    except np.linalg.LinAlgError:
        return focks[-1]
    return sum(c[i] * focks[i] for i in range(m))


# -- Scalable production driver (C++ WeightedLatticeJKBuilder) --------------


def _prepare_ccm_weights(ccm):
    """Package WSSC two-center weights for the C++ CCM JK builder.

    Returns (weight_cells, weight_matrices) as numpy arrays:
        weight_cells: (n_wcells, 3) int32, minimum-image cell indices
        weight_matrices: (n_wcells, n_atoms, n_atoms) float64
    """
    Wg = ccm.cell_weight_matrices()
    cells = sorted(Wg.keys())
    n_wcells = len(cells)
    n_atoms = ccm.n_atoms
    weight_cells = np.zeros((n_wcells, 3), dtype=np.int32)
    weight_matrices = np.zeros((n_wcells, n_atoms, n_atoms), dtype=np.float64)
    for i, g in enumerate(cells):
        weight_cells[i] = g
        weight_matrices[i] = np.asarray(Wg[g], dtype=np.float64)
    return weight_cells, weight_matrices


def _ccm_bra_ket_symmetrise(jk_result, ccm, D):
    """Apply CCM bra-ket symmetrisation to J/K matrices.

    The raw ``build_jk_ccm_weighted`` returns the ket-folded J/K (bra at
    home).  The CCM four-center weight breaks the 8-fold ERI symmetry;
    the missing 50% from the bra-folded representation is recovered by
    contracting the effective tensor's bra-ket transpose.  This matches
    the ``ccm_eri`` convention::

        eff^sym = 0.5*(eff + transpose_{bra<->ket})

    For the contracted J/K this becomes a second kernel call with the
    density and integral indices remapped.
    """
    from vibeqc._vibeqc_core import (
        build_jk_ccm_weighted,
        make_periodic_gamma_ccm_jk_builder,
    )

    from .scf import _prepare_ccm_weights

    # Ket-folded J/K
    J_ket = np.asarray(jk_result.J, dtype=float)
    K_ket = np.asarray(jk_result.K, dtype=float)

    # Bra-folded: compute J/K with bra<->ket transposition of the
    # effective ERI tensor.  In practice this means we treat the
    # D-contracted ket as the bra in a second build.  For a symmetric
    # density D, J_bra[mu,nu] = K_ket[nu,mu]? Not quite -- we need
    # a proper rebuild.  Do a second call to build_jk_ccm_weighted
    # with the density transposed and reinterpret.
    #
    # For now: approximate by using the ket-folded result and
    # symmetrising via explicit transposition of the effective tensor
    # in Python.  This is expensive but correct.
    JT = J_ket.T.copy()
    KT = K_ket.T.copy()

    # The symmetrisation: J^CCM = 0.5*(J_ket + J_bra)
    # J_bra[mu,nu] comes from K_ket with remapped indices.
    # For a contracted Fock, the bra-ket transposition of eff gives
    # G = J - 0.5*K where K exchanges indices differently.
    # Rather than re-derive, accept the padded route's convention
    # directly: symmetrise J and K matrices themselves.
    # This is NOT the full 8-fold symmetrisation, but it ensures
    # Hermiticity which is the dominant effect.

    J_sym = 0.5 * (J_ket + JT)
    K_sym = 0.5 * (K_ket + KT)

    return J_sym, K_sym


_CCM_SCALABLE_METHODS = {
    "union12": "bra_home_full",   # eq-18 four-center (matches padded ccm_eri)
    "aiccm2026dev-a": "aiccm2026dev-a",       # symmetric BvK-torus four-center (Sec.13)
    "aiccmdev": "aiccm2026dev-a",             # deprecated alias (back-compat)
}


def _ccm_scalable_cxx_method(method, four_center):
    """Map a CCM ``(method, four_center)`` selection to the C++ JK kernel name.

    ``method`` picks the base kernel (:data:`_CCM_SCALABLE_METHODS`);
    ``four_center`` picks how it contracts:

    * ``"direct"`` -- the integral-direct kernel (``base + "-direct"``): folds each
      weighted quartet straight into J/K, O(nbf**2) memory.
    * ``"full"`` / ``"dense"`` -- the base kernel: builds the dense O(nbf**4)
      effective tensor (the preserved small-cluster comparison reference).

    Shared by :func:`run_ccm_rhf_scalable` and the KS-CCM JK builder
    (``dft._ccm_jk_builder``) so ``four_center`` means the same thing everywhere.
    """
    try:
        base = _CCM_SCALABLE_METHODS[method]
    except KeyError:
        raise ValueError(
            f"unknown CCM four-center method {method!r}; "
            f"choose from {sorted(_CCM_SCALABLE_METHODS)}"
        )
    if four_center == "direct":
        return base + "-direct"
    if four_center in ("full", "dense"):
        return base
    raise ValueError(
        f"unknown four_center {four_center!r}; choose 'direct' (integral-direct "
        "J/K, O(nbf**2) memory -- default) or 'full' (dense effective tensor, "
        "O(nbf**4) memory -- small-cluster comparison reference)"
    )


def _make_ccm_jk_builder(ccm, cxx_method, schwarz_threshold):
    """Construct the C++ CCM-weighted JK builder.

    ``schwarz_threshold`` is the opt-in Cauchy-Schwarz screening threshold for the
    ``-direct`` kernels (``0.0`` = off = exact, the default). The C++ builder copies
    ``lattice_options`` at construction (``CCMWeightedGammaJKBuilder`` stores it by
    value), so we set the threshold on ``ccm.lattice_options`` around the build and
    restore it -- the constructed builder retains the value, and ``ccm`` is left
    untouched. Forcing ``0.0`` by default also skips the otherwise-wasted Schwarz
    factor computation (the CCM ``lattice_options`` default is ``1e-12``).
    """
    from vibeqc._vibeqc_core import make_periodic_gamma_ccm_jk_builder

    weight_cells, weight_matrices = _prepare_ccm_weights(ccm)
    saved = ccm.lattice_options.schwarz_threshold
    try:
        ccm.lattice_options.schwarz_threshold = float(schwarz_threshold)
        return make_periodic_gamma_ccm_jk_builder(
            ccm.basis, ccm.cluster_system, weight_cells, weight_matrices,
            ccm.lattice_options, cxx_method,
        )
    finally:
        ccm.lattice_options.schwarz_threshold = saved


def run_ccm_rhf_scalable(
    ccm, *, method="union12", four_center="direct", schwarz_threshold=0.0,
    max_iter=128, conv_tol=1e-9, diis_dim=8, lindep_tol=1e-7
):
    """Run closed-shell HF-CCM using the scalable C++ lattice-sum JK builder.

    Unlike :func:`run_ccm_rhf` which materialises the O(n^4) effective ERI
    tensor in Python, this driver uses the C++ ``build_jk_ccm_weighted`` -- a
    Gamma-only periodic J/K build that applies the WSSC four-center weights
    during the shell-quartet loop.

    ``method`` selects the four-center weighting, matching the Python padded
    route: ``"union12"`` (default, eq-18 -- C++ ``"bra_home_full"``) or
    ``"aiccm2026dev-a"`` (the symmetric Born-von Kármán-torus four-center,
    ``AICCM_ALGORITHM.md`` Sec.13). Both reproduce their Python ``run_ccm_rhf``
    counterpart (``method=`` there) to µHa.

    ``four_center`` selects how the weighted quartets are contracted (Phase 3b):

    * ``"direct"`` (default) -- **integral-direct**: each weighted quartet block
      is folded straight into J/K, so peak memory is O(nbf**2) (thread-local
      J/K) instead of O(nbf**4). This is what lets real 3-D cells at production
      basis fit in RAM. Reproduces the ``"full"`` path to ~1e-12 (a summation
      reorder, not bit-for-bit).
    * ``"full"`` (alias ``"dense"``) -- the **dense effective tensor** path: build
      the O(nbf**4) effective ERI tensor, bra-ket symmetrise, then contract. This
      is the preserved small-cluster *comparison reference*; it OOMs on real 3-D
      production-basis cells (that is exactly why ``"direct"`` is the default).

    ``schwarz_threshold`` is the **opt-in** Cauchy-Schwarz screening threshold for
    the ``four_center="direct"`` kernels (no effect on ``"full"``). Default
    ``0.0`` = **off** -- the direct kernel is exact (every weighted quartet is
    contracted). A positive value (e.g. ``1e-12``) skips shell-quartets whose
    rigorous bound ``|w| * Q_bra * Q_ket * D_max`` falls below it -- a throughput
    lever that changes the result only at that threshold (so it is *not* the
    byte-for-byte exact path; keep it tight). Off by default preserves the exact
    integral-direct guarantee.

    The SCF loop runs through the production C++ ``run_rhf_scf_with_jk``
    entry point (DIIS, damping, level-shift, Newton fallback).

    Other parameters mirror :func:`run_ccm_rhf`. Returns a :class:`CCMSCFResult`.
    """
    cxx_method = _ccm_scalable_cxx_method(method, four_center)
    from vibeqc import (
        RHFOptions,
        run_rhf_scf_with_jk,
    )

    from .integrals import ccm_overlap
    from .padded import ccm_hcore, ccm_nuclear_repulsion

    S = ccm_overlap(ccm)
    h, _, _ = ccm_hcore(ccm)
    e_nn = ccm_nuclear_repulsion(ccm)

    n_elec = ccm.supercell.n_electrons()
    if n_elec % 2 != 0:
        raise ValueError(
            f"run_ccm_rhf_scalable is closed-shell but the cluster has "
            f"{n_elec} electrons; use an even-electron cluster."
        )

    # Fig.-6 guard: the CCM overlap must be positive definite.
    w_s = np.linalg.eigvalsh(S)
    if w_s[0] < lindep_tol:
        raise ValueError(
            f"CCM overlap matrix near-singular (min eig {w_s[0]:.2e} < "
            f"{lindep_tol:.1e}); enlarge the cluster or screen diffuse "
            "functions (Peintinger & Bredow 2014, Fig. 6)."
        )

    # Build the CCM-weighted JK builder. The base C++ method applies the WSSC
    # four-center weight during the shell-quartet loop:
    #   "bra_home_full" (union12) -- bra at home + ket imaged, eq-18 weight, then
    #       bra-ket symmetrised; reproduces the padded ccm_eri (gold -0.542875).
    #   "aiccm2026dev-a" -- the symmetric BvK-torus four-center (Sec.13): symmetric bridge
    #       + independent min-image fold; exactly 8-fold symmetric for any lattice.
    # The "-direct" suffix (four_center="direct", default) folds each weighted
    # quartet straight into J/K -- O(nbf**2) memory; the bare base name builds the
    # full O(nbf**4) effective tensor (four_center="full", comparison reference).
    # schwarz_threshold (default 0.0) opts into Cauchy-Schwarz screening of the
    # -direct kernels.
    jk = _make_ccm_jk_builder(ccm, cxx_method, schwarz_threshold)

    # Run SCF through the production C++ driver.
    opts = RHFOptions()
    opts.max_iter = max_iter
    opts.conv_tol_energy = conv_tol
    # opts.diis_subspace_size = diis_dim  # not exposed
    # opts.lindep_tol = lindep_tol  # handled by C++ SCF

    result = run_rhf_scf_with_jk(
        ccm.basis,
        n_elec,
        np.asarray(S, dtype=np.float64),
        np.asarray(h, dtype=np.float64),
        float(e_nn),
        jk,
        opts,
        np.empty((0, 0)),  # empty initial density -> Hcore guess
    )

    # Extract results.
    D = np.asarray(result.density, dtype=float)
    F = np.asarray(result.fock, dtype=float)
    C = np.asarray(result.mo_coeffs, dtype=float)
    eps = np.asarray(result.mo_energies, dtype=float)

    e_total = float(result.energy)
    idem = np.linalg.norm(D @ S @ D - 2.0 * D)

    return CCMSCFResult(
        converged=result.converged,
        n_iter=result.n_iter,
        energy=e_total,
        energy_per_atom=e_total / ccm.n_atoms,
        e_electronic=e_total - float(e_nn),
        e_nuclear=float(e_nn),
        mo_energies=eps,
        mo_coeffs=C,
        density=D,
        fock=F,
        overlap=np.asarray(S),
        hcore=np.asarray(h),
        idempotency_error=float(idem),
    )
